from dataclasses import dataclass, field, asdict

import pandas as pd
from dotenv import load_dotenv

from NL2SQLEvaluator.hf_argument_parser import TrlParser
from NL2SQLEvaluator.orchestrator import orchestrator_entrypoint
from NL2SQLEvaluator.orchestrator_state import MultipleTasks, flatten_multiple_tasks, AvailableMetrics, \
    AvailableDialect, SingleTask, SQLInstance, DataCfg, EvalCfg
from NL2SQLEvaluator.utils import utils_read_dataset, utils_get_engine

load_dotenv(override=True)


@dataclass
class ScriptArgs:
    output_dir: str = field(
        default="./outputs",
        metadata={"help": "Directory where evaluation outputs will be saved"}
    )
    seed: int = field(
        default=42,
        metadata={"help": "Random seed for reproducibility"}
    )
    metrics: list[AvailableMetrics] = field(
        default_factory=lambda: [AvailableMetrics.EXECUTION_ACCURACY],
        metadata={"help": f"List of evaluation metrics to compute (e.g., {AvailableMetrics})"}
    )
    batch_size: int = field(
        default=50,
        metadata={"help": "Number of samples to process in each batch"}
    )


@dataclass
class DatasetArgs:
    relative_db_base_path: str = field(
        default="data/bird_dev/dev_databases",
        metadata={"help": "Relative path to the database files directory"}
    )
    database_dialect: AvailableDialect = field(
        default=AvailableDialect.sqlite,
        metadata={"help": f"Database dialect (e.g., {AvailableDialect})"}
    )

    dataset_path: str = field(
        default="simone-papicchio/bird",
        metadata={"help": "HuggingFace dataset path or local dataset path"}
    )
    dataset_name: str = field(
        default="bird-dev",
        metadata={"help": "Name of the dataset configuration to use"}
    )
    column_name_target: str = field(
        default="SQL",
        metadata={"help": "Column name in dataset for the target SQL queries"}
    )

    column_name_predicted: str = field(
        default="predicted_sql",
        metadata={"help": "Column name in dataset for the predicted SQL queries"}
    )

    column_name_db_id: str = field(
        default="db_id",
        metadata={"help": "Column name in dataset for the database identifiers"}
    )


@dataclass
class ModelArgs:
    model_name: str = field(
        default="Qwen3-Coder-30B",
        metadata={"help": "Human-readable name for the model"}
    )
    model: str = field(
        default="Qwen/Qwen3-Coder-30B-A3B-Instruct",
        metadata={"help": "Model identifier for HuggingFace or local model path"}
    )
    temperature: float = field(
        default=0.7,
        metadata={"help": "Sampling temperature for text generation (0.0-2.0)"}
    )
    top_p: float = field(
        default=0.8,
        metadata={"help": "Top-p (nucleus) sampling parameter (0.0-1.0)"}
    )
    top_k: int = field(
        default=20,
        metadata={"help": "Top-k sampling parameter (number of tokens to consider)"}
    )
    repetition_penalty: float = field(
        default=1.05,
        metadata={"help": "Penalty for token repetition (1.0 = no penalty)"}
    )
    max_tokens: int = field(
        default=32000,
        metadata={"help": "Maximum number of tokens to generate"}
    )


def run_evaluation(script_args: ScriptArgs, dataset_args: DatasetArgs, model_args: ModelArgs
                   ) -> tuple[dict, pd.DataFrame]:
    # read dataset
    dataset: list[dict] = utils_read_dataset(dataset_args.dataset_path)
    # crete the tasks for the orchestrator
    multiple_tasks = [utils_create_single_task_orchestration(row, script_args, dataset_args, model_args)
                      for row in dataset]
    multiple_tasks = MultipleTasks(tasks=multiple_tasks, batch_size=script_args.batch_size)
    # run the evaluation
    completed_tasks = orchestrator_entrypoint.invoke(multiple_tasks)

    # process the results in a DataFrame
    df_completed_tasks = pd.DataFrame(flatten_multiple_tasks(completed_tasks))
    # create summary with the metrics
    results = df_completed_tasks[[v.value for v in script_args.metrics]]
    summary = results.mean().to_dict()
    return summary, df_completed_tasks


def utils_create_single_task_orchestration(row: dict,
                                           script_args: ScriptArgs,
                                           dataset_args: DatasetArgs,
                                           model_args: ModelArgs) -> SingleTask:
    data_cfg = _create_datacfg(dataset_args, row['db_id'])
    eval_cfg = _create_evalcfg(script_args)
    target_sql = row.pop(dataset_args.column_name_target)
    predicted_sql = row.pop(dataset_args.column_name_predicted)
    db_id = row.pop(dataset_args.column_name_db_id)

    return SingleTask(
        dataset_parameters=data_cfg,
        eval_parameters=eval_cfg,
        target_sql=SQLInstance(query=target_sql) if isinstance(target_sql, str) else
        [SQLInstance(query=sql) for sql in target_sql],
        predicted_sql=SQLInstance(query=predicted_sql) if isinstance(predicted_sql, str) else
        [SQLInstance(query=sql) for sql in predicted_sql],
        db_id=db_id,
        external_metadata=row | asdict(model_args),
    )


def _create_datacfg(dataset_args: DatasetArgs, db_id: str) -> DataCfg:
    return DataCfg(
        relative_db_base_path=dataset_args.relative_db_base_path,
        dataset_name=dataset_args.dataset_name,
        dialect=dataset_args.database_dialect,
        engine=utils_get_engine(relative_base_path=dataset_args.relative_db_base_path,
                                db_executor=dataset_args.database_dialect, db_id=db_id)
    )


def _create_evalcfg(script_args: ScriptArgs) -> EvalCfg:
    return EvalCfg(
        metrics=script_args.metrics
    )


if __name__ == "__main__":
    df = pd.read_json('data/bird_dev/dev.json')
    df['target_sql'] = df['SQL']
    df['predicted_sql'] = df['SQL']

    parser = TrlParser((ScriptArgs, DatasetArgs, ModelArgs))
    script_args, dataset_args, model_args = parser.parse_args_and_config()
    summary, df_eval = run_evaluation(script_args, dataset_args, model_args)
