jax
)¶import jax
import jax.numpy as np
import jax.lax as lax
import jax.scipy.stats as stats
import matplotlib.pyplot as plt
# configure matplotlib output
import matplotlib as mpl
mpl.style.use('config/clean.mplstyle') # this loads my personal plotting settings
col = mpl.rcParams['axes.prop_cycle'].by_key()['color']
%matplotlib inline
# if you have an HD display
%config InlineBackend.figure_format = 'retina'
# some warnings can get annoying
import warnings
warnings.filterwarnings('ignore')
import tools.models as mod
jax
Basics¶Almost all of the numpy
API is implemented in jax
. Arrays have a device that they are stored on. If you have a supported GPU, it will default to this, otherwise it'll go on the CPU (or a TPU if you're really fancy, but I've never used one). Note that there is the notion of a platform which includes cpu
, gpu
, and tpu
. Each platform has a certain number of devices (often just one though).
# the output here will depend on your setup (we'll be using this "x" a lot below too)
x = np.array([1.0, 2.0, 3.0])
x.device()
# you can print out available devices
print(jax.devices('cpu'))
print(jax.devices('gpu'))
# you can send data between devices
cpu0, *_ = jax.devices('cpu')
xc = jax.device_put(x, cpu0)
xc.device()
Arrays also have a dtype
, which usually corresponds to those found in regular numpy
. However, the defaults can be different and some are not supported on certain devices. The most common such difference is that float64
is not supported unless you explicity enable it.
# create some arrays
print(np.arange(3).dtype)
print(np.arange(3, dtype='float32').dtype)
print(np.array([1, 2, 3]).dtype)
print(np.array([1.0, 2, 3]).dtype)
print(np.array([1, 2, 3], dtype='float32').dtype)
# should still be all float32
print(x.dtype, x.astype('float64').dtype)
print(xc.dtype, xc.astype('float64').dtype)
You can enable float64
by running: jax.config.update('jax_enable_x64', True)
grad
¶Let's start with a scalar-to-scalar function derivatives. You can make a new function that returns the gradient by calling grad
on a function.
# define a simple function (can be lambda too)
def f(x):
return x**2
f(3.0)
# take the gradient (derivative)
df = jax.grad(f)
df(3.0)
# take the second derivative
d2f = jax.grad(df)
d2f(3.0)
# you need to make sure your inputs are floats
try:
df(3)
except Exception as e:
print(e)
Up until now, we've only looked at functions of one parameter. Going to many parameters is fine, we just need to tell grad
which variable to take the derivative with respect to. We do this by specifying the index of the desired parameter or parameters in the argnums
flag. If you give it multiple indices, it will return a list of derivatives. The default is argnums=0
.
def g(x, y):
return y*x**2 + x*y**3
dg = jax.grad(g, argnums=1)
dg(2.0, 3.0)
grad
and jacobian
¶Now let's turn to vector-to-scalar functions. Here, calling grad
will give you a function that returns a vector with the same size as the input dimension. You don't have to specify the dimensions to jax
, it'll figure these out the first time the function actually gets called. Because of this, if jax
isn't happy about something, it may not give you an error until you actually try to use the output of grad
.
# the second term is just to add some complexity
def fv(x):
return np.sum(x**2) + x[1]*x[2]
fv(x)
# we use the "x" defined above as a test input
dfv = jax.grad(fv)
dfv(x)
# note that the jacobian is non-diagonal because of the additional term
jfv = jax.jacobian(dfv)
jfv(x)
# you can't use grad on vector-return functions! but maybe we can find another way...
try:
dfe = jax.grad(lambda x: x**2)
dfe(x)
except Exception as e:
print(e)
vmap
¶Vectorization is often necessary for performance in Python, especially in jax
. Fortunately, we can do this easily with vmap
. This function will take a input function and map it along a new dimension of your choice. For instance, you can use it to turn a scalar-to-scalar function into a vector-to-vector function. But there are many more possibilities.
# the most basic usage on the "f" defined above
fv2 = jax.vmap(f)
fv2(x)
# now we can do the element-by-element gradient of a vector-return function
dfv2 = jax.vmap(jax.grad(f))
dfv2(x)
jit
¶This is an approach to speeding up code by compiling it just when it is needed (rather than beforehand, like in C or C++). It's used in JavaScript, MATLAB, and other places. Calling jit
on a function will return another function that does the same thing but faster.
As with grad
, it only actually does the compilation when you run the
# compile the vector gradient function above
jdfv = jax.jit(dfv)
jdfv(x)
# give it a much bigger vector for testing
x2 = np.linspace(1.0, 5.0, 100000)
# first do the un-jit-ed version
%timeit -n 100 dfv(x2)
# now the jitted version (run twice for true comparison)
%timeit -n 100 jdfv(x2)
On my computer this gives ~300x speed improvement. May vary from CPU to GPU though.
jax
¶Trees are very cool! They're basically arbitrarily nested dict
s and list
s whose leaves can be any type. If you've seen JSON before, it's very similar to that. Why are trees cool? Just as numpy
lets you perform operations over an entire array at once, trees let you do the same over entire trees as well. This let's you avoid a lot of clunky and error-prone packing and unpacking.
# first let's define a very simple tree, which is just a dict of arrays
xd = {
'x': np.linspace(1.0, 5.0, 5),
'y': np.linspace(10.0, 19.0, 10),
}
xd
# here's a function that operates on such a tree
# you could imagine specifying a model like this
# where the input is a dict of parameters or data points
def ft(d):
xsum = np.sum(d['x']**3)
ysum = np.sum(d['y']**2)
return xsum + ysum
# now we can take a grad with respect to a tree and get a tree of the same shape back!
dft = jax.grad(ft)
dft(xd)
Sometimes it's easier to think in scalar terms them vectorize things. Suppose that our model really just operated on some x
and y
and we want to run the model for many such pairs. Then we can first define a slightly different function in purely scalar terms. I'm also adding in a separate parameter $\alpha$ to make things more interesting
def fts(d):
x, y = d['x'], d['y']
return x**3 + y**2
Then we use vmap
to broadcast over the tree d
. Below is what we'd write if we're interested in doing 10 distinct parameter sets.
xd1 = {'x': np.linspace(0, 1, 10), 'y': np.linspace(2, 3, 10)}
ftv1 = jax.vmap(fts)
ftv1(xd1)
But sometimes we just wany to vary one parameter at a time. As with grad
, we need to tell it which dimensions to operate on, this time using in_axes
. But it's a bit more complicated
xd2 = {'x': np.linspace(0, 1, 10), 'y': 2.0}
ftv2 = jax.vmap(fts, in_axes=({'x': 0, 'y': None},))
ftv2(xd2)
Notice that in_axes
is a list of the same length as the number of function arguments (in this case 1). We then have to mirror the structure of the tree itself to specify that we want to map x
over axis 0 and not map y
at all.
So... loops work a little differently in jax
. It's okay to write a for
loop actually, but not if you want to ultimately jit
the thing. First, as we get into more advanced usage, there's a great overview of common "gotchas" in jax
here: JAX - The Sharp Bits.
# very slow way to compute x**n
def fl(x, n):
out = 1
for i in range(n):
out = x*out
return out
try:
jfl = jax.jit(fl)
jfl(2.0, 3)
except Exception as e:
print(e)
Not a very helpful error message! But basically we tried to use a function argument as the loop bounds and then jit
it. One way to take care of this is to specify that n
is mostly fixed using static_argnums
. This will work but will recompile for each distinct value of n
that is seen.
jfl = jax.jit(fl, static_argnums=1)
jfl(2.0, 3)
The other way is to use one of the looping tools from lax
such as scan
or fori_loop
. First, let's try it using fori_loop
def fl_fori(x, n):
fl_x = lambda i, v: x*v
return lax.fori_loop(0, n, fl_x, 1.0)
jfl_fori = jax.jit(fl_fori)
jfl_fori(2.0, 3)
The scan
function operates similarly, in that it carries a running state, but instead of just getting i
at each iteration, you can get a slice of your input data. You also get the whole history of outputs at the end, rather than just the last value. The function you pass should accept a running value (v
below) and a state (i
below). Either of these can be arbirary tree types.
def fl_scan(x, n):
fl_x = lambda v, i: (x*v, x*v)
tvec = np.arange(n)
return lax.scan(fl_x, 1.0, tvec)
jfl_scan = jax.jit(fl_scan, static_argnums=1)
jfl_scan(2.0, 3)
The return at the end is the final value of the running state, as well as the history of values. Ok, so both of these were designed for fixed length iterations. For variable length iterations, we can use lax.while_loop
. This accepts a "condition" function that tells it whether to continue or not and a function that updates the running state.
def fl_while(x, n):
f_cond = lambda iv: iv[0] < n
fl_x = lambda iv: (iv[0] + 1, x*iv[1])
return lax.while_loop(f_cond, fl_x, (0, 1.0))
jfl_while = jax.jit(fl_while)
jfl_while(2.0, 3)
Here we just reimplemented a for loop by including i
in the running state and conditioning on it, but you could image the condition being something like a convergence check too.
Random numbers work a bit different in jax
due to the different needs of Deep Learning folks. You need to generate a "key" first using a "seed" as input (below, 42
a very common seed). You then pass this key to various random number routines that are similar to those in numpy.random
.
key = jax.random.PRNGKey(42)
r1 = jax.random.normal(key)
r2 = jax.random.normal(key)
print(key)
print(r1, r2)
Notice the same key leads to the same random numbers! So there must be more. To advance your key to something new you can use the split
function, which is also deterministic. This will return two new keys.
key, subkey = jax.random.split(key)
print(key, subkey)
So one pattern is to keep a running key
and generate a new subkey
when you need another random number
jax.random.uniform(subkey, (5,))
Optimization typically involves looping in some way, so we need to be a bit careful when using jax
. But you can achieve all of the same results with much more flexibility and speed at runtime. Here we're using simple gradient descent for opimization.
# hardcoded step size
Δ = 0.1
# simple function
f = lambda x: -(x-1)**2 + 1
df = jax.grad(f)
vdf = jax.vmap(df)
# visualize value and gradient
xgrid = np.linspace(0, 2, 100)
fig, (ax0, ax1) = plt.subplots(ncols=2, figsize=(8, 3))
ax0.plot(xgrid, f(xgrid))
ax1.plot(xgrid, vdf(xgrid));
# we define an update step function and one to apply it with scan
def update_scalar(x, t):
xp = x + Δ*df(x)
return xp, xp
def maximize_func(x0, K=50):
tvec = np.arange(K) # this just has to be the right shape (K)
x1, xvec = lax.scan(update_scalar, x0, tvec)
return x1, xvec
# jit'ed scalar and vectorized versions
jmax_func = jax.jit(maximize_func)
vmax_func = jax.jit(jax.vmap(maximize_func))
# run for just one starting value
x1, xvec = jmax_func(0.5)
plt.plot(xvec);
# run for many starting values simultaneously
key = jax.random.PRNGKey(42)
xrand = jax.random.normal(key, (100,))
x1, xvec = vmax_func(xrand)
plt.plot(xvec.T);
Now let's add in an additional parameter $\alpha$ which influences the optimization
# parameterized function (max=α)
f1 = lambda x, α: -(x-α)**2
df1 = jax.grad(f1)
vdf1 = jax.vmap(df1)
# we just need to apply α to function before optimizing
def update_scalar(α, x, t):
xp = x + Δ*df1(x, α)
return xp, xp
def maximize_func(α, x0, K=50):
tvec = np.arange(K)
upd = lambda x, t: update_scalar(α, x, t)
x1, xvec = lax.scan(upd, x0, tvec)
return x1, xvec
jmax_func = jax.jit(maximize_func)
vmax_func = jax.jit(jax.vmap(maximize_func, in_axes=(0, None)))
# run the scalar optimization for α = 2
x1, xvec = jmax_func(2.0, 0.5)
plt.plot(xvec);
# run the optimization for a range of α's
αgrid = np.linspace(-1, 1, 20)
x1, xvec = vmax_func(αgrid, 0.0)
plt.plot(xvec.T);
We can even take the derivative of the optimal solution with respect to $\alpha$, which should always be $1$ in our case since $x^{\ast}(\alpha) = \alpha$.
dmax_func_out = jax.grad(lambda α, x0: maximize_func(α, x0)[0], argnums=0)
vmax_func_out = jax.vmap(dmax_func_out, in_axes=(0, None))
vmax_func_out(αgrid, 0.5)
Below is an implementation of simple value function iteration using jax
for the neoclassical growth model. This makes use of trees for parameter values and the scan
function for iteration. We start by defining parameters and the utility and production functions, as well as constructing the capital grid and finding the analytic steady state.
from functools import partial
# fixed params
T = 300 # time periods
N = 100 # grid size
f_lo, f_hi = 0.5, 2.0 # proportional grid range
# simple parameters
par0 = {
'β': 0.95,
'δ': 0.1,
'α': 0.35,
'z': 1.0,
}
# find steady state
def get_kss(par):
β, δ, z, α = par['β'], par['δ'], par['z'], par['α']
rhs = (1-β)/β + δ
k = (α*z/rhs)**(1/(1-α))
return k
k_ss = get_kss(par0)
print(k_ss)
# construct capital grid
k_min = f_lo*k_ss
k_max = f_hi*k_ss
k_grid = np.linspace(k_min, k_max, N)
# defined functions
def util(c, ϵ=1e-6):
c1 = np.maximum(ϵ, c)
return np.log(c1)
def prod(k, z, α):
return z*k**α
Now we define the actual updating rule and the outer solving function.
# the value function update function
def value(par, grid, st, tv):
β = par['β']
cp = grid['cp']
vn = st['vn']
# calculate optimal investment
v1 = util(cp) + β*vn[None,:]
ip = np.argmax(v1, axis=1)
vp = np.max(v1, axis=1)
# apply at values
kp = k_grid[ip]
# compute update errors
err = np.max(np.abs(vp-vn))
# return state and output
stp = {
'vn': vp,
}
out = {
'v': vp,
'kp': kp,
'err': err,
}
return stp, out
# the actual solver (T is the number of iterations)
def solve(par, T):
α = par['α']
δ = par['δ']
z = par['z']
# precompute grid values
y_grid = prod(k_grid, z, α)
yd_grid = y_grid + (1-δ)*k_grid
cp_grid = yd_grid[:,None] - k_grid[None,:]
# partially apply grid
grid = {
'cp': cp_grid,
}
value1 = partial(value, par, grid)
# scan over time (backwards)
st0 = {
'vn': util(y_grid),
}
tv = {
't': np.arange(T)[::-1],
}
last, path = lax.scan(value1, st0, tv)
return path
Now let's time the model and see how the results look!
# jit and run once to ensure compilation
jsolve = jax.jit(solve, static_argnums=(1,))
ret = jsolve(par0, T)
# get some timing numbers
%timeit -r 10 -n 10 jsolve(par0, T)
# plot the log error over the whole run
plt.plot(ret['err'])
plt.yscale('log');
# make sure our policy function lines up with the analytic steady state
fig, ax = plt.subplots()
ax.plot(k_grid, ret['kp'][-1,:]-k_grid);
ax.hlines(0, *ax.get_xlim(), linestyle='--', linewidth=1, color='k')
ax.scatter(k_ss, 0, color='k', zorder=10);
The flip side to solving value functions is constructing state distributions, either as they evolve or in steady state. Let's consider the system described in the lecture slides.
# algorithm params
N = 100
Δ = 0.1
# model params
λ = 1.1
τ = 0.1
κ = 0.02
μ = 0.0
σ = 0.3
g = np.log(λ)*τ
# construct grid
qmin, qmax = 0.0, 5.0
qbins = np.linspace(qmin, qmax, N+1)
qlow, qhigh = qbins[:-1], qbins[1:]
qmids = 0.5*(qlow+qhigh)
qdiff = qhigh - qlow
# entry distribution
ecdf = stats.norm.cdf(np.log(qbins), μ, σ)
epmf = mod.normed(np.diff(ecdf))
plt.plot(qmids, epmf);
Now we need to construct the transition matrices. The non-exit conditional transition matrix is called qstay
, while the fully inclusive transition matrix is called qtmat
. Notice that it's way easier to write split_range
with vmap
rather than directly.
# range distributor
def split_range(q1, q2):
return np.maximum(0, np.minimum(q2, qhigh) - np.maximum(q1, qlow))
v_split_range = jax.vmap(split_range)
# various rates for a given Δ
gdel = 1 - Δ*g
λdel = λ*gdel
τdel = Δ*τ
κdel = Δ*κ
# construct the transition matrices!
qdrop = v_split_range(gdel*qlow, gdel*qhigh)/(gdel*qdiff[:, None])
qjump = v_split_range(λdel*qlow, λdel*qhigh)/(λdel*qdiff[:, None])
qstay = τdel*qjump + (1-τdel)*qdrop
qtmat = κdel*epmf[None, :] + (1-κdel)*qstay
qtmat = mod.normed(qtmat, axis=1)
Now let's just use the brute force method of manually iterating on transition matrices. Here we do so for both the no-exit and yes-exit versions. Notice that there is no steady state in the no-exit version.
def iterate_mat(dist, tmat, T=4000):
hist = [dist]
for i in range(T):
dist = dist @ tmat
hist.append(dist)
return np.vstack(hist)
hist0 = iterate_mat(epmf, qstay)
hist1 = iterate_mat(epmf, qtmat)
dist_sim = hist1[-1, :]
fig, (ax0, ax1) = plt.subplots(ncols=2, figsize=(10, 4))
ax0.plot(qmids, hist0[::1000, :].T);
ax1.plot(qmids, hist1[::1000, :].T);
Now let's use an eigen-solver to find the eigenvector associated with the unit eigenvalue.
eigval, eigvec = mod.eig(qtmat.T)
eigidx = np.abs(eigval-1).argmin()
dist_eig = mod.normed(eigvec[:, eigidx].real)
plt.plot(qmids, dist_eig, qmids, dist_sim, alpha=0.5);
Finally, let's find the null space of $A^{\prime} - I$, which is the solution to $x A = x$.
qtmat1 = qtmat.T - np.eye(N)
dist_nul = mod.normed(mod.null_space(qtmat1).squeeze())
plt.plot(qmids, dist_nul, qmids, dist_sim, alpha=0.5);