from pathlib import Path
import yaml
from typing import Dict, Any
import copy
import re

class FlowList(list):
    pass

def flow_style_list_representer(dumper, data):
    return dumper.represent_sequence('tag:yaml.org,2002:seq', data, flow_style=True)

yaml.add_representer(FlowList, flow_style_list_representer)

def wrap_contacts_with_flow_style(obj):
    if isinstance(obj, dict):
        new_obj = {}
        for k, v in obj.items():
            if k == 'contacts' and isinstance(v, list) and v and isinstance(v[0], list):
                new_obj[k] = FlowList([FlowList(x) if isinstance(x, list) else x for x in v])
            else:
                new_obj[k] = wrap_contacts_with_flow_style(v)
        return new_obj
    elif isinstance(obj, list):
        return [wrap_contacts_with_flow_style(x) for x in obj]
    else:
        return obj

def dump_with_compact_contacts(config: Dict[str, Any], output_path: Path):
    config = wrap_contacts_with_flow_style(config)
    yaml_str = yaml.dump(config, default_flow_style=False, sort_keys=False)
    output_path.write_text(yaml_str)

def extract_ligand_base_id(template: Dict[str, Any]) -> str:
    for item in template.get("sequences", []):
        if "ligand" in item and "id" in item["ligand"]:
            original_id = item["ligand"]["id"]
            base_id = ''.join(c for c in original_id if not c.isdigit())
            return base_id
    raise ValueError("No ligand ID found in template")

def generate_yamls_from_sdfs(
    template_yaml: Path,
    sdf_dir: Path,
    output_dir: Path,
    yaml_prefix: str = "config_"
) -> None:
    output_dir.mkdir(parents=True, exist_ok=True)
    with open(template_yaml) as f:
        template = yaml.safe_load(f)

    base_id = extract_ligand_base_id(template)
    sdf_files = sorted(sdf_dir.glob("ligand_*.sdf"))
    pattern = re.compile(r"ligand_(\d+)\.sdf")

    for sdf_path in sdf_files:
        match = pattern.match(sdf_path.name)
        if not match:
            print(f"Skipping file with unexpected name: {sdf_path.name}")
            continue

        index = int(match.group(1))
        ligand_id = f"{base_id}{index}"
        config = copy.deepcopy(template)

        for item in config["sequences"]:
            if "ligand" in item:
                item["ligand"]["id"] = ligand_id
                item["ligand"]["sdf"] = str(sdf_path)

        if "properties" in config:
            for prop in config["properties"]:
                if "affinity" in prop:
                    prop["affinity"]["binder"] = ligand_id

        if "constraints" in config:
            for constraint in config["constraints"]:
                if "pocket" in constraint and "binder" in constraint["pocket"]:
                    constraint["pocket"]["binder"] = ligand_id

        output_file = output_dir / f"{yaml_prefix}{index}.yaml"
        dump_with_compact_contacts(config, output_file)
        print(f"Created {output_file} with ligand ID {ligand_id}")

def main():
    import argparse
    parser = argparse.ArgumentParser(description="Generate YAML files from a template and SDF files")
    parser.add_argument("template_yaml", type=str, help="Path to template YAML file")
    parser.add_argument("sdf_dir", type=str, help="Path to directory containing SDF files")
    parser.add_argument("output_dir", type=str, help="Path to output directory")
    parser.add_argument("--yaml-prefix", type=str, default="config_", help="Prefix for output YAML filenames")
    args = parser.parse_args()

    generate_yamls_from_sdfs(
        template_yaml=Path(args.template_yaml),
        sdf_dir=Path(args.sdf_dir),
        output_dir=Path(args.output_dir),
        yaml_prefix=args.yaml_prefix,
    )

if __name__ == "__main__":
    main()