from __future__ import annotations

import hashlib
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Sequence

import shutil

from PIL import Image

from ..defaults import DEFAULT_GENERATOR
from ..generator import GenerationResult, build_generator
from .schemas import (
    CampaignConfig,
    CampaignRoute,
    ManifestRouteEntry,
    ManifestVariant,
    PlacementManifest,
    PlacementRef,
    ReviewState,
    RouteSource,
    DeterministicBatchSpec,
)
from .workspace import CampaignWorkspace

SUPPORTED_DETERMINISTIC_PROVIDERS = {
    "openrouter:gemini-2.5-flash-image-preview",
    "mock",
}


class DeterministicProviderError(RuntimeError):
    """Raised when a provider without deterministic seeding is requested."""


@dataclass(slots=True)
class VariantPlan:
    route: CampaignRoute
    placement: PlacementRef
    placement_id: str
    variant_index: int
    prompt: str
    provider: str
    seed: int
    output_path: Path
    thumbnail_path: Path
    manifest_file: Path
    provider_params: Dict[str, object] = field(default_factory=dict)


@dataclass(slots=True)
class GenerationStats:
    campaign_id: str
    generated: int = 0
    warnings: List[str] = field(default_factory=list)
    events: List[Dict[str, object]] = field(default_factory=list)

    def extend(self, warnings: Iterable[str]) -> None:
        for warning in warnings:
            if warning not in self.warnings:
                self.warnings.append(warning)


def enforce_deterministic_provider(provider: str) -> None:
    normalized = provider.lower()
    for candidate in SUPPORTED_DETERMINISTIC_PROVIDERS:
        if normalized.startswith(candidate.lower()):
            return
    raise DeterministicProviderError(
        f"Provider '{provider}' is not certified for deterministic seeding."
    )


def compute_variant_seed(
    campaign_id: str,
    route_id: str,
    placement_id: str,
    variant_index: int,
    seed_base: Optional[int] = None,
) -> int:
    """Derive a stable seed from campaign metadata."""
    base = seed_base or 0
    payload = f"{campaign_id}:{route_id}:{placement_id}:{variant_index}:{base}".encode("utf-8")
    digest = hashlib.sha256(payload).hexdigest()
    return int(digest[:8], 16)


def ensure_parent(path: Path) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)


def write_thumbnail(source: Path, destination: Path, size: int = 320) -> None:
    ensure_parent(destination)
    with Image.open(source) as img:
        img.thumbnail((size, size))
        img.save(destination)


def build_prompt(route: CampaignRoute, placement: PlacementRef) -> str:
    tokens = list(route.prompt_tokens)
    tokens.extend(placement.copy_tokens)
    prompt = route.prompt_template.strip()
    if tokens:
        prompt = f"{prompt}\n\n" + "\n".join(tokens)
    return prompt


def plan_generation(
    workspace: CampaignWorkspace,
    config: CampaignConfig,
    routes: Sequence[str] | None,
    placements: Sequence[str] | None,
    variants_override: Optional[int],
    provider_override: Optional[str],
) -> List[VariantPlan]:
    available_routes: Dict[str, CampaignRoute] = {route.route_id: route for route in workspace.iter_routes()}
    selected_routes = routes or list(available_routes)
    plans: List[VariantPlan] = []
    default_provider = provider_override or config.default_provider
    enforce_deterministic_provider(default_provider)

    for route_id in selected_routes:
        route = available_routes.get(route_id)
        if route is None:
            raise RuntimeError(f"Route '{route_id}' not found in campaign workspace")
        for placement in config.placements:
            placement_id = placement.override_id or placement.template_id
            if placements and placement_id not in placements and placement.template_id not in placements:
                continue
            variant_count = variants_override or placement.variants or config.variant_defaults.count
            provider = placement.provider or default_provider
            enforce_deterministic_provider(provider)
            base_provider_params = dict(config.variant_defaults.provider_params)
            for index in range(int(variant_count)):
                seed = compute_variant_seed(config.campaign_id, route.route_id, placement_id, index)
                file_name = f"v{index + 1:03d}.png"
                image_path = workspace.root / "images" / route.route_id / placement_id / file_name
                thumbnail_path = workspace.thumbnails_dir / route.route_id / placement_id / file_name
                manifest_path = workspace.placement_manifest_path(placement_id)
                plan = VariantPlan(
                    route=route,
                    placement=placement,
                    placement_id=placement_id,
                    variant_index=index,
                    prompt=build_prompt(route, placement),
                    provider=provider,
                    seed=seed,
                    output_path=image_path,
                    thumbnail_path=thumbnail_path,
                    manifest_file=manifest_path,
                    provider_params=base_provider_params,
                )
                plans.append(plan)
    return plans


def plan_from_batch_spec(
    workspace: CampaignWorkspace,
    spec: DeterministicBatchSpec,
) -> List[VariantPlan]:
    config = workspace.load_config()
    placement_lookup: Dict[str, PlacementRef] = {}
    for placement in config.placements:
        placement_id = placement.override_id or placement.template_id
        placement_lookup[placement_id] = placement

    plans: List[VariantPlan] = []
    for route in spec.routes:
        route_summary = route.prompt.splitlines()[0][:120] if route.prompt else route.route_id
        campaign_route = CampaignRoute(
            route_id=route.route_id,
            name=route.route_id.replace("_", " ").title(),
            source=RouteSource.MANUAL,
            summary=route_summary,
            prompt_template=route.prompt,
            prompt_tokens=route.copy_tokens,
            copy_tokens=route.copy_tokens,
        )
        for placement in spec.placements:
            variants = placement.variants or spec.variants_per_placement
            placement_id = placement.placement_id
            base_seed = route.seed_base or 0
            template_ref = placement_lookup.get(placement_id)
            reference = PlacementRef(
                template_id=placement.template_id,
                override_id=placement.placement_id,
                variants=variants,
                copy_tokens=template_ref.copy_tokens if template_ref else [],
                provider=placement.provider or (template_ref.provider if template_ref else None),
            )
            provider = reference.provider or spec.provider or config.default_provider
            enforce_deterministic_provider(provider)
            for index in range(int(variants)):
                seed = compute_variant_seed(spec.campaign_id, route.route_id, placement_id, base_seed + index)
                file_name = f"v{index + 1:03d}.png"
                image_path = workspace.root / "images" / route.route_id / placement_id / file_name
                thumbnail_path = workspace.thumbnails_dir / route.route_id / placement_id / file_name
                manifest_path = workspace.placement_manifest_path(placement_id)
                plan = VariantPlan(
                    route=campaign_route,
                    placement=reference,
                    placement_id=placement_id,
                    variant_index=index,
                    prompt=route.prompt,
                    provider=provider,
                    seed=seed,
                    output_path=image_path,
                    thumbnail_path=thumbnail_path,
                    manifest_file=manifest_path,
                    provider_params=dict(spec.provider_params),
                )
                plans.append(plan)
    return plans


def execute_generation(
    workspace: CampaignWorkspace,
    plans: Sequence[VariantPlan],
    generator_kind: Optional[str],
) -> GenerationStats:
    stats = GenerationStats(campaign_id=workspace.campaign_id)
    if not plans:
        return stats

    temp_dir = workspace.root / ".tmp" / datetime.utcnow().strftime("%Y%m%d%H%M%S")
    generator = build_generator(temp_dir, generator_kind or DEFAULT_GENERATOR)

    for plan in plans:
        result = generator.generate(
            plan.prompt,
            1,
            seed=plan.seed,
            provider=plan.provider,
            provider_options=plan.provider_params,
        )
        stats.extend(result.warnings)
        if not result.images:
            raise RuntimeError("Generator returned no images")
        artifact = result.images[0]
        ensure_parent(plan.output_path)
        ensure_parent(plan.thumbnail_path)
        if plan.output_path.exists():
            plan.output_path.unlink()
        artifact.processed_path.rename(plan.output_path)
        write_thumbnail(plan.output_path, plan.thumbnail_path)
        _upsert_manifest(workspace, plan, plan.output_path)
        timestamp = datetime.utcnow().replace(microsecond=0).strftime("%Y-%m-%dT%H:%M:%SZ")
        stats.events.append(
            {
                "timestamp": timestamp,
                "campaign_id": workspace.campaign_id,
                "route_id": plan.route.route_id,
                "placement_id": plan.placement_id,
                "variant_index": plan.variant_index,
                "status": "succeeded",
                "provider": plan.provider,
                "prompt": plan.prompt,
                "seed": plan.seed,
            }
        )
        stats.generated += 1
    shutil.rmtree(temp_dir, ignore_errors=True)
    return stats


def _upsert_manifest(workspace: CampaignWorkspace, plan: VariantPlan, image_path: Path) -> None:
    manifest_path = plan.manifest_file
    if manifest_path.exists():
        manifest = workspace.load_manifest(plan.placement_id)
    else:
        manifest = PlacementManifest(
            campaign_id=workspace.campaign_id,
            placement_id=plan.placement_id,
            template_id=plan.placement.template_id,
            routes=[],
            updated_at=datetime.utcnow().replace(microsecond=0).strftime("%Y-%m-%dT%H:%M:%SZ"),
        )
    route_entry = None
    for existing in manifest.routes:
        if existing.route_id == plan.route.route_id:
            route_entry = existing
            break
    if route_entry is None:
        route_entry = ManifestRouteEntry(
            route_id=plan.route.route_id,
            summary=plan.route.summary,
            status=plan.route.status,
            variants=[],
        )
        manifest.routes.append(route_entry)

    variant_id = f"{plan.route.route_id}-{plan.placement_id}-v{plan.variant_index + 1:03d}"
    record = ManifestVariant(
        variant_id=variant_id,
        index=plan.variant_index,
        file=str(image_path.relative_to(workspace.root)),
        thumbnail=str(plan.thumbnail_path.relative_to(workspace.root)),
        provider=plan.provider,
        prompt=plan.prompt,
        seed=plan.seed,
        params={},
        review_state=ReviewState.PENDING,
        artifacts=[],
        created_at=datetime.utcnow().replace(microsecond=0).strftime("%Y-%m-%dT%H:%M:%SZ"),
    )

    # replace existing variant at index if present
    replaced = False
    for idx, existing in enumerate(route_entry.variants):
        if existing.index == record.index:
            route_entry.variants[idx] = record
            replaced = True
            break
    if not replaced:
        route_entry.variants.append(record)
        route_entry.variants.sort(key=lambda item: item.index)

    manifest.updated_at = datetime.utcnow().replace(microsecond=0).strftime("%Y-%m-%dT%H:%M:%SZ")
    workspace.save_manifest(manifest)
