# Introduction
[Release Notes](NEWS.md)

*Symjit* is a lightweight just-in-time (JIT) compiler that directly translates *sympy* expressions into machine code without using a separate library such as LLVM. Its main utility is to generate fast numerical functions to feed into various numerical solvers provided by the NumPy/SciPy ecosystem, including numerical integration routines and ordinary differential equation (ODE) solvers.

*Symjit* has two different code-generating backends. The default is a Rust library with minimum external dependencies. The second backend is written in plain Python, relying solely on the Python standard library and NumPy. Both backends generate AMD64 (also known as x86-64) and ARM64 (also known as aarch64) machine code on Linux, Windows, and Darwin (MacOS) platforms. Further architectures (e.g., RISC V) are planned.

As of version 2.0, the Rust backend generates AVX-compatible code by default for x86-64/AMD64 processors but can downgrade to SSE2 instructions if the processor does not support AVX or if explicitly requested by passing `ty='amd-sse'` to compile functions (see below). In version 1, the default was SSE instructions. Note that SSE2 instructions were introduced in 2000, meaning that virtually all current 64-bit x86-64 processors support them. Intel introduced the AVX instruction set in 2011; therefore, most processors support it. The Python backend uses only AVX instructions.

On ARM64 processors, both the Rust and Python backends generate code for the aarch64 instruction set. ARM32 and IA32 are not supported.

[FuncBuilder](https://github.com/siravan/funcbuilder) is a companion package that provides a more general code generator akin to [llvmlite](https://github.com/numba/llvmlite). It is currently in the early stages of development.

# Installing symjit

Installing `symjit` from the `conda-forge` channel can be achieved by adding `conda-forge` to your channels with:

```
conda config --add channels conda-forge
conda config --set channel_priority strict
```

Once the `conda-forge` channel has been enabled, `symjit` can be installed with `conda`:

```
conda install symjit
```

or with `mamba`:

```
mamba install symjit
```

It is possible to list all of the versions of `symjit` available on your platform with `conda`:

```
conda search symjit --channel conda-forge
```

or with `mamba`:

```
mamba search symjit --channel conda-forge
```

Alternatively, `mamba repoquery` may provide more information:

```
# Search all versions available on your platform:
mamba repoquery search symjit --channel conda-forge

# List packages depending on `symjit`:
mamba repoquery whoneeds symjit --channel conda-forge

# List dependencies of `symjit`:
mamba repoquery depends symjit --channel conda-forge
```

You can also install symjit from pypi using pip:

```
python -m pip install symjit
```

However, the pip install may not include the correct binary Rust backend for different platforms and the conda-forge install is preferable. In addition, you can install *symjit* from the source by cloning [symjit](https://github.com/siravan/symjit) into `symjit` folder and then running

```
cd symjit
python -m pip install .
```

For the last option, you need a working Rust compiler and toolchains.

# Tutorial

## `compile_func`: a fast substitute for `lambdify`

*symjit* is invoked by calling different `compile_*` functions. The most basic is `compile_func`, which behaves similarly to sympy `lambdify` function. While `lambdify` translates sympy expressions into regular Python functions, which in turn call numpy functions, `compile_func` returns a callable object `Func`, which is a thin wrapper over the jit code generated by the backends.

A simple example is

```python
import numpy as np
from symjit import compile_func
from sympy import symbols

x, y = symbols('x y')
f = compile_func([x, y], [x+y, x*y])
assert(np.all(f(3, 5) == [8., 15.]))
```

`compile_*` functions support these [operators and functions](FUNCTIONS.md).


`compile_func` takes two mandatory arguments as `compile_func(states, eqs)`. The first one, `states`, is a list or tuple of sympy symbols. The second argument, `eqs`, is a list, a tuple, or a single expression. We can think of `states` and `eqs` as corresponding to function signature and body. 

If `states` has only one element, it can be passed directly. Similar to sympy `lambdify`, the output follows the form of the second argument to `compile_func`. Therefore, if `f = compile_func([x, y], [x+y, x*y])`, then `f(2, 3)` returns a list. On the other hand, if `f = compile_func([x, y], (x+y, x*y))`, the output will be `(5, 6)`. The third form is a single scalar, such as if `f = compile_func([x, y], sin(x+y))`.

In addition, `compile_func` accepts a named argument `params`, which is a list of symbolic parameters. The output of `compile_func`, say `f`, is a callable object of type `Func`. The signature of `f` is `f(x_1,...,x_n,p_1,...,p_m)`, where `x`s are the state variables and `p`s are the parameters. Therefore, `n = len(states)` and `m = len(params)`. For example,

```python
x, y, a = symbols('x y a')
f = compile_func([x, y], [(x+y)**a], params=[a])
assert(np.all(f(3., 5., 2.) == [64.]))  # 2. is the value of parameter a
```

By default, `compile_func` uses the Rust backend. However, we can force the use of the Python backend by passing `backend='python'` to `compile_func`. Moreover, if the binary library containing the Rust backend is unavailable or incompatible, symjit automatically switches to the Python backend.

`compile_func` helps generate functions to pass to numerical integration (quadrature) routines. The following example is adapted from scipy documentation:

```python
import numpy as np
from scipy.integrate import nquad
from sympy import symbols, exp
from symjit import compile_func

N = 5
t, x = symbols("t x")
f = compile_func([t, x], exp(-t*x)/t**N)

sol = nquad(f, [[1, np.inf], [0, np.inf]])

np.testing.assert_approx_equal(sol[0], 1/N)
```

The output of the returned callable is a numpy array with `dtype='double'`. Note that you can call `f` by passing a list of numbers (say, `f(1.0, 2.0)`) or a list of numpy arrays (for example, `f(np.asarray([1., 2.]), np.asarray([3., 4.]))`. However, broadcasting is not supported. All the parameters should be passed as scalars even if the state variables are arrays. For example,

```python
import numpy as np
import matplotlib.pyplot as plt
from sympy import symbols
from symjit import compile_func

x, sigma = symbols('x sigma')
f = compile_func([x], [exp(-(x-100)**2/(2*sigma**2))], params=[sigma])

t = np.arange(0, 200)
y = f(t, 25.)[0]

plt.plot(t, y)
```

The following example uses the vectorization feature to calculate the [Mandelbrot set](https://en.wikipedia.org/wiki/Mandelbrot_set).

```python
# examples/mandelbrot.py
import numpy as np
import matplotlib.pyplot as plt
from sympy import symbols
from symjit import compile_func

x, y, a, b = symbols("x y a b")

A, B = np.meshgrid(np.arange(-2, 1, 0.002), np.arange(-1.5, 1.5, 0.002))
X = np.zeros_like(A)
Y = np.zeros_like(A)

f = compile_func([a, b, x, y], [x**2 - y**2 + a, 2*x*y + b])

for i in range(20):
    X, Y = f(A, B, X, Y)

Z = np.hypot(X, Y)

plt.imshow(Z < 2)
```

The output is:

![Mandelbrot](./figures/mandelbrot.png)

## Optimization

The Rust backend supports different optimization and parallelization methods, which can be controlled using `compile_func` arguments. The options are:

* `use_simd` (default `True`): generates SIMD instructions if possible (currently supports AVX instructions on X86-64 processors). SIMD code should improve the performance up to 4x for certain tasks (using 256-bit registers that encode and operate on four doubles simultaneously).
* `use_threads` (default `True`): use multi-threading to speed up parallel processing of array operations using [Rayon rust crate](https://docs.rs/rayon/latest/rayon/). 
* `cse` (default `True`): New to version 2.4. It performs common-subexpression elimination, i.e., factoring common expressions and sub-expressions. 

Note that SIMD and multi-threading optimizations only apply to vectorized calls, but common-subexpression elimination applies to both scalar and vectorized operations. 


## Fast Functions

The result of different `compile` functions is a Python object, say `f`, that encapsulates the underlying compiled code. When we call `f(...)`,  `f.__call__` is called with the arguments. Then, `__call__` checks the type of arguments (scalar vs. vector), packages the inputs accordingly, calls the correct compiled routine via the respective Rust routines, and finally, formats the return values. All these actions have an overhead. The overhead is acceptable if the compiled function is large and complex, but it becomes relatively too expensive if the function is simple and lightweight. In this situation, it is faster to call the underlying compiled code directly. If the following conditions hold, it is possible to do so:

1. The output is a single **scalar** expression.
2. There are zero to eight **scalar** input arguments.
3. There is no parameter.

In most cases, *Symjit* can automatically switch a function to a fast one. However, there are situations when using the fast function directly improves performance. For example, this applies when passing functions to Scipy integration functions (`quad`, `nquad`, `dbpquad`, `tplquad`). To assist this, we can access the fast function by calling `f.fast_func()`. The result is a `ctypes.CFUNCTYPE`-generated foreign function. For example, we can rewrite the integration example above as

```python
import numpy as np
from scipy.integrate import nquad
from sympy import symbols, exp
from symjit import compile_func

def integrate():
    N = 5
    t, x = symbols("t x")
    f = compile_func([t, x], exp(-t*x)/t**N)
    fast = f.fast_func()
    return nquad(lambda t, x: fast(t, x), [[1, np.inf], [0, np.inf]])

sol = integrate()
np.testing.assert_approx_equal(sol[0], 1/N)
```

Some points. First, we pass a lambda function to `nquad` because of the peculiarities of `nquad` (and other Scipy integration routines) concerning the expected signature of the foreign functions. We plan to generate the correct signature in a future version. Second, the lifetime of the fast function is linked to `f`. If `f` goes out of the scope and is garbage-collected, the fast function becomes invalid. Therefore, never store the resulting fast function separately from the parent `f`. Thus, in the example above, we had to add the `integrate` function to provide a scope for the fast function.


## Exponentiation to an Integer Power and Modular Exponentiation

Polynomial manipulation over various finite and infinite fields, such as &Zopf;p and &Zopf;, is the cornerstone of computer algebra systems. *Symjit* is primarily designed as a bridge between Sympy and numerical libraries (NumPy, SciPy, ...) and, as such, focuses on floating-point calculations. However, to assist with sympy integer calculations, version 2 has the capability of detecting and emitting special codes for integer exponentiation and modular exponentiation. IEEE 754 doubles can represent integers accurately up to 2**53 = 9007199254740992.

The first special form is `x**n`, where `x` is any variable or expression, and `n` is a constant integer. *Symjit* emits the corresponding code directly in the function byte stream using the exponentiation-by-squaring method. This improves performance by allowing for better register allocations.

The second special form is `x**n % p`, where `p` is any expression. Instead of calculating `x**n` first and then applying `%` (which can easily overflow), *Symjit* incorporates modular reduction at each stage of squaring. For example,

```python
from sympy import symbols
from symjit import compile_func

x = symbols("x")
f = compile_func([x], x ** 1000 % 257)
assert(f(10) == 189)
```

Note that `10**1000` can be represented by a double (the max double value is ~1.8*10**308). Therefore, calculating `10**1000` directly would overflow.

## `compile_ode`: to solve ODEs

`compile_ode` returns a callable object (`OdeFunc`) suitable for passing to `scipy.integrate.solve_ivp` (the main Numpy/Scipy ODE solver). It takes three mandatory arguments as `compile_ode(iv, states, odes)`. The first one (`iv`) is a single symbol that specifies the independent variable. The second argument, `states`, is a list of symbols defining the ODE state. The right-hand side of ODE equations is passed as the third argument, `odes.` It is a list of expressions that define the ODE by providing the derivative of each state variable with respect to the independent variable. In addition, similar to `compile_func`, `compile_ode` can accept an optional `params`. For example,

```python
# examples/trig.py
import scipy.integrate
import matplotlib.pyplot as plt
import numpy as np
from sympy import symbols
from symjit import compile_ode

t, x, y = symbols('t x y')
f = compile_ode(t, (x, y), (y, -x))
t_eval=np.arange(0, 10, 0.01)
sol = scipy.integrate.solve_ivp(f, (0, 10), (0.0, 1.0), t_eval=t_eval)

plt.plot(t_eval, sol.y.T)
```

Here, the ODE definition is `x' = y` and `y' = -x`, which means `y" = y`. The solution is `a*sin(t) + b*cos(t)`, where `a` and `b` are determined by the initial values. Given the initial values of 0 and 1 passed as the third argument of `solve_ivp`, the solutions are `sin(t)` and `cos(t)`. We can confirm this by running the code. The output is

![sin/cos functions](./figures/trig.png)

Note that `OdeFunc` conforms to the function form `scipy.integrate.solve_ivp` expects, i.e., it should be called as `f(t, y, *args)`.

The following example is more complicated and showcases the [Lorenz system](https://en.wikipedia.org/wiki/Lorenz_system), an important milestone in the historical development of chaos theory.

```python
import numpy as np
from scipy.integrate import solve_ivp
import matplotlib.pyplot as plt
from sympy import symbols

from symjit import compile_ode

t, x, y, z = symbols("t x y z")
sigma, rho, beta = symbols("sigma rho beta")

ode = (
    sigma * (y - x),
    x * (rho - z) - y,
    x * y - beta * z
    )

f = compile_ode(t, (x, y, z), ode, params=(sigma, rho, beta))

u0 = (1.0, 1.0, 1.0)
p = (10.0, 28.0, 8 / 3)
t_eval = np.arange(0, 100, 0.01)

sol = solve_ivp(f, (0, 100.0), u0, t_eval=t_eval, args=p)

plt.plot(sol.y[0, :], sol.y[2, :])
```

The result is the famous *strange attractor*:

![the strange attractor](./figures/lorenz.png)


## `compile_jac`: calculating Jacobian

The ODE examples discussed in the previous section are non-stiff and easy to solve using explicit methods. However, not all differential equations are so accommodating! Many important equations are stiff and usually require implicit methods. Many implicit ODE solvers use the system's [Jacobian matrix](https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant) to improve performance.

There are different techniques for calculating the Jacobian. In the last few years, automatic differentiation (AD) methods have gained popularity, working at the abstract syntax tree or lower level. However, if we define our model symbolically using a Computer Algebra System (CAS) such as sympy, we can calculate the Jacobian by differentiating the source symbolic expressions.

`compile_jac` is the symjit function to calculate the Jacobian of an ODE system. It has the same call signature as `compile_ode,` i.e., it is called `compile_jac(iv, states, odes)` with an optional argument `params.` The return value (of type `JacFunc`) is a callable similar to `OdeFunc`, which returns an n-by-n matrix J, where n is the number of states. The element at the ith row and jth column of J is the derivative of `odes[i]` w.r.t `state[j]` (this is the definition of Jacobian).

For example, we can consider the [Van der Pol oscillator](https://en.wikipedia.org/wiki/Van_der_Pol_oscillator). This system has a control parameter (mu). For small values of mu, the ODE system is not stiff and can easily be solved using explicit methods.

```python
import matplotlib.pyplot as plt
import numpy as np
from scipy.integrate import solve_ivp
from sympy import symbols
from math import sqrt
from symjit import compile_ode, compile_jac

t, x, y, mu = symbols('t x y mu')
ode = [y, mu * ((1 - x*x) * y - x)]

f = compile_ode(t, [x, y], ode, params=[mu])
u0 = [0.0, sqrt(3.0)]
t_eval = np.arange(0, 10.0, 0.01)

sol1 = solve_ivp(f, (0, 10.0), u0, method='RK45', t_eval=t_eval, args=[5.0])

plt.plot(t_eval, sol1.y[0,:])
```

The output is

![non-stiff Van der Pol](./figures/van_der_pol_non_stiff.png)

On the other hand, as mu is increased (for example, to 1e6), the system becomes very stiff. An explicit ODE solver, such as RK45 (Runge-Kutta 4/5), cannot solve this problem. Instead, we need an implicit method, such as the backward differentiation formula (BDF). BDF needs a Jacobian. If one is not provided, it numerically calculates one using the finite-difference method. However, this technique is both inaccurate and computationally intensive. It would be much better to give the solver a closed-form Jacobian. As mentioned above, `calculate_jac` exactly does this.

```python
jac = compile_jac(t, [x, y], ode, params=[mu])
sol2 = solve_ivp(f, (0, 10.0), u0, method='BDF', t_eval=t_eval, args=[1e6], jac=jac)

plot.plot(t_eval, sol2.y[0,:])
```

The output of the stiff system is

![non-stiff Van der Pol](./figures/van_der_pol_stiff.png)

# Code Generation

All `compile_*` functions accept an optional parameter `ty`, which defines the type of the code to generate. Currently, the possible values are:

* `amd`: generates 64-bit AMD64/x86-64 code. If the processor supports AVX, then this is equivalent to passing `amd-avx`; otherwise, it is equal to `amd-sse`.
* `amd-avx`: generates 64-bit AMD64/x86-64 AVX code.
* `amd-sse`: generates 64-bit AMD64/x86-64 SSE code. It requires a minimum SSE2.1 specification, which should be easily fulfilled by all except the most ancient processors.
* `arm` generates 64-bit ARM64/aarch64 code. This option is mainly tested on Apple Silicon.
* `bytecode`: this option uses a generic and simple bytecode evaluator as a fallback option in case of unsupported instruction sets. The utility is to test correctness (see option `debug` below), not speed.
* `native` (**default**): selects the correct instruction set based on the current processor.
* `debug`: is useful for debugging the generated code. It runs both `native` and `bytecode` versions, compares the results,
and panics if they are different.

Note that `ty='wasm'` is no longer supported in version 2. Also, as discussed above, `compile_*` functions accept a `backend` argument with possible values of `rust` and `python`.


## Code Inspection

To inspect the generated code, you can use either 'dump' function of various `Func` callables to write the binary into a file or use `dumps` to return a hex string. The output of `dump` is a flat binary file with no header or other extras that can be disassembled. For example,

```python
from symjit import compile_func
from sympy import symbols

x, y = symbols('x y')
f = compile_func([x, y], [x+y, x*y])
f.dump('test.bin', what='scalar')
```

Passing `what='simd'` dumps the vectorized version of the function and `what='fast'` to dump the fast function.

On a Linux system, we can invoke `objdump` to disassemble the output as below:

```
objdump -b binary -m i386:x86-64 -M intel -D test.bin
```

The output (assuming a Linux x86-64 machine) is

```
0000000000000000 <.data>:
   0:	55                   	push   rbp
   1:	53                   	push   rbx
   2:	48 8b ef             	mov    rbp,rdi
   5:	48 81 ec 88 00 00 00 	sub    rsp,0x88
   c:	c5 fb 10 5d 00       	vmovsd xmm3,QWORD PTR [rbp+0x0]
  11:	c5 fb 10 55 08       	vmovsd xmm2,QWORD PTR [rbp+0x8]
  16:	c5 e3 58 da          	vaddsd xmm3,xmm3,xmm2
  1a:	c5 fb 11 5d 18       	vmovsd QWORD PTR [rbp+0x18],xmm3
  1f:	c5 fb 10 5d 00       	vmovsd xmm3,QWORD PTR [rbp+0x0]
  24:	c5 fb 10 55 08       	vmovsd xmm2,QWORD PTR [rbp+0x8]
  29:	c5 e3 59 da          	vmulsd xmm3,xmm3,xmm2
  2d:	c5 fb 11 5d 20       	vmovsd QWORD PTR [rbp+0x20],xmm3
  32:	c5 f8 77             	vzeroupper
  35:	48 81 c4 88 00 00 00 	add    rsp,0x88
  3c:	5b                   	pop    rbx
  3d:	5d                   	pop    rbp
  3e:	c3                   	ret
```

Note that this is the output from an older version, and the more recent versions have a more complex prologue and epilogue. 

