Star on GitHub
DocsGuides

nanoGPT, explained from scratch

Andrej Karpathy's nanoGPT is the cleanest ~300-line implementation of a GPT-style language model. Read it once and the rest of the transformer ecosystem stops feeling like magic.

What is nanoGPT?

nanoGPT is a minimal, hackable PyTorch re-implementation of GPT-2 by Andrej Karpathy. The whole model lives in one file (model.py), training in another (train.py), and sampling in a third (sample.py). It's small enough to read in an afternoon and powerful enough to reproduce GPT-2 (124M) on a single 8×A100 node.

Architecture
Decoder-only transformer
Lines of code
~300 for the model
Training data
OpenWebText / Shakespeare
Hardware
Laptop → 8×A100

The mental model

A GPT is a function that takes a sequence of tokens and predicts the next token. That's it. Training teaches it to predict the next token across billions of examples; sampling repeatedly calls it, appending each prediction, to generate text. Everything else — attention, MLPs, layer norms — is plumbing that makes that one prediction more accurate.

text
tokens ──▶ token + position embeddings
              │
              ▼
        ┌───────────────────┐
        │  Transformer block│  × N layers
        │  ┌─────────────┐  │
        │  │ LayerNorm   │  │
        │  │ Causal Attn │  │   ← tokens look only at past tokens
        │  │ + residual  │  │
        │  ├─────────────┤  │
        │  │ LayerNorm   │  │
        │  │ MLP (4×)    │  │
        │  │ + residual  │  │
        │  └─────────────┘  │
        └───────────────────┘
              │
              ▼
        LayerNorm → Linear → logits over vocab → softmax → next token

1. Tokenization

Text is split into integer tokens using Byte-Pair Encoding (BPE). nanoGPT uses tiktoken's GPT-2 encoding with a 50,257-token vocabulary. Each token becomes an index into the embedding table.

python
import tiktoken
enc = tiktoken.get_encoding("gpt2")

ids = enc.encode("Vector databases are useful.")
# [38469, 18209, 389, 4465, 13]

enc.decode(ids)
# 'Vector databases are useful.'

2. Embeddings: token + position

Each token id is looked up in a learned token embedding table to get a vector of size n_embd (e.g. 768). Transformers have no built-in sense of order, so we add a learned position embedding for each slot in the sequence. The sum is what enters the first transformer block.

python
class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
        self.pos_emb = nn.Embedding(config.block_size, config.n_embd)
        self.blocks  = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
        self.ln_f   = nn.LayerNorm(config.n_embd)
        self.head   = nn.Linear(config.n_embd, config.vocab_size, bias=False)

    def forward(self, idx):
        B, T = idx.shape
        pos = torch.arange(T, device=idx.device)
        x = self.tok_emb(idx) + self.pos_emb(pos)   # (B, T, n_embd)
        for block in self.blocks:
            x = block(x)
        x = self.ln_f(x)
        return self.head(x)                          # (B, T, vocab_size)

3. Causal self-attention (the heart)

Self-attention lets every token look at every other token and decide which ones to pay attention to. In a decoder-only GPT it's causal: token t can attend to tokens 0…t but not the future. This is enforced with a triangular mask.

For each token we project the embedding into three vectors: Query (what am I looking for?), Key (what do I offer?), and Value (what info do I carry?). The attention score between tokens i and j is the dot product Qᵢ · Kⱼ, scaled and softmaxed, then used to weight the Values.

python
class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)  # Q, K, V
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        # lower-triangular mask: position i can see 0..i
        self.register_buffer(
            "mask",
            torch.tril(torch.ones(config.block_size, config.block_size))
                 .view(1, 1, config.block_size, config.block_size),
        )

    def forward(self, x):
        B, T, C = x.shape
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
        # reshape into (B, n_head, T, head_dim)
        hd = C // self.n_head
        q = q.view(B, T, self.n_head, hd).transpose(1, 2)
        k = k.view(B, T, self.n_head, hd).transpose(1, 2)
        v = v.view(B, T, self.n_head, hd).transpose(1, 2)

        att = (q @ k.transpose(-2, -1)) / math.sqrt(hd)        # (B, h, T, T)
        att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
        att = F.softmax(att, dim=-1)
        y   = att @ v                                            # (B, h, T, hd)
        y   = y.transpose(1, 2).contiguous().view(B, T, C)
        return self.c_proj(y)
Multi-head = parallel perspectives
Splitting the embedding into n_head smaller heads lets each head specialize: one head might track subject–verb agreement, another might track quotation pairing. The outputs are concatenated and projected back.

4. The MLP and the transformer block

After attention mixes information across tokens, a position- wise MLP (two linear layers with a GELU in between, expanding to 4× the embedding dim) processes each token independently. Both sublayers are wrapped in residual connections and pre-LayerNorm.

python
class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc   = nn.Linear(config.n_embd, 4 * config.n_embd)
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)

    def forward(self, x):
        return self.c_proj(F.gelu(self.c_fc(x)))

class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp  = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))   # residual + attention
        x = x + self.mlp(self.ln_2(x))    # residual + MLP
        return x

5. The training loop

Training is unglamorous: grab a random window of the corpus, shift it by one to make the targets, compute cross-entropy between predicted and actual next tokens, backprop, step the optimizer (AdamW with cosine LR schedule). Repeat for hundreds of thousands of iterations.

python
def get_batch(data, block_size, batch_size):
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i  :i+block_size]   for i in ix])
    y = torch.stack([data[i+1:i+1+block_size] for i in ix])  # shifted by 1
    return x, y

for step in range(max_iters):
    xb, yb = get_batch(train_data, block_size, batch_size)
    logits = model(xb)                                     # (B, T, vocab)
    loss = F.cross_entropy(
        logits.view(-1, logits.size(-1)),                  # (B*T, vocab)
        yb.view(-1),                                       # (B*T,)
    )
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
Why next-token prediction is enough
To predict the next token reliably across billions of examples, the model must implicitly learn grammar, facts, reasoning patterns, coding conventions, and more. That's why this dirt-simple objective produces such capable models at scale.

6. Sampling: generating text

To generate, feed the prompt through the model, take the logits at the last position, optionally apply temperature and top-k, sample a token, append it, and repeat until you hit a length limit or stop token.

python
@torch.no_grad()
def generate(model, idx, max_new_tokens, temperature=1.0, top_k=None):
    for _ in range(max_new_tokens):
        idx_cond = idx[:, -block_size:]              # crop to context window
        logits = model(idx_cond)[:, -1, :] / temperature  # last position only
        if top_k is not None:
            v, _ = torch.topk(logits, top_k)
            logits[logits < v[:, [-1]]] = -float("inf")
        probs = F.softmax(logits, dim=-1)
        next_id = torch.multinomial(probs, num_samples=1)
        idx = torch.cat((idx, next_id), dim=1)
        if next_id.item() == eos_id: break
    return idx
temperature < 1
Sharper, more confident
temperature > 1
More creative, more typos
top-k
Sample from k most likely
top-p (nucleus)
Smallest set with prob ≥ p

7. KV-cache: making decoding fast

During generation we call the model once per new token. Naively, every call recomputes Q, K, and V for the entire prefix — even though the prefix hasn't changed. That's O(T²) work to produce T tokens.

The fix: cache the K and V tensors for every past token in every layer. On each new step we only compute Q/K/V for the single new token, append its K and V to the cache, and attend against the full cached K/V. Each generation step becomes O(T) instead of O(T²), and the speedup grows with context length.

Why Q isn't cached
Only the newest token needs a Query (it's the one asking "who should I attend to?"). Past tokens' Queries were used once during their own step and are never needed again. Keys and Values, however, are looked up by every future token — so we keep them around.
python
class CausalSelfAttention(nn.Module):
    def forward(self, x, kv_cache=None):
        B, T, C = x.shape                              # T == 1 during decoding
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
        hd = C // self.n_head
        q = q.view(B, T, self.n_head, hd).transpose(1, 2)
        k = k.view(B, T, self.n_head, hd).transpose(1, 2)
        v = v.view(B, T, self.n_head, hd).transpose(1, 2)

        # Append new K, V to the cache for this layer
        if kv_cache is not None:
            past_k, past_v = kv_cache
            k = torch.cat([past_k, k], dim=2)          # (B, h, T_past+1, hd)
            v = torch.cat([past_v, v], dim=2)
        new_cache = (k, v)

        # No mask needed: q has length 1, k/v are all past+current
        att = (q @ k.transpose(-2, -1)) / math.sqrt(hd)
        att = F.softmax(att, dim=-1)
        y   = (att @ v).transpose(1, 2).contiguous().view(B, T, C)
        return self.c_proj(y), new_cache

@torch.no_grad()
def generate_cached(model, idx, max_new_tokens):
    caches = [None] * len(model.blocks)                # one per layer
    # 1) Prefill: run the full prompt once to populate the caches
    x, caches = model.forward_with_cache(idx, caches)
    next_id = sample(x[:, -1, :])
    out = [next_id]

    # 2) Decode: feed only the newest token each step
    for _ in range(max_new_tokens - 1):
        x, caches = model.forward_with_cache(next_id, caches)
        next_id = sample(x[:, -1, :])
        out.append(next_id)
    return torch.cat([idx] + out, dim=1)
Without cache
O(T²) per generation
With cache
O(T) per generation
Memory cost
2 · n_layer · T · n_embd floats
Prefill vs decode
Cache built once, reused per token

This is also why long contexts get expensive at inference time even on small models: the KV-cache grows linearly with context length per layer. Tricks like grouped-query attention, multi-query attention, and paged KV-caches (vLLM) all exist to shrink or better manage this exact buffer.

What nanoGPT deliberately leaves out

To stay readable, nanoGPT skips many production niceties — and that's the point. Once you understand the core, you can read papers about each addition without losing the thread.

RoPE / ALiBi
Better position encoding
KV-cache
Fast incremental decoding
Grouped-query attn
Llama-style memory saving
RMSNorm + SwiGLU
Llama/Mistral primitives
Mixture of Experts
Sparse scaling
RLHF / DPO
Instruction tuning

Try it yourself

The fastest on-ramp is the Shakespeare character-level demo — trains in minutes on a laptop CPU and produces recognizably Shakespearean gibberish. Then graduate to OpenWebText on a GPU to reproduce GPT-2.

bash
git clone https://github.com/karpathy/nanoGPT
cd nanoGPT
pip install torch numpy transformers datasets tiktoken wandb tqdm

# Prepare the tiny Shakespeare dataset
python data/shakespeare_char/prepare.py

# Train a small model on CPU / single GPU
python train.py config/train_shakespeare_char.py

# Sample from your trained model
python sample.py --out_dir=out-shakespeare-char
Read in this order
model.pytrain.py sample.py. Then read Karpathy's "Let's build GPT" video alongside it — it walks through every line.