import abc
import datetime
import pickle
import uuid
from collections import defaultdict
from multiprocessing import shared_memory
from typing import Self

from PyQuantKit import TickData, TradeData, OrderBook, MarketData, TransactionSide

from . import LOGGER
from ..profile import Profile, DefaultProfile, CN_Profile

__all__ = ['MDS', 'MarketDataService', 'MarketDataMonitor', 'MonitorManager']


class MarketDataMonitor(object, metaclass=abc.ABCMeta):
    """
    this is a template for market data monitor

    A data monitor is a module that process market data and generate custom index

    When MDS receive an update of market data, the __call__ function of this monitor is triggered.

    Note: all the market_data, of all subscribed ticker will be fed into monitor. It should be assumed that a storage for multiple ticker is required.
    To access the monitor, use `monitor = MDS[monitor_id]`
    To access the index generated by the monitor, use `monitor.value`
    To indicate that the monitor is ready to use set `monitor.is_ready = True`

    The implemented monitor should be initialized and use `MDS.add_monitor(monitor)` to attach onto the engine
    """

    def __init__(self, name: str, monitor_id: str = None):
        self.name: str = name
        self.monitor_id: str = uuid.uuid4().hex if monitor_id is None else monitor_id
        self.enabled: bool = True

    @abc.abstractmethod
    def __call__(self, market_data: MarketData, **kwargs):
        ...

    def __reduce__(self):
        return self.__class__.from_json, (self.to_json(),)

    @abc.abstractmethod
    def to_json(self, fmt='str') -> dict | str:
        ...

    @classmethod
    @abc.abstractmethod
    def from_json(cls, json_message: str | bytes | bytearray | dict) -> Self:
        ...

    def to_shm(self, name: str = None) -> str:
        """
        Put the data of the monitor into python shared memory.
        This function is designed to facilitate multiprocessing.
        Some monitor is not advised to be handled concurrently,
        In which case, raise a NotImplementedError.

        The function is expected to put all data into a sharable list,
        and return the name of the list, which can be set by the given name.
        Default name = self.monitor_id

        Note that this method HAVE NO LOCK, use with caution.
        """
        if name is None:
            name = f'{self.monitor_id}.json'

        data = pickle.dumps(self.to_json(fmt='dict'))
        size = len(data)

        try:
            shm = shared_memory.SharedMemory(name=name)

            if shm.size != size:
                shm.close()
                shm.unlink()
                shm = shared_memory.SharedMemory(create=True, size=size, name=name)
        except FileNotFoundError as _:
            shm = shared_memory.SharedMemory(create=True, size=size, name=name)

        shm.buf[:size] = data
        shm.close()
        return name

    @classmethod
    def from_shm(cls, monitor_id: str):
        """
        retrieve the data and update the monitor from shared memory.
        This function is designed to facilitate multiprocessing.
        """
        return

    @abc.abstractmethod
    def clear(self) -> None:
        ...

    @property
    @abc.abstractmethod
    def value(self) -> dict[str, float] | float:
        ...

    @property
    def is_ready(self) -> bool:
        return True


class MonitorManager(object):
    """
    manage market data monitor

    state codes for the manager
    0: idle
    1: working
    -1: terminating
    """

    def __init__(self):
        self.monitor: dict[str, MarketDataMonitor] = {}

    def __call__(self, market_data: MarketData):
        for monitor_id in self.monitor:
            self._work(monitor_id=monitor_id, market_data=market_data)

    def add_monitor(self, monitor: MarketDataMonitor):
        self.monitor[monitor.monitor_id] = monitor

    def pop_monitor(self, monitor_id: str) -> MarketDataMonitor:
        return self.monitor.pop(monitor_id)

    def _work(self, monitor_id: str, market_data: MarketData):
        monitor = self.monitor.get(monitor_id)
        if monitor is not None and monitor.enabled:
            monitor.__call__(market_data)

    def start(self):
        pass

    def stop(self):
        pass

    def clear(self):
        self.monitor.clear()

    @property
    def values(self) -> dict[str, float]:
        values = {}

        for monitor in self.monitor.values():
            values.update(monitor.value)

        return values


class MarketDataService(object):
    def __init__(self, profile: Profile = None, **kwargs):
        self.profile = DefaultProfile() if profile is None else profile
        self.cache_history = kwargs.pop('cache_history', False)

        self._market_price = {}
        self._market_history = defaultdict(dict)
        self._market_time: datetime.datetime | None = None
        self._timestamp: float | None = None

        self._order_book: dict[str, OrderBook] = {}
        self._tick_data: dict[str, TickData] = {}
        self._trade_data: dict[str, TradeData] = {}
        self._monitor: dict[str, MarketDataMonitor] = {}
        self._monitor_manager = MonitorManager()

    def __call__(self, **kwargs):
        if 'market_data' in kwargs:
            self.on_market_data(market_data=kwargs['market_data'])

    def __getitem__(self, monitor_id: str) -> MarketDataMonitor:
        return self.monitor[monitor_id]

    def add_monitor(self, monitor: MarketDataMonitor):
        self.monitor[monitor.monitor_id] = monitor
        self.monitor_manager.add_monitor(monitor)

    def pop_monitor(self, monitor: MarketDataMonitor = None, monitor_id: str = None, monitor_name: str = None):
        if monitor_id is not None:
            pass
        elif monitor_name is not None:
            for _ in list(self.monitor.values()):
                if _.name == monitor_name:
                    monitor_id = _.monitor_id
            if monitor is None:
                LOGGER.error(f'monitor_name {monitor_name} not registered.')
        elif monitor is not None:
            monitor_id = monitor.monitor_id
        else:
            LOGGER.error('must assign a monitor, or monitor_id, or monitor_name to pop.')
            return None

        self.monitor.pop(monitor_id)
        self.monitor_manager.pop_monitor(monitor_id)

    def init_cn_override(self):
        self.profile = CN_Profile()

    def _on_trade_data(self, trade_data: TradeData):
        ticker = trade_data.ticker

        if ticker not in self._trade_data:
            LOGGER.info(f'MDS confirmed {ticker} TradeData subscribed!')

        self._trade_data[ticker] = trade_data

    def _on_tick_data(self, tick_data: TickData):
        ticker = tick_data.ticker

        if ticker not in self._tick_data:
            LOGGER.info(f'MDS confirmed {ticker} TickData subscribed!')

        self._tick_data[ticker] = tick_data
        # self._order_book[ticker] = tick_data.order_book

    def _on_order_book(self, order_book):
        ticker = order_book.ticker

        if ticker not in self._order_book:
            LOGGER.info(f'MDS confirmed {ticker} OrderBook subscribed!')

        self._order_book[ticker] = order_book

    def on_market_data(self, market_data: MarketData):
        ticker = market_data.ticker
        market_time = market_data.market_time
        timestamp = market_data.timestamp
        market_price = market_data.market_price

        self._market_price[ticker] = market_price
        self._market_time = market_time
        self._timestamp = timestamp

        if self.cache_history:
            self._market_history[ticker][market_time] = market_price

        if isinstance(market_data, TradeData):
            self._on_trade_data(trade_data=market_data)
        elif isinstance(market_data, TickData):
            self._on_tick_data(tick_data=market_data)
        elif isinstance(market_data, OrderBook):
            self._on_order_book(order_book=market_data)

        self.monitor_manager.__call__(market_data=market_data)

    def get_order_book(self, ticker: str) -> OrderBook | None:
        return self._order_book.get(ticker, None)

    def get_queued_volume(self, ticker: str, side: TransactionSide | str | int, prior: float, posterior: float = None) -> float:
        """
        get queued volume prior / posterior to given price, NOT COUNTING GIVEN PRICE!
        :param ticker: the given ticker
        :param side: the given trade side
        :param prior: the given price
        :param posterior: optional the given posterior price
        :return: the summed queued volume, in float.
        """
        order_book = self.get_order_book(ticker=ticker)

        if order_book is None:
            queued_volume = float('nan')
        else:
            trade_side = TransactionSide(side)

            if trade_side.sign > 0:
                book = order_book.bid
            elif trade_side < 0:
                book = order_book.ask
            else:
                raise ValueError(f'Invalid side {side}')

            queued_volume = book.loc_volume(p0=prior, p1=posterior)
        return queued_volume

    def trade_time_between(self, start_time: datetime.datetime | float, end_time: datetime.datetime | float, **kwargs) -> datetime.timedelta:
        return self.profile.trade_time_between(start_time=start_time, end_time=end_time, **kwargs)

    def in_trade_session(self, market_time: datetime.datetime | float) -> bool:
        return self.profile.in_trade_session(market_time=market_time)

    def clear(self):
        # self._market_price.clear()
        # self._market_time = None
        # self._timestamp = None

        self._market_history.clear()
        self._order_book.clear()
        self.monitor.clear()
        self.monitor_manager.clear()

    @property
    def market_price(self) -> dict[str, float]:
        result = self._market_price
        return result

    @property
    def market_history(self) -> dict[str, dict[datetime.datetime, float]]:
        result = self._market_history
        return result

    @property
    def market_time(self) -> datetime.datetime | None:
        if self._market_time is None:
            if self._timestamp is None:
                return None
            else:
                return datetime.datetime.fromtimestamp(self._timestamp)
        else:
            return self._market_time

    @property
    def market_date(self) -> datetime.date | None:
        if self.market_time is None:
            return None

        return self._market_time.date()

    @property
    def timestamp(self) -> float | None:
        if self._timestamp is None:
            if self._market_time is None:
                return None
            else:
                return self._market_time.timestamp()
        else:
            return self._timestamp

    @property
    def session_start(self) -> datetime.time | None:
        return self.profile.session_start

    @property
    def session_end(self) -> datetime.time | None:
        return self.profile.session_end

    @property
    def session_break(self) -> tuple[datetime.time, datetime.time] | None:
        return self.profile.session_break

    @property
    def monitor(self) -> dict[str, MarketDataMonitor]:
        return self._monitor

    @property
    def monitor_manager(self) -> MonitorManager:
        return self._monitor_manager

    @monitor_manager.setter
    def monitor_manager(self, manager: MonitorManager):
        self._monitor_manager.clear()

        self._monitor_manager = manager

        for monitor in self.monitor.values():
            self._monitor_manager.add_monitor(monitor=monitor)


MDS = MarketDataService()
