import os
import shutil
from glob import glob
from pathlib import Path

import joblib
import polars as pl
from phylogenie import Tree, load_nexus
from tqdm import tqdm

from bella_companion.utils import read_weights_dir, summarize_logs_dir


def summarize_logs():
    data_dir = Path(__file__).parent / "data"
    change_times = pl.read_csv(data_dir / "change_times.csv", has_header=False)
    n_time_bins = len(change_times) + 1

    logs_dir = Path(os.environ["BELLA_BEAST_OUTPUT_DIR"]) / "platyrrhine"
    summaries = summarize_logs_dir(
        logs_dir=logs_dir,
        target_columns=[
            f"{rate}RateSPi{i}_{t}"
            for rate in ["birth", "death"]
            for i in range(n_time_bins)
            for t in ["0", "1", "2", "3"]
        ],
    )
    weights = read_weights_dir(logs_dir)

    summaries_dir = Path(os.environ["BELLA_LOG_SUMMARIES_DIR"], "platyrrhine")
    os.makedirs(summaries_dir, exist_ok=True)
    summaries.write_csv(summaries_dir / "MLP.csv")
    joblib.dump(weights, summaries_dir / "MLP.weights.pkl")

    trees: list[Tree] = []
    for tree_file in tqdm(glob(str(logs_dir / "*.trees")), "Summarizing trees"):
        trees.extend(list(load_nexus(tree_file).values())[:-10])
    joblib.dump(trees, summaries_dir / "trees.pkl")

    shutil.copy(logs_dir / "0.trees", summaries_dir / "sample-tree.nexus")
