#!/usr/bin/env python
# -*- coding: UTF-8 -*-
#
# Copyright 2019-2025 NXP
#
# SPDX-License-Identifier: BSD-3-Clause

"""The module provides support for On-The-Fly encoding for RTxxx devices."""

import logging
import os
from copy import deepcopy
from dataclasses import dataclass
from struct import pack
from typing import Any, Optional, Union

from typing_extensions import Self

from spsdk.apps.utils.utils import filepath_from_config
from spsdk.crypto.crc import CrcAlg, from_crc_algorithm
from spsdk.crypto.rng import random_bytes
from spsdk.crypto.symmetric import Counter, aes_ctr_encrypt, aes_key_wrap
from spsdk.exceptions import SPSDKError, SPSDKValueError
from spsdk.fuses.fuses import FuseScript
from spsdk.utils.abstract_features import FeatureBaseClass
from spsdk.utils.binary_image import BinaryImage
from spsdk.utils.config import Config
from spsdk.utils.database import DatabaseManager, get_schema_file
from spsdk.utils.family import FamilyRevision, get_db, update_validation_schema_family
from spsdk.utils.misc import (
    Endianness,
    align_block,
    load_binary,
    reverse_bits,
    split_data,
    value_to_bytes,
    value_to_int,
    write_file,
)

logger = logging.getLogger(__name__)


class KeyBlob:
    """OTFAD KeyBlob: The class specifies AES key and counter initial value for specified address range.

    | typedef struct KeyBlob
    | {
    |     unsigned char key[kAesKeySizeBytes]; // 16 bytes, 128-bits, KEY[A15...A00]
    |     unsigned char ctr[kCtrSizeBytes];    // 8 bytes, 64-bits, CTR[C7...C0]
    |     unsigned int srtaddr;                // region start, SRTADDR[31 - 10]
    |     unsigned int endaddr;                // region end, ENDADDR[31 - 10]; lowest three bits are used as flags
    |     // end of 32-byte area covered by CRC
    |     unsigned int zero_fill;      // zeros
    |     unsigned int key_blob_crc32; // crc32 over 1st 32-bytes
    |     // end of 40 byte (5*64-bit) key blob data
    |     unsigned char expanded_wrap_data[8]; // 8 bytes, used for wrap expanded data
    |     // end of 48 byte (6*64-bit) wrap data
    |     unsigned char unused_filler[16]; // unused fill to 64 bytes
    | } keyblob_t;
    """

    _START_ADDR_MASK = 0x400 - 1
    # Region addresses are modulo 1024
    # The address ends with RO, ADE, VLD bits. From this perspective, only
    # bits [9:3] must be set to 1. The rest is configurable.
    _END_ADDR_MASK = 0x3F8

    # Key flags mask: RO, ADE, VLD
    _KEY_FLAG_MASK = 0x07
    # This field signals that the entire set of context registers (CTXn_KEY[0-3], CTXn_CTR[0-1],
    # CTXn_RGD_W[0-1] are read-only and cannot be modified. This field is sticky and remains
    # asserted until the next system reset. SR[RRAM] provides another level of register access
    # control and is independent of the RO indicator.
    KEY_FLAG_READ_ONLY = 0x4
    # AES Decryption Enable: For accesses hitting in a valid context, this bit indicates if the fetched data is to be
    # decrypted or simply bypassed.
    KEY_FLAG_ADE = 0x2
    # Valid: This field signals if the context is valid or not.
    KEY_FLAG_VLD = 0x1

    # key length in bytes
    KEY_SIZE = 16
    # counter length in bytes
    CTR_SIZE = 8
    # len of counter init value for export
    _EXPORT_CTR_IV_SIZE = 8
    # this constant seems to be fixed for SB2.1
    _EXPORT_NBLOCKS_5 = 5
    # binary export size
    _EXPORT_KEY_BLOB_SIZE = 64
    # QSPI image alignment length, 512 is supposed to be the safe alignment level for any QSPI device
    # this means that all QSPI images generated by this tool will be sizes of multiple 512
    _IMAGE_ALIGNMENT = 512
    # Encryption block size
    _ENCRYPTION_BLOCK_SIZE = 16

    def __init__(
        self,
        start_addr: int,
        end_addr: int,
        key: Optional[bytes] = None,
        counter_iv: Optional[bytes] = None,
        key_flags: int = KEY_FLAG_VLD | KEY_FLAG_ADE,
        # for testing
        zero_fill: Optional[bytes] = None,
        crc: Optional[bytes] = None,
    ):
        """Constructor.

        :param start_addr: start address of the region
        :param end_addr: end address of the region
        :param key_flags: see KEY_FLAG_xxx constants; default flags: RO = 0, ADE = 1, VLD = 1
        :param key: optional AES key; None to use random value
        :param counter_iv: optional counter init value for AES; None to use random value
        :param binaries: optional data chunks of this key blob
        :param zero_fill: optional value for zero_fill (for testing only); None to use random value (recommended)
        :param crc: optional value for unused CRC fill (for testing only); None to use random value (recommended)
        :raises SPSDKError: Start or end address are not aligned
        :raises SPSDKError: When there is invalid key
        :raises SPSDKError: When there is invalid start/end address
        :raises SPSDKError: When key_flags exceeds mask
        """
        if key is None:
            key = random_bytes(self.KEY_SIZE)
        if counter_iv is None:
            counter_iv = random_bytes(self.CTR_SIZE)
        if (len(key) != self.KEY_SIZE) and (len(counter_iv) != self.CTR_SIZE):
            raise SPSDKError("Invalid key")
        if start_addr < 0 or start_addr > end_addr or end_addr > 0xFFFFFFFF:
            raise SPSDKError("Invalid start/end address")
        if key_flags & ~self._KEY_FLAG_MASK != 0:
            raise SPSDKError(f"key_flags exceeds mask {hex(self._KEY_FLAG_MASK)}")
        if (start_addr & self._START_ADDR_MASK) != 0:
            raise SPSDKError(
                f"Start address must be aligned to {hex(self._START_ADDR_MASK + 1)} boundary"
            )
        # if (end_addr & self._END_ADDR_MASK) != self._END_ADDR_MASK:
        #     raise SPSDKError(f"End address must be aligned to {hex(self._END_ADDR_MASK)} boundary")
        self.key = key
        self.ctr_init_vector = counter_iv
        self.start_addr = start_addr
        self.end_addr = end_addr
        self.key_flags = key_flags
        self.zero_fill = zero_fill
        self.crc_fill = crc

    def __str__(self) -> str:
        """Text info about the instance."""
        msg = ""
        msg += f"Key:        {self.key.hex()}\n"
        msg += f"Counter IV: {self.ctr_init_vector.hex()}\n"
        msg += f"Start Addr: {hex(self.start_addr)}\n"
        msg += f"End Addr:   {hex(self.end_addr)}\n"
        return msg

    def plain_data(self) -> bytes:
        """Plain data for selected key range.

        :return: key blob exported into binary form (serialization)
        :raises SPSDKError: Invalid value of zero fill parameter
        :raises SPSDKError: Invalid value crc
        :raises SPSDKError: Invalid length binary data
        """
        result = bytes()
        result += self.key
        result += self.ctr_init_vector
        result += pack("<I", self.start_addr)
        if self.end_addr or self.key_flags:
            end_addr_with_flags = (
                ((self.end_addr - 1) & ~self._KEY_FLAG_MASK) | self.key_flags | self._END_ADDR_MASK
            )
        else:
            end_addr_with_flags = 0
        result += pack("<I", end_addr_with_flags)
        header_crc = (
            from_crc_algorithm(CrcAlg.CRC32_MPEG)
            .calculate(result)
            .to_bytes(4, Endianness.LITTLE.value)
        )
        # zero fill
        if self.zero_fill:
            if len(self.zero_fill) != 4:
                raise SPSDKError("Invalid value")
            result += self.zero_fill
        else:
            result += random_bytes(4)
        # CRC is not used, use random value
        if self.crc_fill:
            if len(self.crc_fill) != 4:
                raise SPSDKError("Invalid value crc")
            result += self.crc_fill
        else:
            result += header_crc
        result += bytes([0] * 8)  # expanded_wrap_data
        result += bytes([0] * 16)  # unused filler
        if len(result) != 64:
            raise SPSDKError("Invalid length binary data")
        return result

    # pylint: disable=invalid-name
    def export(
        self,
        kek: Union[bytes, str],
        iv: bytes = bytes([0xA6] * 8),
        byte_swap_cnt: int = 0,
    ) -> bytes:
        """Creates key wrap for the key blob.

        :param kek: key to encode; 16 bytes long
        :param iv: counter initialization vector; 8 bytes; optional, OTFAD uses empty init value
        :param byte_swap_cnt: Encrypted keyblob reverse byte count, 0 means NO reversing is enabled
        :return: Exported key blob
        :raises SPSDKError: If any parameter is not valid
        :raises SPSDKError: If length of kek is not valid
        :raises SPSDKError: If length of data is not valid
        """
        if isinstance(kek, str):
            kek = bytes.fromhex(kek)
        if len(kek) != 16:
            raise SPSDKError("Invalid length of kek")
        if len(iv) != self._EXPORT_CTR_IV_SIZE:
            raise SPSDKError("Invalid length of initialization vector")
        n = self._EXPORT_NBLOCKS_5
        plaintext = self.plain_data()  # input data to be encrypted
        if len(plaintext) < n * 8:
            raise SPSDKError("Invalid length of data to be encrypted")

        blobs = bytes()
        wrap = aes_key_wrap(kek, plaintext[:40])
        if byte_swap_cnt > 0:
            for i in range(0, len(wrap), byte_swap_cnt):
                blobs += wrap[i : i + byte_swap_cnt][::-1]
        else:
            blobs += wrap

        return align_block(
            blobs, self._EXPORT_KEY_BLOB_SIZE, padding=0
        )  # align to 64 bytes (0 padding)

    def _get_ctr_nonce(self) -> bytes:
        """Get the counter initial value for image encryption.

        :return: counter bytes
        :raises SPSDKError: If length of counter is not valid
        """
        #  CTRn_x[127-0] = {CTR_W0_x[C0...C3],    // 32 bits of pre-programmed CTR
        #  CTR_W1_x[C4...C7],                     // another 32 bits of CTR
        #  CTR_W0_x[C0...C3] ^ CTR_W1_x[C4...C7], // exclusive-OR of CTR values
        #  systemAddress[31-4], 0000b             // 0-modulo-16 system address */

        if len(self.ctr_init_vector) != 8:
            raise SPSDKError("Invalid length of counter init")

        result = bytearray(16)
        result[:4] = self.ctr_init_vector[:4]
        result[4:8] = self.ctr_init_vector[4:]
        for i in range(0, 4):
            result[8 + i] = self.ctr_init_vector[0 + i] ^ self.ctr_init_vector[4 + i]

        # result[15:12] = start_addr as a counter; nonce has these bytes zero and value passes as counter init value

        return bytes(result)

    def contains_addr(self, addr: int) -> bool:
        """Whether key blob contains specified address.

        :param addr: to be tested
        :return: True if yes, False otherwise
        """
        return self.start_addr <= addr <= self.end_addr

    def matches_range(self, image_start: int, image_end: int) -> bool:
        """Whether key blob matches address range of the image to be encrypted.

        :param image_start: start address of the image
        :param image_end: last address of the image
        :return: True if yes, False otherwise
        """
        return self.contains_addr(image_start) and self.contains_addr(image_end)

    def encrypt_image(
        self,
        base_address: int,
        data: bytes,
        byte_swap: bool,
        counter_value: Optional[int] = None,
    ) -> bytes:
        """Encrypt specified data.

        :param base_address: of the data in target memory; must be >= self.start_addr
        :param data: to be encrypted (e.g. plain image); base_address + len(data) must be <= self.end_addr
        :param byte_swap: this probably depends on the flash device, how bytes are organized there
        :param counter_value: Optional counter value, if not specified start address of keyblob will be used
        :return: encrypted data
        :raises SPSDKError: If start address is not valid
        """
        if base_address % 16 != 0:
            raise SPSDKError("Invalid start address")  # Start address has to be 16 byte aligned
        data = align_block(data, self._ENCRYPTION_BLOCK_SIZE)  # align data length
        data_len = len(data)

        # check start and end addresses
        # Support dual image boot, do not raise exception
        if not self.matches_range(base_address, base_address + data_len - 1):
            logger.warning(
                f"Image address range is not within key blob: "
                f"{hex(self.start_addr)}-{hex(self.end_addr)}."
                " Ignore this if flash remap feature is used"
            )
        result = bytes()

        if not counter_value:
            counter_value = self.start_addr

        counter = Counter(
            self._get_ctr_nonce(), ctr_value=counter_value, ctr_byteorder_encoding=Endianness.BIG
        )

        for index in range(0, data_len, 16):
            # prepare data in byte order
            if byte_swap:
                # swap 8 bytes + swap 8 bytes
                data_2_encr = (
                    data[-data_len + index + 7 : -data_len + index - 1 : -1]
                    + data[-data_len + index + 15 : -data_len + index + 7 : -1]
                )
            else:
                data_2_encr = data[index : index + 16]
            # encrypt
            encr_data = aes_ctr_encrypt(self.key, data_2_encr, counter.value)
            # fix byte order in result
            if byte_swap:
                result += encr_data[-9:-17:-1] + encr_data[-1:-9:-1]  # swap 8 bytes + swap 8 bytes
            else:
                result += encr_data
            # update counter for encryption
            counter.increment(16)

        if len(result) != data_len:
            raise SPSDKError("Invalid length of encrypted data")
        return bytes(result)

    @property
    def is_encrypted(self) -> bool:
        """Get the required encryption or not.

        :return: True if blob is encrypted, False otherwise.
        """
        return (bool)(
            (self.key_flags & (self.KEY_FLAG_ADE | self.KEY_FLAG_VLD))
            == (self.KEY_FLAG_ADE | self.KEY_FLAG_VLD)
        )


class Otfad(FeatureBaseClass):
    """OTFAD: On-the-Fly AES Decryption Module with reflecting of NXP parts."""

    FEATURE = DatabaseManager.OTFAD

    OTFAD_DATA_UNIT = 0x400

    def __init__(
        self,
        family: FamilyRevision,
        kek: Union[bytes, str],
        table_address: int = 0,
        key_blobs: Optional[list[KeyBlob]] = None,
        key_scramble_mask: Optional[int] = None,
        key_scramble_align: Optional[int] = None,
        binaries: Optional[BinaryImage] = None,
        data_alignment: int = 512,
        otfad_table_name: str = "OTFAD_Table",
        otfad_all_name: str = "otfad_whole_image",
        generate_readme: bool = True,
        index: Optional[int] = None,
    ) -> None:
        """Constructor.

        :param family: Device family
        :param kek: KEK to encrypt OTFAD table
        :param table_address: Absolute address of OTFAD table.
        :param key_blobs: Optional Key blobs to add to OTFAD, defaults to None
        :param key_scramble_mask: If defined, the key scrambling algorithm will be applied.
            ('key_scramble_align' must be defined also)
        :param key_scramble_align: If defined, the key scrambling algorithm will be applied.
            ('key_scramble_mask' must be defined also)
        :param binaries: Optional binary image to be encrypted
        :param data_alignment: Data alignment for the binary image
        :param otfad_table_name: Name of the OTFAD table file
        :param otfad_all_name: Name of the whole OTFAD image file
        :param generate_readme: Generate readme file
        :param index: Index of the OTFAD peripheral for fuses
        :raises SPSDKValueError: Unsupported family
        """
        self._key_blobs: list[KeyBlob] = []
        self.data_alignment = data_alignment
        self.otfad_table_name = otfad_table_name
        self.otfad_all_name = otfad_all_name
        self.generate_readme = generate_readme
        self.index = index

        if (key_scramble_align is None and key_scramble_mask) or (
            key_scramble_align and key_scramble_mask is None
        ):
            raise SPSDKValueError("Key Scrambling is not fully defined")

        self.family = family
        self.db = get_db(family)
        self.kek = bytes.fromhex(kek) if isinstance(kek, str) else kek
        self.key_scramble_mask = key_scramble_mask
        self.key_scramble_align = key_scramble_align
        self.table_address = table_address

        self.reversed_scramble_key = self.db.get_bool(
            DatabaseManager.OTFAD, "reversed_scramble_key", False
        )
        self.blobs_min_cnt = self.db.get_int(self.FEATURE, "key_blob_min_cnt")
        self.blobs_max_cnt = self.db.get_int(self.FEATURE, "key_blob_max_cnt")
        self.byte_swap = self.db.get_bool(self.FEATURE, "byte_swap")
        self.key_blob_rec_size = self.db.get_int(self.FEATURE, "key_blob_rec_size")
        self.keyblob_byte_swap_cnt = self.db.get_int(self.FEATURE, "keyblob_byte_swap_cnt")
        if self.keyblob_byte_swap_cnt not in [0, 2, 4, 8, 16]:
            raise SPSDKValueError(
                f"Invalid value of keyblob_byte_swap_cnt: {self.keyblob_byte_swap_cnt}"
            )
        self.binaries = binaries

        if key_blobs:
            for key_blob in key_blobs:
                self.add_key_blob(key_blob)

        # Just fill up the minimum count of key blobs
        while len(self._key_blobs) < self.blobs_min_cnt:
            self.add_key_blob(
                KeyBlob(
                    start_addr=0,
                    end_addr=0,
                    key=bytes([0] * KeyBlob.KEY_SIZE),
                    counter_iv=bytes([0] * KeyBlob.CTR_SIZE),
                    key_flags=0,
                    zero_fill=bytes([0] * 4),
                )
            )

    def __getitem__(self, index: int) -> KeyBlob:
        return self._key_blobs[index]

    def __setitem__(self, index: int, value: KeyBlob) -> None:
        self._key_blobs.remove(self._key_blobs[index])
        self._key_blobs.insert(index, value)

    def __len__(self) -> int:
        """Count of keyblobs."""
        return len(self._key_blobs)

    def add_key_blob(self, key_blob: KeyBlob) -> None:
        """Add key for specified address range.

        :param key_blob: to be added
        """
        self._key_blobs.append(key_blob)

    def encrypt_image(self, image: bytes, base_addr: int, byte_swap: bool) -> bytes:
        """Encrypt image with all available keyblobs.

        :param image: plain image to be encrypted
        :param base_addr: where the image will be located in target processor
        :param byte_swap: this probably depends on the flash device, how bytes are organized there
        :return: encrypted image
        """
        encrypted_data = bytearray(image)
        addr = base_addr
        for block in split_data(image, self.OTFAD_DATA_UNIT):
            for key_blob in self._key_blobs:
                if key_blob.matches_range(addr, addr + len(block) - 1) and key_blob.is_encrypted:
                    logger.debug(
                        f"Encrypting {hex(addr)}:{hex(len(block) + addr - 1)}"
                        f" with keyblob: \n {str(key_blob)}"
                    )
                    encrypted_data[addr - base_addr : len(block) + addr - base_addr] = (
                        key_blob.encrypt_image(addr, block, byte_swap, counter_value=addr)
                    )
            addr += len(block)

        return bytes(encrypted_data)

    def get_key_blobs(self) -> bytes:
        """Get key blobs.

        :return: Binary key blobs joined together
        """
        result = bytes()
        for key_blob in self._key_blobs:
            result += key_blob.plain_data()
        return align_block(
            result, 256
        )  # this is for compatibility with elftosb, probably need FLASH sector size

    def encrypt_key_blobs(
        self,
        kek: Union[bytes, str],
        key_scramble_mask: Optional[int] = None,
        key_scramble_align: Optional[int] = None,
        byte_swap_cnt: int = 0,
    ) -> bytes:
        """Encrypt key blobs with specified key.

        :param kek: key to encode key blobs
        :param key_scramble_mask: 32-bit scramble key, if KEK scrambling is desired.
        :param key_scramble_align: 8-bit scramble align, if KEK scrambling is desired.
        :param byte_swap_cnt: Encrypted keyblob reverse byte count, 0 means NO reversing is enabled
        :raises SPSDKValueError: Invalid input value.
        :return: encrypted binary key blobs joined together
        """
        if isinstance(kek, str):
            kek = bytes.fromhex(kek)
        scramble_enabled = key_scramble_mask is not None and key_scramble_align is not None
        if scramble_enabled:
            assert isinstance(key_scramble_mask, int) and isinstance(key_scramble_align, int)
            if key_scramble_mask >= 1 << 32:
                raise SPSDKValueError("OTFAD Key scramble mask has invalid length")
            if key_scramble_align >= 1 << 8:
                raise SPSDKValueError("OTFAD Key scramble align has invalid length")

            logger.debug("The scrambling of keys is enabled.")

            if self.reversed_scramble_key:
                key_scramble_mask = reverse_bits(key_scramble_mask, 32)

            key_scramble_mask_bytes = key_scramble_mask.to_bytes(
                4, byteorder=Endianness.LITTLE.value
            )
            logger.debug(f"The inverted scramble key is: {key_scramble_mask_bytes.hex()}")
        result = bytes()
        scrambled = bytes()
        for i, key_blob in enumerate(self._key_blobs):
            if scramble_enabled:
                assert isinstance(key_scramble_mask, int) and isinstance(key_scramble_align, int)
                scrambled = bytearray(kek)
                long_ix = (key_scramble_align >> (i * 2)) & 0x03
                for j in range(4):
                    scrambled[(long_ix * 4) + j] ^= key_scramble_mask_bytes[j]

            logger.debug(
                f"Used KEK for keyblob{i} encryption is: {scrambled.hex() if scramble_enabled else kek.hex()}"
            )

            result += key_blob.export(
                scrambled if scramble_enabled else kek, byte_swap_cnt=byte_swap_cnt
            )
        return align_block(
            result, 256
        )  # this is for compatibility with elftosb, probably need FLASH sector size

    def __repr__(self) -> str:
        """Simple object text representation."""
        return f"Otfad object for {self.family}"

    def __str__(self) -> str:
        """Text info about the instance."""
        msg = "Key-Blob\n"
        for index, key_blob in enumerate(self._key_blobs):
            msg += f"Key-Blob {str(index)}:\n"
            msg += str(key_blob)
        return msg

    @property
    def scramble_enabled(self) -> bool:
        """Property indicating if the scrambling is enabled."""
        return self.key_scramble_mask is not None and self.key_scramble_align is not None

    @staticmethod
    def get_blhost_script_otp_keys(
        family: FamilyRevision, otp_master_key: bytes, otfad_kek_seed: bytes
    ) -> str:
        """Create BLHOST script to load fuses needed to run OTFAD with OTP fuses.

        :param family: Device family.
        :param otp_master_key: OTP Master Key.
        :param otfad_kek_seed: OTFAD Key Seed.
        :return: BLHOST script that loads the keys into fuses.
        """
        fuses_script = FuseScript(family, DatabaseManager.OTFAD)

        @dataclass
        class OTP:
            """Just dumb class for storing OTP values."""

            otp_master_key: bytes
            otfad_kek_seed: bytes

        otp = OTP(otp_master_key, otfad_kek_seed)

        return fuses_script.generate_script(otp)

    def get_blhost_script_otp_kek(self, index: int = 1) -> str:
        """Create BLHOST script to load fuses needed to run OTFAD with OTP fuses just for OTFAD key.

        :param index: Index of OTFAD peripheral [1, 2, ..., n].
        :return: BLHOST script that loads the keys into fuses.
        """
        if not self.db.get_bool(self.FEATURE, "has_kek_fuses", default=False):
            logger.debug(f"The {self.family} has no OTFAD KEK fuses")
            return ""

        fuses_script = FuseScript(self.family, self.FEATURE, index)
        return fuses_script.generate_script(self)

    def export_image(
        self,
        plain_data: bool = False,
        swap_bytes: bool = False,
        join_sub_images: bool = True,
        table_address: int = 0,
    ) -> Optional[BinaryImage]:
        """Get the OTFAD Key Blob Binary Image representation.

        :param plain_data: Binary representation in plain data format, defaults to False
        :param swap_bytes: For some platforms the swap bytes is needed in encrypted format, defaults to False.
        :param join_sub_images: If it's True, all the binary sub-images are joined into one, defaults to True.
        :param table_address: Absolute address of OTFAD table.
        :return: OTFAD key blob data in BinaryImage.
        """
        if self.binaries is None:
            return None
        binaries: BinaryImage = deepcopy(self.binaries)
        for binary in binaries.sub_images:
            if binary.binary:
                binary.binary = align_block(binary.binary, KeyBlob._ENCRYPTION_BLOCK_SIZE)
            for segment in binary.sub_images:
                if segment.binary:
                    segment.binary = align_block(segment.binary, KeyBlob._ENCRYPTION_BLOCK_SIZE)

        binaries.validate()

        if not plain_data:
            for binary in binaries.sub_images:
                if binary.binary:
                    binary.binary = self.encrypt_image(
                        binary.binary,
                        table_address + binary.absolute_address,
                        swap_bytes,
                    )
                for segment in binary.sub_images:
                    if segment.binary:
                        segment.binary = self.encrypt_image(
                            segment.binary,
                            segment.absolute_address + table_address,
                            swap_bytes,
                        )

        if join_sub_images:
            binaries.join_images()
            binaries.validate()

        return binaries

    def binary_image(
        self,
        plain_data: bool = False,
        data_alignment: int = 16,
        otfad_table_name: str = "OTFAD_Table",
    ) -> BinaryImage:
        """Get the OTFAD Binary Image representation.

        :param plain_data: Binary representation in plain format, defaults to False
        :param data_alignment: Alignment of data part key blobs.
        :param otfad_table_name: name of the output file that contains OTFAD table
        :return: OTFAD in BinaryImage.
        """
        otfad = BinaryImage("OTFAD", offset=self.table_address)
        # Add mandatory OTFAD table
        otfad_table = (
            self.get_key_blobs()
            if plain_data
            else self.encrypt_key_blobs(
                self.kek,
                self.key_scramble_mask,
                self.key_scramble_align,
                self.keyblob_byte_swap_cnt,
            )
        )
        otfad.add_image(
            BinaryImage(
                otfad_table_name,
                size=self.key_blob_rec_size * self.blobs_max_cnt,
                offset=0,
                description=f"OTFAD description table for {self.family}",
                binary=otfad_table,
                alignment=256,
            )
        )
        binaries = self.export_image(table_address=self.table_address)

        if binaries:
            binaries.alignment = data_alignment
            binaries.validate()
            otfad.add_image(binaries)
        return otfad

    def export(self) -> bytes:
        """Export object into bytes array.

        :return: Exported object as bytes.
        :raises NotImplementedError: If not implemented in the specific subclass.
        """
        raise NotImplementedError()

    @classmethod
    def parse(cls, data: bytes) -> Self:
        """Parse object from bytes array.

        :param data: OTFAD keyblob in bytes.
        :return: Parsed object.
        :raises NotImplementedError: If not implemented in the specific subclass.
        """
        raise NotImplementedError()

    @classmethod
    def get_validation_schemas(cls, family: FamilyRevision) -> list[dict[str, Any]]:
        """Get list of validation schemas.

        :param family: Family for which the template should be generated.
        :return: Validation list of schemas.
        """
        database = get_db(family)
        schemas = get_schema_file(cls.FEATURE)
        sch_family = get_schema_file("general")["family"]
        update_validation_schema_family(
            sch_family["properties"], Otfad.get_supported_families(), family
        )
        sch_family["main_title"] = f"On-The-Fly AES decryption Configuration for {family}."
        sch_family["note"] = database.get_str(cls.FEATURE, "additional_template_text", default="")
        # Update address in the schema template
        try:
            flexspi_base = database.device.info.memory_map.get_memory(
                block_name="flexspi1_ns"
            ).base_address
            schemas["otfad"]["properties"]["otfad_table_address"]["template_value"] = hex(
                flexspi_base
            )
            schemas["otfad"]["properties"]["data_blobs"]["items"]["properties"]["address"][
                "template_value"
            ] = hex(flexspi_base + 0x1000)
            schemas["otfad"]["properties"]["key_blobs"]["items"]["properties"]["start_address"][
                "template_value"
            ] = hex(flexspi_base + 0x1000)
            schemas["otfad"]["properties"]["key_blobs"]["items"]["properties"]["end_address"][
                "template_value"
            ] = hex(flexspi_base + 0x10000)
        except (SPSDKError, KeyError):
            pass

        ret = [sch_family, schemas["otfad_output"], schemas["otfad"]]
        additional_schemes = database.get_list(cls.FEATURE, "additional_template", default=[])
        ret.extend([schemas[x] for x in additional_schemes])
        return ret

    def get_config(self, data_path: str = "./") -> Config:
        """Create configuration of the Feature.

        :param data_path: Path to store the data files of configuration.
        :return: Configuration dictionary.
        """
        raise NotImplementedError

    @classmethod
    def load_from_config(cls, config: Config) -> Self:
        """Converts the configuration option into an OTFAD image object.

        "config" content array of containers configurations.

        :param config: array of OTFAD configuration dictionaries.
        :return: initialized OTFAD object.
        """
        otfad_config = config.get_list_of_configs("key_blobs")
        family = FamilyRevision.load_from_config(config)
        database = get_db(family)
        kek = config.load_symmetric_key("kek", expected_size=16)
        logger.debug(f"Loaded KEK: {kek.hex()}")
        table_address = config.get_int("otfad_table_address")
        start_address = min([addr.get_int("start_address") for addr in otfad_config])

        key_scramble_mask = None
        key_scramble_align = None
        if database.get_bool(cls.FEATURE, "supports_key_scrambling", default=False):
            if "key_scramble" in config.keys():
                key_scramble = config["key_scramble"]
                key_scramble_mask = value_to_int(key_scramble["key_scramble_mask"])
                key_scramble_align = value_to_int(key_scramble["key_scramble_align"])

        binaries = None
        if "data_blobs" in config:
            data_blobs = config.get_list_of_configs("data_blobs")
            # pylint: disable-next=nested-min-max
            start_address = min(
                min([addr.get_int("address") for addr in data_blobs]), start_address
            )
            binaries = BinaryImage(
                filepath_from_config(
                    config,
                    "encrypted_name",
                    "encrypted_blobs",
                    config["output_folder"],
                ),
                offset=start_address - table_address,
            )
            for data_blob in data_blobs:
                data = load_binary(data_blob.get_input_file_name("data"))
                address = data_blob.get_int("address")

                binary = BinaryImage(
                    os.path.basename(data_blob["data"]),
                    offset=address - table_address - binaries.offset,
                    binary=data,
                )
                binaries.add_image(binary)
        else:
            logger.warning("The OTFAD configuration has NOT any data blobs records!")

        output_folder = config.get_output_file_name("output_folder")
        otfad_table_name = filepath_from_config(
            config, "keyblob_name", "OTFAD_Table", output_folder
        )
        otfad_all = filepath_from_config(config, "output_name", "otfad_whole_image", output_folder)
        generate_readme = config.get("generate_readme", True)
        data_alignment = config.get_int("data_alignment", 512)
        try:
            index = config.get_int("index")
        except SPSDKError:
            index = None

        otfad = cls(
            family=family,
            kek=kek,
            table_address=table_address,
            key_scramble_align=key_scramble_align,
            key_scramble_mask=key_scramble_mask,
            binaries=binaries,
            data_alignment=data_alignment,
            otfad_table_name=otfad_table_name,
            otfad_all_name=otfad_all,
            generate_readme=generate_readme,
            index=index,
        )

        for i, key_blob_cfg in enumerate(otfad_config):
            aes_key = value_to_bytes(key_blob_cfg["aes_key"], byte_cnt=KeyBlob.KEY_SIZE)
            aes_ctr = value_to_bytes(key_blob_cfg["aes_ctr"], byte_cnt=KeyBlob.CTR_SIZE)
            start_addr = key_blob_cfg.get_int("start_address")
            end_addr = key_blob_cfg.get_int("end_address")
            aes_decryption_enable = key_blob_cfg.get("aes_decryption_enable", True)
            valid = key_blob_cfg.get("valid", True)
            read_only = key_blob_cfg.get("read_only", True)
            flags = 0
            if aes_decryption_enable:
                flags |= KeyBlob.KEY_FLAG_ADE
            if valid:
                flags |= KeyBlob.KEY_FLAG_VLD
            if read_only:
                flags |= KeyBlob.KEY_FLAG_READ_ONLY

            otfad[i] = KeyBlob(
                start_addr=start_addr,
                end_addr=end_addr,
                key=aes_key,
                counter_iv=aes_ctr,
                key_flags=flags,
                zero_fill=bytes([0] * 4),
            )

        return otfad

    def post_export(self, output_path: str) -> list[str]:
        """Perform post export steps like saving the script files."""
        generated_files = []

        binary_image = self.binary_image(
            data_alignment=self.data_alignment, otfad_table_name=self.otfad_table_name
        )
        sb21_supported = self.db.get_bool(DatabaseManager.OTFAD, "sb_21_supported", default=False)
        logger.info(f" The OTFAD image structure:\n{binary_image.draw()}")
        if self.otfad_all_name != "":
            write_file(binary_image.export(), self.otfad_all_name, mode="wb")
            logger.info(f"Created OTFAD Image:\n{self.otfad_all_name}")
        else:
            logger.info("Skipping export of OTFAD image")

        memory_map = (
            "In folder is stored two kind of files:\n"
            "  -  Binary file that contains whole image data including "
            "OTFAD table and key blobs data 'otfad_whole_image.bin'.\n"
        )
        if sb21_supported:
            memory_map += "  -  Example of BD file to simplify creating the SB2.1 file from the OTFAD source files.\n"
            bd_file_sources = "sources {"
            bd_file_section0 = "section (0) {"

        memory_map += (
            "  -  Set of separated binary files, one with OTFAD table, and one for each used key blob.\n"
            f"\nOTFAD memory map:\n{binary_image.draw(no_color=True)}"
        )

        for i, image in enumerate(binary_image.sub_images):
            if image.name != "":
                write_file(image.export(), image.name, mode="wb")
                generated_files.append(image.name)
                logger.info(f"Created OTFAD Image:\n{image.name}")
                memory_map += f"\n{image.name}:\n{str(image)}"
            else:
                logger.info(
                    f"Skipping export of {str(image)}, value is blank in the configuration file"
                )
            if sb21_supported:
                bd_file_sources += f'\n    image{i} = "{image.name}";'
                bd_file_section0 += f"\n    // Load Image: {image.name}"
                bd_file_section0 += f"\n    erase {hex(image.absolute_address)}..{hex(image.absolute_address+len(image))};"  # pylint: disable=line-too-long
                bd_file_section0 += f"\n    load image{i} > {hex(image.absolute_address)}"

        readme_file = os.path.join(output_path, "readme.txt")

        if self.generate_readme:
            write_file(memory_map, readme_file)
            generated_files.append(readme_file)
            logger.info(f"Created OTFAD readme file:\n{readme_file}")
        else:
            logger.info("Skipping generation of OTFAD readme file")

        if sb21_supported:
            bd_file_name = os.path.join(output_path, "sb21_otfad_example.bd")
            bd_file_sources += "\n}\n"
            bd_file_section0 += "\n}\n"
            bd_file = (
                "options {\n"
                "    flags = 0x8; // for sb2.1 use only 0x8 encrypted + signed\n"
                "    buildNumber = 0x1;\n"
                '    productVersion = "1.00.00";\n'
                '    componentVersion = "1.00.00";\n'
                '    secureBinaryVersion = "2.1";\n'
                "}\n"
            )
            bd_file += bd_file_sources
            bd_file += bd_file_section0

            write_file(bd_file, bd_file_name)
            generated_files.append(bd_file_name)
            logger.info(f"Created OTFAD BD file example:\n{bd_file_name}")

        if self.db.get_bool(DatabaseManager.OTFAD, "has_kek_fuses", default=False) and self.index:
            blhost_script = None
            blhost_script = self.get_blhost_script_otp_kek(self.index)
            if blhost_script:
                blhost_script_filename = os.path.join(
                    output_path, f"otfad{self.index}_{self.family.name}.bcf"
                )
                write_file(blhost_script, blhost_script_filename)
                generated_files.append(blhost_script_filename)
                logger.info(f"Created OTFAD BLHOST load fuses script:\n{blhost_script_filename}")

        return generated_files
