Metadata-Version: 2.4
Name: torchtl
Version: 0.1.0
Summary: A very minimal training loop abstraction for PyTorch
Requires-Python: >=3.13
Description-Content-Type: text/markdown
License-File: LICENSE
Dynamic: license-file

## TorchTL

A very minimal training loop abstraction for PyTorch.

### Why TorchTL?

- **Minimal**: Only PyTorch as dependency
- **Flexible**: Use existing PyTorch models, no need to subclass
- **Extensible**: Callback system for custom behavior
- **Automatic**: Handles device management, mixed precision, gradient accumulation
- **No magic**: Simple, readable code that does what you expect

### Features

Automatic device management (CPU/CUDA), mixed precision training, gradient accumulation, gradient clipping, checkpoints with resume capability, callback system for extensibility, early stopping, LR scheduling, progress tracking, exponential moving average (EMA), etc.

### Installation

```bash
pip install torchtl
```

### Quick Overview

```python
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchtl import Trainer

model = nn.Linear(10, 1)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.MSELoss()

trainer = Trainer(
    model=model,
    optimizer=optimizer,
    loss_fn=loss_fn,
    device='cuda',
    mixed_precision=True
)

history = trainer.fit(train_loader, val_loader, epochs=10)
```

### Basic Usage

#### Simple training loop

```python
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from torchtl import Trainer

X_train = torch.randn(1000, 10)
y_train = torch.randn(1000, 1)
train_dataset = TensorDataset(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

model = nn.Sequential(
    nn.Linear(10, 64),
    nn.ReLU(),
    nn.Linear(64, 1)
)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.MSELoss()

trainer = Trainer(model, optimizer, loss_fn)
trainer.fit(train_loader, epochs=10)
```

#### Training with validation

```python
X_val = torch.randn(200, 10)
y_val = torch.randn(200, 1)
val_dataset = TensorDataset(X_val, y_val)
val_loader = DataLoader(val_dataset, batch_size=32)

history = trainer.fit(train_loader, val_loader, epochs=10)
print(f"Train losses: {history['train_loss']}")
print(f"Val losses: {history['val_loss']}")
```

#### Mixed precision training

```python
trainer = Trainer(
    model=model,
    optimizer=optimizer,
    loss_fn=loss_fn,
    mixed_precision=True
)
```

#### Gradient accumulation

```python
trainer = Trainer(
    model=model,
    optimizer=optimizer,
    loss_fn=loss_fn,
    grad_acc_steps=4
)
```

#### Gradient clipping

```python
trainer = Trainer(
    model=model,
    optimizer=optimizer,
    loss_fn=loss_fn,
    max_grad_norm=1.0
)
```

### Callbacks

#### Progress tracking

```python
from torchtl import ProgressCallback

trainer = Trainer(model, optimizer, loss_fn)
trainer.add_callback(ProgressCallback(print_every=100))
trainer.fit(train_loader, epochs=10)
```

#### checkpointing

```python
from torchtl import CheckpointCallback

checkpoint_cb = CheckpointCallback(
    checkpoint_dir='./checkpoints',
    save_every_n_epochs=1,
    keep_last_n=3
)

trainer.add_callback(checkpoint_cb)
trainer.fit(train_loader, val_loader, epochs=10)
```

#### Save best model only

```python
checkpoint_cb = CheckpointCallback(
    checkpoint_dir='./checkpoints',
    save_best_only=True,
    monitor='val_loss',
    mode='min'
)

trainer.add_callback(checkpoint_cb)
trainer.fit(train_loader, val_loader, epochs=10)
```

#### Early stopping

```python
from torchtl import EarlyStoppingCallback, StopTraining

early_stop_cb = EarlyStoppingCallback(
    patience=5,
    monitor='val_loss',
    mode='min',
    min_delta=0.001
)

trainer.add_callback(early_stop_cb)

try:
    trainer.fit(train_loader, val_loader, epochs=100)
except StopTraining as e:
    print(f"Training stopped: {e}")
```

#### Learning rate scheduling

```python
from torchtl import LearningRateSchedulerCallback
from torch.optim.lr_scheduler import StepLR

scheduler = StepLR(optimizer, step_size=5, gamma=0.1)
scheduler_cb = LearningRateSchedulerCallback(scheduler)

trainer.add_callback(scheduler_cb)
trainer.fit(train_loader, epochs=20)
```

#### `ReduceLROnPlateau`

```python
from torch.optim.lr_scheduler import ReduceLROnPlateau

scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=3)
scheduler_cb = LearningRateSchedulerCallback(scheduler)

trainer.add_callback(scheduler_cb)
trainer.fit(train_loader, val_loader, epochs=20)
```

#### Multiple callbacks

```python
from torchtl import (
    ProgressCallback,
    CheckpointCallback,
    EarlyStoppingCallback,
    LearningRateSchedulerCallback
)

trainer.add_callback(ProgressCallback(print_every=50))
trainer.add_callback(CheckpointCallback('./checkpoints', save_best_only=True))
trainer.add_callback(EarlyStoppingCallback(patience=5))
trainer.add_callback(LearningRateSchedulerCallback(scheduler))

trainer.fit(train_loader, val_loader, epochs=100)
```

### Checkpoints

#### Manual save/load

```python
trainer.save_checkpoint('./checkpoint.pt')

trainer.load_checkpoint('./checkpoint.pt')
trainer.fit(train_loader, epochs=10)
```

#### Save with extra state

```python
trainer.save_checkpoint('./checkpoint.pt', best_accuracy=0.95, notes="best model")
```

#### Load with `strict=false`

```python
trainer.load_checkpoint('./checkpoint.pt', strict=False)
```

### Utilities

#### Count parameters

```python
from torchtl import count_params

total_params = count_params(model)
trainable_params = count_params(model, trainable_only=True)
print(f"Total: {total_params}, Trainable: {trainable_params}")
```

#### Freeze/unfreeze layers

```python
from torchtl import freeze_layers, unfreeze_layers

freeze_layers(model)

unfreeze_layers(model, layer_names=['fc', 'classifier'])

freeze_layers(model, layer_names=['conv1', 'conv2'])
```

#### Set random seed

```python
from torchtl import set_seed

set_seed(42)
```

#### Learning rate

```python
from torchtl import get_lr, set_lr

current_lr = get_lr(optimizer)
print(f"Current LR: {current_lr}")

set_lr(optimizer, 0.0001)
```

#### Exponential moving average

```python
from torchtl import ExponentialMovingAverage

ema = ExponentialMovingAverage(model, decay=0.999)

for epoch in range(epochs):
    trainer.train_epoch(train_loader)
    ema.update()

ema.apply_shadow()
val_metrics = trainer.validate(val_loader)
ema.restore()
```

### Custom Callbacks

```python
from torchtl import Callback

class CustomCallback(Callback):
    def on_epoch_start(self, trainer):
        print(f"Starting epoch {trainer.epoch + 1}")

    def on_epoch_end(self, trainer, metrics):
        print(f"Epoch {trainer.epoch} finished with loss: {metrics['loss']:.4f}")

    def on_batch_end(self, trainer, batch_idx, batch, metrics):
        if trainer.global_step % 100 == 0:
            print(f"Step {trainer.global_step}, Loss: {metrics['loss']:.4f}")

trainer.add_callback(CustomCallback())
trainer.fit(train_loader, epochs=10)
```

### Batch Format Support

TorchTL supports multiple batch formats.

#### Tuple/list format

```python
batch = (inputs, targets)
```

#### Dictionary format

```python
batch = {'inputs': inputs, 'targets': targets}
batch = {'input': inputs, 'target': targets}
```

### Misc usage

#### Custom training loop

```python
for epoch in range(10):
    train_metrics = trainer.train_epoch(train_loader)
    val_metrics = trainer.validate(val_loader)

    print(f"Epoch {epoch}: Train Loss={train_metrics['loss']:.4f}, Val Loss={val_metrics['val_loss']:.4f}")

    if val_metrics['val_loss'] < best_loss:
        best_loss = val_metrics['val_loss']
        trainer.save_checkpoint('./best_model.pt')
```

#### Access internal state

```python
print(f"Current epoch: {trainer.epoch}")
print(f"Global step: {trainer.global_step}")
print(f"Device: {trainer.device}")
```

### License

Apache v2.0 License.
