Let's start with a basic LLM - the GPT-2 small. Architecturally, it has 3 basic modules:
- Tokenizer and Embedding
- Decoder
- Classifier and Sampling
I will focus on the inference and will try to dissect what the inputs and outputs are, their shapes at each stage of the inference. I will use basic pytorch to come up with the tensor shapes.
import torch
import torch.nn as nn
# GPT2 Model
B = 1
T = 1024
E = 768
H = 12
vocab_size = 50257
# token embeddings
token_ids = torch.randint(low = 0, high=vocab_size, size=(B,T) )
emb = nn.Embedding(num_embeddings=vocab_size, embedding_dim=E)
token_emb = emb(token_ids)
assert token_emb.shape == (B, T, E)
# position embeddings
pos = torch.arange(T)
p_emb = nn.Embedding(num_embeddings=T, embedding_dim=E)
pos_emb = p_emb(pos)
pos_emb = pos_emb.unsqueeze(0)
assert pos_emb.shape == (1, T, E)
# input embedding
x_emb = token_emb + pos_emb
assert x_emb.shape == (B, T, E)
# layernorm
layer_norm = nn.LayerNorm(normalized_shape=E)
x_ln_emb = layer_norm(x_emb)
assert x_ln_emb.shape == (B, T, E)
# project x_ln_emb to q, k, v
projection = nn.Linear(in_features = E, out_features=3*E)
qkv = projection(x_ln_emb)
assert qkv.shape == (B, T, 3*E)
# split qkv into q, k and v
q, k, v = qkv.chunk(3, dim=-1)
assert q.shape == (B, T, E)
# head D_h
D_h = E // H
assert D_h == 64
q = q.view(B, T, H, D_h).transpose(1, 2)
k = k.view(B, T, H, D_h).transpose(1, 2)
v = v.view(B, T, H, D_h).transpose(1, 2)
# print(q.view(B, T, H, -1).shape)
# attention - it is vectorized operation for all the heads together
attn_scores = q @ k.transpose(-2, -1) / (D_h ** 0.5)
assert attn_scores.shape == (B, H, T, T)
# apply softmax
attn_weights = attn_scores.softmax(dim=-1)
assert attn_weights.shape == (B, H, T, T)
attn_out = attn_weights @ v
assert attn_out.shape == (B, H, T, D_h)
# transpose and reshape - this is the concat operation
attn_out = attn_out.transpose(1, 2).reshape(B, T, -1) # (B, T, H*D_h)
assert attn_out.shape == (B, T, E)
# linear projection before exiting MHA block - mixing information across all heads
out_projection= nn.Linear(in_features= E, out_features=E)
x_out = out_projection(attn_out)
assert x_out.shape == (B, T, E)
# Add residual connection x_emb
x_int = x_emb + x_out
assert x_int.shape == (B, T, E)
# mlp layer - layernorm and MLP
mlp_ln = nn.LayerNorm(normalized_shape = E)
x_ln_int = mlp_ln(x_int)
assert x_ln_int.shape == (B, T, E)
mlp = nn.Linear(in_features=E, out_features=E)
x_mlp = mlp(x_ln_int)
assert x_mlp.shape== (B, T, E)
# add residual
x_decode = x_int + x_mlp
assert x_decode.shape == (B, T, E)
# repeat for all the decoder layers - that shape should remain the same
# Final normalization
final_ln = nn.LayerNorm(normalized_shape=E)
x_ln_final = final_ln(x_decode)
assert x_ln_final.shape == (B, T, E)
# final projection - classifier
final_proj = nn.Linear(in_features=E, out_features=vocab_size)
logits = final_proj(x_ln_final)
assert logits.shape == (B, T, vocab_size)
# softmax on logits
token_probs = logits.softmax(dim=-1)
assert token_probs.shape == (B, T, vocab_size)