# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/00_core.ipynb.

# %% auto 0
__all__ = ['Placements', 'empty', 'get_db', 'find_var', 'find_dialog_id', 'find_msgs', 'find_msg_id', 'read_msg_ids', 'msg_idx',
           'read_msg', 'del_msg', 'add_msg', 'update_msg', 'load_gist', 'gist_file', 'import_string', 'is_usable_tool',
           'mk_toollist', 'import_gist', 'export_dialog', 'import_dialog', 'tool_info', 'asdict']

# %% ../nbs/00_core.ipynb
import json, importlib, linecache
from typing import Dict
from tempfile import TemporaryDirectory
from ipykernel_helper import *
from dataclasses import dataclass

from fastcore.utils import *
from fastcore.meta import delegates
from ghapi.all import *
from fastlite import *
from fastcore.xtras import asdict
from inspect import currentframe,Parameter,signature
from httpx import get as xget, post as xpost
from .core import __all__ as _all
from IPython.display import display,Markdown

# %% ../nbs/00_core.ipynb
_all_ = ["asdict"]

# %% ../nbs/00_core.ipynb
def get_db(ns:dict=None):
    app_path = Path('/app') if Path('/.dockerenv').exists() else Path('.')
    if os.environ.get('IN_SOLVEIT', False): dataparent,nm = app_path, 'data.db'
    else: dataparent,nm = Path('..'),'dev_data.db'
    db = database(dataparent/'data'/nm)
    dcs = [o for o in all_dcs(db) if o.__name__[0]!='_']
    if ns:
        for o in dcs: ns[o.__name__]=o
    return db

# %% ../nbs/00_core.ipynb
def find_var(var:str):
    "Search for var in all frames of the call stack"
    frame = currentframe()
    while frame:
        dv = frame.f_globals.get(var, frame.f_locals.get(var, None))
        if dv: return dv
        frame = frame.f_back
    raise ValueError(f"Could not find {var} in any scope")

# %% ../nbs/00_core.ipynb
def find_dialog_id():
    "Get the dialog id by searching the call stack for __dialog_id."
    return find_var('__dialog_id')

# %% ../nbs/00_core.ipynb
def find_msgs(
    pattern:str='', # Optional text to search for
    msg_type:str=None, # optional limit by message type ('code', 'note', or 'prompt')
    limit:int=None, # Optionally limit number of returned items
    include_output:bool=True # Include output in returned dict?
)->list[dict]:
    "Find `list[dict]` of messages in current specific dialog that contain the given information. To refer to a message found later, use its `sid` field (which is the pk)."
    did = find_dialog_id()
    db = get_db()
    res = db.t.message('did=? AND content LIKE ? ORDER BY mid', [did, f'%{pattern}%'], limit=limit)
    res = [asdict(o) for o in res if not msg_type or (msg_type==o.msg_type)]
    if not include_output:
        for o in res: o.pop('output', None)
    return res

# %% ../nbs/00_core.ipynb
def find_msg_id():
    "Get the message id by searching the call stack for __dialog_id."
    return find_var('__msg_id')

# %% ../nbs/00_core.ipynb
def read_msg_ids()->list[str]:
    "Get all ids in current dialog."
    did = find_dialog_id()
    db = get_db()
    return [o.sid for o in db.t.message('did=?', [did], select='sid', order_by='mid')]

# %% ../nbs/00_core.ipynb
def msg_idx():
    "Get relative index of current message in dialog."
    ids = read_msg_ids()
    return ids,ids.index(find_msg_id())

# %% ../nbs/00_core.ipynb
def read_msg(n:int=-1,     # Message index (if relative, +ve is downwards)
             relative:bool=True  # Is `n` relative to current message (True) or absolute (False)?
    ):
    "Get the `Message` object indexed in the current dialog."
    ids,idx = msg_idx()
    if relative:
        idx = idx+n
        if not 0<=idx<len(ids): return None
    else: idx = n
    db = get_db()
    return db.t.message.selectone('sid=?', [ids[idx]])

# %% ../nbs/00_core.ipynb
def del_msg(
    sid:str=None, # sid (stable id -- pk) of message that placement is relative to (if None, uses current message)
):
    "Delete a message from the dialog. Be sure to pass a `sid`, not a `mid`."
    xpost('http://localhost:5001/rm_msg_', data=dict(msid=sid)).raise_for_status()

# %% ../nbs/00_core.ipynb
def _msg(
    msg_type: str='note', # Message type, can be 'code', 'note', or 'prompt'
    output:str='', # For prompts/code, initial output
    time_run: str | None = '', # When was message executed
    is_exported: int | None = 0, # Export message to a module?
    skipped: int | None = 0, # Hide message from prompt?
    i_collapsed: int | None = 0, # Collapse input?
    o_collapsed: int | None = 0, # Collapse output?
    header_collapsed: int | None = 0, # Collapse heading section?
    pinned: int | None = 0 # Pin to context?
): ...

Placements = str_enum('Placements', 'add_after', 'add_before', 'update', 'at_start', 'at_end')

# %% ../nbs/00_core.ipynb
@delegates(_msg)
def add_msg(
    content:str, # Content of the message (i.e the message prompt, code, or note text)
    placement:str='add_after', # Can be 'add_after', 'add_before', 'update', 'at_start', 'at_end'
    sid:str=None, # sid (stable id -- pk) of message that placement is relative to (if None, uses current message)
    **kwargs
):
    """Add/update a message to the queue to show after code execution completes.
    Be sure to pass a `sid` (stable id) not a `mid` (which is used only for sorting, and can change).
    Sets msg_type to 'note' by default if not update placement."""
    if 'msg_type' not in kwargs and placement!='update': kwargs['msg_type']='note'
    mt = kwargs.get('msg_type',None)
    ot = kwargs.get('output',None)
    if mt and mt not in ('note', 'code', 'prompt'): return "msg_type must be 'code', 'note', or 'prompt'."
    if mt=='note' and ot: return "note messages cannot have an output."
    if mt=='code':
        try: json.loads(ot or '[]')
        except: return "Code output must be valid json"
    if not sid: sid = find_msg_id()
    data = dict(content=content, placement=placement, sid=sid, **kwargs)
    return xpost('http://localhost:5001/add_relative_', data=data).text

# %% ../nbs/00_core.ipynb
@delegates(add_msg)
def _add_msg_unsafe(
    content:str, # Content of the message (i.e the message prompt, code, or note text)
    run:bool=False, # For prompts, send it to the AI; for code, execute it (*DANGEROUS -- be careful of what you run!)
    **kwargs
):
    """Add/update a message to the queue to show after code execution completes, and optionally run it. Be sure to pass a `sid` (stable id) not a `mid` (which is used only for sorting, and can change).
    *WARNING*--This can execute arbitrary code, so check carefully what you run!--*WARNING"""
    return add_msg(content=content, run=run, **kwargs)

# %% ../nbs/00_core.ipynb
def _umsg(
    msg_type: str|None = None, # Message type, can be 'code', 'note', or 'prompt'
    output:str|None = None, # For prompts/code, the output
    time_run: str | None = None, # When was message executed
    is_exported: int | None = None, # Export message to a module?
    skipped: int | None = None, # Hide message from prompt?
    i_collapsed: int | None = None, # Collapse input?
    o_collapsed: int | None = None, # Collapse output?
    header_collapsed: int | None = None, # Collapse heading section?
    pinned: int | None = None # Pin to context?
): ...

# %% ../nbs/00_core.ipynb
@delegates(_umsg)
def update_msg(
    sid:str=None, # sid (stable id -- pk) of message to update (if None, uses current message)
    content:str|None = None, # Content of the message (i.e the message prompt, code, or note text)
    msg:Optional[Dict]=None, # Dictionary of field keys/values to update
    **kwargs):
    """Update an existing message. Provide either `msg` OR field key/values to update.
    Use `content` param to update contents. Be sure to pass a `sid` (stable id -- the pk) not a `mid`
    (which is used only for sorting, and can change).
    Only include parameters to update--missing ones will be left unchanged."""
    kw = (msg or {}) | kwargs
    sid = kw.pop('sid', sid)
    if not sid: raise TypeError("update_msg needs either a dict message or `sid=...`")
    kw.pop('did', None)
    return add_msg(content, placement='update', sid=sid, **kw)

# %% ../nbs/00_core.ipynb
def load_gist(gist_id:str):
    "Retrieve a gist"
    api = GhApi()
    if '/' in gist_id: *_,user,gist_id = gist_id.split('/')
    else: user = None
    return api.gists.get(gist_id, user=user)

# %% ../nbs/00_core.ipynb
def gist_file(gist_id:str):
    "Get the first file from a gist"
    gist = load_gist(gist_id)
    return first(gist.files.values())

# %% ../nbs/00_core.ipynb
def import_string(
    code:str, # Code to import as a module
    name:str  # Name of module to create
):
    with TemporaryDirectory() as tmpdir:
        path = Path(tmpdir) / f"{name}.py"
        path.write_text(code)
        # linecache.cache storage allows inspect.getsource() after tmpdir lifetime ends
        linecache.cache[str(path)] = (len(code), None, code.splitlines(keepends=True), str(path))
        spec = importlib.util.spec_from_file_location(name, path)
        module = importlib.util.module_from_spec(spec)
        sys.modules[name] = module
        spec.loader.exec_module(module)
        return module

# %% ../nbs/00_core.ipynb
empty = Parameter.empty

def is_usable_tool(func:callable):
    "True if the function has a docstring and all parameters have types, meaning that it can be used as an LLM tool."    
    if not func.__doc__ or not callable(func): return False
    return all(p.annotation != empty for p in signature(func).parameters.values())

# %% ../nbs/00_core.ipynb
def mk_toollist(syms):
    return "\n".join(f"- &`{sym.__name__}`: {sym.__doc__}" for sym in syms if is_usable_tool(sym))

# %% ../nbs/00_core.ipynb
def import_gist(
    gist_id:str, # user/id or just id of gist to import as a module
    mod_name:str=None, # module name to create (taken from gist filename if not passed)
    add_global:bool=True, # add module to caller's globals?
    import_wildcard:bool=False, # import all exported symbols to caller's globals
    create_msg:bool=False # Add a message that lists usable tools
):
    "Import gist directly from string without saving to disk"
    fil = gist_file(gist_id)
    mod_name = mod_name or Path(fil['filename']).stem
    module = import_string(fil['content'], mod_name)
    glbs = currentframe().f_back.f_globals
    if add_global: glbs[mod_name] = module
    syms = getattr(module, '__all__', None)
    if syms is None: syms = [o for o in dir(module) if not o.startswith('_')]
    syms = [getattr(module, nm) for nm in syms]
    if import_wildcard:
        for sym in syms: glbs[sym.__name__] = sym
    if create_msg:
        pref = getattr(module, '__doc__', "Tools added to dialog:")
        add_msg(f"{pref}\n\n{mk_toollist(syms)}")
    return module

# %% ../nbs/00_core.ipynb
__EXPORT_FIELDS = set('content output input_tokens output_tokens msg_type is_exported skipped pinned i_collapsed o_collapsed header_collapsed'.split())

__REQUIRED_FIELDS = set('content output msg_type'.split())

# %% ../nbs/00_core.ipynb
def export_dialog(filename: str, did:int=None):
    "Export dialog messages and optionally attachments to JSON"
    if did is None: did = find_dialog_id()
    db = get_db()
    msgs = db.t.message('did=? and (pinned=0 or pinned is null)', [did], order_by='mid')
    msg_data = [{k:getattr(msg,k) for k in __EXPORT_FIELDS if hasattr(msg, k)}
                for msg in msgs]
    result = {'messages': msg_data, 'dialog_name': db.t.dialog[did].name}
    with open(filename, 'w') as f: json.dump(result, f, indent=2)

# %% ../nbs/00_core.ipynb
def import_dialog(fname, add_header=True):
    "Import dialog messages from JSON file using `add_msg`"
    data = json.loads(Path(fname).read_text())
    for msg in data['messages'][::-1]:
        opts = {k:msg[k] for k in __EXPORT_FIELDS - __REQUIRED_FIELDS if k in msg}
        add_msg(msg.get('content',''), msg.get('msg_type','note'), msg.get('output',''), 'at_end', **opts)
    if add_header: add_msg(f"# Imported Dialog `{fname}`", 'note', placement='at_end')
    return f"Imported {len(data['messages'])} messages"

# %% ../nbs/00_core.ipynb
def tool_info():
    cts='''Tools available from `dialoghelper`:

- &`find_dialog_id`: Get the current dialog id.
- &`find_msg_id`: Get the current message id.
- &`find_msgs`: Find messages in current specific dialog that contain the given information.
- &`read_msg`: Get the message indexed in the current dialog.
- &`del_msg`: Delete a message from the dialog.
- &`add_msg`: Add/update a message to the queue to show after code execution completes.
- &`update_msg`: Update an existing message.'''
    add_msg(cts)
