Differentiable Programming (with jax)

In [1]:
import jax
import jax.numpy as np
import jax.lax as lax
import matplotlib.pyplot as plt
In [2]:
# configure matplotlib output
import matplotlib as mpl
mpl.style.use('config/clean.mplstyle') # this loads my personal plotting settings
col = mpl.rcParams['axes.prop_cycle'].by_key()['color']
%matplotlib inline
In [3]:
# if you have an HD display
%config InlineBackend.figure_format = 'retina'
In [4]:
# some warnings can get annoying
import warnings

jax Basics

Almost all of the numpy API is implemented in jax. Arrays have a device that they are stored on. If you have a supported GPU, it will default to this, otherwise it'll go on the CPU (or a TPU if you're really fancy, but I've never used one). Note that there is the notion of a platform which includes cpu, gpu, and tpu. Each platform has a certain number of devices (often just one though).

In [5]:
# the output here will depend on your setup (we'll be using this "x" a lot below too)
x = np.array([1.0, 2.0, 3.0])
GpuDevice(id=0, process_index=0)
In [6]:
# you can print out available devices
[GpuDevice(id=0, process_index=0)]
In [7]:
# you can send data between devices
cpu0, *_ = jax.devices('cpu')
xc = jax.device_put(x, cpu0)

Arrays also have a dtype, which usually corresponds to those found in regular numpy. However, the defaults can be different and some are not supported on certain devices. The most common such difference is that float64 is not supported unless you explicity enable it.

In [8]:
# create some arrays
print(np.arange(3, dtype='float32').dtype)
print(np.array([1, 2, 3]).dtype)
print(np.array([1.0, 2, 3]).dtype)
print(np.array([1, 2, 3], dtype='float32').dtype)
In [9]:
# should still be all float32
print(x.dtype, x.astype('float64').dtype)
print(xc.dtype, xc.astype('float64').dtype)
float32 float32
float32 float32

You can enable float64 by running: jax.config.update('jax_enable_x64', True)

Scalar Operations With grad

Let's start with a scalar-to-scalar function derivatives. You can make a new function that returns the gradient by calling grad on a function.

In [10]:
# define a simple function (can be lambda too)
def f(x):
    return x**2
In [11]:
# take the gradient (derivative)
df = jax.grad(f)
DeviceArray(6., dtype=float32, weak_type=True)
In [12]:
# take the second derivative
d2f = jax.grad(df)
DeviceArray(2., dtype=float32, weak_type=True)
In [13]:
# you need to make sure your inputs are floats
except Exception as e:
grad requires real- or complex-valued inputs (input dtype that is a sub-dtype of np.inexact), but got int32. If you want to use Boolean- or integer-valued inputs, use vjp or set allow_int to True.

Up until now, we've only looked at functions of one parameter. Going to many parameters is fine, we just need to tell grad which variable to take the derivative with respect to. We do this by specifying the index of the desired parameter or parameters in the argnums flag. If you give it multiple indices, it will return a list of derivatives. The default is argnums=0.

In [14]:
def g(x, y):
    return y*x**2 + x*y**3
dg = jax.grad(g, argnums=1)
dg(2.0, 3.0)
DeviceArray(58., dtype=float32, weak_type=True)

Vector Operations With grad and jacobian

Now let's turn to vector-to-scalar functions. Here, calling grad will give you a function that returns a vector with the same size as the input dimension. You don't have to specify the dimensions to jax, it'll figure these out the first time the function actually gets called. Because of this, if jax isn't happy about something, it may not give you an error until you actually try to use the output of grad.

In [15]:
# the second term is just to add some complexity
def fv(x):
    return np.sum(x**2) + x[1]*x[2]
DeviceArray(20., dtype=float32)
In [16]:
# we use the "x" defined above as a test input
dfv = jax.grad(fv)
DeviceArray([2., 7., 8.], dtype=float32)
In [17]:
# note that the jacobian is non-diagonal because of the additional term
jfv = jax.jacobian(dfv)
DeviceArray([[2., 0., 0.],
             [0., 2., 1.],
             [0., 1., 2.]], dtype=float32)
In [18]:
# you can't use grad on vector-return functions! but maybe we can find another way...
    dfe = jax.grad(lambda x: x**2)
except Exception as e:
Gradient only defined for scalar-output functions. Output had shape: (3,).

Vectorizing With vmap

Vectorization is often necessary for performance in Python, especially in jax. Fortunately, we can do this easily with vmap. This function will take a input function and map it along a new dimension of your choice. For instance, you can use it to turn a scalar-to-scalar function into a vector-to-vector function. But there are many more possibilities.

In [19]:
# the most basic usage on the "f" defined above
fv2 = jax.vmap(f)
DeviceArray([1., 4., 9.], dtype=float32)
In [20]:
# now we can do the element-by-element gradient of a vector-return function
dfv2 = jax.vmap(jax.grad(f))
DeviceArray([2., 4., 6.], dtype=float32)

Just-In-Time Compilation With jit

This is an approach to speeding up code by compiling it just when it is needed (rather than beforehand, like in C or C++). It's used in JavaScript, MATLAB, and other places. Calling jit on a function will return another function that does the same thing but faster. As with grad, it only actually does the compilation when you run the

In [21]:
# compile the vector gradient function above
jdfv = jax.jit(dfv)
DeviceArray([2., 7., 8.], dtype=float32)
In [22]:
# give it a much bigger vector for testing
x2 = np.linspace(1.0, 5.0, 100000)
In [23]:
# first do the un-jit-ed version
%timeit -n 100 dfv(x2)
6.8 ms ± 1.57 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
In [24]:
# now the jitted version (run twice for true comparison)
%timeit -n 100 jdfv(x2)
The slowest run took 34.34 times longer than the fastest. This could mean that an intermediate result is being cached.
111 µs ± 225 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

On my computer this gives ~300x speed improvement. May vary from CPU to GPU though.

Trees in jax

Trees are very cool! They're basically arbitrarily nested dicts and lists whose leaves can be any type. If you've seen JSON before, it's very similar to that. Why are trees cool? Just as numpy lets you perform operations over an entire array at once, trees let you do the same over entire trees as well. This let's you avoid a lot of clunky and error-prone packing and unpacking.

In [25]:
# first let's define a very simple tree, which is just a dict of arrays
xd = {
    'x': np.linspace(1.0, 5.0, 5),
    'y': np.linspace(10.0, 19.0, 10),
{'x': DeviceArray([1., 2., 3., 4., 5.], dtype=float32),
 'y': DeviceArray([10., 11., 12., 13., 14., 15., 16., 17., 18., 19.], dtype=float32)}
In [26]:
# here's a function that operates on such a tree
# you could imagine specifying a model like this
# where the input is a dict of parameters or data points
def ft(d):
    xsum = np.sum(d['x']**3)
    ysum = np.sum(d['y']**2)
    return xsum + ysum
In [27]:
# now we can take a grad with respect to a tree and get a tree of the same shape back!
dft = jax.grad(ft)
{'x': DeviceArray([ 3., 12., 27., 48., 75.], dtype=float32),
 'y': DeviceArray([20., 22., 24., 26., 28., 30., 32., 34., 36., 38.], dtype=float32)}

Sometimes it's easier to think in scalar terms them vectorize things. Suppose that our model really just operated on some x and y and we want to run the model for many such pairs. Then we can first define a slightly different function in purely scalar terms. I'm also adding in a separate parameter $\alpha$ to make things more interesting

In [28]:
def fts(d):
    x, y = d['x'], d['y']
    return x**3 + y**2

Then we use vmap to broadcast over the tree d. Below is what we'd write if we're interested in doing 10 distinct parameter sets.

In [29]:
xd1 = {'x': np.linspace(0, 1, 10), 'y': np.linspace(2, 3, 10)}
ftv1 = jax.vmap(fts)
DeviceArray([ 4.       ,  4.4581623,  4.949246 ,  5.481481 ,  6.0631013,
              6.702332 ,  7.4074063,  8.186558 ,  9.04801  , 10.       ],            dtype=float32)

But sometimes we just wany to vary one parameter at a time. As with grad, we need to tell it which dimensions to operate on, this time using in_axes. But it's a bit more complicated

In [30]:
xd2 = {'x': np.linspace(0, 1, 10), 'y': 2.0}
ftv2 = jax.vmap(fts, in_axes=({'x': 0, 'y': None},))
DeviceArray([4.       , 4.001372 , 4.010974 , 4.037037 , 4.0877914,
             4.171468 , 4.296296 , 4.4705076, 4.702332 , 5.       ],            dtype=float32)

Notice that in_axes is a list of the same length as the number of function arguments (in this case 1). We then have to mirror the structure of the tree itself to specify that we want to map x over axis 0 and not map y at all.

Loops and Iteration

So... loops work a little differently in jax. It's okay to write a for loop actually, but not if you want to ultimately jit the thing. First, as we get into more advanced usage, there's a great overview of common "gotchas" in jax here: JAX - The Sharp Bits.

In [31]:
# very slow way to compute x**n
def fl(x, n):
    out = 1
    for i in range(n):
        out = x*out
    return out
In [32]:
    jfl = jax.jit(fl)
    jfl(2.0, 3)
except Exception as e:
The __index__() method was called on the JAX Tracer object Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerIntegerConversionError

Not a very helpful error message! But basically we tried to use a function argument as the loop bounds and then jit it. One way to take care of this is to specify that n is mostly fixed using static_argnums. This will work but will recompile for each distinct value of n that is seen.

In [33]:
jfl = jax.jit(fl, static_argnums=1)
jfl(2.0, 3)
DeviceArray(8., dtype=float32, weak_type=True)

The other way is to use one of the looping tools from lax such as scan or fori_loop. First, let's try it using fori_loop

In [34]:
def fl_fori(x, n):
    fl_x = lambda i, v: x*v
    return lax.fori_loop(0, n, fl_x, 1.0)
jfl_fori = jax.jit(fl_fori)
jfl_fori(2.0, 3)
DeviceArray(8., dtype=float32, weak_type=True)

The scan function operates similarly, in that it carries a running state, but instead of just getting i at each iteration, you can get a slice of your input data. You also get the whole history of outputs at the end, rather than just the last value. The function you pass should accept a running value (v below) and a state (i below). Either of these can be arbirary tree types.

In [35]:
def fl_scan(x, n):
    fl_x = lambda v, i: (x*v, x*v)
    tvec = np.arange(n)
    return lax.scan(fl_x, 1.0, tvec)
jfl_scan = jax.jit(fl_scan, static_argnums=1)
jfl_scan(2.0, 3)
(DeviceArray(8., dtype=float32, weak_type=True),
 DeviceArray([2., 4., 8.], dtype=float32, weak_type=True))

The return at the end is the final value of the running state, as well as the history of values. Ok, so both of these were designed for fixed length iterations. For variable length iterations, we can use lax.while_loop. This accepts a "condition" function that tells it whether to continue or not and a function that updates the running state.

In [36]:
def fl_while(x, n):
    f_cond = lambda iv: iv[0] < n
    fl_x = lambda iv: (iv[0] + 1, x*iv[1])
    return lax.while_loop(f_cond, fl_x, (0, 1.0))
jfl_while = jax.jit(fl_while)
jfl_while(2.0, 3)
(DeviceArray(3, dtype=int32, weak_type=True),
 DeviceArray(8., dtype=float32, weak_type=True))

Here we just reimplemented a for loop by including i in the running state and conditioning on it, but you could image the condition being something like a convergence check too.

Random Numbers

Random numbers work a bit different in jax due to the different needs of Deep Learning folks. You need to generate a "key" first using a "seed" as input (below, 42 a very common seed). You then pass this key to various random number routines that are similar to those in numpy.random.

In [37]:
key = jax.random.PRNGKey(42)
r1 = jax.random.normal(key)
r2 = jax.random.normal(key)
print(r1, r2)
[ 0 42]
-0.18471177 -0.18471177

Notice the same key leads to the same random numbers! So there must be more. To advance your key to something new you can use the split function, which is also deterministic. This will return two new keys.

In [38]:
key, subkey = jax.random.split(key)
print(key, subkey)
[2465931498 3679230171] [255383827 267815257]

So one pattern is to keep a running key and generate a new subkey when you need another random number

In [39]:
jax.random.uniform(subkey, (5,))
DeviceArray([0.2899989 , 0.82748747, 0.22911513, 0.2819779 , 0.8697449 ],            dtype=float32)
In [ ]: