`jax`

)¶In [1]:

```
import jax
import jax.numpy as np
import jax.lax as lax
import matplotlib.pyplot as plt
```

In [2]:

```
# configure matplotlib output
import matplotlib as mpl
mpl.style.use('config/clean.mplstyle') # this loads my personal plotting settings
col = mpl.rcParams['axes.prop_cycle'].by_key()['color']
%matplotlib inline
```

In [3]:

```
# if you have an HD display
%config InlineBackend.figure_format = 'retina'
```

In [4]:

```
# some warnings can get annoying
import warnings
warnings.filterwarnings('ignore')
```

`jax`

Basics¶Almost all of the `numpy`

API is implemented in `jax`

. Arrays have a device that they are stored on. If you have a supported GPU, it will default to this, otherwise it'll go on the CPU (or a TPU if you're really fancy, but I've never used one). Note that there is the notion of a **platform** which includes `cpu`

, `gpu`

, and `tpu`

. Each platform has a certain number of devices (often just one though).

In [5]:

```
# the output here will depend on your setup (we'll be using this "x" a lot below too)
x = np.array([1.0, 2.0, 3.0])
x.device()
```

Out[5]:

In [6]:

```
# you can print out available devices
print(jax.devices('cpu'))
print(jax.devices('gpu'))
```

In [7]:

```
# you can send data between devices
cpu0, *_ = jax.devices('cpu')
xc = jax.device_put(x, cpu0)
xc.device()
```

Out[7]:

Arrays also have a `dtype`

, which usually corresponds to those found in regular `numpy`

. However, the defaults can be different and some are not supported on certain devices. The most common such difference is that `float64`

is not supported unless you explicity enable it.

In [8]:

```
# create some arrays
print(np.arange(3).dtype)
print(np.arange(3, dtype='float32').dtype)
print(np.array([1, 2, 3]).dtype)
print(np.array([1.0, 2, 3]).dtype)
print(np.array([1, 2, 3], dtype='float32').dtype)
```

In [9]:

```
# should still be all float32
print(x.dtype, x.astype('float64').dtype)
print(xc.dtype, xc.astype('float64').dtype)
```

You can enable `float64`

by running: `jax.config.update('jax_enable_x64', True)`

`grad`

¶Let's start with a scalar-to-scalar function derivatives. You can make a new function that returns the gradient by calling `grad`

on a function.

In [10]:

```
# define a simple function (can be lambda too)
def f(x):
return x**2
f(3.0)
```

Out[10]:

In [11]:

```
# take the gradient (derivative)
df = jax.grad(f)
df(3.0)
```

Out[11]:

In [12]:

```
# take the second derivative
d2f = jax.grad(df)
d2f(3.0)
```

Out[12]:

In [13]:

```
# you need to make sure your inputs are floats
try:
df(3)
except Exception as e:
print(e)
```

Up until now, we've only looked at functions of one parameter. Going to many parameters is fine, we just need to tell `grad`

which variable to take the derivative with respect to. We do this by specifying the index of the desired parameter or parameters in the `argnums`

flag. If you give it multiple indices, it will return a list of derivatives. The default is `argnums=0`

.

In [14]:

```
def g(x, y):
return y*x**2 + x*y**3
dg = jax.grad(g, argnums=1)
dg(2.0, 3.0)
```

Out[14]:

`grad`

and `jacobian`

¶Now let's turn to vector-to-scalar functions. Here, calling `grad`

will give you a function that returns a vector with the same size as the input dimension. You don't have to specify the dimensions to `jax`

, it'll figure these out the first time the function actually gets called. Because of this, if `jax`

isn't happy about something, it may not give you an error until you actually try to use the output of `grad`

.

In [15]:

```
# the second term is just to add some complexity
def fv(x):
return np.sum(x**2) + x[1]*x[2]
fv(x)
```

Out[15]:

In [16]:

```
# we use the "x" defined above as a test input
dfv = jax.grad(fv)
dfv(x)
```

Out[16]:

In [17]:

```
# note that the jacobian is non-diagonal because of the additional term
jfv = jax.jacobian(dfv)
jfv(x)
```

Out[17]:

In [18]:

```
# you can't use grad on vector-return functions! but maybe we can find another way...
try:
dfe = jax.grad(lambda x: x**2)
dfe(x)
except Exception as e:
print(e)
```

`vmap`

¶Vectorization is often necessary for performance in Python, especially in `jax`

. Fortunately, we can do this easily with `vmap`

. This function will take a input function and map it along a new dimension of your choice. For instance, you can use it to turn a scalar-to-scalar function into a vector-to-vector function. But there are many more possibilities.

In [19]:

```
# the most basic usage on the "f" defined above
fv2 = jax.vmap(f)
fv2(x)
```

Out[19]:

In [20]:

```
# now we can do the element-by-element gradient of a vector-return function
dfv2 = jax.vmap(jax.grad(f))
dfv2(x)
```

Out[20]:

`jit`

¶This is an approach to speeding up code by compiling it just when it is needed (rather than beforehand, like in C or C++). It's used in JavaScript, MATLAB, and other places. Calling `jit`

on a function will return another function that does the same thing but faster.
As with `grad`

, it only actually does the compilation when you run the

In [21]:

```
# compile the vector gradient function above
jdfv = jax.jit(dfv)
jdfv(x)
```

Out[21]:

In [22]:

```
# give it a much bigger vector for testing
x2 = np.linspace(1.0, 5.0, 100000)
```

In [23]:

```
# first do the un-jit-ed version
%timeit -n 100 dfv(x2)
```

In [24]:

```
# now the jitted version (run twice for true comparison)
%timeit -n 100 jdfv(x2)
```

On my computer this gives ~300x speed improvement. May vary from CPU to GPU though.

`jax`

¶Trees are very cool! They're basically arbitrarily nested `dict`

s and `list`

s whose leaves can be any type. If you've seen JSON before, it's very similar to that. Why are trees cool? Just as `numpy`

lets you perform operations over an entire array at once, trees let you do the same over entire trees as well. This let's you avoid a lot of clunky and error-prone packing and unpacking.

In [25]:

```
# first let's define a very simple tree, which is just a dict of arrays
xd = {
'x': np.linspace(1.0, 5.0, 5),
'y': np.linspace(10.0, 19.0, 10),
}
xd
```

Out[25]:

In [26]:

```
# here's a function that operates on such a tree
# you could imagine specifying a model like this
# where the input is a dict of parameters or data points
def ft(d):
xsum = np.sum(d['x']**3)
ysum = np.sum(d['y']**2)
return xsum + ysum
```

In [27]:

```
# now we can take a grad with respect to a tree and get a tree of the same shape back!
dft = jax.grad(ft)
dft(xd)
```

Out[27]:

Sometimes it's easier to think in scalar terms them vectorize things. Suppose that our model really just operated on some `x`

and `y`

and we want to run the model for many such pairs. Then we can first define a slightly different function in purely scalar terms. I'm also adding in a separate parameter $\alpha$ to make things more interesting

In [28]:

```
def fts(d):
x, y = d['x'], d['y']
return x**3 + y**2
```

Then we use `vmap`

to broadcast over the tree `d`

. Below is what we'd write if we're interested in doing 10 distinct parameter sets.

In [29]:

```
xd1 = {'x': np.linspace(0, 1, 10), 'y': np.linspace(2, 3, 10)}
ftv1 = jax.vmap(fts)
ftv1(xd1)
```

Out[29]:

But sometimes we just wany to vary one parameter at a time. As with `grad`

, we need to tell it which dimensions to operate on, this time using `in_axes`

. But it's a bit more complicated

In [30]:

```
xd2 = {'x': np.linspace(0, 1, 10), 'y': 2.0}
ftv2 = jax.vmap(fts, in_axes=({'x': 0, 'y': None},))
ftv2(xd2)
```

Out[30]:

Notice that `in_axes`

is a list of the same length as the number of function arguments (in this case 1). We then have to mirror the structure of the tree itself to specify that we want to map `x`

over axis 0 and not map `y`

at all.

So... loops work a little differently in `jax`

. It's okay to write a `for`

loop actually, but not if you want to ultimately `jit`

the thing. First, as we get into more advanced usage, there's a great overview of common "gotchas" in `jax`

here: JAX - The Sharp Bits.

In [31]:

```
# very slow way to compute x**n
def fl(x, n):
out = 1
for i in range(n):
out = x*out
return out
```

In [32]:

```
try:
jfl = jax.jit(fl)
jfl(2.0, 3)
except Exception as e:
print(e)
```

Not a very helpful error message! But basically we tried to use a function argument as the loop bounds and then `jit`

it. One way to take care of this is to specify that `n`

is mostly fixed using `static_argnums`

. This will work but will recompile for each distinct value of `n`

that is seen.

In [33]:

```
jfl = jax.jit(fl, static_argnums=1)
jfl(2.0, 3)
```

Out[33]:

The other way is to use one of the looping tools from `lax`

such as `scan`

or `fori_loop`

. First, let's try it using `fori_loop`

In [34]:

```
def fl_fori(x, n):
fl_x = lambda i, v: x*v
return lax.fori_loop(0, n, fl_x, 1.0)
jfl_fori = jax.jit(fl_fori)
jfl_fori(2.0, 3)
```

Out[34]:

The `scan`

function operates similarly, in that it carries a running state, but instead of just getting `i`

at each iteration, you can get a slice of your input data. You also get the whole history of outputs at the end, rather than just the last value. The function you pass should accept a running value (`v`

below) and a state (`i`

below). Either of these can be arbirary tree types.

In [35]:

```
def fl_scan(x, n):
fl_x = lambda v, i: (x*v, x*v)
tvec = np.arange(n)
return lax.scan(fl_x, 1.0, tvec)
jfl_scan = jax.jit(fl_scan, static_argnums=1)
jfl_scan(2.0, 3)
```

Out[35]:

The return at the end is the final value of the running state, as well as the history of values. Ok, so both of these were designed for fixed length iterations. For variable length iterations, we can use `lax.while_loop`

. This accepts a "condition" function that tells it whether to continue or not and a function that updates the running state.

In [36]:

```
def fl_while(x, n):
f_cond = lambda iv: iv[0] < n
fl_x = lambda iv: (iv[0] + 1, x*iv[1])
return lax.while_loop(f_cond, fl_x, (0, 1.0))
jfl_while = jax.jit(fl_while)
jfl_while(2.0, 3)
```

Out[36]:

Here we just reimplemented a for loop by including `i`

in the running state and conditioning on it, but you could image the condition being something like a convergence check too.

Random numbers work a bit different in `jax`

due to the different needs of Deep Learning folks. You need to generate a "key" first using a "seed" as input (below, `42`

a very common seed). You then pass this key to various random number routines that are similar to those in `numpy.random`

.

In [37]:

```
key = jax.random.PRNGKey(42)
r1 = jax.random.normal(key)
r2 = jax.random.normal(key)
print(key)
print(r1, r2)
```

Notice the same key leads to the same random numbers! So there must be more. To advance your key to something new you can use the `split`

function, which is also deterministic. This will return two new keys.

In [38]:

```
key, subkey = jax.random.split(key)
print(key, subkey)
```

So one pattern is to keep a running `key`

and generate a new `subkey`

when you need another random number

In [39]:

```
jax.random.uniform(subkey, (5,))
```

Out[39]:

In [ ]:

```
```