Metadata-Version: 2.1
Name: absl_extra
Version: 0.0.3.dev18
Summary: A wrapper to run and monitor absl app.
Author-email: Artem Sereda <artem.sereda.tub@gmail.com>
Maintainer-email: Artem Sereda <artem.sereda.tub@gmail.com>
Requires-Python: >=3.8
Description-Content-Type: text/markdown
Classifier: Development Status :: 3 - Alpha
Requires-Dist: absl_py
Requires-Dist: toolz
Requires-Dist: typing_extensions; python_version < '3.10'
Requires-Dist: black ; extra == "dev"
Requires-Dist: pytest ; extra == "dev"
Requires-Dist: chex ; extra == "dev"
Requires-Dist: absl_extra[mongo,ml_collections,slack,tensorflow,jax,flax] ; extra == "dev"
Requires-Dist: absl_extra[jax] ; extra == "flax"
Requires-Dist: flax ; extra == "flax"
Requires-Dist: clu ; extra == "flax"
Requires-Dist: jaxtyping ; extra == "jax"
Requires-Dist: jax ; extra == "jax"
Requires-Dist: jaxlib ; extra == "jax"
Requires-Dist: ml_collections ; extra == "ml_collections"
Requires-Dist: pymongo ; extra == "mongo"
Requires-Dist: slack_sdk ; extra == "slack"
Requires-Dist: tensorflow ; extra == "tensorflow" and ( sys_platform == 'linux')
Requires-Dist: tensorflow_macos ; extra == "tensorflow" and ( sys_platform == 'darwin')
Project-URL: Homepage, https://github.com/aaarrti/absl_extra
Provides-Extra: dev
Provides-Extra: flax
Provides-Extra: jax
Provides-Extra: ml_collections
Provides-Extra: mongo
Provides-Extra: slack
Provides-Extra: tensorflow

### ABSL-Extra

A collection of utils I commonly use for running my experiments.
It will:
- Notify on execution start, finish or failed.
  - By default, Notifier will just log those out to `stdout`.
  - I prefer receiving those in Slack, though (see example below).
- Log parsed CLI flags from `absl.flags.FLAGS` and config values from `config_file:get_config()`
- Inject `pymongo.collection.Collection` if `mongo_config` kwarg provided.
- Select registered task to run based on --task= CLI argument.

Minimal example

```python
import os
from pymongo.collection import Collection
from ml_collections import ConfigDict
from absl import logging
import tensorflow as tf

from absl_extra import tf_utils, tasks, notifier


@tasks.register_task(
    mongo_config=dict(uri=os.environ["MONGO_URI"], db_name="my_project", collection="experiment_1"),
    notifier=notifier.SlackNotifier(slack_token=os.environ["SLACK_BOT_TOKEN"], channel_id=os.environ["CHANNEL_ID"])
)
@tf_utils.requires_gpu
def main(config: ConfigDict, db: Collection) -> None:
    if tf_utils.supports_mixed_precision():
        tf.keras.mixed_precision.set_global_policy("mixed_float16")
    
    with tf_utils.make_gpu_strategy().scope():
        logging.info("Doing some heavy lifting...")


if __name__ == "__main__":
    tasks.run()
```
