"""Command-line interface entrypoint for the `foldifyai` package."""
from __future__ import annotations
import time 
from colorama import Fore, Style

import base64
import zipfile
import io
import pathlib
import sys
from pathlib import Path
from tqdm import tqdm
import os 
import time 
import json 
from rdkit import Chem
from rdkit.Chem import AllChem
import urllib
from foldifyai.utils import get_type, file_exists
import requests
try: 
    from logmd import LogMD
except:
    pass 


def _usage() -> None:
    """Print a short help message using the actual executable name."""
    prog = pathlib.Path(sys.argv[0]).name or "foldify"
    print(f"Usage: {prog} <path_to_file.fasta>", file=sys.stderr)


def compute_3d_conformer(mol, version: str = "v3") -> bool:
    if version == "v3":
        options = AllChem.ETKDGv3()
    elif version == "v2":
        options = AllChem.ETKDGv2()
    else:
        options = AllChem.ETKDGv2()

    options.clearConfs = False
    conf_id = -1

    options.timeout = 3 # don't spend more than three seconds on AllChem.EmbedMolecule
    #options.maxIterations = 10 # don't spend more than 10 attempts (default is 100?)

    try:
        conf_id = AllChem.EmbedMolecule(mol, options)#, maxAttempts=0)

        if conf_id == -1:
            print(
                f"WARNING: RDKit ETKDGv3 failed to generate a conformer for molecule "
                f"{Chem.MolToSmiles(AllChem.RemoveHs(mol))}, so the program will start with random coordinates. "
                f"Note that the performance of the model under this behaviour was not tested."
            )
            options.useRandomCoords = True
            return False # conf_id = AllChem.EmbedMolecule(mol, options)

        #AllChem.UFFOptimizeMolecule(mol, confId=conf_id, maxIters=1000)
        # i set the maxIters=33 to skip more aggressively.
        AllChem.UFFOptimizeMolecule(mol, confId=conf_id, maxIters=33)

    except RuntimeError:
        return False 
        pass  # Force field issue here
    except ValueError:
        return False 
        pass  # sanitization issue here

    if conf_id != -1:
        conformer = mol.GetConformer(conf_id)
        conformer.SetProp("name", "Computed")
        conformer.SetProp("coord_generation", f"ETKDG{version}")
        return True

    return False

def test(seq, affinity=False):
    try:
        mol = AllChem.MolFromSmiles(seq)
        mol = AllChem.AddHs(mol)

        # Set atom names
        canonical_order = AllChem.CanonicalRankAtoms(mol)
        for atom, can_idx in zip(mol.GetAtoms(), canonical_order):
            atom_name = atom.GetSymbol().upper() + str(can_idx + 1)
            if len(atom_name) > 4:
                msg = (
                    f"{seq} has an atom with a name longer than "
                    f"4 characters: {atom_name}."
                )
                raise ValueError(msg)
                return False 
            atom.SetProp("name", atom_name)

        success = compute_3d_conformer(mol)
        if not success:
            msg = f"Failed to compute 3D conformer for {seq}"
            return False 
            raise ValueError(msg)

        mol_no_h = AllChem.RemoveHs(mol, sanitize=False)
        affinity_mw = AllChem.Descriptors.MolWt(mol_no_h) if affinity else None
        return True
    except Exception as e:
        print(e, seq)
        return False 

def fold(args):
    folder = args.input
    log = args.logmd 
    if args.input.endswith('.fasta'):
        files = [args.input]
        folder = folder.replace('.fasta', '')
    else:
        files = [a for a in Path(folder).rglob("*.fasta")]
        files = sorted(files, key=lambda p: os.path.getsize(str(p)))
        files = [str(a) for a in files]

    if log: 
        l = LogMD()

    if args.output == '': output_dir = f"foldify_{folder.replace('/','')}"
    else: output_dir = args.output
    print(files[:5])
    #exit()

    pbar = tqdm(files[::-1])
    for c,p in enumerate(pbar):
        path = f"{output_dir}/{p.replace('.fasta', '')}/" #v0
        #path = f"{output_dir}/" # v1
        #print(path)

        if args.cf != '':
            s3_path = path[:-1] + '.zip'
            s3_path = '/'.join(s3_path.split('/')[0::2])
            #print(s3_path)
            import boto3
            # https://8a6ab2cee54f34a71f5a8d99e92da2d2.r2.cloudflarestorage.com
            s3 = boto3.session.Session(profile_name="r2").client('s3', endpoint_url=args.cf)
            if file_exists(s3, 'dmitrij', s3_path) and not args.override:
                print(f"Skipping {p} found on s3 {s3_path}")
                continue 

        #print(f"{path}boltz2_prediction_0.pdb")
        if os.path.exists(f"{path}boltz2_prediction_0.pdb") and not args.override:
            print(f"Skipping {p} found {path}boltz2_prediction_0.pdb")

            # if args.cf != '' it doesn't exist on s3 but does exist locally => upload it. 
            if args.cf != '' and not file_exists(s3, 'dmitrij', s3_path):
                print(f"Uploading {p} to s3 {s3_path}")
                zip_bytes = io.BytesIO()
                #print(os.listdir(path))
                #print(s3_path)
                with zipfile.ZipFile(zip_bytes, 'w') as zip_ref:
                    #for file in os.listdir(path):
                    #    zip_ref.write(os.path.join(path, file), arcname=file)
                    for root, dirs, files in os.walk(path):
                        for file in files:
                            zip_ref.write(os.path.join(root, file), os.path.relpath(os.path.join(root, file), path))
                zip_bytes.seek(0)
                #print(zip_bytes)
                with open('my_folder.zip', 'wb') as f: 
                   f.write(zip_bytes.getvalue())
                zip_bytes.seek(0)

                a = s3.upload_fileobj(zip_bytes, 'dmitrij', s3_path)
                #print(a)
                #exit()

            continue 
        #print('stopping')
        #continue # exit()
        try: 
            t0 = time.time()
            skip = False 
            p = str(p)

            # change this to point to output_dir
            #new_path = p.replace('.fasta','_raw.pdb')
            #if os.path.exists(new_path): continue 
            #if os.path.exists(path): continue 

            content = open(p).read()
            num_tokens = sum([len(line) for line in content.split('\n') if not line.startswith('>')])


            # load paired MSA send seperate. 
            if False: 
                lines = [line for line in content.split('\n') if len(line) > 1]
                headers, seqs = lines[::2], lines[1::2]
                a3m_paths = [header.split('|')[-1] for header in headers]
                a3m_paths = [path for path in a3m_paths if os.path.exists(path)]
                if len(a3m_paths) == len(headers):
                    # assuming all the same 
                    assert a3m_paths[0] == a3m_paths[-1], f"Assuming paired MSA, but found {a3m_paths[0]} != {a3m_paths[1]}."
                    msa = open(f"{a3m_paths[0]}", 'r').read()
                    msas = msa.split('\x00')
                    for seq in seqs: 
                        for msa in msas:
                            if seq in msa: 
                                # pair
                                pass

            #print(num_tokens)
            #if num_tokens > 1000: continue 

            if True:  # this killed my desktop
                for line in content.split('\n'):
                    if line.startswith('>'): continue 
                    if line == '': continue 
                    if get_type(line) == 'SMILES': 
                        if not test(line): 
                            print(f"Skipping {p}. RDKit didn't like {line}. ")
                            #open(new_path, 'w').write(f"Skipping {p}. RDKit didn't like {line}. ")
                            skip = True 
                        else: print('ok')
                if skip: continue 
            encoded = urllib.parse.quote(content, safe="")
            if len(content) == 0: 
                continue 

            # decode the .a3m file and pass this on. 

            url = f"{args.host}/fold?only_return_zip=True&seq={encoded}&{args.args}&gpu={args.gpu}&get_msa_from_server={args.msa}"

            # Open connection with progress reporting
            response = urllib.request.urlopen(url)
            total_size = int(response.headers.get('content-length', 0))
            block_size = 1024
            #progress_bar = tqdm(total=total_size, unit='iB', unit_scale=True, 
            #                    desc=f"Foldify")
            
            result = ''
            while True:
                data = response.read(block_size)
                if not data:
                    break
                result += data.decode('utf-8')
                #pbar.update(len(data))
                pbar.set_description(f"{Fore.BLUE}{Style.BRIGHT}[Foldify]{Style.RESET_ALL} {time.strftime('%d-%m %H:%M:%S')} {p} tokens={num_tokens} {len(result)/1000}KB")

            #with open(new_path, 'w') as f: f.write(result)
            jsons = [json.loads(a) for a in result.split('\n@\n') if a != '']
            # last is a zip file, unzip. 
            b64_zip_data = jsons[-1]['data']

            zip_bytes = base64.b64decode(b64_zip_data)
            zip_in_memory = io.BytesIO(zip_bytes)
            with zipfile.ZipFile(zip_in_memory, 'r') as zip_ref:
                os.makedirs(path, exist_ok=True)
                zip_ref.extractall(path)
                if args.cf != '':
                    import boto3
                    # https://8a6ab2cee54f34a71f5a8d99e92da2d2.r2.cloudflarestorage.com
                    s3 = boto3.session.Session(profile_name="r2").client('s3', endpoint_url=args.cf)
                    #s3.upload_file('foldify_folders.zip', 'dmitrij', 'foldify_folders.zip')
                    zip_in_memory.seek(0)
                    s3.upload_fileobj(zip_in_memory, 'dmitrij', s3_path) # 'foldify_folders.zip')

            #time.sleep(1)

        except Exception as e: 
            print('something wrong', e)
            print(url)
            pass 
        #time.sleep(1)
        print('')
        #exit()


def main() -> None:  # pragma: no cover

    import argparse

    parser = argparse.ArgumentParser(description='Foldify.ai CLI', add_help=False)
    parser.add_argument('-input', '-i', type=str, help='')
    parser.add_argument('-args', type=str, help='')
    parser.add_argument('-logmd', action='store_true', help='Log with LogMD')
    #parser.add_argument('-host', '-h', type=str, default='https://gpu1.foldify.org', help='Host URL for Foldify API')
    parser.add_argument('-host', '-h', type=str, default='http://0.0.0.0:8000', help='Host URL for Foldify API')
    parser.add_argument('-output', '-o', type=str, default='', help='Output directory for results')
    parser.add_argument('-gpu', '-g', type=int, default=0, help='GPU')
    parser.add_argument('-y', action='store_true', help='Pre-accept using remote host. ')
    parser.add_argument('-cf', type=str, default='', help='Cloudflared endpoint url to store online.')
    parser.add_argument('-override', action='store_true', help='Override existing files')
    parser.add_argument('-msa', type=str, default='', help='Get msa from other ip. ')
    args = parser.parse_args()

    if args.host == 'https://gpu1.foldify.org' and not args.y:
        print("You didn't specify host. The default is a remote. ")
        print("Reply `REMOTE` if you want to send sequences. ")
        if input() != 'REMOTE': 
            print('Exiting.')
            exit()
        else: 
            print("Using remote host. ")
            print("You can skip this check with `foldify -y`")

    fold(args)


if __name__ == "__main__":  # pragma: no cover
    sys.argv = ['foldifyai','cofactors/']
    main() 
