

def _id():
    return str(uuid.uuid4())


def start_trace(name: str, metadata: Optional[Dict[str, Any]] = None) -> str:
    """Start a new trace with a root span."""
    with _LOCK:
        trace_id = _id()
        root_id = _id()

        root_span = {
            "span_id": root_id,
            "parent_span_id": None,
            "name": name,
            "type": "root",
            "timestamp": _now(),
            "input": "",
            "output": "",
            "metadata": metadata or {},
            "usage": {},
            "duration_ms": 0,
            "success": True,
        }

        _ACTIVE_TRACE["trace_id"] = trace_id
        _ACTIVE_TRACE["spans"] = [root_span]
        _ACTIVE_TRACE["stack"] = [{"span_id": root_id, "start": time.time()}]

        return trace_id


def start_span(name: str, input: Optional[str] = None) -> str:
    """Start a child span."""
    with _LOCK:
        parent = _ACTIVE_TRACE["stack"][-1]["span_id"]
        sid = _id()

        span = {
            "span_id": sid,
            "parent_span_id": parent,
            "name": name,
            "type": "child",
            "timestamp": _now(),
            "input": input or "",
            "output": "",
            "metadata": {},
            "usage": {},
            "duration_ms": 0,
            "success": True,
        }

        _ACTIVE_TRACE["spans"].append(span)
        _ACTIVE_TRACE["stack"].append({"span_id": sid, "start": time.time()})

        return sid


def end_span(span_id: Optional[str] = None, output: Optional[str] = None):
    """Finish a span and calculate duration."""
    with _LOCK:
        if not _ACTIVE_TRACE["stack"]:
            return

        # Auto-close last opened span
        if span_id is None:
            entry = _ACTIVE_TRACE["stack"].pop()
            sid = entry["span_id"]
            dur = int((time.time() - entry["start"]) * 1000)
        else:
            sid = span_id
            dur = 0
            for i in reversed(range(len(_ACTIVE_TRACE["stack"]))):
                if _ACTIVE_TRACE["stack"][i]["span_id"] == sid:
                    dur = int((time.time() - _ACTIVE_TRACE["stack"][i]["start"]) * 1000)
                    _ACTIVE_TRACE["stack"].pop(i)
                    break

        # Update the span
        for sp in _ACTIVE_TRACE["spans"]:
            if sp["span_id"] == sid:
                if output:
                    sp["output"] = output
                sp["duration_ms"] = dur
                return


def end_trace(final_output: Optional[str] = None) -> Optional[Dict[str, Any]]:
    """Finish all spans and return the full trace bundle."""
    with _LOCK:
        if not _ACTIVE_TRACE["trace_id"]:
            return None

        # Close remaining spans
        while _ACTIVE_TRACE["stack"]:
            end_span()

        # update root
        if final_output:
            _ACTIVE_TRACE["spans"][0]["output"] = final_output

        bundle = {
            "trace_id": _ACTIVE_TRACE["trace_id"],
            "spans": _ACTIVE_TRACE["spans"],
        }

        # reset
        _ACTIVE_TRACE["trace_id"] = None
        _ACTIVE_TRACE["spans"] = []
        _ACTIVE_TRACE["stack"] = []

        return bundle 
        

def track_function(name: str = ""):
    """
    Minimal instrumentation decorator.
    
    Captures:
      - function name
      - input args & kwargs
      - output
      - errors
      - latency
    """

    def decorator(func):
        func_name = name or func.__name__

        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            start = time.time()

            # Capture input
            input_data = {
                "args": args,
                "kwargs": kwargs
            }

            try:
                result = func(*args, **kwargs)

                latency = int((time.time() - start) * 1000)

                output_data = {
                    "status": "success",
                    "output": result,
                    "latency_ms": latency
                }

                print(f"[instrument:{func_name}] INPUT =", input_data)
                print(f"[instrument:{func_name}] OUTPUT =", output_data)

                return result

            except Exception as e:
                latency = int((time.time() - start) * 1000)
                error_data = {
                    "status": "error",
                    "error": str(e),
                    "stacktrace": traceback.format_exc(),
                    "latency_ms": latency
                }

                print(f"[instrument:{func_name}] INPUT =", input_data)
                print(f"[instrument:{func_name}] ERROR =", error_data)

                raise e  # keep exception

        return wrapper

    return decorator


class VeriskTracer:
    """
    Context manager for automatic trace handling.
    
    Usage:
        with VeriskTracer("trace_name", user_id="...", metadata={...}) as tracer:
            # code
    """
    def __init__(self, name: str, user_id: Optional[str] = None, session_id: Optional[str] = None, metadata: Optional[Dict] = None):
        self.name = name
        self.user_id = user_id
        self.session_id = session_id or str(uuid.uuid4())
        self.metadata = metadata or {}
        self.trace_id = None

    def __enter__(self):
        # Start the trace
        self.trace_id = start_trace(self.name, self.metadata)
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        # If exception occurred, capture it in the root span
        final_output = None
        if exc_type:
            final_output = f"Error: {exc_val}"
            # We could also add error metadata here if we had access to the root span directly
            
        # End the trace and get the bundle
        bundle = end_trace(final_output)
        
        if bundle:
            # Enrich bundle with session/user info if not already present or override?
            # The start_trace doesn't take user_id/session_id in the current implementation of start_trace
            # We need to add them to the bundle before sending.
            
            bundle["user_id"] = self.user_id
            bundle["session_id"] = self.session_id
            bundle["trace_name"] = self.name
            # Ensure metadata is merged if needed, but start_trace took metadata.
            
            # Send to SQS
            send_to_sqs(bundle)