"""
Distributed Trainer with FSDP2, DeepSpeed ZeRO, and torch.compile support
"""
import os
import torch
import torch.distributed as dist
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    ShardingStrategy,
    MixedPrecision,
    BackwardPrefetch,
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    checkpoint_wrapper,
    CheckpointImpl,
    apply_activation_checkpointing,
)
import functools
from typing import Optional, Dict, Any
from pathlib import Path
import logging
import datetime
from tqdm import tqdm
import wandb

logger = logging.getLogger(__name__)


class DistributedTrainer:
    """
    Production-ready distributed trainer with all optimizations
    """
    def __init__(
        self,
        model: torch.nn.Module,
        train_config,
        model_config,
        train_dataloader,
        eval_dataloader=None,
    ):
        logger.info("Initializing DistributedTrainer...")
        self.model = model
        self.train_config = train_config
        self.model_config = model_config
        self.train_dataloader = train_dataloader
        self.eval_dataloader = eval_dataloader
        self.use_data_parallel = False  # Flag for DataParallel mode

        # Initialize distributed
        logger.info("Step 1/7: Initializing distributed environment...")
        self._init_distributed()
        logger.info("✓ Distributed environment initialized")

        # Setup model for distributed training
        logger.info("Step 2/7: Setting up distributed model...")
        self._setup_distributed_model()
        logger.info("✓ Distributed model setup complete")

        # Setup optimizer and scheduler
        logger.info("Step 3/7: Setting up optimizer...")
        self._setup_optimizer()
        logger.info("✓ Optimizer created")

        logger.info("Step 4/7: Setting up learning rate scheduler...")
        self._setup_scheduler()
        logger.info("✓ Scheduler created")

        # Setup mixed precision
        logger.info("Step 5/7: Setting up mixed precision...")
        self._setup_mixed_precision()
        logger.info("✓ Mixed precision configured")

        # Setup loss function
        logger.info("Step 6/7: Setting up loss function...")
        self._setup_loss()
        logger.info("✓ Loss function ready")

        # Compile model if requested (disabled by default for CUDA compatibility)
        if getattr(train_config, 'use_torch_compile', False):
            logger.info("Compiling model with torch.compile...")
            self._compile_model()
            logger.info("✓ Model compilation complete")

        # Setup logging
        logger.info("Step 7/7: Setting up training logs...")
        self._setup_logging()
        logger.info("✓ Logging configured")

        self.global_step = 0
        self.epoch = 0
        logger.info("="*60)
        logger.info("✅ DistributedTrainer initialization complete!")
        logger.info(f"Training on: {self.device}")
        logger.info(f"World size: {self.world_size}")
        logger.info(f"Using DataParallel: {self.use_data_parallel}")
        logger.info("="*60)

    def _init_distributed(self):
        """Initialize distributed training or single-device training"""
        # Configure logging if not already configured
        if not logging.getLogger().handlers:
            logging.basicConfig(
                format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
                datefmt='%m/%d/%Y %H:%M:%S',
                level=logging.INFO
            )

        # Check for explicit distributed environment (e.g., torchrun)
        has_explicit_env = (
            "RANK" in os.environ and
            "WORLD_SIZE" in os.environ
        )
        
        # Check for multi-GPU but no distributed env (Kaggle/Colab scenario)
        num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0
        has_multi_gpu = num_gpus > 1 and not has_explicit_env
        
        if has_explicit_env:
            # Explicit distributed environment (torchrun, etc.)
            self.is_distributed = True
            if not dist.is_initialized():
                try:
                    # Add timeout to fail fast if something is wrong
                    dist.init_process_group(
                        backend=self.train_config.distributed_backend,
                        timeout=datetime.timedelta(seconds=30)
                    )
                except Exception as e:
                    logger.warning(f"Failed to initialize distributed training: {e}")
                    logger.info("Falling back to single-device training")
                    self.is_distributed = False
                    
        elif has_multi_gpu:
            # Use DataParallel for Kaggle/Colab multi-GPU (single process)
            logger.info(f"🎮 Detected {num_gpus} GPUs. Using DataParallel for multi-GPU training...")
            logger.info("💡 For true distributed training, use: torchrun --nproc_per_node=N script.py")
            
            # Don't try to fake distributed training - use DataParallel instead
            self.is_distributed = False
            self.use_data_parallel = True
            
            # No need to call dist.init_process_group - DataParallel handles it
        else:
            # Single device
            self.is_distributed = False

        if self.is_distributed:
            self.rank = dist.get_rank()
            self.world_size = dist.get_world_size()
            self.local_rank = int(os.environ.get("LOCAL_RANK", 0))
            torch.cuda.set_device(self.local_rank)
            self.device = torch.device(f"cuda:{self.local_rank}")
            if self.rank == 0:
                logger.info(f"✅ Initialized distributed training with {self.world_size} GPUs")
        else:
            # Single device training
            self.rank = 0
            self.world_size = 1
            self.local_rank = 0
            if torch.cuda.is_available():
                self.device = torch.device("cuda:0")
                torch.cuda.set_device(0)
            else:
                self.device = torch.device("cpu")
            logger.info(f"✅ Initialized single-device training on {self.device}")

    def _setup_distributed_model(self):
        """Setup model with FSDP or DeepSpeed"""
        logger.info("Setting up model on devices...")
        self.model = self.model.to(self.device)

        # Check if we should use DataParallel
        if self.use_data_parallel and torch.cuda.device_count() > 1:
            logger.info(f"Wrapping model with DataParallel for {torch.cuda.device_count()} GPUs...")
            self.model = torch.nn.DataParallel(self.model)
            logger.info("✓ DataParallel initialized - model will use all available GPUs")
            return

        if not self.is_distributed:
            # Single device - no distributed wrapper needed
            logger.info("✓ Model ready on single device")
            return

        if self.train_config.use_deepspeed:
            logger.info("Initializing DeepSpeed ZeRO...")
            self._setup_deepspeed()
            logger.info("✓ DeepSpeed initialized")
        elif self.train_config.use_fsdp:
            logger.info("Initializing FSDP...")
            self._setup_fsdp()
            logger.info("✓ FSDP initialized")
        else:
            # Standard DDP
            logger.info("Initializing DDP...")
            self.model = torch.nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[self.local_rank],
                output_device=self.local_rank,
            )
            logger.info("✓ DDP initialized")

    def _setup_fsdp(self):
        """Setup Fully Sharded Data Parallel (FSDP2)"""
        # Define auto wrap policy for transformer blocks
        from parallel_llm.core.diffusion_transformer import DiffusionTransformerBlock

        auto_wrap_policy = functools.partial(
            transformer_auto_wrap_policy,
            transformer_layer_cls={DiffusionTransformerBlock},
        )

        # Mixed precision configuration
        if self.train_config.mixed_precision == "bf16":
            mp_policy = MixedPrecision(
                param_dtype=torch.bfloat16,
                reduce_dtype=torch.bfloat16,
                buffer_dtype=torch.bfloat16,
            )
        elif self.train_config.mixed_precision == "fp16":
            mp_policy = MixedPrecision(
                param_dtype=torch.float16,
                reduce_dtype=torch.float16,
                buffer_dtype=torch.float16,
            )
        else:
            mp_policy = None

        # Sharding strategy
        sharding_strategy_map = {
            "full": ShardingStrategy.FULL_SHARD,
            "shard_grad_op": ShardingStrategy.SHARD_GRAD_OP,
            "no_shard": ShardingStrategy.NO_SHARD,
        }
        sharding_strategy = sharding_strategy_map[self.train_config.fsdp_sharding_strategy]

        # Wrap model with FSDP
        self.model = FSDP(
            self.model,
            auto_wrap_policy=auto_wrap_policy,
            mixed_precision=mp_policy,
            sharding_strategy=sharding_strategy,
            backward_prefetch=BackwardPrefetch.BACKWARD_PRE if self.train_config.fsdp_backward_prefetch else None,
            forward_prefetch=self.train_config.fsdp_forward_prefetch,
            device_id=torch.cuda.current_device(),
            limit_all_gathers=True,
        )

        # Apply activation checkpointing
        if self.train_config.gradient_checkpointing:
            check_fn = lambda submodule: isinstance(submodule, DiffusionTransformerBlock)
            non_reentrant_wrapper = functools.partial(
                checkpoint_wrapper,
                checkpoint_impl=CheckpointImpl.NO_REENTRANT,
            )
            apply_activation_checkpointing(
                self.model,
                checkpoint_wrapper_fn=non_reentrant_wrapper,
                check_fn=check_fn,
            )

        if self.rank == 0:
            logger.info(f"FSDP initialized with {sharding_strategy}")

    def _setup_deepspeed(self):
        """Setup DeepSpeed ZeRO"""
        import deepspeed

        ds_config = {
            "train_batch_size": self.train_config.batch_size * self.world_size,
            "train_micro_batch_size_per_gpu": self.train_config.batch_size,
            "gradient_accumulation_steps": 1,
            "optimizer": {
                "type": "AdamW",
                "params": {
                    "lr": self.train_config.learning_rate,
                    "betas": [self.train_config.adam_beta1, self.train_config.adam_beta2],
                    "eps": self.train_config.adam_epsilon,
                    "weight_decay": self.train_config.weight_decay,
                }
            },
            "scheduler": {
                "type": "WarmupDecayLR",
                "params": {
                    "total_num_steps": self.train_config.num_train_steps,
                    "warmup_num_steps": self.train_config.warmup_steps,
                }
            },
            "fp16": {"enabled": self.train_config.mixed_precision == "fp16"},
            "bf16": {"enabled": self.train_config.mixed_precision == "bf16"},
            "zero_optimization": {
                "stage": self.train_config.zero_stage,
                "offload_optimizer": {
                    "device": "cpu" if self.train_config.zero_offload_optimizer else "none",
                    "pin_memory": True,
                },
                "offload_param": {
                    "device": "cpu" if self.train_config.zero_offload_params else "none",
                    "pin_memory": True,
                },
                "overlap_comm": True,
                "contiguous_gradients": True,
                "reduce_bucket_size": 5e8,
                "stage3_prefetch_bucket_size": 5e8,
                "stage3_param_persistence_threshold": 1e6,
            },
            "gradient_clipping": self.train_config.max_grad_norm,
            "wall_clock_breakdown": False,
        }

        self.model, self.optimizer, _, self.scheduler = deepspeed.initialize(
            model=self.model,
            config=ds_config,
        )

        if self.rank == 0:
            logger.info(f"DeepSpeed ZeRO-{self.train_config.zero_stage} initialized")

    def _setup_optimizer(self):
        """Setup AdamW optimizer"""
        if self.train_config.use_deepspeed:
            logger.info("Skipping optimizer setup (DeepSpeed will handle it)")
            return  # DeepSpeed handles optimizer

        try:
            logger.info("Collecting model parameters for optimizer...")
            # Separate parameters with and without weight decay
            decay_params = []
            no_decay_params = []

            param_count = 0
            for name, param in self.model.named_parameters():
                if param.requires_grad:
                    param_count += 1
                    if "bias" in name or "norm" in name or "embed" in name:
                        no_decay_params.append(param)
                    else:
                        decay_params.append(param)

            logger.info(f"Found {param_count} trainable parameters ({len(decay_params)} with decay, {len(no_decay_params)} without)")

            optimizer_grouped_parameters = [
                {"params": decay_params, "weight_decay": self.train_config.weight_decay},
                {"params": no_decay_params, "weight_decay": 0.0},
            ]

            logger.info("Creating AdamW optimizer...")
            self.optimizer = torch.optim.AdamW(
                optimizer_grouped_parameters,
                lr=self.train_config.learning_rate,
                betas=(self.train_config.adam_beta1, self.train_config.adam_beta2),
                eps=self.train_config.adam_epsilon,
            )
            logger.info(f"Optimizer created with learning rate: {self.train_config.learning_rate}")
        except Exception as e:
            logger.error(f"Failed to setup optimizer: {e}")
            raise

    def _setup_scheduler(self):
        """Setup learning rate scheduler"""
        if self.train_config.use_deepspeed:
            logger.info("Skipping scheduler setup (DeepSpeed will handle it)")
            return  # DeepSpeed handles scheduler

        try:
            # Calculate main scheduler iterations (ensure it's positive)
            main_iters = max(1, self.train_config.num_train_steps - self.train_config.warmup_steps)
            logger.info(f"Setting up {self.train_config.lr_scheduler} scheduler for {main_iters} iterations...")

            if self.train_config.lr_scheduler == "cosine":
                self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                    self.optimizer,
                    T_max=main_iters,
                )
            elif self.train_config.lr_scheduler == "linear":
                self.scheduler = torch.optim.lr_scheduler.LinearLR(
                    self.optimizer,
                    start_factor=1.0,
                    end_factor=0.1,  # Don't go to 0.0 to avoid issues
                    total_iters=main_iters,
                )
            else:
                self.scheduler = torch.optim.lr_scheduler.ConstantLR(self.optimizer, factor=1.0)

            # Add warmup
            if self.train_config.warmup_steps > 0:
                logger.info(f"Adding warmup for {self.train_config.warmup_steps} steps...")
                warmup_scheduler = torch.optim.lr_scheduler.LinearLR(
                    self.optimizer,
                    start_factor=0.1,  # Start from small value, not 0.0
                    end_factor=1.0,
                    total_iters=self.train_config.warmup_steps,
                )
                self.scheduler = torch.optim.lr_scheduler.SequentialLR(
                    self.optimizer,
                    schedulers=[warmup_scheduler, self.scheduler],
                    milestones=[self.train_config.warmup_steps],
                )
        except Exception as e:
            logger.error(f"Failed to setup scheduler: {e}")
            raise

    def _setup_loss(self):
        """Setup loss function for diffusion training"""
        from .losses import DiffusionLoss
        self.loss_fn = DiffusionLoss(
            vocab_size=self.model_config.vocab_size,
            use_energy_model=self.model_config.use_energy_model,
            label_smoothing=0.1  # Small amount of label smoothing
        )

    def _setup_mixed_precision(self):
        """Setup mixed precision training"""
        if self.train_config.use_deepspeed or self.train_config.use_fsdp:
            self.scaler = None  # Handled by FSDP/DeepSpeed
        elif self.train_config.mixed_precision == "fp16":
            self.scaler = torch.cuda.amp.GradScaler()
        else:
            self.scaler = None

    def _compile_model(self):
        """Compile model with torch.compile"""
        if not self.train_config.use_deepspeed:
            self.model = torch.compile(
                self.model,
                mode=self.train_config.torch_compile_mode,
                fullgraph=False,
                dynamic=True,
            )
            if self.rank == 0:
                logger.info(f"Model compiled with mode={self.train_config.torch_compile_mode}")

    def _setup_logging(self):
        """Setup logging and monitoring"""
        if self.rank == 0 and getattr(self.train_config, 'use_wandb', False):
            try:
                logger.info("Initializing Weights & Biases (may take few seconds)...")
                
                # Try to import signal for timeout (works on Linux/Kaggle, not Windows)
                try:
                    import signal
                    
                    # Set 10-second timeout for wandb init
                    def timeout_handler(signum, frame):
                        raise TimeoutError("Wandb initialization timed out after 10 seconds")
                    
                    signal.signal(signal.SIGALRM, timeout_handler)
                    signal.alarm(10)  # 10 second timeout
                    
                    try:
                        wandb.init(
                            project=self.train_config.wandb_project,
                            config={
                                "model": self.model_config.__dict__,
                                "training": self.train_config.__dict__,
                            },
                        )
                        logger.info("✓ Wandb initialized successfully")
                    finally:
                        signal.alarm(0)  # Cancel timeout
                        
                except (ImportError, AttributeError, TimeoutError) as e:
                    # signal module not available (Windows) or timeout occurred
                    logger.warning(f"Wandb initialization failed or timed out: {e}")
                    logger.info("Continuing training without wandb logging...")
                    
            except Exception as e:
                logger.warning(f"Wandb setup failed: {e}. Continuing without wandb...")
        else:
            if not getattr(self.train_config, 'use_wandb', False):
                logger.info("Wandb logging disabled in config")
            else:
                logger.info("Wandb logging skipped (not rank 0)")

    def train(self):
        """Main training loop"""
        logger.info("Starting training...")
        
        # Create progress bar with clean single-line updates
        pbar = tqdm(
            total=self.train_config.num_train_steps,
            desc="Training",
            disable=self.rank != 0,
            dynamic_ncols=True,
            position=0,
            leave=True,
            mininterval=0.5,  # Update at most twice per second
            bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]'
        )

        self.model.train()

        while self.global_step < self.train_config.num_train_steps:
            for batch in self.train_dataloader:
                if self.global_step >= self.train_config.num_train_steps:
                    break

                # Move batch to device
                batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v 
                        for k, v in batch.items()}

                # Forward pass
                loss = self.training_step(batch)

                # Backward pass
                if self.train_config.use_deepspeed:
                    self.model.backward(loss)
                    self.model.step()
                else:
                    if self.scaler is not None:
                        self.scaler.scale(loss).backward()
                        self.scaler.unscale_(self.optimizer)
                        torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(),
                            self.train_config.max_grad_norm
                        )
                        self.scaler.step(self.optimizer)
                    else:
                        loss.backward()
                        torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(),
                            self.train_config.max_grad_norm
                        )
                        self.optimizer.step()

                    self.optimizer.zero_grad()
                    self.scheduler.step()

                self.global_step += 1

                # Logging
                if self.global_step % self.train_config.logging_steps == 0:
                    self._log_metrics({"train/loss": loss.item()})

                # Evaluation
                if self.global_step % self.train_config.eval_steps == 0:
                    self.evaluate()

                # Checkpointing
                if self.global_step % self.train_config.save_steps == 0:
                    self.save_checkpoint()

                if self.rank == 0:
                    pbar.update(1)
                    pbar.set_postfix({
                        'loss': f'{loss.item():.4f}',
                        'lr': f'{self.scheduler.get_last_lr()[0]:.2e}'
                    })

                if self.global_step >= self.train_config.num_train_steps:
                    break

        if self.rank == 0:
            pbar.close()
            logger.info("Training completed!")
    def training_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
        """Single training step for diffusion training"""
        input_ids = batch['input_ids']  # (batch, seq_len)
        attention_mask = batch.get('attention_mask')  # (batch, seq_len)

        batch_size, seq_len = input_ids.shape

        # Sample random timestep for diffusion
        # Use a simple linear schedule from 0 to max_steps
        max_timestep = self.model_config.num_diffusion_steps
        timestep = torch.randint(0, max_timestep, (batch_size,), device=self.device)

        # Create noisy version by masking random positions
        # For simplicity, mask a fixed percentage of tokens
        noise_ratio = 0.15  # 15% masking like BERT

        # Create mask for positions to noise
        mask_positions = torch.rand(batch_size, seq_len, device=self.device) < noise_ratio

        # Don't mask special tokens (assuming first and last are special)
        # For simplicity, avoid masking first and last tokens
        mask_positions[:, 0] = False  # Don't mask BOS
        mask_positions[:, -1] = False  # Don't mask EOS if present

        # Apply masking - replace with mask token (0)
        noisy_input_ids = input_ids.clone()
        mask_token_id = getattr(self.model, 'mask_token_id', 0)  # Use model's mask token
        noisy_input_ids[mask_positions] = mask_token_id
        
        # Clamp input_ids to valid range [0, vocab_size-1]
        input_ids = torch.clamp(input_ids, 0, self.model_config.vocab_size - 1)
        noisy_input_ids = torch.clamp(noisy_input_ids, 0, self.model_config.vocab_size - 1)

        # Forward pass
        logits, confidence = self.model(
            input_ids=noisy_input_ids,
            timestep=timestep,
            attention_mask=attention_mask,
            return_confidence=True
        )

        # Compute diffusion loss
        # The loss compares predictions at masked positions to original tokens
        loss = self.loss_fn(
            logits=logits,
            targets=input_ids,
            mask_positions=mask_positions,
            confidence=confidence
        )

        return loss

    def evaluate(self):
        """Evaluation loop"""
        if self.eval_dataloader is None:
            return

        self.model.eval()
        total_loss = 0
        num_batches = 0

        with torch.no_grad():
            for batch in self.eval_dataloader:
                batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v 
                        for k, v in batch.items()}
                loss = self.training_step(batch)
                total_loss += loss.item()
                num_batches += 1

        avg_loss = total_loss / num_batches
        self._log_metrics({"eval/loss": avg_loss})

        self.model.train()

    def save_checkpoint(self):
        """Save model checkpoint"""
        if self.rank != 0:
            return  # Only rank 0 saves

        checkpoint_dir = Path(self.train_config.output_dir) / f"checkpoint-{self.global_step}"
        checkpoint_dir.mkdir(parents=True, exist_ok=True)

        logger.info(f"Saving checkpoint to {checkpoint_dir}")

        # Save model state - handle different model types
        if self.use_data_parallel:
            # DataParallel: unwrap model.module
            model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
            torch.save(model_to_save.state_dict(), checkpoint_dir / "pytorch_model.bin")
            
        elif self.train_config.use_fsdp and self.is_distributed:
            # FSDP: use special handling
            from torch.distributed.fsdp import FullStateDictConfig, StateDictType, FullyShardedDataParallel as FSDP
            save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
            with FSDP.state_dict_type(self.model, StateDictType.FULL_STATE_DICT, save_policy):
                torch.save(self.model.state_dict(), checkpoint_dir / "pytorch_model.bin")
                
        else:
            # Regular model or DDP: unwrap if needed
            model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
            torch.save(model_to_save.state_dict(), checkpoint_dir / "pytorch_model.bin")

        # Save optimizer and scheduler
        torch.save(self.optimizer.state_dict(), checkpoint_dir / "optimizer.pt")
        torch.save(self.scheduler.state_dict(), checkpoint_dir / "scheduler.pt")

        # Save training state
        torch.save({
            "global_step": self.global_step,
            "epoch": self.epoch,
        }, checkpoint_dir / "training_state.pt")

        logger.info(f"✓ Checkpoint saved at step {self.global_step}")

    def _log_metrics(self, metrics: Dict[str, Any]):
        """Log metrics to wandb"""
        if self.rank == 0:
            metrics["step"] = self.global_step
            if self.train_config.use_wandb:
                wandb.log(metrics)
            else:
                logger.info(f"Step {self.global_step}: {metrics}")
