import time, inspect
from typing import Any, Optional, List, AsyncGenerator, Coroutine
from .helpers import asyncgen, aggregate
from .services import add_history
from .protocol import (
    Chat,
    ChatAsyncGenerator,
    Complete,
    CompleteAsyncGenerator,
    Embed,
    EmbedAsync,
    Tokens,
    TokensAsync,
    Model,
    ModelInfo,
    RequestItem,
    ResponseItem,
    HistoryItem,
)

''' chat '''

_chat: Optional[Chat] = None

def set_chat(chat: Chat) -> Chat:
    '''Set the chat function.'''
    global _chat
    _chat = chat
    return

def get_chat() -> Optional[Chat]:
    '''Get the chat function.'''
    return _chat

async def exec_chat(chat: Chat, args: dict) -> AsyncGenerator[str, None]:
    '''Execute the chat function.'''
    request_item = RequestItem(
        content=args.get('messages', []),
        timestamp=time.time(),
    )
    try:
        response_content = ''
        async for chunk in aggregate(asyncgen(chat, args)):
            response_content += chunk
            yield chunk
    finally:
        response_item = ResponseItem(
            content=response_content,
            timestamp=time.time(),
        )
        add_history(HistoryItem(
            request=request_item,
            response=response_item,
            type='chat'
        ))

''' complete '''

_complete: Optional[Complete] = None

def set_complete(complete: Complete) -> Complete:
    '''Set the complete function.'''
    global _complete
    _complete = complete
    return _complete

def get_complete() -> Optional[Complete]:
    '''Get the complete function.'''
    return _complete

async def exec_complete(complete: Complete, args: dict) -> AsyncGenerator[str, None]:
    '''Execute the complete function.'''
    request_item = RequestItem(
        content=args.get('prompt', ''),
        timestamp=time.time(),
    )
    try:
        response_content = ''
        async for chunk in aggregate(asyncgen(complete, args)):
            response_content += chunk
            yield chunk
    finally:
        response_item = ResponseItem(
            content=response_content,
            timestamp=time.time(),
        )
        add_history(HistoryItem(
            request=request_item,
            response=response_item,
            type='complete'
        ))

''' embed '''

_embed: Optional[Embed] = None

def set_embed(embed: Embed) -> Embed:
    '''Set the embed function.'''
    global _embed
    _embed = embed
    return _embed

def get_embed() -> Optional[Embed]:
    '''Get the embed function.'''
    return _embed

def exec_embed(embed: Embed, args: dict) -> Coroutine[Any, Any, List[List[float]]]:
    '''Execute the embed function.'''
    if inspect.iscoroutinefunction(embed):
        return embed(**args)
    return embed(**args)

''' models '''

_models: Optional[List[Model]] = None

def set_models(models: List[Model]) -> List[Model]:
    '''Set the list of models.'''
    global _models
    _models = models
    return _models

def get_models() -> List[Model]:
    '''Get the list of models.'''
    if _models is None:
        def _create_model(base: Optional[ModelInfo]):
            if base is None:
                return None
            model_id = getattr(base, '__model__', base.__name__)
            owned_by = getattr(base, '__owner__', 'svllm')
            return Model({ 'id': model_id, 'owned_by': owned_by })

        models = [
            _create_model(get_chat()),
            _create_model(get_complete()),
            _create_model(get_embed()),
        ]
        return [m for m in models if m is not None]
    return _models or []

def default_tokens(text: str) -> int:
    return len(text)

_tokens: Tokens = default_tokens

def set_tokens(tokens: Tokens) -> Tokens:
    '''Set the tokens function.'''
    global _tokens
    _tokens = tokens
    return _tokens

def get_tokens() -> Tokens:
    '''Get the tokens function.'''
    return _tokens

async def compute_tokens(text: str, model: str) -> Coroutine[Any, Any, int]:
    token_args = {}
    sig = inspect.signature(_tokens)
    if 'text' in sig.parameters:
        token_args['text'] = text
    if 'model' in sig.parameters:
        token_args['model'] = model
    if inspect.iscoroutinefunction(_tokens):
            return await _tokens(**token_args)
    return _tokens(**token_args)
