import math
import torch
import torch.nn.functional as F
from tqdm import tqdm
from torch import nn
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import torchvision
import torchvision.transforms as T
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
toPIL = T.ToPILImage()
# configure matplotlib output
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.style.use('config/clean.mplstyle') # this loads my personal plotting settings
col = mpl.rcParams['axes.prop_cycle'].by_key()['color']
%matplotlib inline
# if you have an HD display
%config InlineBackend.figure_format = 'retina'
# some warnings can get annoying
import warnings
warnings.filterwarnings('ignore')
from tools.text import process_text, generate_sequences, total_params, pad_list, sample_logits
# here you can set which device to use
device = 'cuda' # 'cpu'
First let's look at a (comparatively) simple Markov-style model in which do next token prediction. I'm going train it on only one book, that classic outlier Moby Dick. First we load it in from a file I've stored locally.
with open('../data/moby_dick.txt') as fid:
moby = fid.read()
moby1 = process_text(moby)
There are some useful tools residing in torchtext
that will help us in tokenizing words for training and evaluation. Here tokenizing just refers to splitting documents into distinct words, or more generally tokens.
tokenizer = get_tokenizer('basic_english')
tokens = tokenizer(moby1)
print(len(tokens))
print(' | '.join(tokens[92:180]))
Now we take these tokens and convert them into integer ids that will get fed into the first state embedding of our model.
vocab = build_vocab_from_iterator([tokens], specials=['<unk>'])
vocab.set_default_index(vocab['<unk>'])
indices = vocab(tokens)
print(len(vocab))
print(indices[:91])
This is a bit heavy compared to what we've done before, but nothing conceptually different. We have the embedding layer, then a dense linear layer mapping from window_len*embed_dim
down to embed_dim
, essentially compressing over the context window. Then we have a final linear layer mapping from embed_dim
to vocab_size
. Typically embed_dim
will be much smaller than vocab_size
, so this is a fan in then fan out type network.
class MarkovLanguage(nn.Module):
def __init__(self, window_len, vocab_size, embed_dim, dropout=0.0):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.flatten = nn.Flatten()
self.dropout0 = nn.Dropout(dropout)
self.linear0 = nn.Linear(window_len*embed_dim, embed_dim)
self.dropout1 = nn.Dropout(dropout)
self.linear1 = nn.Linear(embed_dim, vocab_size)
self.init_weights()
def init_weights(self):
initrange = 0.5
self.embedding.weight.data.uniform_(-initrange, initrange)
self.linear0.weight.data.uniform_(-initrange, initrange)
self.linear0.bias.data.zero_()
self.linear1.weight.data.uniform_(-initrange, initrange)
self.linear1.bias.data.zero_()
def forward(self, toks):
x = self.embedding(toks)
x = self.dropout0(x)
x = self.flatten(x)
x = F.relu(self.linear0(x))
x = self.dropout1(x)
x = self.linear1(x)
return x
You'll also see that we're using dropout between each layer and on the intermediate layer we have a ReLU
non-linear transform. This non-linearity helps keep our numbers bounded and allows for logic-type operations more easily. We could in principle use the ReLU
on the final layer too, but our likelihood just takes logits anyway, so there's no need. Now we have the training code, which is pretty similar to before.
def language_model(model, train_data, valid_data, lr=0.01, shuffle=False, epochs=10, batch_size=1000):
# make data loader
train_load = torch.utils.data.DataLoader(train_data, shuffle=shuffle, batch_size=batch_size)
valid_x, valid_y = valid_data.tensors
# make model and optimizer
optim = torch.optim.RMSprop(model.parameters(), lr=lr)
# define loss function
bce_loss = nn.CrossEntropyLoss()
# track history
history = []
# go for many iterations
for i in tqdm(range(epochs)):
n_batch = 0
train_stats = 0.0
# iterate over batches
model.train()
for batch_x, batch_y in train_load:
# compute loss
logits = model(batch_x)
loss = bce_loss(logits, batch_y)
# update params
optim.zero_grad()
loss.backward()
optim.step()
# training stats
n_batch += 1
train_stats += loss.detach().item()
# average training
train_stats /= n_batch
# validation stats
model.eval()
valid_logits = model(valid_x)
valid_loss = bce_loss(valid_logits, valid_y)
valid_stats = valid_loss.detach().item()
# make it a mean
stats = torch.tensor([train_stats, valid_stats])
history.append(stats)
# turn history into dataframe
names = ['train_loss', 'valid_loss']
return pd.DataFrame(torch.stack(history).numpy(), columns=names)
With that all set up, we can now decide on the particular parameters of our network. Note that given the complexity of the problem, this is pretty small!
V_moby = len(vocab) # vocabulary size
W_moby = 64 # sequence window size
E_moby = 32 # embedding size
M_moby = 2000 # validation set size
The last preparation step is to convert our large document into distinct sequences. For this, I've defined a help function generate_sequences
that splits things into overlapping sequences that step forward one token at a time. Essentiall we're saying, given a context window of size W_moby
tokens, predict the next token in the document.
# make full dataset
seqs, targ = generate_sequences(torch.tensor(indices).reshape(1, -1), W_moby)
moby_data = torch.utils.data.TensorDataset(seqs.to(device), targ.to(device))
train_split, valid_split = torch.utils.data.random_split(moby_data, [len(moby_data)-M_moby, M_moby])
moby_train = torch.utils.data.TensorDataset(*moby_data[train_split.indices])
moby_valid = torch.utils.data.TensorDataset(*moby_data[valid_split.indices])
Let's train this baby! You're gonna want a GPU for this. First let's do it without dropout to get a sense of the baseline performance.
# create model (it's a big one)
lmod = MarkovLanguage(W_moby, V_moby, E_moby).to(device)
print(sum([p.numel() for p in lmod.parameters()]))
hist = language_model(lmod, moby_train, moby_valid, epochs=10)
hist.plot();
As before, you can see that the validation loss just gets worse and worse, even as the training loss gets better and better. Now let's try it with dropout turned on.
# create model (with dropout)
lmod_drop = MarkovLanguage(W_moby, V_moby, E_moby, dropout=0.5).to(device)
hist_drop = language_model(lmod_drop, moby_train, moby_valid, epochs=10)
hist_drop.plot();
That's better! At least there a range where the validation loss decreases, then it only starts increasing slowly. Now let's draw some samples from the model. This is pretty inefficient, since we need to feed in the context, generate one token, append it to the context, and then repeat. It's inefficient because in the next step our context windows almost identical. But since the network is fully interconnected, this is really the only way.
# predict one token ahead
lmtoks = vocab.get_itos()
def predict_next(model, prompt, temp=1.0):
toks = tokenizer(prompt) # string into list of tokens
index = pad_list(vocab(toks), W_moby) # list of tokens to list of indices
vecs = torch.tensor(index, device=device).unsqueeze(0) # turn it into a tensor
pred = model(vecs).squeeze(0) # this is where we actually call the model
word = sample_logits(pred, temp=temp) # sample the next word index
return lmtoks[word] # map back from index to token
# predict many tokens ahead
def predict_many(model, prompt, n, temp=1.0):
prompt = f'{prompt}\n\n██\n\n'
for i in range(n):
word = predict_next(model, prompt, temp=temp)
prompt = f'{prompt} {word}'
return prompt
output = predict_many(lmod, moby1[35026:36001], 250, temp=1.0)
print(output.replace('\\n', '\n\n'))
Okay, this part is quite important. Most of the time, if you're running LLMs you're going to be using big pre-trained versions from one of the big labs. Most of the time the architectures are close variants of the core transformer, which I'll discuss in more detail below, and the weights are distributed in big binary files. The main way to access and run these is through Huggingface with their transformers
library. Here we're going to request the Llama-2 model from Meta, specifically the 7B parameter variant. You can check out the HF page here https://huggingface.co/meta-llama/Llama-2-7b-hf.
hf_name = 'meta-llama/Llama-2-7b-hf'
hf_token = AutoTokenizer.from_pretrained(hf_name, token=True)
hf_model = AutoModelForCausalLM.from_pretrained(hf_name, device_map=device, torch_dtype=torch.float16)
Ok, that was a lot! But it all happened automatically, and we got our model. Since it's about 7B parameters, and it's stored in 16-bit, it should be about 14GB (since 16 bits is 2 bytes). Now let's run it. We need to do the same thing where we add one token at a time and resample in a loop. Here because the runtime is somewhat slow, we'll print out results as we generate them.
streamer = TextStreamer(hf_token)
prompt = moby1[35026:36001]
toks = hf_token(prompt, return_tensors='pt').to(device)
rets = hf_model.generate(toks.input_ids, streamer=streamer, max_new_tokens=500, repetition_penalty=1.1)
The results are usually Moby Dick like! Lotta whale talk and Nantucket goings on. The above is using the high level huggingface interface. We might want to write our own generation code in some cases. Either way, this let's us see precisely how the generation is done. It's not beyond comprehension! We're just calling the model to get a list of logit-probabilities for the next token and sampling from that distribution, then repeating.
if hf_token.pad_token is None:
hf_token.pad_token = hf_token.eos_token
def encode_hf(tokenizer, text):
data = tokenizer(
text, return_tensors='pt', padding=True, truncation=True, max_length=2048
)
return data['input_ids'].to(device)
def detok_hf(tok):
return tok.replace('▁', ' ').replace('<0x0A>', ' ')
def generate_hf(model, tokenizer, prompt, maxlen=2048, context=2048, temp=1.0):
# encode input prompt
input_ids = encode_hf(tokenizer, prompt)
print(f'{prompt}\n\n██\n\n')
# loop until limit and eos token
for i in range(maxlen):
# generate next index (no grad for memory usage)
with torch.no_grad():
output = model(input_ids)
logits = output.logits[0,-1,:]
index = sample_logits(logits, temp=temp)
# break if we hit end token
if index == tokenizer.eos_token_id:
break
# decode and return (llama not doing this right)
token = tokenizer.convert_ids_to_tokens(index)
print(detok_hf(token), end='', flush=True)
# shift and add to input_ids
trim = 1 if input_ids.size(1) == context else 0
newidx = torch.tensor([[index]], device=device)
input_ids = torch.cat([input_ids[trim:], newidx], 1)
generate_hf(hf_model, hf_token, prompt, maxlen=256)
Now let's get an idea of what the building blocks of a transformer are. Let's look at the total number of parameters in the simple model.
total_params(lmod)
The main thing that transformer models add over the simple model is an attention mechanism. One issue with the simplified model is that it's kind of too wide. We do a linear map from window_len*embed_dim
down to embed_dim
and then do next word prediction. It would be better to add more layers, but if we do so without making each layer thinner, we risk being way overparameterized. As it is, we only have 260k tokens in our training set. We want a way for there to be interaction terms between different word positions. Doing the full dense interaction would require a window_len^2 * embed_dim^2 = 6.25e6
element matrix.
Instead we'll use what's called an attention mechanism, which is perhaps the most important element in any transformer model. This will allow for some (but not arbitrary) interactions between different window positions, and it only requires 3 * embed_dim^2
parameters. Then with this savings, we can stack up a few more layers and still be way better off in terms of total parameter count.
class SelfAttention(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
def forward(self, x):
q, k, v = self.q(x), self.k(x), self.v(x) # (B,T,E) => (B,T,E)
w = q @ k.transpose(-2, -1) # (B,T,E) x (B,T,E) => (B,T,T)
a = F.softmax(w / math.sqrt(self.dim), dim=-1) # (B,T,T) => (B,T,T)
y = a @ v # (B,T,T) x (B,T,E) => (B,T,E)
return y
Essentially what's going on here is we're computing query $Q$, key $K$, and value $V$ as functions of our data. The matrix $A = \text{softmax}(Q \times K)$ is of shape (T,T)
and tells us where each position should be looking. We then compute $A \times V$ to generate a final weighted output, which in this case is the same shape as our input. Let's create the layer and see what it looks like.
att = SelfAttention(E_moby).to(device)
print(total_params(att))
att
Now we can just call it directly on the embedding output from our previous model.
test_emb = lmod.embedding(seqs[0,:].to(device))
test_att = att(test_emb)
test_att.shape
Now let's try to inspect what attentions look like in practice. For this we'll use the big pretrained model. We need to tell this copy to return the attention information that we wish to inspect.
hf1_model = AutoModelForCausalLM.from_pretrained(
hf_name, device_map=device, torch_dtype=torch.float16, output_hidden_states=True, output_attentions=True
)
nlayers = len(hf1_model.model.layers)
The output of an attention layer will be $L \times L$ where $L$ is the sequence length, and there are 32 layers. Additionally, there are actuall 32 attention heads per layer, wheras the simple attention above is just a single head. Below we run the model once on the previous prompt, and retrieve the attentions for each layer. Finally, we show in image form the average interaction between each word position. Note that this matrix is diagonal since the attentions only look backwards towards previous words.
toks = hf_token(prompt, return_tensors='pt').to(device)
outp = hf1_model(toks.input_ids)
fattn = torch.stack([outp.attentions[i][0,:,:,:] for i in range(nlayers)])
print(fattn.shape)
toPIL(255*fattn.mean((0, 1)))
This isn't really language model related, but I'll talk a bit about convolution here, which is important when dealing with images.
img = torchvision.io.read_image('../data/racoon.png')
print(img.shape)
toPIL(img)
Now let's implement the simplest type of convolution: blurring. We create a convolutional layer with radius $C=16$, specifying that we are dealing with three input channels and three output channels (RGB colors). Additionally, passing groups=3
makes it so the blur is done individually within each color and bias=False
to ensure we don't add any constants to the result.
C = 16
m = nn.Conv2d(3, 3, C, groups=3, bias=False)
m.weight.data.fill_(1/C**2);
bimg = m(img.float()).to(dtype=torch.uint8)
toPIL(bimg)