import io

from aldepyde.databases._database import local_database
import operator
from contextlib import nullcontext
import re

class scop_parser(local_database):
    op = {
        "and": lambda a,b,c: a and b and c,
        "or": lambda a,b,c: a or b or c
    }

    def fetch(self, url):
        pass

    def fetch_code(self, codes):
        pass

    def parse(self, text):
        pass

    def extract_all_scop(self):
        pass

    def partition_scope(self):
        pass

    def extract_all_astral(self):
        lines = self.fp.readlines()
        entry = b""
        for line in lines:
            if line.startswith(b">") and len(entry) > 0:
                yield entry
                entry = b""
            entry += line
        yield entry

    # TODO allow a list of search parameters. Big challenge to make efficient, but could be cute
    def partition_astral(self, destination:None|str=None, append=False, class_name:str=b'',contains_id:str=b'' , contains_desc:str=b'', mode="and") -> dict:
        mode = mode.lower()
        # Everything is a byte string in order to play nicely with future parent methods
        if isinstance(class_name, str):
            class_name = class_name.encode('utf-8')
        if isinstance(contains_desc, str):
            contains_desc = contains_desc.encode('utf-8')
        if isinstance(contains_id, str):
            contains_id = contains_id.encode('utf-8')
        if mode != "and" and mode != "or":
            raise ValueError("mode must be \"and\" or \"or\".")
        logic = scop_parser.op[mode]
        regex = re.compile(b">[a-zA-Z0-9_.]* *[a-l](.[0-9]+)?(.[0-9]+)?(.[0-9]+)?")
        if append:
            file_context = open(destination, 'ab') if destination is not None else nullcontext(io.BytesIO())
        else:
            file_context = open(destination, 'wb') if destination is not None else nullcontext(io.BytesIO())
        with file_context as fp:
            ret_dict = dict()
            for line in self.extract_all_astral():
                identifiers = regex.search(line).group().split()
                id = identifiers[0]
                cls = identifiers[1]
                unmatched_spl = regex.sub(b'', line).split(b'\n')
                desc = unmatched_spl[0]
                sequence = unmatched_spl[1:]
                if logic(class_name.lower() in cls.lower(), contains_id.lower() in id.lower(), contains_desc.lower() in desc.lower()):
                    ret_dict[id] = { # Yes, I know '>' isn't part of the FASTA identifier. This keeps things more consistant
                        "class" : cls,
                        "description" : desc,
                        "sequence" : b"".join(sequence)
                    }
                    fp.write(line)
        return ret_dict
