Metadata-Version: 2.4
Name: jepa
Version: 0.1.6
Summary: Joint-Embedding Predictive Architecture for Self-Supervised Learning
Home-page: https://github.com/dipsivenkatesh/jepa
Author: Dilip Venkatesh
Author-email: Dilip Venkatesh <your.email@example.com>
Maintainer-email: Dilip Venkatesh <your.email@example.com>
License: MIT
Project-URL: Homepage, https://github.com/dipsivenkatesh/jepa
Project-URL: Documentation, https://jepa.readthedocs.io/
Project-URL: Repository, https://github.com/dipsivenkatesh/jepa.git
Project-URL: Bug Tracker, https://github.com/dipsivenkatesh/jepa/issues
Project-URL: Changelog, https://github.com/dipsivenkatesh/jepa/blob/main/CHANGELOG.md
Keywords: self-supervised-learning,representation-learning,deep-learning,pytorch,jepa,joint-embedding,predictive-architecture
Classifier: Development Status :: 4 - Beta
Classifier: Intended Audience :: Science/Research
Classifier: Intended Audience :: Developers
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Topic :: Software Development :: Libraries :: Python Modules
Requires-Python: >=3.8
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: torch>=1.12.0
Requires-Dist: numpy>=1.21.0
Requires-Dist: tqdm>=4.64.0
Requires-Dist: transformers>=4.20.0
Requires-Dist: datasets>=2.0.0
Requires-Dist: wandb>=0.13.0
Requires-Dist: scikit-learn>=1.1.0
Requires-Dist: pyyaml>=6.0
Requires-Dist: tensorboard>=2.9.0
Requires-Dist: matplotlib>=3.5.0
Provides-Extra: dev
Requires-Dist: pytest>=7.0.0; extra == "dev"
Requires-Dist: pytest-cov>=3.0.0; extra == "dev"
Requires-Dist: black>=22.0.0; extra == "dev"
Requires-Dist: isort>=5.10.0; extra == "dev"
Requires-Dist: flake8>=5.0.0; extra == "dev"
Requires-Dist: mypy>=0.950; extra == "dev"
Requires-Dist: pre-commit>=2.20.0; extra == "dev"
Provides-Extra: docs
Requires-Dist: sphinx>=5.0.0; extra == "docs"
Requires-Dist: sphinx-rtd-theme>=1.0.0; extra == "docs"
Requires-Dist: myst-parser>=0.18.0; extra == "docs"
Requires-Dist: sphinx-autodoc-typehints>=1.19.0; extra == "docs"
Provides-Extra: all
Requires-Dist: jepa[dev,docs]; extra == "all"
Dynamic: author
Dynamic: home-page
Dynamic: license-file
Dynamic: requires-python

# JEPA Framework

<div align="center">

![JEPA Logo](https://img.shields.io/badge/JEPA-Framework-blue?style=for-the-badge)
[![PyPI version](https://badge.fury.io/py/jepa.svg)](https://badge.fury.io/py/jepa)
[![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/)
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
[![Documentation](https://img.shields.io/badge/docs-latest-brightgreen.svg)](https://jepa.readthedocs.io/)

**A powerful self-supervised learning framework for Joint-Embedding Predictive Architecture (JEPA)**

[Installation](#installation) •
[Quick Start](#quick-start) •
[Documentation](https://jepa.readthedocs.io/) •
[Examples](#examples) •
[Contributing](#contributing)

</div>

## 🚀 Overview

JEPA (Joint-Embedding Predictive Architecture) is a cutting-edge self-supervised learning framework that learns rich representations by predicting parts of the input from other parts. This implementation provides a flexible, production-ready framework for training JEPA models across multiple modalities.

### Key Features

🔧 **Modular Design**
- Flexible encoder-predictor architecture
- Support for any PyTorch model as encoder/predictor
- Easy to extend and customize for your specific needs

🌍 **Multi-Modal Support**
- **Computer Vision**: Images, videos, medical imaging
- **Natural Language Processing**: Text, documents, code
- **Time Series**: Sequential data, forecasting, anomaly detection
- **Audio**: Speech, music, environmental sounds
- **Multimodal**: Vision-language, audio-visual learning

⚡ **High Performance**
- Mixed precision training (FP16/BF16)
- Native DistributedDataParallel (DDP) support
- Memory-efficient implementations
- Optimized for both research and production

📊 **Comprehensive Logging**
- Weights & Biases integration
- TensorBoard support
- Console logging with rich formatting
- Multi-backend logging system

🎛️ **Production Ready**
- CLI interface for easy deployment
- Flexible YAML configuration system
- Comprehensive testing suite
- Docker support and containerization
- Type hints throughout

## 🏗️ Architecture

JEPA follows a simple yet powerful architecture:

```mermaid
graph LR
    A[Input Data] --> B[Context/Target Split]
    B --> C[Encoder]
    C --> D[Joint Embedding Space]
    D --> E[Predictor]
    E --> F[Target Prediction]
    F --> G[Loss Computation]
```

The model learns by:
1. **Splitting** input into context and target regions
2. **Encoding** both context and target separately  
3. **Predicting** target embeddings from context embeddings
4. **Learning** representations that capture meaningful relationships

## 📦 Installation

### From PyPI (Recommended)

```bash
pip install jepa
```

### From Source

```bash
git clone https://github.com/dipsivenkatesh/jepa.git
cd jepa
pip install -e .
```

### Development Installation

```bash
git clone https://github.com/dipsivenkatesh/jepa.git
cd jepa
pip install -e ".[dev,docs]"
```

### Docker

```bash
docker pull dipsivenkatesh/jepa:latest
docker run -it dipsivenkatesh/jepa:latest
```

## 🚀 Quick Start

### Python API

```python
import torch
from torch.utils.data import DataLoader, TensorDataset

from jepa.models import JEPA
from jepa.models.encoder import Encoder
from jepa.models.predictor import Predictor
from jepa.trainer import create_trainer

# Toy dataset of (state_t, state_t1) pairs
dataset = TensorDataset(torch.randn(256, 16, 128), torch.randn(256, 16, 128))
train_loader = DataLoader(dataset, batch_size=8, shuffle=True)

# Build model components
encoder = Encoder(hidden_dim=128)
predictor = Predictor(hidden_dim=128)
model = JEPA(encoder=encoder, predictor=predictor)

# Trainer with sensible defaults
trainer = create_trainer(model, learning_rate=3e-4, device="auto")

# Train for a couple of epochs
trainer.train(train_loader, num_epochs=2)

# Optional: stream metrics to Weights & Biases
trainer_ddp = create_trainer(
    model,
    learning_rate=3e-4,
    device="auto",
    logger="wandb",
    logger_project="jepa-experiments",
    logger_run_name="quickstart-run",
)

# Persist weights for downstream inference
model.save_pretrained("artifacts/jepa-small")

# Reload using the same model class
reloaded = JEPA.from_pretrained("artifacts/jepa-small", encoder=encoder, predictor=predictor)
```

### Distributed Training (DDP)

Launch multi-GPU jobs with PyTorch's launcher:

```bash
torchrun --nproc_per_node=4 scripts/train.py --config config/default_config.yaml
```

Inside your training script, enable DDP when you create the trainer:

```python
trainer = create_trainer(
    model,
    distributed=True,
    world_size=int(os.environ["WORLD_SIZE"]),
    local_rank=int(os.environ.get("LOCAL_RANK", 0)),
)
```

The trainer wraps the model in `DistributedDataParallel`, synchronizes losses, and restricts logging/checkpointing to rank zero automatically.

### Action-Conditioned Variant

Use `JEPAAction` when actions influence the next state. Provide a state encoder, an action encoder, and a predictor that consumes the concatenated `[z_t, a_t]` embedding.

```python
from jepa import JEPAAction
import torch.nn as nn

state_dim = 512
action_dim = 64

# Example encoders (replace with your own)
state_encoder = nn.Sequential(
    nn.Flatten(),
    nn.Linear(784, state_dim),
)
action_encoder = nn.Sequential(
    nn.Linear(10, 128), nn.ReLU(),
    nn.Linear(128, action_dim),
)

# Predictor takes [state_dim + action_dim] → state_dim
predictor = nn.Sequential(
    nn.Linear(state_dim + action_dim, 512), nn.ReLU(),
    nn.Linear(512, state_dim),
)

model = JEPAAction(state_encoder, action_encoder, predictor)
```

### Command Line Interface

```bash
# Train a model
jepa-train --config config/default_config.yaml

# Train with custom parameters
jepa-train --config config/vision_config.yaml \
           --batch-size 64 \
           --learning-rate 0.001 \
           --num-epochs 100

# Evaluate a trained model
jepa-evaluate --config config/default_config.yaml \
              --checkpoint checkpoints/best_model.pth

# Generate a configuration template
jepa-train --generate-config my_config.yaml

# Get help
jepa-train --help
```

### Configuration

JEPA uses YAML configuration files for easy experiment management:

```yaml
# config/my_experiment.yaml
model:
  encoder_type: "transformer"
  encoder_dim: 768
  predictor_type: "mlp"
  predictor_hidden_dim: 2048

training:
  batch_size: 32
  learning_rate: 0.0001
  num_epochs: 100
  warmup_epochs: 10

data:
  train_data_path: "data/train"
  val_data_path: "data/val"
  sequence_length: 16

logging:
  wandb:
    enabled: true
    project: "jepa-experiments"
  tensorboard:
    enabled: true
    log_dir: "./tb_logs"
```

## 🎯 Use Cases

### Computer Vision
- **Image Classification**: Pre-train backbones for downstream tasks
- **Object Detection**: Learn robust visual representations
- **Medical Imaging**: Analyze medical scans and imagery
- **Satellite Imagery**: Process large-scale geographic data

### Natural Language Processing
- **Language Models**: Pre-train transformer architectures
- **Document Understanding**: Learn document-level representations
- **Code Analysis**: Understand code structure and semantics
- **Cross-lingual Learning**: Build multilingual representations

### Time Series Analysis
- **Forecasting**: Pre-train models for prediction tasks
- **Anomaly Detection**: Learn normal patterns in sequential data
- **Financial Modeling**: Analyze market trends and patterns
- **IoT Sensors**: Process sensor data streams

### Multimodal Learning
- **Vision-Language**: Combine images and text understanding
- **Audio-Visual**: Learn from synchronized audio and video
- **Cross-Modal Retrieval**: Search across different modalities
- **Embodied AI**: Integrate multiple sensor modalities

## 📚 Examples

### Vision Example
```python
from jepa import JEPADataset, JEPATrainer, load_config
import torch.nn as nn

# Custom vision encoder
class VisionEncoder(nn.Module):
    def __init__(self, input_dim=3, hidden_dim=512):
        super().__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(input_dim, 64, 7, 2, 3),
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, 2, 1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(128, hidden_dim)
        )
    
    def forward(self, x):
        return self.conv_layers(x)

# Load config and customize
config = load_config("config/vision_config.yaml")
trainer = JEPATrainer(config=config, custom_encoder=VisionEncoder)
trainer.train()
```

### NLP Example
```python
from transformers import AutoModel
from jepa import JEPA, JEPATrainer

# Use pre-trained transformer as encoder
class TransformerEncoder(nn.Module):
    def __init__(self, model_name="bert-base-uncased"):
        super().__init__()
        self.transformer = AutoModel.from_pretrained(model_name)
    
    def forward(self, input_ids, attention_mask=None):
        outputs = self.transformer(input_ids, attention_mask=attention_mask)
        return outputs.last_hidden_state.mean(dim=1)  # Pool over sequence

config = load_config("config/nlp_config.yaml")
trainer = JEPATrainer(config=config, custom_encoder=TransformerEncoder)
trainer.train()
```

### Time Series Example
```python
from jepa import create_dataset, JEPATrainer

# Create time series dataset
dataset = create_dataset(
    data_path="data/timeseries.csv",
    sequence_length=50,
    prediction_length=10,
    features=['sensor1', 'sensor2', 'sensor3']
)

config = load_config("config/timeseries_config.yaml")
trainer = JEPATrainer(config=config, train_dataset=dataset)
trainer.train()
```

## 📖 Documentation

- **[Full Documentation](https://jepa.readthedocs.io/)** - Complete API reference and guides
- **[Installation Guide](https://jepa.readthedocs.io/en/latest/guides/installation.html)** - Detailed installation instructions
- **[Configuration Guide](https://jepa.readthedocs.io/en/latest/guides/configuration.html)** - How to configure your experiments
- **[Training Guide](https://jepa.readthedocs.io/en/latest/guides/training.html)** - Training best practices
- **[API Reference](https://jepa.readthedocs.io/en/latest/api/)** - Complete API documentation

## 🔧 Development

### Setting up Development Environment

```bash
git clone https://github.com/dipsivenkatesh/jepa.git
cd jepa

# Create virtual environment
python -m venv venv
source venv/bin/activate  # On Windows: venv\Scripts\activate

# Install in development mode
pip install -e ".[dev,docs]"

# Install pre-commit hooks
pre-commit install
```

### Running Tests

```bash
# Run all tests
pytest

# Run with coverage
pytest --cov=jepa --cov-report=html

# Run specific test
pytest tests/test_model.py::test_jepa_forward
```

### Code Quality

```bash
# Format code
black jepa/
isort jepa/

# Type checking
mypy jepa/

# Linting
flake8 jepa/
```

## 🤝 Contributing

We welcome contributions! Please see our [Contributing Guide](CONTRIBUTING.md) for details.

### Ways to Contribute
- 🐛 **Bug Reports**: Submit detailed bug reports
- ✨ **Feature Requests**: Suggest new features or improvements  
- 📖 **Documentation**: Improve documentation and examples
- 🔧 **Code**: Submit pull requests with bug fixes or new features
- 🎯 **Use Cases**: Share your JEPA applications and results

### Development Workflow
1. Fork the repository
2. Create a feature branch: `git checkout -b feature-name`
3. Make your changes and add tests
4. Ensure all tests pass: `pytest`
5. Submit a pull request

## 📄 Citation

If you use JEPA in your research, please cite:

```bibtex
@software{jepa2025,
  title = {JEPA: Joint-Embedding Predictive Architecture Framework},
  author = {Venkatesh, Dilip},
  year = {2025},
  url = {https://github.com/dipsivenkatesh/jepa},
  version = {0.1.0}
}
```

## 📝 License

This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.

## 🙏 Acknowledgments

- Inspired by the original JEPA paper and Meta's research
- Built with PyTorch, Transformers, and other amazing open-source libraries
- Thanks to all contributors and users of the framework

## 📞 Support

- **GitHub Issues**: [Report bugs or request features](https://github.com/dipsivenkatesh/jepa/issues)
- **Documentation**: [Read the full documentation](https://jepa.readthedocs.io/)
- **Discussions**: [Join community discussions](https://github.com/dipsivenkatesh/jepa/discussions)

Steps to push latest version:
rm -rf dist build *.egg-info
python -m build
twine upload dist/*

---

<div align="center">

**[⭐ Star this repo](https://github.com/dipsivenkatesh/jepa) | [📖 Read the docs](https://jepa.readthedocs.io/) | [🐛 Report issues](https://github.com/dipsivenkatesh/jepa/issues)**

*Built with ❤️ by [Dilip Venkatesh](https://dipsivenkatesh.github.io/)*

</div>
