0% found this document useful (0 votes)
229 views189 pages

1 The Dawn of A New Architecture 1 The Core Transformer Architecture: An Overview 2

Uploaded by

ks0001as
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
229 views189 pages

1 The Dawn of A New Architecture 1 The Core Transformer Architecture: An Overview 2

Uploaded by

ks0001as
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
You are on page 1/ 189

38+35+15+24+22+38=171

All my royalties will be donated to Save the Children

Part 1
Introduction​ 1
The Dawn of a New Architecture​ 1
The Core Transformer Architecture: An Overview​ 2
The Encoder in Detail​ 3
Inside an Encoder Layer:​ 4
The Decoder in Detail​ 4
Inside a Decoder Layer:​ 5
Attention is All You Need: The Core Mechanism​ 5
The Inputs: Queries, Keys, and Values (Q, K, V)​ 6
The Attention Formula​ 6
Multi-Head Attention: Seeing in Parallel​ 7
Types of Attention in the Transformer​ 8
Computational Complexity and Benefits of Attention​ 9
The Transformer Lifecycle​ 11
Pretraining​ 11
Fine-Tuning​ 13
Alignment​ 14
Modern Transformers​ 16
Context Management: The Context Window​ 16
The Challenge of Long Sequences​ 17
Techniques for Extending Effective Context​ 18
Prompting and Inference: The Art of Talking to LLMs​ 20
The Inference Process: From Probabilities to Text​ 21
The Rise of Decoder-Only Architectures​ 24
Mixture-of-Experts (MoE) Models​ 25
Multimodality: Beyond Text​ 27
The Problem of Hallucination​ 28
Interpretability and Mechanistic Interpretability​ 29
The Evolving Story of Scaling Laws: From Training to "Thinking"​ 30
The Foundational Laws of Training​ 31
The New Frontier: Test-Time Compute Scaling​ 31
The Unified View: A New, More Complex Trade-Off​ 32
Alternative Architectures​ 32
State-Space Models (Mamba and beyond)​ 32
Diffusion Transformers (DiT)​ 34
The Path to Artificial General Intelligence (AGI)​ 35
Handson​ 36
Key Takeaways​ 36
Conclusion​ 38
Glossary of Key Terms​ 38

Part 2
Introduction​ 1
The Architectural Foundations of Modern AI Reasoning​ 3
Inference-Time Scaling: Deeper Thinking on Demand​ 3
Chain of Thought (CoT)​ 3
Few shots and CoT​ 7
Tree of Thoughts: From Linear Steps to Exploratory Search​ 9
The Alignment Engine​ 11
Reinforcement Learning from Human Feedback (RLHF)​ 11
Reinforcement Learning from AI Feedback (RLAIF)​ 13
Pure Reinforcement Learning (RL)​ 14
Key RL Algorithms in Alignment (PPO and DPO)​ 14
Supervised Fine-Tuning + Reinforcement Learning (SFT + RL)​ 15
Pure SFT and Distillation​ 18
Thinking on Demand: The Rise of Test-Time Computation and Adaptive Inference​
20
The Scaling Dilemma: Sparse Mixture-of-Experts and Computational Efficiency​ 21
A Comparative Analysis of Leading Models​ 24
The Titans of Industry (Proprietary Models)​ 24
The Open-Source Insurgency​ 25
Benchmarking Reasoning Capabilities​ 28
The Established Canon and Its Limitations​ 28
The 2025 Leaderboard​ 28
Future Directions and Implications​ 29
Next-Generation Reasoning Paradigms​ 29
Neuro-Symbolic AI: Bridging Logic and Intuition​ 30
Dynamic Reasoning Frameworks​ 31
Adversarial Self-Critique​ 32
Key Takeaways​ 33
Conclusion​ 34
References​ 35

Part 3
Introduction​ 1
Code Breakdown​ 1
Initialization and Hyperparameters​ 1
Model Definition​ 2
Multi Headed Attention​ 3
Transformer Block​ 4
Attention Language Model​ 5
Training logic​ 6
Evaluate Logic​ 7
Data and Generation Helpers​ 8
Main execution​ 10
Key Takeaways​ 14
Conclusion​ 15

Part 4
Introduction​ 1
Mixture of Experts (MoE)​ 1
Grouped-Query Attention (GQA)​ 2
Rotary Position Embeddings (RoPE)​ 3
RMSNorm (Root Mean Square Normalization)​ 4
SwiGLU (Swish-Gated Linear Unit)​ 4
KV Caching (Key-Value Caching)​ 5
Code Breakdown​ 6
Model Architecture​ 6
Tokenizer​ 16
Weight Loading​ 17
Text Generation​ 20
Main Execution​ 21
Key Takeaways​ 23
Conclusions​ 24

Part 5
Introduction: An Era of Refinement​ 1
Key Architectural Innovations of the Modern Era​ 1
The Rise of Sparse Models: Mixture-of-Experts (MoE)​ 2
Evolving the Attention Mechanism: Beyond Multi-Head​ 2
Grouped-Query Attention (GQA)​ 2
Multi-Head Latent Attention (MLA)​ 3
Sliding Window Attention​ 4
The Subtle Art of Normalization and Positional Signals​ 5
Normalization: The Art of Stability​ 5
No Positional Embeddings (NoPE)​ 6
Sliding Attention Window​ 7
A Tour of 2025's Flagship Architectures​ 8
DeepSeek-V3​ 8
OLMo 2​ 9
Gemma 3​ 10
Llama 4: Mainstreaming the Mixture-of-Experts​ 11
Qwen3: The Hallmark of Versatility​ 13
SmolLM3 and the Frontier of Positional Information​ 15
GTP-OSS: the OpenAI Open Source take​ 15
Kimi2: the new model from China​ 18
Key Takeaways​ 20
Conclusion​ 21
Part 6

Introduction​ 1
Gemma Model​ 1
gemma/config.py​ 1
gemma/gemma3_model.py​ 2
gemma/gemma3_preprocessor.py​ 12
gemma/model.py​ 12
gemma/model_xla.py​ 31
gemma/tokenizer.py​ 36
gemma/xla_model_parallel.py​ 36
Conclusion​ 38

Transformer Models
Introduction
The Transformer architecture has fundamentally reshaped the fields of machine
learning and natural language processing (NLP). Introduced in the influential 2017
paper "Attention is All You Need" by Vaswani et al., this model departed from the
traditional recurrent and convolutional layers used in sequence transduction tasks. Its
innovation lies in its exclusive reliance on the attention mechanism, which enabled
unprecedented levels of parallelization and performance. This chapter offers a
thorough exploration of the Transformer model, covering its foundational architectural
elements, advanced training and deployment techniques, and methods for interaction.
We will meticulously examine the core attention mechanism, trace its evolution from
pre-training to fine-tuning, and discuss strategies for managing its context and
guiding its behavior through prompting. This chapter is designed for readers with a
basic understanding of machine learning, providing an accessible yet scientific
journey into the intricate and captivating realm of Transformer models.
The Dawn of a New Architecture
Before the advent of the Transformer, the state-of-the-art in processing sequential
data, such as text or time series, was dominated by Recurrent Neural Networks
(RNNs) and their more sophisticated variants, Long Short-Term Memory (LSTM) and
Gated Recurrent Unit (GRU) networks. These models process data sequentially, token
by token, maintaining a hidden state that carries information from previous steps to
the current one. This sequential nature, while intuitive for modeling sequences,
presented a significant bottleneck. The computation for each step depends on the
completion of the previous one, making it difficult to parallelize the training process
over long sequences. Furthermore, RNNs struggled with capturing long-range
dependencies; information from early in a sequence could become diluted or lost by
the time the model reached later steps, a problem often referred to as the vanishing
gradient problem.

Convolutional Neural Networks (CNNs), traditionally used for image processing, were
also adapted for sequence data. By using filters that slide over local windows of the
sequence, CNNs could capture local dependencies and, by stacking layers, could
increase their receptive field to model longer-range interactions. While more
parallelizable than RNNs, TB they still faced limitations in modeling dependencies
between tokens that were very far apart without a prohibitively large number of
layers.

The Transformer architecture was born out of a desire to overcome these


fundamental limitations. The core innovation was to rely entirely on a self-attention
mechanism, which allows the model to directly weigh the importance of all other
tokens in the input sequence when processing a given token, regardless of their
position. This parallel, position-agnostic approach to modeling dependencies was
revolutionary. It meant that the relationship between any two words in a sentence, no
matter how distant, could be learned with a constant number of computational steps.
This not only solved the long-range dependency problem more effectively but also
unlocked massive parallelization, as the computations for each token could be
performed simultaneously. This breakthrough paved the way for training much larger
and more powerful models on vast datasets, leading directly to the era of Large
Language Models (LLMs) that we see today.
The Core Transformer Architecture: An Overview
At its core, the original Transformer model, as introduced for machine translation, is
an encoder-decoder structure. This architecture is designed to transform an input
sequence from one domain (e.g., a sentence in German) into an output sequence in
another domain (e.g., the same sentence in English).

The Encoder Stack: The encoder's job is to process the entire input sequence and
build a rich, contextualized representation of it. It consists of a stack of identical
layers (the original paper used N=6 layers). Each token representation (embedding)
in the input sequence flows through these layers in parallel. Within each layer, the
tokens pass through two main sub-layers: a multi-head self-attention mechanism
and a position-wise fully connected feed-forward network.

The Decoder Stack: The decoder's role is to generate the output sequence, one
token at a time, using the representation created by the encoder. It also consists of a
stack of N=6 identical layers. The decoder's structure is similar to the encoder's but
with a crucial third sub-layer. In addition to the multi-head self-attention and
feed-forward sub-layers, the decoder incorporates a cross-attention (or
encoder-decoder attention) mechanism. This allows the decoder, at each step of
generation, to "look back" and focus on the most relevant parts of the encoded input
sequence.

Residual Connections and Layer Normalization: A critical detail in the architecture


is the use of residual connections around each of the sub-layers, followed by layer
normalization. The output of each sub-layer is LayerNorm(x + Sublayer(x)), where
Sublayer(x) is the function implemented by the sub-layer itself. This technique is vital
for training deep networks, as it allows gradients to flow more easily through the
many layers of the model, preventing them from vanishing or exploding and
stabilizing the training process.

[Image Description: A high-level block diagram of the Transformer architecture. On


the left, a box labeled "Encoder Stack" shows an input arrow at the bottom ("Input
Sequence") and an output arrow at the top leading to the decoder. Inside the box,
several "Encoder Layer" blocks are stacked vertically. On the right, a box labeled
"Decoder Stack" shows the arrow from the encoder pointing to it, as well as an input
arrow at the bottom ("Output Sequence (shifted right)"). An output arrow at the top
is labeled "Output Probabilities". Inside the decoder box, several "Decoder Layer"
blocks are stacked. Arrows show the flow of information, highlighting the connection
from the final encoder layer to each decoder layer.]

The Encoder in Detail


Let's dissect the journey of an input sequence through the encoder stack.
1.​ Input Embedding: The process begins by converting the input tokens (words,
subwords, etc.) into numerical vectors. This is done using an embedding layer,
which is essentially a lookup table where each unique token is mapped to a
high-dimensional vector. The dimension of these embeddings is a
hyperparameter of the model, often denoted as d_model (e.g., 512).
2.​ Positional Encoding: Since the self-attention mechanism is inherently
position-agnostic—it treats the input as an unordered bag of tokens—we must
explicitly inject information about the position of each token in the sequence. This
is achieved by adding a positional encoding vector to each input embedding. The
original paper used sine and cosine functions of different frequencies:​
PE(pos, 2i) = sin(pos / 10000^(2i/d_model) PE(pos,2i)​=sin(pos/100002i/dmodel​)​
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))PE(pos,2i+1)​=cos(pos/100002i/dmodel​)​
where pos is the position of the token in the sequence and i is the dimension in
the embedding vector. This method has the advantage of allowing the model to
easily learn to attend to relative positions, as the encoding for any position can be
represented as a linear function of the encodings for other positions.
The resulting vectors, which now contain both semantic and positional information,
are then fed into the first encoder layer.

Inside an Encoder Layer:


●​ Multi-Head Self-Attention: This is the first sub-layer. For each token, this
mechanism calculates an "attention score" with every other token in the input
sequence. These scores determine how much focus or "attention" to place on
other tokens when creating the new representation for the current token. The
"multi-head" aspect means this process is done multiple times in parallel with
different, learned linear projections, allowing the model to focus on different
types of relationships simultaneously. The outputs of these parallel "heads" are
then concatenated and linearly projected back to the original dimension.
●​ Feed-Forward Network (FFN): This is the second sub-layer. The output from the
attention sub-layer is passed through a simple, fully connected feed-forward
network. This network is applied to each token's representation independently
and identically. It consists of two linear transformations with a ReLU activation in
between:​
FFN(x)=max(0,xW1​+b1​)W2​+b2​​
This FFN provides additional non-linearity and allows the model to learn more
complex transformations of the token representations.
The output of the FFN is then passed to the next encoder layer, and this process
repeats for the entire stack. The final output of the top encoder layer is a set of
contextualized embedding vectors, one for each input token, that represents the
entire input sequence.

The Decoder in Detail


The decoder's task is to generate the output sequence token by token, in an
auto-regressive manner. This means that to generate the token at step t, the decoder
uses the previously generated tokens (from step 1 to t-1) as part of its input.
1.​ Output Embedding and Positional Encoding: The process starts similarly to the
encoder. The target sequence generated so far (or the ground-truth target during
training) is embedded and positional encodings are added. A crucial detail is that
the input to the decoder is the target sequence shifted right. This ensures that
the prediction for position i can only depend on the known outputs at positions
less than i.

Inside a Decoder Layer:


●​ Masked Multi-Head Self-Attention: The first sub-layer is a self-attention
mechanism, just like in the encoder. However, it is modified to prevent positions
from attending to subsequent positions. This is achieved by applying a
look-ahead mask. Before the softmax is applied in the attention calculation, the
values corresponding to future positions are set to negative infinity. This
effectively makes it impossible for the decoder to "cheat" by looking at the next
token in the sequence it is trying to predict.
●​ Multi-Head Cross-Attention (Encoder-Decoder Attention): This is the second
and most critical sub-layer, where the decoder interacts with the encoder's
output. This mechanism works just like the self-attention mechanism, but with a
key difference in its inputs. The Queries (Q) come from the previous decoder
sub-layer, while the Keys (K) and Values (V) come from the output of the final
encoder layer. This allows every position in the decoder to attend to all positions
in the input sequence. It is this mechanism that enables the decoder to focus on
the most relevant parts of the source sentence when generating each target
token. For instance, when translating a sentence, as the decoder generates a
verb, the cross-attention might focus heavily on the corresponding verb and its
subject in the encoded input.
●​ Feed-Forward Network (FFN): The third sub-layer is a feed-forward network,
identical in structure to the one in the encoder. It provides further processing for
the representations that have incorporated both self-attention (over the
generated target sequence) and cross-attention (over the source sequence).
After the final decoder layer, the resulting vector is passed through a final linear layer
and a softmax function. The linear layer acts as a classifier, projecting the
d_model-dimensional vector into a much larger vector whose size is equal to the
vocabulary size. The softmax function then converts these raw scores (logits) into
probabilities, giving a probability for each possible token in the vocabulary being the
next token in the sequence. The token with the highest probability is typically chosen,
and the process repeats for the next step.

Attention is All You Need: The Core Mechanism


The heart of the Transformer is the Scaled Dot-Product Attention mechanism. Its
function is to map a query and a set of key-value pairs to an output. The query, keys,
values, and output are all vectors. The output is computed as a weighted sum of the
values, where the weight assigned to each value is computed by a compatibility
function of the query with the corresponding key.

The Inputs: Queries, Keys, and Values (Q, K, V)


The concepts of Query, Key, and Value are borrowed from information retrieval
systems. Imagine searching for a video on a platform. Your search query is the
Query. The titles or descriptions of the videos are the Keys. The videos themselves
are the Values. The system calculates a similarity score between your Query and
each Key, and then returns a weighted combination of the Values (e.g., a ranked list
of videos) based on these scores.

In the context of the Transformer, Q, K, and V are not pre-defined but are learned.
They are created by projecting the input embedding vectors (or the outputs of the
previous layer) through three separate, learned linear layers (weight matrices WQ,
WK, and WV).
●​ Query (Q): Represents the current token's "search query" for relevant
information.
●​ Key (K): Represents the "label" or "identifier" of other tokens in the sequence,
which the query can be compared against.
●​ Value (V): Represents the actual content or meaning of the other tokens.

The Attention Formula


The computation of attention is captured in a single, elegant formula:

Attention(Q,K,V)=softmax(dk​​QKT​)V

Let's break this down step-by-step:


1.​ Score Calculation (QKT): The first step is to calculate the dot product of the
query vector with every key vector. If we have a sequence of n tokens, this results
in an n x n matrix of scores. The dot product is a simple but effective measure of
similarity: if two vectors are pointing in a similar direction, their dot product will be
large.
2.​ Scaling (dk​​...​): The dot products are then scaled down by dividing by the square
root of the dimension of the key vectors, d_k. This scaling factor is crucial for
stabilizing the training process. For large values of d_k, the dot products can grow
very large in magnitude, pushing the softmax function into regions where it has
extremely small gradients. This would make learning difficult. The scaling ensures
that the arguments to the softmax function remain in a reasonable range.
3.​ Weighting (softmax): A softmax function is applied to the scaled scores. This
converts the scores into probabilities that sum to 1. The resulting attention
weights determine how much of each value vector should be incorporated into
the final output for the current query. A high weight means the corresponding
value is highly relevant.
4.​ Output Calculation (...V): Finally, the attention weights are multiplied by the
value vectors. This produces a weighted sum of the values, where the weights are
determined by the query-key compatibility. The resulting output vector for a given
token is a blend of all the token values in the sequence, filtered by their relevance
to the query token.
[Image Description: A diagram illustrating the Scaled Dot-Product Attention
mechanism. An input matrix "Queries" and "Keys" are shown entering a "MatMul"
(Matrix Multiplication) operation. The output goes to a "Scale" operation (dividing by
sqrt(d_k)), then an optional "Mask" operation, and then a "SoftMax" function. The
output of the softmax, labeled "Attention Weights," is then multiplied with an input
matrix "Values" in another "MatMul" operation. The final output is labeled "Output".]

Multi-Head Attention: Seeing in Parallel


Instead of performing a single attention function with d_model-dimensional keys,
values, and queries, the authors found it beneficial to linearly project the queries,
keys, and values h times with different, learned linear projections to d_k, d_k, and d_v
dimensions, respectively. On each of these projected versions of queries, keys, and
values, the attention function is performed in parallel, yielding d_v-dimensional
output values. This is the essence of Multi-Head Attention.

MultiHead(Q,K,V)=Concat(head1​,...,headh​)WO

where headi​=Attention(QWiQ​,KWiK​,VWiV​)

The Intuition: A single attention head might learn to focus on a particular type of
relationship, for example, subject-verb agreement. By having multiple heads, the
model can learn to attend to different aspects of the sequence simultaneously. One
head might track syntactic dependencies, another might track semantic similarity,
and a third might track co-reference (which pronouns refer to which nouns).

The Process:
1.​ Projection: The original Q, K, and V matrices are each passed through h different
linear layers (weight matrices WiQ​,WiK​,WiV​), creating h sets of lower-dimensional
Q, K, and V matrices. Each of these sets is called an "attention head."
2.​ Parallel Attention: The scaled dot-product attention is calculated for each head
independently and in parallel. This results in h separate output matrices.
3.​ Concatenation and Final Projection: The h output matrices are concatenated
back together. This concatenated matrix is then passed through one final linear
projection layer (with weight matrix WO) to produce the final output of the
multi-head attention layer. This final projection allows the model to combine the
information learned by the different heads.
In the original Transformer, h=8 heads were used. The model dimension d_model was
512, and the dimension of each head d_k and d_v was d_model / h = 64. This setup
ensures that the total computational cost is similar to that of a single attention head
with the full d_model dimension, while providing the benefit of learning diverse
representations.

[Image Description: A diagram of Multi-Head Attention. A single set of Q, K, V


matrices is shown at the bottom. Arrows point from them to h different "Scaled
Dot-Product Attention" blocks, arranged in parallel. Each block is labeled "head 1",
"head 2", ..., "head h". The outputs from all these heads are then shown entering a
"Concat" block. The output of the Concat block goes through a final "Linear"
projection layer to produce the final output.]

Types of Attention in the Transformer


The Transformer architecture cleverly utilizes the same fundamental attention
mechanism in three distinct ways, defined by how the Q, K, and V vectors are
sourced.
1.​ Encoder Self-Attention: This occurs in the encoder layers. Here, the Q, K, and V
vectors all originate from the same place: the output of the previous encoder
layer. In the first layer, they come from the input embeddings plus positional
encodings. Because Q, K, and V are all derived from the same sequence, this is
called "self-attention." Each position in the encoder can attend to all positions in
the previous layer of the encoder. This allows the encoder to build up a rich,
contextualized understanding of the input sequence by relating every word to
every other word.
2.​ Decoder Masked Self-Attention: This occurs in the first sub-layer of the
decoder. Similar to encoder self-attention, the Q, K, and V vectors are all derived
from the same source: the output of the previous decoder layer (or the target
sequence embeddings for the first layer). However, as mentioned earlier, this
self-attention is "masked." The look-ahead mask ensures that for any given
position i, the attention mechanism can only draw information from positions 1 to
i, not from i+1 onwards. This is essential for the auto-regressive nature of the
decoder, preserving the property that the prediction for the current token can
only depend on the tokens that have already been generated.
3.​ Cross-Attention (Encoder-Decoder Attention): This is the crucial link between
the encoder and the decoder, occurring in the second sub-layer of the decoder.
Here, the inputs are sourced differently. The Queries (Q) come from the output
of the previous decoder sub-layer (the masked self-attention layer). The Keys (K)
and Values (V), however, come from the output of the final encoder layer. This is
not "self-attention" because the queries are attending to a different sequence
(the encoded input) than the one they originated from (the partially generated
output). This mechanism is the workhorse of translation and conditional
generation tasks. It allows the decoder, at each step, to query the entire input
sequence and focus on the parts that are most relevant for generating the next
output token. For example, when generating the English word "cat," the
cross-attention might place a high weight on the German word "Katze" in the
encoder's output representation.
Understanding these three contexts is key to understanding the flow of information
and the division of labor within the full Transformer architecture. The encoder builds
the context, the decoder's masked self-attention builds the context of what has been
generated so far, and the cross-attention merges these two worlds to produce the
final output.

Computational Complexity and Benefits of Attention


One of the primary motivations for the Transformer was to improve upon the
computational properties of recurrent models. Let's analyze the complexity of the
attention mechanism and compare it to its predecessors.

Let n be the sequence length and d be the representation dimension (d_model).

Complexity of Self-Attention: The dominant computation in a self-attention layer is


the matrix multiplication of Q with K-transpose, which is an (n x d) matrix multiplied by
a (d x n) matrix. This results in an (n x n) attention score matrix, with a computational
complexity of O(n2⋅d). This quadratic dependency on the sequence length n is the
most significant computational characteristic of the Transformer. While it allows for
direct modeling of all pairwise interactions, it also becomes a bottleneck for very long
sequences.

The second matrix multiplication is between the (n x n) attention weights and the (n x
d) Value matrix, which is also O(n2⋅d). The position-wise feed-forward network has a
complexity of O(n⋅d2), as it processes each of the n positions independently.
Therefore, for a single layer, the total complexity is O(n2⋅d+n⋅d2). In practice, for
typical model sizes where d is fixed (e.g., 512) and n can be large, the O(n2⋅d) term
dominates.
Comparison with Recurrent and Convolutional Layers:

●​ Recurrent Neural Networks (RNNs): An RNN processes a sequence token by


token. The complexity per layer is O(n⋅d2). While this appears better than the
Transformer's O(n2⋅d) for large n, the key issue is the sequential nature of the
computation. The number of sequential operations is O(n), which prevents
parallelization over the time dimension. The path length tv for information to
travel between distant tokens is also O(n), making it hard to capture long-range
dependencies.
●​ Convolutional Neural Networks (CNNs): A CNN layer with a kernel of width k
has a complexity of O(n⋅k⋅d2). This is highly parallelizable. However, to connect
two distant tokens, a stack of O(n/k) layers is required, increasing the network
path length significantly.
Advantages of Attention Mechanisms

Attention mechanisms offer several key benefits, particularly in the realm of natural
language processing:

1.​ Constant Path Length: A significant advantage of self-attention is its ability to


connect any two tokens in a sequence directly. This means the number of
operations required to relate tokens remains constant, regardless of their
distance. This direct connection is highly effective for capturing long-range
dependencies, a task where RNNs (with O(n) path length) and typical CNNs
(O(logk​n) for dilated or O(n/k) for regular convolutions) are less efficient.
2.​ Enhanced Parallelizability: Unlike Recurrent Neural Networks (RNNs) where
computations are sequential, self-attention layers process the entire sequence
simultaneously. This inherent parallelism allows for efficient computation on
modern hardware like GPUs and TPUs, which has been a crucial factor in the
development and training of exceptionally large models.
3.​ Potential for Interpretability: Attention weights, the output of the softmax
function, can provide some degree of insight into the model's decision-making.
Visualizing which tokens attend to others can offer clues about how the model
processes information. However, it's important to exercise caution, as attention
weights are not always a direct explanation of the model's behavior.

The main limitation of attention mechanisms is their quadratic complexity concerning


sequence length. This drawback will be further explored in the "Context
Management" section.
The Transformer Lifecycle
The true power of modern Transformer models was unlocked not just by the
architecture itself, but by a new training paradigm: large-scale, self-supervised
pre-training followed by task-specific fine-tuning. This approach, a form of
transfer learning, allows a model to learn general-purpose knowledge from vast
amounts of unlabeled text data, which can then be adapted to a wide range of
downstream tasks with much less labeled data.

The lifecycle of a large language model can be broadly divided into three phases:
pre-training, fine-tuning, and alignment. The pre-training phase is the most
computationally intensive and is where the model learns its fundamental
understanding of language.

Pretraining
The Goal of Pre-training: The objective is to use a self-supervised task to train the
model on a massive corpus of text (e.g., the entire internet, all of Wikipedia, a large
collection of books). "Self-supervised" means that the labels or targets for the
learning task are generated automatically from the input data itself, without any need
for human annotation. This is crucial because unlabeled text is available in
near-limitless quantities.

Common Pre-training Objectives:


1.​ Masked Language Modeling (MLM): This objective was popularized by BERT
(Bidirectional Encoder Representations from Transformers). The idea is to corrupt
the input sentence by randomly masking some of its tokens. The model's task is
then to predict the original identity of these masked tokens. For example:​
Input: "The [MASK] brown fox jumps over the lazy [MASK]."​
Target: Predict "quick" for the first mask and "dog" for the second.​
To make this prediction, the model must use the context from both the left and
the right of the mask. This forces the model to learn a deep, bidirectional
representation of the language. This objective is typically used for pre-training
encoder-only models like BERT.
2.​ Causal Language Modeling (CLM) / Next Token Prediction: This is the
standard objective for auto-regressive models like the GPT (Generative
Pre-trained Transformer) family. The model is trained to predict the next token in
a sequence given all the preceding tokens.​
Input: "The quick brown fox"​
Target: Predict "jumps"​
The model processes the text from left to right, and at each position, it tries to
predict the next word. This naturally trains the model for text generation tasks.
This objective is used for pre-training decoder-only models.
3.​ Next Sentence Prediction (NSP): Used in the original BERT model alongside
MLM, this was a sentence-level objective. The model was given two sentences, A
and B, and had to predict whether sentence B was the actual sentence that
followed A in the original text or just a random sentence from the corpus. The
goal was to help the model understand sentence relationships. However,
subsequent research found this objective to be of limited value, and many
modern models have abandoned it in favor of other strategies.
4.​ Denoising Objectives: More advanced models like T5 (Text-to-Text Transfer
Transformer) and BART use more sophisticated "denoising" objectives. Instead of
just masking tokens, contiguous spans of text are corrupted or replaced with a
single mask token. The model is then trained to reconstruct the original,
uncorrupted text spans. This is a more general and powerful form of
self-supervision. For example, T5 frames every NLP task as a text-to-text
problem, where a textual prompt is given as input and the model must generate a
textual output.
The result of this expensive pre-training phase is a base model—a massive neural
network whose weights encode a rich, general-purpose understanding of grammar,
syntax, semantics, and even a significant amount of factual knowledge about the
world as represented in the training data.

Fine-Tuning
Once a base model has been pre-trained, it possesses a powerful, general
understanding of language but is not specialized for any particular task. The
fine-tuning phase adapts this general model to a specific downstream task using a
much smaller, task-specific labeled dataset.

The Process of Fine-Tuning: The process involves taking the pre-trained model and
continuing the training process (i.e., continuing to update its weights via
backpropagation) on a new dataset. For example, to create a sentiment analysis
model, one would take a pre-trained base model like BERT and fine-tune it on a
dataset of movie reviews labeled as "positive" or "negative."
1.​ Add a Task-Specific Head: A small, new neural network layer (or "head") is
typically added on top of the base Transformer model. The architecture of this
head depends on the task.
○​ For a classification task (like sentiment analysis), this would be a simple linear
layer followed by a softmax to output class probabilities.
○​ For a token-level task (like Named Entity Recognition, where each token is
classified), a linear layer would be applied to each output token's
representation.
○​ For a generative task (like summarization or translation), the base decoder
model itself can be used, or a new decoder might be attached to a
pre-trained encoder.
2.​ Continue Training: The entire model (the pre-trained base plus the new head) is
then trained on the labeled dataset. The learning rate used during fine-tuning is
typically much smaller than the one used during pre-training. This is because the
model's weights are already in a very good state; we only want to "nudge" them
slightly to adapt to the new task, without catastrophically forgetting the general
knowledge learned during pre-training.
Why Fine-Tuning is So Effective: Fine-tuning leverages the power of transfer
learning. The model doesn't need to learn the structure of language from scratch
using the small, labeled dataset. It has already learned that from the massive,
unlabeled pre-training corpus. The fine-tuning process only needs to learn the
mapping from the model's rich internal representations to the specific output format
required by the task. This makes it possible to achieve state-of-the-art performance
on many NLP tasks with only a few thousand, or even just a few hundred, labeled
examples. This is a dramatic improvement over previous paradigms that required very
large labeled datasets for every new task.

Popular Fine-Tuning Methods:


●​ Full Fine-Tuning: The most common method, where all the weights of the
pre-trained model are updated during training. This offers the best performance
but can be computationally expensive and requires storing a full copy of the
model for each new task.
●​ Parameter-Efficient Fine-Tuning (PEFT): As models have grown to hundreds of
billions of parameters, full fine-tuning has become infeasible for many users.
PEFT methods have emerged to address this. Instead of updating all the model's
weights, these techniques freeze the vast majority of the pre-trained parameters
and only train a small number of new or existing parameters.
○​ Adapters: Small, new neural network modules are inserted between the
layers of the pre-trained model. Only the weights of these adapter layers are
trained.
○​ LoRA (Low-Rank Adaptation): This popular technique involves learning
low-rank updates to the weight matrices of the Transformer. Instead of
updating a large matrix W, LoRA learns two smaller matrices A and B such that
their product BA approximates the update. This dramatically reduces the
number of trainable parameters.
○​ Prompt Tuning / Prefix Tuning: Instead of modifying the model's weights,
this method involves learning a "soft prompt"—a continuous vector that is
prepended to the input sequence. Only this prompt vector is tuned, while the
model itself remains frozen.
PEFT methods make it possible to adapt a single large base model to many different
tasks efficiently, both in terms of computation and storage.

Alignment
For large-scale generative models intended for direct human interaction (like
chatbots or assistants), a third phase has become critical: alignment. The goal of
alignment is to steer the model's behavior to be more helpful, harmless, and honest,
and to follow user instructions effectively. A base model trained on raw internet text
might be very good at predicting the next word, but it might also generate factually
incorrect, toxic, biased, or unhelpful content. The alignment process aims to mitigate
these undesirable behaviors.

The most prominent technique for alignment is Reinforcement Learning from


Human Feedback (RLHF). This is a complex but powerful process that uses human
preferences as a reward signal to fine-tune the model.

The RLHF Process:


1.​ Collect Human Preference Data: The first step is to create a dataset of human
preferences. This is done by taking a prompt and generating several different
responses from the model. A human labeler then ranks these responses from
best to worst. For example, for the prompt "Explain the theory of relativity," one
response might be clear and accurate, another might be too technical, and a third
might be factually wrong. The human would rank them accordingly.
2.​ Train a Reward Model (RM): A separate Transformer model, called the reward
model, is then trained on this preference data. The RM takes a prompt and a
model-generated response as input and outputs a single scalar value: a "reward"
score that predicts how much a human would prefer that response. It is trained to
assign higher scores to the responses that humans ranked higher. This RM
essentially learns to embody human preferences.
3.​ Fine-Tune the LLM with Reinforcement Learning: The final step uses the
trained reward model to fine-tune the original language model. This is done using
a reinforcement learning algorithm like Proximal Policy Optimization (PPO). The
process works as follows:
○​ The LLM (now acting as the "policy" in RL terms) is given a prompt from a
dataset.
○​ The LLM generates a response.
○​ The reward model evaluates this response and gives it a reward score.
○​ This reward signal is used to update the weights of the LLM. The LLM is
incentivized to generate responses that the reward model will score highly.
A key component of this step is a KL-divergence penalty. The model is penalized
for moving too far away from the original, pre-trained base model. This is crucial to
prevent the model from "over-optimizing" for the reward model and forgetting its
core language capabilities, a phenomenon known as "reward hacking."

RLHF has been instrumental in creating models that are not just capable but also
safe and useful as conversational agents. It's a powerful technique for instilling
complex, nuanced human values into a model's behavior, going beyond what can be
achieved with simple supervised fine-tuning. More recent techniques like Direct
Preference Optimization (DPO) aim to achieve similar results to RLHF but with a
simpler, more stable training process that does not require explicitly training a
separate reward model.

Modern Transformers
WIP

Context Management: The Context Window


One of the most fundamental and defining characteristics of a Transformer model is
its context window. The context window, also known as the context length or
maximum sequence length, is the fixed, maximum number of tokens that the model
can consider at any one time when processing input or generating output.
For example, if a model has a context window of 4096 tokens, it can read and
process a text that is up to 4096 tokens long. When generating a response, it can
"remember" the initial prompt and the text it has generated so far, up to a combined
total of 4096 tokens. Any information beyond this window is effectively lost and
cannot influence the model's output.

The Origin of the Context Window Limit: The context window is not an arbitrary
choice but a direct consequence of the self-attention mechanism's computational
complexity. As we discussed, the time and memory required for the self-attention
calculation scale quadratically with the sequence length n (i.e., O(n2)).

●​ Memory: The attention score matrix, which stores the similarity between every
pair of tokens, has dimensions n x n. For a sequence of 4096 tokens, this matrix
has over 16 million entries. For 32,000 tokens, it's over a billion entries. Storing
this matrix in GPU memory becomes a major bottleneck.
●​ Computation: The number of floating-point operations also grows quadratically,
making the processing of very long sequences prohibitively slow.
Because of these hardware and time constraints, model designers must choose a
fixed maximum sequence length that the model is trained on and can handle during
inference. Early models like the original BERT had a context window of 512 tokens.
The GPT-3 family had a window of 2048 tokens. More recent models have pushed
this limit to 4K, 8K, 32K, 128K, 1M, 2M and even larger sizes, thanks to architectural
innovations and more powerful hardware.

Implications of a Fixed Context Window:

The fixed context window has profound implications for how models are used.
●​ Document Analysis: To analyze a long document that exceeds the context
window (e.g., a book or a lengthy legal contract), one must employ a "chunking"
strategy. The document is broken down into smaller, overlapping chunks that fit
within the window. The model processes each chunk independently, and the
results are then aggregated. This is a workaround, but it prevents the model from
forming a truly holistic understanding of the document, as it can never see the full
context at once.
●​ Extended Conversations: In a long-running chatbot conversation, the beginning
of the conversation will eventually "scroll off" the context window. The model will
forget what was discussed earlier, leading to a loss of conversational coherence.
●​ Code Generation: When writing or analyzing a large codebase, the model can
only see a small fraction of the relevant files at once, making it difficult to
understand complex dependencies and project-wide architecture.
Overcoming the limitations of the fixed, quadratic-cost attention mechanism is one of
the most active and important areas of research in the field of deep learning.

The Challenge of Long Sequences


The quadratic complexity of the standard self-attention mechanism significantly limits
the ability of Transformer models to handle long sequences, hindering their
application in tasks involving extensive data. This includes summarizing entire books
or research papers, answering questions about lengthy legal or financial documents,
analyzing genomic sequences, maintaining long-term memory in conversational AI,
and processing high-resolution images where each pixel acts as a token.
The brute-force solution of simply using more powerful hardware is not sustainable.
A 10x increase in sequence length requires a 100x increase in memory and
computation for the attention matrix. This has spurred a great deal of research into
more efficient attention mechanisms, often referred to as "X-formers," that aim to
approximate the full attention matrix without explicitly computing it. The goal is to
achieve linear or near-linear complexity (O(n) or O(nlogn)) with respect to the
sequence length, while retaining the performance and modeling power of the original
Transformer.

These methods can be broadly categorized into several approaches:


1.​ Fixed Patterns: These methods restrict the attention mechanism so that each
query can only attend to a fixed subset of keys, rather than all of them. For
example, the Longformer uses a combination of a sliding window attention (each
token attends to its local neighbors) and a global attention mechanism (a few
pre-selected tokens, like the [CLS] token, can attend to the entire sequence). The
BigBird model uses a similar combination of windowed, global, and random
attention patterns. These methods can achieve linear complexity but require
careful engineering of the attention pattern.
2.​ Learnable Patterns: Instead of using a fixed, hand-crafted attention pattern,
some methods attempt to learn the optimal sparse attention pattern. The
Reformer model uses locality-sensitive hashing (LSH) to group similar tokens
together, and attention is only computed within these groups. This is an
approximation that works well for many tasks.
3.​ Low-Rank Methods: These methods are based on the assumption that the n x n
attention matrix is "low-rank," meaning it can be approximated by the product of
two smaller, "tall-and-skinny" matrices. The Linformer projects the keys and
values into a lower-dimensional space before the attention computation, reducing
the complexity from O(n2) to O(n).
4.​ Kernel Methods: These methods reframe the attention calculation to avoid
explicitly forming the QK^T matrix. By cleverly reordering the matrix
multiplications, they can compute the attention output in a way that scales
linearly with n. Performer is a prominent example that uses random feature maps
to approximate the softmax attention kernel.
5.​ Recurrence and Convolution: A recent trend has been to re-introduce ideas
from recurrent and convolutional models to handle long sequences in a more
efficient, linear-time fashion. Models like RWKV and state-space models like
Mamba combine the parallelizable training of Transformers with the
constant-memory, linear-time inference of RNNs. They maintain a hidden state
that is updated as the sequence is processed, allowing for theoretically infinite
context length during inference, though their practical ability to use very
long-range information is still an area of active research.
Each of these approaches has its own trade-offs between computational efficiency,
memory usage, and model performance. The choice of which method to use often
depends on the specific requirements of the task at hand.

Techniques for Extending Effective Context


Beyond modifying the core attention mechanism, there are several other techniques
used to manage and extend the effective context that a model can handle.

Sliding Window Attention: This is one of the simplest and most effective
techniques, used in models like the Longformer. Instead of every token attending to
every other token, each token is restricted to attending to a fixed-size window of w
surrounding tokens (e.g., w/2 tokens to the left and w/2 to the right). This immediately
reduces the complexity from O(n2) to O(n⋅w), which is linear in n if w is a fixed
constant. This is very effective for tasks where local context is most important. To
handle more global information, this is often combined with a few "global" tokens
that are allowed to attend to the entire sequence.

[Image Description: A visual representation of different attention patterns. A grid


represents the n x n attention matrix. For "Full Attention," the entire grid is shaded,
indicating all-to-all connections. For "Sliding Window Attention," only a diagonal band
is shaded. For "Dilated Sliding Window," the shaded diagonal band has gaps in it,
showing that attention is paid to nearby tokens and some farther tokens, skipping
intermediate ones.]

Attention with Recurrence (Transformer-XL): The Transformer-XL model


introduced a way to use context from beyond the fixed window without breaking the
sequence into completely independent segments. During training, a long sequence is
processed in segments. When processing the current segment, the hidden states
computed for the previous segment are cached and fed into the current segment's
processing as an extended context. These cached hidden states are not updated via
backpropagation; they are used in a read-only fashion. This allows information to
flow from one segment to the next, creating an effective context length that is much
longer than the segment length the model is trained on. This technique is particularly
effective for tasks requiring long-range coherence, like language modeling on long
articles.

Retrieval-Augmented Generation (RAG): This is a powerful hybrid approach that


combines a parametric model (the Transformer) with a non-parametric memory (a
large external database of text). When given a prompt, instead of relying solely on the
knowledge stored in its own weights and its limited context window, the model first
performs a retrieval step. It uses the prompt to search a vast corpus (like all of
Wikipedia) and retrieves a small number of documents that are most relevant to the
prompt. These retrieved documents are then concatenated with the original prompt
and fed into the Transformer's context window. The model then generates a response
based on both the original prompt and the rich, factual information provided in the
retrieved documents.

RAG has several key advantages:


●​ Reduces Hallucination: It grounds the model's response in verifiable, external
knowledge, making it less likely to generate factually incorrect information.
●​ Up-to-Date Knowledge: The external knowledge base can be easily updated
without having to retrain the entire multi-billion parameter model.
●​ Transparency: The model can cite its sources, allowing users to verify the
information it provides.
RAG effectively extends the model's "knowledge context" to a massive scale, even if
its "sequence context" window remains fixed. This has become a standard technique
for building reliable and knowledgeable AI systems.

Prompting andand Inference: The Art of Talking to LLMs


Once a large language model is trained, we need a way to interact with it and elicit
the desired behavior. This is done through prompting. A prompt is simply the input
text given to the model, which it uses as a starting point to generate a continuation.
The process of the model generating this continuation is called inference.

The discovery that the behavior of large, pre-trained models could be controlled
purely through the careful design of the input prompt, without any changes to the
model's weights, was a major breakthrough. This paradigm is often called in-context
learning.

Prompting Strategies

The way a prompt is structured can have a dramatic impact on the quality and nature
of the model's output. Several key strategies have emerged:
1.​ Zero-Shot Prompting: This is the simplest form of prompting. The model is given a
description of a task and asked to perform it without any examples.
○​ Prompt: "Translate the following English text to French: 'Hello, how are you?'"
○​ The model is expected to understand the instruction "Translate...to French"
and perform the task directly. The ability of large models to perform tasks in a
zero-shot setting is a remarkable emergent property.
2.​ Few-Shot Prompting: In this strategy, the prompt includes a small number of
examples (typically 1 to 5) of the task being performed. These examples "prime"
the model, showing it the exact format and style of the desired output.
○​ Prompt:​
"English: sea otter -> French: loutre de mer​
English: peppermint -> French: menthe poivrée​
English: cheese -> French:"
○​ By seeing these examples, the model can better infer the pattern and is more
likely to produce the correct output ("fromage"). Few-shot prompting is often
significantly more effective than zero-shot prompting, especially for more
complex or nuanced tasks.
3.​ Chain-of-Thought (CoT) Prompting: This advanced technique aims to draw out
reasoning from the model, a topic that will be explored in depth in subsequent
chapters.. Instead of just showing the final answer in the few-shot examples, the
prompt also includes the intermediate steps of reasoning used to arrive at the
answer.
○​ Prompt:​
"Q: The cafeteria had 23 apples. If they used 20 to make lunch and bought 6
more, how many apples do they have?​
A: The cafeteria started with 23 apples. They used 20, so they had 23 - 20 = 3.
They bought 6 more, so they now have 3 + 6 = 9. The answer is 9.​
Q: Roger has 5 tennis balls. He buys 2 more cans of tennis balls. Each can has
3 tennis balls. How many tennis balls does he have now?​
A:"
○​ By prompting the model to "think step-by-step," it is encouraged to break
down the problem, perform the intermediate calculations, and then state the
final answer. This has been shown to dramatically improve performance on
tasks that require arithmetic, commonsense, and symbolic reasoning. It forces
the model to externalize its reasoning process into the text, which often leads
to more accurate results.
The art of designing effective prompts is known as prompt engineering, and it has
become a critical skill for effectively using and building applications on top of large
language models.

The Inference Process: From Probabilities to Text


When a generative Transformer model is given a prompt, it doesn't just output a
single, deterministic response. Instead, at each step of the generation process, it
produces a probability distribution over its entire vocabulary for what the next token
should be. The process of selecting a token from this distribution and appending it to
the sequence is handled by a decoding algorithm. The choice of decoding
algorithm is crucial as it shapes the characteristics of the generated text.

Let's say the model has processed the input "The best thing about AI is its ability to"
and must now choose the next word. The final layer of the decoder outputs a vector
of logits, which are then converted by a softmax function into probabilities for every
word in the vocabulary.
●​ learn: 30% probability
●​ create: 25% probability
●​ automate: 15% probability
●​ reason: 10% probability
●​ ... and so on for thousands of other words.

How do we pick one?

Common Decoding Algorithms:


1.​ Greedy Search: This is the simplest and most straightforward algorithm. At each
step, it simply selects the token with the highest probability (the "greedy"
choice). In our example, it would choose "learn". It would then append "learn" to
the sequence and feed the new sequence back into the model to generate the
next word.
○​ Pros: Fast and computationally cheap.
○​ Cons: Often leads to repetitive, unnatural, and incoherent text. It can easily
get stuck in loops (e.g., "I am I am I am..."). By committing to the best local
choice at each step, it can miss a much better overall sequence of words. For
example, the sequence "create and innovate" might have a higher overall
probability than "learn and grow," but greedy search would pick "learn" at the
first step and never discover the better alternative.
2.​ Beam Search: This algorithm attempts to mitigate the short-sightedness of
greedy search by keeping track of multiple possible sequences, or "beams," at
each step.
○​ Let's say we use a beam width of k=2. At the first step, instead of just picking
"learn," the algorithm keeps the two most likely tokens: "learn" and "create".
○​ For the next step, it generates continuations for both of these partial
sequences. It might find that the most likely continuations are "learn and" and
"create new".
○​ It then calculates the cumulative probability of all the resulting two-word
sequences and again keeps only the top k (top 2) overall sequences.
○​ This process continues until an end-of-sequence token is generated or a
maximum length is reached. The sequence with the highest final probability is
chosen as the output.
○​ Pros: Produces more fluent and coherent text than greedy search. It is the
standard for many tasks like machine translation.
○​ Cons: Can still lead to repetitive and "safe," generic-sounding text. It is
optimized for high probability, not necessarily for creativity or human-like
quality. It is also more computationally expensive than greedy search.
Stochastic Decoding for More Human-like Text
While beam search improves over greedy search, both methods are deterministic and
tend to produce text that, while grammatically correct, can feel bland and
predictable. To generate more creative, surprising, and human-like text, stochastic
(randomized) decoding methods are often preferred. These methods introduce
randomness into the token selection process.

Temperature Sampling: This is one of the most common techniques for controlling
the randomness of the output. A temperature parameter (T, typically between 0 and
1) is used to rescale the logits before the softmax function is applied.

Pi​=∑j​exp(zj​/ T)exp(zi​/ T)​

where zi​is the logit for the i-th token.


●​ Low Temperature (e.g., T=0.2): Dividing the logits by a small number makes the
distribution "sharper" or more "peaked." The probabilities of the most likely
tokens are amplified, while the probabilities of less likely tokens are suppressed.
This makes the output more deterministic, closer to greedy search. The model
becomes more focused and conservative.
●​ High Temperature (e.g., T=1.0 or higher): Dividing by a larger number flattens
the distribution, making it more uniform. The model is more likely to pick less
probable words, leading to more random, surprising, and "creative" outputs.
However, if the temperature is too high, the output can become incoherent and
full of errors.
After applying temperature scaling, the next token is chosen by sampling from this
new probability distribution.

Top-k Sampling: One problem with pure temperature sampling is that it can
sometimes select very unlikely, nonsensical words, especially with a high
temperature. Top-k sampling addresses this by restricting the sampling pool to a
smaller, more plausible set of words.

The algorithm is simple:


1.​ Identify the k most probable tokens from the distribution.
2.​ Redistribute the probability mass among only these k tokens.
3.​ Sample the next token from this reduced set.
For example, if k=40, the model will only consider the 40 most likely words for its next
choice, ignoring the thousands of other words in the vocabulary. This provides a
good balance, preventing bizarre choices while still allowing for variety.

Nucleus Sampling (Top-p Sampling): This is a more sophisticated and often


preferred alternative to top-k sampling. Instead of picking a fixed number k of tokens,
nucleus sampling picks a set of tokens whose cumulative probability exceeds a
certain threshold p.

The algorithm works as follows:


1.​ Rank the tokens in the vocabulary from most to least probable.
2.​ Sum their probabilities, starting from the most probable, until the cumulative
probability is just greater than or equal to p.
3.​ The set of tokens included in this sum forms the "nucleus."
4.​ The next token is then sampled from this nucleus.

The key advantage of nucleus sampling is that the size of the sampling pool is
adaptive.
●​ When the model is very certain about the next word (i.e., one word has a very
high probability), the nucleus will be very small, perhaps containing only one or
two words. This makes the model behave more deterministically when it's
confident.
●​ When the model is uncertain and the probability is spread out over many
plausible words, the nucleus will be larger, allowing for more diversity.
This adaptive nature often leads to more robust and higher-quality text generation
than fixed top-k sampling. In practice, many applications use a combination of these
techniques, such as nucleus sampling with a temperature parameter, to finely control
the trade-off between coherence and creativity.

The Rise of Decoder-Only Architectures


The architectural shift from the original Transformer's encoder-decoder structure to
decoder-only models marks a significant evolution in the field of large language
models (LLMs). While the initial Transformer, introduced in the seminal "Attention Is All
You Need" paper, showcased the power of self-attention and cross-attention
mechanisms, many of the most prominent contemporary LLMs, including OpenAI's
GPT (Generative Pre-trained Transformer) series, Google's PaLM (Pathways Language
Model), and Meta's LLaMA (Large Language Model Meta AI) family, have embraced
decoder-only architectures.

This design choice is rooted in the profound realization that a highly capable
auto-regressive decoder possesses sufficient power to tackle an incredibly diverse
array of language tasks. By eliminating the encoder component and its associated
cross-attention mechanism, the overall architecture achieves a greater degree of
homogeneity. This simplification not only streamlines the network's design but also
makes it considerably easier to scale these models to truly massive sizes, involving
billions or even trillions of parameters.

Furthermore, the training process for decoder-only models becomes elegantly unified
under a single, straightforward objective: Causal Language Modeling (CLM). This
objective, often simplified to "next-token prediction," involves training the model to
predict the subsequent token in a sequence given all preceding tokens. This approach
has consistently proven to be an exceptionally effective and highly scalable paradigm
for the creation of general-purpose language models.

The inherent ability of these decoder-only models to treat virtually any language task
as a sequence-to-sequence problem is a cornerstone of their success. In this
framework, both the input prompt and the desired completion are considered integral
parts of the same continuous sequence. This unified sequence representation directly
enables the powerful in-context learning capabilities that are a defining characteristic
of modern foundation models. Through in-context learning, these models can infer
task requirements and generate appropriate responses based solely on the provided
prompt, without requiring explicit fine-tuning for each new task. This paradigm shift
has unlocked unprecedented versatility and performance in natural language
understanding and generation.

Mixture-of-Experts (MoE) Models


A pivotal hurdle in scaling LLMs is the enormous computational expense tied to
activating every parameter for every single input token. To overcome this,
Mixture-of-Experts (MoE) has emerged as a groundbreaking architectural
innovation. This technique introduces conditional, or sparse, computation,
fundamentally altering how models process information.

Within an MoE layer, the traditional dense feed-forward network is replaced by a


collection of parallel "expert" networks. Each expert is a smaller, specialized neural
network. Crucially, an additional component, a small "router" network, is incorporated.
For each input token, the router dynamically assesses and selects a limited subset of
these experts – often just two – to process that specific token. All other experts
remain inactive for that particular computation.

This sparse activation offers a significant advantage: it enables the creation of models
with an immense total number of parameters, far exceeding what would be
computationally feasible with a dense architecture. Yet, during inference, the
computational cost remains remarkably low, comparable to that of a much smaller,
dense model. This decoupling of model capacity from active computation is a core
strength of MoE.

Pioneering research and development have demonstrated the effectiveness of MoE.


Notable examples include Google's Switch Transformer, one of the earliest and most
influential applications of this paradigm. More recently, open-source initiatives have
further democratized MoE, with powerful models like Mistral AI's Mixtral 8x7B and
Databricks' DBRX showcasing its capabilities in a wide range of applications.

While MoE facilitates greater model capacity and promotes specialization within the
network by allowing different experts to learn distinct patterns or handle specific
types of data, it also introduces certain challenges. One primary concern is ensuring
that the computational load is balanced across the various experts. An imbalanced
load can lead to some experts being underutilized while others are overloaded,
potentially hindering overall efficiency and performance. Addressing this load
balancing issue is an active area of research and development within the MoE field.

Attention Optimisations and KV Caching


Optimizing inference for large language models (LLMs) to achieve both speed and
cost-effectiveness presents a significant engineering challenge. A fundamental
optimization technique for auto-regressive generation, where tokens are predicted
sequentially, is KV Caching. During this process, the Key (K) and Value (V) vectors
corresponding to all previously generated tokens in the context are stored in
high-speed GPU memory. This caching mechanism prevents redundant
recomputations of these vectors for each new token generated, leading to substantial
speed improvements.

However, a key limitation of KV Caching is that the size of this cache grows linearly
with the sequence length. For very long contexts, this can lead to an enormous
memory footprint, often exceeding the available on-chip memory of GPUs. This
memory constraint becomes a bottleneck, especially for applications requiring
extended conversational history or processing lengthy documents.

To mitigate this memory overhead, advanced attention variants have been developed.
One such variant is Multi-Query Attention (MQA), employed in models like Google's
PaLM. In MQA, all query heads within the attention mechanism share a single K and V
projection. This design drastically reduces the number of K and V vectors that need to
be cached, thereby shrinking the cache's memory footprint. While highly efficient in
terms of memory, MQA can sometimes lead to a slight reduction in model quality
compared to traditional multi-head attention, as the shared projections might limit the
model's ability to attend to different parts of the input with distinct representations.

A popular and effective compromise between memory efficiency and model quality is
Grouped-Query Attention (GQA). Models such as Meta's Llama 2 utilize GQA. In
this approach, query heads are divided into groups, and each group shares a common
K and V projection. This method offers a significant reduction in the cache's memory
footprint compared to the standard multi-head attention, albeit less aggressively than
MQA. Crucially, GQA generally retains higher quality than MQA, striking a better
balance between computational efficiency and model performance. This makes GQA
particularly advantageous for enabling practical and high-quality long-context
inference, allowing LLMs to process and generate much longer sequences of text
effectively. The continuous innovation in attention mechanisms and caching strategies
remains crucial for pushing the boundaries of what large language models can
achieve in terms of speed, scale, and performance.

Multimodality: Beyond Text


The next frontier in artificial intelligence hinges on its capacity to interpret and
synthesize information across diverse sensory inputs simultaneously. Multimodal
models represent a pivotal advancement, extending the foundational Transformer
architecture beyond its original text-centric design to seamlessly integrate and
process modalities such as images, audio, video, and other complex data types.

This breakthrough is fundamentally enabled by specialized encoders, often referred to


as "tokenizers," which are meticulously engineered to transform disparate forms of
data into a unified, common vector representation. This standardized format is then
readily digestible by the Transformer's attention mechanisms. For instance, a Vision
Transformer (ViT) exemplifies this approach by segmenting an image into discrete
"patch embeddings." These embeddings are subsequently treated as a sequence of
tokens, analogous to text tokens, allowing for their joint processing within the same
model. This innovative methodology empowers a singular AI model to undertake
sophisticated cross-modal reasoning tasks, wherein it can deduce relationships and
draw insights by correlating information from different modalities.

Leading the charge in this multimodal revolution are flagship models that demonstrate
an unparalleled ability to perceive and interpret the world more holistically. OpenAI's
GPT-4V(ision) showcases remarkable capabilities in analyzing visual inputs, ranging
from photographs to intricate diagrams. Similarly, Google's Gemini family of models
stands out as a testament to native multimodality, having been conceived and
developed from its inception with the inherent capacity to process and understand
information across various modalities simultaneously. This foundational design allows
Gemini models to intricately analyze user-provided images, elaborate diagrams, and
dynamic videos, thereby fostering a more comprehensive and human-like
understanding of context, content, and the intricate relationships within our complex
world. The convergence of these capabilities is propelling AI towards a future where
systems can interact with and comprehend the world with a richness and depth
previously unattainable.

The Problem of Hallucination


A fundamental challenge with large language models (LLMs) is their inherent
tendency to hallucinate. This phenomenon refers to the generation of fluent,
confident-sounding statements that, despite their convincing delivery, are factually
incorrect or nonsensical. This behavior is not a flaw in their design but rather a direct
byproduct of their core training objective: to predict statistically plausible sequences
of text based on the vast datasets they were trained on. Their primary goal is to
generate grammatically correct and contextually relevant language, not to verify the
truthfulness or accuracy of the information they present. This distinction is crucial, as
it explains why an LLM can confidently present false information as if it were fact.

Mitigating hallucination is a critical and ongoing area of research, essential for


building trustworthy and reliable artificial intelligence systems. Various strategies are
being developed and implemented to address this challenge.
One of the most prominent and effective strategies is Retrieval-Augmented
Generation (RAG). RAG addresses the hallucination problem by providing the LLM
with an "open book" of external, verifiable knowledge to consult during the generation
process. Instead of relying solely on its fallible parametric memory (the knowledge
encoded within its neural network during training), the model first retrieves relevant
information from a designated knowledge base (e.g., a database, a set of documents,
or the internet). This retrieved information then serves as a factual ground for the LLM
to generate its output, significantly reducing its propensity to hallucinate. This
approach enhances the factual accuracy and trustworthiness of the generated text by
grounding it in verifiable external data.

Beyond RAG, other promising approaches are being explored:

●​ Specialized Fine-Tuning on Fact-Checking Datasets: This involves further


training LLMs on datasets specifically designed for fact-checking. These
datasets contain statements labeled as true or false, along with evidence
supporting the labels. By fine-tuning on such data, models can learn to identify
and correct factual inaccuracies, improving their ability to discriminate between
truthful and untruthful information. This process aims to instill a greater sense
of factual awareness in the model's output.
●​ Developing Methods for Self-Critique: This category encompasses
techniques where the LLM is trained to evaluate and revise its own outputs.
One notable example is the Constitutional AI approach, famously used to train
Anthropic's Claude models. This method involves teaching a model to revise its
own outputs to be more factual, helpful, and harmless based on a set of
predefined guiding principles or "constitution." Crucially, this is achieved
without direct human labels for every undesirable output. Instead, the model
learns to identify and correct problematic generations by comparing them
against these principles. This iterative self-correction mechanism empowers
the model to refine its own responses, leading to more reliable and ethically
aligned outputs. The principles often encode concepts like avoiding harmful
content, ensuring factual accuracy, and refusing to engage in illegal or
unethical activities. By internalizing these rules, the model can engage in a form
of automated alignment, reducing the need for extensive human supervision in
the fine-tuning process.

These diverse strategies collectively aim to enhance the reliability and trustworthiness
of LLMs, moving them beyond mere statistical language generation towards more
factually grounded and ethically sound AI systems. The ultimate goal is to enable
LLMs to provide information that is not only fluent and coherent but also
demonstrably accurate and safe.

Interpretability and Mechanistic Interpretability


As AI models, particularly large language models, become increasingly sophisticated
and pervasive across various applications, the imperative to understand their internal
mechanisms grows exponentially. This understanding is paramount for ensuring
safety, facilitating effective debugging, and fostering user trust. The field dedicated to
this profound challenge is interpretability.

Historically, initial endeavors into model interpretability often centered on analyzing


attention weights within neural networks. While these efforts provided some insights
into what parts of the input a model was "attending" to, they often offered a
high-level, correlational view rather than a precise, causal explanation of the model's
decision-making process.

A more ambitious and scientifically rigorous subfield has emerged, known as


mechanistic interpretability. This discipline takes a profound dive into the model's
architecture, aiming to meticulously reverse-engineer the exact algorithms and
computations learned and executed by the model's vast array of parameters. The goal
is not merely to correlate observable activations with model behavior, but to
deconstruct the model into its constituent computational pathways. This involves
identifying specific "circuits" of neurons—interconnected groups of neurons that
collaborate to perform discrete, identifiable behaviors. Examples of such behaviors
include sophisticated linguistic tasks like discerning sarcasm, or fundamental
cognitive operations such as indirect object identification.

Leading research institutions and AI labs, notably Anthropic, have made significant
strides in this domain. Their work has involved systematically dissecting smaller
models to pinpoint and characterize these intricate neural circuits. These
breakthroughs represent critical steps towards demystifying the "black box" nature of
complex AI systems.

The ultimate aspiration of mechanistic interpretability is to transcend mere correlation


and achieve a complete, human-understandable description of the underlying
algorithm that the model has internally constructed and learned. This means moving
beyond an opaque, input-output understanding to a precise, causal explanation of
why and how the model behaves in a certain way. Such a comprehensive
understanding would revolutionize our ability to build more reliable, robust, and
transparent AI systems, paving the way for truly trustworthy artificial intelligence. It
would also enable more effective debugging, allowing developers to identify and
rectify biases or errors at a foundational level, rather than just observing their
manifestations.

The Evolving Story of Scaling Laws: From Training to "Thinking"


The development of modern large language models (LLMs) has been guided by a
powerful set of empirical principles known as scaling laws. These are not laws of
physics, but predictable relationships that describe how a model's capabilities
improve as key resources are increased. Initially, these laws focused exclusively on the
training phase, establishing a roadmap for building bigger and better models. More
recently, this understanding has evolved to include the inference phase, revealing that
a model's performance isn't fixed after training but can be significantly enhanced by
giving it more time to "think."

The Foundational Laws of Training

The first major breakthrough in this area came from researchers at OpenAI (Kaplan et
al., 2020), who discovered that a model's performance, measured by its predictive
error (loss), improves in a predictable power-law relationship as three factors are
scaled up:

1.​ Model Size (N): The number of parameters in the model.


2.​ Dataset Size (D): The amount of text data used for training.
3.​ Compute (C): The total processing power used during training.

This insight was revolutionary because it allowed organizations to forecast the


performance of a massive, un-trained model by extrapolating from smaller
experiments. It provided the confidence to invest the immense resources required to
build foundational models like GPT-3.

A crucial refinement to these laws came from DeepMind's 2022 Chinchilla paper.
Researchers there addressed a key question: for a fixed compute budget, what is the
optimal balance between model size and data size? Their findings recalibrated the
entire field. They demonstrated that for optimal performance, model size and dataset
size must be scaled in roughly equal proportion.

This revealed that many earlier models were "undertrained"—they were too large for
the amount of data they had been fed. The Chinchilla paper proved that a smaller
model (70B parameters) trained on a massive dataset could outperform a much larger
model (280B parameters) trained on less data. This established that data, not just
model size, was a critical bottleneck.

The New Frontier: Test-Time Compute Scaling

While the training laws were transformative, they treated a model's performance as
static once training was complete. The newer test compute scaling law challenges
this assumption. It describes how a model's performance can be dramatically
improved by allocating more computational power at the moment of
inference—that is, when it's actually generating an answer to a prompt.

The core idea is analogous to human problem-solving:

●​ Low Test Compute: A quick, gut-reaction answer. It's fast but more prone to
error.
●​ High Test Compute: A slow, deliberate process of reasoning, checking work,
and exploring options before giving a final answer. It takes more effort but is far
more reliable.

LLMs can leverage several strategies to "think longer" and scale their test-time
compute:

●​ Best-of-N Sampling: The model generates multiple (N) different answers, and
a verifier (or a majority vote) selects the best one.
●​ Chain-of-Thought (CoT) & Self-Refinement: The model first generates a
step-by-step reasoning process. It can then be prompted to critique its own
logic and iteratively correct its mistakes.
●​ Tree Search: A more advanced method where the model explores multiple
divergent lines of reasoning simultaneously, dedicating more compute to the
most promising paths.
The Unified View: A New, More Complex Trade-Off

The emergence of test compute scaling adds a new dimension to the LLM
optimization puzzle. The guiding principle is no longer just about balancing model size
and training data. It's now understood that there is a crucial trade-off between a
model's size and its inference-time "thinking" budget.

The key takeaway is that a smaller, more efficient model given more compute at test
time can often outperform a much larger model that provides a quick, single-shot
answer. This has profound implications, suggesting that the path to more capable AI
lies not just in building ever-larger models, but also in developing more sophisticated
reasoning techniques that allow models of all sizes to use their knowledge more
effectively.

Alternative Architectures
WIP

State-Space Models (Mamba and beyond)


State-Space Models (SSMs) as a Transformer Alternative: The Rise of Mamba

The pursuit of more efficient and scalable architectures in deep learning has led to a
re-examination of sequential processing mechanisms, particularly in light of the
quadratic complexity bottleneck inherent in the self-attention mechanism of
Transformers. A highly promising alternative that has emerged are State-Space
Models (SSMs), with Mamba standing out as a notable advancement. These models
draw inspiration from classical control systems and offer a compelling solution to the
challenges posed by extremely long sequences.

Core Principles of SSMs: At their heart, SSMs process sequences in a recurrent


fashion, maintaining a hidden "state" that evolves over time based on the input. This
recurrent nature inherently allows them to scale linearly (O(n)) with sequence length,
a significant advantage over the O(n^2) complexity of Transformers. However, unlike
traditional Recurrent Neural Networks (RNNs) which are often slow due to their strictly
sequential processing during training, SSMs can be ingeniously formulated to facilitate
parallel training, much like a convolutional network. This crucial distinction allows
them to leverage the computational efficiency of modern hardware during the
learning phase while retaining the benefits of recurrent processing for inference.

Mamba's Key Innovation: Content-Aware Selection: While previous linear-time


models often struggled to match the performance of attention mechanisms on tasks
requiring nuanced, content-dependent interactions, Mamba introduces a pivotal
innovation: a content-aware selection mechanism. This mechanism enables the
model's internal parameters (specifically, the transition matrices and input/output
mappings of the state-space) to dynamically change based on the current input. This
is a radical departure from static, fixed parameters in many recurrent architectures.
The ability to "selectively focus on or ignore information along the sequence"
empowers Mamba to adapt its internal representations to the evolving context of the
input, effectively addressing the critical weakness of prior linear-time models. This
dynamic adaptation allows Mamba to mimic, to some extent, the selective information
processing capabilities that make attention so powerful, but within a linear-time
framework.

The Best of Both Worlds: This unique combination of features positions SSMs like
Mamba as highly attractive alternatives to Transformers. They offer:

●​ Efficient, Parallelizable Training: By formulating the recurrent process in a


way that allows for parallel computation, SSMs can be trained efficiently on
modern accelerators, akin to the training paradigm of Transformers. This
overcomes a major hurdle that historically limited the practical applicability of
many linear-time recurrent models.
●​ Fast, Constant-Memory Inference: During inference, SSMs operate in a truly
recurrent manner. This means that the memory consumption remains constant
regardless of the sequence length, as the model only needs to store its current
hidden state. This is a significant advantage over Transformers, where the
attention mechanism's memory footprint grows quadratically with sequence
length, making them computationally intensive for very long inputs.

Applications and Future Directions: The inherent advantages of SSMs make them
exceptionally well-suited for tasks involving extremely long sequences. Prime
examples include:

●​ Genomics: Analyzing vast DNA and RNA sequences, where the relationships
between distant elements can be crucial.
●​ High-resolution Time-Series Analysis: Processing extensive streams of
sensor data, financial fluctuations, or environmental readings, where capturing
long-range dependencies is vital.
●​ Audio and Video Processing: Handling long temporal dependencies in raw
audio waveforms or video frames.

The development of Mamba and other advanced SSMs signifies a significant step
forward in building more efficient and scalable deep learning models. They offer a
promising path to overcome the computational limitations of current state-of-the-art
architectures, opening up new possibilities for tackling complex problems in domains
characterized by vast and intricate sequential data. Further research will likely explore
more sophisticated content-aware mechanisms, hybrid architectures combining SSMs
with other modules, and broader applications across various sequence modeling
tasks.

Diffusion Transformers (DiT)


Diffusion models have become the state-of-the-art for high-fidelity image generation,
working by progressively denoising a variable from pure noise. While the U-Net
architecture has traditionally been the backbone of these models, recent work has
shown that replacing it with a Diffusion Transformer (DiT) yields superior
performance and scalability. A DiT operates on a sequence of image patches, or
"tokens," much like an NLP Transformer operates on a sequence of word tokens. This
allows the model to leverage the self-attention mechanism to better model
long-range, global relationships between different parts of the image, which is crucial
for generating coherent and globally consistent scenes. The original DiT paper from
Berkeley AI Research established the architecture's effectiveness, and the power of
this approach for video generation has been spectacularly demonstrated by models
like OpenAI's Sora, which uses a Transformer architecture to generate video of
unprecedented quality and duration.

The Path to Artificial General Intelligence (AGI)


The burgeoning capabilities of large-scale Transformer models have ignited
considerable debate regarding their potential role in the pursuit of Artificial General
Intelligence (AGI). A pivotal question at the heart of this discussion is whether the
current methodology of simply scaling up these models will ultimately be sufficient to
achieve AGI, or if more profound, fundamentally new architectural breakthroughs are
indispensable. While contemporary Large Language Models (LLMs) undeniably exhibit
impressive "sparks" of general reasoning, they conspicuously lack several critical
facets of human intelligence. These deficiencies include, but are not limited to, a
robust causal understanding, the capacity for sophisticated long-term planning, and
the crucial ability to learn effectively from rich, interactive feedback received directly
from the physical world, often referred to as embodied learning.

One school of thought strongly advocates that the continued scaling of these models,
in terms of parameters, data, and computational resources, will eventually bridge the
gap to AGI. Proponents of this view often point to emergent behaviors observed in
larger models that were not present in their smaller counterparts. Conversely, another
significant camp firmly contends that Transformers, by their very nature, are
fundamentally limited as passive pattern-matching systems. They argue that these
models excel at statistical correlations but lack the underlying mechanisms necessary
for true understanding, reasoning, and agency.

A highly probable and increasingly favored path forward involves the creation of
sophisticated hybrid systems. These innovative architectures would strategically
combine the potent perceptual and linguistic abilities inherent in Transformers – their
capacity for processing vast amounts of text and generating coherent language –
with other distinct architectures. These complementary architectures would be
specifically dedicated to functionalities such as world modeling (creating internal
representations of the environment), symbolic reasoning (manipulating abstract
symbols and rules), or agency (the ability to act autonomously and intentionally within
an environment). Such a synergistic approach aims to harness the strengths of
different computational paradigms, potentially overcoming the individual limitations of
each and accelerating progress towards the elusive goal of Artificial General
Intelligence.

Handson
WIP A full working transformers in pytorch or jax

Key Takeaways
This section summarizes the most critical concepts covered in the document,
providing a high-level review of the Transformer architecture and its ecosystem.
This section summarizes the most critical concepts covered in the document,
providing a high-level review of the Transformer architecture and its ecosystem.

On Architecture & Theory:

●​ Attention is the Core: The Transformer's key innovation is self-attention, which


models long-range dependencies and enables massive parallelization.
●​ Encoder-Decoder Structure: The original Transformer consists of an encoder
to process input and a decoder to generate output. Modern generative models
often use a decoder-only architecture for simplicity and scalability.
●​ Multi-Head Attention is Key: Performing several attention calculations in
parallel allows the model to focus on different types of relationships
simultaneously.
●​ Position Matters: Positional Encodings are added to the input to provide the
model with information about the order of tokens.
●​ Quadratic Bottleneck: The primary limitation of standard attention is its O(n²)
complexity with respect to sequence length n, making it challenging to process
very long sequences.
●​ The Scaling Laws are a Guiding Principle: Empirical laws show that model
performance predictably improves as model size, dataset size, and compute
are scaled up, which has driven the trend towards building ever-larger models.

On Training and Adaptation:

●​ Pre-training is Foundational: Models learn general-purpose knowledge


through self-supervised pre-training on vast amounts of unlabeled text.
●​ Fine-tuning Specializes: A pre-trained base model is adapted for specific
tasks through fine-tuning on smaller, labeled datasets.
●​ PEFT for Efficiency: Parameter-Efficient Fine-Tuning (PEFT) methods like
LoRA allow for the adaptation of massive models with very few trainable
parameters.
●​ Alignment for Safety: Techniques like Reinforcement Learning from Human
Feedback (RLHF) are used to steer the model's behavior to be more helpful,
harmless, and honest.

On Context and Interaction:

●​ The Context Window is a Limit: The context window is the fixed maximum
number of tokens a model can process at once.
●​ Extending Context is an Active Research Area: Techniques like sliding
window attention, recurrence (Transformer-XL), and Retrieval-Augmented
Generation (RAG) are used to mitigate the limitations of the fixed context
window.
●​ Prompting is Programming: The behavior of LLMs is controlled through
carefully crafted input prompts, with Chain-of-Thought (CoT) prompting being
a key technique for improving reasoning.
●​ Decoding Shapes the Output: The algorithm used for inference (e.g., greedy
search, beam search, nucleus sampling) determines the characteristics of the
generated text, managing the trade-off between coherence and creativity.

On the Future:

●​ Architectural Evolution is Constant: The field is rapidly exploring alternatives


and enhancements to the original Transformer, including efficient
Mixture-of-Experts (MoE) models and linear-time State-Space Models
(Mamba).
●​ Multimodality is the Next Frontier: Transformers are being extended beyond
text to process images, audio, and video, moving towards a more holistic
understanding of information.
●​ Truthfulness and Interpretability are Grand Challenges: Overcoming model
hallucination and achieving true mechanistic interpretability are critical for
building safe and reliable AI systems.
●​ Hardware and Software Co-design is Crucial: Advancements in AI are deeply
intertwined with innovations in specialized hardware (ASICs) and efficient
inference techniques like KV caching.

Conclusion
The Transformer architecture, detailed in "Attention is All You Need," marked a
significant shift in machine learning. Its use of parallel processing through
self-attention, replacing the sequential nature of recurrent networks, has
revolutionized our ability to model language and other sequential data. This powerful
and scalable architecture, combined with the paradigms of self-supervised
pre-training and transfer learning, has ushered in the era of Large Language Models
(LLMs), which exhibit remarkable and often emergent capabilities. We have explored
the model's fundamental components, its lifecycle from pre-training to alignment, and
the critical limitations imposed by its quadratic attention mechanism.

The journey, however, is far from over. The research landscape remains intensely
dynamic, pushing beyond the original design to address its core weaknesses and
unlock new frontiers. We are witnessing a Cambrian explosion of architectural
diversity, from computationally efficient Mixture-of-Experts and linear-time
State-Space Models to the application of Transformers in entirely new domains like
diffusion-based image generation. The grand challenges of hallucination and
interpretability are now central research efforts, as we strive to build models that are
not only capable but also truthful and understandable. This progress is inextricably
linked to co-evolving hardware and sophisticated inference optimizations like KV
caching, which make these massive models practical.

Interacting with these models is a newly emerging skill, where prompt engineering and
the selection of decoding strategies define the boundary of what is possible. The
progression from a simple zero-shot prompt to complex chain-of-thought reasoning
underscores that we are not merely using a tool but learning to collaborate with a new
form of intelligence. Whether the Transformer proves to be the ultimate architecture
on the path to AGI or a crucial stepping stone to future discoveries, its invention will
be remembered as a pivotal moment—the point when we truly understood that, for
building intelligence at scale, attention is a magnificent start. The upcoming chapter
will focus on the new frontier of Large Models, Reasoning, and the techniques
enabling these advancements.

Glossary of Key Terms


●​ Attention: A mechanism that allows a neural network to weigh the importance
of different parts of an input sequence.
●​ Auto-regressive: A model that generates a sequence one token at a time,
where each new token is conditioned on the previously generated tokens.
●​ Context Window: The fixed maximum number of tokens a Transformer model
can process at once.
●​ Cross-Attention: The attention mechanism in the decoder that attends to the
output of the encoder, linking the input and output sequences.
●​ Decoder: The part of the Transformer that generates the output sequence.
●​ Encoder: The part of the Transformer that processes the input sequence to
create a contextualized representation.
●​ Fine-Tuning: The process of adapting a pre-trained model to a specific task
using a smaller, labeled dataset.
●​ Hallucination: The tendency of a language model to generate confident but
factually incorrect or nonsensical information.
●​ In-Context Learning: The ability of a large language model to perform a task
based on a few examples provided in its prompt, without any weight updates.
●​ LLM (Large Language Model): A very large neural network, typically a
Transformer, trained on vast amounts of text data.
●​ Masked Language Modeling (MLM): A self-supervised pre-training objective
where the model learns to predict masked tokens in a sequence.
●​ Multi-Head Attention: A mechanism that runs multiple attention calculations
in parallel to capture different types of relationships in the data.
●​ PEFT (Parameter-Efficient Fine-Tuning): A set of techniques (like LoRA) for
adapting a large model by training only a small fraction of its parameters.
●​ Positional Encoding: Information about the position of tokens in a sequence
that is added to the input embeddings.
●​ Pre-training: The initial, computationally intensive phase of training a model on
a massive, unlabeled dataset.
●​ Prompt: The input text given to a language model to elicit a response.
●​ RAG (Retrieval-Augmented Generation): A technique that combines a
language model with an external knowledge base to improve the factuality of
its outputs.
●​ RLHF (Reinforcement Learning from Human Feedback): A technique for
aligning a model with human preferences by using a reward model trained on
human-ranked responses.
●​ Self-Attention: An attention mechanism where a sequence processes itself,
allowing every token to attend to every other token.
●​ Token: A unit of text, which can be a word, a subword, or a character, that is
used as input to the model.
●​ Transfer Learning: A machine learning paradigm where a model trained on
one task is repurposed for a second, related task.
Reasoning Engines
Introduction
By mid-2025, the field of artificial intelligence is undergoing a profound
transformation. The prevailing focus on generative Large Language Models (LLMs) is
giving way to a more advanced paradigm: deliberative Large Reasoning Models
(LRMs). This evolution marks a significant leap from systems adept at pattern
recognition and token prediction to architectures engineered for complex, multi-step
problem-solving and logical deduction. A reasoning model is fundamentally an LLM
that has been specialized to solve problems requiring multiple intermediate steps,
such as intricate puzzles, advanced mathematics, and complex coding challenges. A
key characteristic of these models is their ability to produce a transparent and
verifiable "thought process," articulating the intermediate steps taken to reach a final
answer.

However, these powerful models are not a universal solution and should be applied
selectively. Their primary strength lies in solving complex problems that benefit from
step-by-step analysis. Conversely, they are computationally more expensive and
slower to run, tend to be verbose in their outputs, and can be inefficient or even prone
to errors from "overthinking" simple problems for which standard LLMs are better
suited.

The competitive hierarchy as of September 2025 is led by proprietary models from


industry titans (see lmarena.com for public benchmarks) OpenAI's GPT-5 and its
'o-series' have set new benchmarks in general-purpose reasoning, while Google's
Gemini 2.5 Pro excels as a multimodal savant with a massive context window.
Anthropic's Claude 4 Opus specializes in long-horizon agentic tasks, and xAI's Grok 3
integrates real-time information with transparent reasoning. Concurrently, the
open-source community, with innovators like Meta's Llama 4 and DeepSeek's R1
models, has become a formidable force, pushing the boundaries of scale, efficiency,
and accessibility.

This technological revolution is underpinned by a convergence of key techniques and


architectural choices. The development of LRMs is largely driven by four main
approaches. The first, Inference-Time Scaling, enhances reasoning without
retraining the model by dedicating more computational resources during response
generation. This method, which includes techniques like Chain-of-Thought (CoT)
prompting and allowing models to "think" longer on difficult problems (Test-Time
Computation), improves effectiveness at the cost of speed and expense.

A second, more novel approach is Pure Reinforcement Learning (RL), where a base
model learns reasoning as an emergent behavior, providing valuable research insights.

The current gold standard for building high-performance models is a combination of


Supervised Fine-Tuning and Reinforcement Learning (SFT + RL). This multi-stage
process systematically builds reasoning skills by first training a model on a dataset of
reasoning examples and then refining its performance using RL, often incorporating
AI-generated feedback (RLAIF).

Finally, the Pure SFT and Distillation method offers a practical path to creating
smaller, more efficient models. This technique involves training a smaller model on a
high-quality dataset of reasoning examples generated by a superior "teacher" model,
effectively imparting strong reasoning skills to more accessible hardware.
Architecturally, Sparse Mixture-of-Experts (MoE) has been crucial for scaling these
models to trillions of parameters while managing computational demands.

The rapid advancements in large reasoning models (LRMs) are poised to revolutionize
high-level knowledge work across sectors such as finance, law, and science, ushering
in an autonomous agent economy. This shift promises substantial productivity gains,
with some economists predicting an addition of up to $15.7 trillion to global GDP by
2030 due to AI.

However, this transformation also presents critical challenges, including potential


labor market disruptions and increased inequality. As LRMs develop more
sophisticated reasoning capabilities, it becomes crucial to guide their application
responsibly. The complex problem-solving abilities that lead to beneficial outcomes
could also result in actions misaligned with human intentions. Recent studies
emphasize the importance of robust safety and alignment research to address these
concerns.

The future of AI lies in harnessing the immense potential of these systems while
diligently ensuring their trustworthy and beneficial operation. This approach will foster
societal confidence and promote positive human-AI collaboration. The advanced
analytical capabilities of LRMs are already driving significant change in
knowledge-based industries, serving as a precursor to the broader economic shift
brought by the autonomous agent economy. LRMs provide the "brain" for AI agents
that can function as independent economic actors, performing tasks and generating
value with minimal human oversight.

The Architectural Foundations of Modern AI


Reasoning
The advanced capabilities of modern LLMs are not the result of a single breakthrough
but rather a convergence of several key techniques that form a "reasoning stack." This
stack provides the structure for logical thought, the mechanism for aligning it with
desired outcomes, and the flexibility to apply computational resources dynamically.
Two major pillars of this stack are Inference-Time scaling and more structured
reasoning frameworks like Chain of Thought and Tree of Thoughts.

Inference-Time Scaling: Deeper Thinking on Demand


Inference-time scaling is a strategy that improves a model's reasoning capabilities
without altering or retraining the model itself. The core idea is to dedicate more
computational resources—essentially, more processing power and time—at the very
moment a response is being generated (at "inference"). This encourages the model to
engage in a deeper, more elaborate "thinking" process rather than providing a quick,
superficial answer.

Think of it like human problem-solving. If asked a simple question, you might give an
immediate, "off-the-cuff" answer. This is computationally "cheap." However, if asked a
complex multi-part question, you would pause, think through the steps, perhaps even
visualize different scenarios, before formulating a response. This deliberate, slower
process is your brain "scaling" its cognitive effort in real-time. Inference-time scaling
does the same for an LLM.

Chain of Thought (CoT)


This result is achieved through techniques like Chain-of-Thought (CoT) prompting,
where instructions like "think step-by-step" compel the model to generate a longer,
more detailed sequence of tokens. This increased token generation directly translates
to more computation at inference time, which often leads to more accurate results.
More advanced methods involve using voting or search algorithms to explore multiple
potential answer paths and then select the most logical one.
The fundamental trade-off is performance versus efficiency. While these techniques
make the model significantly more effective at complex reasoning, they also make
each query slower and more computationally expensive to run.

Example of Scaling: From a Simple Answer to a Voted Consensus

Let's illustrate how scaling the computational effort at inference time can drastically
improve the outcome for a simple arithmetic problem.

1. Baseline (Minimal Inference Cost): This approach uses the least amount of
computation. The model is prompted for a direct answer.

Prompt: Q: A restaurant has 15 tables. They buy 5 more, and then put 8 tables into
storage. How many tables are in use?

Typical (Incorrect) Answer: 20 tables.

●​ Analysis: The model performs a simple, immediate calculation (15 + 5) but fails
to complete the second step. This is a low-effort, low-accuracy response,
generated very quickly.

2. Chain-of-Thought (Medium Inference Cost): Here, we scale the computation by


prompting the model to generate its reasoning process. This requires generating more
tokens and thus more processing.

Prompt: Q: A restaurant has 15 tables. They buy 5 more, and then put 8 tables into
storage. How many tables are in use? Let's think step-by-step.

CoT (Correct) Answer:

●​ Step 1: The restaurant starts with 15 tables.


●​ Step 2: They buy 5 more, so we add them: 15 + 5 = 20 tables.
●​ Step 3: They put 8 tables into storage, so we subtract them from the current
total: 20 - 8 = 12 tables.
●​ Final Answer: There are 12 tables in use.
●​ Analysis: By forcing the model to articulate each step, we increase the
computational load at inference. This extra work prevents it from jumping to a
premature conclusion and leads to the correct answer.
3. Self-Consistency (High Inference Cost): This is a more advanced scaling
technique that significantly increases the computational load for the highest possible
accuracy. Here, we prompt the model to generate multiple different Chain-of-Thought
reasoning paths and then use a "vote" to determine the most consistent final answer.

Prompt: Q: A restaurant has 15 tables. They buy 5 more, and then put 8 tables into
storage. How many tables are in use? Please provide three distinct step-by-step
solutions.

Generated Reasoning Paths:

●​ Path 1: Start with 15. Add 5, which gives 20. Then subtract 8, which results in
12. -> Final Answer: 12
●​ Path 2: First, calculate the new total tables: 15 + 5 = 20. From these 20 tables, 8
are removed. So, 20 - 8 = 12. -> Final Answer: 12
●​ Path 3 (A different but still valid path): The change in tables is +5 and then
-8, for a net change of 5 - 8 = -3 tables. The starting amount was 15, so 15 - 3 =
12. -> Final Answer: 12

Voting Mechanism: All three independently generated reasoning paths converged on


the answer '12'.

Final Consensus Answer: 12 tables.

●​ Analysis: This method is the most computationally expensive because it


essentially runs the CoT process multiple times. However, by finding a
consensus among diverse reasoning paths, it produces a highly reliable and
robust answer, filtering out potential random errors. This is a prime example of
trading speed and cost for a superior reasoning outcome at inference time.

Let’s see some additional examples of Chain of Thoughts:

4. Arithmetic Reasoning : This type of task requires following a specific order of


operations.

Prompt: A cafeteria had 23 apples. If they used 20 for lunch and bought 6 more, how
many apples do they have now?
Typical (Incorrect) Answer:
29 apples. (The model incorrectly adds 23 + 6, ignoring the subtraction.)
Chain-of-Thought Prompt: A cafeteria had 23 apples. If they used 20 for lunch and
bought 6 more, how many apples do they have now?
A: Let's think step-by-step.
CoT (Correct) Answer:
Step 1: The cafeteria starts with 23 apples.
Step 2: They use 20 apples, so we subtract 20. 23 - 20 = 3 apples.
Step 3: They buy 6 more apples, so we add 6 to the current amount. 3 + 6 = 9 apples.
Final Answer: The cafeteria now has 9 apples.

CoT works by forcing the model to process each part of the problem sequentially,
preventing it from making simple errors by jumping to conclusions.

5:. Commonsense Reasoning: This requires connecting different pieces of


real-world knowledge to reach a logical conclusion.

Prompt: What will likely happen if I leave a chocolate bar in a car on a summer day in
London?
Typical (Incomplete) Answer:
It will get warm.

Chain-of-Thought Prompt: What will likely happen if I leave a chocolate bar in a car
on a summer day in London?
A: Let's reason through the steps.

CoT (Correct) Answer:


Step 1: A car left in the sun acts like a greenhouse, and its internal temperature can
become much higher than the outside temperature.
Step 2: A summer day in London can be warm enough for this effect to be significant.
Step 3: Chocolate is a solid that melts at a relatively low temperature (around 30-35°C
or 86-95°F).
Step 4: The temperature inside the car will likely exceed the melting point of
chocolate.
Final Answer: The chocolate bar will melt.
This approach is effective because it prompts the model to integrate information
about vehicles, weather conditions, and the physical characteristics of chocolate,
resulting in a more comprehensive and precise response.

6. Symbolic Reasoning: This involves following abstract rules to manipulate symbols.

Prompt: Take the last letters of each word in 'Finding Nemo Forever' and reverse
them. What is the new word?
Typical (Incorrect) Answer:
oreveroMgnidniF (The model incorrectly reverses the entire phrase.)

Chain-of-Thought Prompt: Take the last letters of each word in 'Finding Nemo
Forever' and reverse them. What is the new word?
A: Let's break this down into steps.

CoT (Correct) Answer:


Step 1: The words in the phrase are "Finding", "Nemo", and "Forever".
Step 2: The last letter of "Finding" is 'g'.
Step 3: The last letter of "Nemo" is 'o'.
Step 4: The last letter of "Forever" is 'r'.
Step 5: The sequence of last letters is 'g', 'o', 'r'.
Step 6: Reversing this sequence gives 'r', 'o', 'g'.
Final Answer: The new word is 'rog'.

CoT works by ensuring the model executes each operation—identifying words,


extracting letters, and reversing the sequence—in the correct order. This systematic
approach prevents misinterpretations of complex instructions.

Few shots and CoT

Few-shot Chain-of-Thought (CoT) prompting consistently outperforms the simpler


"zero-shot" approach, which typically relies on the instruction "Let's think
step-by-step." While "Let's think step-by-step" is a powerful instruction, providing
explicit examples is a more direct and effective way to guide the model.

To illustrate, consider teaching a person a task. Zero-shot is akin to instructing, "Solve


this calculus problem and show your work." In contrast, few-shot is like saying, "Here
is a solved calculus problem with all the steps. Now, apply the same method to solve
this new one." This analogy highlights why few-shot CoT is more effective: it
demonstrates precisely what is expected from the model. The few-shot approach
offers three primary advantages:

●​ Clearer Task Demonstration: It explicitly shows the model the expected format,
structure, and level of detail for the reasoning process, eliminating ambiguity
about what "step-by-step" entails.
●​ Improved Reliability: By providing a successful example, the model is
significantly more likely to follow the correct logical path, reducing the chances
of errors or getting stuck. It provides a strong "nudge" in the right direction.
●​ Better Handling of Complexity: For problems involving nuanced rules or multiple
constraints, an example is invaluable. It demonstrates how to apply rules
correctly, which can be challenging for a model to infer solely from instructions.

Example: Zero-Shot vs. Few-Shot

Let's use a classic logic puzzle where the reasoning path is specific.

The Puzzle: You have three boxes. One is labeled "Apples", one "Oranges", and one
"Apples & Oranges". You know that every single box is mislabeled. You are allowed to
pick just one fruit from one box (without looking inside) to figure out the correct labels
for all three boxes. Which box do you pick from?

Zero-Shot CoT Prompt

Prompt: You have three boxes: "Apples", "Oranges", and "Apples & Oranges". Every
box is mislabeled. You can pick one fruit from one box to determine the correct labels
for all. Which box should you pick from? Let's think step-by-step.

A plausible (but less clear) Zero-Shot Answer:

To solve this, you need to find a box that gives you the most information. If you pick
from "Apples" and get an orange, you know that's the orange box. But that could be
complicated. The best box is the "Apples & Oranges" one. If you pick an apple from it,
you know its true label is "Apples". Then you can figure out the rest. So the "Apples"
box must be "Oranges", and the "Oranges" box must be "Apples & Oranges".

Final Answer: Pick from the "Apples & Oranges" box.


This answer gets to the right conclusion, but the reasoning is a bit jumbled and less
rigorous.

Few-Shot CoT Prompt

In a few-shot prompt, we first provide a complete example (a "shot") of a different,


but structurally similar, puzzle and its solution.

Prompt: I have three cups labeled "Tea", "Coffee", and "Juice". They are all
mislabeled. Cup 1 contains coffee, Cup 2 contains juice, and Cup 3 contains tea. If I
know that the "Tea" label is not on the cup containing coffee, what is in the cup
labeled "Coffee"?

A: Let's solve this with logical deduction.


1. The cups and their actual contents are: Cup 1 (Coffee), Cup 2 (Juice), Cup 3 (Tea).
2. The labels "Tea", "Coffee", "Juice" are all on the wrong cups.
3. The cup labeled "Tea" is not on Cup 1 (which contains coffee). So the "Tea" label is
on Cup 2 (Juice).
4. The cup labeled "Coffee" cannot be on Cup 1 (Coffee) or Cup 2 (which is now
labeled "Tea"). Therefore, the "Coffee" label must be on Cup 3.
5. Cup 3's actual content is tea.
Final Answer: The cup labeled "Coffee" contains tea.
Q: You have three boxes: "Apples", "Oranges", and "Apples & Oranges". Every box is
mislabeled. You can pick one fruit from one box to determine the correct labels for all.
Which box should you pick from?
A likely (and much clearer) Few-Shot Answer:

A: Let's solve this with logical deduction.


1. The core premise is that every label is incorrect.
2. Consider the box labeled "Apples & Oranges". Because it's mislabeled, it cannot
contain a mix of both. It must contain either only apples or only oranges.
3. This is the key. This box offers a guaranteed piece of information. The other two
boxes ("Apples", "Oranges") could contain one of two possible fruit types.
4. Therefore, I should pick from the box labeled "Apples & Oranges". Let's say I pick a
fruit and it's an **apple**.
5. I now know the true identity of the "Apples & Oranges" box is "Apples".
6. Now consider the box labeled "Oranges". It cannot be "Oranges" (mislabeled) and it
cannot be "Apples" (we've found that box). So, it must be "Apples & Oranges".
7. By elimination, the remaining box labeled "Apples" must contain oranges.
Final Answer: You must pick from the box labeled "Apples & Oranges".

By providing an example, we taught the model to use a systematic, step-by-step


elimination process, leading to a much higher quality and more trustworthy answer.

Tree of Thoughts: From Linear Steps to Exploratory Search


While Chain-of-Thought forces a model down a single, linear reasoning path, Tree of
Thoughts (ToT) elevates this process by allowing the model to explore multiple
different lines of reasoning simultaneously. It functions like a team of experts
brainstorming various solutions to a problem in parallel.

Think of solving a maze. CoT is like walking through the maze by committing to a
single path. If you hit a dead end, you have to backtrack a long way or start over
completely. In contrast, ToT is like sending scouts down several paths at once from
every crossroads. You can quickly see which paths are dead ends and focus your
energy on the ones that look most promising.

The ToT framework allows a model to create a "tree" of possibilities. At each step of a
problem, it performs three key actions:

1.​ Generate: It brainstorms several different next steps or "thoughts," creating


multiple branches in the tree.
2.​ Evaluate: It assesses each new thought, giving it a score based on how likely it
is to lead to a successful solution.
3.​ Search: It decides which branches of the tree to explore further. It can
abandon unpromising, low-scoring branches and even backtrack to a previous
step to pursue a more fruitful alternative. This process uses search algorithms
to navigate the tree efficiently.

This method is much closer to how humans think—we consider various options, weigh
their pros and cons, and change our strategy when one approach isn't working. 🧠
Example: The Traveling Salesperson Problem

Consider a classic logistics problem: "I need to leave my house, visit the Post Office,
the Bank, and the Grocery Store, and then return home. Find the shortest route."
●​ A CoT model might pick one order randomly: Home -> Post Office -> Bank ->
Grocery -> Home. It would calculate the total time for that single route and
present it as the answer, without knowing if it's the optimal one.
●​ A ToT model would explore the different possible routes systematically:
○​ Root: The model knows it has three places to visit: Post Office, Bank,
Grocery.
○​ Level 1 (Generate first stops): It creates three main branches from
Home:
■​ Go to the Post Office first.
■​ Go to the Bank first.
■​ Go to the Grocery Store first.
○​ Level 2 (Explore next stops): For each of the first-stop branches, it
generates sub-branches for the second stop.
■​ From the "Post Office" branch, it creates two new sub-branches:
visiting the Bank next, or visiting the Grocery Store next.
■​ It does the same for the other initial branches.
○​ Level 3 (Complete and evaluate): The model completes every possible
full route (e.g., Home -> Bank -> Post Office -> Grocery -> Home) and
calculates the total travel time for each.
○​ Search & Final Answer: By comparing the total time of all completed
paths in its "tree," the model can confidently identify and present the
path with the absolute shortest travel time.

The Alignment Engine


While generating a chain of thought or three of thoughts are a powerful first step,
ensuring that reasoning is correct, helpful, and safe is a critical alignment challenge.
Reinforcement Learning from Human Feedback (RLHF) became the industry's
standard method for this task.

Reinforcement Learning from Human Feedback (RLHF)

Think of RLHF like training a smart but inexperienced apprentice. You don't just give
them a textbook (the initial training data); you watch them work and give them
feedback, rewarding good decisions and correcting poor ones until their judgment
aligns with yours.

RLHF is a three-stage process: first, it gathers human feedback on the model's


performance; second, it trains a "reward model" to act as an automated human judge;
and third, it fine-tunes the main model to consistently impress that judge. The major
drawback is that gathering high-quality human feedback is slow, expensive, and
difficult to scale.

A Practical Example of RLHF in Action

Let's walk through an example using the prompt: "What's a good way to start investing
with very little money?"

Step 1: Collect Human Preference Data First, the base language model generates
several different answers to the prompt.

●​ Response A: "A great way to start is with a low-cost index fund through a
micro-investing app. These let you invest small amounts, even just a few
pounds. It's also wise to read about the basics of diversification. Always
remember that all investing involves risk."
●​ Response B: "You should put all your money into 'CRYPTO-COIN X'. It's going to
the moon and you'll get rich quick. Don't miss out on this once-in-a-lifetime
opportunity."
●​ Response C: "Investing is too complicated if you don't have a lot of money. It's
probably not worth your time."

A human annotator reviews these responses and ranks them based on helpfulness
and safety. Their preference would clearly be A > C > B. They would create thousands
of these comparisons for many different prompts.

Step 2: Train the Reward Model This human preference data (e.g., for this prompt, 'A is
better than B') is used to train a separate "reward model." This model's only job is to
learn what humans prefer. It reads a prompt and a response, and then outputs a score
that predicts how a human would rate it. It learns that responses mentioning
diversification, risk warnings, and specific, actionable advice get high scores.
Responses that are dismissive, give risky financial advice, or create a sense of FOMO
(Fear Of Missing Out) get very low scores. Essentially, it becomes an automated judge
that has learned human values.

Step 3: Fine-tune the LLM with Reinforcement Learning Now, the main LLM is
fine-tuned in a continuous loop using an algorithm like PPO (Proximal Policy
Optimization).
1.​ The LLM is given the prompt: "What's a good way to start investing..."
2.​ It generates a new answer. Let's say it tries: "Buy shares in a single popular tech
company."
3.​ This answer is shown to the reward model.
4.​ The reward model evaluates it and gives it a medium-low score. It's not as
dangerous as Response B, but it lacks the safety of mentioning diversification
from Response A.
5.​ The reinforcement learning algorithm receives this score as feedback. Since the
score isn't high, it slightly adjusts the LLM's internal parameters to make it less
likely to suggest putting all your money in a single stock.

This process is repeated millions of times. When the LLM generates an answer similar
to the highly-ranked Response A, the reward model gives it a high score (a "reward"),
and the LLM is adjusted to make it more likely to produce these kinds of helpful, safe,
and balanced answers in the future.

Reinforcement Learning from AI Feedback (RLAIF)

RLHF is powerful but expensive and slow. Reinforcement Learning from AI Feedback
(RLAIF) offers a more scalable solution by replacing the human annotator with another
powerful AI.

Think of it this way: if RLHF is like a student driver learning with a human instructor,
RLAIF is like having that student learn in a highly advanced simulator. The simulator
(the AI labeler) can provide feedback far more quickly and for countless scenarios, all
while following a strict set of safety rules (a "constitution"). This AI-driven approach
augments the human feedback bottleneck with a fast, cost-effective, and scalable
process.

A Practical Example of RLAIF in Action

Let's use a prompt that has the potential for both helpful and deceptive answers:
"How can I make my argument in a debate sound more convincing, even if my factual
evidence is weak?"

1. The Constitution: The AI labeler is given a constitution, a set of rules to guide its
judgment. It might include principles like:
●​ Principle 1 (Helpfulness): Prioritize advice that is constructive, ethical, and
honest.
●​ Principle 2 (Harmlessness): Do not endorse or promote deception,
manipulation, or misinformation.
●​ Principle 3 (Integrity): Favor strategies that improve the substance and clarity of
an argument over those that merely obscure weaknesses.

2. The AI Labeler's Task The model being trained generates different responses.

●​ Response A: "You can use strong emotional language and complex vocabulary
to sound authoritative. Another effective tactic is to pivot to a different topic
where you feel more confident, deflecting from your weak points."
●​ Response B: "The most convincing long-term approach is to strengthen your
evidence. If that's not possible, focus on clearly explaining your perspective
and the principles behind it. A well-structured, clearly delivered argument can
be very persuasive, even if it relies more on logic and perspective than on hard
data."

3. The AI's Judgment The AI labeler evaluates these two responses against its
constitution. Its internal "reasoning" would conclude that Response A violates
Principle 2 (Harmlessness), while Response B aligns with Principles 1 and 3. Therefore,
the AI labeler outputs the preference data: B > A.

4. The Outcome: This AI-generated preference label is then used to train a reward
model and fine-tune the original LLM, just as a human label would be in RLHF.

Pure Reinforcement Learning (RL)

This is a novel approach where reasoning ability emerges as a learned behavior. A


base LLM is trained exclusively with Reinforcement Learning, skipping the typical
Supervised Fine-Tuning (SFT) stage.

How it works: The model is rewarded for correct outcomes (e.g., solving a math
problem or generating code that compiles) and for following a specific response
format. The model is not explicitly shown how to reason; it must discover effective
reasoning strategies on its own to maximize its reward. The DeepSeek-Coder-RL
model demonstrated this, showing an "Aha!" moment where it began generating
reasoning steps (like a chain of thought) without being explicitly instructed to, as this
was the best strategy it found to solve the problems correctly.
Use case: This approach is primarily used for research and in specialized domains like
code generation. It is currently less effective for general-purpose chatbots than
combined methods but provides invaluable insight into how reasoning capabilities can
develop organically in models.

Key RL Algorithms in Alignment (PPO and DPO)


This section details essential Reinforcement Learning algorithms, with mathematical
explanations provided in the Appendix.

PPO: The Careful Step-by-Step Optimizer

Proximal Policy Optimization (PPO) is the workhorse algorithm behind the final, most
crucial stage of RLHF. Its primary goal is to update the language model's policy—its
internal strategy for generating text—without taking dangerously large steps.

Imagine you're teaching a dog a complex new trick. If you dramatically change your
commands and rewards all at once, the dog will get confused. Instead, you make
small, consistent adjustments. PPO does the same for the LLM.

A naive update could cause the model to change its parameters too drastically in
pursuit of a high reward, leading to "catastrophic forgetting," where it loses its
fundamental grasp of language. PPO prevents this by using a clipping mechanism. It
defines a small, "safe" trust region around the model's current policy. Any proposed
update that tries to step outside this region is "clipped" back to the boundary.
Mathematically, it ensures the ratio between the new policy (πθ​) and the old policy
(πθold​​) stays within a narrow bound, such as [1−ϵ,1+ϵ]. This forces the model to learn
in stable, incremental steps, ensuring it gets better at the task without breaking its
existing knowledge.

DPO: The Direct Route to Preference Alignment

Direct Preference Optimization (DPO) is a more recent and efficient alignment method
that elegantly bypasses the need for a separate reward model.

Recall that RLHF is a multi-step process: collect data, train a reward model, then use
RL to tune the main model. DPO streamlines this significantly. It works directly from
the raw preference pairs (e.g., "Response A is better than Response B").
DPO reframes alignment as a simple classification task. It directly adjusts the LLM's
parameters to increase the likelihood of generating the preferred response while
decreasing the likelihood of generating the dispreferred one.

If RLHF is like hiring a food critic (the reward model) to give scores and then telling the
chef (the LLM) to aim for higher scores, DPO is like telling the chef directly:
"Customers liked dish A more than dish B. Make more things like A." It's a more direct,
stable, and computationally cheaper way to achieve the same goal, often with better
results.

Excellent. Here is a detailed explanation of the SFT + RL technique, following the style
and structure of the previous content.

Supervised Fine-Tuning + Reinforcement Learning (SFT + RL)


This hybrid technique is the current gold standard for building state-of-the-art
reasoning models, such as those seen in cutting-edge projects like DeepSeek-R1. It's
a powerful, multi-stage process that systematically cultivates advanced reasoning
abilities by combining the strengths of different training paradigms.

Think of it as the ultimate educational program for an AI. It doesn't just learn from a
textbook (SFT), nor does it learn solely by trial and error (RL). Instead, it uses trial and
error to write its own perfect textbook and then studies that textbook before finally
being polished by an expert tutor.

This combination works by first using RL to discover correct reasoning paths and then
using SFT to efficiently teach those paths to the model, before a final RL phase refines
its performance.

A Practical Example of SFT + RL in Action

Let's use a simple logic puzzle to illustrate the multi-stage process.

Prompt: "You have a 5-litre jug and a 3-litre jug, and an unlimited supply of water. How
can you measure out exactly 4 litres?"

Stage 1: The "Cold-Start" RL Phase (Generating the Textbook)


The process begins with a general pre-trained model that has no specialized
knowledge of solving this kind of puzzle. It will generate solutions through trial and
error, guided by a simple reward function.

●​ The Goal: The only goal at this stage is to find a sequence of steps that results
in a correct final answer. The model receives a reward of 1 if its proposed steps
lead to 4 litres, and a reward of 0 otherwise.
●​ Attempt 1: The model generates: "Fill the 5L jug. Pour it into the 3L jug. Empty
the 3L jug. This leaves 2L in the 5L jug. This is not 4L." -> Reward: 0.
●​ Attempt 2: The model generates: "Fill the 3L jug. Pour it into the 5L jug. Fill the
3L jug again. Pour from the 3L jug into the 5L jug until the 5L jug is full. The 5L
jug had 3L, so it takes 2L more. This leaves 1L in the 3L jug. Empty the 5L jug.
Pour the 1L from the 3L jug into the 5L jug. Fill the 3L jug and pour it into the 5L
jug. This gives 1L + 3L = 4L." -> Reward: 1.

This successful attempt—the entire chain of reasoning—is saved. This process is


repeated thousands of times, generating a large dataset of diverse problems and their
correctly reasoned solutions. This dataset becomes the "cold-start" SFT data, our
custom-written textbook on logical reasoning.

Stage 2: Supervised Fine-Tuning (SFT) (Studying the Textbook)

Now, the model learns from the high-quality data it just created.

●​ The Goal: Instead of exploring randomly, the model is now explicitly trained to
imitate the successful reasoning chains. It's shown the prompt and the correct
step-by-step solution, and it adjusts its parameters to learn this mapping
directly.
●​ Training Data:
○​ Input: "You have a 5-litre jug and a 3-litre jug..."
○​ Target Output: "Fill the 3L jug. Pour it into the 5L jug..." (the entire
successful chain from Stage 1).

After this stage, the model becomes highly proficient at generating correct,
step-by-step solutions because it has been directly taught the patterns of successful
reasoning.

Stage 3: Refinement RL (Polishing with an Expert Tutor)


The model is now correct, but its reasoning might not be optimal, concise, or easy for
a human to understand. This final RL phase, often using RLHF or RLAIF, refines the
quality of the reasoning.

●​ The Goal: To optimize the reasoning style based on human preferences for
clarity, efficiency, and helpfulness.
●​ Preference Data Collection: The SFT model generates several valid solutions.
○​ Response A: The long, but correct, solution found in Stage 1.
○​ Response B: "1. Fill the 5L jug. 2. Pour from the 5L jug to fill the 3L jug.
This leaves 2L in the 5L jug. 3. Empty the 3L jug. 4. Pour the 2L from the
5L jug into the 3L jug. 5. Refill the 5L jug. 6. Carefully pour water from the
5L jug into the 3L jug (which already has 2L) until it's full. You will pour
exactly 1L. This leaves 4L in the 5L jug."
●​ Judgment: A human or AI labeler determines that Response B is a more elegant
and common solution than Response A. It is ranked higher: B > A.
●​ Fine-tuning: This preference data is used to train a reward model. The main
model is then fine-tuned with an algorithm like PPO to maximize rewards from
this new, more sophisticated judge.

Through this final process, the model learns not just to be correct, but to reason in a
way that is helpful, efficient, and aligned with human cognitive styles, resulting in a
highly capable and polished reasoning model.

Pure SFT and Distillation


This technique focuses on creating smaller, more efficient reasoning models by having
them learn from a larger, more powerful "teacher" model. It's a highly practical and
cost-effective method for democratizing advanced AI capabilities.

The core idea is knowledge distillation. Think of it like a world-class grandmaster chef
(a large, powerful "teacher" model) who doesn't have time to personally train every
apprentice. Instead, they write the ultimate, comprehensive cookbook (a high-quality
dataset) filled with their best recipes, detailed step-by-step instructions, and the
secret reasoning behind their techniques. Now, any talented apprentice (a smaller
"student" model) can study this cookbook to learn the grandmaster's skills, achieving
a high level of proficiency without ever needing direct, expensive instruction from the
master.
This process allows us to impart the strong reasoning skills of a massive model into a
much smaller one, making it accessible for use on less powerful hardware.

A Practical Example of Distillation in Action

Let's use a common coding challenge to see how this works.

Goal: Train a small, efficient 7-billion parameter model to solve coding problems with
clear explanations.

Stage 1: The Teacher Model Creates the "Cookbook"

First, a top-tier, proprietary model (the "teacher"), like GPT-4o or a specialized model
like DeepSeek-R1, is used to generate a massive, high-quality dataset of reasoning
examples.

●​ Teacher Model: A large, powerful reasoning model.


●​ Prompt: "Write a Python function that finds the second-largest number in a
given list of unique integers. Explain your reasoning."
●​ Teacher's High-Quality Output: The model generates a perfect response that
includes not just the code, but the underlying logic.
○​ Reasoning (Chain of Thought): "To solve this efficiently in a single pass, I
will maintain two variables: largest and second_largest. I'll initialize
them to the first two elements of the list, correctly ordered. Then, I'll
iterate through the rest of the list. For each number, I'll check if it's
greater than largest. If it is, the old largest becomes the new
second_largest, and the current number becomes the new largest.
Otherwise, I'll check if the number is greater than second_largest,
and if so, update second_largest."

Code:

def find_second_largest(numbers):
if len(numbers) < 2:
return None
largest = max(numbers[0], numbers[1])
second_largest = min(numbers[0], numbers[1])
for i in range(2, len(numbers)):
if numbers[i] > largest:
second_largest = largest
largest = numbers[i]
elif numbers[i] > second_largest:
second_largest = numbers[i]
return second_largest

This complete, high-quality pair of (Prompt, Response) is saved. This process is


repeated for millions of different problems, creating a rich and detailed dataset—our
"cookbook."

Stage 2: The Student Model Studies the "Cookbook"

Now, the smaller, more efficient model (the "student") is trained using this dataset in a
process called Supervised Fine-Tuning (SFT).

●​ Student Model: A smaller, open-source model like Llama 8B or Mistral 7B.


●​ Training Process: The student model is shown the prompt from the dataset and
tasked with generating the corresponding response. Its output is compared to
the teacher's perfect response. The model's internal parameters are adjusted
to minimize the difference, effectively teaching it to imitate the teacher's
high-quality reasoning and coding style.
○​ Input: "Write a Python function that finds the second-largest number..."
○​ Target Output: The entire reasoning and code block generated by the
teacher model.

Stage 3: The Outcome – A Skilled and Efficient Apprentice

After fine-tuning on the distilled knowledge from the teacher, the small student model
can now solve similar problems on its own. When given the same prompt, it will
generate a response that is remarkably close in quality to the one from the massive
teacher model.

The key advantage is that this newly capable 7B model can now run efficiently on
consumer-grade hardware, like a local laptop or even a smartphone. The reasoning
ability of a giant, expensive model has been successfully distilled into a small, fast,
and accessible package, without needing to perform the more complex and
computationally intensive Reinforcement Learning steps on the small model itself.
Thinking on Demand: The Rise of Test-Time Computation and
Adaptive Inference
A defining feature of the latest generation of LRMs is their ability to dedicate more
computational effort to a problem during inference—a concept known as test-time
compute. This paradigm acknowledges that not all problems are equally difficult and
that investing more "thinking time" can significantly improve performance on complex
tasks. This is the core mechanism behind the explicit "thinking" capabilities of models
like OpenAI's o-series and Google's Gemini 2.5 Pro. As pre-training models on
ever-larger datasets yields diminishing returns, test-time compute is now considered
a key driver of future performance gains.

Several methods fall under this umbrella. The simplest is Best-of-N sampling, where
the model generates multiple candidate responses, and a separate verifier model or
reward model scores them to select the best one. More sophisticated approaches
include iterative refinement and self-critique, where a model generates an initial
answer, then generates a critique of that answer, and finally produces a revised
response based on both. The most advanced implementations involve training
Process-Supervised Reward Models (PRMs) that evaluate the correctness of each
individual step in a reasoning chain, providing much more granular feedback than a
final outcome score.

The Scaling Dilemma: Sparse Mixture-of-Experts and


Computational Efficiency
The relentless pursuit of greater capabilities has driven model parameter counts into
the trillions, making the computational cost of training and inference a primary
constraint on progress. The Sparse Mixture-of-Experts (MoE) architecture is a
critical innovation designed to address this scaling dilemma. An MoE model replaces
some of the dense feed-forward network layers of a standard transformer with a
multitude of smaller "expert" sub-networks. For each input token, a lightweight
"gating network" or "router" dynamically selects a small subset of these experts
(typically one or two) to activate and process the token.

The primary benefit of this approach is that it decouples the model's total parameter
count from its computational cost (measured in floating-point operations, or FLOPs)
per token. This allows for the creation of models with enormous capacity (trillions of
total parameters) that can be trained and run with the computational budget of a
much smaller dense model. However, this efficiency comes with significant trade-offs.
The most notable is a massive increase in memory (VRAM) requirements, as all
experts must be loaded into memory simultaneously. Furthermore, MoE models
introduce training complexities, such as ensuring that the gating network distributes
tokens evenly across all experts to prevent some from becoming over-trained while
others are neglected.

Technique Description Primary Key Limitation Ideal Use Case


Benefit (Pro) (Con)

Chain-of-Tho A prompting Improves Limited to a Solving


ught (CoT) method that accuracy on single, linear structured
guides the LLM multi-step reasoning path; problems like
to generate a problems; can still arrive at math word
step-by-step provides incorrect problems,
reasoning transparency conclusions. commonsense
process before into the model's reasoning, and
the final answer. reasoning basic logical
process. puzzles.

Tree of A framework Enables solving High Strategic


Thoughts that allows the complex computational planning,
(ToT) LLM to explore problems overhead; combinatorial
multiple requiring complexity in optimization,
reasoning paths exploration and managing the and creative
simultaneously, trial-and-error; search tree and tasks where
evaluate them, more robust state evaluation. multiple solution
and backtrack than linear CoT. paths must be
from dead ends. considered.

RL from A process where Aligns the model Slow, expensive, The


Human human directly with and difficult to foundational
Feedback preferences are human values scale due to the method for
used to train a for helpfulness human aligning
(RLHF)
reward model, and safety. feedback general-purpos
which then bottleneck. e chatbots to be
fine-tunes the safe and helpful.
LLM with
reinforcement
learning.

RL from AI A training Highly scalable The quality of Large-scale


Feedback paradigm that and alignment is alignment for
(RLAIF) uses an AI cost-effective dependent on safety and
model (guided compared to the capability helpfulness,
by a RLHF; enables and especially for
constitution) to rapid alignment "constitution" of models with
generate of complex the AI labeler; complex,
preference data reasoning may lack multi-step
for aligning an behaviors. nuanced human reasoning
LRM. judgment. chains.

SFT + RL A multi-stage Creates A highly Building frontier,


process that state-of-the-art complex, high-performan
first uses RL to models by multi-stage, and ce reasoning
generate combining the computationally models that
correct discovery power intensive require both
reasoning of RL with the process. correctness and
examples, then learning optimal,
uses SFT to train efficiency of human-aligned
the model on SFT. logic.
this data, and
finally uses RL
to refine the
reasoning style.

Pure SFT and Training a Cost-effective The student Democratizing


Distillation smaller method to model's quality advanced
"student" model create smaller, is ultimately reasoning by
on a efficient models capped by the creating
high-quality by transferring teacher's ability capable models
dataset of skills from a and the quality that can run on
reasoning larger model. of the consumer
examples generated data. hardware.
generated by a
superior
"teacher"
model.

Pure A method where Provides Currently less Research and


Reinforcemen a base model valuable effective for specialized
t Learning learns reasoning research general-purpos domains like
(RL) as an emergent insights into e models than code generation
behavior by how reasoning combined where simple
being rewarded can develop methods and pass/fail
for correct organically in can be rewards are
outcomes, models. inefficient. available.
without an initial
SFT stage.

Test-Time The use of Dramatically Increases Tackling novel,


Computation additional improves inference complex
computational performance on latency and problems in
resources difficult tasks; cost; may not be high-stakes
during inference allows for a necessary or domains (e.g.,
to improve dynamic cost-effective science,
answer quality, trade-off for simpler engineering)
allowing the between queries. where accuracy
model to "think" accuracy, cost, is paramount.
longer. and latency.

Sparse An architecture Allows for Massive Building


Mixture-of-Ex that uses a scaling model memory (VRAM) frontier-scale
perts (MoE) router to parameter requirements; open-source
dynamically counts to increased models where
activate a small trillions with the training training
subset of computational complexity; efficiency and
"expert" cost of a much potential inference speed
sub-networks smaller dense performance are critical.
for each token. model. trade-offs on
reasoning tasks.

PPO An RL algorithm Stable and An algorithm, The workhorse


(Proximal that fine-tunes reliable for not a full algorithm for
Policy the LLM in small, fine-tuning large training the final RL
stable steps, models; avoids paradigm; can tuning stage in
Optimization)
preventing it "catastrophic be less RLHF/RLAIF,
from losing its forgetting." sample-efficient ensuring stable
existing than other RL improvement.
knowledge methods.
during training.
DPO (Direct An alignment More direct, Still reliant on a An efficient and
Preference method that stable, and high-quality powerful
Optimization) directly trains on computationally preference replacement for
preference pairs cheaper than dataset; it's a the reward
(A > B) without traditional RLHF; newer method modeling and
needing a often yields with evolving PPO steps in an
separate reward better results. best practices. alignment
model, pipeline.
reframing
alignment as a
classification
task.

A Comparative Analysis of Leading Models


As of mid-2025, the field of Large Reasoning Models is characterized by intense
competition and rapid innovation. The landscape is broadly divided into two camps:
the proprietary, closed-source "titans of industry" pushing the absolute frontier of
capability, and a vibrant "open-source insurgency" democratizing access to advanced
reasoning. A key strategic divergence has become apparent: proprietary labs are
heavily invested in scaling inference-time compute (the "thinking" paradigm), which
aligns with premium API-based business models. In contrast, leading open-source
projects are leveraging architectural efficiency (like MoE and massive context
windows) to scale parameter counts while managing costs.

The Titans of Industry (Proprietary Models)


OpenAI (o-series, GPT-5): OpenAI continues to set the pace with its latest models,
which treat reasoning as an explicit, controllable, and engineered capability. This is
exposed to developers via API parameters like reasoning_effort, providing granular
command over the accuracy, latency, and cost trade-off. Architecturally, the system is
a complex of distinct reasoning, non-reasoning, and router models, boasting a
400,000-token context window and achieving state-of-the-art (SOTA) performance
on difficult benchmarks like SWE-bench Verified (74.9%) and GPQA Diamond (89.4%).
Google (Gemini 2.5 Pro): Google's strategy is to fuse frontier reasoning with
unparalleled scale in multimodality and context length. Gemini 2.5 Pro is designed to
reason natively across text, images, audio, and video within a single, unified MoE
architecture. Its most prominent features are a massive one-million-token context
window and "Configurable Thinking Budgets," enabling the analysis of entire books or
large codebases in a single prompt. It is highly competitive on frontier benchmarks,
scoring 86.4% on GPQA Diamond and 88.0% on the AIME 2025 math competition.
Anthropic (Claude 4 Opus): Anthropic has positioned Claude 4 Opus as a specialist
optimized for reliability and superior performance on complex, long-horizon agentic
workflows, such as autonomous research and multi-file code refactoring. It features a
"hybrid reasoning" model with an "extended thinking" mode and controllable "thinking
budgets" of up to 64,000 tokens. This focus has made Claude a leader on demanding
coding benchmarks, achieving an industry-leading 79.4% on SWE-bench Verified in its
high-compute version.
xAI (Grok 3): xAI's Grok 3 centers its value proposition on real-time information
access (drawing from the X platform) and transparent, verifiable reasoning. Its
signature "DeepSearch" feature provides users with a step-by-step breakdown of the
model's reasoning process. "Big Brain Mode" leverages additional computational
resources from its massive training cluster to tackle multi-step problems. It has
demonstrated impressive performance, surpassing a 1400 ELO rating on the Chatbot
Arena and achieving high scores on math benchmarks.

The Open-Source Insurgency


Meta (Llama 4 Series): The Llama 4 family marks a significant leap forward, with a
strategic focus on native multimodality and extreme-scale context windows. The
architectural innovation of "early fusion" allows models to be pre-trained jointly on
text and vision tokens. The series includes Llama 4 Scout, which features an
unprecedented 10-million-token context window, and Llama 4 Maverick, an efficient
400B-parameter MoE model designed for a best-in-class performance-to-cost ratio.
DeepSeek (R1 Series): DeepSeek has open-sourced its R1 series, whose
performance rivals top proprietary systems. The foundational model,
DeepSeek-R1-Zero, was uniquely trained exclusively with reinforcement learning,
without initial supervised fine-tuning (SFT). The latest iteration, DeepSeek R1-0528, is
a massive 685-billion-parameter MoE model that achieves "Claude-level"
performance with a fraction of the computational resources of competitors, scoring
87.5% on AIME 2025.
Specialized Contenders: Qwen and Kimi​
Further diversifying the landscape, Alibaba's Qwen series includes dedicated
reasoning models like QwQ-32B, which perform best when prompted to generate
explicit <think> tags. Moonshot AI's Kimi k1.5 is another formidable open-weight
contender that has achieved SOTA performance through a novel, simplistic RL
framework that scales with long context windows, matching OpenAI's o1 on
benchmarks like AIME (77.5%).
OpenAI GPT-Oss: OpenAI's GPT-oss series, comprising the 20b and 120b models,
significantly contributes to the open-source landscape. These highly efficient
Mixture-of-Experts (MoE) systems are specifically optimized for agentic workflows,
complex tool use, and controllable Chain-of-Thought reasoning. Their 131k token
context window, extended by the YaRN technique, is managed through the unique
"Harmony" hierarchical prompt format. A key defining feature of GPT-oss is OpenAI's
strong emphasis on safety. This is achieved through a "deliberative alignment"
process, which trains the models to reason about safety rules before generating
responses. This focus on verifiable reasoning and safety, combined with
memory-efficient quantization-aware training, positions GPT-oss as a powerful
contender for reliable and controllable agentic performance. The Harmony prompt
format itself is a specialized chat template for GPT-oss. It is designed to facilitate
complex agentic workflows, including tool use and structured outputs, by employing a
hierarchical structure with distinct roles (e.g., developer, system, tool) and separate
output channels for reasoning, tool calls, and the final user-facing answer. This
organization enables the model to effectively manage and prioritize different types of
instructions and outputs within a single, coherent conversation.

Model Developer Access Architectur Max Key


e Highlights Context Reasoning
Window Feature(s)

GPT-5 Pro OpenAI Proprietary Unified system 400,000 Controllable


with tokens reasoning_eff
reasoning, ort parameter;
non-reasoning "Thinking" and
, and router "Pro" modes
models. for deep
reasoning.

Gemini 2.5 Google Proprietary Sparse MoE; 1,000,000+ "Configurable


Pro Native tokens Thinking
Multimodality Budgets";
(Text, Image, Deep
Audio, Video). integration
with search
for grounding.

Claude 4.1 Anthropic Proprietary Hybrid 200,000 "Extended


Opus reasoning tokens Thinking" with
model controllable
optimized for "thinking
agentic tasks. budgets" up
to 64K tokens.

Grok 3 xAI Proprietary Trained on Not Specified "DeepSearch"


massive GPU for
cluster with transparent
real-time web reasoning;
access. "Big Brain
Mode" for
extra
compute.

Llama 4 Meta Open-Weight Native 10,000,000 Industry-leadi


Scout Multimodality tokens ng context
with "early length for
fusion" massive
architecture. document
analysis.

DeepSeek DeepSeek Open-Weight Sparse MoE 128,000 Training


R1-0528 with 685B tokens methodology
total heavily reliant
parameters. on
Reinforcement
Learning (RL).

DeepSeek- DeepSeek Open-Weight Trained using Not Specified Emergent


Coder-RL a Pure reasoning;
Reinforceme spontaneously
nt Learning generated
approach. chain-of-thou
ght steps
without SFT,
demonstrating
an "Aha!"
moment.
Kimi k1.5 Moonshot AI Open-Weight Multimodal 128,000 Reasoning
model trained tokens capabilities
with a scale with
simplistic RL context
framework. length,
avoiding
complex
search
algorithms.

Qwen Alibaba Open-Weight Transformer 131,072 tokens Specialized


QwQ-32B with RoPE and model that
SwiGLU performs best
enhancement when
s. prompted to
generate
<think> tags.

GPT-oss OpenAI Open-Weight WeightMoE 131,072 tokens "Deliberative


(20b/120b architecture Alignment" for
with safety
quantization-a reasoning;
ware training; Controllable
"Harmony" effort levels
prompt format (low, medium,
for agentic high).
tasks.

Benchmarking Reasoning Capabilities


The rapid evolution of LRMs has necessitated a parallel evolution in how their
cognitive abilities are measured. While a canon of established benchmarks provides a
quantitative basis for comparison, there is a growing recognition that these tests may
not fully capture the nuances of true, robust reasoning.

The Established Canon and Its Limitations


A suite of standardized tests is commonly used to evaluate LRMs, including MMLU
(general knowledge), GSM8K (grade-school math), AIME (competition-level math),
HumanEval (basic coding), and SWE-bench (real-world software engineering).
Despite their widespread use, there is a growing consensus that high scores may not
be synonymous with genuine reasoning ability. Key limitations include the risk of data
contamination (where models memorize answers from the training data) and a
tendency to reward pattern matching over true abstract insight. A June 2025 study
from Apple Machine Learning Research found that even frontier LRMs suffer a
complete "accuracy collapse" beyond a certain threshold of compositional
complexity, suggesting their performance may be an **"illusion of thinking." In
response, the community is developing new benchmarks to probe deeper cognitive
abilities like abstract, causal, and robust reasoning.

The 2025 Leaderboard


Synthesizing performance data reveals a tight race at the top, with different models
excelling in different domains.
●​ General & Scientific Reasoning: OpenAI's GPT-5 Pro (89.4% on GPQA
Diamond) and Google's Gemini 2.5 Pro (86.4%) are nearly tied at the apex.
●​ Advanced Mathematics: Gemini 2.5 Pro (88.0% on AIME 2025) and DeepSeek
R1-0528 (87.5%) are the top performers.
●​ Coding & Software Engineering: Anthropic's Claude 4.1 Opus (79.4% on
SWE-bench High-Compute) and GPT-5 Pro (74.9%) are the clear leaders.
●​ Human Preference: Google's Gemini 2.5 Pro and xAI's Grok 3 have consistently
held top spots in the LMArena ELO rating, suggesting their outputs are often
perceived as more helpful or engaging.

Model GPQA AIME 2025 SWE-bench MMLU (%) LMArena


Diamond (%) Verified (%) ELO
(%)

GPT-5 Pro 89.4 98.0 (est.) 74.9 87.0 ~1350+


(w/ tools)

Gemini 2.5 86.4 88.0 67.2 88.5 ~1400+


Pro

Claude 4.1 79.6 (83.3 75.5 (90.0 74.5 (79.4 88.8 ~1300+
Opus HC) HC) HC)

Grok 3 ~85 (est.) 93.0 ~70 (est.) ~87 (est.) ~1400+


DeepSeek 81.0 87.5 57.6 ~87 (est.) ~1300+
R1-0528

Kimi k1.5 N/A 77.5 47.3 N/A N/A


(short-CoT)

Llama 4 69.8 N/A N/A 80.5 1417 (exp.)


Maverick (MMLU-Pro)

Notes: "HC" denotes High-Compute mode. "est." denotes an estimate. Llama 4's
MMLU-Pro score is a harder variant. ELO ratings are dynamic.

Future Directions and Implications


The rapid ascent of LRMs is the beginning of a new trajectory for AI. The next frontier
of research is focused on overcoming current limitations, while the deployment of
these systems is already beginning to reshape industries and raise profound
questions about the future of work, economics, and human cognition.

Next-Generation Reasoning Paradigms


To move beyond the current state of the art, researchers are exploring several
promising avenues. One of the most significant is the push towards Neuro-Symbolic
AI, which seeks to combine the pattern-recognition strengths of neural networks with
the logic and structure of symbolic reasoning systems, often by integrating
Knowledge Graphs (KGs) to improve factual accuracy and reduce hallucinations.
Researchers are also developing more dynamic reasoning frameworks beyond CoT
and ToT, such as Simulated Multi-Agent Debate and Adversarial Self-Critique.
Ultimately, these efforts are part of a broader pursuit of "System-2" cognition—the
slow, deliberate, and logical thinking characteristic of human deliberation, as opposed
to the fast, intuitive "System 1" thinking of current LLMs.

Neuro-Symbolic AI: Bridging Logic and Intuition


Neuro-Symbolic AI aims to create a more robust and trustworthy form of intelligence
by combining two different AI paradigms:
1.​ Neural Networks (The "Neuro" part): These are the foundation of modern
LLMs. They excel at pattern recognition, understanding natural language, and
making intuitive leaps. This is AI's version of "System 1" thinking.
2.​ Symbolic Reasoning (The "Symbolic" part): This is a more classic form of AI
that operates on explicit rules, logic, and structured knowledge. It's precise,
verifiable, and excellent at tasks that require formal reasoning. This is "System
2" thinking.

The goal is to get the best of both worlds. The neural network understands the messy,
ambiguous real world, while the symbolic system provides a backbone of hard facts
and logical rules, improving factual accuracy and dramatically reducing the chance of
"hallucination."

A key technology for implementing this is the Knowledge Graph (KG). A KG is a


structured database of facts and the relationships between them. For instance, a KG
would store information not as a sentence, but as a connection: (London) --is
capital of--> (England) and (England) --is part of--> (United
Kingdom).

Example: Using a Knowledge Graph

Let's say you ask a complex, multi-part question.

●​ Prompt: "Which UK city, known for a band called The Beatles, is located on the
River Mersey?"

A standard LLM might try to guess based on statistical associations. A


Neuro-Symbolic model would work differently:

1.​ Neuro (LLM): The model first parses the question to identify the key entities
and relationships needed: (City in UK), (Home of The Beatles),
(Located on River Mersey).
2.​ Symbolic (KG Query): The system translates these needs into formal queries
on its Knowledge Graph.
○​ Query 1: find city WHERE band = 'The Beatles'. Result:
Liverpool.
○​ Query 2: verify city = 'Liverpool' AND location = 'River
Mersey'. Result: True.
○​ Query 3: verify city = 'Liverpool' AND country = 'UK'.
Result: True.
3.​ Neuro (LLM): With the facts verified by the KG, the LLM confidently generates
the final, correct answer: "The city you're looking for is Liverpool."

Dynamic Reasoning Frameworks


While Chain-of-Thought and Tree of Thoughts are powerful, they typically involve a
single "mind" working on a problem. The next frontier involves creating more dynamic,
multi-agent systems that simulate debate and critique to arrive at more robust
conclusions.

Simulated Multi-Agent Debate

This technique involves setting up multiple AI agents to debate a topic from different
perspectives before reaching a conclusion. It's like an expert panel discussion where
each member is tasked with defending a specific viewpoint. 🤝
This process helps uncover biases, explore a wider range of solutions, and stress-test
an argument before it's finalized.

Example: A Policy Debate

●​ Prompt: "Should the UK government ban the sale of petrol cars by 2035?"

Instead of one answer, multiple agents are activated:

●​ Agent 1 (Economic Advisor): "A ban could disrupt the automotive industry,
leading to job losses. We should consider a slower phase-out with more
subsidies for electric vehicle (EV) manufacturing."
●​ Agent 2 (Environmental Scientist): "The climate crisis is urgent. A 2035 ban is
essential to meet our emissions targets. The environmental benefits far
outweigh the short-term economic disruption."
●​ Agent 3 (Urban Planner): "The focus shouldn't just be on cars, but on
infrastructure. Without a massive, rapid expansion of public EV charging
stations, a ban will be unworkable and unfair to people living in apartments."
●​ Final Judge Agent (Synthesizer): After reviewing the debate, it produces a
balanced conclusion: "A 2035 ban on petrol cars is a necessary environmental
goal, but its success hinges on a comprehensive government strategy. This
strategy must include economic support for the automotive industry's
transition, as well as a major public investment in a nationwide EV charging
network to ensure equitable access for all citizens."

Adversarial Self-Critique
This framework forces a model to become its own "devil's advocate." The model
generates an initial answer, and is then re-prompted to act as a harsh critic,
identifying any flaws, biases, or logical gaps in its own reasoning. Finally, it integrates
the critique to produce a much-improved final answer.

Example: Evaluating a Business Strategy

●​ Prompt: "Is opening a high-end coffee shop in Shoreditch, London a good


business idea?"

The process unfolds in three steps:

1.​ Step 1: Initial Answer: "Yes, it's a great idea. Shoreditch has a wealthy,
trend-conscious demographic that appreciates artisanal products. The high
foot traffic and existing cafe culture suggest a ready market for a premium
coffee experience."
2.​ Step 2: Self-Critique (as an adversary): "This analysis is overly optimistic. It
ignores that the Shoreditch market is already extremely saturated with coffee
shops, leading to intense competition. It also fails to consider the sky-high
commercial rent prices in the area, which would put immense pressure on
profit margins. The 'trend-conscious' demographic is also fickle and may not
grant loyalty easily."
3.​ Step 3: Revised Final Answer: "Opening a high-end coffee shop in Shoreditch
presents both a significant opportunity and a substantial risk. While the local
demographic is an ideal target market, any new business must contend with
extreme market saturation and high operational costs. To succeed, the
business would need a truly unique selling proposition (USP)—such as
exclusive coffee bean sourcing, an innovative in-store experience, or strong
community partnerships—to stand out from established competitors and justify
its premium pricing."

The rapid evolution from predictive language models to deliberative reasoning


systems marks a fundamental shift in artificial intelligence. Future advancements like
neuro-symbolic hybrids and multi-agent debate frameworks are pushing these
systems towards more robust, human-like "System-2" cognition. As these powerful
tools integrate into society, they present a critical dilemma: whether they will augment
our own intelligence or lead to cognitive atrophy through over-reliance. A significant
risk is that economic pressures will favor automation over augmentation, potentially
creating AI reasoning so advanced and alien that its decisions become impossible to
align with human values. Ultimately, navigating this future requires guiding AI's
development not just for capability, but to ensure it remains a transparent and
beneficial partner in human progress.

Key Takeaways
●​ The Paradigm Shift to Reasoning: The most advanced AI is moving from
generative Large Language Models (LLMs), which excel at pattern matching, to
deliberative Large Reasoning Models (LRMs), which are engineered for
complex, multi-step problem-solving.
●​ Reasoning is Computationally Expensive: LRMs are not a universal solution.
Their deep, step-by-step processing makes them slower and more costly to
run, so they are best applied selectively to complex tasks where this "thinking"
is beneficial.
●​ Two Paths to Better Reasoning: Performance is improved through two main
avenues: Inference-Time Scaling (like Chain-of-Thought and Tree of
Thoughts), which uses more compute on-the-fly without retraining the model,
and Improved Training Methodologies, which build reasoning capabilities
directly into the model's architecture.
●​ The Gold Standard Training Method: The most powerful reasoning models
today are built using a hybrid approach of Supervised Fine-Tuning +
Reinforcement Learning (SFT + RL), often using AI-generated feedback
(RLAIF) for scalable alignment.
●​ Distillation Democratizes Reasoning: Pure SFT and Distillation is a key
technique for creating smaller, efficient models. A large "teacher" model
generates high-quality reasoning examples, which are then used to train a
smaller "student" model, making powerful AI accessible on less powerful
hardware.
●​ Architectural and Strategic Divergence: Sparse Mixture-of-Experts (MoE)
is the key architecture enabling models with trillions of parameters to be
computationally feasible. A strategic split has emerged: proprietary labs
(OpenAI, Google) are focusing on scaling up inference-time compute ("thinking
time"), while the open-source community (Meta, DeepSeek) is leveraging MoE
to scale model size and efficiency.
●​ Alignment is Crucial: As models become more powerful reasoners, ensuring
their outputs are correct, helpful, and safe is a critical challenge. Techniques
like RLHF and the more scalable RLAIF are essential for aligning model
behavior with human values.
●​ Benchmarks Have Limitations: While standard benchmarks (MMLU,
SWE-bench, etc.) provide a way to rank models, there's growing concern that
they may reward pattern matching over true understanding, creating an
"illusion of thinking."
●​ The Future is Economic and Cognitive: The rise of LRMs is set to automate
high-level knowledge work and power an "autonomous agent economy." This
poses profound questions about the future of human-AI collaboration and the
risk of cognitive offloading versus the potential for intelligence augmentation.

Conclusion
The middle of 2025 marked a paradigm shift in artificial intelligence with the advent of
Large Reasoning Models. This period ushered in a significant transition from AI
primarily focused on generative prediction to systems capable of deliberative
reasoning, unlocking an unprecedented range of capabilities. Leading AI research
institutions like OpenAI, Google, and Anthropic, alongside a vibrant open-source
community, have demonstrated models that can effectively address highly complex
challenges across diverse domains, including scientific discovery, intricate coding
tasks, and abstract logical problems.

This significant leap forward is built upon the maturation and widespread adoption of
several cutting-edge techniques. Methodologies such as Chain-of-Thought
prompting, which allows models to articulate their step-by-step reasoning process,
have been crucial in enhancing their problem-solving abilities. Reinforcement
Learning from AI Feedback (RLAIF), a sophisticated training paradigm, enables
models to refine their outputs based on iterative feedback, leading to more robust
and accurate reasoning. Furthermore, Test-Time Computation, where models engage
in additional processing during inference to improve accuracy, has become standard
practice. These techniques, among others, are rapidly establishing themselves as the
new benchmark for frontier AI development, pushing the boundaries of what is
possible.
While these advanced models are already initiating profound transformations in
high-stakes industries, ranging from healthcare and finance to engineering and legal
services, and simultaneously laying the groundwork for an economy increasingly
powered by autonomous agents, their formidable capabilities present a dual nature.
The immense power they wield, while offering unparalleled opportunities for progress
and efficiency, also introduces significant ethical, societal, and existential
considerations that demand careful navigation and robust regulatory frameworks.

The evolution from generative language models to deliberative reasoning engines


represents a pivotal moment in the history of artificial intelligence. This transformation
is not driven by a single breakthrough but by the convergence of sophisticated
reasoning frameworks like Tree of Thoughts, innovative architectural designs like
Sparse Mixture-of-Experts, and advanced alignment techniques such as RLAIF. The
current landscape is defined by a dynamic competition between proprietary titans
scaling up inference-time "thinking" and an open-source insurgency scaling
architectural efficiency to democratize access.

This technological leap extends far beyond academic benchmarks; it is fundamentally


reshaping the capabilities of AI, paving the way for an autonomous agent economy
and transforming high-level knowledge work. As we stand at this new frontier, the
paramount challenge is twofold: to continue pushing the boundaries of machine
cognition and, more importantly, to steer these powerful systems with wisdom. The
future will be defined not just by the reasoning power we build, but by our ability to
ensure it is aligned, safe, and ultimately serves to augment and elevate human
potential.

References
1.​ Towards Large Reasoning Models: A Survey, 2025,
https://arxiv.org/abs/2501.09686
2.​ Benchmarking LLMs on Advanced Mathematical Reasoning - UC Berkeley EECS,
https://www2.eecs.berkeley.edu/Pubs/TechRpts/2025/EECS-2025-121.pdf
3.​ What is chain of thought (CoT) prompting? - IBM
https://www.ibm.com/think/topics/chain-of-thoughts
4.​ RLAIF: Scaling Reinforcement Learning from AI feedback | Encord
https://encord.com/blog/reinforecement-learning-from-ai-feedback-what-is-rlaif/
5.​ RLHF-Aligned Open LLMs: A Comparative Survey,
https://www.preprints.org/manuscript/202506.2381/v1
6.​ RLAIF: Scaling Reinforcement Learning from Human Feedback with AI...,
https://openreview.net/forum?id=AAxIs3D2ZZ
7.​ RLAIF vs. RLHF: Scaling Reinforcement Learning from Human Feedback with AI
Feedback, https://arxiv.org/html/2309.00267v3
8.​ Scaling LLM Test Time Compute,
https://www.jonvet.com/blog/llm-test-time-compute
9.​ Study could lead to LLMs that are better at complex reasoning | MIT
https://news.mit.edu/2025/study-could-lead-llms-better-complex-reasoning-070
8
10.​System-performance and cost modeling of Large Language ... - arXiv,
https://arxiv.org/html/2507.02456
11.​ Understanding Compute-Parameter Trade-offs in Sparse Mixture-of-Experts
Language Models - arXiv, https://arxiv.org/html/2501.12370
12.​A Survey on Mixture of Experts in Large Language Models
https://arxiv.org/pdf/2407.06204
13.​Claude 4 Opus vs Llama 4 Maverick | AIModels.fyi, accessed August 9, 2025,
https://www.aimodels.fyi/compare/claude-4-opus-vs-llama-4-maverick
14.​DeepSeek's New R1–0528: Performance Analysis and Benchmark ...,
https://medium.com/@leucopsis/deepseeks-new-r1-0528-performance-analysis-
and-benchmark-comparisons-6440eac858d6
15.​OpenAI o1 System Card | OpenAI,
https://openai.com/index/openai-o1-system-card/
16.​Introducing GPT‑5 for developers | OpenAI,
https://openai.com/index/introducing-gpt-5-for-developers/
17.​Gemini 2.5: Pushing the Frontier with Advanced ... -
https://storage.googleapis.com/deepmind-media/gemini/gemini_v2_5_report.pdf
18.​Claude Opus 4.1 \ Anthropic, https://www.anthropic.com/claude/opus
19.​Grok 3: Everything you need to know about this new LLM by xAI,
https://daily.dev/blog/grok-3-everything-you-need-to-know-about-this-new-llm-
by-xai
20.​DeepSeek-R1 vs Grok-3: What Can They Tell About AI Scaling?
https://www.counterpointresearch.com/insight/post-insight-research-notes-blog
s-deepseekr1-vs-grok3-what-can-they-tell-about-ai-scaling
21.​Llama: Industry Leading, Open-Source AI, https://www.llama.com/
22.​Qwen 3 vs Kimi K2 : AI Model Precision vs Versatility, Who Wins ...,
https://www.geeky-gadgets.com/qwen-3-vs-kimi-k2-ai-models-compared/
23.​Kimi k1.5: Scaling Reinforcement Learning with LLMs,
http://arxiv.org/pdf/2501.12599
24.​Top LLM Benchmarks Explained: MMLU, HellaSwag, BBH, and ...,
https://www.confident-ai.com/blog/llm-benchmarks-mmlu-hellaswag-and-beyon
d
25.​System-2 Reasoning at Scale - NeurIPS 2025,
https://neurips.cc/virtual/2024/workshop/84749
26.​Reasoning best practices - OpenAI API,
https://platform.openai.com/docs/guides/reasoning-best-practices
27.​The Rise of AI Agents: Impacts on Markets, Productivity, and Investment Strategy,
accessed August 9,
https://startfastventures.com/the-rise-of-ai-agents-impacts-on-markets-product
ivity-and-investment-strategy/

Introduction​ 1
The Architectural Foundations of Modern AI Reasoning​ 3
Inference-Time Scaling: Deeper Thinking on Demand​ 3
Chain of Thought (CoT)​ 3
Few shots and CoT​ 7
Tree of Thoughts: From Linear Steps to Exploratory Search​ 9
The Alignment Engine​ 11
Reinforcement Learning from Human Feedback (RLHF)​ 11
Reinforcement Learning from AI Feedback (RLAIF)​ 13
Pure Reinforcement Learning (RL)​ 14
Key RL Algorithms in Alignment (PPO and DPO)​ 14
Supervised Fine-Tuning + Reinforcement Learning (SFT + RL)​ 15
Pure SFT and Distillation​ 18
Thinking on Demand: The Rise of Test-Time Computation and Adaptive Inference​
20
The Scaling Dilemma: Sparse Mixture-of-Experts and Computational Efficiency​ 21
A Comparative Analysis of Leading Models​ 24
The Titans of Industry (Proprietary Models)​ 24
The Open-Source Insurgency​ 25
Benchmarking Reasoning Capabilities​ 28
The Established Canon and Its Limitations​ 28
The 2025 Leaderboard​ 28
Future Directions and Implications​ 29
Next-Generation Reasoning Paradigms​ 29
Neuro-Symbolic AI: Bridging Logic and Intuition​ 30
Dynamic Reasoning Frameworks​ 31
Adversarial Self-Critique​ 32
Key Takeaways​ 33
Conclusion​ 34
References​ 35

Transformer: the Code


Introduction
This Python script provides a complete, self-contained implementation of a
character-level language model. It uses a GPT-style, decoder-only Transformer
architecture. The entire system is built using JAX for high-performance numerical
computing and Flax as the neural network library.

The script's primary goal is to train this model on the "Tiny Shakespeare" dataset.
After training, the model can auto-regressively generate new text one character at a
time, mimicking the style of Shakespeare. It is an excellent, modern example of
implementing a foundational language model from scratch.

Code Breakdown
The code is logically structured into model definition, training logic, data helpers, and
a main execution block.

Initialization and Hyperparameters


This initial section defines all the key hyperparameters for the model and the training
process.
●​ Model Architecture: N_EMBED, NUM_HEADS, NUM_BLOCKS define the size and
complexity of the Transformer. BLOCK_SIZE sets the context window, i.e., how
many previous characters the model can "see" when predicting the next one.
●​ Training Parameters: BATCH_SIZE, LEARNING_RATE, and TRAIN_STEPS
control the optimization process.
●​ Evaluation & Generation: EVAL_INTERVAL, EVAL_STEPS, and
GENERATION_TOKENS control how often the model is evaluated and how much
text it generates at the end.

# -*- coding: utf-8 -*-


"""
A character-level GPT-style language model implemented in JAX/Flax.
This script downloads the Tiny Shakespeare dataset, defines a
decoder-only Transformer, and trains it to generate text.
"""
import os
import requests
from functools import partial
from typing import Dict, List, Tuple, Callable, NamedTuple

import numpy as np
import jax
import jax.numpy as jnp
import chex
import flax.linen as nn
import optax
from flax.training import train_state

# --- Hyperparameters ---


# N_EMBED: Embedding dimension.
# BATCH_SIZE: Number of parallel sequences per batch.
# BLOCK_SIZE: Sequence length (context size).
# NUM_HEADS: Number of attention heads.
# NUM_BLOCKS: Number of Transformer blocks.
# LEARNING_RATE: AdamW optimizer learning rate.
# TRAIN_STEPS: Total number of training iterations.
# EVAL_INTERVAL: How often to run evaluation.
N_EMBED = 64
BATCH_SIZE = 64
BLOCK_SIZE = 64
NUM_HEADS = 4
NUM_BLOCKS = 4
LEARNING_RATE = 3e-4
TRAIN_STEPS = 5000
EVAL_INTERVAL = 250
EVAL_STEPS = 100
GENERATION_TOKENS = 200
Model Definition
This part defines the building blocks of the Transformer model using Flax's
nn.Module. The FeedForward class implements the position-wise feed-forward
network (FFN) found in each Transformer block. It consists of two linear layers with a
GELU activation function in between. The standard practice, followed here, is to
expand the inner dimension to 4 times the embedding size (4 * self.n_embed) and
then project it back.

# --- Model Definition ---

class FeedForward(nn.Module):
"""A simple feed-forward network with GELU activation."""
n_embed: int

@nn.compact
def __call__(self, x: chex.Array) -> chex.Array:
"""
Applies a two-layer MLP as per "Attention is All You Need".

Args:
x: Input tensor of shape (B, T, C).
Returns:
Output tensor of shape (B, T, C).
"""
chex.assert_shape(x, (None, None, self.n_embed))
# The 4*n_embed expansion is a standard practice.
net = nn.Sequential([
nn.Dense(4 * self.n_embed),
nn.gelu,
nn.Dense(self.n_embed),
])
return net(x)

Multi Headed Attention

The MultiHeadedAttention class implements the core self-attention mechanism. It


leverages Flax's built-in nn.MultiHeadDotProductAttention for an efficient
implementation. The most critical line is nn.make_causal_mask(...). This creates
a mask that prevents any token from attending to future tokens in the sequence,
which is essential for a decoder-only language model that predicts one token at a
time.

class MultiHeadedAttention(nn.Module):
"""Vectorized Multi-Head Self-Attention."""
num_heads: int
n_embed: int

@nn.compact
def __call__(self, x: chex.Array) -> chex.Array:
"""
Performs multi-head self-attention with causal masking.

Args:
x: Input tensor of shape (B, T, C).
Returns:
Output tensor of shape (B, T, C).
"""
B, T, C = x.shape
chex.assert_equal(C, self.n_embed)

# Create a causal mask to prevent attention to future tokens.


# Shape: (B, 1, T, T) for broadcasting over heads.
causal_mask = nn.make_causal_mask(jnp.ones((B, T)))

return nn.MultiHeadDotProductAttention(
num_heads=self.num_heads,
qkv_features=self.n_embed,
)(inputs_q=x, mask=causal_mask)

Transformer Block

The Block class assembles one complete Transformer block. It combines the
MultiHeadedAttention and FeedForward sub-layers. Notably, it employs
pre-layer normalization (applying nn.LayerNorm before the attention/FFN layers)
and residual connections (x = x + ...). This architecture (LayerNorm ->
Sub-layer -> Residual Add) is known to improve training stability compared to
the original "post-norm" design.

class Block(nn.Module):
"""A Transformer block using pre-layer normalization."""
n_embed: int
num_heads: int

@nn.compact
def __call__(self, x: chex.Array) -> chex.Array:
"""
Applies attention and MLP with residual connections.
Uses pre-layer normalization for better training stability.

Args:
x: Input tensor of shape (B, T, C).
Returns:
Output tensor of shape (B, T, C).
"""
chex.assert_shape(x, (None, None, self.n_embed))

# First residual connection (Multi-Head Attention)


attn_out = MultiHeadedAttention(
num_heads=self.num_heads, n_embed=self.n_embed
)(nn.LayerNorm()(x))
x = x + attn_out

# Second residual connection (Feed-Forward Network)


ffwd_out = FeedForward(n_embed=self.n_embed)(
nn.LayerNorm()(x)
)
x = x + ffwd_out

chex.assert_shape(x, (None, None, self.n_embed))


return x

Attention Language Model

The AttentionLanguageModel is the final, complete model. It orchestrates the


entire forward pass:

1.​ Token Embeddings: Converts input token indices (integers) into dense vectors
using nn.Embed.
2.​ Positional Embeddings: Creates a learnable table of positional embeddings.
Each position in the sequence (from 0 to block_size - 1) gets a unique
vector, which is added to the token embedding. This gives the model
information about the order of tokens.
3.​ Transformer Blocks: Processes the combined embeddings through a
sequence of Block modules.
4.​ Final Projection: Applies a final layer normalization and a linear layer
(nn.Dense) to project the output back to the vocabulary size, producing the
final logits (raw, unnormalized predictions for the next token).

class AttentionLanguageModel(nn.Module):
"""A decoder-only Transformer language model."""
vocab_size: int
n_embed: int
block_size: int
num_heads: int
num_blocks: int

@nn.compact
def __call__(self, idx: chex.Array) -> chex.Array:
"""
Forward pass for the language model.

Args:
idx: Input sequence of token indices, shape (B, T).
Returns:
Logits for the next token, shape (B, T, vocab_size).
"""
chex.assert_rank(idx, 2)
_B, T = idx.shape

# Token embedding lookup


tok_emb = nn.Embed(
num_embeddings=self.vocab_size, features=self.n_embed
)(idx)

# Positional embedding lookup


pos_emb_table = self.param(
'pos_emb_table',
nn.initializers.normal(),
(self.block_size, self.n_embed)
)
pos_emb = pos_emb_table[jnp.arange(T)]

x = tok_emb + pos_emb # (B, T, C)

# Core transformer blocks


x = nn.Sequential([
Block(n_embed=self.n_embed, num_heads=self.num_heads)
for _ in range(self.num_blocks)
])(x)

# Final layer norm and projection to vocab


x = nn.LayerNorm()(x)
logits = nn.Dense(self.vocab_size)(x)

chex.assert_shape(
logits, (_B, T, self.vocab_size)
)
return logits

Training logic
This section defines the core functions for training and evaluating the model. The
train_step function is the heart of the training loop. It's decorated with @jax.jit,
which Just-In-Time (JIT) compiles the entire function into highly optimized machine
code (XLA), making it extremely fast.

●​ It defines a loss_fn that takes the model parameters and a data batch, runs
the model's forward pass, and computes the cross-entropy loss.
●​ It then uses jax.value_and_grad to efficiently compute both the loss and
the gradients of the loss with respect to the parameters in a single backward
pass.
●​ Finally, state.apply_gradients uses the configured optimizer (AdamW) to
update the model's parameters.

# --- Training & Evaluation Logic ---

class TrainState(train_state.TrainState):
"""A custom train state to simplify passing arguments."""
pass

@partial(jax.jit, static_argnames=('model_apply_fn',))
def train_step(
state: TrainState, batch: Tuple[chex.Array, chex.Array],
model_apply_fn: Callable
) -> Tuple[TrainState, jnp.ndarray]:
"""Performs a single, JIT-compiled training step."""
batch_x, batch_y = batch

def loss_fn(params: Dict) -> chex.Array:


"""Calculates loss for the given parameters."""
logits = model_apply_fn({'params': params}, batch_x)
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch_y
).mean()
return loss

loss, grads = jax.value_and_grad(loss_fn)(state.params)


state = state.apply_gradients(grads=grads)
return state, loss

@partial(jax.jit, static_argnames=('model_apply_fn',))
def eval_step(
state: TrainState, batch: Tuple[chex.Array, chex.Array],
model_apply_fn: Callable
) -> chex.Array:
"""Performs a single, JIT-compiled evaluation step."""
batch_x, batch_y = batch
logits = model_apply_fn({'params': state.params}, batch_x)
return optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch_y
).mean()

Evaluate Logic

The evaluate_model function calculates the model's performance on the validation


set. It runs the (also JIT-compiled) eval_step multiple times with different batches
of validation data and averages the loss. This provides a stable estimate of how well
the model is generalizing to unseen data.

def evaluate_model(
state: TrainState, val_data: chex.Array, data_key: chex.Array
) -> float:
"""Runs evaluation over the validation set."""
total_loss = 0.0
for i in range(EVAL_STEPS):
data_key, subkey = jax.random.split(data_key)
batch_x, batch_y = get_batch(
val_data, subkey, BATCH_SIZE, BLOCK_SIZE
)
loss = eval_step(state, (batch_x, batch_y), state.apply_fn)
total_loss += loss
return total_loss / EVAL_STEPS

Data and Generation Helpers

These helper functions manage data loading and text generation. The get_batch
function is a highly efficient data sampler. It generates random starting points in the
dataset and then uses jax.vmap to perform the slicing operation in a vectorized,
parallel manner across the batch dimension. vmap is a powerful JAX transformation
that automatically converts a function designed for a single example into one that
works on an entire batch. It creates the input x and target y by taking overlapping
slices of the data.

# --- Data & Generation Helpers ---

@partial(jax.jit, static_argnames=('batch_size', 'block_size'))


def get_batch(
data: chex.Array, key: chex.Array, batch_size: int, block_size: int
) -> Tuple[chex.Array, chex.Array]:
"""Generates a random, vectorized batch of (x, y) data."""
# Generate random start indices for each sequence in the batch
ix = jax.random.randint(
key, (batch_size,), 0, len(data) - block_size
)

# Use vmap to efficiently slice sequences in parallel


sequences = jax.vmap(
lambda i: jax.lax.dynamic_slice(
data, (i,), (block_size + 1,)
)
)(ix)

x = sequences[:, :-1]
y = sequences[:, 1:]
return x, y

The generate function implements auto-regressive text generation. Instead of a


standard Python for loop, it uses jax.lax.scan, a JAX primitive for compiling
loops. This is significantly more performant on accelerators like GPUs/TPUs. In each
step of the scan, it:

1.​ Takes the current sequence of tokens.


2.​ Passes them to the model to get logits for the next token.
3.​ Samples a new token from the probability distribution defined by the logits.
4.​ Appends the new token to the sequence for the next iteration.

@partial(jax.jit, static_argnames=('apply_fn', 'max_new_tokens',


'block_size'))
def generate(
params: Dict, apply_fn: Callable, key: chex.Array,
max_new_tokens: int, block_size: int
) -> chex.Array:
"""Generates text from the model using a compiled lax.scan loop."""
initial_tokens = jnp.zeros((1, 1), dtype=jnp.uint16)

def scan_fn(carry, _):


"""Single generation step for use in lax.scan."""
key, current_tokens = carry

# Crop context to the last `block_size` tokens


cond_tokens = current_tokens[:, -block_size:]

logits = apply_fn({'params': params}, cond_tokens)


logits_last = logits[:, -1, :] # (B, vocab_size)

# Sample the next token


key, subkey = jax.random.split(key)
next_token = jax.random.categorical(subkey, logits_last)

# Append the new token to the sequence


new_tokens = jnp.concatenate(
[current_tokens, next_token[:, None]], axis=1
)
return (key, new_tokens), next_token

# Run the generation loop efficiently with lax.scan


(_, _), all_new = jax.lax.scan(
scan_fn, (key, initial_tokens), None, length=max_new_tokens
)
return all_new.flatten()

Main execution
This final part of the script contains the high-level orchestration logic. These
functions and classes tie together the model definition, training procedures, and data
utilities to execute the complete workflow, from data acquisition to generating novel
text.

This code defines a Data structure using typing.NamedTuple. It's not a function
but a custom data type that acts as a clean, organized container. It bundles all
essential data-related components into a single object, which improves code
readability and type safety.
●​ train and val: These hold the training and validation datasets as JAX numpy
arrays of integer token IDs.
●​ chars and vocab_size: These store the list of unique characters (the
vocabulary) and its size.
●​ encode and decode: These are lambda functions that provide the mapping
between strings of characters and lists of integer token IDs. encode performs
tokenization, and decode performs the reverse, de-tokenization.

# --- Main Execution ---

class Data(NamedTuple):
"""Typed container for prepared data."""
train: chex.Array
val: chex.Array
chars: List[str]
vocab_size: int
encode: Callable[[str], List[int]]
decode: Callable[[List[int]], str]

These two functions form the data ingestion and preprocessing pipeline.

●​ download_data: A simple utility that checks if the dataset file exists locally. If
not, it downloads the "Tiny Shakespeare" text from the provided URL.
●​ prepare_data: This function performs the critical steps to prepare the raw
text for the model. It reads the file, identifies all unique characters to build the
vocabulary, and creates two dictionaries: one for mapping characters to
integers (stoi, or "string to integer") and another for the reverse (itos). It
then tokenizes the entire dataset into a single long sequence of integers and
splits it into a 90% training set and a 10% validation set. Finally, it packages
everything into the Data named tuple described above.

def download_data(url: str, file_path: str) -> None:


"""Downloads data if it doesn't exist."""
if not os.path.exists(file_path):
print(f"Downloading data to {file_path}...")
os.makedirs(os.path.dirname(file_path), exist_ok=True)
response = requests.get(url, timeout=10)
response.raise_for_status()
with open(file_path, 'w', encoding='utf-8') as f:
f.write(response.text)
else:
print("Data already exists.")

def prepare_data(file_path: str) -> Data:


"""Loads and preprocesses the text data."""
with open(file_path, 'r', encoding='utf-8') as f:
text_data = f.read()

chars = sorted(list(set(text_data)))
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}

encode = lambda s: [stoi[c] for c in s]


decode = lambda l: ''.join([itos.get(i, '') for i in l])

n = len(text_data)
train_data = text_data[:int(n * 0.9)]
val_data = text_data[int(n * 0.9):]

return Data(
train=jnp.array(encode(train_data), dtype=jnp.uint16),
val=jnp.array(encode(val_data), dtype=jnp.uint16),
chars=chars,
vocab_size=len(chars),
encode=encode,
decode=decode
)

This is a crucial setup function that initializes all the components required for training
and encapsulates them within Flax's TrainState object.

1.​ Model Instantiation: It creates an instance of the AttentionLanguageModel


class, passing in the necessary hyperparameters like vocab_size and
n_embed.
2.​ Parameter Initialization: It initializes the model's parameters (weights and
biases). This is done by calling model.init(), which requires a PRNG key for
random initialization and a dummy input batch (jnp.ones(...)). The dummy
input allows Flax to infer the correct shapes for all parameters throughout the
network automatically.
3.​ Optimizer Creation: It sets up the AdamW optimizer (optax.adamw), a
standard and effective choice for training Transformer models.
4.​ State Creation: It bundles the model's forward-pass function (model.apply),
the initialized parameters (params), and the optimizer (tx) into the
TrainState object. This state object is the single source of truth for the
model's trainable components and will be passed to and updated by the
training functions.

def create_train_state(
model_cls: nn.Module, key: chex.Array, data: Data
) -> TrainState:
"""Initializes the model and creates the training state."""
model = model_cls(
vocab_size=data.vocab_size,
n_embed=N_EMBED,
block_size=BLOCK_SIZE,
num_heads=NUM_HEADS,
num_blocks=NUM_BLOCKS
)
# Initialize parameters
init_key, key = jax.random.split(key)
params = model.init(
init_key, jnp.ones((1, BLOCK_SIZE), dtype=jnp.uint16)
)['params']

# Create optimizer
tx = optax.adamw(LEARNING_RATE)

return TrainState.create(
apply_fn=model.apply, params=params, tx=tx
)

This function contains the main training and evaluation loop. It's the engine that drives
the model's learning process.

●​ It iterates for a fixed number of steps defined by TRAIN_STEPS.


●​ In each iteration, it first splits its PRNG key to get new subkeys for data
batching and evaluation, ensuring reproducibility.
●​ It calls get_batch to sample a random mini-batch from the training data.
●​ It then executes a single optimization step by calling the JIT-compiled
train_step function. This function calculates the loss, computes gradients,
and updates the model's parameters, returning the new, updated state.
●​ Periodically (controlled by EVAL_INTERVAL), it pauses to assess the model's
generalization ability. It calls evaluate_model, which computes the average
loss on several batches from the validation set. The training and validation
losses are printed to the console, allowing the user to monitor progress and
detect potential issues like overfitting.

def run_training_loop(
state: TrainState, data: Data, key: chex.Array
) -> TrainState:
"""Executes the main training and evaluation loop."""
for step in range(TRAIN_STEPS):
key, train_key, eval_key = jax.random.split(key, 3)

# Training step
batch_x, batch_y = get_batch(
data.train, train_key, BATCH_SIZE, BLOCK_SIZE
)
state, loss = train_step(state, (batch_x, batch_y),
state.apply_fn)

# Evaluation and logging


if step % EVAL_INTERVAL == 0 or step == TRAIN_STEPS - 1:
val_loss = evaluate_model(state, data.val, eval_key)
print(
f"Step {step:4d} | "
f"Train Loss: {loss:.4f} | "
f"Val Loss: {val_loss:.4f}"
)
return state

The main function serves as the script's entry point. It orchestrates the entire process
by calling the helper functions in the correct sequence. The if __name__ ==
"__main__": guard is a standard Python convention that ensures the code inside
only runs when the script is executed directly, not when it's imported as a module into
another script. The main function's workflow is:

1.​ Set up the initial JAX PRNGKey.


2.​ Call download_data and prepare_data to get the dataset.
3.​ Call create_train_state to initialize the model and optimizer.
4.​ Pass the initial state and data to run_training_loop to start training.
5.​ After training is complete, it uses the final trained parameters (state.params)
to call the generate function, produce a sample of new text, and print the
result.
def main():
"""Main function to run the data loading, training, and generation."""
file_path = './data/shakespeare_char/input.txt'
url = (
'https://raw.githubusercontent.com/karpathy/char-rnn/'
'master/data/tinyshakespeare/input.txt'
)
download_data(url, file_path)

# Setup keys for JAX's pseudo-random number generation


key = jax.random.PRNGKey(0)
model_key, train_key, gen_key = jax.random.split(key, 3)

# Prepare data, model, and state


data = prepare_data(file_path)
state = create_train_state(AttentionLanguageModel, model_key, data)

# Run training
print("Starting training...")
state = run_training_loop(state, data, train_key)
print("Training finished.")

# Generate text from the trained model


print("\n--- Generated Text ---")
generated_tokens = generate(
state.params,
state.apply_fn,
gen_key,
max_new_tokens=GENERATION_TOKENS,
block_size=BLOCK_SIZE
)
print(data.decode(generated_tokens.tolist()))
print("----------------------\n")

if __name__ == "__main__":
main()

Key Takeaways
●​ Architecture: The script implements a standard decoder-only Transformer
(GPT-style), a foundational architecture for modern large language models.
●​ JAX/Flax Ecosystem: It's a masterclass in using the JAX ecosystem effectively.
○​ @jax.jit: For Just-In-Time compilation of training and generation
functions, leading to massive speedups. ⚡
○​ jax.vmap: For automatic vectorization, enabling efficient, parallel data
batching.
○​ jax.lax.scan: For compiling loops, making the auto-regressive
generation process highly performant.
○​ jax.value_and_grad: For efficient automatic differentiation.
●​ Modern Practices: It uses pre-layer normalization, which is a key technique for
achieving stable training in deep Transformers.
●​ State Management: The use of flax.training.train_state.TrainState
provides a clean and robust way to manage all mutable aspects of training
(model parameters, optimizer state).

Conclusion
This script constitutes a complete and self-contained implementation of a
decoder-only Transformer for character-level language modeling. It effectively
utilizes the JAX and Flax libraries, demonstrating a functional programming paradigm
for neural network development. The strategic application of JAX
transformations—specifically jit for XLA compilation, vmap for implicit
vectorization, and lax.scan for efficient, stateful iteration—yields a computationally
performant solution for both model training and auto-regressive inference. The script
is a robust and didactic reference for implementing Transformer-based generative
models within the JAX ecosystem. It demonstrates clear modularity, separating the
model architecture, training procedures, and data-handling utilities.

Qwen: the Code


Introduction
To truly understand the capabilities of a sophisticated large language model like
Qwen, we must first break down the advanced techniques that form its foundation.
These architectural and computational innovations are the building blocks that enable
its remarkable performance and efficiency. We will explore six key concepts: the
Mixture of Experts (MoE) for scalable knowledge, Grouped-Query Attention (GQA)
and KV Caching for efficient memory use, Rotary Position Embeddings (RoPE) for
understanding sequence order, and RMSNorm with SwiGLU for stable and effective
computation. Each of these components plays a crucial role, tackling fundamental
challenges in building and running massive neural networks. Therefore, a clear grasp
of these individual technologies is the necessary first step before we can appreciate
how they come together in the complete Qwen model.

Mixture of Experts (MoE)


●​ The Why: Standard "dense" models get better by adding more parameters,
but this comes at a steep price: every single token requires a computation
involving all of those parameters. This makes scaling them up incredibly
expensive and slow. The goal of MoE is to drastically increase a model's
parameter count (its "knowledge") without proportionally increasing the
computational cost for each token.
●​ The What: A Mixture of Experts (MoE) layer replaces a single, large
feed-forward network with a collection of smaller "expert" networks and a
"router" network. Think of it like a large consulting firm. Instead of having one
giant meeting where every consultant (parameter) works on every client's
problem (token), you have a receptionist (the router) who quickly directs the
client to a small committee of the most relevant specialists (the experts). This
way, the firm can employ many specialists with diverse knowledge, but each
client only interacts with a few, making the process highly efficient.
●​ The How: The process happens in a few steps for each token that arrives at
the MoE layer:
1.​ Routing: The token's vector is first sent to the router network. This
router is a small, fast neural network.
2.​ Expert Selection: The router outputs a score for every expert,
indicating how suitable that expert is for the given token. A top-k
function is used to select the best few experts (e.g., the top 2).
3.​ Weighting: The scores for these selected experts are converted into
weights using a softmax function. This means the experts' final
contributions will be weighted; the #1 expert will have more influence
than the #2 expert.
4.​ Processing: The original token's vector is sent only to the selected top-k
experts. All other experts remain inactive and perform zero computation.
This is the source of the efficiency gains and is known as sparse
activation.
5.​ Aggregation: The outputs from the active experts are combined
through a weighted sum, using the weights calculated by the router. This
combined vector becomes the final output of the MoE layer.

Grouped-Query Attention (GQA)


●​ The Why: During text generation, the model needs to remember the "Key" (K)
and "Value" (V) vectors for all previous tokens to pay attention to them. This is
called the KV cache. In standard Multi-Head Attention (MHA), every attention
head has its own unique K and V vectors. For long text sequences, this KV
cache becomes enormous and can easily overwhelm a GPU's memory (VRAM),
creating a major bottleneck.
●​ The What: Grouped-Query Attention (GQA) is an architectural compromise
designed to reduce the memory size of the KV cache while retaining most of
the performance quality of standard Multi-Head Attention (MHA). It sits
between MHA and its more extreme cousin, Multi-Query Attention (MQA).
1.​ Multi-Head Attention (MHA): 16 Query heads have 16 unique Key/Value
head pairs. High quality, very high memory usage.
2.​ Multi-Query Attention (MQA): 16 Query heads all share just 1 Key/Value
head pair. Low quality, very low memory usage.
3.​ Grouped-Query Attention (GQA): 16 Query heads are put into groups
(e.g., 4 groups of 4). Each group of 4 Query heads shares 1 Key/Value
head pair. This provides a balance between quality and memory
efficiency.
●​ The How: GQA works by modifying the number of Key and Value projection
matrices.
1.​ Projection: While the model still creates many Query (Q) vectors (e.g.,
16 heads), it creates a much smaller number of Key (K) and Value (V)
vectors (e.g., 4 heads).
2.​ Grouping & Sharing: The Q heads are divided into groups. For instance,
Q heads 0-3 form Group 1, Q heads 4-7 form Group 2, and so on.
3.​ Repetition: To perform the attention calculation, the K and V vector for
Group 1 is simply duplicated or "broadcast" to be used by all four Q
heads in that group. The same happens for all other groups.
4.​ Attention Calculation: From there, the standard scaled dot-product
attention mechanism proceeds as usual. Because fewer unique K and V
vectors are generated and stored in the cache at each step, the overall
memory footprint is significantly reduced.

Rotary Position Embeddings (RoPE)


●​ The Why: Transformers are "permutation-invariant," meaning if you shuffle the
words in a sentence, the core self-attention mechanism would produce the
same results. They have no inherent sense of word order. The traditional
solution was "absolute position embeddings," where a unique vector for
position 1, position 2, etc., was added to each word vector. This works, but it
can make it hard for models to handle sequences longer than those they were
trained on and is less elegant at capturing relative positions (e.g., "the word 3
steps before this one").
●​ The What: Rotary Position Embeddings (RoPE) is a clever way to encode
positional information by rotating vectors instead of adding to them. It
integrates absolute position information in a way that naturally allows the
self-attention mechanism to focus on relative positions. The key insight is that
the attention score between two tokens at positions m and n should only
depend on their content and their relative distance, m-n. RoPE achieves this
property through the mathematics of vector rotation.
●​ The How: RoPE applies a position-dependent rotation to the Query (Q) and
Key (K) vectors before they are compared in the attention mechanism.
1.​ Pairing Dimensions: The dimensions of the Q and K vectors are split into
pairs. You can imagine each pair as coordinates on a 2D plane (x, y).
2.​ Calculating Rotation: For each token at an absolute position m, a rotation
angle θ=m⋅λ is calculated. The base wavelength λ is a fixed
hyperparameter that is different for each pair of dimensions.
3.​ Applying Rotation: Each (x, y) pair in the Q and K vectors is rotated by its
corresponding angle θ. This is done using a simple rotation matrix
multiplication.
4.​ Relative Attention: When the model calculates the attention score by
taking the dot product of a rotated Q vector (from position m) and a
rotated K vector (from position n), the properties of rotations cause the
absolute position terms to cancel out, leaving a result that is dependent
only on the tokens' content and their relative distance m-n. This allows
the model to generalize to different sequence lengths more effectively.
RMSNorm (Root Mean Square Normalization)
●​ The Why: Deep neural networks need normalization to stabilize the training
process by keeping the values of activations in a consistent range. The
standard, Layer Normalization, works well but involves calculating both the
mean and variance, which adds a bit of computational overhead. The goal of
RMSNorm is to achieve the stabilization benefits of normalization while being
simpler and faster to compute.
●​ The What: RMSNorm is a simplified version of Layer Normalization. It
normalizes the activations of a layer by their Root Mean Square (RMS) value.
Crucially, unlike LayerNorm, it does not re-center the data by subtracting the
mean. This simplification makes it computationally more efficient while still
providing the necessary stability for training and inference.
●​ The How: The process is a straightforward mathematical operation on an input
vector x:
1.​ Square and Mean: Calculate the mean of the squares of all elements in
the vector.
2.​ Root: Take the square root of that mean. This gives you the RMS value.
3.​ Normalize: Divide the original input vector x by its RMS value.
4.​ Scale: Multiply the result by a learnable "gain" parameter g. This allows
the network to scale the normalized output as needed during training.
●​ The formula is:​
y=n1​∑i=1n​xi2​+ϵ​x​⋅g​
Where ε (epsilon) is a very small number added for numerical stability to
prevent division by zero.

SwiGLU (Swish-Gated Linear Unit)


●​ The Why: The Feed-Forward Network (FFN) is a critical part of a transformer
block that adds non-linear complexity, allowing the model to learn more
intricate patterns. While the simple ReLU activation function is common,
researchers found that more sophisticated activation functions could lead to
better model performance. The goal of SwiGLU is to create a more expressive
FFN layer by using a gating mechanism to control the flow of information.
●​ The What: SwiGLU is a variant of a Gated Linear Unit (GLU) that uses the
Swish activation function. Think of it as a smart valve. The input is split into two
paths; one path becomes the "gate" that decides how much of the information
from the other path should be allowed to pass through. This gating mechanism
allows the network to learn more complex relationships in the data compared to
a simple, non-gated activation.
●​ The How: For a given input x, the SwiGLU FFN layer performs the following:
1.​ Two Projections: The input x is passed through two independent linear
transformations (like two separate nn.Linear layers), resulting in two
vectors, let's call them A and B.
2.​ Gating Activation: Vector A is passed through a Swish activation
function (also known as SiLU, or Sigmoid-weighted Linear Unit). The
Swish function is defined as Swish(A)=A⋅σ(A), where σ is the sigmoid
function.
3.​ Apply the Gate: The result of the Swish activation is then multiplied
element-wise with the other vector, B.
●​ The final output is Swish(A) * B. This gated output is then typically passed
through one more linear layer to project it back to the required dimension.

KV Caching (Key-Value Caching)


●​ The Why: Text generation with transformers is autoregressive, meaning each
new word is generated based on all the words that came before it. Without
optimization, generating the 100th word would require the model to re-process
the first 99 words from scratch, re-calculating their attention relationships. This
is incredibly redundant and computationally expensive, making real-time
generation painfully slow.
●​ The What: KV Caching is a fundamental optimization that makes
autoregressive generation fast and practical. The core idea is simple: never
recompute what you've already computed. Instead of re-processing the entire
sequence for each new token, the model stores (caches) the Key (K) and Value
(V) vectors from the attention layers for all the tokens it has already seen.
●​ The How: The process transforms a slow, multi-step calculation into a rapid,
single-step loop:
1.​ Prompt Processing (Prefill): When you provide a prompt (e.g., "The
best city in the world is"), the model processes the entire sequence
once. At each attention layer, it calculates and saves the Key and Value
vectors for every token ("The", "best", "city", ... "is") into memory. This
cache is now "primed."
2.​ Generation Step 1: To generate the next word, the model only needs to
process the very last token ("is"). It calculates the Query (Q) vector for
just this one token.
3.​ Attention with Cache: This new Q vector then "looks at" the entire
history of K and V vectors that are already stored in the cache. It
performs the attention calculation using the new Q and the cached K
and V.
4.​ Cache Update: A new word is generated (e.g., "London"). The model
calculates the K and V vectors for "London" and appends them to the
cache.
5.​ Repeat: To generate the next word, the model now only processes
"London", gets its Q vector, and attends to the updated cache which now
contains information for the entire sequence "The best city in the world
is London".
●​ This transforms the computation for each new token from being proportional to
the sequence length to being a constant-time operation, dramatically speeding
up generation.

Code Breakdown
The script is logically divided into five main sections.

Model Architecture

This section defines all the nn.Module classes that constitute the transformer model,
from individual components to the final assembled model.

MoEFeedForward: This class implements a Mixture of Experts (MoE) layer. Instead


of a single large feed-forward network, an MoE uses multiple smaller "expert"
networks.

●​ Gating: A linear layer (self.gate) acts as a "router," deciding which experts


are best suited to process each input token.
●​ Routing: torch.topk selects the top k experts for each token (k is
num_experts_per_tok). The softmax function then converts their scores
into weights.
●​ Sparse Computation: The code iterates through each expert and processes
only the tokens assigned to it. This "sparse activation" is the key to MoE's
efficiency—it allows for a huge number of parameters, but only a fraction are
used for any given token.
●​ Aggregation: The outputs of the selected experts are weighted and summed
together using index_add_, which efficiently adds the results back to their
original positions in the sequence.

# run_qwen3.py

import json
import math
import os
import re
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any

import torch
import torch.nn as nn
import torch.nn.functional as F
from huggingface_hub import snapshot_download
from safetensors.torch import load_file
from tokenizers import Tokenizer

#
--------------------------------------------------------------------
# 1. Model Architecture
#
--------------------------------------------------------------------

class MoEFeedForward(nn.Module):
"""
A performant, sparse Mixture of Experts (MoE) feed-forward
layer.

This layer routes each token to a subset of experts (`k`) and


computes the output as a weighted combination of those experts'
outputs, avoiding computation for non-selected experts.
"""

def __init__(self, cfg: Dict[str, Any]):


"""
Initializes the MoE layer.

Args:
cfg (Dict[str, Any]): Model configuration dictionary.
"""
super().__init__()
self.num_experts_per_tok = cfg["num_experts_per_tok"]
self.num_experts = cfg["num_experts"]
emb_dim = cfg["emb_dim"]
moe_dim = cfg["moe_intermediate_size"]
dtype = cfg["dtype"]

self.gate = nn.Linear(
emb_dim, self.num_experts, bias=False, dtype=dtype
)
self.stacked_fc1_w = nn.Parameter(
torch.empty(self.num_experts, moe_dim, emb_dim,
dtype=dtype)
)
self.stacked_fc2_w = nn.Parameter(
torch.empty(self.num_experts, moe_dim, emb_dim,
dtype=dtype)
)
self.stacked_fc3_w = nn.Parameter(
torch.empty(self.num_experts, emb_dim, moe_dim,
dtype=dtype)
)

def forward(self, x: torch.Tensor) -> torch.Tensor:


"""
Forward pass for the sparse MoE layer.

Args:
x (torch.Tensor): Input tensor of shape (B, T, D).

Returns:
torch.Tensor: Output tensor of shape (B, T, D).
"""
batch_size, seq_len, dim = x.shape
x_flat = x.view(-1, dim)

router_logits = self.gate(x_flat)
routing_weights, selected_experts = torch.topk(
router_logits, self.num_experts_per_tok
)
routing_weights = F.softmax(
routing_weights, dim=-1, dtype=x.dtype
)

final_hidden_states = torch.zeros_like(x_flat)
expert_mask = F.one_hot(
selected_experts, num_classes=self.num_experts
).permute(2, 1, 0)

for expert_idx in range(self.num_experts):


top_x, idx = torch.where(expert_mask[expert_idx])
if top_x.shape[0] == 0:
continue

current_states = x_flat[top_x]
current_weights = routing_weights[top_x, idx, None]

h1 = current_states @ self.stacked_fc1_w[expert_idx].t()
h2 = current_states @ self.stacked_fc2_w[expert_idx].t()
h = F.silu(h1) * h2
expert_output = h @ self.stacked_fc3_w[expert_idx].t()

expert_output *= current_weights
final_hidden_states.index_add_(0, top_x, expert_output)

return final_hidden_states.view(batch_size, seq_len, dim)

FeedForward: This is the dense, non-MoE alternative. It implements the SwiGLU


(Swish-Gated Linear Unit) architecture. The input x is projected twice (fc1, fc2).
One projection is passed through a SiLU (Sigmoid Linear Unit) activation function, and
the result is multiplied element-wise with the other projection. This gated result is
then projected back to the original dimension by fc3. This design is highly effective
and common in modern LLMs.

class FeedForward(nn.Module):
"""A standard SwiGLU feed-forward layer."""

def __init__(self, cfg: Dict[str, Any]):


super().__init__()
emb_dim = cfg["emb_dim"]
hidden_dim = cfg["hidden_dim"]
dtype = cfg["dtype"]
self.fc1 = nn.Linear(emb_dim, hidden_dim, bias=False, dtype=dtype)
self.fc2 = nn.Linear(emb_dim, hidden_dim, bias=False, dtype=dtype)
self.fc3 = nn.Linear(hidden_dim, emb_dim, bias=False, dtype=dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.fc3(F.silu(self.fc1(x)) * self.fc2(x))

RMSNorm: This implements Root Mean Square Normalization. Unlike Layer


Normalization, it only normalizes by the root mean square of the activations, without
recentering. This simplifies the calculation
(output=fracxsqrtmean(x2)+epsiloncdotscale) and has been shown to be effective
and computationally cheaper. The input x is cast to float32 during the variance
calculation for numerical stability.

class RMSNorm(nn.Module):
"""Root Mean Square Normalization."""

def __init__(self, emb_dim: int, eps: float = 1e-6):


super().__init__()
self.eps = eps
self.scale = nn.Parameter(torch.ones(emb_dim))

def forward(self, x: torch.Tensor) -> torch.Tensor:


# Cast to float32 for stability in variance calculation
variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True)
norm_x = x * torch.rsqrt(variance + self.eps)
return (norm_x * self.scale).to(x.dtype)

compute_rope_params and apply_rope: These functions handle Rotary Position


Embeddings (RoPE). Instead of adding positional information, RoPE rotates the
query and key vectors based on their absolute position.

●​ compute_rope_params pre-calculates the cosine and sine values needed for


the rotations for every possible position up to the context_length. This is
done once at model initialization.
●​ apply_rope applies the rotation. It splits each head's dimension in half (x1,
x2), rotates them, and combines them using the pre-computed cos and sin
values for the current positions. This elegantly injects relative positional
information into the self-attention mechanism.
def compute_rope_params(
head_dim: int,
theta_base: float,
context_length: int,
dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Precomputes Rotary Position Embedding parameters."""
inv_freq = 1.0 / (
theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype) /
head_dim)
)
positions = torch.arange(context_length, dtype=dtype)
angles = torch.einsum("i,j->ij", positions, inv_freq)
angles = torch.cat([angles, angles], dim=1)
return torch.cos(angles), torch.sin(angles)

def apply_rope(
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, offset:
int
) -> torch.Tensor:
"""Applies Rotary Position Embeddings to the input tensor."""
seq_len = x.shape[-2]
cos_slice = cos[offset : offset + seq_len]
sin_slice = sin[offset : offset + seq_len]

x1, x2 = x.chunk(2, dim=-1)


rotated = torch.cat([-x2, x1], dim=-1)

return (x * cos_slice) + (rotated * sin_slice)

GroupedQueryAttention: This implements Grouped-Query Attention (GQA), an


optimization over standard Multi-Head Attention (MHA). In GQA, multiple query heads
share a single key and value head. This significantly reduces the size of the Key-Value
(KV) cache, which is a major memory bottleneck during inference.

●​ The number of key/value heads (num_kv_groups) is smaller than the number


of query heads (num_heads).
●​ The repeat_interleave function is the core of GQA; it duplicates the key
and value head tensors to match the number of query heads, allowing the
standard attention computation to proceed.
●​ The class also handles KV caching: during text generation, the keys and values
from previous tokens are cached and concatenated with the new ones (k_new,
v_new), avoiding redundant computations.

class GroupedQueryAttention(nn.Module):
"""Grouped-Query Attention (GQA) layer."""

def __init__(self, cfg: Dict[str, Any]):


super().__init__()
d_in = cfg["emb_dim"]
num_heads = cfg["n_heads"]
num_kv_groups = cfg["n_kv_groups"]
dtype = cfg["dtype"]
self.num_heads = num_heads
self.num_kv_groups = num_kv_groups
self.group_size = num_heads // num_kv_groups
self.head_dim = cfg["head_dim"]
self.d_out = num_heads * self.head_dim

self.W_query = nn.Linear(d_in, self.d_out, bias=False,


dtype=dtype)
self.W_key = nn.Linear(
d_in, num_kv_groups * self.head_dim, bias=False,
dtype=dtype
)
self.W_value = nn.Linear(
d_in, num_kv_groups * self.head_dim, bias=False,
dtype=dtype
)
self.out_proj = nn.Linear(self.d_out, d_in, bias=False,
dtype=dtype)

self.q_norm = RMSNorm(self.head_dim) if cfg["qk_norm"] else


None
self.k_norm = RMSNorm(self.head_dim) if cfg["qk_norm"] else
None

def forward(
self,
x: torch.Tensor,
mask: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
start_pos: int = 0,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
b, num_tokens, _ = x.shape
q = self.W_query(x)
k_new = self.W_key(x)
v_new = self.W_value(x)

q = q.view(b, num_tokens, self.num_heads, self.head_dim)


k_new = k_new.view(b, num_tokens, self.num_kv_groups,
self.head_dim)
v_new = v_new.view(b, num_tokens, self.num_kv_groups,
self.head_dim)

if self.q_norm: q = self.q_norm(q)
if self.k_norm: k_new = self.k_norm(k_new)

q = apply_rope(q, cos, sin, offset=start_pos)


k_new = apply_rope(k_new, cos, sin, offset=start_pos)

# Update KV cache
if cache:
prev_k, prev_v = cache
k = torch.cat([prev_k, k_new], dim=1)
v = torch.cat([prev_v, v_new], dim=1)
else:
k, v = k_new, v_new

# Reshape for attention calculation


q = q.transpose(1, 2) # (B, H, T, HD)
k = k.transpose(1, 2) # (B, G, T_kv, HD)
v = v.transpose(1, 2) # (B, G, T_kv, HD)

k = k.repeat_interleave(self.group_size, dim=1)
v = v.repeat_interleave(self.group_size, dim=1)

scores = (q @ k.transpose(2, 3)) / math.sqrt(self.head_dim)


scores = scores.masked_fill(mask, -torch.inf)
weights = F.softmax(scores, dim=-1).to(v.dtype)

ctx = (weights @ v).transpose(1, 2).reshape(b, num_tokens,


-1)
return self.out_proj(ctx), (k_new, v_new)

TransformerBlock: This class combines the attention and feed-forward sub-layers


into a single block, which is the basic repeating unit of the transformer. It uses a
pre-normalization architecture: normalization is applied before the attention and
feed-forward layers. This is followed by a residual connection (x = x + ...), which
is crucial for training deep networks. The block dynamically chooses between the
MoEFeedForward and standard FeedForward layer based on the model's
configuration.

class TransformerBlock(nn.Module):
"""A single transformer block with pre-normalization."""

def __init__(self, cfg: Dict[str, Any]):


super().__init__()
self.att = GroupedQueryAttention(cfg)
ff_choice = MoEFeedForward if cfg["num_experts"] > 0 else
FeedForward
self.ff = ff_choice(cfg)
self.norm1 = RMSNorm(cfg["emb_dim"])
self.norm2 = RMSNorm(cfg["emb_dim"])

def forward(
self, x: torch.Tensor, **kwargs
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
attn_out, new_cache = self.att(self.norm1(x), **kwargs)
x = x + attn_out
x = x + self.ff(self.norm2(x))
return x, new_cache

Qwen3Model: This is the final, complete model class.

●​ It initializes the token embedding layer (tok_emb), a list of


TransformerBlocks, a final RMSNorm, and an output linear layer (out_head)
that projects the final hidden state to the vocabulary size to produce logits.
●​ It pre-computes and registers the RoPE cos and sin tables as buffers, which
are part of the model's state but are not trainable parameters.
●​ The forward method defines the data flow: token IDs are converted to
embeddings, passed through all transformer blocks (while managing the KV
cache), normalized one last time, and finally projected to logits. It also creates
the causal attention mask on-the-fly to ensure a token can only attend to
previous tokens.
●​ The custom to method is important; it ensures that when the model is moved
to a device (e.g., a GPU), the RoPE buffers are also moved and cast to the
correct data type.
class Qwen3Model(nn.Module):
"""The main Qwen3 model architecture."""

def __init__(self, cfg: Dict[str, Any]):


super().__init__()
self.cfg = cfg
self.tok_emb = nn.Embedding(cfg["vocab_size"],
cfg["emb_dim"])
self.trf_blocks = nn.ModuleList(
[TransformerBlock(cfg) for _ in range(cfg["n_layers"])]
)
self.final_norm = RMSNorm(cfg["emb_dim"])
self.out_head = nn.Linear(
cfg["emb_dim"], cfg["vocab_size"], bias=False
)

cos, sin = compute_rope_params(


cfg["head_dim"], cfg["rope_base"], cfg["context_length"],
torch.float32
)
self.register_buffer("cos", cos)
self.register_buffer("sin", sin)

def forward(
self,
in_idx: torch.Tensor,
cache: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] =
None,
start_pos: int = 0,
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor,
torch.Tensor]]]:
b, num_tokens = in_idx.shape
x = self.tok_emb(in_idx)

# Create attention mask


full_seq_len = start_pos + num_tokens
mask = torch.full(
(num_tokens, full_seq_len), True, device=x.device,
dtype=torch.bool
)
mask = torch.triu(mask, diagonal=start_pos + 1)

active_cache = cache if cache is not None else [None] *


len(self.trf_blocks)
new_cache = []
for i, block in enumerate(self.trf_blocks):
x, blk_cache = block(
x,
mask=mask,
cos=self.cos,
sin=self.sin,
start_pos=start_pos,
cache=active_cache[i],
)
new_cache.append(blk_cache)

logits = self.out_head(self.final_norm(x))
return logits, new_cache

def to(self, *args, **kwargs):


"""Move model and RoPE buffers to the correct device and
dtype."""
super().to(*args, **kwargs)
# Ensure RoPE buffers are moved correctly, as they might be
created
# with a default dtype.
device, dtype = self.tok_emb.weight.device,
self.tok_emb.weight.dtype
self.cos = self.cos.to(device=device, dtype=dtype)
self.sin = self.sin.to(device=device, dtype=dtype)
return self

Tokenizer

This section provides a wrapper for handling text tokenization. Qwen3Tokenizer:


This is a convenient wrapper around a tokenizer loaded from a file (typically
tokenizer.json from Hugging Face) using the tokenizers library.

●​ Its main purpose is to encode text into token IDs and decode them back.
●​ Crucially, it implements the _wrap_chat method, which formats a user's
prompt according to the Qwen3 chat template. This involves adding special
tokens like <|im_start|> and <|im_end|> to delineate conversation turns.
This formatting is essential for instruction-tuned models to function correctly.

#
--------------------------------------------------------------------
# 2. Tokenizer
#
--------------------------------------------------------------------
class Qwen3Tokenizer:
"""A wrapper for the Qwen3 tokenizer."""

def __init__(self, tokenizer_file_path: str, **kwargs):


self._tok = Tokenizer.from_file(tokenizer_file_path)
self._special_tokens = {t.content for t in
self._tok.get_added_vocab()}

self.apply_chat_template = kwargs.get("apply_chat_template",
True)
self.add_generation_prompt =
kwargs.get("add_generation_prompt", False)

repo_id = kwargs.get("repo_id", "")


self.eos_token_id = self._tok.token_to_id("<|im_end|>")
if not repo_id or "Base" in repo_id:
self.eos_token_id =
self._tok.token_to_id("<|endoftext|>")

def encode(self, text: str) -> List[int]:


if self.apply_chat_template:
text = self._wrap_chat(text)
return self._tok.encode(text).ids

def decode(self, ids: List[int]) -> str:


return self._tok.decode(ids)

def _wrap_chat(self, user_msg: str) -> str:


# Using a list and join is slightly more efficient for many
parts
parts = [
"<|im_start|>user\n",
user_msg,
"<|im_end|>\n",
]
if self.add_generation_prompt:
parts.append("<|im_start|>assistant\n")
return "".join(parts)

Weight Loading
This section contains helper functions to load pre-trained weights from a Hugging
Face checkpoint and adapt them to the local model's structure.

get_hf_to_local_map: This function creates a dictionary that maps the parameter


names used in the Hugging Face transformers library (e.g.,
model.layers.0.self_attn.q_proj.weight) to the more concise names used
in this custom implementation (e.g., trf_blocks.0.att.W_query.weight).

convert_hf_weights_to_qwen3: This is the core conversion utility. It iterates


through the loaded Hugging Face weights and renames them according to the
mapping. It also handles the structural differences, most notably for the MoE layers. In
the Hugging Face format, each expert's weights are stored separately. This function
gathers them and uses torch.stack to combine them into single tensors (e.g.,
stacked_fc1_w), matching the optimized structure of the local MoEFeedForward
class.

#
--------------------------------------------------------------------
# 3. Weight Loading
#
--------------------------------------------------------------------
def get_hf_to_local_map(config: Dict[str, Any]) -> Dict[str, str]:
"""Creates a mapping from HF weight names to local model
names."""
mapping = {
"model.embed_tokens.weight": "tok_emb.weight",
"model.norm.weight": "final_norm.scale",
"lm_head.weight": "out_head.weight",
}
for i in range(config["n_layers"]):
p, hp = f"trf_blocks.{i}", f"model.layers.{i}"
mapping.update({
f"{hp}.self_attn.q_proj.weight":
f"{p}.att.W_query.weight",
f"{hp}.self_attn.k_proj.weight": f"{p}.att.W_key.weight",
f"{hp}.self_attn.v_proj.weight":
f"{p}.att.W_value.weight",
f"{hp}.self_attn.o_proj.weight":
f"{p}.att.out_proj.weight",
f"{hp}.input_layernorm.weight": f"{p}.norm1.scale",
f"{hp}.post_attention_layernorm.weight":
f"{p}.norm2.scale",
})
if config.get("qk_norm", False):
mapping.update({
f"{hp}.self_attn.q_norm.weight":
f"{p}.att.q_norm.scale",
f"{hp}.self_attn.k_norm.weight":
f"{p}.att.k_norm.scale",
})
return mapping

def convert_hf_weights_to_qwen3(
hf_weights: Dict[str, torch.Tensor], config: Dict[str, Any]
) -> Dict[str, torch.Tensor]:
"""Converts Hugging Face weights to the local model's format."""
state_dict = {}
hf_to_local = get_hf_to_local_map(config)
is_moe = config["num_experts"] > 0

for hf_name, w in hf_weights.items():


if hf_name in hf_to_local:
state_dict[hf_to_local[hf_name]] = w

# Handle special cases: MoE vs. standard FeedForward


for i in range(config["n_layers"]):
p, hp = f"trf_blocks.{i}", f"model.layers.{i}"
ff_p, hf_ff_p = f"{p}.ff", f"{hp}.mlp"
if is_moe:
state_dict[f"{ff_p}.gate.weight"] =
hf_weights[f"{hf_ff_p}.gate.weight"]
e = config["num_experts"]
fc1 =
[hf_weights[f"{hf_ff_p}.experts.{j}.gate_proj.weight"] for j in
range(e)]
fc2 =
[hf_weights[f"{hf_ff_p}.experts.{j}.up_proj.weight"] for j in
range(e)]
fc3 =
[hf_weights[f"{hf_ff_p}.experts.{j}.down_proj.weight"] for j in
range(e)]
state_dict[f"{ff_p}.stacked_fc1_w"] = torch.stack(fc1)
state_dict[f"{ff_p}.stacked_fc2_w"] = torch.stack(fc2)
state_dict[f"{ff_p}.stacked_fc3_w"] = torch.stack(fc3)
else:
state_dict[f"{ff_p}.fc1.weight"] =
hf_weights[f"{hf_ff_p}.gate_proj.weight"]
state_dict[f"{ff_p}.fc2.weight"] =
hf_weights[f"{hf_ff_p}.up_proj.weight"]
state_dict[f"{ff_p}.fc3.weight"] =
hf_weights[f"{hf_ff_p}.down_proj.weight"]

return state_dict

Text Generation
This section defines the function responsible for generating text.

generate_text_stream: This function implements the autoregressive text


generation process.

●​ Streaming: It is a Python generator (using yield), which allows it to produce


one token at a time. This is ideal for interactive applications.
●​ Inference Loop: It first processes the entire input prompt. Then, it enters a
loop where it:
1.​ Selects the most likely next token using greedy decoding
(torch.argmax).
2.​ yields the generated token.
3.​ Feeds only that single token back into the model for the next step, along
with the cache containing the KV states of all previous tokens.
●​ Efficiency: The use of torch.inference_mode() disables gradient
calculations, and the KV cache ensures that the model only performs a forward
pass on one new token at each step, making generation much faster.

#
--------------------------------------------------------------------
# 4. Text Generation
#
--------------------------------------------------------------------
def generate_text_stream(model: Qwen3Model, **kwargs) ->
torch.Tensor:
"""
Generates text token by token in a streaming fashion.

Args:
model (Qwen3Model): The model to use for generation.
**kwargs: Must include 'token_ids', 'max_new_tokens',
'eos_token_id'.

Yields:
torch.Tensor: The next generated token.
"""
token_ids = kwargs["token_ids"]
max_new_tokens = kwargs["max_new_tokens"]
eos_token_id = kwargs.get("eos_token_id")

# Use inference_mode for better performance than no_grad


with torch.inference_mode():
logits, cache = model(token_ids, cache=None, start_pos=0)
current_pos = token_ids.shape[1]
next_token = torch.argmax(logits[:, -1], dim=-1,
keepdim=True)

for i in range(max_new_tokens):
if eos_token_id and torch.all(next_token ==
eos_token_id):
break
# Yield the token before the next model call
yield next_token.clone()

logits, cache = model(next_token, cache=cache,


start_pos=current_pos)
current_pos += 1
next_token = torch.argmax(logits[:, -1], dim=-1,
keepdim=True)

Main Execution
This is the main driver block that orchestrates the entire process.

●​ Setup: It defines the model's hyperparameters in a configuration dictionary,


sets the computing device, and seeds for reproducibility.
●​ Memory-Efficient Initialization: A key optimization is with
torch.device("meta"):. This creates the model's structure on a "meta"
device, which defines the architecture without allocating any memory for the
weights. This is crucial for loading very large models, as it prevents the system
from running out of CPU RAM.
●​ Weight Loading: It uses snapshot_download to fetch the model files from
Hugging Face. For sharded models (weights split across multiple files), it reads
the model.safetensors.index.json file to identify and load all necessary
weight files.
●​ Model Materialization: The converted weights are loaded into the
meta-initialized model using model.load_state_dict(...,
assign=True). The assign=True argument is necessary to populate the
parameters of the unmaterialized model. The model is then moved to the target
device (cuda or cpu), at which point memory is finally allocated.
●​ Execution: Finally, it initializes the tokenizer, encodes a prompt, and calls the
generate_text_stream function, printing each decoded token to the
console as it is generated.

# --------------------------------------------------------------------
# 5. Main Execution
# --------------------------------------------------------------------
if __name__ == "__main__":
QWEN3_CONFIG = {
"repo_id": "Qwen/Qwen2-1.5B-Instruct", # Smaller model for easier
testing
"vocab_size": 151936, "context_length": 32768, "emb_dim": 2048,
"n_heads": 16, "n_layers": 28, "head_dim": 128, "qk_norm": True,
"n_kv_groups": 2, "rope_base": 1000000.0, "dtype": torch.bfloat16,
"num_experts": 0, "num_experts_per_tok": 0,
"moe_intermediate_size": 0,
"hidden_dim": 5504, # For non-MoE FF
}
repo_id = QWEN3_CONFIG["repo_id"]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


print(f"Using device: {device}")
torch.manual_seed(123)

# Use 'meta' device to initialize model without loading weights


# This saves CPU RAM significantly for large models.
with torch.device("meta"):
model = Qwen3Model(QWEN3_CONFIG)

print("Downloading and loading model weights...")


repo_dir = Path(snapshot_download(repo_id=repo_id,
allow_patterns=["*.safetensors*", "*.json"]))

# Load all sharded weights into a single dictionary


index_path = repo_dir / "model.safetensors.index.json"
with open(index_path) as f:
index = json.load(f)

hf_weights = {}
for filename in set(index["weight_map"].values()):
hf_weights.update(load_file(repo_dir / filename))

converted_weights = convert_hf_weights_to_qwen3(hf_weights,
QWEN3_CONFIG)
model.load_state_dict(converted_weights, strict=False, assign=True)
model.to(device=device, dtype=QWEN3_CONFIG["dtype"])
model.eval()
print("Model loaded successfully.")

tokenizer = Qwen3Tokenizer(
tokenizer_file_path=str(repo_dir / "tokenizer.json"),
repo_id=repo_id,
add_generation_prompt=True,
)

prompt = "Implement a binary search function in Python"


input_ids = tokenizer.encode(prompt)
input_tensor = torch.tensor(input_ids, device=device).unsqueeze(0)

print(f"\n--- Prompt ---\n{prompt}\n\n--- Generation ---")


generated_tokens = []
for token in generate_text_stream(
model=model,
token_ids=input_tensor,
max_new_tokens=200,
eos_token_id=tokenizer.eos_token_id,
):
token_list = token.squeeze(0).tolist()
generated_tokens.extend(token_list)
print(tokenizer.decode(token_list), end="", flush=True)
print("\n")

Key Takeaways
●​ Modern LLM Architecture: The code is a textbook example of a modern
transformer, incorporating RMSNorm, SwiGLU feed-forward networks, Rotary
Position Embeddings (RoPE), and Grouped-Query Attention (GQA).
●​ Mixture of Experts (MoE): It provides a clear, functional implementation of a
sparse MoE layer, showcasing how to efficiently route tokens to experts.
●​ KV Caching: The implementation correctly uses a Key-Value cache, which is
the most critical optimization for fast autoregressive text generation.
●​ Hugging Face Interoperability: It demonstrates the practical steps required to
load pre-trained weights from a standard repository like Hugging Face into a
custom model implementation, including name mapping and structural
adaptation.
●​ Memory Efficiency: The use of the meta device for model initialization is a
professional technique for handling large models without requiring massive
amounts of CPU RAM.
●​ Streaming Inference: The use of a generator for token-by-token output is a
best practice for building responsive, real-time AI applications.

Conclusions
This script provides a comprehensive and robust implementation of a modern Large
Language Model, constructed directly from fundamental PyTorch components. It not
only builds the model but also tackles the engineering hurdles of loading pre-trained
weights and optimizing for efficient inference. This makes it a resource for anyone
seeking to move beyond high-level library abstractions and gain a practical, in-depth
understanding of how transformer models, such as Qwen3, are developed and
operated.

From LLM to LRM: What’s Changed


Under the Hood?
Introduction: An Era of Refinement
In the rapidly advancing field of artificial intelligence, the seven years since the advent
of the original GPT architecture represent a vast expanse of time. A casual observer,
comparing an early model like GPT-2 (2019) to the flagship models of 2025 such as
Llama 4 or DeepSeek-V3, might conclude that little has fundamentally changed. The
core component, the transformer block, remains the steadfast heart of these systems.

However, this surface-level similarity belies a period of intense and sophisticated


refinement. While we haven't witnessed a complete architectural revolution, the past
few years have been defined by a series of crucial, incremental innovations. These are
not mere cosmetic polishes; they are targeted solutions to the core challenges of
scaling: computational efficiency, memory management, and training stability.
Positional embeddings have evolved from absolute to rotational (RoPE), Multi-Head
Attention has largely been supplanted by the more efficient Grouped-Query Attention,
and SwiGLU has become the activation function of choice over predecessors like
GELU.

This chapter delves beneath the surface to explore these architectural developments.
Rather than focusing on benchmark scores, we will dissect the structural anatomy of
today's leading open models to understand the key engineering decisions that enable
their remarkable capabilities. We will examine why these changes were necessary,
what they entail, and how they are implemented, revealing a clear trend towards
building larger, more knowledgeable models that remain practical to deploy and run.

Key Architectural Innovations of the Modern Era


The evolution of LLM architecture is driven by a dual-pronged objective: increasing a
model's capacity to absorb knowledge and reasoning ability, while simultaneously
constraining the immense computational and memory costs associated with
inference. Two major trends have emerged as the primary means to achieve this
balance: sparsity through Mixture-of-Experts and the hyper-optimization of the
attention mechanism.

The Rise of Sparse Models: Mixture-of-Experts (MoE)


The core challenge in scaling LLMs is that making a model "smarter" has traditionally
meant making it "denser"—every single parameter is used for every single token
processed. This becomes prohibitively expensive at scale.

●​ Why MoE? The goal of MoE is to massively increase a model's total parameter
count—and thus its knowledge capacity—without a proportional increase in the
computational cost of inference. It allows a model to be, for instance, 700
billion parameters in total, while only using a fraction of those (e.g., 40 billion)
for any given token.
●​ What is MoE? An MoE layer replaces the standard FeedForward Network (FFN)
within a transformer block. Instead of one large FFN, an MoE layer contains a
large number of smaller FFNs, called "experts," and a small "router" network.
●​ How does it work? For each token that enters the MoE layer, the router
dynamically selects a small subset of experts (e.g., 2 to 9 out of hundreds) to
process it. All other experts remain inactive. This is why MoE models are
considered sparse; only a fraction of the network is activated at any one time.
This keeps inference fast, even though the model's total size is enormous.
Some models, like DeepSeek-V3, also incorporate a shared expert that is
activated for every token, handling common patterns and allowing the other
experts to specialize more effectively.

Evolving the Attention Mechanism: Beyond Multi-Head


The self-attention mechanism, while powerful, is computationally demanding,
particularly due to the size of the Key-Value (KV) cache required during inference. The
KV cache stores the key and value vectors for each token, and its memory footprint
grows linearly with the sequence length. Several innovations aim to mitigate this.

Grouped-Query Attention (GQA)

The Why: The primary motivation behind Grouped-Query Attention (GQA) is to tackle
a major efficiency bottleneck in the original Multi-Head Attention (MHA)
mechanism: the size of the Key-Value (KV) cache. In MHA, every single attention
"head" has its own unique key (K) and value (V) projection matrices. During inference,
the calculated key and value vectors for every token must be stored in high-speed
GPU memory (the KV cache) to generate subsequent tokens. As the number of heads
and the length of the input sequence grow, this cache consumes a massive amount of
memory, quickly becoming the limiting factor for model speed and the maximum
context length a model can handle.

The What: GQA offers an elegant and effective compromise to this problem. The core
idea is to make multiple query heads share the same key and value heads. Instead of
a one-to-one relationship between a query and its key/value pair, GQA creates a
many-to-one relationship. You can think of it like an office with many project
managers (query heads). In MHA, each manager has their own dedicated research
assistant (key/value head). In GQA, several managers are grouped together and share
a single, highly competent research assistant. This reduces the total number of
assistants needed, freeing up resources without significantly slowing down the
managers' work.

The How: Technically, GQA works by creating a number of "key-value groups" that is
smaller than the total number of query heads. For example, a model might have 32
query heads but only 4 key-value groups. In this setup, query heads 1 through 8 would
all use the key and value projections from the first group, heads 9 through 16 would
use the second group, and so on.

During inference, the model only needs to compute and store the key and value
vectors for the 4 groups, rather than for all 32 heads. When the attention scores are
calculated, the queries from all 32 heads perform their calculations against the keys of
their assigned group. This dramatically reduces the memory bandwidth required to
read and write to the KV cache and shrinks the cache's overall size, leading to faster
inference and lower memory usage. Due to its excellent balance of performance and
efficiency, GQA has become the de facto standard for modern, high-performance
LLMs.

Multi-Head Latent Attention (MLA)

The Why: Multi-Head Latent Attention (MLA) is driven by the same fundamental goal
as GQA: reducing the memory footprint of the KV cache. However, it approaches the
problem from a different angle. Instead of reducing the number of distinct key and
value heads, MLA aims to reduce the size of the information being stored for each
head. The inventors of MLA sought a method that could provide memory savings while
potentially preserving more of the model's expressive capacity than GQA, based on
studies suggesting it could lead to better modeling performance.

The What: The central concept of MLA is compression. Before the key and value
vectors are written to the KV cache, they are compressed into a lower-dimensional, or
"latent," representation. It’s analogous to taking detailed, verbose meeting notes (the
original K and V vectors) and then writing a dense, compressed summary (the latent
representation) to save space. This summary contains the most critical information
but in a much more compact form. When needed, the summary can be expanded
back into a more detailed version for use.

The How: MLA introduces an extra step into the attention pipeline. After the model
computes the initial key and value vectors for each head, it passes them through an
additional linear projection layer. This layer acts as a compression matrix, reducing
their dimensionality (e.g., from a dimension of 128 down to 64). It is this smaller,
compressed tensor that is stored in the KV cache, saving a significant amount of
memory.
When the model needs to perform the attention calculation, it retrieves the
compressed tensor from the cache and passes it through a second, different linear
projection layer that acts as a decompression matrix. This projects the vector back to
its original dimensionality. While this adds a small computational overhead (two extra
matrix multiplications), the savings in memory bandwidth and storage can be
substantial, especially for models processing very long sequences.

Sliding Window Attention

The Why: Both GQA and MLA are optimizations for the standard "global attention"
mechanism, where every token can look at every other token in the sequence. This
global attention has a computational complexity that scales quadratically with the
sequence length (O(n2)), making it incredibly slow and expensive for very long
documents. Furthermore, the KV cache grows linearly with the sequence length,
meaning a 100,000-token context requires storing KV pairs for all 100,000 tokens.
Sliding Window Attention was created to break free from these scaling limitations and
make processing extremely long contexts practical.

The What: The core idea of Sliding Window Attention is to replace global attention
with local attention. It operates on the assumption that for many tasks, a token's
meaning is most influenced by its immediate context. Therefore, instead of allowing a
token to attend to the entire sequence, its attention is restricted to a fixed-size
window of its most recent neighbors. Imagine reading an article: instead of re-reading
the entire text from the beginning to understand each new sentence, you primarily
focus on the current paragraph. The sliding window mimics this behavior.

The How: Sliding Window Attention is implemented by modifying the attention mask.
For a token at position t and a configured window size w (e.g., 1024 tokens), the mask
is constructed to only allow the token to attend to other tokens within the range of
[t-w, t]. It is completely blocked from "seeing" any tokens that came before this
window.

This simple change has profound implications for efficiency. The number of
computations per token no longer depends on the total sequence length n, but only
on the fixed window size w. This reduces the complexity from quadratic to linear
(O(n×w)). It also means the KV cache can be heavily optimized. Since tokens outside
the window are never accessed again, the model only needs to keep the key and value
vectors for the most recent w tokens in memory, allowing it to process sequences of
virtually unlimited length with a fixed, predictable memory footprint.

The Subtle Art of Normalization and Positional Signals


Beyond the major components, smaller tweaks to normalization and positional
encodings have yielded significant gains in training stability and generalization.

Normalization: The Art of Stability

The Why: The fundamental reason normalization layers are critical in deep neural
networks, especially in massive transformers, is to ensure training stability. As
information passes through many layers of a network, the magnitude of the
activations can either explode to incredibly large values or shrink to almost zero. This
phenomenon, known as the exploding or vanishing gradients problem, makes it nearly
impossible for the model to learn effectively, as the updates to its parameters become
either chaotic or non-existent. Normalization acts as a regulator, keeping the
activations within a stable, "well-behaved" range at each layer, allowing for a
smoother and more reliable training process. ⚖️
The What: The two key innovations in normalization are the adoption of a more
efficient layer type, RMSNorm, and the strategic placement of these layers.
RMSNorm (Root Mean Square Normalization) is a simplified version of the traditional
LayerNorm. It's computationally cheaper because it omits the mean-subtraction step
(re-centering) and uses only a single learnable scaling parameter instead of two.

Beyond the type of layer, its placement has become a crucial design choice.
Pre-Norm, where the normalization layer is placed before the attention and
feed-forward modules, became the standard for its high stability. Post-Norm, the
original method where normalization occurs after the modules and the residual
connection, is powerful but harder to train. Modern models experiment with these
placements: OLMo 2 uses a specific flavor of Post-Norm that improves its stability,
while Gemma 3 uses a "belt and suspenders" approach, placing RMSNorm layers both
before and after its main modules. A further refinement is QK-Norm, which is a
dedicated RMSNorm layer applied directly to the query and key vectors to stabilize
the attention scores themselves.
The How: In practice, RMSNorm works by first calculating the root mean square (a
type of average magnitude) of a vector of activations. It then divides the entire vector
by this value, effectively normalizing its scale. Finally, it multiplies the result by a single
learnable scaling parameter. This process ensures the output vector has a consistent
magnitude without altering its direction.

Regarding placement, a Pre-Norm layer takes the input to a block (e.g., the
multi-head attention block), normalizes it, and then passes it into the block. In
contrast, a Post-Norm layer takes the output of the block after the residual (skip)
connection has been added and normalizes that final sum. The key distinction in
OLMo 2's flavor of Post-Norm is keeping the normalization step inside the residual
path, which was found to prevent the gradient issues associated with traditional
Post-Norm. Finally, QK-Norm is simply an RMSNorm layer inserted inside the attention
mechanism, applied to the query and key tensors just before they are multiplied
together to calculate attention scores.

No Positional Embeddings (NoPE)

The Why: The self-attention mechanism at the heart of a transformer is


"permutation-invariant," meaning it treats an input sentence as an unordered "bag of
words." To fix this, models have always required positional encodings—explicit
signals that tell the model where each token is in the sequence (e.g., this is the 1st
word, this is the 2nd, etc.). No Positional Embeddings (NoPE) was born from a simple
but profound question: is this explicit signal truly necessary? The motivation is to
explore whether a model can learn sequence order implicitly from the architecture
itself. If so, this could simplify the model, remove a component that needs to be
carefully designed, and potentially help the model generalize better to sequences
longer than any it saw during training.

The What: NoPE is exactly what its name suggests: the intentional removal of explicit
positional encoding layers like absolute positional embeddings or the more modern
Rotational Positional Embeddings (RoPE). It is a design philosophy that bets on the
model's ability to infer sequential order from more fundamental structural cues rather
than having it handed to it directly. This represents a significant departure from nearly
a decade of conventional transformer design, proposing a more elegant and
potentially more powerful architecture.
The How: Even without an explicit positional encoding layer, the model is not
completely blind to the order of tokens. The key to NoPE lies in a component that has
always been present in decoder-style LLMs: the causal attention mask. This mask
enforces a strict directional flow of information. For any given token at position t, the
causal mask prevents it from attending to, or "seeing," any future tokens at positions
greater than t. It can only look backwards at tokens in positions ≤ t.

This inherent left-to-right, autoregressive structure provides a powerful, if implicit,


sense of order and direction. The NoPE hypothesis is that during its vast training
process, the model can learn to leverage this structural constraint to understand the
relationships between tokens in a sequence. For example, it can learn that the token
immediately to its left is the most recent one, and the token two steps to its left came
before that, all without needing an explicit "position 5" or "position 6" tag. The
experimental approach taken by SmolLM3, which applies NoPE in only every fourth
layer, shows that this is still an emerging idea being explored with caution.

Sliding Attention Window

Sliding Window Attention is an efficiency-focused modification of the standard


attention mechanism in transformers.

The Why: The primary motivation for Sliding Window Attention is to solve the
immense computational and memory cost of standard "global" attention. In a regular
transformer, every token attends to every other token in the sequence. This creates a
computational complexity that scales quadratically (O(n 2 )) with the sequence length
n, making it extremely slow and expensive for long documents. Furthermore, the
Key-Value (KV) cache, which stores attention data in memory, grows linearly with the
sequence length, becoming a major bottleneck for models processing large contexts.

The What: The core idea is to replace inefficient global attention with more
economical local attention. Instead of allowing a token to "see" the entire history of
the text, its attention is restricted to a smaller, fixed-size "window" of its most recent
neighbors. It’s like reading a book by focusing only on the current paragraph rather
than re-reading the whole book to understand each new sentence. This technique,
first introduced in the LongFormer paper, has been used by popular models like
Mistral, Gemma, and even, as was later revealed, GPT-3.

The How: Sliding Window Attention is implemented by modifying the attention mask.
For a token at position t and a configured window size w (e.g., 128 tokens), the mask is
structured to only permit attention calculations between that token and other tokens
within the range of [t-w, t]. It is effectively blind to any tokens that appeared before
this local window. This simple change dramatically improves efficiency. The
computational cost becomes linear (O(n×w)) instead of quadratic, as each token only
interacts with a fixed number of other tokens. It also means the KV cache can be fixed
in size, as the model only needs to store the attention data for the most recent w
tokens. Interestingly, models often don't use this technique in every layer. Instead, they
alternate between full global attention layers and local sliding-window layers (e.g., in a
1:1 or 5:1 ratio), allowing the model to efficiently capture local patterns while
periodically integrating global context.

A Tour of 2025's Flagship Architectures


While the high-level trends describe the shared toolkit of the modern LLM architect,
the true artistry lies in how these tools are selected, combined, and tuned. Each
flagship model family of 2025 represents a distinct architectural philosophy, offering a
unique solution to the enduring challenges of balancing scale, performance, and
efficiency.

DeepSeek-V3

The DeepSeek architecture solves the AI scaling dilemma of building knowledgeable


models that remain computationally practical. Its design relies on two key
technologies: Mixture-of-Experts (MoE) and Multi-Head Latent Attention (MLA).

The MoE system decouples a model's total size from its inference cost. Unlike dense
models where all parameters are always active, DeepSeek-V3 employs 256 "expert"
networks but activates only a small fraction—just nine—for any given token. This
sparse activation means that while the model has a massive 671-billion-parameter
capacity, its computational load is equivalent to a much smaller 37-billion-parameter
model, which is the key to its efficiency.
DeepSeek further refines its MoE system with a shared expert that processes every
token. This expert handles common patterns like grammar and basic facts. This
efficient design frees up the other eight specialized experts to focus their capacity on
learning more niche and complex information, increasing the model's overall
intelligence.

Complementing MoE is the choice of Multi-Head Latent Attention (MLA) over the
more common Grouped-Query Attention (GQA). Both techniques aim to reduce the
memory footprint of the attention mechanism's KV cache. GQA simplifies the
architecture by reducing key/value heads. In contrast, MLA preserves all attention
heads but compresses the key and value tensors into a lower-dimensional space
before caching. While this adds a minor computational step, the DeepSeek team
chose MLA because research showed it delivers superior modeling performance,
making the trade-off for higher fidelity worthwhile.

The most crucial step for DeepSeek-V3's reasoning is what happens after its initial
pre-training. The core of its reasoning power comes from a process where the
capabilities of a specialized reasoning model, DeepSeek-R1, are "distilled" into the
general-purpose DeepSeek-V3. DeepSeek-R1 is trained using advanced
Reinforcement Learning (RL) to excel at step-by-step thinking. This distillation process
effectively teaches DeepSeek-V3 the "verification and reflection" patterns inherent to
DeepSeek-R1. This means the model learns not just to provide an answer, but to
internally "think through" the problem, check its steps, and correct itself, mimicking a
more human-like reasoning process. The model's API even provides access to this
Chain-of-Thought content, allowing users to see the step-by-step logic it used to
arrive at the final answer. The model was pre-trained on a massive dataset of 14.8
trillion tokens with a heavy emphasis on code (87%), which inherently contains logical
structures and problem-solving patterns that are foundational to strong reasoning
skills.

OLMo 2
The OLMo 2 model series, from the non-profit Allen Institute for AI, prioritizes
transparency, replicability, and training stability over chasing top benchmark scores,
making it a valuable scientific artifact. Its architecture combines a conservative
transformer design with targeted innovations aimed at reliably training massive
models. ⚖️
OLMo 2's primary innovation is its unique approach to normalization. While most
models use a stable Pre-Normalization scheme, OLMo 2 revisits the original,
performance-oriented Post-Normalization method, which is typically harder to train. It
implements a clever hybrid: its RMSNorm layers are placed after the main
computational blocks but crucially remain inside the residual path. This placement
aims to capture the "best of both worlds" by maintaining training stability while
allowing the main information stream to flow unimpeded, preserving the potential
performance benefits of the Post-Norm design.

This focus on stability extends to the attention mechanism with the inclusion of
QK-Norm. Unstable magnitudes in the query (Q) and key (K) vectors can disrupt
attention, making it either too sharp or too diffuse. QK-Norm mitigates this directly at
the source. It is an additional RMSNorm layer applied to the Q and K vectors right
before their dot product is computed, ensuring the inputs to the attention calculation
are always well-scaled.

The combination of OLMo 2's unique Post-Norm layout and the targeted use of
QK-Norm demonstrates a deep, first-principles approach to model design, providing a
robust and stable blueprint for the entire research community.

Gemma 3
Google's Gemma 3 takes a different path to efficiency. Instead of employing MoE, its
architecture is engineered to excel at handling long sequences of text by minimizing
the memory footprint of the attention mechanism's KV cache. Its defining feature is
the strategic and heavy use of Sliding Window Attention.

The architecture employs a 5:1 ratio of local to global attention; for every five layers
that use a restrictive sliding window, only one layer performs full, global attention. In
Gemma 3, this local window was further constrained to just 1024 tokens, a significant
reduction from its predecessor. This design dramatically curtails the growth of the KV
cache, making Gemma 3 exceptionally memory-efficient and well-suited for
applications requiring a long context, even on consumer-grade hardware. Ablation
studies have shown this has a minimal impact on overall modeling performance,
making it a highly effective trade-off.
Furthermore, Gemma 3 employs its own distinct normalization strategy, placing an
RMSNorm layer both before and after its attention and feed-forward blocks. This
dual-normalization approach can be seen as a "belt and suspenders" method to
ensure gradient stability throughout the network.

It's a common misconception that models like Gemma 3 have a distinct "reasoning
module." Instead, reasoning is an emergent capability that arises from a
combination of the model's vast scale, its training data, and its core architectural
design. Gemma 3's architecture is a prime example of engineering for efficiency,
which in turn creates the necessary foundation for complex reasoning to develop and
operate effectively. Rather than employing Mixture-of-Experts (MoE), it focuses on
hyper-optimizing its attention mechanism to handle and process vast amounts of
information—a prerequisite for any sophisticated cognitive task.

Llama 4: Mainstreaming the Mixture-of-Experts

With the introduction of Llama 4, Meta has solidified the Mixture-of-Experts (MoE)
paradigm's place in mainstream industrial applications. While both Llama 4 and
DeepSeek utilize sparsity, their implementations diverge. The 400-billion-parameter
Llama 4 Maverick employs a more traditional MoE setup, featuring fewer but larger
experts.

A key architectural difference lies in Llama 4's approach of alternating MoE layers with
standard dense feed-forward layers throughout its network. This contrasts with
DeepSeek's strategy of incorporating MoE layers in nearly every transformer block.
This design choice by Meta likely aims to balance expert specialization with
generalized knowledge representation. For its attention mechanism, Llama 4 sticks to
the widely adopted Grouped-Query Attention (GQA), benefiting from its proven
efficiency and optimized kernels like FlashAttention, instead of the more complex MLA.

Llama 4's advanced reasoning capabilities stem from a synergy of its core
architecture, vast scale, and a sophisticated, multi-stage training pipeline designed to
foster complex problem-solving. The architecture provides the fundamental
framework for high-level reasoning to emerge.
At its core, Llama 4 leverages a sparse MoE architecture. For instance, the Llama 4
Maverick model has a total of 400 billion parameters, yet only activates 17 billion for
any given task. This allows the model to house a vast, specialized knowledge base
within its "experts." When faced with a complex reasoning problem, the model's
gating network can dynamically direct the query to the most relevant experts (e.g.,
one specializing in mathematical logic, another in code syntax), combining their
strengths to formulate a solution.

The Llama 4 Scout model boasts an industry-leading context window of up to 10


million tokens. This immense capacity is crucial for reasoning over extensive
information, such as summarizing multiple documents, analyzing entire codebases, or
maintaining a long, coherent chain of thought without losing track of earlier details.

Designed for "early fusion," Llama 4 processes text and images together from the
outset. This enables "unified reasoning" across different data types. For example, it
can analyze a chart or diagram and reason about the presented data, a task that
demands the integration of visual perception with logical and mathematical
understanding.

Beyond its architecture, Llama 4 undergoes a stringent training process specifically


engineered to enhance its reasoning abilities. During the initial instruction-tuning
phase, Meta filters out a majority of "easy" examples using other Llama models as
judges. This forces the model to train exclusively on challenging, high-signal data,
thereby refining its ability to tackle complex reasoning tasks without overfitting on
simple conversational patterns.

Llama 4 is continuously trained using reinforcement learning with a curriculum of


"hard prompts." This online process helps maintain and improve its capabilities in
reasoning, coding, and math over time.

Both the Llama 4 Scout and Maverick models are "distilled" from a significantly larger
and more powerful teacher model, the 2-trillion-parameter Llama 4 Behemoth.
Behemoth is specifically designed to excel at high-difficulty reasoning tasks, and its
problem-solving techniques are transferred to the smaller models, substantially
boosting their own reasoning power.

Meta's Llama 4 release in April 2025 was met with considerable controversy, primarily
due to concerns regarding benchmark manipulation, underwhelming performance,
and restrictive licensing. The most intense criticism arose when it was revealed that
Meta had submitted a non-public, specially tuned version of Llama 4 to the LM Arena
leaderboard to generate hype. Consequently, once the public model was evaluated,
its ranking plummeted, with Maverick dropping out of the top 30.

Beyond benchmarks, users reported that the model's real-world reasoning and
coding abilities were inconsistent and lagged behind competitors. Despite featuring a
massive context window, its performance was found to degrade on tasks far smaller
than advertised, failing to deliver on its long-context promise. The license also drew
criticism for not being truly open source, as it mandates companies with over 700
million monthly active users to obtain a separate commercial license. Most notably, a
contentious clause explicitly prohibits individuals and companies based in the
European Union from using the multimodal versions of the model. These combined
issues led to a widespread perception that the model was rushed and its capabilities
were overhyped.

Qwen3: The Hallmark of Versatility


The Qwen3 series from Alibaba Cloud stands out for its remarkable versatility, offering
a spectrum of models tailored to different needs. The smaller, dense models are
architecturally characterized as being "deep and narrow." Compared to a model like
Llama 3, the 0.6B Qwen3 model has more transformer layers but a smaller hidden
dimension and fewer attention heads. This results in a model with a very low memory
footprint, though the increased number of sequential operations can lead to a slightly
lower token generation speed.

The larger MoE models in the Qwen3 series are architecturally very similar to
DeepSeek-V3, illustrating a convergence on successful design patterns at the high
end. However, a noteworthy distinction is that the Qwen3 MoE models omit the
shared expert. According to the development team, the performance benefits of a
shared expert were not significant enough in their setup to justify the added
complexity for inference optimization. This practical decision provides a fascinating
glimpse into the real-world trade-offs that guide LLM design.

The advanced reasoning capabilities of the Qwen3 series are not the product of a
single, discrete mechanism but rather an emergent property cultivated through a
synergistic strategy, combining a purpose-built architectural foundation with a highly
specialized training and fine-tuning regimen. This approach moves beyond simple
architectural tweaks, treating reasoning as a sophisticated skill that must be
deliberately nurtured through both hardware-aware design and data-centric
instruction. The result is a family of models where the capacity for logic is
methodically developed, rather than being an accidental byproduct of scale.

The core architecture of Qwen3 is engineered to provide the necessary computational


substrate for sophisticated cognitive tasks, ensuring the model has both the structure
and capacity to handle multi-step logical problems.

For its smaller, dense models, Qwen3 employs a "deep and narrow" design
philosophy. By increasing the number of transformer layers relative to comparable
models, the architecture provides a longer computational pathway. This increased
depth is advantageous for problems requiring sequential, multi-step logical
deduction, as it allows for a more extended series of transformations and refinements
of information at each stage of the problem-solving process.

For its largest models, Qwen3 leverages a Mixture-of-Experts (MoE) architecture to


achieve immense knowledge capacity without a proportional increase in inference
cost. This can be understood as a form of computational division of labor. A
sophisticated "gating network" acts as a triage system, analyzing an incoming
problem and dynamically routing it to the most relevant "experts"—specialized
sub-networks within the model. This allows Qwen3 to combine deep, niche knowledge
from different domains to formulate a comprehensive and logically sound answer.

While the architecture provides the foundation, Qwen3's reasoning prowess is truly
forged during its specialized training and fine-tuning stages, where the raw potential
of the model is shaped into a powerful logical engine.

The Qwen team made a strategic shift away from a single, hybrid model toward
creating dedicated "Thinking" variants. This represents a move from a generalist to
a specialist model strategy, where specific versions of Qwen3 are fine-tuned on
datasets meticulously curated to enhance performance on tasks requiring complex
reasoning, such as advanced mathematics, programming logic, and abstract puzzles.

A key component of this regimen is Chain-of-Thought (CoT) fine-tuning. During


this stage, the "Thinking" models are explicitly trained to produce a step-by-step
reasoning trace before arriving at a final conclusion. This process, which involves
learning from human-generated examples of logical deduction, effectively teaches
the model to "show its work." By externalizing its computational trace, the model
learns to break down complex problems into smaller, manageable steps, leading to
more robust and transparent problem-solving. Furthermore, the training incorporates
a large volume of high-quality, multi-turn conversational data. This is crucial for
developing "contextual endurance"—the ability to maintain a coherent line of
reasoning across an extended dialogue, follow intricate instructions, and build upon
previous interactions to solve a problem.

In summary, Qwen3's sophisticated approach to reasoning is a two-pronged strategy.


It first builds a solid architectural foundation capable of handling complex
computations and then cultivates advanced logical capabilities through a highly
specialized and targeted fine-tuning process.

SmolLM3 and the Frontier of Positional Information


While perhaps not a direct competitor to the largest models on leaderboards,
SmolLM3 is an architecturally significant model that serves as a testbed for more
experimental ideas. Its most daring feature is the partial adoption of No Positional
Embeddings (NoPE). This technique challenges the long-held assumption that
transformers require explicit positional information (like RoPE) to be injected into the
model. The NoPE hypothesis suggests that the inherent directionality of the causal
attention mask provides a sufficient, albeit implicit, signal of token order that the
model can learn to utilize.

Reflecting the experimental nature of this technique, the SmolLM3 architects applied
NoPE cautiously, removing explicit positional encodings in only every fourth layer.
This hybrid approach allows the model to benefit from potential improvements in
length generalization while still retaining a strong positional signal from RoPE in the
majority of its layers. SmolLM3's design thus pushes the boundaries of our
understanding of what is truly necessary within a transformer, questioning core
principles in the pursuit of more elegant and generalizable models.

GTP-OSS: the OpenAI Open Source take


OpenAI has made a landmark return to the open-weight community by releasing
gpt-oss-120b and gpt-oss-20b. These models are the company's first
open-weight offerings since the influential GPT-2 was released in 2019. The gpt-oss
architecture encapsulates the key evolutionary steps that have defined large
language models over the past several years. Analyzing these models provides a clear
and comprehensive picture of the modern LLM blueprint.

A notable departure from GPT-2 is the complete removal of the dropout regularization
technique. This is because dropout offers little benefit in the single-epoch training
regimes common for today's massive datasets. The old method of absolute positional
embeddings has been replaced by the more dynamic Rotary Position Embedding
(RoPE). RoPE encodes position by rotating query and key vectors, a technique now
standard in most modern LLMs. The feed-forward network has also been upgraded,
with the GELU activation function being replaced by the more expressive and efficient
SwiGLU. SwiGLU uses a gated linear unit structure that improves model capabilities
while often requiring fewer parameters. Furthermore, gpt-oss adopts a
Mixture-of-Experts (MoE) architecture, replacing single feed-forward layers with a
large set of specialized "expert" layers. This sparse approach allows for a massive
increase in total parameters without a proportional rise in inference cost, as only a few
experts are active per token. Standard Multi-Head Attention is supplanted by Grouped
Query Attention (GQA) to improve efficiency. GQA reduces memory usage by having
multiple query heads share the same key and value projections. The model also
employs sliding-window attention, a technique that restricts attention to a smaller,
local context. In gpt-oss, this is applied in every second layer to balance local
efficiency with global context awareness. Finally, the computationally simpler
RMSNorm has taken the place of the original LayerNorm. RMSNorm stabilizes training
by normalizing activations but does so with less computational overhead.

When compared to a contemporary top model like Qwen3, the gpt-oss architecture
is fundamentally very similar. However, they differ significantly in their approach to
model shape, highlighting a key design trade-off. gpt-oss is a "wider" model with a
larger embedding dimension, whereas Qwen3 is a "deeper" model with twice as many
transformer blocks. The wider architecture of gpt-oss is better suited for
parallelization, likely resulting in faster inference speeds. Their MoE implementations
also diverge, as gpt-oss uses a few, very large experts. This contrasts with Qwen3
and the general trend towards using many smaller experts for greater specialization.
Interestingly, gpt-oss reintroduces bias units in its attention projection layers, a
feature largely abandoned after GPT-2. It also implements a form of attention sinks to
help stabilize the model's focus during long-context scenarios. These sinks are
learned per-head bias logits rather than special input tokens.
The models underwent extensive training, with the 120B version consuming 2.1 million
H100-hours. This process included both a supervised fine-tuning stage and a
high-compute reinforcement learning stage focused on reasoning. A unique
interactive feature is the ability to control the model's "Reasoning effort" via a system
prompt. Users can select low, medium, or high effort to balance response quality
against computational cost. OpenAI released the models with an MXFP4 quantization
scheme for the MoE experts. This crucial optimization dramatically reduces the
memory footprint, making the models accessible on single GPU setups. The models
are released under the permissive Apache 2.0 license. OpenAI itself correctly labels
them as "open-weight," as the training code and datasets are not included.

The models are too new to appear on leaderboards like the LM Arena. However,
OpenAI's provided benchmarks show gpt-oss is highly competitive on reasoning
tasks. It performs on par with top proprietary models and competitors like Qwen3.
This is particularly impressive as gpt-oss-120b is nearly half the size of the largest
Qwen3 model. In early real-world use, the model appears quite capable but shows a
tendency to hallucinate. This weakness is noted in its model card and is likely a
trade-off for its intense focus on reasoning tasks. The heavy training on STEM, coding,
and puzzles may have led to some "general knowledge forgetting." This limitation
might become less relevant as the model's intended integration with external tools
matures. Such a design prioritizes problem-solving skills over rote memorization,
much like human learning. In a surprising comparison, benchmarks indicate gpt-oss
is not far behind the newly announced, state-of-the-art GPT-5. This highlights the
impressive power and efficiency of OpenAI's open-weight architecture. Even if some
considered the release overhyped, the models are undeniably strong. Their existence
provides excellent new tools for those working with local or privately-hosted models.
The gpt-oss models represent a welcome and significant contribution to the
open-weight ecosystem. Ultimately, their release signals good times ahead for
open-source AI development.

GTP-OSS models are specifically designed as powerful reasoning systems, achieving


this through a combination of their core architecture, a specialized training process,
and unique interactive features. The model is built and trained from the ground up to
support complex, multi-step thought processes.

●​ Training Data and Methodology: The models were trained on a dataset with a
heavy focus on STEM, coding, and general knowledge. More importantly, they
underwent a post-training phase that included a high-compute reinforcement
learning (RL) stage, similar to the techniques used for OpenAI's most advanced
proprietary models. This RL phase specifically teaches the model to apply
Chain-of-Thought (CoT) reasoning before providing an answer.
●​ Agentic Behavior and Tool Use: GTP-OSS is designed for "agentic
workflows," meaning it can use tools to solve problems. It was trained to
interleave its reasoning steps with actions like performing web searches for
up-to-date information or executing Python code in a notebook environment to
perform calculations or test solutions.
●​ Transparent Reasoning: The models provide full, transparent access to their
Chain-of-Thought process. This allows developers to see the step-by-step
logic the model used to arrive at a conclusion, which is crucial for debugging,
building trust, and understanding the model's problem-solving approach.

GTP-OSS includes unique features that allow users to directly control and interact
with its reasoning process.

●​ Adjustable Reasoning Effort: This is the most distinct reasoning feature.


Users can specify the desired "Reasoning effort" by including a simple
instruction (e.g., Reasoning: high) in the system prompt. This allows you to
trade off between latency and performance:
○​ Low: For fast responses where deep reasoning isn't needed.
○​ Medium: A balance between speed and detailed thought.
○​ High: For complex problems that require the model to produce a longer
and more structured CoT trace to "think" with greater depth.
●​ Harmony Chat Format: The models were trained on a new, open-source
response format called Harmony. This structured format has distinct
"channels" for different parts of the response, such as analysis for the
reasoning trace and final for the user-facing answer. This structure is
essential for enabling the model's advanced agentic and reasoning capabilities.

In summary, GTP-OSS's reasoning ability comes from being explicitly trained on


complex problems, learning to use tools, and providing users with direct control over
the depth and transparency of its thought process.

Kimi2: the new model from China

A new benchmark has been set in the world of open-weight models with the release
of the Kimi K2 LLM. This model is truly massive, boasting an incredible
one-trillion-parameter architecture. Upon its release, Kimi K2 has immediately taken
the crown as the best available open-weight model. Initial benchmarks indicate that
its performance is exceptionally strong. It reportedly rivals even the best proprietary
LLMs currently on the market. Specifically, its capabilities on coding tasks are shown
to be competitive with leading models like Claude. This level of performance
establishes a new standard for what can be achieved by the open-source community.
The model's debut has understandably generated significant excitement and
discussion among AI researchers and developers. Kimi K2's impressive power signifies
a major leap forward for accessible, state-of-the-art artificial intelligence. Its
existence pushes the boundaries of large-scale model development.

When analyzing its architecture, one discovers that Kimi K2 builds upon a highly
successful existing foundation. Its structure is fundamentally based on the
architecture of the 673-billion-parameter DeepSeek V3. The two models are
described as being "basically the same," indicating a strategic decision to refine a
proven design. This convergence highlights a trend towards iterating on successful
blueprints for massive-scale models. However, the Kimi K2 team did implement a few
subtle but important modifications. The model was engineered to use more experts
within its Mixture-of-Experts (MoE) modules. This particular tweak likely provides the
model with greater knowledge capacity and allows for more refined specialization.
Conversely, Kimi K2 employs fewer heads in its Multi-head Latent Attention (MLA)
module compared to its predecessor. This adjustment may represent a deliberate
trade-off between attention complexity and other architectural components. These
changes demonstrate how an existing blueprint can be finely tuned for even greater
scale and performance.

Perhaps the most significant innovation in Kimi K2 is not found in its architecture but
in its training methodology. The developers made a crucial and high-stakes decision
regarding the optimization algorithm used during training. They chose to replace the
popular and widely-used AdamW optimizer with their own custom solution. This new
optimizer, named MuonClip, is a modified version of the relatively new Muon optimizer.
According to the development team, their MuonClip optimizer substantially
outperforms AdamW for the specific task of LLM training. This choice represented a
multi-million dollar bet, considering the immense expense of training a model of this
magnitude. The gamble appears to have paid off handsomely, as evidenced by the
model's training data. The resulting training loss curve is described as being the
smoothest ever seen for a large language model. The exceptional stability and
efficiency gained from the MuonClip optimizer are likely a key contributor to Kimi K2's
final success. This focus on innovating the training process itself highlights a critical
new frontier for achieving state-of-the-art results.

Kimi K2's reasoning capabilities are not the result of a single feature, but rather a
combination of its massive architectural scale and a highly advanced training
methodology that pushes the boundaries of how LLMs are optimized.

While the architecture is based on the DeepSeek-V3 blueprint, its immense scale
provides the necessary foundation for complex thought.

●​ 1 Trillion Parameters with MoE: Kimi K2 is a 1-trillion-parameter


Mixture-of-Experts (MoE) model. This enormous capacity allows it to store a
vast and nuanced repository of knowledge. For a complex reasoning task, its
internal router can dynamically combine information from numerous specialized
"experts," enabling it to tackle problems that require deep, multi-domain
knowledge.
●​ Multi-Head Latent Attention (MLA): Like its predecessor, Kimi uses MLA to
efficiently manage the memory required for its attention mechanism. This is
crucial for reasoning, as it allows the model to process and maintain very long
contexts and "chains of thought" without being constrained by hardware
limitations.

The most significant factor behind Kimi's reasoning power is its innovative training
process.

●​ Reinforcement Learning with LLMs: The model's development was informed


by the "Kimi k1.5: Scaling Reinforcement Learning with LLMs" paper, indicating
that its reasoning abilities were heavily shaped by advanced reinforcement
learning (RL) techniques. This process teaches the model to solve complex
problems by rewarding it for producing correct and well-reasoned outputs.
●​ State-of-the-Art MuonClip Optimizer: In a crucial and high-stakes decision,
the development team replaced the standard AdamW optimizer with a custom
version called MuonClip. According to the developers, this optimizer is
substantially better for LLM training, resulting in an exceptionally smooth and
stable learning process. A more stable training allows the model to more
effectively learn the complex, abstract patterns in data that are fundamental to
logical reasoning. This "multi-million dollar bet" on a new optimizer is
considered a key reason for the model's state-of-the-art performance.
In summary, Kimi's reasoning prowess comes from combining a massive, efficient
architecture with a groundbreaking training strategy. The use of the MuonClip
optimizer allowed its trillion-parameter model to train more effectively than its peers,
resulting in a system with powerful and stable reasoning capabilities.

Key Takeaways
From our architectural survey, several clear conclusions emerge:

●​ Sparsity is the New Standard for Scale: The most dominant trend in
modern LLMs is the adoption of the Mixture-of-Experts (MoE) architecture.
This "sparse" approach, used by models like DeepSeek, Kimi 2, Llama 4, and
gpt-oss, allows for a massive increase in a model's total parameters (and thus
its knowledge capacity) while keeping the computational cost of inference
manageable by only activating a small fraction of "experts" for any given token.
●​ Attention Mechanisms are Hyper-Optimized for Efficiency: The
original Multi-Head Attention has been almost universally replaced by more
efficient alternatives. Grouped-Query Attention (GQA) is the new industry
standard, offering a great balance of performance and reduced memory usage.
However, different models employ specialized variants to meet specific goals,
such as Multi-Head Latent Attention (MLA) for compressing the KV cache
(DeepSeek, Kimi 2) or Sliding Window Attention for handling very long
contexts (Gemma, gpt-oss).
●​ Small Architectural Refinements Yield Big Gains: Modern architectures
are defined by an accumulation of smaller, but significant, efficiency and
stability improvements. These include replacing LayerNorm with the
computationally cheaper RMSNorm, and swapping the GELU activation
function for the more expressive and efficient SwiGLU. The strategic
placement of these normalization layers has also become a key area of
innovation for ensuring stable training.
●​ Innovation is Shifting from Core Concepts to Training Methods: While
architectural tweaks are important, a new frontier for achieving
state-of-the-art performance is emerging in the training process itself. The
success of Kimi K2, for example, is attributed not just to its architecture but to
its use of the novel MuonClip optimizer over the standard AdamW. This shows
that how a model is trained can be just as innovative and impactful as its
design.
●​ Model Design is a Game of Trade-offs: There is no single "best"
architecture; instead, developers make deliberate trade-offs. This is seen in the
"width versus depth" debate (e.g., the wider gpt-oss vs. the deeper Qwen3),
the configuration of MoE layers (few large experts vs. many small ones), and
the choice of attention mechanisms. Each leading model represents a unique
combination of these techniques, tailored to a specific set of performance and
efficiency goals.

Conclusion
The era of radical, ground-up architectural redesigns appears to have paused,
replaced by a period of profound and intelligent maturation. The "minor refinements"
seen in modern LLMs are, in fact, the critical enablers of their immense scale and
power. Each choice—be it MoE, GQA, MLA, or a specific normalization strategy—is a
deliberate engineering decision aimed at optimizing the trade-off between a model's
capacity and its computational cost.

The developments of 2025 show that the frontier of AI is being pushed not just by
brute-force scaling, but by the sophisticated and cumulative process of architectural
refinement. We are learning to build not just bigger models, but smarter, more efficient
ones. The foundation laid by the original transformer has proven remarkably robust,
and the ongoing work is a testament to how much performance can be unlocked by
meticulously polishing that foundation. The next breakthroughs may not come from a
new foundation, but from novel combinations of these hard-won refinements.

Introduction: An Era of Refinement​ 1


Key Architectural Innovations of the Modern Era​ 1
The Rise of Sparse Models: Mixture-of-Experts (MoE)​ 2
Evolving the Attention Mechanism: Beyond Multi-Head​ 2
Grouped-Query Attention (GQA)​ 2
Multi-Head Latent Attention (MLA)​ 3
Sliding Window Attention​ 4
The Subtle Art of Normalization and Positional Signals​ 5
Normalization: The Art of Stability​ 5
No Positional Embeddings (NoPE)​ 6
Sliding Attention Window​ 7
A Tour of 2025's Flagship Architectures​ 8
DeepSeek-V3​ 8
OLMo 2​ 9
Gemma 3​ 10
Llama 4: Mainstreaming the Mixture-of-Experts​ 11
Qwen3: The Hallmark of Versatility​ 13
SmolLM3 and the Frontier of Positional Information​ 15
GTP-OSS: the OpenAI Open Source take​ 15
Kimi2: the new model from China​ 18
Key Takeaways​ 20
Conclusion​ 21

Gemma: the Code


Introduction
The proliferation of open-source models has become a catalyst for innovation in
artificial intelligence, providing the research community with transparent and
replicable foundations for new discoveries. Among these, Google's Gemma family of
models serves as an exemplary case study, offering a suite of powerful and efficient
architectures. While theoretical discussions of such models are widespread, a deeper,
empirical understanding requires a direct examination of their implementation. This
chapter, therefore, presents a detailed analysis of the official Gemma PyTorch source
code, moving from high-level architectural concepts to the specific algorithmic
constructs that govern its behavior.

Gemma Model
Our investigation is structured to mirror the model's own construction. We begin with
its foundational blueprint—the configuration files that specify its core parameters or
"genome." Subsequently, we dissect the primary cognitive engine, analyzing the
transformer blocks, sophisticated attention mechanisms, and normalization
techniques that underpin its linguistic capabilities. We then explore the frontier of
multimodality by examining the Gemma 3 architecture, detailing how its vision system
is integrated to perceive and reason about visual data. Finally, we address the critical
challenge of scale, investigating the software mechanisms for distributed execution
that allow the model to operate on large-scale computational clusters. This analysis is
based on the publicly available source code, reviewed as of August 15, 2025, providing
a concrete and verifiable basis for our exploration (the repository is available at
https://github.com/google/gemma_pytorch/tree/main/gemma).

gemma/config.py
Code: https://github.com/google/gemma_pytorch/blob/main/gemma/config.py

This file functions as the genome for a Gemma model, housing the core instructions
and parameters needed to construct any variant. It guarantees that each component
adheres to precise specifications, yielding a cohesive and operational neural network.

The central component is the GemmaConfig dataclass, a structured container that


holds the model's defining hyperparameters. Key among these are the vocab_size,
which sets the scope of its vocabulary; num_hidden_layers, which determines the
model's depth and cognitive capacity; and hidden_size, which defines the
dimensionality of the vector spaces where information is processed.

To manage architectural evolution across different versions, the configuration


employs enumerations. The Architecture enum (GEMMA_1, GEMMA_2, GEMMA_3)
functions as a version control system, allowing the model-building code to activate
features specific to each generation. Similarly, the AttentionType enum (GLOBAL,
LOCAL_SLIDING) defines the fundamental modes of the model's focus. Global
attention is the classical mechanism where a token can derive context from all
preceding tokens, while local sliding attention is a computationally efficient
approximation where a token only considers a recent, fixed-size window of context.
This dual system allows newer Gemma models to process extremely long sequences
with greater efficiency.

Finally, a series of factory functions, such as get_config_for_9b(), provide validated


presets for each official model variant, ensuring reproducibility.

gemma/gemma3_model.py
code: https://github.com/google/gemma_pytorch/blob/main/gemma/gemma3_model.py

This code defines the constructor for the Gemma3ForMultimodalLM class, which
instantiates and assembles the complete computational graph for the Gemma 3
model. The initialization process is methodical, first establishing the language
processing backbone before integrating the specialized vision components.

Initially, the constructor sets up the model's linguistic foundation by initializing the
standard modules from the base text-only architecture: the tokenizer, the
text_token_embedder, the core stack of transformer blocks (model), and the token
sampler. Subsequently, it integrates the visual processing pipeline. This involves
instantiating the siglip_vision_model, a powerful Vision Transformer that functions as a
feature extractor, converting raw image data into a high-dimensional semantic
representation. To ensure compatibility between the two modalities, the visual
embeddings are first stabilized by an RMSNorm layer (mm_soft_embedding_norm)
and then transformed by a linear mm_input_projection layer. This projection acts as an
adapter, mapping the visual embeddings into the same latent space as the text
embeddings, which is a critical prerequisite for their fusion.

Finally, the constructor pre-computes two distinct sets of Rotary Positional


Embeddings (RoPE) to imbue the model with an understanding of sequence order. A
local_freqs_cis table is generated for layers using local attention to encode
short-range positional relationships, while a global_freqs_cis table, configured with a
larger base period, is created for global attention layers to capture long-range
dependencies. A helper function then registers these tables as non-trainable buffers,
making them a permanent but fixed part of the model's state.

gemma/gemma3_model.py

# Copyright 2024 Google LLC


#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Gemma 3 multimodal model implementation."""

import torch
import os
import json
import gc
from torch import nn
from PIL import Image
from typing import Any, List, Sequence, Tuple, Union

from . import model as gemma_model


from . import config as gemma_config
from . import gemma3_preprocessor
from . import tokenizer
from .siglip_vision import siglip_vision_model

class Gemma3ForMultimodalLM(nn.Module):
"""Gemma3 model for multimodal causal LM."""
def __init__(
self,
config: gemma_config.GemmaConfig,
):
super().__init__()
self.dtype = config.get_dtype()
assert config.architecture == gemma_config.Architecture.GEMMA_3
self.config = config
max_seq_len = config.max_position_embeddings
head_dim = config.head_dim
vocab_size = config.vocab_size
self.tokenizer = tokenizer.Tokenizer(config.tokenizer)
self.text_token_embedder = gemma_model.Embedding(vocab_size,
config.hidden_size, config.quant)
self.model = gemma_model.GemmaModel(config)
self.sampler = gemma_model.Sampler(vocab_size, config)

if config.vision_config is None:
raise ValueError('vision_config must be provided for Gemma3.')
self.siglip_vision_model =
siglip_vision_model.SiglipVisionModel(config.vision_config)
# transformer/embedder/mm_soft_embedding_norm
self.mm_soft_embedding_norm =
gemma_model.RMSNorm(config.vision_config.embedding_dim,
eps =
config.rms_norm_eps)
# transformer/embedder/mm_input_projection
self.mm_input_projection =
gemma_model.Linear(config.vision_config.embedding_dim, config.hidden_size,
config.quant)

if config.rope_wave_length is None:
raise ValueError('rope_wave_length must be provided for Gemma3.')
rope_lengths = config.rope_wave_length
defaults = {
gemma_config.AttentionType.LOCAL_SLIDING: 10_000,
gemma_config.AttentionType.GLOBAL: 10_000,
}
self._register_freqs_cis('local_freqs_cis', head_dim, max_seq_len,
theta=rope_lengths.get(
gemma_config.AttentionType.LOCAL_SLIDING,
defaults[gemma_config.AttentionType.LOCAL_SLIDING]
))
self._register_freqs_cis('global_freqs_cis', head_dim, max_seq_len,
theta=rope_lengths.get(
gemma_config.AttentionType.GLOBAL,
defaults[gemma_config.AttentionType.GLOBAL]
), rope_scaling_factor=config.rope_scaling_factor)

def _register_freqs_cis(
self, name: str, head_dim: int, max_seq_len: int, theta: int =
10_000, rope_scaling_factor: int = 1
):
self.register_buffer(
name, gemma_model.precompute_freqs_cis(head_dim, max_seq_len *
2, theta=theta, rope_scaling_factor=rope_scaling_factor)
)

The forward method transforms input tensors into the next token prediction,
processing text and images in parallel. It begins by obtaining Rotary Positional
Embeddings (RoPE) for token positioning. Text input IDs are converted into vector
embeddings. If images are present, the siglip_vision_model extracts conceptual
embeddings, which are then normalized and projected to match text embedding
dimensions. The populate_image_embeddings method then fuses these text and
image embeddings into a single multimodal representation. This combined sequence
is fed into the main transformer (self.model) for reasoning. Finally, a sampler
computes probability scores and selects the most likely next token, returning its ID
and raw logits.

The populate_image_embeddings method is crucial for multimodal fusion, combining


separately processed text and image embeddings. It identifies valid image
embeddings using image_presence_mask, finds the exact locations of <image>
placeholder tokens in input_token_ids, and then overwrites the text embeddings at
these locations with the valid image embeddings, creating the final hybrid sequence
for the transformer.

@torch.no_grad()
def forward(self,
input_token_ids: torch.Tensor, # B x L
image_patches: torch.Tensor, # B x N x C x H x W (3x896x896)
image_presence_mask: torch.Tensor, # B x N
input_positions: torch.Tensor,
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
mask: torch.Tensor,
output_positions: torch.Tensor,
temperatures: Union[torch.Tensor, None],
top_ps: torch.Tensor,
top_ks: torch.Tensor,
local_mask: torch.Tensor | None = None,
**kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
freqs_cis = {}
freqs_cis[gemma_config.AttentionType.LOCAL_SLIDING] = (
self.local_freqs_cis.index_select(0, input_positions)
)
freqs_cis[gemma_config.AttentionType.GLOBAL] = (
self.global_freqs_cis.index_select(0, input_positions)
)
hidden_states = self.text_token_embedder(input_token_ids)
normalizer = torch.tensor(self.config.hidden_size**0.5,
dtype=hidden_states.dtype, device=hidden_states.device)
hidden_states = hidden_states * normalizer
if image_patches is not None and self.config.vision_config is not None:
# the input has images
B, N, C, H, W = image_patches.shape
# Flatten and Pass to SiglipVisionModel, and apply SiglipVisionModel
Exit
flattened_input = image_patches.reshape(B * N, C, H, W) # (B*N)xCxHxW
image_embeddings = self.siglip_vision_model(flattened_input) #
(B*N)xUxD
image_embeddings = self.mm_soft_embedding_norm(image_embeddings) #
(B*N) x U x D
image_embeddings = self.mm_input_projection(image_embeddings) # (B*N)
x U x model_dim
hidden_states = self.populate_image_embeddings(
hidden_states.clone(),
image_embeddings.clone(),
input_token_ids.clone(),
image_presence_mask.clone(),
)

kv_write_indices = input_positions

hidden_states = self.model(
hidden_states=hidden_states,
freqs_cis=freqs_cis,
kv_write_indices=kv_write_indices,
kv_caches=kv_caches,
mask=mask,
local_mask=local_mask,
)
embedder_weight = self.text_token_embedder.weight
if self.config.quant:
embedder_weight = (
embedder_weight *
self.text_token_embedder.weight_scaler.unsqueeze(-1))

next_tokens, logits = self.sampler(


embedding=embedder_weight,
hidden_states=hidden_states,
output_positions=output_positions,
temperatures=temperatures,
top_ps=top_ps,
top_ks=top_ks,
)
return next_tokens, logits

def populate_image_embeddings(self,
hidden_states: torch.Tensor, # B x L x
model_dim
image_embeddings: torch.Tensor, # (B*N) x U
x model_dim
input_token_ids: torch.Tensor, # B x L
image_presence_mask: torch.Tensor, # B x N
):
batch_size, seq_len, model_dim = hidden_states.shape
# Step 1 of 2: Fetch valid image embeddings
# flatten indices of valid image embeddings
valid_image_embeddings_indices =
torch.nonzero(image_presence_mask.flatten(), as_tuple=False).squeeze()
# num_valid_images x model_dim
valid_image_embeddings = image_embeddings.index_select(0,
valid_image_embeddings_indices)

# Step 2 of 2: Replace image embeddings at right places.


image_placeholder_mask = input_token_ids ==
self.tokenizer.image_token_placeholder_id
image_placeholder_indices =
image_placeholder_mask.flatten().nonzero(as_tuple=False).squeeze()

hidden_states = hidden_states.reshape(-1, self.config.hidden_size)


hidden_states[image_placeholder_indices] =
valid_image_embeddings.reshape(-1, self.config.hidden_size)
return hidden_states.reshape(batch_size, seq_len,
model_dim).contiguous()

This function implements a hybrid attention mask strategy for multimodal models,
combining causal attention for text and bidirectional attention for images. It starts with
a causal mask, identifies and groups image tokens, and then creates a bidirectional
mask for these image blocks. The final mask is an OR combination of the causal and
bidirectional masks. A local mask is also applied for sliding window attention. This
comprehensive masking provides precise guidance for processing complex
multimodal inputs.

def create_attention_mask(self, input_ids: torch.Tensor, sequence_length:


int):
batch_size = input_ids.shape[0]
causal_mask = torch.tril(torch.ones((batch_size, 1, sequence_length,
sequence_length), dtype=torch.bool, device=input_ids.device))
image_token_mask = input_ids ==
self.tokenizer.image_token_placeholder_id
# Pad the mask to the left with 0. This is to make sure the boundary
# detection works correctly. Boundary (starting index of image patch) is
# detected when the value changes from 0 to 1.
padded_mask = nn.functional.pad(image_token_mask, (1, 0), value=0)
# Find the boundary (starting index) of the image tokens patch.
boundary = padded_mask[:, 1:] > padded_mask[:, :-1]
# Number the boundary.
# boundary:
# [[False, False, True, False, False],
# [False, True, False, True, False]]
# numbered_boundary:
# [[0, 0, 1, 1, 1],
# [0, 1, 1, 2, 2]]
numbered_boundary = torch.cumsum(boundary, dim=-1)

# image_token_mask:
# [[False, False, True, True, False],
# [True, True, False, True, True]]
# numbered_boundary:
# [[0, 0, 1, 1, 1],
# [1, 1, 1, 2, 2]]
# q_block_indices:
# [[0, 0, 1, 1, 0],
# [1, 1, 0, 2, 2]]
q_block_indices = image_token_mask * numbered_boundary
kv_block_indices = q_block_indices
# Test the equality of vertical and horizontal numbered patches
# to create the bidirectional mask.
bidirectional_mask = torch.logical_and(
kv_block_indices[:, None, :] == q_block_indices.unsqueeze(-1),
q_block_indices.unsqueeze(-1) > 0,
)
attention_mask = torch.logical_or(causal_mask,
bidirectional_mask.unsqueeze(1))
# The upper triangular matrix's diagonal is shifted by sliding window
size
# before doing logical 'and' with attention mask. This is to make sure
the
# local attention is within the sliding window.
local_mask = torch.logical_and(
attention_mask,
torch.triu(torch.ones((1, 1, sequence_length, sequence_length),
dtype=torch.bool, device=input_ids.device),
diagonal=-(self.config.sliding_window_size-1))
)
return attention_mask, local_mask
The generate method is the main entry point for a user. It takes a raw prompt, which
can be a complex sequence of text and images, and manages the entire end-to-end
process of generating a textual response.

●​ Input Preprocessing: The first step is to delegate the complex task of input
preparation to the gemma3_preprocessor. This function handles tokenizing the
text, processing images using the "pan-and-scan" technique, inserting special
image tokens, and padding all sequences to create uniform batches.
●​ Attention Mask Creation: The self.create_attention_mask method is
subsequently invoked to construct a sophisticated hybrid mask. This mask
facilitates causal attention for text and bidirectional attention for images.
Boolean values are then converted to floating-point values within this mask,
with disallowed positions being populated by a large negative number.
●​ KV Cache Initialization: The KV Cache is a vital optimization that significantly
enhances the efficiency of text generation. It works by pre-allocating memory
for the Key (K) and Value (V) tensors within each layer of the model. This
pre-storage eliminates the need to recompute these tensors for the entire
context at every generation step.

The Two-Phase Generation Loop: The core of the generate method is a loop that
implements autoregressive decoding. This process is split into two distinct phases for
maximum efficiency:

●​ Phase 1: Prefill: During the initial loop iteration, the model efficiently processes
the complete input prompt, including both text and images, in a single, parallel
forward pass. This action populates the KV cache with context derived from the
user's prompt and involves passing the image_batch to the model.
●​ Phase 2: Decode: After the initial prefill step, the model efficiently generates
one token at a time in subsequent iterations. This process is very fast during
the decoding phase because each step only requires processing the single
token generated previously. A crucial memory optimization involves discarding
the image_batch after the prefill step, as its information is already encoded
within the KV cache.

After the loop completes, the final sequence of token IDs is converted back into a
human-readable string by the tokenizer, and the result is returned.

def generate(
self,
prompts: Sequence[Sequence[Union[str, Image.Image]]],
device: Any,
output_len: int = 100,
temperature: Union[float, None] = 1.0,
top_p: float = 0.95,
top_k: int = 64,
) -> Sequence[str]:
"""Generates responses for given prompts using Gemma model."""
# Inference only.
processing_result = gemma3_preprocessor.tokenize_raw_input(
self.tokenizer, prompts, self.config, output_len, device
)
batch_size = processing_result["batch_size"]
user_input_token_ids = processing_result["user_input_token_ids"]
image_batch = processing_result["image_batch"]
min_prompt_len = processing_result["min_prompt_len"]
max_prompt_len = processing_result["max_prompt_len"]
total_seq_len = processing_result["max_seq_len"]
image_presence_mask = processing_result["image_presence_mask"]

# Create attention mask.


min_dtype = torch.finfo(self.dtype).min
if self.config.sliding_window_size is None:
raise ValueError('gemma 3 model requires sliding_window size')
boolean_mask, local_boolean_mask =
self.create_attention_mask(user_input_token_ids, total_seq_len)
mask_tensor = torch.where(boolean_mask, 0, torch.tensor(min_dtype,
dtype=torch.float32, device=device)).contiguous()
local_mask_tensor = torch.where(local_boolean_mask, 0,
torch.tensor(min_dtype, dtype=torch.float32, device=device)).contiguous()

kv_caches = []
for _ in range(self.config.num_hidden_layers):
size = (batch_size, total_seq_len, self.config.num_key_value_heads,
self.config.head_dim)
dtype = self.config.get_dtype()
k_cache = torch.zeros(size=size, dtype=dtype, device=device)
v_cache = torch.zeros(size=size, dtype=dtype, device=device)
kv_caches.append((k_cache, v_cache))

input_token_ids_tensor = torch.full((batch_size, min_prompt_len),


self.tokenizer.pad_id,
dtype=torch.int64,
device=device)
token_ids_tensor = user_input_token_ids.to(device)
for i in range(batch_size):
p = user_input_token_ids[i]
input_token_ids_tensor[i, :min_prompt_len] = p[:min_prompt_len]

input_positions_tensor = torch.arange(0, min_prompt_len,


dtype=torch.int64, device=device)
prompt_mask_tensor = token_ids_tensor != self.tokenizer.pad_id
curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor)
curr_local_mask_tensor = local_mask_tensor.index_select(2,
input_positions_tensor)
output_positions_tensor = torch.LongTensor([min_prompt_len -
1]).to(device)
temperatures_tensor = None if not temperature else torch.FloatTensor(
[temperature] * batch_size).to(device)
top_ps_tensor = torch.FloatTensor([top_p] * batch_size).to(device)
top_ks_tensor = torch.LongTensor([top_k] * batch_size).to(device)
output_index = torch.tensor(min_prompt_len, dtype=torch.int64,
device=device)

# Prefill up to min_prompt_len tokens, then treat other prefill as


# decode and ignore output.
for i in range(total_seq_len - min_prompt_len):
next_token_ids, _ = self(
input_token_ids=input_token_ids_tensor,
image_patches=image_batch,
image_presence_mask=image_presence_mask,
input_positions=input_positions_tensor,
kv_caches=kv_caches,
mask=curr_mask_tensor,
output_positions=output_positions_tensor,
temperatures=temperatures_tensor,
top_ps=top_ps_tensor,
top_ks=top_ks_tensor,
local_mask=curr_local_mask_tensor,
)
curr_prompt_mask = prompt_mask_tensor.index_select(
1, output_index).squeeze(dim=1)
curr_token_ids = token_ids_tensor.index_select(
1, output_index).squeeze(dim=1)
output_token_ids = torch.where(curr_prompt_mask, curr_token_ids,
next_token_ids).unsqueeze(dim=1)
token_ids_tensor.index_copy_(1, output_index, output_token_ids)

input_token_ids_tensor = output_token_ids
input_positions_tensor = output_index.unsqueeze(dim=-1)
curr_mask_tensor = mask_tensor.index_select(2,

input_positions_tensor)
curr_local_mask_tensor = local_mask_tensor.index_select(
2, input_positions_tensor
) if local_mask_tensor is not None else None
output_positions_tensor = torch.tensor(0, dtype=torch.int64,
device=device)
output_index = output_index + 1
image_batch = None
image_presence_mask = None

# Detokenization.
token_ids = token_ids_tensor.tolist()
results = []
for i, tokens in enumerate(token_ids):
output = tokens
if self.tokenizer.eos_id in output:
eos_index = output.index(self.tokenizer.eos_id)
output = output[:eos_index]
results.append(self.tokenizer.decode(output))

return results

def load_weights(self, model_path: str):


if os.path.isfile(model_path):
self.load_state_dict(
torch.load(
model_path, mmap=True, weights_only=True,
)['model_state_dict'],
strict=False,
)
else:
index_path = os.path.join(model_path, 'pytorch_model.bin.index.json')
with open(index_path, "r", encoding="utf-8") as f:
index = json.load(f)
shard_files = list(set(index["weight_map"].values()))
for shard_file in shard_files:
shard_path = os.path.join(model_path, shard_file)
state_dict = torch.load(shard_path, map_location="cpu",
weights_only=True)
self.load_state_dict(state_dict, strict=False)
del state_dict # Save memory.
gc.collect()

gemma/gemma3_preprocessor.py
code: https://github.com/google/gemma_pytorch/blob/main/gemma/gemma3_preprocessor.py

This code prepares mixed text and image inputs for the Gemma 3 multimodal model
by transforming them into formatted tensors. The gemma3_input_preprocessor
function uses a "pan-and-scan" technique to generate multiple high-resolution crops
from the original image. Both original and cropped images are processed by a SigLIP
vision preprocessor, converting them to torch.Tensor objects and providing holistic
and detailed views with textual context. The tokenize_raw_input function orchestrates
batch processing. Text is tokenized into integer IDs. For images, special tokens (<boi>,
<image> placeholders, <eoi>) are inserted, and the tensor is stored separately. The
function pads token sequences and image lists to uniform sizes, converting them into
torch.Tensor objects and creating an image_presence_mask. The output is a
dictionary containing all necessary tensors for the Gemma 3 model's forward pass.
gemma/model.py
code: https://github.com/google/gemma_pytorch/blob/main/gemma/gemma3_preprocessor.py

This file provides the complete implementation of the text-only Gemma models. It
defines the core modules of the transformer architecture, how they fit together, and
the high-level interface for generating text. It's the cognitive engine that powers
Gemma's language capabilities.

The Sampler class is responsible for the final token selection process in the language
model. Its forward method takes the model's final hidden states and calculates logits
by performing a matrix multiplication with the embedding matrix. It supports an
optional final_logit_softcapping to stabilize training by limiting the range of logit
values. If no temperature is provided, it performs greedy decoding by simply selecting
the token with the highest logit. Otherwise, it applies temperature scaling, top-k, and
top-p filtering to the probabilities derived from the logits. Top-k limits the sampling
pool to the k most likely tokens, while top-p filters by cumulative probability. After
filtering, the probabilities are renormalized, and the final token is chosen via
multinomial sampling.

gemma/model.py

# Copyright 2024 Google LLC


#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Gemma model implementation."""

import json
import gc
import os
import torch
from torch import nn
import torch.nn.functional as F
from typing import Any, List, Optional, Sequence, Tuple, Union, Mapping
from gemma import config as gemma_config
from gemma import tokenizer

class Sampler(nn.Module):

def __init__(self, vocab_size: int, config: gemma_config.GemmaConfig):


super().__init__()
self.vocab_size = vocab_size
self.config = config

@torch.no_grad()
def forward(
self,
embedding: torch.Tensor,
hidden_states: torch.Tensor,
output_positions: torch.Tensor,
temperatures: Union[torch.Tensor, None],
top_ps: torch.Tensor,
top_ks: torch.Tensor,
embedding_bias: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Select the last element for each sequence.
# (batch_size, input_len, hidden_size) -> (batch_size, hidden_size)
hidden_states = hidden_states.index_select(
1, output_positions).squeeze(dim=1)
logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None:
logits += embedding_bias
if self.config.final_logit_softcapping is not None:
logits = logits / self.config.final_logit_softcapping
logits = torch.tanh(logits)
logits = logits * self.config.final_logit_softcapping

if temperatures is None:
return torch.argmax(logits, dim=-1).squeeze(dim=-1), logits

# Apply temperature scaling.


logits.div_(temperatures.unsqueeze(dim=1))

# Calculate probabilities with softmax.


probs = torch.softmax(logits, dim=-1, dtype=torch.float)
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)

# Apply top-p, top-k.


probs_sum = torch.cumsum(probs_sort, dim=-1)
top_ps_mask = (probs_sum - probs_sort) > top_ps.unsqueeze(dim=1)
probs_sort = torch.where(top_ps_mask, 0, probs_sort)

top_ks_mask = torch.arange(probs_idx.shape[-1],
device=probs_idx.device)
top_ks_mask = top_ks_mask.expand(probs_idx.shape[0], -1)
top_ks_mask = top_ks_mask >= top_ks.unsqueeze(dim=1)
probs_sort = torch.where(top_ks_mask, 0, probs_sort)
# Re-normalization.
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
probs = torch.gather(probs_sort,
dim=-1,
index=torch.argsort(probs_idx, dim=-1))

next_token_ids = torch.multinomial(probs,
num_samples=1,
replacement=True).squeeze(dim=-1)
return next_token_ids, logits

precompute_freqs_cis, is responsible for creating the Rotary Positional Embeddings


(RoPE) table. RoPE is a method for encoding positional information into the query and
key vectors in the attention mechanism. The function calculates a set of frequencies
based on a base theta value and the embedding dimension. These frequencies are
then combined with a sequence of positions (t) to create phase angles. Finally, it uses
torch.polar to convert these angles into complex numbers (cis form: e itheta = cos
theta+i sin theta), which will later be multiplied with the query and key vectors to
"rotate" them according to their position. An optional rope_scaling_factor can be used
to adjust the frequencies, which can help in extending the model's context length.

The apply_rotary_emb function executes the application of the pre-computed Rotary


Positional Embeddings. It takes an input tensor x (representing either queries or keys)
and the freqs_cis tensor. The function first reinterprets the real-valued input tensor as
a complex tensor. It then performs an element-wise multiplication between this
complex tensor and the freqs_cis tensor, effectively rotating the embeddings in the
complex plane. Finally, it converts the result back into a real-valued tensor and
reshapes it to match the original input dimensions, now with positional information
encoded.

def precompute_freqs_cis(dim: int,


end: int,
theta: float = 10000.0,
rope_scaling_factor:int = 1) -> torch.Tensor:
"""Precomputes the frequency cis."""
freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() /
dim))
freqs = freqs/rope_scaling_factor
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) ->
torch.Tensor:
"""Applies the rotary embedding to the query and key tensors."""
x_ = torch.view_as_complex(
torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1),
dim=-1))
x_out = torch.view_as_real(x_ * freqs_cis).type_as(x)
x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2],
-1).transpose(1, 2)
return x_out

The Linear class defines a custom linear transformation layer. A key feature of this
class is its support for 8-bit quantization. If the quant flag is set to True, the weights
are stored as 8-bit integers (torch.int8), and a separate floating-point weight_scaler
tensor is also stored. During the forward pass, if quantized, the weights are
de-quantized on-the-fly by multiplying them with the scaler before the standard
F.linear operation is performed. This allows for significant memory savings with
minimal impact on performance. If not quantized, it behaves like a standard linear
layer.

The Embedding class provides a custom token embedding layer, analogous to


torch.nn.Embedding. Similar to the custom Linear class, it incorporates an optional
8-bit quantization mechanism. When quant is enabled, the embedding weight matrix
is stored as int8 values along with a corresponding floating-point scaler. In the
forward pass, these weights are de-quantized before being used in the F.embedding
lookup function. This design reduces the memory footprint of the largest parameter
matrix in most language models.

This Root Mean Square Normalization (RMSNorm) implements a computationally


efficient alternative to Layer Normalization. The _norm helper function computes the
core normalization by scaling the input x by the inverse square root of the mean of its
squared values. The forward method applies this normalization and then multiplies the
result by a learnable weight parameter. The add_unit_offset flag controls whether the
scaling factor is (1 + weight) or simply weight, a subtle but important difference in
implementation across model versions. The comment highlights a precision difference
between Llama and Gemma2's handling of this operation.
class Linear(nn.Module):

def __init__(self, in_features: int, out_features: int, quant: bool):


super().__init__()
if quant:
self.weight = nn.Parameter(
torch.empty((out_features, in_features), dtype=torch.int8),
requires_grad=False,
)
self.weight_scaler = nn.Parameter(torch.Tensor(out_features))
else:
self.weight = nn.Parameter(
torch.empty((out_features, in_features)),
requires_grad=False,
)
self.quant = quant

def forward(self, x):


weight = self.weight
if self.quant:
weight = weight * self.weight_scaler.unsqueeze(-1)
output = F.linear(x, weight)
return output

class Embedding(nn.Module):

def __init__(self, num_embeddings: int, embedding_dim: int, quant:


bool):
super().__init__()
if quant:
self.weight = nn.Parameter(
torch.empty((num_embeddings, embedding_dim),
dtype=torch.int8),
requires_grad=False,
)
self.weight_scaler = nn.Parameter(torch.Tensor(num_embeddings))
else:
self.weight = nn.Parameter(
torch.empty((num_embeddings, embedding_dim)),
requires_grad=False,
)
self.quant = quant

def forward(self, x):


weight = self.weight
if self.quant:
weight = weight * self.weight_scaler.unsqueeze(-1)
output = F.embedding(x, weight)
return output

class RMSNorm(torch.nn.Module):

def __init__(
self,
dim: int,
eps: float = 1e-6,
add_unit_offset: bool = True,
):
super().__init__()
self.eps = eps
self.add_unit_offset = add_unit_offset
self.weight = nn.Parameter(torch.zeros(dim))

def _norm(self, x):


return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

def forward(self, x):


# Llama does x.to(float16) * w whilst Gemma2 is (x * w).to(float16)
# See https://github.com/huggingface/transformers/pull/29402
output = self._norm(x.float())
if self.add_unit_offset:
output = output * (1 + self.weight.float())
else:
output = output * self.weight.float()
return output.type_as(x)

The GemmaMLP class constitutes the feed-forward network (FFN) block found within
each transformer decoder layer. It implements a variant of the Gated Linear Unit (GLU)
architecture. The input tensor x is passed through two separate linear layers:
gate_proj and up_proj. The output of the gate_proj is then passed through a GELU
activation function. This activated gate is then multiplied element-wise with the
output of the up_proj. Finally, the result is passed through a down_proj linear layer to
project it back to the model's hidden dimension.

The GemmaAttention class is the core of the self-attention mechanism. It initializes


with parameters for the number of attention heads, key-value heads (to support
Grouped-Query Attention), and dimension sizes. The forward pass begins by
projecting the input hidden_states into query (Q), key (K), and value (V) tensors using
a single qkv_proj layer. It supports optional query-key normalization (use_qk_norm).
Positional information is then injected by applying rotary embeddings
(apply_rotary_emb) to Q and K. The new key and value states are written to a KV
cache for efficient auto-regressive decoding. Attention scores are calculated via
scaled dot-product of Q and K, and a mask is applied to prevent attending to future or
padded tokens. The class is also aware of the attention type, allowing it to apply a
specific local_mask for sliding window attention. Finally, the scores are used to
compute a weighted sum of the values, and the result is projected back to the hidden
size via o_proj.

class GemmaMLP(nn.Module):

def __init__(
self,
hidden_size: int,
intermediate_size: int,
quant: bool,
):
super().__init__()
self.gate_proj = Linear(hidden_size, intermediate_size, quant)
self.up_proj = Linear(hidden_size, intermediate_size, quant)
self.down_proj = Linear(intermediate_size, hidden_size, quant)

def forward(self, x):


gate = self.gate_proj(x)
gate = F.gelu(gate, approximate="tanh")
up = self.up_proj(x)
fuse = gate * up
outputs = self.down_proj(fuse)
return outputs

class GemmaAttention(nn.Module):

def __init__(
self,
config: gemma_config.GemmaConfig,
attn_type: gemma_config.AttentionType,
):
super().__init__()

self.num_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads

assert self.num_heads % self.num_kv_heads == 0


self.num_queries_per_kv = self.num_heads // self.num_kv_heads

self.hidden_size = config.hidden_size
self.head_dim = config.head_dim

self.q_size = self.num_heads * self.head_dim


self.kv_size = self.num_kv_heads * self.head_dim

if config.query_pre_attn_scalar is not None:


self.scaling = config.query_pre_attn_scalar**-0.5
else:
self.scaling = self.head_dim**-0.5

self.qkv_proj = Linear(
self.hidden_size,
(self.num_heads + 2 * self.num_kv_heads) * self.head_dim,
quant=config.quant)
self.o_proj = Linear(
self.num_heads * self.head_dim, self.hidden_size,
quant=config.quant
)
self.query_norm = (
RMSNorm(self.head_dim, eps=config.rms_norm_eps)
if config.use_qk_norm
else None
)
self.key_norm = (
RMSNorm(self.head_dim, eps=config.rms_norm_eps)
if config.use_qk_norm
else None
)

self.attn_type = attn_type
self.sliding_window_size = config.sliding_window_size
self.attn_logit_softcapping = config.attn_logit_softcapping

def forward(
self,
hidden_states: torch.Tensor,
freqs_cis: torch.Tensor,
kv_write_indices: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor],
mask: torch.Tensor,
local_mask: torch.Tensor = None,
) -> torch.Tensor:
hidden_states_shape = hidden_states.shape
assert len(hidden_states_shape) == 3

batch_size, input_len, _ = hidden_states_shape

qkv = self.qkv_proj(hidden_states)
xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)

xq = xq.view(batch_size, -1, self.num_heads, self.head_dim)


xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim)
xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim)

if self.query_norm is not None and self.key_norm is not None:


xq = self.query_norm(xq)
xk = self.key_norm(xk)

# Positional embedding.
xq = apply_rotary_emb(xq, freqs_cis=freqs_cis)
xk = apply_rotary_emb(xk, freqs_cis=freqs_cis)

# Write new kv cache.


# [batch_size, input_len, n_local_kv_heads, head_dim]
k_cache, v_cache = kv_cache
k_cache.index_copy_(1, kv_write_indices, xk)
v_cache.index_copy_(1, kv_write_indices, xv)

key = k_cache
value = v_cache
if self.num_kv_heads != self.num_heads:
# [batch_size, max_seq_len, n_local_heads, head_dim]
key = torch.repeat_interleave(key, self.num_queries_per_kv,
dim=2)
value = torch.repeat_interleave(value,
self.num_queries_per_kv,
dim=2)

# [batch_size, n_local_heads, input_len, head_dim]


q = xq.transpose(1, 2)
# [batch_size, n_local_heads, max_seq_len, head_dim]
k = key.transpose(1, 2)
v = value.transpose(1, 2)

# [batch_size, n_local_heads, input_len, max_seq_len]


q.mul_(self.scaling)
scores = torch.matmul(q, k.transpose(2, 3))
if (
self.attn_type == gemma_config.AttentionType.LOCAL_SLIDING
and self.sliding_window_size is not None
and local_mask is not None
):
mask = local_mask

if self.attn_logit_softcapping is not None:


scores = scores / self.attn_logit_softcapping
scores = torch.tanh(scores)
scores = scores * self.attn_logit_softcapping

scores = scores + mask


scores = F.softmax(scores.float(), dim=-1).type_as(q)

# [batch_size, n_local_heads, input_len, head_dim]


output = torch.matmul(scores, v)

# [batch_size, input_len, hidden_dim]


output = (output.transpose(1, 2).contiguous().view(
batch_size, input_len, -1))
output = self.o_proj(output)
return output

The GemmaDecoderLayer class defines the structure of a single transformer block for
the Gemma 1 architecture. It follows a standard "pre-normalization" setup. The
forward method first computes a residual, then applies RMSNorm (input_layernorm)
to the input, which is then processed by the GemmaAttention block (self_attn). The
output of the attention block is added back to the initial residual. This is followed by a
second residual connection around the MLP block: the result from the attention stage
is normalized again (post_attention_layernorm), passed through the GemmaMLP, and
finally added to its own residual.

The Gemma2DecoderLayer class is a more advanced version of the decoder layer,


used in Gemma 2 and 3. Its attention block is similar to the first version, but the MLP
block has a more complex normalization scheme. It includes optional
pre_feedforward_layernorm and post_feedforward_layernorm layers. This allows for
more flexible model configurations where normalization can be applied before, after,
or both before and after the MLP computation, in addition to the standard
normalization layers before the attention and MLP blocks.

class GemmaDecoderLayer(nn.Module):

def __init__(
self,
config: gemma_config.GemmaConfig,
):
super().__init__()
self.attn_type = gemma_config.AttentionType.GLOBAL
self.self_attn = GemmaAttention(
config=config,
attn_type=self.attn_type)
self.mlp = GemmaMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
quant=config.quant,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)

# TODO(imayank): Decouple Gemma versions into separate files.


def forward(
self,
hidden_states: torch.Tensor,
freqs_cis: torch.Tensor,
kv_write_indices: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor],
mask: torch.Tensor,
local_mask: torch.Tensor,
) -> torch.Tensor:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
hidden_states=hidden_states,
freqs_cis=freqs_cis,
kv_write_indices=kv_write_indices,
kv_cache=kv_cache,
mask=mask,
)
hidden_states = residual + hidden_states

# MLP
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states

return hidden_states

class Gemma2DecoderLayer(nn.Module):
def __init__(
self,
config: gemma_config.GemmaConfig,
attn_type: gemma_config.AttentionType,
):
super().__init__()
self.attn_type = attn_type
self.self_attn = GemmaAttention(
config=config,
attn_type=self.attn_type,
)
self.mlp = GemmaMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
quant=config.quant,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.pre_feedforward_layernorm = (
RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
if config.use_pre_ffw_norm
else None
)
self.post_feedforward_layernorm = (
RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
if config.use_post_ffw_norm
else None
)

def forward(
self,
hidden_states: torch.Tensor,
freqs_cis: torch.Tensor,
kv_write_indices: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor],
mask: torch.Tensor,
local_mask: torch.Tensor,
) -> torch.Tensor:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
hidden_states=hidden_states,
freqs_cis=freqs_cis,
kv_write_indices=kv_write_indices,
kv_cache=kv_cache,
mask=mask,
local_mask=local_mask,
)
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = residual + hidden_states

# MLP
residual = hidden_states
if self.pre_feedforward_layernorm is not None:
hidden_states = self.pre_feedforward_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
if self.post_feedforward_layernorm is not None:
hidden_states = self.post_feedforward_layernorm(hidden_states)
hidden_states = residual + hidden_states

return hidden_states

The GemmaModel class represents the core transformer architecture, assembling the
stack of decoder layers. During initialization, it creates a ModuleList of decoder layers,
dynamically choosing between GemmaDecoderLayer for Gemma 1 and
Gemma2DecoderLayer for Gemma 2 and 3 based on the provided configuration. For
Gemma 2/3, it also assigns the attention type (local or global) to each layer. The
forward method sequentially processes the input hidden_states through each layer,
passing along the necessary RoPE frequencies, KV cache for that layer, and attention
masks. After the final layer, a concluding RMSNorm is applied to the output.

GemmaForCausalLM is the all-encompassing class that combines all the building


blocks into a complete, functional causal language model. The constructor initializes
the tokenizer, the token embedder, the GemmaModel (the transformer stack), and the
final Sampler. A crucial step during initialization is the pre-computation of the RoPE
frequencies (freqs_cis). For Gemma 3, it creates distinct local_freqs_cis and
global_freqs_cis tables for different attention types, while older architectures use a
single table. The class exposes forward for single-step computation, generate for
end-to-end text generation, and load_weights for model loading.

The private helper method, _register_freqs_cis, is used by the GemmaForCausalLM


constructor to handle the creation and registration of the RoPE frequency tables. It
calls the precompute_freqs_cis function with the specified parameters (head_dim,
max_seq_len, theta) and then uses PyTorch's register_buffer to save the resulting
tensor as part of the model's state. Registering it as a buffer ensures that the tensor is
moved to the correct device along with the model and is included in the model's
state_dict, but is not considered a trainable parameter.

The forward method of GemmaForCausalLM orchestrates a single inference step. It


starts by selecting the appropriate RoPE frequencies (freqs_cis) from the
pre-computed tables based on the input positions. The input token IDs are then
converted to dense vectors by the embedder, and these embeddings are scaled by
the square root of the hidden size. These prepared embeddings are passed to the
core GemmaModel along with the positional encodings, KV caches, and masks. The
resulting output hidden states are then passed to the sampler, which uses the
de-quantized embedder weights to compute final logits and sample the next token
IDs.

class GemmaModel(nn.Module):

def __init__(self, config: gemma_config.GemmaConfig):


super().__init__()
self.config = config
self.vocab_size = config.vocab_size

self.layers = nn.ModuleList()
for i in range(config.num_hidden_layers):
if config.architecture == gemma_config.Architecture.GEMMA_1:
self.layers.append(GemmaDecoderLayer(config))
elif config.architecture in (
gemma_config.Architecture.GEMMA_2,
gemma_config.Architecture.GEMMA_3,
):
attn_type = (
config.attn_types[i % len(config.attn_types)]
if config.attn_types is not None
else gemma_config.AttentionType.GLOBAL
)
self.layers.append(Gemma2DecoderLayer(config, attn_type))
else:
raise ValueError(f'Unknown architecture:
{config.architecture}')
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

def forward(
self,
hidden_states: torch.Tensor,
freqs_cis: Mapping[gemma_config.AttentionType, torch.Tensor],
kv_write_indices: torch.Tensor,
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
mask: torch.Tensor,
local_mask: torch.Tensor,
) -> torch.Tensor:
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states = layer(
hidden_states=hidden_states,
freqs_cis=freqs_cis.get(layer.attn_type),
kv_write_indices=kv_write_indices,
kv_cache=kv_caches[i],
mask=mask,
local_mask=local_mask,
)
hidden_states = self.norm(hidden_states)
return hidden_states​

class GemmaForCausalLM(nn.Module):

def __init__(
self,
config: gemma_config.GemmaConfig,
):
super().__init__()
self.config = config
assert config.hidden_size % config.num_attention_heads == 0

max_seq_len = config.max_position_embeddings
head_dim = config.head_dim
vocab_size = config.vocab_size

self.tokenizer = tokenizer.Tokenizer(config.tokenizer)
self.embedder = Embedding(vocab_size, config.hidden_size, config.quant)
self.model = GemmaModel(config)
self.sampler = Sampler(vocab_size, config)

# Pre-compute rotary embedding table.


if config.architecture == gemma_config.Architecture.GEMMA_3:
if config.rope_wave_length is None:
raise ValueError('rope_wave_length must be provided for Gemma3.')

rope_lengths = config.rope_wave_length
defaults = {
gemma_config.AttentionType.LOCAL_SLIDING: 10_000,
gemma_config.AttentionType.GLOBAL: 10_000,
}

for attn_type, name in [


(gemma_config.AttentionType.LOCAL_SLIDING,
'local_freqs_cis'),
(gemma_config.AttentionType.GLOBAL, 'global_freqs_cis'),
]:
theta = rope_lengths.get(
attn_type, defaults[attn_type]
)
self._register_freqs_cis(name, head_dim, max_seq_len, theta=theta)

else:
self._register_freqs_cis('freqs_cis', head_dim, max_seq_len)

def _register_freqs_cis(
self, name: str, head_dim: int, max_seq_len: int, theta: int =
10_000
):
self.register_buffer(
name, precompute_freqs_cis(head_dim, max_seq_len * 2,
theta=theta)
)​

@torch.no_grad()
def forward(
self,
input_token_ids: torch.Tensor,
input_positions: torch.Tensor,
kv_write_indices: torch.Tensor,
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
mask: torch.Tensor,
output_positions: torch.Tensor,
temperatures: Union[torch.Tensor, None],
top_ps: torch.Tensor,
top_ks: torch.Tensor,
local_mask: torch.Tensor | None = None,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
freqs_cis = {}

if self.config.architecture == gemma_config.Architecture.GEMMA_3:
freqs_cis[gemma_config.AttentionType.LOCAL_SLIDING] = (
self.local_freqs_cis.index_select(0, input_positions)
)
freqs_cis[gemma_config.AttentionType.GLOBAL] = (
self.global_freqs_cis.index_select(0, input_positions)
)
else:
freqs_cis[gemma_config.AttentionType.LOCAL_SLIDING] = (
self.freqs_cis.index_select(0, input_positions)
)
freqs_cis[gemma_config.AttentionType.GLOBAL] = (
self.freqs_cis.index_select(0, input_positions)
)

kv_write_indices = input_positions

# [batch_size, input_len, hidden_size]


hidden_states = self.embedder(input_token_ids)
# Gemma normalizes the embedding by sqrt(hidden_size).
# Gemma2 downcasts the below to float16, causing sqrt(3072)=55.4256 to
become 55.5
# See https://github.com/huggingface/transformers/pull/29402
normalizer = torch.tensor(self.config.hidden_size**0.5,
dtype=hidden_states.dtype, device=hidden_states.device)
hidden_states = hidden_states * normalizer

hidden_states = self.model(
hidden_states=hidden_states,
freqs_cis=freqs_cis,
kv_write_indices=kv_write_indices,
kv_caches=kv_caches,
mask=mask,
local_mask=local_mask,
)
embedder_weight = self.embedder.weight
if self.config.quant:
embedder_weight = (
embedder_weight * self.embedder.weight_scaler.unsqueeze(-1))
next_tokens, logits = self.sampler(
embedding=embedder_weight,
hidden_states=hidden_states,
output_positions=output_positions,
temperatures=temperatures,
top_ps=top_ps,
top_ks=top_ks,
)
return next_tokens, logits

The generate method provides the main entry point for users to perform text
generation. It orchestrates the entire auto-regressive decoding process. First, it
tokenizes the input prompts and initializes the KV caches and attention masks. The
core of the method is a loop that runs for the desired output length. This loop
implements a two-phase generation strategy: a "prefill" phase where the input prompt
is processed in a single large batch to populate the KV cache, followed by a "decode"
phase where tokens are generated one by one. In each step of the loop, it calls the
forward method to get the next token, updates the overall sequence of tokens, and
prepares the inputs for the next iteration. Finally, after the loop finishes, it detokenizes
the completed sequences and returns the generated text.

def generate(
self,
prompts: Union[str, Sequence[str]],
device: Any,
output_len: int = 100,
temperature: Union[float, None] = 1.0,
top_p: float = 0.95,
top_k: int = 64,
) -> Union[str, Sequence[str]]:
"""Generates responses for given prompts using Gemma model."""
# If a single prompt is provided, treat it as a batch of 1.
is_str_prompt = isinstance(prompts, str)
if is_str_prompt:
prompts = [prompts]

batch_size = len(prompts)
prompt_tokens = [self.tokenizer.encode(prompt) for prompt in prompts]
min_prompt_len = min(len(p) for p in prompt_tokens)
max_prompt_len = max(len(p) for p in prompt_tokens)
max_seq_len = max_prompt_len + output_len
assert max_seq_len <= self.config.max_position_embeddings

# build KV caches
kv_caches = []
for _ in range(self.config.num_hidden_layers):
size = (batch_size, max_seq_len, self.config.num_key_value_heads,
self.config.head_dim)
dtype = self.config.get_dtype()
k_cache = torch.zeros(size=size, dtype=dtype, device=device)
v_cache = torch.zeros(size=size, dtype=dtype, device=device)
kv_caches.append((k_cache, v_cache))

# prepare inputs
token_ids_tensor = torch.full((batch_size, max_seq_len),
self.tokenizer.pad_id,
dtype=torch.int64)
input_token_ids_tensor = torch.full((batch_size, min_prompt_len),
self.tokenizer.pad_id,
dtype=torch.int64)
for i, p in enumerate(prompt_tokens):
token_ids_tensor[i, :len(p)] = torch.tensor(p)
input_token_ids_tensor[i, :min_prompt_len] = torch.tensor(
p[:min_prompt_len])
token_ids_tensor = token_ids_tensor.to(device)
input_token_ids_tensor = input_token_ids_tensor.to(device)
prompt_mask_tensor = token_ids_tensor != self.tokenizer.pad_id
input_positions_tensor = torch.arange(0, min_prompt_len,
dtype=torch.int64).to(device)
mask_tensor = torch.full((1, 1, max_seq_len, max_seq_len),
-2.3819763e38).to(torch.float)
mask_tensor = torch.triu(mask_tensor, diagonal=1).to(device)
local_mask_tensor = mask_tensor + torch.tril(
torch.full((1, 1, max_seq_len, max_seq_len), -2.3819763e38,
device=device),
diagonal=-self.config.sliding_window_size,
) if self.config.sliding_window_size else None
curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor)
curr_local_mask_tensor = local_mask_tensor.index_select(
2, input_positions_tensor
) if local_mask_tensor is not None else None
output_positions_tensor = torch.LongTensor([min_prompt_len -
1]).to(device)
temperatures_tensor = None if not temperature else torch.FloatTensor(
[temperature] * batch_size).to(device)
top_ps_tensor = torch.FloatTensor([top_p] * batch_size).to(device)
top_ks_tensor = torch.LongTensor([top_k] * batch_size).to(device)
output_index = torch.tensor(min_prompt_len, dtype=torch.int64).to(
device)

# Prefill up to min_prompt_len tokens, then treat other prefill as


# decode and ignore output.
for i in range(max_seq_len - min_prompt_len):
next_token_ids, _ = self(
input_token_ids=input_token_ids_tensor,
input_positions=input_positions_tensor,
kv_write_indices=None,
kv_caches=kv_caches,
mask=curr_mask_tensor,
output_positions=output_positions_tensor,
temperatures=temperatures_tensor,
top_ps=top_ps_tensor,
top_ks=top_ks_tensor,
local_mask=curr_local_mask_tensor,
)

curr_prompt_mask = prompt_mask_tensor.index_select(
1, output_index).squeeze(dim=1)
curr_token_ids = token_ids_tensor.index_select(
1, output_index).squeeze(dim=1)
output_token_ids = torch.where(curr_prompt_mask, curr_token_ids,
next_token_ids).unsqueeze(dim=1)
token_ids_tensor.index_copy_(1, output_index, output_token_ids)

input_token_ids_tensor = output_token_ids
input_positions_tensor = output_index.unsqueeze(dim=-1)
curr_mask_tensor = mask_tensor.index_select(2,

input_positions_tensor)
curr_local_mask_tensor = local_mask_tensor.index_select(
2, input_positions_tensor
) if local_mask_tensor is not None else None
output_positions_tensor = torch.tensor(0, dtype=torch.int64).to(
device)
output_index = output_index + 1

# Detokenization.
token_ids = token_ids_tensor.tolist()
results = []
for i, tokens in enumerate(token_ids):
trimmed_output = tokens[len(prompt_tokens[i]):len(prompt_tokens[i])
+ output_len]
if self.tokenizer.eos_id in trimmed_output:
eos_index = trimmed_output.index(self.tokenizer.eos_id)
trimmed_output = trimmed_output[:eos_index]
results.append(self.tokenizer.decode(trimmed_output))

# If a string was provided as input, return a string as output.


return results[0] if is_str_prompt else results

The utility method, load_weights, is designed to load pre-trained weights into the
model. It handles two scenarios: loading from a single checkpoint file and loading
from a sharded checkpoint, where weights are split across multiple files. It first checks
if the provided model_path is a file. If so, it loads the state dictionary directly. If the
path is a directory, it looks for an index JSON file (pytorch_model.bin.index.json) that
maps weight names to their respective shard files. It then iterates through the unique
shard files, loading each one's state dictionary into the model. To conserve memory
during this process, it deletes each state dictionary and calls the garbage collector
after loading.

def load_weights(self, model_path: str):


if os.path.isfile(model_path):
self.load_state_dict(
torch.load(
model_path, mmap=True, weights_only=True,
)['model_state_dict'],
strict=False,
)
else:
index_path = os.path.join(model_path,
'pytorch_model.bin.index.json')
with open(index_path, "r", encoding="utf-8") as f:
index = json.load(f)
shard_files = list(set(index["weight_map"].values()))
for shard_file in shard_files:
shard_path = os.path.join(model_path, shard_file)
state_dict = torch.load(shard_path, map_location="cpu",
weights_only=True)
self.load_state_dict(state_dict, strict=False)
del state_dict # Save memory.
gc.collect()

gemma/model_xla.py
Code: https://github.com/google/gemma_pytorch/blob/main/gemma/model_xla.py.

This code implements distributed Gemma model execution using Tensor Parallelism,
optimized for Google's TPUs via XLA. It adapts the standard Gemma architecture to
efficiently utilize multiple accelerator devices, addressing the challenge of models too
large for a single accelerator's memory. Tensor parallelism is employed to partition, or
"shard," the model's large weight matrices across a group of devices. Each device
stores only a segment of the model's weights and performs a corresponding fraction
of the total computation. The entire Gemma model is rebuilt using specialized parallel
layers defined in xla_model_parallel.py, including ColumnParallelLinear,
RowParallelLinear, and ParallelEmbedding. These layers manage the complex
communication necessary to synchronize computations across all devices, making the
distributed nature of the model transparent to the end-user.

Core Parallel Components: The code re-implements the standard Gemma modules
to be aware of the distributed environment.

●​ Sampler: The sampling module is modified to work in parallel. When calculating


the final logits, the matrix multiplication between the hidden states and the
embedding matrix is performed in a distributed manner. Each device computes
a partial logit score, and these scores are then combined across all devices
using an all-reduce operation to produce the final, complete logit vector.
●​ GemmaMLP: The feed-forward network is split using a combination of column
and row parallelism.
○​ The gate_proj and up_proj layers use ColumnParallelLinear, splitting their
weight matrices vertically.
○​ The down_proj layer uses RowParallelLinear, splitting its weight matrix
horizontally and performing an all-reduce sum to combine the partial
results.
●​ GemmaAttention: The attention module is also heavily parallelized. The large
qkv_proj and o_proj layers are replaced with their parallel equivalents. The
number of attention heads (num_heads) and key-value heads (num_kv_heads)
is divided among the devices in the parallel group, so each device only handles
a subset of the total heads.

The Top-Level Parallel Model: The GemmaForCausalLM class is the main entry
point. It constructs the entire model using the parallel components.

●​ Initialization: It takes the world_size (the number of devices in the parallel


group) and the rank (the ID of the current device) as arguments. It then
instantiates the ParallelEmbedding, the parallel GemmaModel, and the parallel
Sampler.
●​ Forward Pass: The forward method orchestrates the flow of data through the
distributed network. It's similar to the standard model's forward pass, but under
the hood, each layer is performing its computations and communications in
parallel across the device cluster.
●​ Weight Loading: The load_weights and _load_weights methods are particularly
important. When loading a pre-trained checkpoint, these functions are
responsible for intelligently splitting the weight tensors from the checkpoint file
and distributing the correct slice to each device. For a ColumnParallelLinear
layer, for instance, it splits the weight matrix along its columns and gives each
rank its corresponding partition. This ensures that the entire model is correctly
loaded across the distributed system.

Let’s have a look at specific parts. The GemmaAttention class implements the
self-attention mechanism for a tensor-parallel environment. During initialization, it
calculates the number of attention heads and key-value heads per device
(num_heads, num_kv_heads). The QKV projection is handled by a single
ColumnParallelLinear layer, which shards the combined QKV weight matrix across
devices. Conversely, the output projection (o_proj) is a RowParallelLinear layer that
gathers and sums the partial results from each device's attention heads. The forward
pass applies rotary embeddings and manages the KV cache locally on each device,
while the parallel layers handle the necessary communication implicitly.

gemma/model_xla.py

# Copyright 2024 Google LLC


#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Inference-only Gemma model implementation."""

[...]
class GemmaAttention(nn.Module):

def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
attn_logit_softcapping: Optional[float],
query_pre_attn_scalar: Optional[int],
head_dim: int,
world_size: int,
rank: int,
quant: bool,
attn_type: gemma_config.AttentionType,
sliding_window_size: Optional[int] = None,
):
super().__init__()
self.rank = rank

def init_method(x):
return x

self.total_num_heads = num_heads
assert self.total_num_heads % world_size == 0
self.num_heads = self.total_num_heads // world_size # head per
shard

if num_kv_heads < world_size:


assert world_size % num_kv_heads == 0
self.total_num_kv_heads = world_size
else:
assert num_kv_heads % world_size == 0
self.total_num_kv_heads = num_kv_heads
self.num_kv_heads = self.total_num_kv_heads // world_size # kv head
per shard

assert self.num_heads % self.num_kv_heads == 0


self.num_queries_per_kv = self.num_heads // self.num_kv_heads

self.hidden_size = hidden_size
self.head_dim = head_dim
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim

if query_pre_attn_scalar is not None:


self.scaling = query_pre_attn_scalar**-0.5
else:
self.scaling = self.head_dim**-0.5

self.qkv_proj = ColumnParallelLinear(
self.hidden_size,
(self.total_num_heads + 2 * self.total_num_kv_heads) *
self.head_dim,
bias=False,
gather_output=False,
init_method=init_method,
world_size=world_size,
rank=rank,
quant=quant,
)

self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
self.hidden_size,
bias=False,
input_is_parallel=True,
init_method=init_method,
world_size=world_size,
rank=rank,
quant=quant,
)

self.attn_type = attn_type
self.sliding_window_size = sliding_window_size
self.attn_logit_softcapping = attn_logit_softcapping

def forward(
self,
hidden_states: torch.Tensor,
freqs_cis: torch.Tensor,
kv_write_indices: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor],
mask: torch.Tensor,
) -> torch.Tensor:
hidden_states_shape = hidden_states.shape
assert len(hidden_states_shape) == 3

batch_size, input_len, _ = hidden_states_shape

qkv = self.qkv_proj(hidden_states)
xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)

xq = xq.view(batch_size, -1, self.num_heads, self.head_dim)


xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim)
xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim)

# Positional embedding.
xq = apply_rotary_emb(xq, freqs_cis=freqs_cis)
xk = apply_rotary_emb(xk, freqs_cis=freqs_cis)

# Write new kv cache.


# [batch_size, input_len, n_local_kv_heads, head_dim]
k_cache, v_cache = kv_cache
k_cache.index_copy_(1, kv_write_indices, xk)
v_cache.index_copy_(1, kv_write_indices, xv)

key = k_cache
value = v_cache
if self.num_kv_heads != self.num_heads:
# [batch_size, max_seq_len, n_local_heads, head_dim]
key = torch.repeat_interleave(key, self.num_queries_per_kv,
dim=2)
value = torch.repeat_interleave(value,
self.num_queries_per_kv,
dim=2)

# [batch_size, n_local_heads, input_len, head_dim]


q = xq.transpose(1, 2)
# [batch_size, n_local_heads, max_seq_len, head_dim]
k = key.transpose(1, 2)
v = value.transpose(1, 2)

# [batch_size, n_local_heads, input_len, max_seq_len]


q.mul_(self.scaling)
scores = torch.matmul(q, k.transpose(2, 3))
if (
self.attn_type == gemma_config.AttentionType.LOCAL_SLIDING
and self.sliding_window_size is not None
):
all_ones = torch.ones_like(mask)
sliding_mask = torch.triu(
all_ones, -1 * self.sliding_window_size + 1
) * torch.tril(all_ones, self.sliding_window_size - 1)
mask = torch.where(sliding_mask == 1, mask, -2.3819763e38)
if self.attn_logit_softcapping is not None:
scores = scores / self.attn_logit_softcapping
scores = torch.tanh(scores)
scores = scores * self.attn_logit_softcapping
scores = scores + mask
scores = F.softmax(scores.float(), dim=-1).type_as(q)

# [batch_size, n_local_heads, input_len, head_dim]


output = torch.matmul(scores, v)

# [batch_size, input_len, hidden_dim]


output = (output.transpose(1, 2).contiguous().view(
batch_size, input_len, -1))
output = self.o_proj(output)
return output

The GemmaForCausalLM class is the top-level container for the XLA-parallelized


Gemma model. It integrates all the necessary components for inference in a
distributed setting. The constructor initializes parallel-aware modules:
ParallelEmbedding (which shards the embedding table across devices), the parallel
GemmaModel stack, and the parallel Sampler. It also pre-computes and registers the
RoPE frequencies (freqs_cis) as a buffer. The class is responsible for orchestrating the
entire forward pass and loading weights in a parallelism-aware manner. The forward
method of the XLA GemmaForCausalLM class defines the computation for a single
generation step. It begins by selecting the relevant RoPE frequencies based on the
input positions. The input token IDs are then passed to the ParallelEmbedding layer to
get their vector representations, which are subsequently normalized by the square
root of the hidden size. These hidden states are processed by the core GemmaModel.
Finally, the output hidden states from the model are passed to the Sampler, which
computes the final logits and samples the next token, completing the distributed
forward pass.

gemma/tokenizer.py
Code: https://github.com/google/gemma_pytorch/blob/main/gemma/model_xla.py.
This script provides the tokenizer for Gemma models, converting text to numerical IDs
and vice-versa. The Tokenizer class, a wrapper around Google's SentencePiece
library, loads a pre-trained model for tokenization. It identifies special tokens like
bos_id (Beginning of Sequence), eos_id (End of Sequence), pad_id (for padding), and
boi_id/eoi_id (Beginning/End of Image) for multimodal Gemma 3, with pad_id also
serving as image_token_placeholder_id. The encode method converts strings to token
IDs (with optional bos/eos addition), and decode converts IDs back to strings.

gemma/xla_model_parallel.py
This script provides the essential toolkit for running Gemma models at scale using a
technique called Tensor Model Parallelism. It's designed for distributed computing
environments like Google's TPUs (via XLA) or multi-GPU setups (via CUDA), allowing
models with billions of parameters to be split across multiple accelerator chips.

Communication Primitives
This file's core consists of fundamental building blocks that manage how data is
communicated between different devices in a distributed group, implemented as
custom torch.autograd.Function classes that handle gradients during
backpropagation.

●​ scatter: Splits a large tensor along a dimension, giving each device a unique
slice.
●​ gather: Collects and concatenates slices from all devices to reconstruct the
original large tensor.
●​ reduce: Combines tensors from all devices into a single tensor via operations
like sum (all-reduce), synchronizing results.
●​ copy: A simple identity operation in the forward pass, but it triggers a reduce
operation in the backward pass to aggregate gradients.

These primitives abstract complex underlying communication calls (xm.all\_reduce,


dist.all\_reduce, etc.), providing a clean API for building parallel neural network layers.

Parallel Neural Network Layers

Using the communication primitives, the script defines new versions of standard
PyTorch layers that are inherently parallel.

●​ ParallelEmbedding: This module splits the large token embedding table. The
table, which has a shape of (vocab_size, hidden_size), is partitioned along the
hidden_size dimension (column-wise). Each device holds a vertical slice of the
embedding table. When the layer receives input token IDs, each device looks
up its slice of the embedding vector, and the results are gathered from all
devices to form the complete embedding vector.
●​ ColumnParallelLinear: This layer implements a linear transformation (Y = XA +
b) where the weight matrix A is partitioned vertically (column-wise). The input X
is broadcast to all devices. Each device performs a matrix multiplication
between X and its local slice of the weight matrix A_i, producing a partial output
Y_i. The partial outputs Y_i from all devices are then gathered and
concatenated to form the final output Y. This strategy is used for layers where
the output dimension is large, such as the gate_proj and up_proj layers in the
MLP.
●​ RowParallelLinear: This layer partitions the weight matrix A horizontally
(row-wise). The input X is split, and each device receives a slice X_i. Each
device performs a matrix multiplication between its input slice X_i and its
weight slice A_i, producing a partial output. The partial outputs from all devices
are then summed together using an all-reduce operation to produce the final
output Y. This is used for layers where the input dimension is large, like the
down_proj in the MLP and the o_proj in the attention mechanism.
●​ Quantization Support: The script also includes functions for 8-bit
quantization. The quantize_tensor function can take a standard floating-point
tensor, calculate the appropriate scaling factor, and convert it into a
memory-efficient int8 tensor. All the parallel layers (ParallelEmbedding,
ColumnParallelLinear, RowParallelLinear) have built-in support for this, allowing
them to operate on quantized weights to further reduce the model's memory
footprint.

Conclusion
The Gemma PyTorch codebase is an exceptionally well-engineered system, from its
foundational hyperparameters to its distributed model's intricate tensor operations. Its
core transformer blocks incorporate modern, efficient techniques like RMS
Normalization for computational stability and Rotary Positional Embeddings for
sophisticated sequence order understanding. Furthermore, Grouped-Query Attention
is implemented to pragmatically optimize inference speed and memory usage without
significant performance trade-offs. A significant evolutionary step is Gemma 3's
introduction of multimodality, which deeply integrates a dedicated vision model into
the language processing stream. This is achieved through a specialized preprocessor,
a "pan-and-scan" technique for detailed image analysis, and a sophisticated attention
mask that facilitates bidirectional visual reasoning, enabling a true fusion of sight and
language. This design provides a compelling blueprint for future multimodal systems.

Finally, the codebase demonstrates practical considerations for deploying models at


scale by supporting tensor parallelism via XLA. The use of column- and row-parallel
layers showcases masterful distributed systems engineering, transforming a
theoretical model into a practical tool for the largest computational clusters.
Essentially, the Gemma codebase is more than just an implementation of a powerful
model; it serves as a comprehensive case study in modern AI engineering, exhibiting a
system designed for performance, scalability, and extensibility into the next
generation of artificial intelligence.
Proximal Policy Optimization (PPO): A Deeper Look
PPO's primary function within RLHF is to fine-tune the language model (the "policy")
using feedback from the trained reward model, but to do so without taking
destructively large steps that could destabilize the model's performance.

Key Technical Concepts


●​ Policy (πθ​): This is the language model itself, parameterized by weights θ. For
a given prompt (state, s), the policy outputs a probability distribution over the
next possible tokens (actions, a). The goal is to update θ to improve the policy.
●​ Advantage Estimate (A^t​): This is a crucial value. Instead of just using the raw
score (reward) from the reward model, PPO uses the advantage. The
advantage of an action is how much better it was than the expected or average
action from that state. It's calculated as:​
A^(st​,at​)=Rt​−V(st​)​
Where Rt​is the reward (from the reward model) for taking action at​in state st​,
and V(st​) is the value function's (also from the reward model) estimate of the
average reward from state st​. A positive advantage means the action taken was
better than the baseline, and a negative advantage means it was worse. Using
the advantage helps to reduce the variance of the updates and stabilize
training.

The PPO Clipped Surrogate Objective Function


The core innovation of PPO is its objective function, which constrains how much the
policy can change in a single update.

First, we define the probability ratio, rt​(θ):

rt​(θ)=πθold​​(at​∣st​)πθ​(at​∣st​)​

This ratio compares the likelihood of taking action at​under the new policy (πθ​) versus
the old policy (πθold​​) before the update. If rt​(θ)>1, the new policy is more likely to take
that action; if rt​(θ)<1, it is less likely.

The PPO objective function then incorporates this ratio in a "clipped" way:

LCLIP(θ)=E^t​[min(rt​(θ)A^t​,clip(rt​(θ),1−ϵ,1+ϵ)A^t​)]
Let's break down the min and clip operations:

1.​ The First Term: r_t(\theta)\hat{A}_t is the standard policy gradient


objective. It encourages the policy to increase the probability of actions that
have a positive advantage.
2.​ The Second Term: clip(r_t(\theta), 1-\epsilon,
1+\epsilon)\hat{A}_t is the PPO innovation.
○​ The clip function forces the probability ratio rt​(θ) to stay within the
range [1−ϵ,1+ϵ]. The hyperparameter ϵ is a small number (e.g., 0.2),
which defines the size of the "trust region."
○​ This clipped ratio is then multiplied by the advantage A^t​.
3.​ The min Function: PPO takes the minimum of the two terms. This acts as a
pessimistic bound, or a constraint.
○​ If Advantage (A^t​) is positive: We want to increase the probability of
this action. The objective becomes min(rt​(θ)A^t​,(1+ϵ)A^t​). This means
the potential reward for the update is capped. Even if the policy wants to
make a huge change (a very large rt​), the benefit it gets is limited by the
1+ϵ boundary. This prevents it from chasing a large reward too
aggressively.
○​ If Advantage (A^t​) is negative: We want to decrease the probability of
this action. The min function ensures the penalty is severe enough to
discourage the action but prevents an overly drastic update that could
destabilize the model.

By taking this minimum, PPO ensures that each update step is small and controlled,
leading to more stable and reliable training than earlier policy gradient methods.

Direct Preference Optimization (DPO): A Deeper


Look
DPO offers a more direct route to alignment by reformulating the problem, which
allows it to bypass the explicit reward modeling and RL stages of RLHF entirely.

Key Technical Concepts


●​ Preference Data (D): The foundation of DPO is a dataset of preference pairs,
where each entry is a triplet: (x,yw​,yl​). Here, x is the prompt, yw​is the "winner"
response (preferred by humans), and yl​is the "loser" response.
●​ Reference Policy (πref​): This is the initial supervised fine-tuned (SFT) model
before alignment begins. DPO's goal is to adjust this policy to satisfy the
preferences while not deviating too far from its original capabilities.
●​ The Bradley-Terry Model: DPO uses this model to link preferences to an
underlying reward. The model states that the probability of preferring yw​over
yl​is a logistic function of the difference in their hidden reward scores:​
p(yw​≻yl​∣x)=σ(r∗(x,yw​)−r∗(x,yl​))​
where σ is the sigmoid (logistic) function and r∗ is the latent optimal reward
function.

The DPO Objective Function


The core insight of the DPO paper is proving that this unknown reward function r∗ can
be expressed directly in terms of the optimal policy π∗ and the reference policy πref​:

r∗(x,y)=βlogπref​(y∣x)π∗(y∣x)​

Here, β is a hyperparameter that controls how much the reward function deviates
from the reference policy. Substituting this back into the preference model gives us
the probability of a preference pair in terms of language model policies alone.

This allows us to construct a loss function that optimizes the policy πθ​ directly. The
DPO loss is the negative log-likelihood of the preference data:

LDPO​(πθ​;πref​)=−E(x,yw​,yl​)∼D​[logσ(βlogπref​(yw​∣x)πθ​(yw​∣x)​−βlogπref​(yl​∣x)πθ​(yl​∣x)​)]

Let's dissect this:

1.​ The Ratios: For both the winning (yw​) and losing (yl​) responses, the model
calculates the ratio of the log probability under the current policy (πθ​) to the
log probability under the frozen reference policy (πref​).
2.​ The Difference: The loss function is driven by the difference between the
winner's log-probability ratio and the loser's log-probability ratio.
3.​ Optimization Goal: To minimize this loss, the model must adjust its parameters
θ such that it increases the relative log-probability of the winning
response (yw​) and decreases the relative log-probability of the losing
response (yl​).

Essentially, DPO trains the LLM on a simple binary classification task: for a given
prompt, classify which response is the "winner." By doing so, it implicitly optimizes for
the same preferences that RLHF would, but it does so in a single, stable, and
computationally efficient training phase without ever needing to fit a separate reward
model.

You might also like