from __future__ import annotations

import re

from refinery.units import Unit
from refinery.units.formats.archive.xtpyi import disassemble_code, extract_code_from_buffer

from types import CodeType


class _WrongKey(ValueError):
    pass


class kramer(Unit):
    """
    Deobfuscate Python samples obfuscated with Kramer.
    """

    _LINEBREAK_MAGIC = 950

    def process(self, data):
        kramer = str()
        secret = set()
        _pyver = None

        def crawl(code: CodeType, depth=1):
            nonlocal kramer
            for instruction in disassemble_code(code, _pyver):
                arg = instruction.argval
                if arg is None:
                    continue
                if isinstance(arg, tuple):
                    continue
                if isinstance(arg, str):
                    if len(arg) > len(kramer):
                        kramer = arg
                    continue
                if isinstance(arg, int):
                    secret.add(arg)
                    continue
                try:
                    crawl(arg, depth + 1)
                except Exception as E:
                    self.log_info(F'error crawling arg of type {type(arg).__name__} at depth {depth}: {E}')

        for code in extract_code_from_buffer(bytes(data)):
            _pyver = code.version
            crawl(code.container)

        if not kramer:
            raise ValueError('could not find the encoded string')

        separator = re.search('[^a-fA-F0-9]+', kramer)

        if not separator:
            raise ValueError('no separator detected; encoding method may have changed')

        def rotchar(c: int):
            if c in range(0x61, 0x7a) or c in range(0x30, 0x39):
                return c + 1
            if c == 0x7a:
                return 0x30
            if c == 0x39:
                return 0x61
            return c

        def decrypt(c: int, k: int):
            if c >= k:
                out = rotchar(c - k)
                if out not in range(0x100):
                    raise _WrongKey
                return out
            if c == self._LINEBREAK_MAGIC:
                return 0x0A
            raise _WrongKey

        def decrypt_with_key(key: int):
            decrypted = bytearray(decrypt(c, key) for c in encrypted)
            if not re.fullmatch(B'[\\s!-~]+', decrypted):
                raise _WrongKey
            return decrypted

        separator = separator.group(0)
        encrypted = [ord(bytes.fromhex(e).decode()) for e in kramer.split(separator)]

        ubound = min(x for x in encrypted if x != self._LINEBREAK_MAGIC)
        lbound = ubound - 0xFF

        secret = {k for k in secret if k > lbound and k < ubound}
        self.log_debug('potential secrets from code:', secret)

        for key in sorted(secret, reverse=True):
            try:
                return decrypt_with_key(key)
            except _WrongKey:
                pass

        self.log_info(F'all candidates failed, searching [{lbound}, {ubound}]')

        for key in range(ubound, lbound - 1, -1):
            try:
                self.log_debug('attempting key:', key)
                return decrypt_with_key(key)
            except _WrongKey:
                pass

        raise RuntimeError('could not find decryption key')
