import json
import logging
from datetime import timedelta
from pathlib import Path
from typing import Any

import dapr.ext.workflow as wf
from dapr.ext.workflow import DaprWorkflowContext

from deepbrief.activities.speak import convert_transcript_to_audio
from deepbrief.services import get_object

logger = logging.getLogger(__name__)

# --- retry policy for all activities ---
CLASSIFY_RETRY_POLICY = wf.RetryPolicy(
    first_retry_interval=timedelta(seconds=1),  # initial backoff
    backoff_coefficient=2.0,  # exponential
    max_retry_interval=timedelta(seconds=60),  # cap individual delay
    max_number_of_attempts=6,  # total tries
    retry_timeout=timedelta(minutes=5),  # overall retry window
)


def generate_recordings_workflow(ctx: DaprWorkflowContext, input: dict):
    """
    Orchestrator that generates audio recordings from transcripts.
    """
    transcripts_metadata = input.get("transcripts_metadata") or []
    recordings_batch_size = max(1, int(input.get("recordings_batch_size", 3)))
    recordings_directory = input.get("recordings_directory", "output/recordings")
    host_config = input.get("host_config", {})
    participants_config = input.get("participants_config", {})
    audio_model = input.get("audio_model", "eleven_flash_v2_5")
    recordings_storage_prefix = input.get("recordings_storage_prefix", "recordings")
    persist_recordings_locally = bool(input.get("persist_recordings_locally", False))

    total_recording_batches = max(
        1, (len(transcripts_metadata) + recordings_batch_size - 1) // recordings_batch_size
    )
    recordings_metadata: list[dict[str, Any]] = []

    if not ctx.is_replaying:
        logger.info(
            "Generating recordings in %s batches of up to %s transcripts each.",
            total_recording_batches,
            recordings_batch_size,
        )

    for batch_start in range(0, len(transcripts_metadata), recordings_batch_size):
        batch_index = batch_start // recordings_batch_size + 1
        current_batch = transcripts_metadata[batch_start : batch_start + recordings_batch_size]

        if not ctx.is_replaying:
            logger.info(
                "Processing audio batch %s/%s (%s transcripts)",
                batch_index,
                total_recording_batches,
                len(current_batch),
            )

        audio_conversion_tasks = [
            ctx.call_activity(
                convert_transcript_to_audio,
                input={
                    "transcript_parts": _load_transcript_parts(transcript_record),
                    "recordings_directory": recordings_directory,
                    "file_name": (
                        Path(transcript_record.get("file_path")).stem
                        if transcript_record.get("file_path")
                        else transcript_record.get("article_id") or "transcript"
                    ),
                    "host_config": host_config,
                    "participants_config": participants_config,
                    "model": audio_model,
                    "storage_prefix": recordings_storage_prefix,
                    "persiste_locally": persist_recordings_locally,
                },
                retry_policy=CLASSIFY_RETRY_POLICY,
            )
            for transcript_record in current_batch
        ]

        results = yield wf.when_all(audio_conversion_tasks)
        for transcript_record, audio_result in zip(current_batch, results, strict=False):
            if not audio_result:
                continue
            recordings_metadata.append(
                {
                    "article_id": transcript_record.get("article_id"),
                    "recording_file_path": audio_result.get("file_path"),
                    "recording_storage_key": audio_result.get("storage_key"),
                    "transcript_file_path": transcript_record.get("file_path"),
                    "transcript_storage_key": transcript_record.get("storage_key"),
                }
            )

        if not ctx.is_replaying:
            logger.info("Audio batch %s completed successfully.", batch_index)

    if not ctx.is_replaying:
        logger.info("All audio files have been successfully generated.")

    return {"recordings_metadata": recordings_metadata}


def _load_transcript_parts(record: dict[str, Any]) -> list[dict[str, Any]]:
    file_path = record.get("file_path")
    if file_path:
        path = Path(file_path)
        if path.exists():
            with path.open("r", encoding="utf-8") as file:
                return json.load(file)

    storage_key = record.get("storage_key")
    if storage_key:
        data = get_object(storage_key)
        return json.loads(data.decode("utf-8"))

    raise RuntimeError("Transcript record missing both file_path and storage_key")
