# Approximate Distribution Metrics

`ApproxDistributionMetricsModule` compares the terminal states distribution
generated by recent policies with the ground truth distribution. The terminal states are stored in a buffer.

## Intuition

- Use this metric when the environment can enumerate its terminal space and provide a true distribution;
- Prefer the exact-distribution metric if produced time overhead is acceptable;
- For even larger or not enumerable environments, use ELBO, since distribution-based metrics will not be reasonable.
- Accurately choose the size of the replay buffer to cover the policy’s support: too small yields a noisy metric,
  too large reacts slowly to policy updates;
- Add "2d_marginal_distribution" to view a coarse heatmap.

## Key parameters

- `metrics`: List of metric names to compute, choose from `{"tv", "kl", "jsd", "2d_marginal_distribution"}`.
- `env`: Enumerable environment for which to compute metrics.
- `buffer_size`: Maximum number of states to store in the replay buffer for empirical distribution computation.

## Quick start

> **Environment requirement:** you must be able to enumerate or sample the terminal distribution and capture terminal states during rollouts (for example via `gfnx.utils.forward_rollout`) before updating the metric.

```python
import jax
import gfnx

env = gfnx.HypergridEnvironment(reward_module=gfnx.EasyHypergridRewardModule())
env_params = env.init(jax.random.PRNGKey(0))

metrics = gfnx.metrics.ApproxDistributionMetricsModule(
    metrics=["tv", "kl", "jsd"],
    env=env,
    buffer_size=10_000,
)

state = metrics.init(jax.random.PRNGKey(1), metrics.InitArgs(env_params=env_params))

# During training: add every batch of terminal states you collect.
state = metrics.update(
    state,
    jax.random.PRNGKey(2),
    metrics.UpdateArgs(states=trajectory.final_env_state),  # terminal states from your rollout
)

# When you want a report: rebuild the empirical distribution and read the metrics.
state = metrics.process(
    state,
    jax.random.PRNGKey(3),
    metrics.ProcessArgs(env_params=env_params),
)
scores = metrics.get(state)
print(scores["tv"], scores["kl"], scores["jsd"])
```
