import logging
from datetime import datetime, timedelta
from typing import Any

import dapr.ext.workflow as wf
from dapr.ext.workflow import DaprWorkflowContext

from deepbrief.activities.prompts import generate_prompt
from deepbrief.activities.read import read_pdf_docling, read_pdf_pypdf
from deepbrief.activities.research import generate_search_queries, web_search
from deepbrief.activities.transcribe import generate_transcript, write_transcript_to_file

logger = logging.getLogger(__name__)

# --- retry policy for all activities ---
CLASSIFY_RETRY_POLICY = wf.RetryPolicy(
    first_retry_interval=timedelta(seconds=1),  # initial backoff
    backoff_coefficient=2.0,  # exponential
    max_retry_interval=timedelta(seconds=60),  # cap individual delay
    max_number_of_attempts=6,  # total tries
    retry_timeout=timedelta(minutes=5),  # overall retry window
)


def generate_transcripts_workflow(ctx: DaprWorkflowContext, input: dict):
    """
    Orchestrator that generates transcripts for a list of papers using batched fan-out.
    """
    papers_metadata = input.get("papers_metadata") or []
    participants = input.get("participants") or []
    podcast_name = input.get("podcast_name", "AI Security Voice")
    host_name = input.get("host_name", "Titan")
    dialogue_max_rounds = int(input.get("dialogue_max_rounds", 3))
    transcripts_directory = input.get("transcripts_directory", "output/transcripts")
    transcripts_storage_prefix = input.get("transcripts_storage_prefix", "transcripts")
    persist_transcripts_locally = bool(input.get("persist_transcripts_locally", False))
    markdowns_directory = input.get("markdowns_directory", "output/markdowns")
    markdowns_storage_prefix = input.get("markdowns_storage_prefix", "markdowns")
    persist_markdowns_locally = bool(input.get("persist_markdowns_locally", False))
    document_reader = str(input.get("document_reader", "docling")).lower()
    transcript_batch_size = max(1, int(input.get("transcript_batch_size", 5)))

    transcripts_metadata: list[dict[str, Any]] = []

    total_transcript_batches = max(
        1, (len(papers_metadata) + transcript_batch_size - 1) // transcript_batch_size
    )

    if not ctx.is_replaying:
        logger.info(
            "Generating transcripts in %s batches of up to %s papers each.",
            total_transcript_batches,
            transcript_batch_size,
        )

    for batch_start in range(0, len(papers_metadata), transcript_batch_size):
        batch_index = batch_start // transcript_batch_size + 1
        current_batch = papers_metadata[batch_start : batch_start + transcript_batch_size]

        if not ctx.is_replaying:
            logger.info(
                "Processing transcription batch %s/%s (%s papers)",
                batch_index,
                total_transcript_batches,
                len(current_batch),
            )

        child_tasks = [
            ctx.call_child_workflow(
                generate_single_transcript_workflow,
                input={
                    "record": record,
                    "participants": participants,
                    "podcast_name": podcast_name,
                    "host_name": host_name,
                    "dialogue_max_rounds": dialogue_max_rounds,
                    "transcripts_directory": transcripts_directory,
                    "transcripts_storage_prefix": transcripts_storage_prefix,
                    "persist_transcripts_locally": persist_transcripts_locally,
                    "markdowns_directory": markdowns_directory,
                    "markdowns_storage_prefix": markdowns_storage_prefix,
                    "persist_markdowns_locally": persist_markdowns_locally,
                    "document_reader": document_reader,
                },
                retry_policy=CLASSIFY_RETRY_POLICY,
            )
            for record in current_batch
        ]

        batch_results = yield wf.when_all(child_tasks)
        transcripts_metadata.extend(res for res in batch_results if res)

    return {"transcripts_metadata": transcripts_metadata}


def generate_single_transcript_workflow(ctx: DaprWorkflowContext, input: dict):
    """Process a single paper end-to-end and return transcript metadata."""
    record = input.get("record") or {}
    if not record:
        return None

    participants = input.get("participants") or []
    podcast_name = input.get("podcast_name", "AI Security Voice")
    host_name = input.get("host_name", "Host")
    dialogue_max_rounds = int(input.get("dialogue_max_rounds", 3))
    transcripts_directory = input.get("transcripts_directory", "output/transcripts")
    transcripts_storage_prefix = input.get("transcripts_storage_prefix", "transcripts")
    persist_transcripts_locally = bool(input.get("persist_transcripts_locally", False))
    markdowns_directory = input.get("markdowns_directory", "output/markdowns")
    markdowns_storage_prefix = input.get("markdowns_storage_prefix", "markdowns")
    persist_markdowns_locally = bool(input.get("persist_markdowns_locally", False))
    document_reader = str(input.get("document_reader", "docling")).lower()
    reader_activity = read_pdf_docling if document_reader == "docling" else read_pdf_pypdf

    canonical_id = record.get("id")
    latest_id = record.get("latest_id") or canonical_id
    file_path = record.get("file_path")
    storage_key = record.get("storage_key")

    documents: list[dict[str, Any]] = yield ctx.call_activity(
        reader_activity,
        input={
            "file_path": file_path,
            "storage_key": storage_key,
            "doc_id": latest_id,
            "markdowns_directory": markdowns_directory,
            "storage_prefix": markdowns_storage_prefix,
            "persist_locally": persist_markdowns_locally,
        },
        retry_policy=CLASSIFY_RETRY_POLICY,
    )
    if not documents:
        logger.warning("No pages returned for %s; skipping transcripts.", latest_id)
        return

    accumulated_context = ""
    transcript_parts: list[dict[str, Any]] = []
    aggregated_sources: list[dict[str, str | None]] = []
    total_iterations = len(documents)

    for chunk_index, document in enumerate(documents):
        iteration_index = chunk_index + 1
        doc_text = document.get("text") or ""
        document_with_context = {
            "text": str(doc_text).encode("utf-8", "replace").decode("utf-8"),
            "iteration_index": iteration_index,
            "total_iterations": total_iterations,
            "context": accumulated_context.encode("utf-8", "replace").decode("utf-8"),
            "participants": (
                [p["name"] for p in participants if "name" in p] if participants else []
            ),
        }

        if iteration_index == 1:
            document_with_context["doc_metadata"] = {
                "title": str(record.get("title", "Unknown Title"))
                .encode("utf-8", "replace")
                .decode("utf-8"),
                "summary": str(record.get("summary", "No summary available."))
                .encode("utf-8", "replace")
                .decode("utf-8"),
            }

        generate_queries_results = yield ctx.call_activity(
            generate_search_queries,
            input={
                "context": document_with_context["text"],
                "date_time": datetime.now().strftime("%B %d, %Y"),
            },
            retry_policy=CLASSIFY_RETRY_POLICY,
        )
        search_queries = (generate_queries_results or {}).get("queries", [])
        search_tasks = [
            ctx.call_activity(
                web_search,
                input={"query": query["query"]},
                retry_policy=CLASSIFY_RETRY_POLICY,
            )
            for query in search_queries
        ]
        search_results = yield wf.when_all(search_tasks)
        search_context = "\n\n".join(
            f"### Web Findings for: {res['query']}\n{res['text']}\nSources:\n{res['sources']}"
            for res in search_results
            if res.get("text")
        )
        document_with_context["search_context"] = search_context
        if search_context:
            if not ctx.is_replaying:
                logger.info(
                    "Web search context appended for %s page %s",
                    latest_id,
                    iteration_index,
                )
        else:
            if not ctx.is_replaying:
                logger.info(
                    "No web search findings for %s page %s",
                    latest_id,
                    iteration_index,
                )
        for res in search_results:
            for item in res.get("results", []) or []:
                aggregated_sources.append(
                    {
                        "title": item.get("title"),
                        "url": item.get("url"),
                    }
                )

        generated_prompt = yield ctx.call_activity(
            generate_prompt,
            input=document_with_context,
            retry_policy=CLASSIFY_RETRY_POLICY,
        )
        prompt_parameters = {
            "podcast_name": podcast_name,
            "host_name": host_name,
            "prompt": generated_prompt,
            "max_rounds": dialogue_max_rounds,
        }
        dialogue_entry = yield ctx.call_activity(
            generate_transcript,
            input=prompt_parameters,
            retry_policy=CLASSIFY_RETRY_POLICY,
        )
        if not ctx.is_replaying:
            logger.info(
                "Generated %s dialogue entries for %s page %s",
                len(dialogue_entry.get("participants", [])),
                latest_id,
                iteration_index,
            )

        accumulated_context = ""
        for participant in dialogue_entry["participants"]:
            accumulated_context += f" {participant['name']}: {participant['text']}"
            transcript_parts.append(participant)

    transcript_result = yield ctx.call_activity(
        write_transcript_to_file,
        input={
            "podcast_dialogue": transcript_parts,
            "transcripts_directory": transcripts_directory,
            "file_name": f"{latest_id}.json",
            "storage_prefix": transcripts_storage_prefix,
            "persist_locally": persist_transcripts_locally,
        },
        retry_policy=CLASSIFY_RETRY_POLICY,
    )
    if not ctx.is_replaying:
        logger.info("Transcript persisted for %s", latest_id)

    return {
        "article_id": latest_id,
        "file_path": transcript_result["file_path"],
        "storage_key": transcript_result["storage_key"],
        "markdowns_directory": markdowns_directory,
        "markdowns_storage_prefix": markdowns_storage_prefix,
        "search_sources": aggregated_sources,
    }
