"""
Data utilities for training
Provides dataset classes for text and multimodal data
"""
import torch
from torch.utils.data import Dataset
from typing import Optional, Dict, Any


class TextDataset(Dataset):
    """
    Dataset for text sequences
    Tokenizes and prepares text data for training
    
    Args:
        dataset: HuggingFace dataset or list of text examples
        tokenizer: Tokenizer to use for encoding
        max_length: Maximum sequence length
        text_column: Name of the text column in dataset
    """
    
    def __init__(
        self, 
        dataset, 
        tokenizer, 
        max_length=1024,
        text_column='text'
    ):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.text_column = text_column
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        item = self.dataset[idx]
        
        # Extract text from item
        if isinstance(item, dict):
            text = item.get(self.text_column, item.get('text', ''))
        else:
            text = str(item)
        
        # Tokenize
        encoded = self.tokenizer(
            text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoded['input_ids'].squeeze(0),
            'attention_mask': encoded['attention_mask'].squeeze(0)
        }


class MultimodalDataset(Dataset):
    """
    Dataset for multimodal (text + image) data
    
    Args:
        dataset: HuggingFace dataset with text and image columns
        tokenizer: Tokenizer for text
        image_processor: Processor for images
        max_length: Maximum text sequence length
        text_column: Name of text column
        image_column: Name of image column
    """
    
    def __init__(
        self,
        dataset,
        tokenizer,
        image_processor,
        max_length=512,
        text_column='text',
        image_column='image'
    ):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.image_processor = image_processor
        self.max_length = max_length
        self.text_column = text_column
        self.image_column = image_column
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        item = self.dataset[idx]
        
        # Process text
        text = item[self.text_column]
        text_encoded = self.tokenizer(
            text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        # Process image
        image = item[self.image_column]
        image_encoded = self.image_processor(
            images=image,
            return_tensors='pt'
        )
        
        return {
            'input_ids': text_encoded['input_ids'].squeeze(0),
            'attention_mask': text_encoded['attention_mask'].squeeze(0),
            'pixel_values': image_encoded['pixel_values'].squeeze(0),
        }


class PreTokenizedDataset(Dataset):
    """
    Dataset for pre-tokenized sequences
    Useful when working with large datasets where tokenization is done offline
    
    Args:
        token_files: List of files containing tokenized sequences
        max_length: Maximum sequence length
    """
    
    def __init__(self, token_ids, max_length=1024):
        """
        Args:
            token_ids: List or tensor of token IDs
            max_length: Maximum sequence length
        """
        self.token_ids = token_ids
        self.max_length = max_length
    
    def __len__(self):
        return len(self.token_ids)
    
    def __getitem__(self, idx):
        tokens = self.token_ids[idx]
        
        # Ensure tensor
        if not isinstance(tokens, torch.Tensor):
            tokens = torch.tensor(tokens, dtype=torch.long)
        
        # Truncate if needed
        if len(tokens) > self.max_length:
            tokens = tokens[:self.max_length]
        
        # Pad if needed
        if len(tokens) < self.max_length:
            padding = torch.zeros(self.max_length - len(tokens), dtype=torch.long)
            tokens = torch.cat([tokens, padding])
        
        # Create attention mask (1 for real tokens, 0 for padding)
        attention_mask = (tokens != 0).long()
        
        return {
            'input_ids': tokens,
            'attention_mask': attention_mask
        }
