import inspect
from typing import List
from fastapi import Request
from fastapi.responses import StreamingResponse

from ..protocol import Message, Chat, LengthException, StopSequenceException
from ..base import get_chat, exec_chat, compute_tokens
from . import helpers
from .protocol import (
    MessagesRequest,
    MessagesResponse,
    MessagesResponseUsage,
    StreamMessagesResponseMessage,
    StreamMessagesResponseMessageStart,
    ResponseMessageContentBlockText,
    StreamMessagesResponseContentBlockStart,
    StreamMessagesResponseContentBlockDelta,
    StreamMessagesResponseContentBlockStop,
    StreamMessagesResponseMessageDeltaDelta,
    StreamMessagesResponseMessageDelta,
    StreamMessagesResponseMessageStop,
)

def get_messages(request: MessagesRequest) -> List[Message]:
    anthropic_messages = request.messages
    if not isinstance(anthropic_messages, list):
        raise ValueError('\'messages\' must be a list.')

    messages: List[Message] = []
    if request.system:
        if isinstance(request.system, str):
            messages.append({ 'role': 'system', 'content': request.system })
        elif isinstance(request.system, list):
            system_content = ''
            for block in request.system:
                system_content += block.text \
                    if block.type == 'text' else f'[{block.type}]'
            messages.append({ 'role': 'system', 'content': system_content })
        else:
            raise ValueError('\'system\' must be a string or a list of blocks.')

    for (idx, msg) in enumerate(anthropic_messages):
        # we need also handle `system` roles correctly here
        prev_role = messages[idx - 1]['role'] if idx > 0 else None
        role = msg.role or ('user' if prev_role != 'user' else 'assistant')

        if isinstance(msg.content, str):
            messages.append({ 'role': role, 'content': msg.content })
            continue
        if not isinstance(msg.content, list):
            raise ValueError(f'Message at index {idx} is not a string or a list.')

        message_content = ''
        for block in msg.content:
            message_content += block.text \
                if block.type == 'text' else f'[{block.type}]'
        messages.append({ 'role': role, 'content': message_content })
    return messages

def get_arguments(request: MessagesRequest, sig: inspect.Signature) -> dict:
    arguments = {}
    if 'model' in sig.parameters:
        arguments['model'] = request.model
    if 'temperature' in sig.parameters:
        arguments['temperature'] = request.temperature
    if 'top_p' in sig.parameters:
        arguments['top_p'] = request.top_p
    if 'top_k' in sig.parameters:
        arguments['top_k'] = request.top_k
    if 'max_tokens' in sig.parameters:
        arguments['max_tokens'] = request.max_tokens
    if 'stop' in sig.parameters:
        arguments['stop'] = request.stop_sequences
    if 'stream' in sig.parameters:
        arguments['stream'] = request.stream
    if 'user' in sig.parameters:
        arguments['user'] = request.metadata.user_id \
            if request.metadata else None
    return arguments

async def create_chat_message_response(
        request: MessagesRequest,
        chat: Chat,
        arguments: dict,
        messages: List[Message],
    ) -> MessagesResponse:
    stop_reason = 'end_turn'
    stop_sequence = None

    try:
        response_content = ''
        async for chunk in exec_chat(chat, arguments):
            response_content += chunk
    except LengthException as e:
        stop_reason = 'max_tokens'
    except StopSequenceException as e:
        stop_reason = 'stop_sequence'
        stop_sequence = e.sequence
    except Exception as e:
        stop_reason = 'refusal'

    input_tokens = 0
    for msg in messages:
        input_tokens += await compute_tokens(
            msg.get('content', ''), request.model)
    output_tokens = await compute_tokens(response_content, request.model)

    return MessagesResponse(
        id=f'msg_{helpers.uuid()}',
        type='message',
        role='assistant',
        content=[ResponseMessageContentBlockText(
            type='text',
            text=response_content,
        )],
        model=request.model,
        stop_reason=stop_reason,
        stop_sequence=stop_sequence,
        usage=MessagesResponseUsage(
            input_tokens=input_tokens,
            output_tokens=output_tokens,
        ),
    )

async def create_streaming_chat_message_response(
        request: MessagesRequest,
        chat: Chat,
        arguments: dict,
        messages: List[Message],
    ) -> StreamingResponse:
    async def create_streaming_events():
        input_tokens = 0
        for msg in messages:
            input_tokens += await compute_tokens(
                msg.get('content', ''), request.model)
        output_tokens = 0

        yield [StreamMessagesResponseMessageStart(
                type='message_start',
                message=StreamMessagesResponseMessage(
                    type='message',
                    id=f'msg_{helpers.uuid()}',
                    role='assistant',
                    content=[],
                    model=request.model,
                    usage=MessagesResponseUsage(
                        input_tokens=input_tokens,
                        output_tokens=output_tokens,
                    ),
                ),
            ), 'message_start']
        yield [StreamMessagesResponseContentBlockStart(
                type='content_block_start',
                index=0,
                content_block=ResponseMessageContentBlockText(
                    type='text',
                    text='',
                ),
            ), 'content_block_start']

        try:
            response_content: list[str] = []
            async for chunk in exec_chat(chat, arguments):
                yield [StreamMessagesResponseContentBlockDelta(
                        type='content_block_delta',
                        index=0,
                        content_block=ResponseMessageContentBlockText(
                            type='text',
                            text=chunk,
                        )
                    ), 'content_block_delta']
                response_content += chunk

            output_tokens = await compute_tokens(response_content, request.model)

            yield [StreamMessagesResponseContentBlockStop(
                    type='content_block_stop',
                    index=0,
                ), 'content_block_stop']

            yield [StreamMessagesResponseMessageDelta(
                    type='message_delta',
                    delta=StreamMessagesResponseMessageDeltaDelta(
                        stop_reason='end_turn',
                        stop_sequence=None,
                    ),
                    usage=MessagesResponseUsage(
                        input_tokens=input_tokens,
                        output_tokens=input_tokens,
                    ),
                ), 'message_delta']
        except LengthException as e:
            yield [StreamMessagesResponseMessageDelta(
                    type='message_delta',
                    delta=StreamMessagesResponseMessageDeltaDelta(
                        stop_reason='max_tokens',
                        stop_sequence=None,
                    ),
                    usage=MessagesResponseUsage(
                        input_tokens=input_tokens,
                        output_tokens=output_tokens,
                    ),
                ), 'message_delta']
        except StopSequenceException as e:
            yield [StreamMessagesResponseMessageDelta(
                    type='message_delta',
                    delta=StreamMessagesResponseMessageDeltaDelta(
                        stop_reason='stop_sequence',
                        stop_sequence=e.sequence,
                    ),
                    usage=MessagesResponseUsage(
                        input_tokens=input_tokens,
                        output_tokens=output_tokens,
                    ),
                ), 'message_delta']
        except Exception as e:
            yield [StreamMessagesResponseMessageDelta(
                    type='message_delta',
                    delta=StreamMessagesResponseMessageDeltaDelta(
                        stop_reason='refusal',
                        stop_sequence=None,
                    ),
                    usage=MessagesResponseUsage(
                        input_tokens=input_tokens,
                        output_tokens=output_tokens,
                    ),
                ), 'message_delta']
        finally:
            yield [StreamMessagesResponseMessageStop(
                    type='message_stop',
                ), 'message_stop']

    async def stream_response():
        async for [data, event] in create_streaming_events():
            yield f'event: {event}\n'
            yield f'data: {data.model_dump_json()}\n\n'
    return StreamingResponse(stream_response(), media_type='text/event-stream')

async def create_chat_message(request: MessagesRequest, raw_request: Request):
    try:
        chat = get_chat()
        if not chat:
            return helpers.create_501_error('chat')

        chat_sig = inspect.signature(chat)
        arguments = get_arguments(request, chat_sig)

        if 'request' in chat_sig.parameters:
            arguments['request'] = raw_request

        messages = get_messages(request)
        if 'messages' in chat_sig.parameters:
            arguments['messages'] = messages

        if not request.stream:
            return await create_chat_message_response(
                request, chat, arguments, messages)

        return await create_streaming_chat_message_response(
            request, chat, arguments, messages)
    except ValueError as e:
        return helpers.create_500_error(str(e))
