Metadata-Version: 2.4
Name: reaxion
Version: 0.1.1
Summary: object-oriented ISM processes
Author-email: Mike Grudić <mike.grudic@gmail.com>
License-Expression: MIT
Project-URL: Homepage, https://github.com/mikegrudic/reaxion
Project-URL: Issues, https://github.com/mikegrudic/reaxion/issues
Classifier: Programming Language :: Python :: 3
Classifier: Operating System :: OS Independent
Requires-Python: >=3.10
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: jax
Requires-Dist: sympy
Requires-Dist: numpy
Requires-Dist: sphinx_rtd_theme
Requires-Dist: astropy
Requires-Dist: matplotlib
Dynamic: license-file

# reaxion

[![Python package](https://github.com/mikegrudic/reaxion/actions/workflows/test.yml/badge.svg)](https://github.com/mikegrudic/reaxion/actions/workflows/test.yml)
[![Readthedocs Status][docs-badge]][docs-link]
[![codecov](https://codecov.io/github/mikegrudic/reaxion/graph/badge.svg?token=OWJQMWGABZ)](https://codecov.io/github/mikegrudic/reaxion)

[docs-link]:           https://reaxion.readthedocs.io
[docs-badge]:          https://readthedocs.org/projects/reaxion/badge

`reaxion` is a flexible, object-oriented implementation for systems of ISM microphysics and chemistry equations, with numerical solvers implemented in JAX, and interfaces for embedding the equations and their Jacobians into other codes.

## Do we really need yet another ISM code?

`reaxion` might be interesting because it combines two powerful concepts:
1. **Object-oriented implementation of microphysics and chemistry via the `Process` class**, which implements methods for representing physical processes, composing them into a network in a fully-symbolic `sympy` representation. OOP is nice here because if you want to add a new process to `reaxion`, you typically only have to do it in one file. Rate expressions never have to be repeated in-code. Most processes one would want to implement follow very common patterns (e.g. 2-body processes), so class inheritance is also used to minimize new lines of code. 
Once you've constructed your system, `reaxion` can give you the symbolic equations to manipulate and analyze as you please. If you want to solve the equations numerically, `Process` has methods for substituting known values into numerical solvers. It can also automatically generate compilable implementations of the RHS of the system to embed in your choice of simulation code and plug into your choice of solver.
2. **Fast, differentiable implementation of nonlinear algebraic and differential-algebraic equation solvers with JAX**, implemented in its functional programming paradigm (e.g. `reaxion.numerics.newton_rootsolve`). These can achieve excellent numerical throughput running natively on GPUs - in fact, crunching iterates in-place is essentially the best-case application of numerics on GPUs. Differentiability enables sensitivity analysis with respect to all parameters in a single pass, instead of constructing a grid of `N` parameter variations for `N` parameters. This makes it easier in principle to directly answer questions like "How sensitive is this temperature to the abundance of C or the ionization energy of H?", etc.

## Roadmap

`reaxion` is in an early prototyping phase right now. Here are some things I would eventually like to add:
* Flexible implementation of a reduced network suitable for RHD simulations in GIZMO and potentially other codes.
* Dust and radiation physics: add the dust energy equation and evolution of photon number densities to the network.
* Interfaces to convert from other existing chemistry network formats to the `Process` representation.
* Solver robustness upgrades: thermochemical networks can be quite challenging numerically, due to how steeply terms switch on with increasing `T`. In can be hard to get a solution without good initial guesses.
* If possible, glue interface allowing an existing compiled hydro code to call the JAX solvers on-the-fly.

## Installation

Clone the repo and run `pip install .` from the directory.

# Quickstart: Collisional Ionization Equilibrium

Example of using `reaxion` to solve for collisional ionization equilibrium (CIE) for a hydrogen-helium mixture and plot the ionization states as a function of temperature.


```python
%matplotlib inline
%config InlineBackend.figure_format='retina'
import numpy as np
from matplotlib import pyplot as plt
import sympy as sp
```

## Simple processes
A simple process is defined by a single reaction, with a specified rate.

Let's inspect the structure of a single process, the gas-phase recombination of H+: `H+ + e- -> H + hν` 


```python
from reaxion.processes import CollisionalIonization, GasPhaseRecombination

process = GasPhaseRecombination("H+")
print(f"Name: {process.name}")
print(f"Heating rate coefficient: {process.heat_rate_coefficient}")
print(f"Heating rate per cm^-3: {process.heat}"),
print(f"Rate coefficient: {process.rate_coefficient}")
print(f"Recombination rate per cm^-3: {process.rate}")
print(f"RHS of e- number density equation: {process.network['e-']}")
```

    Name: Gas-phase recombination of H+
    Heating rate coefficient: -1.46719838641439e-26*sqrt(T)/((0.00119216696847702*sqrt(T) + 1.0)**1.748*(0.563615123664978*sqrt(T) + 1.0)**0.252)
    Heating rate per cm^-3: -1.46719838641439e-26*sqrt(T)*n_H+*n_e-/((0.00119216696847702*sqrt(T) + 1.0)**1.748*(0.563615123664978*sqrt(T) + 1.0)**0.252)
    Rate coefficient: 1.41621465870114e-10/(sqrt(T)*(0.00119216696847702*sqrt(T) + 1.0)**1.748*(0.563615123664978*sqrt(T) + 1.0)**0.252)
    Recombination rate per cm^-3: 1.41621465870114e-10*n_H+*n_e-/(sqrt(T)*(0.00119216696847702*sqrt(T) + 1.0)**1.748*(0.563615123664978*sqrt(T) + 1.0)**0.252)
    RHS of e- number density equation: Eq(Derivative(n_e-(t), t), -1.41621465870114e-10*n_H+*n_e-/(sqrt(T)*(0.00119216696847702*sqrt(T) + 1.0)**1.748*(0.563615123664978*sqrt(T) + 1.0)**0.252))


Note that all symbolic representations assume CGS units as is standard in ISM physics.

## Composing processes
Now let's define our full network as a sum of simple processes


```python
processes = [CollisionalIonization(s) for s in ("H", "He", "He+")] + [GasPhaseRecombination(i) for i in ("H+", "He+", "He++")]
system = sum(processes)

system.subprocesses
```




    [Collisional Ionization of H,
     Collisional Ionization of He,
     Collisional Ionization of He+,
     Gas-phase recombination of H+,
     Gas-phase recombination of He+,
     Gas-phase recombination of He++]



Summed processes keep track of all subprocesses, e.g. the total net heating rate is:


```python
system.heat
```




$\displaystyle - \frac{1.55 \cdot 10^{-26} n_{He+} n_{e-}}{T^{0.3647}} - \frac{1.2746917300104 \cdot 10^{-21} \sqrt{T} n_{H} n_{e-} e^{- \frac{157809.1}{T}}}{\frac{\sqrt{10} \sqrt{T}}{1000} + 1} - \frac{1.46719838641439 \cdot 10^{-26} \sqrt{T} n_{H+} n_{e-}}{\left(0.00119216696847702 \sqrt{T} + 1.0\right)^{1.748} \left(0.563615123664978 \sqrt{T} + 1.0\right)^{0.252}} - \frac{9.37661057635428 \cdot 10^{-22} \sqrt{T} n_{He} n_{e-} e^{- \frac{285335.4}{T}}}{\frac{\sqrt{10} \sqrt{T}}{1000} + 1} - \frac{4.9524176975855 \cdot 10^{-22} \sqrt{T} n_{He+} n_{e-} e^{- \frac{631515}{T}}}{\frac{\sqrt{10} \sqrt{T}}{1000} + 1} - \frac{5.86879354565754 \cdot 10^{-26} \sqrt{T} n_{He++} n_{e-}}{\left(0.00119216696847702 \sqrt{T} + 1.0\right)^{1.748} \left(0.563615123664978 \sqrt{T} + 1.0\right)^{0.252}}$



Summing processes also sums all chemical and gas/dust cooling/heating rates. 

## Solving ionization equilibrium

We would like to solve for ionization equilibrium given a temperature $T$, overall H number density $n_{\rm H,tot}$.  We define a dictionary of those input quantities and also one for the initial guesses of the number densities of the species in the reduced network.


```python
Tgrid = np.logspace(3,6,10**6)
ngrid = np.ones_like(Tgrid) * 100

knowns = {"T": Tgrid, "n_Htot": ngrid}

guesses = {
    "H": 0.5*np.ones_like(Tgrid),
    "He": 1e-5*np.ones_like(Tgrid),
    "He+": 1e-5*np.ones_like(Tgrid)
}
```

Note that by default, the solver only directly solves for $n_{\rm H}$, $n_{\rm He}$ and $n_{\rm He+}$ because $n_{\rm H+}$, $n_{\rm He++}$, and $n_{\rm e-}$ are eliminated by conservation equations. So we only need initial guesses for those 3 quantities. By default the solver takes abundances $x_i = n_i / n_{\rm H,tot}$ as inputs and outputs.

The `solve` method calls the JAX solver and computes the solution:


```python
sol = system.solve(knowns, guesses,tol=1e-3)
print(sol)
```

    {'He': Array([9.2546351e-02, 9.2546351e-02, 9.2546351e-02, ..., 2.7493625e-09,
           2.7493037e-09, 2.7492442e-09], dtype=float32), 'H': Array([9.9999994e-01, 9.9999994e-01, 9.9999994e-01, ..., 6.0612075e-07,
           6.0611501e-07, 6.0610921e-07], dtype=float32), 'He+': Array([3.1222404e-13, 3.1222396e-13, 3.1222374e-13, ..., 7.6922206e-06,
           7.6921306e-06, 7.6920396e-06], dtype=float32), 'He++': Array([0.        , 0.        , 0.        , ..., 0.09253865, 0.09253865,
           0.09253865], dtype=float32), 'H+': Array([5.9604645e-08, 5.9604645e-08, 5.9604645e-08, ..., 9.9999940e-01,
           9.9999940e-01, 9.9999940e-01], dtype=float32), 'e-': Array([5.9604957e-08, 5.9604957e-08, 5.9604957e-08, ..., 1.1850843e+00,
           1.1850843e+00, 1.1850843e+00], dtype=float32)}



```python
for i, xi in sorted(sol.items()):
    plt.loglog(Tgrid, xi, label=i)
plt.legend(labelspacing=0)
plt.ylabel("$x_i$")
plt.xlabel("T (K)")
plt.ylim(1e-4,3)
```




    (0.0001, 3)




    
![png](CIE_files/CIE_15_1.png)
    


## Generating code

Suppose you just want the RHS of the system you're solving, or its Jacobian, because you have a better solver and/or want to embed these equations in some old C or Fortran code without any dependencies. You can do that too with `generate_code`.


```python
print(system.generate_code(('H','He','He+'),language='c'))
```

    # Computes the RHS function and Jacobianto solve for [x_He, x_H, x_Heplus]
    
    # INDEX CONVENTION: (0: x_He) (1: x_H) (2: x_Heplus)
    
    x0 = 1.0/T; 
    x1 = sqrt(T); 
    x2 = pow(n_Htot, 2); 
    x3 = 1.0/((1.0/1000.0)*sqrt(10)*x1 + 1); 
    x4 = x1*x2*x3; 
    x5 = x4*exp(-285335.40000000002*x0); 
    x6 = x5*x_He; 
    x7 = x_H - 1; 
    x8 = -x7 - 2*x_He - x_Heplus + 2*y; 
    x9 = 2.3800000000000001e-11*x8; 
    x10 = 1.0/x1; 
    x11 = x2*(0.0019*pow(T, -1.5)*(1 + 0.29999999999999999*exp(-94000.0*x0))*exp(-470000.0*x0) + 1.9324160622805846e-10*x10*pow(0.00016493478118851054*x1 + 1.0, -1.7891999999999999)*pow(4.8416074481177231*x1 + 1.0, -0.21079999999999999)); 
    x12 = x11*x_Heplus; 
    x13 = -x12*x8 + x6*x9; 
    x14 = exp(-157809.10000000001*x0); 
    x15 = x14*x4; 
    x16 = x15*x_H; 
    x17 = 5.8500000000000005e-11*x16; 
    x18 = -x7; 
    x19 = pow(0.0011921669684770192*x1 + 1.0, -1.748); 
    x20 = pow(0.56361512366497779*x1 + 1.0, -0.252); 
    x21 = -x_He - x_Heplus + y; 
    x22 = x10*x2; 
    x23 = x22*pow(0.00059608348423850961*x1 + 1.0, -1.748)*pow(0.2818075618324889*x1 + 1.0, -0.252); 
    x24 = 5.664858634804579e-10*x23; 
    x25 = x24*x8; 
    x26 = exp(-631515*x0); 
    x27 = x26*x4; 
    x28 = 4.7600000000000002e-11*x6; 
    x29 = x5*x9; 
    x30 = 2*x12; 
    x31 = -x12 + 2.3800000000000001e-11*x6; 
    x32 = x11*x8 + x31; 
    x33 = x19*x20*x22; 
    x34 = x18*x33; 
    x35 = 1.4162146587011448e-10*x34; 
    x36 = -5.68e-12*x1*x2*x26*x3*x_Heplus + x21*x24;
    
    rhs_result[0] = -x13;
    rhs_result[1] = 1.4162146587011448e-10*x10*x18*x19*x2*x20*x8 - x17*x8;
    rhs_result[2] = x13 + x21*x25 - 5.68e-12*x27*x8*x_Heplus;
    
    jac_result[0] = x28 - x29 - x30;
    jac_result[1] = x31;
    jac_result[2] = x32;
    jac_result[3] = 1.1700000000000001e-10*x16 - 2.8324293174022895e-10*x34;
    jac_result[4] = 5.8500000000000005e-11*x1*x14*x2*x3*x_H - 5.8500000000000005e-11*x15*x8 - 1.4162146587011448e-10*x33*x8 - x35;
    jac_result[5] = x17 - x35;
    jac_result[6] = 1.136e-11*x1*x2*x26*x3*x_Heplus - 1.1329717269609158e-9*x21*x23 - x25 - x28 + x29 + x30;
    jac_result[7] = -x31 - x36;
    jac_result[8] = -x25 - 5.68e-12*x27*x8 - x32 - x36;


Let's break down what happened there. First, reaxion is generating the symbolic functions needed to solve the system, as it needs to do before it solves the system with its own solver:


```python
func, jac, _ = system.network.solver_functions(('H','He','He+'),return_jac=True)
```

Here `func` represents the set of functions $f_i$ such that $f_i = 0$ solves the system. `jac` encodes the Jacbian of f $J_{ij} = \frac{\partial f_i}{\partial x_j}$ of derivatives with respect to the solved variables. Note that the two have many common expressions - before being implemented, one should employ common expression elimination to simplify the code and evaluate the functions more efficiently:


```python
cse, (cse_func, cse_jac) = sp.cse((sp.Matrix(func),sp.Matrix(jac)))

cse
```




    [(x0, 1/T),
     (x1, sqrt(T)),
     (x2, n_Htot**2),
     (x3, 1/(sqrt(10)*x1/1000 + 1)),
     (x4, x1*x2*x3),
     (x5, x4*exp(-285335.4*x0)),
     (x6, x5*x_He),
     (x7, x_H - 1),
     (x8, -x7 - 2*x_He - x_He+ + 2*y),
     (x9, 2.38e-11*x8),
     (x10, 1/x1),
     (x11,
      x2*(0.0019*(1 + 0.3*exp(-94000.0*x0))*exp(-470000.0*x0)/T**1.5 + 1.93241606228058e-10*x10/((0.000164934781188511*x1 + 1.0)**1.7892*(4.84160744811772*x1 + 1.0)**0.2108))),
     (x12, x11*x_He+),
     (x13, -x12*x8 + x6*x9),
     (x14, exp(-157809.1*x0)),
     (x15, x14*x4),
     (x16, x15*x_H),
     (x17, 5.85e-11*x16),
     (x18, -x7),
     (x19, (0.00119216696847702*x1 + 1.0)**(-1.748)),
     (x20, (0.563615123664978*x1 + 1.0)**(-0.252)),
     (x21, -x_He - x_He+ + y),
     (x22, x10*x2),
     (x23,
      x22/((0.00059608348423851*x1 + 1.0)**1.748*(0.281807561832489*x1 + 1.0)**0.252)),
     (x24, 5.66485863480458e-10*x23),
     (x25, x24*x8),
     (x26, exp(-631515*x0)),
     (x27, x26*x4),
     (x28, 4.76e-11*x6),
     (x29, x5*x9),
     (x30, 2*x12),
     (x31, -x12 + 2.38e-11*x6),
     (x32, x11*x8 + x31),
     (x33, x19*x20*x22),
     (x34, x18*x33),
     (x35, 1.41621465870114e-10*x34),
     (x36, -5.68e-12*x1*x2*x26*x3*x_He+ + x21*x24)]




```python
cse_func
```




$\displaystyle \left[\begin{matrix}- x_{13}\\1.41621465870114 \cdot 10^{-10} x_{10} x_{18} x_{19} x_{2} x_{20} x_{8} - x_{17} x_{8}\\x_{13} + x_{21} x_{25} - 5.68 \cdot 10^{-12} x_{27} x_{8} x_{He+}\end{matrix}\right]$




```python
cse_jac
```




$\displaystyle \left[\begin{matrix}x_{28} - x_{29} - x_{30} & x_{31} & x_{32}\\1.17 \cdot 10^{-10} x_{16} - 2.83242931740229 \cdot 10^{-10} x_{34} & 5.85 \cdot 10^{-11} x_{1} x_{14} x_{2} x_{3} x_{H} - 5.85 \cdot 10^{-11} x_{15} x_{8} - 1.41621465870114 \cdot 10^{-10} x_{33} x_{8} - x_{35} & x_{17} - x_{35}\\1.136 \cdot 10^{-11} x_{1} x_{2} x_{26} x_{3} x_{He+} - 1.13297172696092 \cdot 10^{-9} x_{21} x_{23} - x_{25} - x_{28} + x_{29} + x_{30} & - x_{31} - x_{36} & - x_{25} - 5.68 \cdot 10^{-12} x_{27} x_{8} - x_{32} - x_{36}\end{matrix}\right]$



One can then take these expressions and convert them to the syntax of the code you wish to embed them in: 


```python
from sympy.codegen.ast import Assignment
for expr in cse:
    print(sp.ccode(Assignment(*expr),standard='c99'))

rhs_result = sp.MatrixSymbol('rhs_result', len(func), 1)
jac_result = sp.MatrixSymbol('jac_result', len(func),len(func))
print()
print(sp.ccode(Assignment(rhs_result, cse_func),standard='c99'))
print()
print(sp.ccode(Assignment(jac_result, cse_jac),standard='c99'))
```

    x0 = 1.0/T;
    x1 = sqrt(T);
    x2 = pow(n_Htot, 2);
    x3 = 1.0/((1.0/1000.0)*sqrt(10)*x1 + 1);
    x4 = x1*x2*x3;
    x5 = x4*exp(-285335.40000000002*x0);
    x6 = x5*x_He;
    x7 = x_H - 1;
    x8 = -x7 - 2*x_He - x_He+ + 2*y;
    x9 = 2.3800000000000001e-11*x8;
    x10 = 1.0/x1;
    x11 = x2*(0.0019*pow(T, -1.5)*(1 + 0.29999999999999999*exp(-94000.0*x0))*exp(-470000.0*x0) + 1.9324160622805846e-10*x10*pow(0.00016493478118851054*x1 + 1.0, -1.7891999999999999)*pow(4.8416074481177231*x1 + 1.0, -0.21079999999999999));
    x12 = x11*x_He+;
    x13 = -x12*x8 + x6*x9;
    x14 = exp(-157809.10000000001*x0);
    x15 = x14*x4;
    x16 = x15*x_H;
    x17 = 5.8500000000000005e-11*x16;
    x18 = -x7;
    x19 = pow(0.0011921669684770192*x1 + 1.0, -1.748);
    x20 = pow(0.56361512366497779*x1 + 1.0, -0.252);
    x21 = -x_He - x_He+ + y;
    x22 = x10*x2;
    x23 = x22*pow(0.00059608348423850961*x1 + 1.0, -1.748)*pow(0.2818075618324889*x1 + 1.0, -0.252);
    x24 = 5.664858634804579e-10*x23;
    x25 = x24*x8;
    x26 = exp(-631515*x0);
    x27 = x26*x4;
    x28 = 4.7600000000000002e-11*x6;
    x29 = x5*x9;
    x30 = 2*x12;
    x31 = -x12 + 2.3800000000000001e-11*x6;
    x32 = x11*x8 + x31;
    x33 = x19*x20*x22;
    x34 = x18*x33;
    x35 = 1.4162146587011448e-10*x34;
    x36 = -5.68e-12*x1*x2*x26*x3*x_He+ + x21*x24;
    
    rhs_result[0] = -x13;
    rhs_result[1] = 1.4162146587011448e-10*x10*x18*x19*x2*x20*x8 - x17*x8;
    rhs_result[2] = x13 + x21*x25 - 5.68e-12*x27*x8*x_He+;
    
    jac_result[0] = x28 - x29 - x30;
    jac_result[1] = x31;
    jac_result[2] = x32;
    jac_result[3] = 1.1700000000000001e-10*x16 - 2.8324293174022895e-10*x34;
    jac_result[4] = 5.8500000000000005e-11*x1*x14*x2*x3*x_H - 5.8500000000000005e-11*x15*x8 - 1.4162146587011448e-10*x33*x8 - x35;
    jac_result[5] = x17 - x35;
    jac_result[6] = 1.136e-11*x1*x2*x26*x3*x_He+ - 1.1329717269609158e-9*x21*x23 - x25 - x28 + x29 + x30;
    jac_result[7] = -x31 - x36;
    jac_result[8] = -x25 - 5.68e-12*x27*x8 - x32 - x36;

