from ...data.struct import AssetWeight, AssetPrice, AssetPosition, AssetValue
from ...data.struct import AssetTrade, FundTrade, FundPosition, FundWeight
from ...data.struct import AssetTradeParam, FundTradeParam
from ...data.struct import AssetTradeCache, FundWeightItem
from ...data.constant import TradeTrigger
from . import Helper
from .asset_helper import FAHelper
import numpy as np
import math
import datetime

class AssetTrader(Helper):

    TRADE_DAY_INTERVAL = 10

    def __init__(self, asset_param: AssetTradeParam=None):
        self.asset_param = asset_param or AssetTradeParam()
        self.last_trade_date = None

    def trade_interval_allow(self, dt):
        if self.last_trade_date is None or self.last_trade_date + datetime.timedelta(days=self.TRADE_DAY_INTERVAL) < dt:
            return True
        else:
            return False

    def calc_asset_trade(self, dt,
                               cur_position: AssetPosition,
                               cur_price: AssetPrice,
                               target_allocation: AssetWeight):
        cur_mv = AssetValue(prices=cur_price, positions=cur_position)
        tot_mv = cur_mv.sum()

        trades = []
        launch_trade = False
        for index_id, target_weight in target_allocation.__dict__.items():
            if index_id != 'cash':
                target_amt = tot_mv * target_weight
                p = cur_price.__dict__[index_id]
                cur_amt = cur_mv.__dict__[index_id]
                if abs(target_amt - cur_amt) > tot_mv * self.asset_param.MinCountAmtDiff:
                    amount = abs(target_amt - cur_amt)
                    is_buy = target_amt > cur_amt
                    if is_buy:
                        trades.append(AssetTrade(
                            index_id=index_id,
                            mark_price=p,
                            amount=amount,
                            is_buy=is_buy,
                            submit_date=dt
                        ))
                    else:
                        trades.append(AssetTrade(
                            index_id=index_id,
                            mark_price=p,
                            volume=amount/p,
                            is_buy=is_buy,
                            submit_date=dt
                        ))
                launch_trade = launch_trade or abs(target_amt - cur_amt) > tot_mv * self.asset_param.MinActionAmtDiff

        if not launch_trade:
            return cur_position, None
        else:
            trades.sort(key=lambda x: x.is_buy)
            new_position = cur_position.copy()
            for trd in trades:
                new_position.update(trd)
            return new_position, trades

    def finalize_trade(self, dt, trades: list,
                            t1_price: AssetPrice,
                            bt_position: AssetPosition):
        pendings = []
        traded_list = []
        if trades is None or len(trades) == 0:
            return pendings, traded_list
        # TODO: if some trades needs more time
        for trd in trades:
            # TODO: commision calculate
            #trd.commission = ?
            # update position
            trd.trade_price = t1_price.__dict__[trd.index_id]
            trd.volume = trd.volume if trd.volume else (trd.amount / trd.trade_price)
            trd.trade_date = dt
            if not trd.is_buy:
                if not(bt_position.__dict__[trd.index_id] - trd.volume > -1e-8):
                    #print(f'trade volume exceeds, adjusted to pos (index_id){trd.index_id} (vol){trd.volume} (is_buy){trd.is_buy} (pos){bt_position.__dict__[trd.index_id]}')
                    trd.volume = bt_position.__dict__[trd.index_id]
            trd.amount = trd.volume * trd.trade_price    
            if self.asset_param.EnableCommission:
                if trd.is_buy:
                    trd.volume = trd.amount / trd.trade_price / (1 + self.asset_param.PurchaseDiscount * self.asset_param.AssetPurchaseRate[trd.index_id])
                    trd.commission = trd.volume * trd.trade_price * self.asset_param.PurchaseDiscount * self.asset_param.AssetPurchaseRate[trd.index_id]
                else:
                    trd.commission = trd.amount * self.asset_param.RedeemDiscount * self.asset_param.AssetRedeemRate[trd.index_id]
                    trd.amount -= trd.commission
            else:
                trd.commission = 0
            bt_position.update(trd)
            traded_list.append(trd)
        
        self.last_trade_date = dt
        return pendings, traded_list

class FundTrader(AssetTrader):

    SMALL_POS_FLOAT = 1e-6
    BIG_INT = 1e10

    def __init__(self, asset_param: AssetTradeParam=None, fund_param: FundTradeParam=None):
        AssetTrader.__init__(self, asset_param=asset_param)
        self.fund_param = fund_param or FundTradeParam()

    def set_helper(self, fa_helper: FAHelper):
        self.fa_helper = fa_helper
        self.last_trade_date = None

    def has_expired_fund(self, cur_fund_position:FundPosition, _prep_fund_score:dict):
         # 如果持仓基金 没有分数， 返回True
        pos_fund_set = {fund_id for fund_id, fund_pos_i in cur_fund_position.funds.items() if fund_pos_i.volume > 0}
        score_set = set()
        for index_id, pos_i in _prep_fund_score.items():
            score_set.update(pos_i.keys())
        return bool(pos_fund_set.difference(score_set))

    # to be deprecated
    def calc_fund_trade(self, dt, fund_weight: FundWeight, cur_fund_position: FundPosition,
                            cur_fund_nav: dict,
                            fund_purchase_fees: dict,
                            fund_redeem_fees: dict) -> list:
        new_fund_position = cur_fund_position.copy()
        fund_trades = []
        # return trade list
        fund_tot_mv, cur_fund_wgts = cur_fund_position.calc_mv_n_w(fund_navs=cur_fund_nav)
        all_funds = {}
        # prepare fund candidates and its index_id
        for fund_id, fund_wgt_item in fund_weight.funds.items():
            all_funds[fund_id] = fund_wgt_item.index_id
        for fund_id, fund_pos_item in cur_fund_position.funds.items():
            all_funds[fund_id] = fund_pos_item.index_id
        # calc trade
        for fund_id, index_id in all_funds.items():
            target_fund_amt = fund_weight.get_wgt(fund_id) * fund_tot_mv
            cur_fund_volume = cur_fund_position.get_volume(fund_id) or 0
            p = cur_fund_nav[fund_id]
            cur_fund_amt = cur_fund_volume * p

            if abs(target_fund_amt - cur_fund_amt) > fund_tot_mv * self.fund_param.MinCountAmtDiff or (target_fund_amt == 0 and cur_fund_amt > 0):
                # TODO: commision and 如果是清某一只基金的逻辑，清空可以执行
                is_buy = target_fund_amt > cur_fund_amt
                if is_buy:
                    _trade = FundTrade(
                        fund_id=fund_id,
                        index_id=index_id,
                        mark_price=p,
                        amount=abs(target_fund_amt - cur_fund_amt),
                        is_buy=is_buy,
                        submit_date=dt
                    )
                else:
                    _trade = FundTrade(
                        fund_id=fund_id,
                        index_id=index_id,
                        mark_price=p,
                        volume=abs(target_fund_amt - cur_fund_amt)/p,
                        is_buy=is_buy,
                        submit_date=dt
                    )
                fund_trades.append(_trade)
                #print(f'(fund){fund_id} (p){p} (amt0){cur_fund_amt} (amt1){target_fund_amt} (idx){index_id} (amt){abs(target_fund_amt - cur_fund_amt)} (direc) {target_fund_amt > cur_fund_amt} ')

        fund_trades.sort(key=lambda x: x.is_buy)
        for _trade in fund_trades:
            new_fund_position.update(_trade)
        return new_fund_position, fund_trades

    def finalize_trade(self, dt, trades: list,
                            t1_price: AssetPrice,
                            bt_position: AssetPosition,
                            cur_fund_position: FundPosition,
                            cur_fund_nav: dict,
                            cur_fund_unit_nav: dict,
                            fund_purchase_fees: dict,
                            fund_redeem_fees: dict,
                            disproved_set: set):

        if trades is None or len(trades) == 0:
            return [], []

        pendings = []
        traded_list = []
        # TODO: if some trades needs more time
        for trd in trades:
            if trd.fund_id not in disproved_set:
                trd.is_permitted_fund = True
            else:
                trd.is_permitted_fund = False
            trd.trade_price = cur_fund_nav[trd.fund_id]
            trd.fund_unit_nav = cur_fund_unit_nav[trd.fund_id]
            trd.volume = trd.volume if trd.volume else (trd.amount / trd.trade_price)
            trd.amount = trd.amount if trd.amount else (trd.volume * trd.trade_price)
            trd.fund_unit_volume = trd.amount / trd.fund_unit_nav
            trd.trade_date = dt
            if not trd.is_buy:
                cur_vol = cur_fund_position.get_volume(trd.fund_id)
                if not((cur_vol or 0) - trd.volume > -1e-8):
                    #print(f'trade volume exceeds, adjusted to pos (fund_id){trd.fund_id} (vol){trd.volume} (is_buy){trd.is_buy} (pos){cur_vol}')
                    assert cur_vol is not None, 'sell fund with no current position!'
                    trd.volume = cur_vol
                    trd.amount = trd.volume * trd.trade_price
            if self.fund_param.EnableCommission:
                if trd.is_buy:
                    purchase_fee = fund_purchase_fees[trd.fund_id] * self.fund_param.PurchaseDiscount 
                    if np.isnan(purchase_fee):
                        #print(f'fund_id {trd.fund_id} purchase fee data not avaiable')
                        purchase_fee = 0
                    trd.volume = trd.amount / trd.trade_price / (1 + purchase_fee)
                    trd.commission = trd.volume * trd.trade_price * purchase_fee
                else:
                    redeem_fee = fund_redeem_fees[trd.fund_id] * self.fund_param.RedeemDiscount
                    if np.isnan(redeem_fee):
                        #print(f'fund_id {trd.fund_id} redeem fee data not avaiable')
                        redeem_fee = 0
                    trd.commission = trd.amount * redeem_fee
                    trd.amount -= trd.commission
            else:
                trd.commission = 0
            if trd.is_permitted_fund == False:
                pass
                #print(f'trade is not permitted : {trd}')

            trade_status = cur_fund_position.update(trd)
            if trade_status:
                traded_list.append(trd)
            else:
                pass
                #print(f'trade failed alert : {trd}')
        # get cur asset weight from fund weight
        cur_mv, cur_fund_weight = cur_fund_position.calc_mv_n_w(cur_fund_nav)
        fund_index_dic = {}
        for fund_id, fund_pos_item in cur_fund_position.funds.items():
            fund_index_dic[fund_id] = fund_pos_item.index_id
        asset_wgt = { _ : 0 for _ in set(fund_index_dic.values())}
        for fund_id, wgt_i in cur_fund_weight.items():
            index_id = fund_index_dic[fund_id]
            asset_wgt[index_id] += wgt_i
        # set index position:
        for index_id in asset_wgt:
            asset_p = t1_price.__dict__[index_id]
            amount = cur_mv * asset_wgt[index_id]
            bt_position.__dict__[index_id] = amount / asset_p
        
        self.last_trade_date = dt
        # fund cash -> asset cash
        bt_position.cash = cur_fund_position.cash
        return pendings, traded_list


    def judge_trade(self, dt, 
            index_fund_cache: list, # to change
            tar_fund_weight: FundWeight,
            tar_asset_weight: AssetWeight,
            cur_fund_position: FundPosition,
            cur_asset_position: AssetPosition,
            cur_fund_nav: dict,
            cur_fund_score: dict,
            cur_asset_price: AssetPrice) -> bool:

        index_rebalance_max_diff = 0
        fund_selection_min_score = 1
        fund_rebalance_min_score = 1
        fund_end = 0
        _index_rebalance_max_diff_name = ''
        _fund_selection_min_name = ''
        _fund_rebalance_min_name = ''
        fund_rank_list = []
        cur_asset_mv = AssetValue(prices=cur_asset_price, positions=cur_asset_position)
        cur_asset_weight = cur_asset_mv.get_weight()
        for index_id, index_tar_wgt in tar_asset_weight.__dict__.items():
            if index_id == 'cash':
                continue
            c = AssetTradeCache(index_id=index_id)
            c.index_tar_wgt = index_tar_wgt
            c.index_cur_wgt = cur_asset_weight.__dict__[index_id]
            c.index_diff = c.index_tar_wgt - c.index_cur_wgt
            if abs(c.index_diff) > index_rebalance_max_diff:
                index_rebalance_max_diff = abs(c.index_diff)
                _index_rebalance_max_diff_name = index_id

            # calc funds in each index
            c.fund_cur_wgts = cur_fund_position.calc_mv_n_w(cur_fund_nav, index_id)[1]
            c.cur_fund_ids = set(c.fund_cur_wgts.keys())
            c.fund_scores = sorted(cur_fund_score.get(index_id, {}).items(), key=lambda item: item[1], reverse=True)
            c.fund_ranks = {info[0]: rank + 1 for rank, info in enumerate(c.fund_scores)}
            c.proper_fund_num = min(self.fa_helper.get_max_fund_num(index_id), len(c.fund_scores)) if index_tar_wgt > self.SMALL_POS_FLOAT else 0
            # 维度1：现有基金的选取优化情况：（评估整体现有基金池的好坏）
            c.fund_judge_ranking_score = 0
            wgt_func = lambda _rank: 1.0 / math.pow(_rank, 1/3)
            _rank_list = []
            for f in c.cur_fund_ids:
                rank_i = c.fund_ranks.get(f, self.BIG_INT)
                if rank_i == self.BIG_INT:
                    fund_end = 1
                _rank_list.append((f, rank_i))
                c.fund_judge_ranking_score += wgt_func(rank_i)

            c.fund_judge_ranking_best = 0
            for i in range(0, c.proper_fund_num):
                c.fund_judge_ranking_best += wgt_func(i + 1)
            c.fund_judge_ranking = (c.fund_judge_ranking_score + self.SMALL_POS_FLOAT) / (c.fund_judge_ranking_best + self.SMALL_POS_FLOAT)
            
            # 维度2：现有基金的比例平均化情况：（评估现有基金比例的合理性）
            if len(c.cur_fund_ids) != c.proper_fund_num:
                c.fund_judge_diverse = 0
            else:
                c.fund_judge_diverse = 1
                for fund_id in c.cur_fund_ids:
                    old_wgt = c.fund_cur_wgts.get(fund_id, 0)
                    c.fund_judge_diverse *= old_wgt * c.proper_fund_num

            # to start or others, cur_fund_ids is few
            if len(c.cur_fund_ids) < c.proper_fund_num:
                for fund_id, _score in c.fund_scores:
                    if fund_id not in c.cur_fund_ids:
                        c.cur_fund_ids.add(fund_id)
                        if len(c.cur_fund_ids) == c.proper_fund_num:
                            break
            assert len(c.cur_fund_ids) >= c.proper_fund_num, 'cur fund id candidate should be no less than proper_fund_num'
            
            if index_tar_wgt > self.fund_param.DiffJudgeAssetWgtRequirement:
                if fund_selection_min_score > c.fund_judge_ranking:
                    fund_selection_min_score = c.fund_judge_ranking
                    _fund_selection_min_name = index_id
                    fund_rank_list = _rank_list

                if fund_rebalance_min_score > c.fund_judge_diverse:
                    fund_rebalance_min_score = c.fund_judge_diverse
                    _fund_rebalance_min_name = index_id
            index_fund_cache.append(c)

        # self.fund_param.DiffJudgeLambda = 0
        trigger = 0
        trigger += TradeTrigger.IndexRebalance if index_rebalance_max_diff > self.fund_param.JudgeIndexDiff else 0
        trigger += TradeTrigger.FundSelection if fund_selection_min_score < self.fund_param.JudgeFundSelection else 0
        trigger += TradeTrigger.FundRebalance if fund_rebalance_min_score < self.fund_param.JudgeFundRebalance else 0
        trigger += TradeTrigger.FundEnd if fund_end == 1 else 0

        trigger_reason = TradeTrigger.trigger_log(trigger,
                                                  index_rebalance_max_diff,
                                                  _index_rebalance_max_diff_name,
                                                  self.fund_param.JudgeIndexDiff,
                                                  fund_selection_min_score,
                                                  _fund_selection_min_name,
                                                  self.fund_param.JudgeFundSelection,
                                                  fund_rank_list,
                                                  fund_rebalance_min_score,
                                                  _fund_rebalance_min_name,
                                                  self.fund_param.JudgeFundRebalance)
        
        
        if trigger > 0:
            if self.trade_interval_allow(dt):
                #print(f'judge allow {dt} (trigger){TradeTrigger.parse(trigger)} (Fs){fund_selection_min_score}[{_fund_selection_min_name}]#{self.fund_param.JudgeFundSelection} (Fr){fund_rebalance_min_score}/[{_fund_rebalance_min_name}]#{self.fund_param.JudgeFundRebalance} (Ir){index_rebalance_max_diff}/[{_index_rebalance_max_diff_name}]#{self.fund_param.JudgeIndexDiff}')
                pass
            else:
                trigger = 0
                #print(f'judge dismiss {dt} (trigger){TradeTrigger.parse(trigger)} (Fs){fund_selection_min_score}[{_fund_selection_min_name}]#{self.fund_param.JudgeFundSelection} (Fr){fund_rebalance_min_score}/[{_fund_rebalance_min_name}]#{self.fund_param.JudgeFundRebalance} (Ir){index_rebalance_max_diff}/[{_index_rebalance_max_diff_name}]#{self.fund_param.JudgeIndexDiff}')
        return trigger, trigger_reason

    def calc_trade(self, dt,
            tar_fund_weight: FundWeight,
            tar_asset_weight: AssetWeight,
            cur_fund_position: FundPosition,
            cur_asset_position: AssetPosition,
            cur_fund_nav: dict,
            cur_fund_score: dict,
            cur_asset_price: AssetPrice) -> list:
        '''
        v_asset_position, asset_trade_list = self.calc_asset_trade(dt, cur_asset_position, cur_asset_price, tar_asset_weight)
        if not asset_trade_list:
            return v_asset_position, asset_trade_list
        else:
            return self.calc_fund_trade(dt, tar_fund_weight, cur_fund_position, cur_fund_nav, fund_purchase_fees, fund_redeem_fees)
        '''
        index_fund_cache = []

        trigger, trigger_reason = self.judge_trade(dt, index_fund_cache, 
                                    tar_fund_weight, tar_asset_weight,
                                    cur_fund_position, cur_asset_position, cur_fund_nav, 
                                    cur_fund_score, cur_asset_price)
        
        if trigger == 0:
            return cur_fund_position, [], {}

        fund_trades = self.get_trades(dt, index_fund_cache, 
                                    tar_fund_weight, tar_asset_weight,
                                    cur_fund_position, cur_asset_position, cur_fund_nav, 
                                    cur_fund_score, cur_asset_price)


        fund_trades.sort(key=lambda x: x.is_buy)

        new_fund_position = cur_fund_position.copy()

        for _trade in fund_trades:
            _trade.trigger = trigger
            new_fund_position.update(_trade)
        
        return new_fund_position, fund_trades, trigger_reason


    def get_trades(self, dt,
            index_fund_cache, 
            tar_fund_weight: FundWeight,
            tar_asset_weight: AssetWeight,
            cur_fund_position: FundPosition,
            cur_asset_position: AssetPosition,
            cur_fund_nav: dict,
            cur_fund_score: dict,
            cur_asset_price: AssetPrice):
        
        fund_tot_mv, cur_fund_wgts = cur_fund_position.calc_mv_n_w(fund_navs=cur_fund_nav)
        
        fund_trades = []

        for c in index_fund_cache:
            index_id = c.index_id
            tar_fund_list = tar_fund_weight.get_funds(index_id)
            cur_fund_list = cur_fund_position.get_funds(index_id)
            for fund_id in set(tar_fund_list).union(cur_fund_list):
                target_fund_amt = tar_fund_weight.get_wgt(fund_id) * fund_tot_mv
                cur_fund_volume = cur_fund_position.get_volume(fund_id) or 0
                p = cur_fund_nav[fund_id]
                cur_fund_amt = cur_fund_volume * p

                if abs(target_fund_amt - cur_fund_amt) > fund_tot_mv * self.fund_param.MinCountAmtDiff \
                    or (target_fund_amt == 0 and cur_fund_amt > 0):
                    # 如果是清某一只基金的逻辑，清空可以执行
                    # commision 放到 finalize_trade 部分执行
                    is_buy = target_fund_amt > cur_fund_amt
                    if is_buy:
                        _trade = FundTrade(
                            fund_id=fund_id,
                            index_id=index_id,
                            mark_price=p,
                            amount=abs(target_fund_amt - cur_fund_amt),
                            is_buy=is_buy,
                            submit_date=dt
                        )
                    else:
                        _trade = FundTrade(
                            fund_id=fund_id,
                            index_id=index_id,
                            mark_price=p,
                            volume=abs(target_fund_amt - cur_fund_amt)/p,
                            is_buy=is_buy,
                            submit_date=dt,
                            is_to_cleanup=(target_fund_amt == 0 and cur_fund_amt > 0)
                        )
                    fund_trades.append(_trade)
                    # print(f'(fund){fund_id} (d){"buy" if is_buy else "sell"} (r){c.fund_ranks.get(fund_id, -1)} (p){p} (amt){target_fund_amt - cur_fund_amt} (tar){target_fund_amt} (ind){index_id}')
        return fund_trades
