import logging
from datetime import datetime, timedelta
from typing import List
from operator import itemgetter
import pytz

from influxdb_client.client.write_api import SYNCHRONOUS
import boto3
from botocore.exceptions import ClientError
import ujson

from algo.config import config
from algo.data.connection import DataConnection
from algo.data.common import CompletePricePoint

log = logging.getLogger(__name__)


class InfluxDataReporter(DataConnection):  # pylint: disable=too-few-public-methods
    def __init__(self, url, token, org):
        super().__init__(url, token, org)
        self._write_api = self._client.write_api(write_options=SYNCHRONOUS)

    def report(self, price: CompletePricePoint):
        json_body = [
            {
                "measurement": 'price',
                "tags": {
                    "instrument": price.instrument.value,
                },
                "time": price.time,
                "fields": {
                    "ask": price.ask,
                    "bid": price.bid,
                }
            }
        ]
        self._write_api.write(bucket='tsdata', record=json_body)
        log.debug('influx-data-sent')


# this one has a space to lower current limit is 5 requests per second
DATA_SENDING_INTERVAL = timedelta(milliseconds=5000)
ALREADY_EXISTS_ERROR_CODE = 'ResourceAlreadyExistsException'


class CloudWatchDataReporter:

    def _init_connection(self, log_group_name, log_stream_name):
        api = boto3.client('logs')

        try:
            api.create_log_group(logGroupName=log_group_name)
        except ClientError as exc:
            log.info(
                'error creating log group',
                extra={
                    'e': exc, 't': type(exc),
                    'error-details': exc.response
                }
            )
            if exc.response['Error']['Code'] != ALREADY_EXISTS_ERROR_CODE:
                raise

        try:
            api.create_log_stream(
                logGroupName=log_group_name,
                logStreamName=log_stream_name
            )
        except ClientError as exc:
            log.error(
                'error creating log stream',
                extra={
                    'e': exc, 't': type(exc),
                    'error-details': exc.response
                }
            )
            if exc.response['Error']['Code'] != ALREADY_EXISTS_ERROR_CODE:
                raise
        return api

    def __init__(self, log_group_name, log_stream_name) -> None:
        self._log_group_name = log_group_name
        self._log_stream_name = log_stream_name

        self._price_stack: List[CompletePricePoint] = []
        self._last_report_time = datetime.now()
        self._last_sequence_token = None
        self._api = self._init_connection(log_group_name, log_stream_name)

    def report(self, price: CompletePricePoint):
        self._price_stack.append(price)
        self._send_data()

    def _init_sequence_token(self):
        if self._last_sequence_token is None:
            streams = self._api.describe_log_streams(
                logGroupName=self._log_group_name, logStreamNamePrefix=self._log_stream_name
            )
            # uploadSequenceToken will be absent on first run for new stream
            self._last_sequence_token = streams['logStreams'][0].get(
                'uploadSequenceToken'
            )

    def _send_data(self):
        self._init_sequence_token()

        # TBD: need to catch potential errors and reconnect
        if datetime.now() - self._last_report_time > DATA_SENDING_INTERVAL:
            log.info(
                'sending-cloudtrail-message',
                extra={
                    'dt': datetime.now() - self._last_report_time,
                    'data-count': len(self._price_stack)
                }
            )
            self._last_report_time = datetime.now()
            records_batch = [{
                'timestamp': int(p.time.astimezone(pytz.utc).timestamp() * 1000),
                'message': ujson.dumps(p.dict)
            } for p in self._price_stack]

            log_events_args = {
                'logGroupName': self._log_group_name,
                'logStreamName': self._log_stream_name,
                'logEvents': sorted(records_batch, key=itemgetter('timestamp'), reverse=False),
            }
            self._price_stack = []

            if self._last_sequence_token is not None:
                log_events_args['sequenceToken'] = self._last_sequence_token

            resp = self._api.put_log_events(**log_events_args)
            log.debug('put_log_event_response', extra={'response': resp})
            self._last_sequence_token = resp.get('nextSequenceToken', None)


def get_influx_data_reporter() -> InfluxDataReporter:
    return InfluxDataReporter(
        config.influx_url,
        config.influx_token,
        config.influx_org
    )


def get_cloud_watch_reporter() -> CloudWatchDataReporter:
    return CloudWatchDataReporter(
        config.aws_cloud_watch_log_group_name,
        config.aws_cloud_watch_log_stream_name
    )
