Solving Models (with jax)

In [1]:
import jax
import jax.numpy as np
import jax.lax as lax
import jax.scipy.stats as stats
import matplotlib.pyplot as plt
In [2]:
# configure matplotlib output
import matplotlib as mpl'config/clean.mplstyle') # this loads my personal plotting settings
col = mpl.rcParams['axes.prop_cycle'].by_key()['color']
%matplotlib inline
In [3]:
# if you have an HD display
%config InlineBackend.figure_format = 'retina'
In [4]:
# some warnings can get annoying
import warnings
In [5]:
import tools.models as mod

jax Basics

Almost all of the numpy API is implemented in jax. Arrays have a device that they are stored on. If you have a supported GPU, it will default to this, otherwise it'll go on the CPU (or a TPU if you're really fancy, but I've never used one). Note that there is the notion of a platform which includes cpu, gpu, and tpu. Each platform has a certain number of devices (often just one though).

In [6]:
# the output here will depend on your setup (we'll be using this "x" a lot below too)
x = np.array([1.0, 2.0, 3.0])
StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)
In [7]:
# you can print out available devices
[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]
In [8]:
# you can send data between devices
cpu0, *_ = jax.devices('cpu')
xc = jax.device_put(x, cpu0)

Arrays also have a dtype, which usually corresponds to those found in regular numpy. However, the defaults can be different and some are not supported on certain devices. The most common such difference is that float64 is not supported unless you explicity enable it.

In [9]:
# create some arrays
print(np.arange(3, dtype='float32').dtype)
print(np.array([1, 2, 3]).dtype)
print(np.array([1.0, 2, 3]).dtype)
print(np.array([1, 2, 3], dtype='float32').dtype)
In [10]:
# should still be all float32
print(x.dtype, x.astype('float64').dtype)
print(xc.dtype, xc.astype('float64').dtype)
float32 float32
float32 float32

You can enable float64 by running: jax.config.update('jax_enable_x64', True)

Scalar Operations With grad

Let's start with a scalar-to-scalar function derivatives. You can make a new function that returns the gradient by calling grad on a function.

In [11]:
# define a simple function (can be lambda too)
def f(x):
    return x**2
In [12]:
# take the gradient (derivative)
df = jax.grad(f)
DeviceArray(6., dtype=float32, weak_type=True)
In [13]:
# take the second derivative
d2f = jax.grad(df)
DeviceArray(2., dtype=float32, weak_type=True)
In [14]:
# you need to make sure your inputs are floats
except Exception as e:
grad requires real- or complex-valued inputs (input dtype that is a sub-dtype of np.inexact), but got int32. If you want to use Boolean- or integer-valued inputs, use vjp or set allow_int to True.

Up until now, we've only looked at functions of one parameter. Going to many parameters is fine, we just need to tell grad which variable to take the derivative with respect to. We do this by specifying the index of the desired parameter or parameters in the argnums flag. If you give it multiple indices, it will return a list of derivatives. The default is argnums=0.

In [15]:
def g(x, y):
    return y*x**2 + x*y**3
dg = jax.grad(g, argnums=1)
dg(2.0, 3.0)
DeviceArray(58., dtype=float32, weak_type=True)

Vector Operations With grad and jacobian

Now let's turn to vector-to-scalar functions. Here, calling grad will give you a function that returns a vector with the same size as the input dimension. You don't have to specify the dimensions to jax, it'll figure these out the first time the function actually gets called. Because of this, if jax isn't happy about something, it may not give you an error until you actually try to use the output of grad.

In [16]:
# the second term is just to add some complexity
def fv(x):
    return np.sum(x**2) + x[1]*x[2]
DeviceArray(20., dtype=float32)
In [17]:
# we use the "x" defined above as a test input
dfv = jax.grad(fv)
DeviceArray([2., 7., 8.], dtype=float32)
In [18]:
# note that the jacobian is non-diagonal because of the additional term
jfv = jax.jacobian(dfv)
DeviceArray([[2., 0., 0.],
             [0., 2., 1.],
             [0., 1., 2.]], dtype=float32)
In [19]:
# you can't use grad on vector-return functions! but maybe we can find another way...
    dfe = jax.grad(lambda x: x**2)
except Exception as e:
Gradient only defined for scalar-output functions. Output had shape: (3,).

Vectorizing With vmap

Vectorization is often necessary for performance in Python, especially in jax. Fortunately, we can do this easily with vmap. This function will take a input function and map it along a new dimension of your choice. For instance, you can use it to turn a scalar-to-scalar function into a vector-to-vector function. But there are many more possibilities.

In [20]:
# the most basic usage on the "f" defined above
fv2 = jax.vmap(f)
DeviceArray([1., 4., 9.], dtype=float32)
In [21]:
# now we can do the element-by-element gradient of a vector-return function
dfv2 = jax.vmap(jax.grad(f))
DeviceArray([2., 4., 6.], dtype=float32)

Just-In-Time Compilation With jit

This is an approach to speeding up code by compiling it just when it is needed (rather than beforehand, like in C or C++). It's used in JavaScript, MATLAB, and other places. Calling jit on a function will return another function that does the same thing but faster. As with grad, it only actually does the compilation when you run the

In [22]:
# compile the vector gradient function above
jdfv = jax.jit(dfv)
DeviceArray([2., 7., 8.], dtype=float32)
In [23]:
# give it a much bigger vector for testing
x2 = np.linspace(1.0, 5.0, 100000)
In [24]:
# first do the un-jit-ed version
%timeit -n 100 dfv(x2)
3.15 ms ± 808 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
In [25]:
# now the jitted version (run twice for true comparison)
%timeit -n 100 jdfv(x2)
The slowest run took 31.41 times longer than the fastest. This could mean that an intermediate result is being cached.
57.3 µs ± 113 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

On my computer this gives ~300x speed improvement. May vary from CPU to GPU though.

Trees in jax

Trees are very cool! They're basically arbitrarily nested dicts and lists whose leaves can be any type. If you've seen JSON before, it's very similar to that. Why are trees cool? Just as numpy lets you perform operations over an entire array at once, trees let you do the same over entire trees as well. This let's you avoid a lot of clunky and error-prone packing and unpacking.

In [26]:
# first let's define a very simple tree, which is just a dict of arrays
xd = {
    'x': np.linspace(1.0, 5.0, 5),
    'y': np.linspace(10.0, 19.0, 10),
{'x': DeviceArray([1., 2., 3., 4., 5.], dtype=float32),
 'y': DeviceArray([10., 11., 12., 13., 14., 15., 16., 17., 18., 19.], dtype=float32)}
In [27]:
# here's a function that operates on such a tree
# you could imagine specifying a model like this
# where the input is a dict of parameters or data points
def ft(d):
    xsum = np.sum(d['x']**3)
    ysum = np.sum(d['y']**2)
    return xsum + ysum
In [28]:
# now we can take a grad with respect to a tree and get a tree of the same shape back!
dft = jax.grad(ft)
{'x': DeviceArray([ 3., 12., 27., 48., 75.], dtype=float32),
 'y': DeviceArray([20., 22., 24., 26., 28., 30., 32., 34., 36., 38.], dtype=float32)}

Sometimes it's easier to think in scalar terms them vectorize things. Suppose that our model really just operated on some x and y and we want to run the model for many such pairs. Then we can first define a slightly different function in purely scalar terms. I'm also adding in a separate parameter $\alpha$ to make things more interesting

In [29]:
def fts(d):
    x, y = d['x'], d['y']
    return x**3 + y**2

Then we use vmap to broadcast over the tree d. Below is what we'd write if we're interested in doing 10 distinct parameter sets.

In [30]:
xd1 = {'x': np.linspace(0, 1, 10), 'y': np.linspace(2, 3, 10)}
ftv1 = jax.vmap(fts)
DeviceArray([ 4.       ,  4.4581623,  4.949246 ,  5.481481 ,  6.0631013,
              6.702332 ,  7.4074063,  8.186558 ,  9.04801  , 10.       ],            dtype=float32)

But sometimes we just wany to vary one parameter at a time. As with grad, we need to tell it which dimensions to operate on, this time using in_axes. But it's a bit more complicated

In [31]:
xd2 = {'x': np.linspace(0, 1, 10), 'y': 2.0}
ftv2 = jax.vmap(fts, in_axes=({'x': 0, 'y': None},))
DeviceArray([4.       , 4.001372 , 4.010974 , 4.037037 , 4.0877914,
             4.171468 , 4.296296 , 4.4705076, 4.702332 , 5.       ],            dtype=float32)

Notice that in_axes is a list of the same length as the number of function arguments (in this case 1). We then have to mirror the structure of the tree itself to specify that we want to map x over axis 0 and not map y at all.

Loops and Iteration

So... loops work a little differently in jax. It's okay to write a for loop actually, but not if you want to ultimately jit the thing. First, as we get into more advanced usage, there's a great overview of common "gotchas" in jax here: JAX - The Sharp Bits.

In [32]:
# very slow way to compute x**n
def fl(x, n):
    out = 1
    for i in range(n):
        out = x*out
    return out
In [33]:
    jfl = jax.jit(fl)
    jfl(2.0, 3)
except Exception as e:
The __index__() method was called on the JAX Tracer object Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>

Not a very helpful error message! But basically we tried to use a function argument as the loop bounds and then jit it. One way to take care of this is to specify that n is mostly fixed using static_argnums. This will work but will recompile for each distinct value of n that is seen.

In [34]:
jfl = jax.jit(fl, static_argnums=1)
jfl(2.0, 3)
DeviceArray(8., dtype=float32, weak_type=True)

The other way is to use one of the looping tools from lax such as scan or fori_loop. First, let's try it using fori_loop

In [35]:
def fl_fori(x, n):
    fl_x = lambda i, v: x*v
    return lax.fori_loop(0, n, fl_x, 1.0)
jfl_fori = jax.jit(fl_fori)
jfl_fori(2.0, 3)
DeviceArray(8., dtype=float32, weak_type=True)

The scan function operates similarly, in that it carries a running state, but instead of just getting i at each iteration, you can get a slice of your input data. You also get the whole history of outputs at the end, rather than just the last value. The function you pass should accept a running value (v below) and a state (i below). Either of these can be arbirary tree types.

In [36]:
def fl_scan(x, n):
    fl_x = lambda v, i: (x*v, x*v)
    tvec = np.arange(n)
    return lax.scan(fl_x, 1.0, tvec)
jfl_scan = jax.jit(fl_scan, static_argnums=1)
jfl_scan(2.0, 3)
(DeviceArray(8., dtype=float32, weak_type=True),
 DeviceArray([2., 4., 8.], dtype=float32, weak_type=True))

The return at the end is the final value of the running state, as well as the history of values. Ok, so both of these were designed for fixed length iterations. For variable length iterations, we can use lax.while_loop. This accepts a "condition" function that tells it whether to continue or not and a function that updates the running state.

In [37]:
def fl_while(x, n):
    f_cond = lambda iv: iv[0] < n
    fl_x = lambda iv: (iv[0] + 1, x*iv[1])
    return lax.while_loop(f_cond, fl_x, (0, 1.0))
jfl_while = jax.jit(fl_while)
jfl_while(2.0, 3)
(DeviceArray(3, dtype=int32, weak_type=True),
 DeviceArray(8., dtype=float32, weak_type=True))

Here we just reimplemented a for loop by including i in the running state and conditioning on it, but you could image the condition being something like a convergence check too.

Random Numbers

Random numbers work a bit different in jax due to the different needs of Deep Learning folks. You need to generate a "key" first using a "seed" as input (below, 42 a very common seed). You then pass this key to various random number routines that are similar to those in numpy.random.

In [38]:
key = jax.random.PRNGKey(42)
r1 = jax.random.normal(key)
r2 = jax.random.normal(key)
print(r1, r2)
[ 0 42]
-0.18471177 -0.18471177

Notice the same key leads to the same random numbers! So there must be more. To advance your key to something new you can use the split function, which is also deterministic. This will return two new keys.

In [39]:
key, subkey = jax.random.split(key)
print(key, subkey)
[2465931498 3679230171] [255383827 267815257]

So one pattern is to keep a running key and generate a new subkey when you need another random number

In [40]:
jax.random.uniform(subkey, (5,))
DeviceArray([0.2899989 , 0.82748747, 0.22911513, 0.2819779 , 0.8697449 ],            dtype=float32)


Optimization typically involves looping in some way, so we need to be a bit careful when using jax. But you can achieve all of the same results with much more flexibility and speed at runtime. Here we're using simple gradient descent for opimization.

Simple Function

In [41]:
# hardcoded step size
Δ = 0.1
In [42]:
# simple function
f = lambda x: -(x-1)**2 + 1
df = jax.grad(f)
vdf = jax.vmap(df)
In [43]:
# visualize value and gradient
xgrid = np.linspace(0, 2, 100)
fig, (ax0, ax1) = plt.subplots(ncols=2, figsize=(8, 3))
ax0.plot(xgrid, f(xgrid))
ax1.plot(xgrid, vdf(xgrid));
No description has been provided for this image
In [44]:
# we define an update step function and one to apply it with scan
def update_scalar(x, t):
    xp = x + Δ*df(x)
    return xp, xp
def maximize_func(x0, K=50):
    tvec = np.arange(K) # this just has to be the right shape (K)
    x1, xvec = lax.scan(update_scalar, x0, tvec)
    return x1, xvec
In [45]:
# jit'ed scalar and vectorized versions
jmax_func = jax.jit(maximize_func)
vmax_func = jax.jit(jax.vmap(maximize_func))
In [46]:
# run for just one starting value
x1, xvec = jmax_func(0.5)
No description has been provided for this image
In [47]:
# run for many starting values simultaneously
key = jax.random.PRNGKey(42)
xrand = jax.random.normal(key, (100,))
x1, xvec = vmax_func(xrand)
No description has been provided for this image


Now let's add in an additional parameter $\alpha$ which influences the optimization

In [48]:
# parameterized function (max=α)
f1 = lambda x, α: -(x-α)**2
df1 = jax.grad(f1)
vdf1 = jax.vmap(df1)
In [49]:
# we just need to apply α to function before optimizing
def update_scalar(α, x, t):
    xp = x + Δ*df1(x, α)
    return xp, xp
def maximize_func(α, x0, K=50):
    tvec = np.arange(K)
    upd = lambda x, t: update_scalar(α, x, t)
    x1, xvec = lax.scan(upd, x0, tvec)
    return x1, xvec
In [50]:
jmax_func = jax.jit(maximize_func)
vmax_func = jax.jit(jax.vmap(maximize_func, in_axes=(0, None)))
In [51]:
# run the scalar optimization for α = 2
x1, xvec = jmax_func(2.0, 0.5)
No description has been provided for this image
In [52]:
# run the optimization for a range of α's
αgrid = np.linspace(-1, 1, 20)
x1, xvec = vmax_func(αgrid, 0.0)
No description has been provided for this image

We can even take the derivative of the optimal solution with respect to $\alpha$, which should always be $1$ in our case since $x^{\ast}(\alpha) = \alpha$.

In [53]:
dmax_func_out = jax.grad(lambda α, x0: maximize_func(α, x0)[0], argnums=0)
vmax_func_out = jax.vmap(dmax_func_out, in_axes=(0, None))
vmax_func_out(αgrid, 0.5)
DeviceArray([0.99998564, 0.99998564, 0.99998564, 0.99998564, 0.99998564,
             0.99998564, 0.99998564, 0.99998564, 0.99998564, 0.99998564,
             0.99998564, 0.99998564, 0.99998564, 0.99998564, 0.99998564,
             0.99998564, 0.99998564, 0.99998564, 0.99998564, 0.99998564],            dtype=float32)

Value Functions

Below is an implementation of simple value function iteration using jax for the neoclassical growth model. This makes use of trees for parameter values and the scan function for iteration. We start by defining parameters and the utility and production functions, as well as constructing the capital grid and finding the analytic steady state.

In [54]:
from functools import partial
In [55]:
# fixed params
T = 300 # time periods
N = 100 # grid size
f_lo, f_hi = 0.5, 2.0 # proportional grid range
In [56]:
# simple parameters
par0 = {
    'β': 0.95,
    'δ': 0.1,
    'α': 0.35,
    'z': 1.0,
In [57]:
# find steady state
def get_kss(par):
    β, δ, z, α = par['β'], par['δ'], par['z'], par['α']
    rhs = (1-β)/β + δ
    k = (α*z/rhs)**(1/(1-α))
    return k
k_ss = get_kss(par0)
In [58]:
# construct capital grid
k_min = f_lo*k_ss
k_max = f_hi*k_ss
k_grid = np.linspace(k_min, k_max, N)
In [59]:
# defined functions
def util(c, ϵ=1e-6):
    c1 = np.maximum(ϵ, c)
    return np.log(c1)
def prod(k, z, α):
    return z*k**α

Now we define the actual updating rule and the outer solving function.

In [60]:
# the value function update function
def value(par, grid, st, tv):
    β = par['β']
    cp = grid['cp']
    vn = st['vn']

    # calculate optimal investment
    v1 = util(cp) + β*vn[None,:]
    ip = np.argmax(v1, axis=1)
    vp = np.max(v1, axis=1)

    # apply at values
    kp = k_grid[ip]

    # compute update errors
    err = np.max(np.abs(vp-vn))

    # return state and output
    stp = {
        'vn': vp,
    out = {
        'v': vp,
        'kp': kp,
        'err': err,
    return stp, out
In [61]:
# the actual solver (T is the number of iterations)
def solve(par, T):
    α = par['α']
    δ = par['δ']
    z = par['z']

    # precompute grid values
    y_grid = prod(k_grid, z, α)
    yd_grid = y_grid + (1-δ)*k_grid
    cp_grid = yd_grid[:,None] - k_grid[None,:]

    # partially apply grid
    grid = {
        'cp': cp_grid,
    value1 = partial(value, par, grid)

    # scan over time (backwards)
    st0 = {
        'vn': util(y_grid),
    tv = {
        't': np.arange(T)[::-1],
    last, path = lax.scan(value1, st0, tv)

    return path

Now let's time the model and see how the results look!

In [62]:
# jit and run once to ensure compilation
jsolve = jax.jit(solve, static_argnums=(1,))
ret = jsolve(par0, T)
In [63]:
# get some timing numbers
%timeit -r 10 -n 10 jsolve(par0, T)
3.73 ms ± 111 µs per loop (mean ± std. dev. of 10 runs, 10 loops each)
In [64]:
# plot the log error over the whole run
No description has been provided for this image
In [65]:
# make sure our policy function lines up with the analytic steady state
fig, ax = plt.subplots()
ax.plot(k_grid, ret['kp'][-1,:]-k_grid);
ax.hlines(0, *ax.get_xlim(), linestyle='--', linewidth=1, color='k')
ax.scatter(k_ss, 0, color='k', zorder=10);
No description has been provided for this image

State Distributions

The flip side to solving value functions is constructing state distributions, either as they evolve or in steady state. Let's consider the system described in the lecture slides.

In [66]:
# algorithm params
N = 100
Δ = 0.1
In [67]:
# model params
λ = 1.1
τ = 0.1
κ = 0.02
μ = 0.0
σ = 0.3
g = np.log(λ)*τ
In [68]:
# construct grid
qmin, qmax = 0.0, 5.0
qbins = np.linspace(qmin, qmax, N+1)
qlow, qhigh = qbins[:-1], qbins[1:]
qmids = 0.5*(qlow+qhigh)
qdiff = qhigh - qlow
In [69]:
# entry distribution
ecdf = stats.norm.cdf(np.log(qbins), μ, σ)
epmf = mod.normed(np.diff(ecdf))
plt.plot(qmids, epmf);
No description has been provided for this image

Now we need to construct the transition matrices. The non-exit conditional transition matrix is called qstay, while the fully inclusive transition matrix is called qtmat. Notice that it's way easier to write split_range with vmap rather than directly.

In [70]:
# range distributor
def split_range(q1, q2):
    return np.maximum(0, np.minimum(q2, qhigh) - np.maximum(q1, qlow))
v_split_range = jax.vmap(split_range)
In [71]:
# various rates for a given Δ
gdel = 1 - Δ*g
λdel = λ*gdel
τdel = Δ*τ
κdel = Δ*κ
In [72]:
# construct the transition matrices!
qdrop = v_split_range(gdel*qlow, gdel*qhigh)/(gdel*qdiff[:, None])
qjump = v_split_range(λdel*qlow, λdel*qhigh)/(λdel*qdiff[:, None])
qstay = τdel*qjump + (1-τdel)*qdrop
qtmat = κdel*epmf[None, :] + (1-κdel)*qstay
qtmat = mod.normed(qtmat, axis=1)

Now let's just use the brute force method of manually iterating on transition matrices. Here we do so for both the no-exit and yes-exit versions. Notice that there is no steady state in the no-exit version.

In [73]:
def iterate_mat(dist, tmat, T=4000):
    hist = [dist]
    for i in range(T):
        dist = dist @ tmat
    return np.vstack(hist)
In [74]:
hist0 = iterate_mat(epmf, qstay)
hist1 = iterate_mat(epmf, qtmat)
dist_sim = hist1[-1, :]
fig, (ax0, ax1) = plt.subplots(ncols=2, figsize=(10, 4))
ax0.plot(qmids, hist0[::1000, :].T);
ax1.plot(qmids, hist1[::1000, :].T);
No description has been provided for this image

Now let's use an eigen-solver to find the eigenvector associated with the unit eigenvalue.

In [75]:
eigval, eigvec = mod.eig(qtmat.T)
eigidx = np.abs(eigval-1).argmin()
dist_eig = mod.normed(eigvec[:, eigidx].real)
plt.plot(qmids, dist_eig, qmids, dist_sim, alpha=0.5);
No description has been provided for this image

Finally, let's find the null space of $A^{\prime} - I$, which is the solution to $x A = x$.

In [76]:
qtmat1 = qtmat.T - np.eye(N)
dist_nul = mod.normed(mod.null_space(qtmat1).squeeze())
plt.plot(qmids, dist_nul, qmids, dist_sim, alpha=0.5);
No description has been provided for this image
In [ ]: