Metadata-Version: 2.4
Name: xmmutablemap
Version: 0.2.0
Summary: Immutable Map, compatible with JAX & Equinox
Project-URL: Homepage, https://github.com/GalacticDynamics/xmmutablemap
Project-URL: Bug Tracker, https://github.com/GalacticDynamics/xmmutablemap/issues
Project-URL: Discussions, https://github.com/GalacticDynamics/xmmutablemap/discussions
Project-URL: Changelog, https://github.com/GalacticDynamics/xmmutablemap/releases
Author-email: Galactic Dynamics Maintainers <nstarman@users.noreply.github.com>
License: Copyright 2024 Galactic Dynamics Maintainers
        
        Permission is hereby granted, free of charge, to any person obtaining a copy of
        this software and associated documentation files (the "Software"), to deal in
        the Software without restriction, including without limitation the rights to
        use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
        of the Software, and to permit persons to whom the Software is furnished to do
        so, subject to the following conditions:
        
        The above copyright notice and this permission notice shall be included in all
        copies or substantial portions of the Software.
        
        THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
        IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
        FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
        AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
        LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
        OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
        SOFTWARE.
License-File: LICENSE
Classifier: Development Status :: 1 - Planning
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3 :: Only
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Classifier: Topic :: Scientific/Engineering
Classifier: Typing :: Typed
Requires-Python: >=3.10
Requires-Dist: jax
Requires-Dist: typing-extensions>=4.8
Description-Content-Type: text/markdown

<h1 align='center'> xmmutablemap </h1>
<h3 align="center"><code>JAX</code>-compatible Immutable Mapping</h3>

JAX prefers immutable objects but neither Python nor JAX provide an immutable
dictionary. 😢 </br> This repository defines a light-weight immutable map
(lower-level than a dict) that JAX understands as a PyTree. 🎉 🕶️

## Installation

[![PyPI platforms][pypi-platforms]][pypi-link]
[![PyPI version][pypi-version]][pypi-link]

```bash
pip install xmmutablemap
```

<details>
  <summary>using <code>uv</code></summary>

```bash
uv add xmmutablemap
```

</details>
<details>
  <summary>from source, using pip</summary>

```bash
pip install git+https://github.com/GalacticDynamics/xmmutablemap.git
```

</details>
<details>
  <summary>building from source</summary>

```bash
cd /path/to/parent
git clone https://github.com/GalacticDynamics/xmmutablemap.git
cd xmmutablemap
pip install -e .  # editable mode
```

</details>

## Documentation

`xmutablemap` provides the class `ImmutableMap`, which is a full implementation
of
[Python's `Mapping` ABC](https://docs.python.org/3/library/collections.abc.html#collections-abstract-base-classes).
If you've used a `dict` then you already know how to use `ImmutableMap`! The
things `ImmutableMap` adds is 1) immutability (and related benefits like
hashability) and 2) compatibility with `JAX`.

```python
from xmmutablemap import ImmutableMap

print(ImmutableMap(a=1, b=2, c=3))
# ImmutableMap({'a': 1, 'b': 2, 'c': 3})

print(ImmutableMap({"a": 1, "b": 2.0, "c": "3"}))
# ImmutableMap({'a': 1, 'b': 2.0, 'c': '3'})
```

### JAX Integration

One of the key benefits of `ImmutableMap` is its compatibility with JAX. Since
it's immutable and hashable, it can be used in places where JAX would normally
complain about mutable objects like regular dictionaries.

#### Using ImmutableMap as a Default in JAX Dataclasses

Here's an example showing how `ImmutableMap` can be used as a default value in a
dataclass, which is particularly useful with JAX:

```python
import functools
import jax
import jax.numpy as jnp
from dataclasses import dataclass
from xmmutablemap import ImmutableMap


@functools.partial(
    jax.tree_util.register_dataclass, data_fields=["params"], meta_fields=["batch_size"]
)
@dataclass(frozen=True)
class Config:
    """Configuration with immutable default parameters."""

    # This works! ImmutableMap is immutable and hashable
    params: ImmutableMap[str, float] = ImmutableMap(
        learning_rate=0.001, momentum=0.9, weight_decay=1e-4
    )
    batch_size: int = 32


# JAX can safely transform functions using this dataclass
@jax.jit
def train_step(config: Config, data: jnp.ndarray) -> jnp.ndarray:
    """Example training step that uses config parameters."""
    lr = config.params["learning_rate"]
    return data * lr


# This works perfectly
config = Config()
data = jnp.array([1.0, 2.0, 3.0])
result = train_step(config, data)
print(f"Result: {result}")
# Result: [0.001 0.002 0.003]
```

#### Key Benefits for JAX

- **Immutability**: Once created, `ImmutableMap` cannot be modified, preventing
  accidental mutations that could break JAX's functional programming model
- **Hashability**: JAX can safely cache and memoize functions that use
  `ImmutableMap` instances
- **PyTree Support**: `ImmutableMap` is registered as a JAX PyTree, so it works
  seamlessly with JAX transformations like `jit`, `grad`, `vmap`, etc.
- **Safe Defaults**: Can be used as default values in dataclasses without the
  typical pitfalls of mutable defaults

## Development

[![Actions Status][actions-badge]][actions-link]

We welcome contributions!

<!-- prettier-ignore-start -->
[actions-badge]:            https://github.com/GalacticDynamics/xmmutablemap/workflows/CI/badge.svg
[actions-link]:             https://github.com/GalacticDynamics/xmmutablemap/actions
[pypi-link]:                https://pypi.org/project/xmmutablemap/
[pypi-platforms]:           https://img.shields.io/pypi/pyversions/xmmutablemap
[pypi-version]:             https://img.shields.io/pypi/v/xmmutablemap

<!-- prettier-ignore-end -->
