import json
import logging
from datetime import timedelta
from pathlib import Path

import dapr.ext.workflow as wf
from dapr.ext.workflow import DaprWorkflowContext

from deepbrief.activities.create_episode import generate_episode_metadata, write_episode_to_file
from deepbrief.services import get_object

logger = logging.getLogger(__name__)

# --- retry policy for all activities ---
CLASSIFY_RETRY_POLICY = wf.RetryPolicy(
    first_retry_interval=timedelta(seconds=1),  # initial backoff
    backoff_coefficient=2.0,  # exponential
    max_retry_interval=timedelta(seconds=60),  # cap individual delay
    max_number_of_attempts=6,  # total tries
    retry_timeout=timedelta(minutes=5),  # overall retry window
)


def generate_episodes_workflow(ctx: DaprWorkflowContext, input: dict):
    """
    Orchestrator that generates podcast episode overviews from transcripts.
    """
    transcripts_metadata = input.get("transcripts_metadata") or []
    episode_batch_size = max(1, int(input.get("episode_batch_size", 3)))
    episodes_directory = input.get("episodes_directory", "output/episodes")
    episodes_storage_prefix = input.get("episodes_storage_prefix", "episodes")
    persist_episodes_locally = bool(input.get("persist_episodes_locally", True))
    recordings_metadata = input.get("recordings_metadata") or []
    papers_metadata = input.get("papers_metadata") or []

    papers_metadata_lookup = {
        rec.get("latest_id"): rec for rec in papers_metadata if rec.get("latest_id")
    }

    total_episode_batches = max(
        1, (len(transcripts_metadata) + episode_batch_size - 1) // episode_batch_size
    )

    if not ctx.is_replaying:
        logger.info(
            "Generating episode overviews in %s batches of up to %s transcripts each.",
            total_episode_batches,
            episode_batch_size,
        )

    # Step 9: Create batches of transcript files
    for batch_start in range(0, len(transcripts_metadata), episode_batch_size):
        batch_index = batch_start // episode_batch_size + 1
        current_batch = transcripts_metadata[batch_start : batch_start + episode_batch_size]

        # Log the batch progress
        if not ctx.is_replaying:
            logger.info(
                "Processing overview batch %s/%s (%s transcripts)",
                batch_index,
                total_episode_batches,
                len(current_batch),
            )

        # Fan-out: Create parallel tasks for the current batch of transcript files
        overview_generation_tasks = [
            ctx.call_activity(
                generate_episode_metadata,
                input={
                    "transcript": _load_transcript_text(transcript_record),
                    "paper_id": Path(transcript_record["file_path"]).stem,
                },
                retry_policy=CLASSIFY_RETRY_POLICY,
            )
            for transcript_record in current_batch
        ]

        # Fan-in: Wait for all overview generation tasks to complete
        results = yield wf.when_all(overview_generation_tasks)

        # Create episosed files
        outputs = yield ctx.call_activity(
            write_episode_to_file,
            input={
                "article_index": papers_metadata_lookup,
                "results": results,
                "episodes_directory": episodes_directory,
                "episodes_storage_prefix": episodes_storage_prefix,
                "persist_locally": persist_episodes_locally,
                "recordings_metadata": recordings_metadata,
                "transcripts_metadata": transcripts_metadata,
            },
            retry_policy=CLASSIFY_RETRY_POLICY,
        )
        if not ctx.is_replaying:
            logger.info("Overview batch %s completed successfully.", batch_index)
    if not ctx.is_replaying:
        logger.info("All podcast episode overviews have been successfully generated.")
    return {"episodes_metadata": outputs or []}


def _load_transcript_text(transcript_record: dict) -> str:
    file_path = transcript_record.get("file_path")
    transcript_data = None

    if file_path and Path(file_path).exists():
        with open(file_path, encoding="utf-8") as file:
            transcript_data = json.load(file)
    else:
        storage_key = transcript_record.get("storage_key")
        if storage_key:
            raw = get_object(storage_key)
            transcript_data = json.loads(raw.decode("utf-8"))

    if not transcript_data:
        return ""

    if isinstance(transcript_data, dict):
        transcript_data = transcript_data.get("participants", [])

    return " ".join(part.get("text", "") for part in transcript_data if part.get("text"))
