from datetime import datetime
from typing import Dict, List

from sqlalchemy import or_
from sqlalchemy.orm import aliased

from mage_ai.api.operations.constants import META_KEY_LIMIT
from mage_ai.api.resources.GenericResource import GenericResource
from mage_ai.api.utils import get_query_timestamps
from mage_ai.data_preparation.logging.logger_manager_factory import LoggerManagerFactory
from mage_ai.data_preparation.models.block.constants import LOG_PARTITION_EDIT_PIPELINE
from mage_ai.data_preparation.models.file import File
from mage_ai.data_preparation.models.pipeline import Pipeline
from mage_ai.orchestration.db import safe_db_query
from mage_ai.orchestration.db.models.schedules import (
    BlockRun,
    PipelineRun,
    PipelineSchedule,
)
from mage_ai.server.logger import Logger

MAX_LOG_FILES = 20
server_logger = Logger().new_server_logger(__name__)


class LogResource(GenericResource):
    @classmethod
    @safe_db_query
    async def collection(self, query, meta, user, **kwargs):
        parent_model = kwargs['parent_model']

        arr = []
        if type(parent_model) is BlockRun:
            arr = parent_model.logs
        elif issubclass(parent_model.__class__, Pipeline):
            arr = await self.__pipeline_logs(parent_model, query, meta)

        return self.build_result_set(
            arr,
            user,
            **kwargs,
        )

    @classmethod
    @safe_db_query
    async def __pipeline_logs(self, pipeline: Pipeline, query_arg, meta) -> List[Dict]:
        pipeline_uuid = pipeline.uuid

        start_timestamp, end_timestamp = get_query_timestamps(query_arg)
        pipeline_schedule_ids = query_arg.get('pipeline_schedule_id[]', [None])
        if pipeline_schedule_ids:
            pipeline_schedule_ids = pipeline_schedule_ids[0]
        if pipeline_schedule_ids:
            pipeline_schedule_ids = pipeline_schedule_ids.split(',')
        else:
            pipeline_schedule_ids = []

        block_uuids = query_arg.get('block_uuid[]', [None])
        if block_uuids:
            block_uuids = block_uuids[0]
        if block_uuids:
            block_uuids = block_uuids.split(',')
        else:
            block_uuids = []

        pipeline_run_ids = query_arg.get('pipeline_run_id[]', [None])
        if pipeline_run_ids:
            pipeline_run_ids = pipeline_run_ids[0]
        if pipeline_run_ids:
            pipeline_run_ids = pipeline_run_ids.split(',')
        else:
            pipeline_run_ids = []

        block_run_ids = query_arg.get('block_run_id[]', [None])
        if block_run_ids:
            block_run_ids = block_run_ids[0]
        if block_run_ids:
            block_run_ids = block_run_ids.split(',')
        else:
            block_run_ids = []

        a = aliased(PipelineRun, name='a')
        b = aliased(PipelineSchedule, name='b')
        c = aliased(BlockRun, name='c')

        columns = [
            a.execution_date,
            a.pipeline_schedule_id,
            a.pipeline_schedule_id,
            a.pipeline_uuid,
            a.variables,
        ]

        total_pipeline_run_log_count = 0
        pipeline_run_logs = []

        @safe_db_query
        def get_pipeline_runs():
            query = (
                PipelineRun.
                select(*columns).
                join(b, a.pipeline_schedule_id == b.id).
                filter(b.pipeline_uuid == pipeline_uuid).
                order_by(a.created_at.desc())
            )

            if len(pipeline_schedule_ids):
                query = (
                    query.
                    filter(a.pipeline_schedule_id.in_(pipeline_schedule_ids))
                )

            if len(pipeline_run_ids):
                query = (
                    query.
                    filter(a.id.in_(pipeline_run_ids))
                )

            if start_timestamp:
                query = (
                    query.
                    filter(a.execution_date >= start_timestamp)
                )

            if end_timestamp:
                query = (
                    query.
                    filter(a.execution_date <= end_timestamp)
                )
            total_pipeline_run_log_count = query.count()
            if meta.get(META_KEY_LIMIT, None) is not None:
                rows = query.limit(meta[META_KEY_LIMIT])
            else:
                rows = query.all()
            return dict(
                total_pipeline_run_log_count=total_pipeline_run_log_count,
                rows=rows,
            )

        if not len(block_uuids) and not len(block_run_ids):
            pipeline_run_results = get_pipeline_runs()
            total_pipeline_run_log_count = pipeline_run_results['total_pipeline_run_log_count']
            pipeline_run_rows = pipeline_run_results['rows']

            processed_pipeline_run_log_files = set()
            for row in pipeline_run_rows:
                model = PipelineRun()
                model.execution_date = row.execution_date
                model.pipeline_schedule_id = row.pipeline_schedule_id
                model.pipeline_uuid = row.pipeline_uuid
                model.variables = row.variables
                logs = await model.logs_async()
                for logs_item in logs:
                    pipeline_log_file_path = logs_item.get('path')
                    if pipeline_log_file_path not in processed_pipeline_run_log_files:
                        pipeline_run_logs.append(logs_item)
                        processed_pipeline_run_log_files.add(pipeline_log_file_path)
                if len(pipeline_run_logs) >= MAX_LOG_FILES:
                    break

        @safe_db_query
        def get_block_runs():
            query = (
                BlockRun.
                select(*(columns + [
                    c.block_uuid,
                ])).
                join(a, a.id == c.pipeline_run_id).
                join(b, a.pipeline_schedule_id == b.id).
                filter(b.pipeline_uuid == pipeline_uuid).
                order_by(c.started_at.desc())
            )

            if len(block_uuids):
                ors = []
                for block_uuid in block_uuids:
                    ors.append(c.block_uuid.like(f'{block_uuid}%'))

                query = (
                    query.
                    filter(or_(*ors))
                )

            if len(block_run_ids):
                query = (
                    query.
                    filter(c.id.in_(block_run_ids))
                )

            if len(pipeline_schedule_ids):
                query = (
                    query.
                    filter(a.pipeline_schedule_id.in_(pipeline_schedule_ids))
                )

            if len(pipeline_run_ids):
                query = (
                    query.
                    filter(a.id.in_(pipeline_run_ids))
                )

            if start_timestamp:
                query = (
                    query.
                    filter(a.execution_date >= start_timestamp)
                )

            if end_timestamp:
                query = (
                    query.
                    filter(a.execution_date <= end_timestamp)
                )

            if meta.get(META_KEY_LIMIT, None) is not None:
                rows = query.limit(meta[META_KEY_LIMIT])
            else:
                rows = query.all()

            total_block_run_log_count = query.count()
            return dict(
                total_block_run_log_count=total_block_run_log_count,
                rows=rows,
            )

        block_run_results = get_block_runs()
        total_block_run_log_count = block_run_results['total_block_run_log_count']
        rows = block_run_results['rows']

        block_run_logs = []

        processed_block_run_log_files = set()
        for row in rows:
            model = PipelineRun()
            model.execution_date = row.execution_date
            model.pipeline_schedule_id = row.pipeline_schedule_id
            model.pipeline_uuid = row.pipeline_uuid
            model.variables = row.variables

            model2 = BlockRun()
            model2.block_uuid = row.block_uuid
            model2.pipeline_run = model

            logs = await model2.logs_async(repo_path=pipeline.repo_path if pipeline else None)
            block_log_file_path = logs.get('path')
            if block_log_file_path not in processed_block_run_log_files:
                block_run_logs.append(logs)
                processed_block_run_log_files.add(block_log_file_path)

            if len(block_run_logs) >= MAX_LOG_FILES:
                break

        for block in pipeline.blocks_by_uuid.values():
            logger = LoggerManagerFactory.get_logger_manager(
                partition=LOG_PARTITION_EDIT_PIPELINE,
                pipeline_uuid=pipeline_uuid,
                subpartition=block.uuid,
            )

            def __filter(file: File) -> bool:
                should_add = True

                try:
                    dsts = datetime.strptime(file.filename.split('.')[0], '%Y%m%dT%H%M%S')

                    if start_timestamp:
                        should_add = should_add and start_timestamp <= dsts
                    if end_timestamp:
                        should_add = should_add and dsts <= end_timestamp
                except ValueError as err:
                    server_logger.warning(f'__filter: {err}')
                    should_add = False

                return should_add

            block_run_logs += await logger.get_logs_in_subpartition_async(filter_func=__filter)

        return [
            {
                'block_run_logs': block_run_logs,
                'pipeline_run_logs': pipeline_run_logs,
                'total_block_run_log_count': total_block_run_log_count,
                'total_pipeline_run_log_count': total_pipeline_run_log_count,
            },
        ]
