Metadata-Version: 2.4
Name: stoa-env
Version: 0.1.1
Summary: Single-Agent Reinforcement Learning with JAX
Author: Edan Toledo
License: MIT License
        
        Copyright (c) 2025 Edan Toledo
        
        Permission is hereby granted, free of charge, to any person obtaining a copy
        of this software and associated documentation files (the "Software"), to deal
        in the Software without restriction, including without limitation the rights
        to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
        copies of the Software, and to permit persons to whom the Software is
        furnished to do so, subject to the following conditions:
        
        The above copyright notice and this permission notice shall be included in all
        copies or substantial portions of the Software.
        
        THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
        IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
        FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
        AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
        LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
        OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
        SOFTWARE.
        
Project-URL: Homepage, https://github.com/EdanToledo/Stoa
Project-URL: Bug Tracker, https://github.com/EdanToledo/Stoa/issues
Classifier: Development Status :: 3 - Alpha
Classifier: Environment :: Console
Classifier: Intended Audience :: Science/Research
Classifier: Intended Audience :: Developers
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python :: 3.10
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Topic :: Software Development :: Libraries :: Python Modules
Classifier: License :: OSI Approved :: MIT License
Requires-Python: >=3.10
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: chex>=0.1.89
Requires-Dist: flax>=0.10.5
Requires-Dist: jax<0.6.0,>=0.4.25
Requires-Dist: jaxlib>=0.5.3
Requires-Dist: numpy>=1.26.4
Provides-Extra: brax
Requires-Dist: brax>=0.9.0; extra == "brax"
Provides-Extra: gymnax
Requires-Dist: gymnax>=0.0.6; extra == "gymnax"
Requires-Dist: gymnasium>=1.1.1; extra == "gymnax"
Provides-Extra: jumanji
Requires-Dist: jumanji==1.0.0; extra == "jumanji"
Provides-Extra: kinetix
Requires-Dist: kinetix-env; extra == "kinetix"
Provides-Extra: navix
Requires-Dist: navix>=0.7.0; extra == "navix"
Provides-Extra: pgx
Requires-Dist: pgx>=2.6.0; extra == "pgx"
Provides-Extra: xminigrid
Requires-Dist: xminigrid; extra == "xminigrid"
Provides-Extra: all
Requires-Dist: stoa-env[brax]; extra == "all"
Requires-Dist: stoa-env[gymnax]; extra == "all"
Requires-Dist: stoa-env[jumanji]; extra == "all"
Requires-Dist: stoa-env[kinetix]; extra == "all"
Requires-Dist: stoa-env[navix]; extra == "all"
Requires-Dist: stoa-env[pgx]; extra == "all"
Requires-Dist: stoa-env[xminigrid]; extra == "all"
Provides-Extra: dev
Requires-Dist: black; extra == "dev"
Requires-Dist: coverage; extra == "dev"
Requires-Dist: flake8; extra == "dev"
Requires-Dist: importlib-metadata; extra == "dev"
Requires-Dist: isort; extra == "dev"
Requires-Dist: livereload; extra == "dev"
Requires-Dist: mkdocs; extra == "dev"
Requires-Dist: mkdocs-git-revision-date-plugin; extra == "dev"
Requires-Dist: mkdocs-include-markdown-plugin; extra == "dev"
Requires-Dist: mkdocs-material; extra == "dev"
Requires-Dist: mkdocs-mermaid2-plugin; extra == "dev"
Requires-Dist: mkdocstrings; extra == "dev"
Requires-Dist: mknotebooks; extra == "dev"
Requires-Dist: mypy; extra == "dev"
Requires-Dist: nbmake; extra == "dev"
Requires-Dist: pre-commit; extra == "dev"
Requires-Dist: promise; extra == "dev"
Requires-Dist: pymdown-extensions; extra == "dev"
Requires-Dist: pytest; extra == "dev"
Requires-Dist: pytest-cov; extra == "dev"
Requires-Dist: pytest-mock; extra == "dev"
Requires-Dist: pytest-parallel; extra == "dev"
Requires-Dist: pytest-xdist; extra == "dev"
Requires-Dist: pytype; extra == "dev"
Requires-Dist: testfixtures; extra == "dev"
Dynamic: license-file

<p align="center">
  <a href="docs/images/stoa.png">
    <img src="docs/images/stoa.jpeg" alt="Stoa logo" width="30%"/>
  </a>
</p>

<div align="center">
  <a href="https://www.python.org/doc/versions/">
    <img src="https://img.shields.io/badge/python-3.10-blue" alt="Python Versions"/>
  </a>
  <a href="https://github.com/EdanToledo/Stoa/blob/main/LICENSE">
    <img src="https://img.shields.io/badge/License-MIT-yellow.svg" alt="License"/>
  </a>
  <a href="https://github.com/psf/black">
    <img src="https://img.shields.io/badge/code%20style-black-000000" alt="Code Style"/>
  </a>
  <a  href="http://mypy-lang.org/">
    <img src="https://www.mypy-lang.org/static/mypy_badge.svg" alt="MyPy" />
</a>
</div>

<h2 align="center">
  <p>A JAX-Native Interface for Reinforcement Learning Environments</p>
</h2>

## 🚀 Welcome to Stoa

Stoa provides a lightweight, JAX-native interface for reinforcement learning environments. It defines a common abstraction layer that enables different environment libraries to work together seamlessly in JAX workflows.

> ⚠️ **Early Development** – Core abstractions are in place, but the library is still growing!

---

## 🎯 What Stoa Provides

* **Common Interface**: A standardized `Environment` base class that defines the contract for RL environments in JAX.
* **JAX-Native Design**: Pure-functional `step` and `reset` operations compatible with JAX transformations like `jit` and `vmap`.
* **Environment Wrappers**: A flexible system for composing and extending environments with additional functionality.
* **Space Definitions**: Structured representations for observation, action, and state spaces.
* **TimeStep Protocol**: A standardized `TimeStep` structure to represent environment transitions with clear termination and truncation signals.

---

## 🛠️ Installation

You can install the core `stoa` library via pip:

```bash
pip install stoa-env
```

This minimal installation includes the core API and wrappers but no specific environment adapters.

### Environment Adapters

Adapters for external environment libraries are available as optional extras. You can install them individually or all at once.

**Install a specific adapter:**

```bash
# Example for Gymnax
pip install "stoa-env[gymnax]"

# Example for Brax
pip install "stoa-env[brax]"
```

**Install all available adapters:**

```bash
pip install "stoa-env[all]"
```

---

## 🧩 Available Adapters

Stoa currently supports the following JAX-native environment libraries:

* **Brax**
* **Gymnax**
* **Jumanji**
* **Kinetix**
* **Navix**
* **PGX** (Game environments)
* **MuJoCo Playground**
* **XMinigrid**

---

## ✨ Available Wrappers

Stoa provides a rich set of wrappers to modify and extend environment behavior:

* **Core Wrappers**: `AutoResetWrapper`, `RecordEpisodeMetrics`, `AddRNGKey`, `VmapWrapper`.
* **Observation Wrappers**: `FlattenObservationWrapper`, `FrameStackingWrapper`, `ObservationExtractWrapper`, `AddActionMaskWrapper`, `AddStartFlagAndPrevAction`, `AddStepCountWrapper`, `MakeChannelLast`, `ObservationTypeWrapper`.
* **Action Space Wrappers**: `MultiDiscreteToDiscreteWrapper`, `MultiBoundedToBoundedWrapper`.
* **Utility Wrappers**: `EpisodeStepLimitWrapper`, `ConsistentExtrasWrapper`.

---

## ⚡ Usage Example

Here's how to adapt a `gymnax` environment and compose it with several wrappers:

```python
import jax
import gymnax
from stoa import GymnaxToStoa, FlattenObservationWrapper, AutoResetWrapper, RecordEpisodeMetrics

# 1. Instantiate a base environment from a supported library
gymnax_env, env_params = gymnax.make("CartPole-v1")

# 2. Adapt the environment to the Stoa interface
env = GymnaxToStoa(gymnax_env, env_params)

# 3. Apply standard wrappers
# Note: The order of wrappers matters.
env = AutoResetWrapper(env, next_obs_in_extras=True)
env = RecordEpisodeMetrics(env)

# JIT compile the reset and step functions for performance
env.reset = jax.jit(env.reset)
env.step = jax.jit(env.step)

# 4. Interact with the environment
rng_key = jax.random.PRNGKey(0)
state, timestep = env.reset(rng_key)
total_reward = 0

for _ in range(100):
    action = env.action_space().sample(rng_key)
    state, timestep = env.step(state, action)
    total_reward += timestep.reward

    if timestep.last():
        # Access metrics recorded by the RecordEpisodeMetrics wrapper
        episode_return = timestep.extras['episode_metrics']['episode_return']
        print(f"Episode finished. Return: {episode_return}")

        # The state has been auto-reset, so we can continue the loop
        total_reward = 0
```

---

## 🛣️ Roadmap

* **Documentation**: Expand documentation with detailed tutorials and API references.
* **More Wrappers**: Add more common utility wrappers (e.g., observation normalization, reward clipping).
* **Integration Examples**: Provide examples of how to integrate `stoa` with popular JAX-based RL libraries.

---

## 🤝 Contributing

We're building Stoa to provide a common foundation for JAX-based RL research. Contributions are welcome!

---

### 📚 Related Projects

* **Stoix** – Distributed single-agent RL in JAX
* **Gymnax** – Classic control environments in JAX
* **Brax** – Physics-based environments in JAX
* **Jumanji** – Board games and optimization problems in JAX
* **Navix** – Grid-world environments in JAX
* **PGX** - Classic board and card game environments in JAX
* **Kinetix** - Robotics environments in JAX
