import inspect
from fastapi import Request
from fastapi.responses import StreamingResponse

from ..protocol import Complete, LengthException, StopSequenceException
from ..base import get_complete, exec_complete
from . import helpers
from .protocol import (
    ErrorDetail,
    CompleteRequest,
    CompleteResponse,
    StreamCompleteResponseCompletion,
    StreamCompleteResponseError,
)

def get_arguments(request: CompleteRequest, sig: inspect.Signature) -> dict:
    arguments = {}
    if 'prompt' in sig.parameters:
        arguments['prompt'] = request.prompt
    if 'model' in sig.parameters:
        arguments['model'] = request.model
    if 'max_tokens' in sig.parameters:
        arguments['max_tokens'] = request.max_tokens_to_sample
    if 'stop' in sig.parameters:
        arguments['stop'] = request.stop_sequences
    if 'temperature' in sig.parameters:
        arguments['temperature'] = request.temperature
    if 'top_p' in sig.parameters:
        arguments['top_p'] = request.top_p
    if 'top_k' in sig.parameters:
        arguments['top_k'] = request.top_k
    if 'stream' in sig.parameters:
        arguments['stream'] = request.stream
    if 'user' in sig.parameters:
        arguments['user'] = request.metadata.user_id \
            if request.metadata else None
    return arguments

async def create_completion_response(
        request: CompleteRequest,
        complete: Complete,
        arguments: dict,
    ) -> CompleteResponse:
    stop_reason = 'stop_sequence'
    try:
        response_content = ''
        async for chunk in exec_complete(complete, arguments):
            response_content += chunk
    except LengthException as e:
        stop_reason = 'max_tokens'
    except StopSequenceException as e:
        stop_reason = 'stop_sequence'

    return CompleteResponse(
        id=f'msg_{helpers.uuid()}',
        completion=response_content,
        model=request.model,
        stop_reason=stop_reason,
        type='completion',
    )

async def create_streaming_completion_response(
        request: CompleteRequest,
        complete: Complete,
        arguments: dict,
    ) -> StreamingResponse:
    async def create_streaming_events():
        try:
            async for chunk in exec_complete(complete, arguments):
                yield [StreamCompleteResponseCompletion(
                        type='completion',
                        completion=chunk,
                        stop_reason=None,
                        model=request.model,
                    ), 'completion']

            yield [StreamCompleteResponseCompletion(
                    type='completion',
                    completion='',
                    stop_reason='stop_sequence',
                    model=request.model,
                ), 'completion']
        except LengthException as e:
            yield [StreamCompleteResponseCompletion(
                    type='completion',
                    completion='',
                    stop_reason='max_tokens',
                    model=request.model,
                ), 'completion']
        except StopSequenceException as e:
            yield [StreamCompleteResponseCompletion(
                    type='completion',
                    completion='',
                    stop_reason='stop_sequence',
                    model=request.model,
                ), 'completion']
        except Exception as e:
            yield [StreamCompleteResponseError(
                error=ErrorDetail(
                    type='server_error',
                    message=str(e),
                )
            )]

    async def stream_response():
        async for [data, event] in create_streaming_events():
            yield f'event: {event}\n'
            yield f'data: {data.model_dump_json()}\n\n'
    return StreamingResponse(stream_response(), media_type='text/event-stream')

async def create_completion(request: CompleteRequest, raw_request: Request):
    try:
        complete = get_complete()
        if not complete:
            return helpers.create_501_error('complete')

        complete_sig = inspect.signature(complete)
        arguments = get_arguments(request, complete_sig)

        if 'request' in complete_sig.parameters:
            arguments['request'] = raw_request

        if not request.stream:
            return await create_completion_response(
                request, complete, arguments)

        return await create_streaming_completion_response(
            request, complete, arguments)
    except ValueError as e:
        return helpers.create_500_error(str(e))
