import time
import logging
from typing import Callable, Union

from ujson import JSONDecodeError as UJSONDecodeError

import v20
from v20.errors import V20Timeout, V20ConnectionError
from v20.pricing import ClientPrice, PricingHeartbeat


from algo.config import config
from algo.data.common import PricePoint, PriceStream, CompletePricePoint, CompletePriceStream
from algo.constant import Symbol
from algo.util import d
from algo.metric import incr_message_counter


log = logging.getLogger(__name__)


def price_to_dict(price: ClientPrice):
    return {
        'type': type(price).__name__,
        'dict': price and price.dict()
    }


def heartbeat_to_dict(heartbeat: PricingHeartbeat):
    return {
        'type': type(heartbeat).__name__,
        'dict': heartbeat and heartbeat.dict()
    }


def get_prices_generator(stream) -> PriceStream:
    for msg_type, msg in stream.parts():
        log.debug(
            'got message', extra={
                'message_type': msg_type,
                'dict': msg.dict()
            }
        )
        incr_message_counter()
        if msg_type == 'pricing.PricingHeartbeat':
            log.debug('pricing.PricingHeartbeat', extra=heartbeat_to_dict(msg))
        elif msg_type == 'pricing.ClientPrice':
            log.debug('pricing.ClientPrice', extra=price_to_dict(msg))
            if PricePoint.is_valid_price_data(msg):
                yield PricePoint.from_oanda_price(msg)
        else:
            log.info('unexpected-stream-message', extra=d(msg))


def get_complete_prices_generator(stream) -> CompletePriceStream:
    for msg_type, msg in stream.parts():
        log.debug(
            'got message', extra={
                'message_type': msg_type,
                'dict': msg.dict()
            }
        )
        incr_message_counter()
        if msg_type == 'pricing.PricingHeartbeat':
            log.debug('pricing.PricingHeartbeat', extra=heartbeat_to_dict(msg))
        elif msg_type == 'pricing.ClientPrice':
            log.debug('pricing.ClientPrice', extra=price_to_dict(msg))
            if CompletePricePoint.is_valid_price_data(msg):
                yield CompletePricePoint.from_oanda_price(msg)
        else:
            log.info('unexpected-stream-message', extra=d(msg))


AnyPriceStream = Union[PriceStream, CompletePriceStream]
PriceReader = Callable[
    [any],  # I don't want to define Oanda's class that has method parts that returns generator of PricingHeartbeat | ClientPrice
    AnyPriceStream
]


def get_live_price_stream(symbol: Symbol = None, price_reader: PriceReader = get_prices_generator) -> AnyPriceStream:
    """
    first in a queue to cover with tests
    """
    symbol_str = None
    if symbol is not None:
        symbol_str = symbol.value
    else:
        symbol_str = Symbol.all_values()

    api = v20.Context(
        hostname=config.oanda_stream_host_name,
        token=config.oanda_token,
    )
    while True:
        # exit point is unexpected  error
        try:
            log.info('getting-new-stream')
            stream = api.pricing.stream(
                accountID=config.oanda_account_id,
                instruments=symbol_str,
                snapshot=True
            )

            yield from price_reader(stream)
        except (V20ConnectionError, V20Timeout) as err:
            # stream can timeout, so re-connecting
            log.error(
                'catched-v20error',
                extra=d(err, err_type=type(err)),
                exc_info=err
            )
            time.sleep(1.0)
        except UJSONDecodeError as json_decode_error:
            log.error('json-decode-error', extra=d(json_decode_error))

        except Exception as unexpected_error:
            # die bug log why exactly
            log.error(
                'unexpected-error',
                extra=(d(unexpected_error,
                         err_type=type(unexpected_error))),
                exc_info=unexpected_error
            )
            raise
