"""Functions to batch process trades into dataframes for analysis.
"""
from __future__ import annotations
from typing import Tuple, List

from .live_odds import live_odds
from .qpbanker import win_probability, expected_value, average_odds
from .optimization import _pareto_filter

import polars as pl
import numpy as np
from itertools import combinations
from tqdm import tqdm


def _all_subsets(lst): return [list(x) for r in range(
    1, len(lst)+1) for x in combinations(lst, r)]  # list subsets of a list


def _process_single_qp_trade(banker: int, covered: List[int], odds_pla: List[float], odds_qpl: List[float], rebate: float) -> Tuple[int, List, float, float]:
    """Process a single qp trade.
    """
    win_prob = win_probability(odds_pla, banker, covered)
    exp_value = expected_value(odds_pla, odds_qpl, banker, covered, rebate)
    ave_odds = average_odds(odds_qpl, banker, covered)
    return (banker, covered, win_prob, exp_value, ave_odds)


def generate_all_qp_trades(date: str, venue_code: str, race_number: int, rebate: float = 0.12, fit_harville=False) -> pl.DataFrame:
    """Generate all possible qp tickets for the specified race.

    Args:
        date (str): Date in 'YYYY-MM-DD' format.
        venue_code (str): Venue code, e.g., 'ST' for Shatin, 'HV' for Happy Valley.
        race_number (int): Race number.
        rebate (float, optional): The rebate percentage. Defaults to 0.12.
        fit_harville (bool, optional): Whether to fit the odds using Harville model. Defaults to False.

    Returns:
        pl.DataFrame: DataFrame with all possible trades and their metrics.
    """

    odds = live_odds(date, venue_code, race_number,
                     odds_type=['PLA', 'QPL'], fit_harville=fit_harville)
    N = len(odds['PLA'])
    candidates = np.arange(1, N+1)

    results = [_process_single_qp_trade(banker, covered, odds['PLA'], odds['QPL'], rebate)
               for banker in tqdm(candidates, desc="Processing bankers")
               for covered in _all_subsets(candidates[candidates != banker])]

    df = (pl.DataFrame(results, schema=['Banker', 'Covered', 'WinProb', 'ExpValue', 'AvgOdds'])
          .with_columns(pl.col('Covered').list.len().alias('NumCovered')))

    return df


def generate_pareto_qp_trades(date: str, venue_code: str, race_number: int, rebate: float = 0.12, groupby: List[str] = [], fit_harville=False) -> pl.DataFrame:
    """Generate qp tickets that are Pareto optimal for the specified race.

    Args:
        date (str): Date in 'YYYY-MM-DD' format.
        venue_code (str): Venue code, e.g., 'ST' for Shatin, 'HV' for Happy Valley.
        race_number (int): Race number.
        rebate (float, optional): The rebate percentage. Defaults to 0.12.
        groupby (List[str], optional): Columns to group by when determining Pareto optimality. Defaults to [] (global optimal).
        harville_fit (bool, optional): Whether to fit the odds using Harville model. Defaults to False.

    Returns:
        pl.DataFrame: DataFrame with all Pareto trades and their metrics.
    """
    df = generate_all_qp_trades(date, venue_code, race_number, rebate, harville_fit=fit_harville)
    pareto_df = _pareto_filter(df, groupby=groupby, by=[
                               'WinProb', 'ExpValue'], maximize=True)
    return pareto_df
