1 The Dawn of A New Architecture 1 The Core Transformer Architecture: An Overview 2
1 The Dawn of A New Architecture 1 The Core Transformer Architecture: An Overview 2
Part 1
Introduction                                                               1
The Dawn of a New Architecture                                             1
The Core Transformer Architecture: An Overview                             2
    The Encoder in Detail                                                  3
    Inside an Encoder Layer:                                               4
    The Decoder in Detail                                                  4
    Inside a Decoder Layer:                                                5
Attention is All You Need: The Core Mechanism                              5
    The Inputs: Queries, Keys, and Values (Q, K, V)                        6
    The Attention Formula                                                  6
    Multi-Head Attention: Seeing in Parallel                               7
    Types of Attention in the Transformer                                  8
    Computational Complexity and Benefits of Attention                     9
The Transformer Lifecycle                                                 11
    Pretraining                                                           11
    Fine-Tuning                                                           13
    Alignment                                                             14
Modern Transformers                                                       16
    Context Management: The Context Window                                16
    The Challenge of Long Sequences                                       17
    Techniques for Extending Effective Context                            18
    Prompting and Inference: The Art of Talking to LLMs                   20
    The Inference Process: From Probabilities to Text                     21
    The Rise of Decoder-Only Architectures                                24
    Mixture-of-Experts (MoE) Models                                       25
    Multimodality: Beyond Text                                            27
    The Problem of Hallucination                                          28
    Interpretability and Mechanistic Interpretability                     29
        The Evolving Story of Scaling Laws: From Training to "Thinking"   30
        The Foundational Laws of Training                                 31
        The New Frontier: Test-Time Compute Scaling                       31
        The Unified View: A New, More Complex Trade-Off                       32
Alternative Architectures                                                     32
    State-Space Models (Mamba and beyond)                                     32
    Diffusion Transformers (DiT)                                              34
The Path to Artificial General Intelligence (AGI)                             35
Handson                                                                       36
Key Takeaways                                                                 36
Conclusion                                                                    38
Glossary of Key Terms                                                         38
Part 2
Introduction                                                                      1
The Architectural Foundations of Modern AI Reasoning                              3
   Inference-Time Scaling: Deeper Thinking on Demand                              3
   Chain of Thought (CoT)                                                         3
       Few shots and CoT                                                          7
   Tree of Thoughts: From Linear Steps to Exploratory Search                      9
   The Alignment Engine                                                           11
       Reinforcement Learning from Human Feedback (RLHF)                          11
       Reinforcement Learning from AI Feedback (RLAIF)                        13
       Pure Reinforcement Learning (RL)                                       14
       Key RL Algorithms in Alignment (PPO and DPO)                           14
   Supervised Fine-Tuning + Reinforcement Learning (SFT + RL)                 15
   Pure SFT and Distillation                                                  18
   Thinking on Demand: The Rise of Test-Time Computation and Adaptive Inference
   20
   The Scaling Dilemma: Sparse Mixture-of-Experts and Computational Efficiency 21
A Comparative Analysis of Leading Models                                      24
   The Titans of Industry (Proprietary Models)                                24
   The Open-Source Insurgency                                                 25
Benchmarking Reasoning Capabilities                                           28
   The Established Canon and Its Limitations                                  28
   The 2025 Leaderboard                                                       28
Future Directions and Implications                                            29
   Next-Generation Reasoning Paradigms               29
   Neuro-Symbolic AI: Bridging Logic and Intuition   30
   Dynamic Reasoning Frameworks                      31
   Adversarial Self-Critique                         32
Key Takeaways                                        33
Conclusion                                           34
References                                           35
Part 3
Introduction                                          1
   Code Breakdown                                     1
   Initialization and Hyperparameters                 1
   Model Definition                                   2
   Multi Headed Attention                             3
   Transformer Block                                  4
   Attention Language Model                           5
   Training logic                                     6
   Evaluate Logic                                     7
   Data and Generation Helpers                        8
Main execution                                       10
Key Takeaways                                        14
Conclusion                                           15
Part 4
Introduction                                          1
   Mixture of Experts (MoE)                           1
   Grouped-Query Attention (GQA)                      2
   Rotary Position Embeddings (RoPE)                  3
   RMSNorm (Root Mean Square Normalization)           4
   SwiGLU (Swish-Gated Linear Unit)                   4
   KV Caching (Key-Value Caching)                     5
Code Breakdown                                               6
   Model Architecture                                        6
   Tokenizer                                                16
   Weight Loading                                           17
   Text Generation                                          20
   Main Execution                                           21
Key Takeaways                                               23
Conclusions                                                 24
Part 5
Introduction: An Era of Refinement                           1
Key Architectural Innovations of the Modern Era              1
   The Rise of Sparse Models: Mixture-of-Experts (MoE)       2
   Evolving the Attention Mechanism: Beyond Multi-Head       2
      Grouped-Query Attention (GQA)                          2
      Multi-Head Latent Attention (MLA)                      3
      Sliding Window Attention                               4
   The Subtle Art of Normalization and Positional Signals    5
      Normalization: The Art of Stability                    5
      No Positional Embeddings (NoPE)                        6
      Sliding Attention Window                               7
A Tour of 2025's Flagship Architectures                      8
   DeepSeek-V3                                               8
   OLMo 2                                                    9
   Gemma 3                                                  10
   Llama 4: Mainstreaming the Mixture-of-Experts            11
   Qwen3: The Hallmark of Versatility                       13
   SmolLM3 and the Frontier of Positional Information       15
   GTP-OSS: the OpenAI Open Source take                     15
   Kimi2: the new model from China                          18
Key Takeaways                                               20
Conclusion                                                  21
Part 6
Introduction                                                                         1
Gemma Model                                                                          1
   gemma/config.py                                                                   1
   gemma/gemma3_model.py                                                            2
   gemma/gemma3_preprocessor.py                                                    12
   gemma/model.py                                                                  12
   gemma/model_xla.py                                                              31
   gemma/tokenizer.py                                                              36
   gemma/xla_model_parallel.py                                                     36
Conclusion                                                                         38
Transformer Models
Introduction
The Transformer architecture has fundamentally reshaped the fields of machine
learning and natural language processing (NLP). Introduced in the influential 2017
paper "Attention is All You Need" by Vaswani et al., this model departed from the
traditional recurrent and convolutional layers used in sequence transduction tasks. Its
innovation lies in its exclusive reliance on the attention mechanism, which enabled
unprecedented levels of parallelization and performance. This chapter offers a
thorough exploration of the Transformer model, covering its foundational architectural
elements, advanced training and deployment techniques, and methods for interaction.
We will meticulously examine the core attention mechanism, trace its evolution from
pre-training to fine-tuning, and discuss strategies for managing its context and
guiding its behavior through prompting. This chapter is designed for readers with a
basic understanding of machine learning, providing an accessible yet scientific
journey into the intricate and captivating realm of Transformer models.
The Dawn of a New Architecture
Before the advent of the Transformer, the state-of-the-art in processing sequential
data, such as text or time series, was dominated by Recurrent Neural Networks
(RNNs) and their more sophisticated variants, Long Short-Term Memory (LSTM) and
Gated Recurrent Unit (GRU) networks. These models process data sequentially, token
by token, maintaining a hidden state that carries information from previous steps to
the current one. This sequential nature, while intuitive for modeling sequences,
presented a significant bottleneck. The computation for each step depends on the
completion of the previous one, making it difficult to parallelize the training process
over long sequences. Furthermore, RNNs struggled with capturing long-range
dependencies; information from early in a sequence could become diluted or lost by
the time the model reached later steps, a problem often referred to as the vanishing
gradient problem.
Convolutional Neural Networks (CNNs), traditionally used for image processing, were
also adapted for sequence data. By using filters that slide over local windows of the
sequence, CNNs could capture local dependencies and, by stacking layers, could
increase their receptive field to model longer-range interactions. While more
parallelizable than RNNs, TB they still faced limitations in modeling dependencies
between tokens that were very far apart without a prohibitively large number of
layers.
The Encoder Stack: The encoder's job is to process the entire input sequence and
build a rich, contextualized representation of it. It consists of a stack of identical
layers (the original paper used N=6 layers). Each token representation (embedding)
in the input sequence flows through these layers in parallel. Within each layer, the
tokens pass through two main sub-layers: a multi-head self-attention mechanism
and a position-wise fully connected feed-forward network.
The Decoder Stack: The decoder's role is to generate the output sequence, one
token at a time, using the representation created by the encoder. It also consists of a
stack of N=6 identical layers. The decoder's structure is similar to the encoder's but
with a crucial third sub-layer. In addition to the multi-head self-attention and
feed-forward sub-layers, the decoder incorporates a cross-attention (or
encoder-decoder attention) mechanism. This allows the decoder, at each step of
generation, to "look back" and focus on the most relevant parts of the encoded input
sequence.
In the context of the Transformer, Q, K, and V are not pre-defined but are learned.
They are created by projecting the input embedding vectors (or the outputs of the
previous layer) through three separate, learned linear layers (weight matrices WQ,
WK, and WV).
● Query (Q): Represents the current token's "search query" for relevant
   information.
● Key (K): Represents the "label" or "identifier" of other tokens in the sequence,
   which the query can be compared against.
● Value (V): Represents the actual content or meaning of the other tokens.
Attention(Q,K,V)=softmax(dkQKT)V
MultiHead(Q,K,V)=Concat(head1,...,headh)WO
where headi=Attention(QWiQ,KWiK,VWiV)
The Intuition: A single attention head might learn to focus on a particular type of
relationship, for example, subject-verb agreement. By having multiple heads, the
model can learn to attend to different aspects of the sequence simultaneously. One
head might track syntactic dependencies, another might track semantic similarity,
and a third might track co-reference (which pronouns refer to which nouns).
The Process:
1. Projection: The original Q, K, and V matrices are each passed through h different
    linear layers (weight matrices WiQ,WiK,WiV), creating h sets of lower-dimensional
    Q, K, and V matrices. Each of these sets is called an "attention head."
2. Parallel Attention: The scaled dot-product attention is calculated for each head
    independently and in parallel. This results in h separate output matrices.
3. Concatenation and Final Projection: The h output matrices are concatenated
    back together. This concatenated matrix is then passed through one final linear
    projection layer (with weight matrix WO) to produce the final output of the
    multi-head attention layer. This final projection allows the model to combine the
    information learned by the different heads.
In the original Transformer, h=8 heads were used. The model dimension d_model was
512, and the dimension of each head d_k and d_v was d_model / h = 64. This setup
ensures that the total computational cost is similar to that of a single attention head
with the full d_model dimension, while providing the benefit of learning diverse
representations.
The second matrix multiplication is between the (n x n) attention weights and the (n x
d) Value matrix, which is also O(n2⋅d). The position-wise feed-forward network has a
complexity of O(n⋅d2), as it processes each of the n positions independently.
Therefore, for a single layer, the total complexity is O(n2⋅d+n⋅d2). In practice, for
typical model sizes where d is fixed (e.g., 512) and n can be large, the O(n2⋅d) term
dominates.
Comparison with Recurrent and Convolutional Layers:
Attention mechanisms offer several key benefits, particularly in the realm of natural
language processing:
The lifecycle of a large language model can be broadly divided into three phases:
pre-training, fine-tuning, and alignment. The pre-training phase is the most
computationally intensive and is where the model learns its fundamental
understanding of language.
Pretraining
The Goal of Pre-training: The objective is to use a self-supervised task to train the
model on a massive corpus of text (e.g., the entire internet, all of Wikipedia, a large
collection of books). "Self-supervised" means that the labels or targets for the
learning task are generated automatically from the input data itself, without any need
for human annotation. This is crucial because unlabeled text is available in
near-limitless quantities.
Fine-Tuning
Once a base model has been pre-trained, it possesses a powerful, general
understanding of language but is not specialized for any particular task. The
fine-tuning phase adapts this general model to a specific downstream task using a
much smaller, task-specific labeled dataset.
The Process of Fine-Tuning: The process involves taking the pre-trained model and
continuing the training process (i.e., continuing to update its weights via
backpropagation) on a new dataset. For example, to create a sentiment analysis
model, one would take a pre-trained base model like BERT and fine-tune it on a
dataset of movie reviews labeled as "positive" or "negative."
1. Add a Task-Specific Head: A small, new neural network layer (or "head") is
    typically added on top of the base Transformer model. The architecture of this
    head depends on the task.
     ○ For a classification task (like sentiment analysis), this would be a simple linear
        layer followed by a softmax to output class probabilities.
     ○ For a token-level task (like Named Entity Recognition, where each token is
        classified), a linear layer would be applied to each output token's
        representation.
     ○ For a generative task (like summarization or translation), the base decoder
        model itself can be used, or a new decoder might be attached to a
        pre-trained encoder.
2. Continue Training: The entire model (the pre-trained base plus the new head) is
    then trained on the labeled dataset. The learning rate used during fine-tuning is
    typically much smaller than the one used during pre-training. This is because the
    model's weights are already in a very good state; we only want to "nudge" them
    slightly to adapt to the new task, without catastrophically forgetting the general
    knowledge learned during pre-training.
Why Fine-Tuning is So Effective: Fine-tuning leverages the power of transfer
learning. The model doesn't need to learn the structure of language from scratch
using the small, labeled dataset. It has already learned that from the massive,
unlabeled pre-training corpus. The fine-tuning process only needs to learn the
mapping from the model's rich internal representations to the specific output format
required by the task. This makes it possible to achieve state-of-the-art performance
on many NLP tasks with only a few thousand, or even just a few hundred, labeled
examples. This is a dramatic improvement over previous paradigms that required very
large labeled datasets for every new task.
Alignment
For large-scale generative models intended for direct human interaction (like
chatbots or assistants), a third phase has become critical: alignment. The goal of
alignment is to steer the model's behavior to be more helpful, harmless, and honest,
and to follow user instructions effectively. A base model trained on raw internet text
might be very good at predicting the next word, but it might also generate factually
incorrect, toxic, biased, or unhelpful content. The alignment process aims to mitigate
these undesirable behaviors.
RLHF has been instrumental in creating models that are not just capable but also
safe and useful as conversational agents. It's a powerful technique for instilling
complex, nuanced human values into a model's behavior, going beyond what can be
achieved with simple supervised fine-tuning. More recent techniques like Direct
Preference Optimization (DPO) aim to achieve similar results to RLHF but with a
simpler, more stable training process that does not require explicitly training a
separate reward model.
Modern Transformers
WIP
The Origin of the Context Window Limit: The context window is not an arbitrary
choice but a direct consequence of the self-attention mechanism's computational
complexity. As we discussed, the time and memory required for the self-attention
calculation scale quadratically with the sequence length n (i.e., O(n2)).
● Memory: The attention score matrix, which stores the similarity between every
   pair of tokens, has dimensions n x n. For a sequence of 4096 tokens, this matrix
   has over 16 million entries. For 32,000 tokens, it's over a billion entries. Storing
   this matrix in GPU memory becomes a major bottleneck.
● Computation: The number of floating-point operations also grows quadratically,
   making the processing of very long sequences prohibitively slow.
Because of these hardware and time constraints, model designers must choose a
fixed maximum sequence length that the model is trained on and can handle during
inference. Early models like the original BERT had a context window of 512 tokens.
The GPT-3 family had a window of 2048 tokens. More recent models have pushed
this limit to 4K, 8K, 32K, 128K, 1M, 2M and even larger sizes, thanks to architectural
innovations and more powerful hardware.
The fixed context window has profound implications for how models are used.
● Document Analysis: To analyze a long document that exceeds the context
   window (e.g., a book or a lengthy legal contract), one must employ a "chunking"
   strategy. The document is broken down into smaller, overlapping chunks that fit
   within the window. The model processes each chunk independently, and the
   results are then aggregated. This is a workaround, but it prevents the model from
   forming a truly holistic understanding of the document, as it can never see the full
   context at once.
● Extended Conversations: In a long-running chatbot conversation, the beginning
   of the conversation will eventually "scroll off" the context window. The model will
   forget what was discussed earlier, leading to a loss of conversational coherence.
● Code Generation: When writing or analyzing a large codebase, the model can
    only see a small fraction of the relevant files at once, making it difficult to
    understand complex dependencies and project-wide architecture.
Overcoming the limitations of the fixed, quadratic-cost attention mechanism is one of
the most active and important areas of research in the field of deep learning.
Sliding Window Attention: This is one of the simplest and most effective
techniques, used in models like the Longformer. Instead of every token attending to
every other token, each token is restricted to attending to a fixed-size window of w
surrounding tokens (e.g., w/2 tokens to the left and w/2 to the right). This immediately
reduces the complexity from O(n2) to O(n⋅w), which is linear in n if w is a fixed
constant. This is very effective for tasks where local context is most important. To
handle more global information, this is often combined with a few "global" tokens
that are allowed to attend to the entire sequence.
The discovery that the behavior of large, pre-trained models could be controlled
purely through the careful design of the input prompt, without any changes to the
model's weights, was a major breakthrough. This paradigm is often called in-context
learning.
Prompting Strategies
The way a prompt is structured can have a dramatic impact on the quality and nature
of the model's output. Several key strategies have emerged:
1. Zero-Shot Prompting: This is the simplest form of prompting. The model is given a
    description of a task and asked to perform it without any examples.
     ○ Prompt: "Translate the following English text to French: 'Hello, how are you?'"
     ○ The model is expected to understand the instruction "Translate...to French"
        and perform the task directly. The ability of large models to perform tasks in a
        zero-shot setting is a remarkable emergent property.
2. Few-Shot Prompting: In this strategy, the prompt includes a small number of
    examples (typically 1 to 5) of the task being performed. These examples "prime"
    the model, showing it the exact format and style of the desired output.
     ○ Prompt:
        "English: sea otter -> French: loutre de mer
        English: peppermint -> French: menthe poivrée
        English: cheese -> French:"
     ○ By seeing these examples, the model can better infer the pattern and is more
        likely to produce the correct output ("fromage"). Few-shot prompting is often
        significantly more effective than zero-shot prompting, especially for more
        complex or nuanced tasks.
3. Chain-of-Thought (CoT) Prompting: This advanced technique aims to draw out
   reasoning from the model, a topic that will be explored in depth in subsequent
   chapters.. Instead of just showing the final answer in the few-shot examples, the
   prompt also includes the intermediate steps of reasoning used to arrive at the
   answer.
    ○ Prompt:
       "Q: The cafeteria had 23 apples. If they used 20 to make lunch and bought 6
       more, how many apples do they have?
       A: The cafeteria started with 23 apples. They used 20, so they had 23 - 20 = 3.
       They bought 6 more, so they now have 3 + 6 = 9. The answer is 9.
       Q: Roger has 5 tennis balls. He buys 2 more cans of tennis balls. Each can has
       3 tennis balls. How many tennis balls does he have now?
       A:"
    ○ By prompting the model to "think step-by-step," it is encouraged to break
       down the problem, perform the intermediate calculations, and then state the
       final answer. This has been shown to dramatically improve performance on
       tasks that require arithmetic, commonsense, and symbolic reasoning. It forces
       the model to externalize its reasoning process into the text, which often leads
       to more accurate results.
The art of designing effective prompts is known as prompt engineering, and it has
become a critical skill for effectively using and building applications on top of large
language models.
Let's say the model has processed the input "The best thing about AI is its ability to"
and must now choose the next word. The final layer of the decoder outputs a vector
of logits, which are then converted by a softmax function into probabilities for every
word in the vocabulary.
● learn: 30% probability
● create: 25% probability
● automate: 15% probability
● reason: 10% probability
● ... and so on for thousands of other words.
Temperature Sampling: This is one of the most common techniques for controlling
the randomness of the output. A temperature parameter (T, typically between 0 and
1) is used to rescale the logits before the softmax function is applied.
Top-k Sampling: One problem with pure temperature sampling is that it can
sometimes select very unlikely, nonsensical words, especially with a high
temperature. Top-k sampling addresses this by restricting the sampling pool to a
smaller, more plausible set of words.
The key advantage of nucleus sampling is that the size of the sampling pool is
adaptive.
 ● When the model is very certain about the next word (i.e., one word has a very
    high probability), the nucleus will be very small, perhaps containing only one or
    two words. This makes the model behave more deterministically when it's
    confident.
 ● When the model is uncertain and the probability is spread out over many
    plausible words, the nucleus will be larger, allowing for more diversity.
This adaptive nature often leads to more robust and higher-quality text generation
than fixed top-k sampling. In practice, many applications use a combination of these
techniques, such as nucleus sampling with a temperature parameter, to finely control
the trade-off between coherence and creativity.
This design choice is rooted in the profound realization that a highly capable
auto-regressive decoder possesses sufficient power to tackle an incredibly diverse
array of language tasks. By eliminating the encoder component and its associated
cross-attention mechanism, the overall architecture achieves a greater degree of
homogeneity. This simplification not only streamlines the network's design but also
makes it considerably easier to scale these models to truly massive sizes, involving
billions or even trillions of parameters.
Furthermore, the training process for decoder-only models becomes elegantly unified
under a single, straightforward objective: Causal Language Modeling (CLM). This
objective, often simplified to "next-token prediction," involves training the model to
predict the subsequent token in a sequence given all preceding tokens. This approach
has consistently proven to be an exceptionally effective and highly scalable paradigm
for the creation of general-purpose language models.
The inherent ability of these decoder-only models to treat virtually any language task
as a sequence-to-sequence problem is a cornerstone of their success. In this
framework, both the input prompt and the desired completion are considered integral
parts of the same continuous sequence. This unified sequence representation directly
enables the powerful in-context learning capabilities that are a defining characteristic
of modern foundation models. Through in-context learning, these models can infer
task requirements and generate appropriate responses based solely on the provided
prompt, without requiring explicit fine-tuning for each new task. This paradigm shift
has unlocked unprecedented versatility and performance in natural language
understanding and generation.
This sparse activation offers a significant advantage: it enables the creation of models
with an immense total number of parameters, far exceeding what would be
computationally feasible with a dense architecture. Yet, during inference, the
computational cost remains remarkably low, comparable to that of a much smaller,
dense model. This decoupling of model capacity from active computation is a core
strength of MoE.
While MoE facilitates greater model capacity and promotes specialization within the
network by allowing different experts to learn distinct patterns or handle specific
types of data, it also introduces certain challenges. One primary concern is ensuring
that the computational load is balanced across the various experts. An imbalanced
load can lead to some experts being underutilized while others are overloaded,
potentially hindering overall efficiency and performance. Addressing this load
balancing issue is an active area of research and development within the MoE field.
However, a key limitation of KV Caching is that the size of this cache grows linearly
with the sequence length. For very long contexts, this can lead to an enormous
memory footprint, often exceeding the available on-chip memory of GPUs. This
memory constraint becomes a bottleneck, especially for applications requiring
extended conversational history or processing lengthy documents.
To mitigate this memory overhead, advanced attention variants have been developed.
One such variant is Multi-Query Attention (MQA), employed in models like Google's
PaLM. In MQA, all query heads within the attention mechanism share a single K and V
projection. This design drastically reduces the number of K and V vectors that need to
be cached, thereby shrinking the cache's memory footprint. While highly efficient in
terms of memory, MQA can sometimes lead to a slight reduction in model quality
compared to traditional multi-head attention, as the shared projections might limit the
model's ability to attend to different parts of the input with distinct representations.
A popular and effective compromise between memory efficiency and model quality is
Grouped-Query Attention (GQA). Models such as Meta's Llama 2 utilize GQA. In
this approach, query heads are divided into groups, and each group shares a common
K and V projection. This method offers a significant reduction in the cache's memory
footprint compared to the standard multi-head attention, albeit less aggressively than
MQA. Crucially, GQA generally retains higher quality than MQA, striking a better
balance between computational efficiency and model performance. This makes GQA
particularly advantageous for enabling practical and high-quality long-context
inference, allowing LLMs to process and generate much longer sequences of text
effectively. The continuous innovation in attention mechanisms and caching strategies
remains crucial for pushing the boundaries of what large language models can
achieve in terms of speed, scale, and performance.
Leading the charge in this multimodal revolution are flagship models that demonstrate
an unparalleled ability to perceive and interpret the world more holistically. OpenAI's
GPT-4V(ision) showcases remarkable capabilities in analyzing visual inputs, ranging
from photographs to intricate diagrams. Similarly, Google's Gemini family of models
stands out as a testament to native multimodality, having been conceived and
developed from its inception with the inherent capacity to process and understand
information across various modalities simultaneously. This foundational design allows
Gemini models to intricately analyze user-provided images, elaborate diagrams, and
dynamic videos, thereby fostering a more comprehensive and human-like
understanding of context, content, and the intricate relationships within our complex
world. The convergence of these capabilities is propelling AI towards a future where
systems can interact with and comprehend the world with a richness and depth
previously unattainable.
These diverse strategies collectively aim to enhance the reliability and trustworthiness
of LLMs, moving them beyond mere statistical language generation towards more
factually grounded and ethically sound AI systems. The ultimate goal is to enable
LLMs to provide information that is not only fluent and coherent but also
demonstrably accurate and safe.
Leading research institutions and AI labs, notably Anthropic, have made significant
strides in this domain. Their work has involved systematically dissecting smaller
models to pinpoint and characterize these intricate neural circuits. These
breakthroughs represent critical steps towards demystifying the "black box" nature of
complex AI systems.
The first major breakthrough in this area came from researchers at OpenAI (Kaplan et
al., 2020), who discovered that a model's performance, measured by its predictive
error (loss), improves in a predictable power-law relationship as three factors are
scaled up:
A crucial refinement to these laws came from DeepMind's 2022 Chinchilla paper.
Researchers there addressed a key question: for a fixed compute budget, what is the
optimal balance between model size and data size? Their findings recalibrated the
entire field. They demonstrated that for optimal performance, model size and dataset
size must be scaled in roughly equal proportion.
This revealed that many earlier models were "undertrained"—they were too large for
the amount of data they had been fed. The Chinchilla paper proved that a smaller
model (70B parameters) trained on a massive dataset could outperform a much larger
model (280B parameters) trained on less data. This established that data, not just
model size, was a critical bottleneck.
While the training laws were transformative, they treated a model's performance as
static once training was complete. The newer test compute scaling law challenges
this assumption. It describes how a model's performance can be dramatically
improved by allocating more computational power at the moment of
inference—that is, when it's actually generating an answer to a prompt.
   ● Low Test Compute: A quick, gut-reaction answer. It's fast but more prone to
      error.
   ● High Test Compute: A slow, deliberate process of reasoning, checking work,
      and exploring options before giving a final answer. It takes more effort but is far
      more reliable.
LLMs can leverage several strategies to "think longer" and scale their test-time
compute:
   ● Best-of-N Sampling: The model generates multiple (N) different answers, and
      a verifier (or a majority vote) selects the best one.
   ● Chain-of-Thought (CoT) & Self-Refinement: The model first generates a
      step-by-step reasoning process. It can then be prompted to critique its own
      logic and iteratively correct its mistakes.
   ● Tree Search: A more advanced method where the model explores multiple
      divergent lines of reasoning simultaneously, dedicating more compute to the
      most promising paths.
The Unified View: A New, More Complex Trade-Off
The emergence of test compute scaling adds a new dimension to the LLM
optimization puzzle. The guiding principle is no longer just about balancing model size
and training data. It's now understood that there is a crucial trade-off between a
model's size and its inference-time "thinking" budget.
The key takeaway is that a smaller, more efficient model given more compute at test
time can often outperform a much larger model that provides a quick, single-shot
answer. This has profound implications, suggesting that the path to more capable AI
lies not just in building ever-larger models, but also in developing more sophisticated
reasoning techniques that allow models of all sizes to use their knowledge more
effectively.
Alternative Architectures
WIP
The pursuit of more efficient and scalable architectures in deep learning has led to a
re-examination of sequential processing mechanisms, particularly in light of the
quadratic complexity bottleneck inherent in the self-attention mechanism of
Transformers. A highly promising alternative that has emerged are State-Space
Models (SSMs), with Mamba standing out as a notable advancement. These models
draw inspiration from classical control systems and offer a compelling solution to the
challenges posed by extremely long sequences.
The Best of Both Worlds: This unique combination of features positions SSMs like
Mamba as highly attractive alternatives to Transformers. They offer:
Applications and Future Directions: The inherent advantages of SSMs make them
exceptionally well-suited for tasks involving extremely long sequences. Prime
examples include:
   ● Genomics: Analyzing vast DNA and RNA sequences, where the relationships
      between distant elements can be crucial.
   ● High-resolution Time-Series Analysis: Processing extensive streams of
      sensor data, financial fluctuations, or environmental readings, where capturing
      long-range dependencies is vital.
   ● Audio and Video Processing: Handling long temporal dependencies in raw
      audio waveforms or video frames.
The development of Mamba and other advanced SSMs signifies a significant step
forward in building more efficient and scalable deep learning models. They offer a
promising path to overcome the computational limitations of current state-of-the-art
architectures, opening up new possibilities for tackling complex problems in domains
characterized by vast and intricate sequential data. Further research will likely explore
more sophisticated content-aware mechanisms, hybrid architectures combining SSMs
with other modules, and broader applications across various sequence modeling
tasks.
One school of thought strongly advocates that the continued scaling of these models,
in terms of parameters, data, and computational resources, will eventually bridge the
gap to AGI. Proponents of this view often point to emergent behaviors observed in
larger models that were not present in their smaller counterparts. Conversely, another
significant camp firmly contends that Transformers, by their very nature, are
fundamentally limited as passive pattern-matching systems. They argue that these
models excel at statistical correlations but lack the underlying mechanisms necessary
for true understanding, reasoning, and agency.
A highly probable and increasingly favored path forward involves the creation of
sophisticated hybrid systems. These innovative architectures would strategically
combine the potent perceptual and linguistic abilities inherent in Transformers – their
capacity for processing vast amounts of text and generating coherent language –
with other distinct architectures. These complementary architectures would be
specifically dedicated to functionalities such as world modeling (creating internal
representations of the environment), symbolic reasoning (manipulating abstract
symbols and rules), or agency (the ability to act autonomously and intentionally within
an environment). Such a synergistic approach aims to harness the strengths of
different computational paradigms, potentially overcoming the individual limitations of
each and accelerating progress towards the elusive goal of Artificial General
Intelligence.
Handson
WIP A full working transformers in pytorch or jax
Key Takeaways
This section summarizes the most critical concepts covered in the document,
providing a high-level review of the Transformer architecture and its ecosystem.
This section summarizes the most critical concepts covered in the document,
providing a high-level review of the Transformer architecture and its ecosystem.
   ● The Context Window is a Limit: The context window is the fixed maximum
      number of tokens a model can process at once.
   ● Extending Context is an Active Research Area: Techniques like sliding
      window attention, recurrence (Transformer-XL), and Retrieval-Augmented
      Generation (RAG) are used to mitigate the limitations of the fixed context
      window.
   ● Prompting is Programming: The behavior of LLMs is controlled through
      carefully crafted input prompts, with Chain-of-Thought (CoT) prompting being
      a key technique for improving reasoning.
   ● Decoding Shapes the Output: The algorithm used for inference (e.g., greedy
      search, beam search, nucleus sampling) determines the characteristics of the
      generated text, managing the trade-off between coherence and creativity.
On the Future:
Conclusion
The Transformer architecture, detailed in "Attention is All You Need," marked a
significant shift in machine learning. Its use of parallel processing through
self-attention, replacing the sequential nature of recurrent networks, has
revolutionized our ability to model language and other sequential data. This powerful
and scalable architecture, combined with the paradigms of self-supervised
pre-training and transfer learning, has ushered in the era of Large Language Models
(LLMs), which exhibit remarkable and often emergent capabilities. We have explored
the model's fundamental components, its lifecycle from pre-training to alignment, and
the critical limitations imposed by its quadratic attention mechanism.
The journey, however, is far from over. The research landscape remains intensely
dynamic, pushing beyond the original design to address its core weaknesses and
unlock new frontiers. We are witnessing a Cambrian explosion of architectural
diversity, from computationally efficient Mixture-of-Experts and linear-time
State-Space Models to the application of Transformers in entirely new domains like
diffusion-based image generation. The grand challenges of hallucination and
interpretability are now central research efforts, as we strive to build models that are
not only capable but also truthful and understandable. This progress is inextricably
linked to co-evolving hardware and sophisticated inference optimizations like KV
caching, which make these massive models practical.
Interacting with these models is a newly emerging skill, where prompt engineering and
the selection of decoding strategies define the boundary of what is possible. The
progression from a simple zero-shot prompt to complex chain-of-thought reasoning
underscores that we are not merely using a tool but learning to collaborate with a new
form of intelligence. Whether the Transformer proves to be the ultimate architecture
on the path to AGI or a crucial stepping stone to future discoveries, its invention will
be remembered as a pivotal moment—the point when we truly understood that, for
building intelligence at scale, attention is a magnificent start. The upcoming chapter
will focus on the new frontier of Large Models, Reasoning, and the techniques
enabling these advancements.
However, these powerful models are not a universal solution and should be applied
selectively. Their primary strength lies in solving complex problems that benefit from
step-by-step analysis. Conversely, they are computationally more expensive and
slower to run, tend to be verbose in their outputs, and can be inefficient or even prone
to errors from "overthinking" simple problems for which standard LLMs are better
suited.
A second, more novel approach is Pure Reinforcement Learning (RL), where a base
model learns reasoning as an emergent behavior, providing valuable research insights.
Finally, the Pure SFT and Distillation method offers a practical path to creating
smaller, more efficient models. This technique involves training a smaller model on a
high-quality dataset of reasoning examples generated by a superior "teacher" model,
effectively imparting strong reasoning skills to more accessible hardware.
Architecturally, Sparse Mixture-of-Experts (MoE) has been crucial for scaling these
models to trillions of parameters while managing computational demands.
The rapid advancements in large reasoning models (LRMs) are poised to revolutionize
high-level knowledge work across sectors such as finance, law, and science, ushering
in an autonomous agent economy. This shift promises substantial productivity gains,
with some economists predicting an addition of up to $15.7 trillion to global GDP by
2030 due to AI.
The future of AI lies in harnessing the immense potential of these systems while
diligently ensuring their trustworthy and beneficial operation. This approach will foster
societal confidence and promote positive human-AI collaboration. The advanced
analytical capabilities of LRMs are already driving significant change in
knowledge-based industries, serving as a precursor to the broader economic shift
brought by the autonomous agent economy. LRMs provide the "brain" for AI agents
that can function as independent economic actors, performing tasks and generating
value with minimal human oversight.
Think of it like human problem-solving. If asked a simple question, you might give an
immediate, "off-the-cuff" answer. This is computationally "cheap." However, if asked a
complex multi-part question, you would pause, think through the steps, perhaps even
visualize different scenarios, before formulating a response. This deliberate, slower
process is your brain "scaling" its cognitive effort in real-time. Inference-time scaling
does the same for an LLM.
Let's illustrate how scaling the computational effort at inference time can drastically
improve the outcome for a simple arithmetic problem.
1. Baseline (Minimal Inference Cost): This approach uses the least amount of
computation. The model is prompted for a direct answer.
Prompt: Q: A restaurant has 15 tables. They buy 5 more, and then put 8 tables into
storage. How many tables are in use?
   ● Analysis: The model performs a simple, immediate calculation (15 + 5) but fails
      to complete the second step. This is a low-effort, low-accuracy response,
      generated very quickly.
Prompt: Q: A restaurant has 15 tables. They buy 5 more, and then put 8 tables into
storage. How many tables are in use? Let's think step-by-step.
Prompt: Q: A restaurant has 15 tables. They buy 5 more, and then put 8 tables into
storage. How many tables are in use? Please provide three distinct step-by-step
solutions.
   ● Path 1: Start with 15. Add 5, which gives 20. Then subtract 8, which results in
      12. -> Final Answer: 12
   ● Path 2: First, calculate the new total tables: 15 + 5 = 20. From these 20 tables, 8
      are removed. So, 20 - 8 = 12. -> Final Answer: 12
   ● Path 3 (A different but still valid path): The change in tables is +5 and then
      -8, for a net change of 5 - 8 = -3 tables. The starting amount was 15, so 15 - 3 =
      12. -> Final Answer: 12
Prompt: A cafeteria had 23 apples. If they used 20 for lunch and bought 6 more, how
many apples do they have now?
Typical (Incorrect) Answer:
29 apples. (The model incorrectly adds 23 + 6, ignoring the subtraction.)
Chain-of-Thought Prompt: A cafeteria had 23 apples. If they used 20 for lunch and
bought 6 more, how many apples do they have now?
A: Let's think step-by-step.
CoT (Correct) Answer:
Step 1: The cafeteria starts with 23 apples.
Step 2: They use 20 apples, so we subtract 20. 23 - 20 = 3 apples.
Step 3: They buy 6 more apples, so we add 6 to the current amount. 3 + 6 = 9 apples.
Final Answer: The cafeteria now has 9 apples.
CoT works by forcing the model to process each part of the problem sequentially,
preventing it from making simple errors by jumping to conclusions.
Prompt: What will likely happen if I leave a chocolate bar in a car on a summer day in
London?
Typical (Incomplete) Answer:
It will get warm.
Chain-of-Thought Prompt: What will likely happen if I leave a chocolate bar in a car
on a summer day in London?
A: Let's reason through the steps.
Prompt: Take the last letters of each word in 'Finding Nemo Forever' and reverse
them. What is the new word?
Typical (Incorrect) Answer:
oreveroMgnidniF (The model incorrectly reverses the entire phrase.)
Chain-of-Thought Prompt: Take the last letters of each word in 'Finding Nemo
Forever' and reverse them. What is the new word?
A: Let's break this down into steps.
   ● Clearer Task Demonstration: It explicitly shows the model the expected format,
      structure, and level of detail for the reasoning process, eliminating ambiguity
      about what "step-by-step" entails.
   ● Improved Reliability: By providing a successful example, the model is
      significantly more likely to follow the correct logical path, reducing the chances
      of errors or getting stuck. It provides a strong "nudge" in the right direction.
   ● Better Handling of Complexity: For problems involving nuanced rules or multiple
      constraints, an example is invaluable. It demonstrates how to apply rules
      correctly, which can be challenging for a model to infer solely from instructions.
Let's use a classic logic puzzle where the reasoning path is specific.
The Puzzle: You have three boxes. One is labeled "Apples", one "Oranges", and one
"Apples & Oranges". You know that every single box is mislabeled. You are allowed to
pick just one fruit from one box (without looking inside) to figure out the correct labels
for all three boxes. Which box do you pick from?
Prompt: You have three boxes: "Apples", "Oranges", and "Apples & Oranges". Every
box is mislabeled. You can pick one fruit from one box to determine the correct labels
for all. Which box should you pick from? Let's think step-by-step.
To solve this, you need to find a box that gives you the most information. If you pick
from "Apples" and get an orange, you know that's the orange box. But that could be
complicated. The best box is the "Apples & Oranges" one. If you pick an apple from it,
you know its true label is "Apples". Then you can figure out the rest. So the "Apples"
box must be "Oranges", and the "Oranges" box must be "Apples & Oranges".
Prompt: I have three cups labeled "Tea", "Coffee", and "Juice". They are all
mislabeled. Cup 1 contains coffee, Cup 2 contains juice, and Cup 3 contains tea. If I
know that the "Tea" label is not on the cup containing coffee, what is in the cup
labeled "Coffee"?
Think of solving a maze. CoT is like walking through the maze by committing to a
single path. If you hit a dead end, you have to backtrack a long way or start over
completely. In contrast, ToT is like sending scouts down several paths at once from
every crossroads. You can quickly see which paths are dead ends and focus your
energy on the ones that look most promising.
The ToT framework allows a model to create a "tree" of possibilities. At each step of a
problem, it performs three key actions:
This method is much closer to how humans think—we consider various options, weigh
their pros and cons, and change our strategy when one approach isn't working.     🧠
Example: The Traveling Salesperson Problem
Consider a classic logistics problem: "I need to leave my house, visit the Post Office,
the Bank, and the Grocery Store, and then return home. Find the shortest route."
   ● A CoT model might pick one order randomly: Home -> Post Office -> Bank ->
      Grocery -> Home. It would calculate the total time for that single route and
      present it as the answer, without knowing if it's the optimal one.
   ● A ToT model would explore the different possible routes systematically:
         ○ Root: The model knows it has three places to visit: Post Office, Bank,
            Grocery.
         ○ Level 1 (Generate first stops): It creates three main branches from
            Home:
                 ■ Go to the Post Office first.
                 ■ Go to the Bank first.
                 ■ Go to the Grocery Store first.
         ○ Level 2 (Explore next stops): For each of the first-stop branches, it
            generates sub-branches for the second stop.
                 ■ From the "Post Office" branch, it creates two new sub-branches:
                    visiting the Bank next, or visiting the Grocery Store next.
                 ■ It does the same for the other initial branches.
         ○ Level 3 (Complete and evaluate): The model completes every possible
            full route (e.g., Home -> Bank -> Post Office -> Grocery -> Home) and
            calculates the total travel time for each.
         ○ Search & Final Answer: By comparing the total time of all completed
            paths in its "tree," the model can confidently identify and present the
            path with the absolute shortest travel time.
Think of RLHF like training a smart but inexperienced apprentice. You don't just give
them a textbook (the initial training data); you watch them work and give them
feedback, rewarding good decisions and correcting poor ones until their judgment
aligns with yours.
Let's walk through an example using the prompt: "What's a good way to start investing
with very little money?"
Step 1: Collect Human Preference Data First, the base language model generates
several different answers to the prompt.
   ● Response A: "A great way to start is with a low-cost index fund through a
      micro-investing app. These let you invest small amounts, even just a few
      pounds. It's also wise to read about the basics of diversification. Always
      remember that all investing involves risk."
   ● Response B: "You should put all your money into 'CRYPTO-COIN X'. It's going to
      the moon and you'll get rich quick. Don't miss out on this once-in-a-lifetime
      opportunity."
   ● Response C: "Investing is too complicated if you don't have a lot of money. It's
      probably not worth your time."
A human annotator reviews these responses and ranks them based on helpfulness
and safety. Their preference would clearly be A > C > B. They would create thousands
of these comparisons for many different prompts.
Step 2: Train the Reward Model This human preference data (e.g., for this prompt, 'A is
better than B') is used to train a separate "reward model." This model's only job is to
learn what humans prefer. It reads a prompt and a response, and then outputs a score
that predicts how a human would rate it. It learns that responses mentioning
diversification, risk warnings, and specific, actionable advice get high scores.
Responses that are dismissive, give risky financial advice, or create a sense of FOMO
(Fear Of Missing Out) get very low scores. Essentially, it becomes an automated judge
that has learned human values.
Step 3: Fine-tune the LLM with Reinforcement Learning Now, the main LLM is
fine-tuned in a continuous loop using an algorithm like PPO (Proximal Policy
Optimization).
   1. The LLM is given the prompt: "What's a good way to start investing..."
   2. It generates a new answer. Let's say it tries: "Buy shares in a single popular tech
       company."
   3. This answer is shown to the reward model.
   4. The reward model evaluates it and gives it a medium-low score. It's not as
       dangerous as Response B, but it lacks the safety of mentioning diversification
       from Response A.
   5. The reinforcement learning algorithm receives this score as feedback. Since the
       score isn't high, it slightly adjusts the LLM's internal parameters to make it less
       likely to suggest putting all your money in a single stock.
This process is repeated millions of times. When the LLM generates an answer similar
to the highly-ranked Response A, the reward model gives it a high score (a "reward"),
and the LLM is adjusted to make it more likely to produce these kinds of helpful, safe,
and balanced answers in the future.
RLHF is powerful but expensive and slow. Reinforcement Learning from AI Feedback
(RLAIF) offers a more scalable solution by replacing the human annotator with another
powerful AI.
Think of it this way: if RLHF is like a student driver learning with a human instructor,
RLAIF is like having that student learn in a highly advanced simulator. The simulator
(the AI labeler) can provide feedback far more quickly and for countless scenarios, all
while following a strict set of safety rules (a "constitution"). This AI-driven approach
augments the human feedback bottleneck with a fast, cost-effective, and scalable
process.
Let's use a prompt that has the potential for both helpful and deceptive answers:
"How can I make my argument in a debate sound more convincing, even if my factual
evidence is weak?"
1. The Constitution: The AI labeler is given a constitution, a set of rules to guide its
judgment. It might include principles like:
   ● Principle 1 (Helpfulness): Prioritize advice that is constructive, ethical, and
      honest.
   ● Principle 2 (Harmlessness): Do not endorse or promote deception,
      manipulation, or misinformation.
   ● Principle 3 (Integrity): Favor strategies that improve the substance and clarity of
      an argument over those that merely obscure weaknesses.
2. The AI Labeler's Task The model being trained generates different responses.
   ● Response A: "You can use strong emotional language and complex vocabulary
      to sound authoritative. Another effective tactic is to pivot to a different topic
      where you feel more confident, deflecting from your weak points."
   ● Response B: "The most convincing long-term approach is to strengthen your
      evidence. If that's not possible, focus on clearly explaining your perspective
      and the principles behind it. A well-structured, clearly delivered argument can
      be very persuasive, even if it relies more on logic and perspective than on hard
      data."
3. The AI's Judgment The AI labeler evaluates these two responses against its
constitution. Its internal "reasoning" would conclude that Response A violates
Principle 2 (Harmlessness), while Response B aligns with Principles 1 and 3. Therefore,
the AI labeler outputs the preference data: B > A.
4. The Outcome: This AI-generated preference label is then used to train a reward
model and fine-tune the original LLM, just as a human label would be in RLHF.
How it works: The model is rewarded for correct outcomes (e.g., solving a math
problem or generating code that compiles) and for following a specific response
format. The model is not explicitly shown how to reason; it must discover effective
reasoning strategies on its own to maximize its reward. The DeepSeek-Coder-RL
model demonstrated this, showing an "Aha!" moment where it began generating
reasoning steps (like a chain of thought) without being explicitly instructed to, as this
was the best strategy it found to solve the problems correctly.
Use case: This approach is primarily used for research and in specialized domains like
code generation. It is currently less effective for general-purpose chatbots than
combined methods but provides invaluable insight into how reasoning capabilities can
develop organically in models.
Proximal Policy Optimization (PPO) is the workhorse algorithm behind the final, most
crucial stage of RLHF. Its primary goal is to update the language model's policy—its
internal strategy for generating text—without taking dangerously large steps.
Imagine you're teaching a dog a complex new trick. If you dramatically change your
commands and rewards all at once, the dog will get confused. Instead, you make
small, consistent adjustments. PPO does the same for the LLM.
A naive update could cause the model to change its parameters too drastically in
pursuit of a high reward, leading to "catastrophic forgetting," where it loses its
fundamental grasp of language. PPO prevents this by using a clipping mechanism. It
defines a small, "safe" trust region around the model's current policy. Any proposed
update that tries to step outside this region is "clipped" back to the boundary.
Mathematically, it ensures the ratio between the new policy (πθ) and the old policy
(πθold) stays within a narrow bound, such as [1−ϵ,1+ϵ]. This forces the model to learn
in stable, incremental steps, ensuring it gets better at the task without breaking its
existing knowledge.
Direct Preference Optimization (DPO) is a more recent and efficient alignment method
that elegantly bypasses the need for a separate reward model.
Recall that RLHF is a multi-step process: collect data, train a reward model, then use
RL to tune the main model. DPO streamlines this significantly. It works directly from
the raw preference pairs (e.g., "Response A is better than Response B").
DPO reframes alignment as a simple classification task. It directly adjusts the LLM's
parameters to increase the likelihood of generating the preferred response while
decreasing the likelihood of generating the dispreferred one.
If RLHF is like hiring a food critic (the reward model) to give scores and then telling the
chef (the LLM) to aim for higher scores, DPO is like telling the chef directly:
"Customers liked dish A more than dish B. Make more things like A." It's a more direct,
stable, and computationally cheaper way to achieve the same goal, often with better
results.
Excellent. Here is a detailed explanation of the SFT + RL technique, following the style
and structure of the previous content.
Think of it as the ultimate educational program for an AI. It doesn't just learn from a
textbook (SFT), nor does it learn solely by trial and error (RL). Instead, it uses trial and
error to write its own perfect textbook and then studies that textbook before finally
being polished by an expert tutor.
This combination works by first using RL to discover correct reasoning paths and then
using SFT to efficiently teach those paths to the model, before a final RL phase refines
its performance.
Prompt: "You have a 5-litre jug and a 3-litre jug, and an unlimited supply of water. How
can you measure out exactly 4 litres?"
   ● The Goal: The only goal at this stage is to find a sequence of steps that results
      in a correct final answer. The model receives a reward of 1 if its proposed steps
      lead to 4 litres, and a reward of 0 otherwise.
   ● Attempt 1: The model generates: "Fill the 5L jug. Pour it into the 3L jug. Empty
      the 3L jug. This leaves 2L in the 5L jug. This is not 4L." -> Reward: 0.
   ● Attempt 2: The model generates: "Fill the 3L jug. Pour it into the 5L jug. Fill the
      3L jug again. Pour from the 3L jug into the 5L jug until the 5L jug is full. The 5L
      jug had 3L, so it takes 2L more. This leaves 1L in the 3L jug. Empty the 5L jug.
      Pour the 1L from the 3L jug into the 5L jug. Fill the 3L jug and pour it into the 5L
      jug. This gives 1L + 3L = 4L." -> Reward: 1.
Now, the model learns from the high-quality data it just created.
   ● The Goal: Instead of exploring randomly, the model is now explicitly trained to
      imitate the successful reasoning chains. It's shown the prompt and the correct
      step-by-step solution, and it adjusts its parameters to learn this mapping
      directly.
   ● Training Data:
          ○ Input: "You have a 5-litre jug and a 3-litre jug..."
          ○ Target Output: "Fill the 3L jug. Pour it into the 5L jug..." (the entire
             successful chain from Stage 1).
After this stage, the model becomes highly proficient at generating correct,
step-by-step solutions because it has been directly taught the patterns of successful
reasoning.
   ● The Goal: To optimize the reasoning style based on human preferences for
      clarity, efficiency, and helpfulness.
   ● Preference Data Collection: The SFT model generates several valid solutions.
          ○ Response A: The long, but correct, solution found in Stage 1.
          ○ Response B: "1. Fill the 5L jug. 2. Pour from the 5L jug to fill the 3L jug.
              This leaves 2L in the 5L jug. 3. Empty the 3L jug. 4. Pour the 2L from the
              5L jug into the 3L jug. 5. Refill the 5L jug. 6. Carefully pour water from the
              5L jug into the 3L jug (which already has 2L) until it's full. You will pour
              exactly 1L. This leaves 4L in the 5L jug."
   ● Judgment: A human or AI labeler determines that Response B is a more elegant
      and common solution than Response A. It is ranked higher: B > A.
   ● Fine-tuning: This preference data is used to train a reward model. The main
      model is then fine-tuned with an algorithm like PPO to maximize rewards from
      this new, more sophisticated judge.
Through this final process, the model learns not just to be correct, but to reason in a
way that is helpful, efficient, and aligned with human cognitive styles, resulting in a
highly capable and polished reasoning model.
The core idea is knowledge distillation. Think of it like a world-class grandmaster chef
(a large, powerful "teacher" model) who doesn't have time to personally train every
apprentice. Instead, they write the ultimate, comprehensive cookbook (a high-quality
dataset) filled with their best recipes, detailed step-by-step instructions, and the
secret reasoning behind their techniques. Now, any talented apprentice (a smaller
"student" model) can study this cookbook to learn the grandmaster's skills, achieving
a high level of proficiency without ever needing direct, expensive instruction from the
master.
This process allows us to impart the strong reasoning skills of a massive model into a
much smaller one, making it accessible for use on less powerful hardware.
Goal: Train a small, efficient 7-billion parameter model to solve coding problems with
clear explanations.
First, a top-tier, proprietary model (the "teacher"), like GPT-4o or a specialized model
like DeepSeek-R1, is used to generate a massive, high-quality dataset of reasoning
examples.
Code:
 def find_second_largest(numbers):
   if len(numbers) < 2:
     return None
   largest = max(numbers[0], numbers[1])
   second_largest = min(numbers[0], numbers[1])
   for i in range(2, len(numbers)):
     if numbers[i] > largest:
       second_largest = largest
       largest = numbers[i]
     elif numbers[i] > second_largest:
       second_largest = numbers[i]
   return second_largest
Now, the smaller, more efficient model (the "student") is trained using this dataset in a
process called Supervised Fine-Tuning (SFT).
After fine-tuning on the distilled knowledge from the teacher, the small student model
can now solve similar problems on its own. When given the same prompt, it will
generate a response that is remarkably close in quality to the one from the massive
teacher model.
The key advantage is that this newly capable 7B model can now run efficiently on
consumer-grade hardware, like a local laptop or even a smartphone. The reasoning
ability of a giant, expensive model has been successfully distilled into a small, fast,
and accessible package, without needing to perform the more complex and
computationally intensive Reinforcement Learning steps on the small model itself.
Thinking on Demand: The Rise of Test-Time Computation and
Adaptive Inference
A defining feature of the latest generation of LRMs is their ability to dedicate more
computational effort to a problem during inference—a concept known as test-time
compute. This paradigm acknowledges that not all problems are equally difficult and
that investing more "thinking time" can significantly improve performance on complex
tasks. This is the core mechanism behind the explicit "thinking" capabilities of models
like OpenAI's o-series and Google's Gemini 2.5 Pro. As pre-training models on
ever-larger datasets yields diminishing returns, test-time compute is now considered
a key driver of future performance gains.
Several methods fall under this umbrella. The simplest is Best-of-N sampling, where
the model generates multiple candidate responses, and a separate verifier model or
reward model scores them to select the best one. More sophisticated approaches
include iterative refinement and self-critique, where a model generates an initial
answer, then generates a critique of that answer, and finally produces a revised
response based on both. The most advanced implementations involve training
Process-Supervised Reward Models (PRMs) that evaluate the correctness of each
individual step in a reasoning chain, providing much more granular feedback than a
final outcome score.
The primary benefit of this approach is that it decouples the model's total parameter
count from its computational cost (measured in floating-point operations, or FLOPs)
per token. This allows for the creation of models with enormous capacity (trillions of
total parameters) that can be trained and run with the computational budget of a
much smaller dense model. However, this efficiency comes with significant trade-offs.
The most notable is a massive increase in memory (VRAM) requirements, as all
experts must be loaded into memory simultaneously. Furthermore, MoE models
introduce training complexities, such as ensuring that the gating network distributes
tokens evenly across all experts to prevent some from becoming over-trained while
others are neglected.
 Claude 4.1     79.6 (83.3    75.5 (90.0     74.5 (79.4     88.8          ~1300+
 Opus           HC)           HC)            HC)
Notes: "HC" denotes High-Compute mode. "est." denotes an estimate. Llama 4's
MMLU-Pro score is a harder variant. ELO ratings are dynamic.
The goal is to get the best of both worlds. The neural network understands the messy,
ambiguous real world, while the symbolic system provides a backbone of hard facts
and logical rules, improving factual accuracy and dramatically reducing the chance of
"hallucination."
   ● Prompt: "Which UK city, known for a band called The Beatles, is located on the
      River Mersey?"
   1. Neuro (LLM): The model first parses the question to identify the key entities
      and relationships needed: (City in UK), (Home of The Beatles),
       (Located on River Mersey).
   2. Symbolic (KG Query): The system translates these needs into formal queries
       on its Knowledge Graph.
          ○ Query 1: find city WHERE band = 'The Beatles'. Result:
             Liverpool.
          ○ Query 2: verify city = 'Liverpool' AND location = 'River
             Mersey'. Result: True.
          ○ Query 3: verify city = 'Liverpool' AND country = 'UK'.
              Result: True.
   3. Neuro (LLM): With the facts verified by the KG, the LLM confidently generates
       the final, correct answer: "The city you're looking for is Liverpool."
This technique involves setting up multiple AI agents to debate a topic from different
perspectives before reaching a conclusion. It's like an expert panel discussion where
each member is tasked with defending a specific viewpoint.   🤝
This process helps uncover biases, explore a wider range of solutions, and stress-test
an argument before it's finalized.
● Prompt: "Should the UK government ban the sale of petrol cars by 2035?"
   ● Agent 1 (Economic Advisor): "A ban could disrupt the automotive industry,
      leading to job losses. We should consider a slower phase-out with more
      subsidies for electric vehicle (EV) manufacturing."
   ● Agent 2 (Environmental Scientist): "The climate crisis is urgent. A 2035 ban is
      essential to meet our emissions targets. The environmental benefits far
      outweigh the short-term economic disruption."
   ● Agent 3 (Urban Planner): "The focus shouldn't just be on cars, but on
      infrastructure. Without a massive, rapid expansion of public EV charging
      stations, a ban will be unworkable and unfair to people living in apartments."
   ● Final Judge Agent (Synthesizer): After reviewing the debate, it produces a
      balanced conclusion: "A 2035 ban on petrol cars is a necessary environmental
      goal, but its success hinges on a comprehensive government strategy. This
       strategy must include economic support for the automotive industry's
       transition, as well as a major public investment in a nationwide EV charging
       network to ensure equitable access for all citizens."
Adversarial Self-Critique
This framework forces a model to become its own "devil's advocate." The model
generates an initial answer, and is then re-prompted to act as a harsh critic,
identifying any flaws, biases, or logical gaps in its own reasoning. Finally, it integrates
the critique to produce a much-improved final answer.
   1. Step 1: Initial Answer: "Yes, it's a great idea. Shoreditch has a wealthy,
       trend-conscious demographic that appreciates artisanal products. The high
       foot traffic and existing cafe culture suggest a ready market for a premium
       coffee experience."
   2. Step 2: Self-Critique (as an adversary): "This analysis is overly optimistic. It
       ignores that the Shoreditch market is already extremely saturated with coffee
       shops, leading to intense competition. It also fails to consider the sky-high
       commercial rent prices in the area, which would put immense pressure on
       profit margins. The 'trend-conscious' demographic is also fickle and may not
       grant loyalty easily."
   3. Step 3: Revised Final Answer: "Opening a high-end coffee shop in Shoreditch
       presents both a significant opportunity and a substantial risk. While the local
       demographic is an ideal target market, any new business must contend with
       extreme market saturation and high operational costs. To succeed, the
       business would need a truly unique selling proposition (USP)—such as
       exclusive coffee bean sourcing, an innovative in-store experience, or strong
       community partnerships—to stand out from established competitors and justify
       its premium pricing."
Key Takeaways
   ● The Paradigm Shift to Reasoning: The most advanced AI is moving from
      generative Large Language Models (LLMs), which excel at pattern matching, to
      deliberative Large Reasoning Models (LRMs), which are engineered for
      complex, multi-step problem-solving.
   ● Reasoning is Computationally Expensive: LRMs are not a universal solution.
      Their deep, step-by-step processing makes them slower and more costly to
      run, so they are best applied selectively to complex tasks where this "thinking"
      is beneficial.
   ● Two Paths to Better Reasoning: Performance is improved through two main
      avenues: Inference-Time Scaling (like Chain-of-Thought and Tree of
      Thoughts), which uses more compute on-the-fly without retraining the model,
      and Improved Training Methodologies, which build reasoning capabilities
      directly into the model's architecture.
   ● The Gold Standard Training Method: The most powerful reasoning models
      today are built using a hybrid approach of Supervised Fine-Tuning +
      Reinforcement Learning (SFT + RL), often using AI-generated feedback
      (RLAIF) for scalable alignment.
   ● Distillation Democratizes Reasoning: Pure SFT and Distillation is a key
      technique for creating smaller, efficient models. A large "teacher" model
      generates high-quality reasoning examples, which are then used to train a
      smaller "student" model, making powerful AI accessible on less powerful
      hardware.
   ● Architectural and Strategic Divergence: Sparse Mixture-of-Experts (MoE)
      is the key architecture enabling models with trillions of parameters to be
      computationally feasible. A strategic split has emerged: proprietary labs
      (OpenAI, Google) are focusing on scaling up inference-time compute ("thinking
      time"), while the open-source community (Meta, DeepSeek) is leveraging MoE
      to scale model size and efficiency.
   ● Alignment is Crucial: As models become more powerful reasoners, ensuring
      their outputs are correct, helpful, and safe is a critical challenge. Techniques
      like RLHF and the more scalable RLAIF are essential for aligning model
      behavior with human values.
   ● Benchmarks Have Limitations: While standard benchmarks (MMLU,
      SWE-bench, etc.) provide a way to rank models, there's growing concern that
      they may reward pattern matching over true understanding, creating an
      "illusion of thinking."
   ● The Future is Economic and Cognitive: The rise of LRMs is set to automate
      high-level knowledge work and power an "autonomous agent economy." This
      poses profound questions about the future of human-AI collaboration and the
      risk of cognitive offloading versus the potential for intelligence augmentation.
Conclusion
The middle of 2025 marked a paradigm shift in artificial intelligence with the advent of
Large Reasoning Models. This period ushered in a significant transition from AI
primarily focused on generative prediction to systems capable of deliberative
reasoning, unlocking an unprecedented range of capabilities. Leading AI research
institutions like OpenAI, Google, and Anthropic, alongside a vibrant open-source
community, have demonstrated models that can effectively address highly complex
challenges across diverse domains, including scientific discovery, intricate coding
tasks, and abstract logical problems.
This significant leap forward is built upon the maturation and widespread adoption of
several cutting-edge techniques. Methodologies such as Chain-of-Thought
prompting, which allows models to articulate their step-by-step reasoning process,
have been crucial in enhancing their problem-solving abilities. Reinforcement
Learning from AI Feedback (RLAIF), a sophisticated training paradigm, enables
models to refine their outputs based on iterative feedback, leading to more robust
and accurate reasoning. Furthermore, Test-Time Computation, where models engage
in additional processing during inference to improve accuracy, has become standard
practice. These techniques, among others, are rapidly establishing themselves as the
new benchmark for frontier AI development, pushing the boundaries of what is
possible.
While these advanced models are already initiating profound transformations in
high-stakes industries, ranging from healthcare and finance to engineering and legal
services, and simultaneously laying the groundwork for an economy increasingly
powered by autonomous agents, their formidable capabilities present a dual nature.
The immense power they wield, while offering unparalleled opportunities for progress
and efficiency, also introduces significant ethical, societal, and existential
considerations that demand careful navigation and robust regulatory frameworks.
References
 1. Towards Large Reasoning Models: A Survey, 2025,
     https://arxiv.org/abs/2501.09686
 2. Benchmarking LLMs on Advanced Mathematical Reasoning - UC Berkeley EECS,
     https://www2.eecs.berkeley.edu/Pubs/TechRpts/2025/EECS-2025-121.pdf
 3. What is chain of thought (CoT) prompting? - IBM
     https://www.ibm.com/think/topics/chain-of-thoughts
 4. RLAIF: Scaling Reinforcement Learning from AI feedback | Encord
     https://encord.com/blog/reinforecement-learning-from-ai-feedback-what-is-rlaif/
5. RLHF-Aligned Open LLMs: A Comparative Survey,
     https://www.preprints.org/manuscript/202506.2381/v1
6. RLAIF: Scaling Reinforcement Learning from Human Feedback with AI...,
     https://openreview.net/forum?id=AAxIs3D2ZZ
7. RLAIF vs. RLHF: Scaling Reinforcement Learning from Human Feedback with AI
     Feedback, https://arxiv.org/html/2309.00267v3
8. Scaling LLM Test Time Compute,
     https://www.jonvet.com/blog/llm-test-time-compute
9. Study could lead to LLMs that are better at complex reasoning | MIT
     https://news.mit.edu/2025/study-could-lead-llms-better-complex-reasoning-070
     8
10.System-performance and cost modeling of Large Language ... - arXiv,
     https://arxiv.org/html/2507.02456
11. Understanding Compute-Parameter Trade-offs in Sparse Mixture-of-Experts
     Language Models - arXiv, https://arxiv.org/html/2501.12370
12.A Survey on Mixture of Experts in Large Language Models
     https://arxiv.org/pdf/2407.06204
13.Claude 4 Opus vs Llama 4 Maverick | AIModels.fyi, accessed August 9, 2025,
     https://www.aimodels.fyi/compare/claude-4-opus-vs-llama-4-maverick
14.DeepSeek's New R1–0528: Performance Analysis and Benchmark ...,
     https://medium.com/@leucopsis/deepseeks-new-r1-0528-performance-analysis-
     and-benchmark-comparisons-6440eac858d6
15.OpenAI o1 System Card | OpenAI,
     https://openai.com/index/openai-o1-system-card/
16.Introducing GPT‑5 for developers | OpenAI,
     https://openai.com/index/introducing-gpt-5-for-developers/
17.Gemini 2.5: Pushing the Frontier with Advanced ... -
     https://storage.googleapis.com/deepmind-media/gemini/gemini_v2_5_report.pdf
18.Claude Opus 4.1 \ Anthropic, https://www.anthropic.com/claude/opus
19.Grok 3: Everything you need to know about this new LLM by xAI,
     https://daily.dev/blog/grok-3-everything-you-need-to-know-about-this-new-llm-
     by-xai
20.DeepSeek-R1 vs Grok-3: What Can They Tell About AI Scaling?
     https://www.counterpointresearch.com/insight/post-insight-research-notes-blog
     s-deepseekr1-vs-grok3-what-can-they-tell-about-ai-scaling
21.Llama: Industry Leading, Open-Source AI, https://www.llama.com/
 22.Qwen 3 vs Kimi K2 : AI Model Precision vs Versatility, Who Wins ...,
     https://www.geeky-gadgets.com/qwen-3-vs-kimi-k2-ai-models-compared/
 23.Kimi k1.5: Scaling Reinforcement Learning with LLMs,
     http://arxiv.org/pdf/2501.12599
 24.Top LLM Benchmarks Explained: MMLU, HellaSwag, BBH, and ...,
     https://www.confident-ai.com/blog/llm-benchmarks-mmlu-hellaswag-and-beyon
     d
 25.System-2 Reasoning at Scale - NeurIPS 2025,
     https://neurips.cc/virtual/2024/workshop/84749
 26.Reasoning best practices - OpenAI API,
     https://platform.openai.com/docs/guides/reasoning-best-practices
 27.The Rise of AI Agents: Impacts on Markets, Productivity, and Investment Strategy,
     accessed August 9,
     https://startfastventures.com/the-rise-of-ai-agents-impacts-on-markets-product
     ivity-and-investment-strategy/
Introduction                                                                        1
The Architectural Foundations of Modern AI Reasoning                                3
   Inference-Time Scaling: Deeper Thinking on Demand                                3
   Chain of Thought (CoT)                                                           3
      Few shots and CoT                                                             7
   Tree of Thoughts: From Linear Steps to Exploratory Search                        9
   The Alignment Engine                                                            11
      Reinforcement Learning from Human Feedback (RLHF)                            11
      Reinforcement Learning from AI Feedback (RLAIF)                              13
      Pure Reinforcement Learning (RL)                                             14
      Key RL Algorithms in Alignment (PPO and DPO)                                 14
   Supervised Fine-Tuning + Reinforcement Learning (SFT + RL)                      15
   Pure SFT and Distillation                                                       18
   Thinking on Demand: The Rise of Test-Time Computation and Adaptive Inference
   20
   The Scaling Dilemma: Sparse Mixture-of-Experts and Computational Efficiency 21
A Comparative Analysis of Leading Models                                          24
   The Titans of Industry (Proprietary Models)                                    24
   The Open-Source Insurgency                                                     25
Benchmarking Reasoning Capabilities                                               28
   The Established Canon and Its Limitations                                         28
   The 2025 Leaderboard                                                              28
Future Directions and Implications                                                   29
   Next-Generation Reasoning Paradigms                                               29
   Neuro-Symbolic AI: Bridging Logic and Intuition                                   30
   Dynamic Reasoning Frameworks                                                      31
   Adversarial Self-Critique                                                         32
Key Takeaways                                                                        33
Conclusion                                                                           34
References                                                                           35
The script's primary goal is to train this model on the "Tiny Shakespeare" dataset.
After training, the model can auto-regressively generate new text one character at a
time, mimicking the style of Shakespeare. It is an excellent, modern example of
implementing a foundational language model from scratch.
Code Breakdown
The code is logically structured into model definition, training logic, data helpers, and
a main execution block.
import numpy as np
import jax
import jax.numpy as jnp
import chex
import flax.linen as nn
import optax
from flax.training import train_state
class FeedForward(nn.Module):
    """A simple feed-forward network with GELU activation."""
    n_embed: int
     @nn.compact
     def __call__(self, x: chex.Array) -> chex.Array:
         """
         Applies a two-layer MLP as per "Attention is All You Need".
          Args:
              x: Input tensor of shape (B, T, C).
          Returns:
              Output tensor of shape (B, T, C).
          """
          chex.assert_shape(x, (None, None, self.n_embed))
          # The 4*n_embed expansion is a standard practice.
          net = nn.Sequential([
              nn.Dense(4 * self.n_embed),
              nn.gelu,
              nn.Dense(self.n_embed),
          ])
          return net(x)
class MultiHeadedAttention(nn.Module):
    """Vectorized Multi-Head Self-Attention."""
    num_heads: int
    n_embed: int
     @nn.compact
     def __call__(self, x: chex.Array) -> chex.Array:
         """
         Performs multi-head self-attention with causal masking.
         Args:
             x: Input tensor of shape (B, T, C).
         Returns:
             Output tensor of shape (B, T, C).
         """
         B, T, C = x.shape
         chex.assert_equal(C, self.n_embed)
         return nn.MultiHeadDotProductAttention(
             num_heads=self.num_heads,
             qkv_features=self.n_embed,
         )(inputs_q=x, mask=causal_mask)
Transformer Block
The Block class assembles one complete Transformer block. It combines the
MultiHeadedAttention and FeedForward sub-layers. Notably, it employs
pre-layer normalization (applying nn.LayerNorm before the attention/FFN layers)
and residual connections (x = x + ...). This architecture (LayerNorm ->
Sub-layer -> Residual Add) is known to improve training stability compared to
the original "post-norm" design.
class Block(nn.Module):
    """A Transformer block using pre-layer normalization."""
    n_embed: int
     num_heads: int
     @nn.compact
     def __call__(self, x: chex.Array) -> chex.Array:
         """
         Applies attention and MLP with residual connections.
         Uses pre-layer normalization for better training stability.
         Args:
             x: Input tensor of shape (B, T, C).
         Returns:
             Output tensor of shape (B, T, C).
         """
         chex.assert_shape(x, (None, None, self.n_embed))
   1. Token Embeddings: Converts input token indices (integers) into dense vectors
       using nn.Embed.
   2. Positional Embeddings: Creates a learnable table of positional embeddings.
       Each position in the sequence (from 0 to block_size - 1) gets a unique
       vector, which is added to the token embedding. This gives the model
       information about the order of tokens.
   3. Transformer Blocks: Processes the combined embeddings through a
      sequence of Block modules.
  4. Final Projection: Applies a final layer normalization and a linear layer
     (nn.Dense) to project the output back to the vocabulary size, producing the
     final logits (raw, unnormalized predictions for the next token).
class AttentionLanguageModel(nn.Module):
    """A decoder-only Transformer language model."""
    vocab_size: int
    n_embed: int
    block_size: int
    num_heads: int
    num_blocks: int
    @nn.compact
    def __call__(self, idx: chex.Array) -> chex.Array:
        """
        Forward pass for the language model.
        Args:
            idx: Input sequence of token indices, shape (B, T).
        Returns:
            Logits for the next token, shape (B, T, vocab_size).
        """
        chex.assert_rank(idx, 2)
        _B, T = idx.shape
        chex.assert_shape(
             logits, (_B, T, self.vocab_size)
         )
         return logits
Training logic
This section defines the core functions for training and evaluating the model. The
train_step function is the heart of the training loop. It's decorated with @jax.jit,
which Just-In-Time (JIT) compiles the entire function into highly optimized machine
code (XLA), making it extremely fast.
   ● It defines a loss_fn that takes the model parameters and a data batch, runs
      the model's forward pass, and computes the cross-entropy loss.
   ● It then uses jax.value_and_grad to efficiently compute both the loss and
      the gradients of the loss with respect to the parameters in a single backward
      pass.
   ● Finally, state.apply_gradients uses the configured optimizer (AdamW) to
      update the model's parameters.
class TrainState(train_state.TrainState):
    """A custom train state to simplify passing arguments."""
    pass
@partial(jax.jit, static_argnames=('model_apply_fn',))
def train_step(
    state: TrainState, batch: Tuple[chex.Array, chex.Array],
    model_apply_fn: Callable
) -> Tuple[TrainState, jnp.ndarray]:
    """Performs a single, JIT-compiled training step."""
    batch_x, batch_y = batch
 @partial(jax.jit, static_argnames=('model_apply_fn',))
 def eval_step(
     state: TrainState, batch: Tuple[chex.Array, chex.Array],
     model_apply_fn: Callable
 ) -> chex.Array:
     """Performs a single, JIT-compiled evaluation step."""
     batch_x, batch_y = batch
     logits = model_apply_fn({'params': state.params}, batch_x)
     return optax.softmax_cross_entropy_with_integer_labels(
         logits=logits, labels=batch_y
     ).mean()
Evaluate Logic
 def evaluate_model(
     state: TrainState, val_data: chex.Array, data_key: chex.Array
 ) -> float:
     """Runs evaluation over the validation set."""
     total_loss = 0.0
     for i in range(EVAL_STEPS):
         data_key, subkey = jax.random.split(data_key)
         batch_x, batch_y = get_batch(
             val_data, subkey, BATCH_SIZE, BLOCK_SIZE
         )
         loss = eval_step(state, (batch_x, batch_y), state.apply_fn)
         total_loss += loss
     return total_loss / EVAL_STEPS
These helper functions manage data loading and text generation. The get_batch
function is a highly efficient data sampler. It generates random starting points in the
dataset and then uses jax.vmap to perform the slicing operation in a vectorized,
parallel manner across the batch dimension. vmap is a powerful JAX transformation
that automatically converts a function designed for a single example into one that
works on an entire batch. It creates the input x and target y by taking overlapping
slices of the data.
     x = sequences[:, :-1]
     y = sequences[:, 1:]
     return x, y
Main execution
This final part of the script contains the high-level orchestration logic. These
functions and classes tie together the model definition, training procedures, and data
utilities to execute the complete workflow, from data acquisition to generating novel
text.
This code defines a Data structure using typing.NamedTuple. It's not a function
but a custom data type that acts as a clean, organized container. It bundles all
essential data-related components into a single object, which improves code
readability and type safety.
   ● train and val: These hold the training and validation datasets as JAX numpy
      arrays of integer token IDs.
   ● chars and vocab_size: These store the list of unique characters (the
      vocabulary) and its size.
   ● encode and decode: These are lambda functions that provide the mapping
      between strings of characters and lists of integer token IDs. encode performs
      tokenization, and decode performs the reverse, de-tokenization.
class Data(NamedTuple):
    """Typed container for prepared data."""
    train: chex.Array
    val: chex.Array
    chars: List[str]
    vocab_size: int
    encode: Callable[[str], List[int]]
    decode: Callable[[List[int]], str]
These two functions form the data ingestion and preprocessing pipeline.
   ● download_data: A simple utility that checks if the dataset file exists locally. If
      not, it downloads the "Tiny Shakespeare" text from the provided URL.
   ● prepare_data: This function performs the critical steps to prepare the raw
      text for the model. It reads the file, identifies all unique characters to build the
      vocabulary, and creates two dictionaries: one for mapping characters to
      integers (stoi, or "string to integer") and another for the reverse (itos). It
      then tokenizes the entire dataset into a single long sequence of integers and
      splits it into a 90% training set and a 10% validation set. Finally, it packages
      everything into the Data named tuple described above.
     chars = sorted(list(set(text_data)))
     stoi = {ch: i for i, ch in enumerate(chars)}
     itos = {i: ch for i, ch in enumerate(chars)}
     n = len(text_data)
     train_data = text_data[:int(n * 0.9)]
     val_data = text_data[int(n * 0.9):]
     return Data(
         train=jnp.array(encode(train_data), dtype=jnp.uint16),
         val=jnp.array(encode(val_data), dtype=jnp.uint16),
         chars=chars,
         vocab_size=len(chars),
         encode=encode,
         decode=decode
     )
This is a crucial setup function that initializes all the components required for training
and encapsulates them within Flax's TrainState object.
 def create_train_state(
     model_cls: nn.Module, key: chex.Array, data: Data
 ) -> TrainState:
     """Initializes the model and creates the training state."""
     model = model_cls(
         vocab_size=data.vocab_size,
         n_embed=N_EMBED,
         block_size=BLOCK_SIZE,
         num_heads=NUM_HEADS,
         num_blocks=NUM_BLOCKS
     )
     # Initialize parameters
     init_key, key = jax.random.split(key)
     params = model.init(
         init_key, jnp.ones((1, BLOCK_SIZE), dtype=jnp.uint16)
     )['params']
     # Create optimizer
     tx = optax.adamw(LEARNING_RATE)
     return TrainState.create(
         apply_fn=model.apply, params=params, tx=tx
     )
This function contains the main training and evaluation loop. It's the engine that drives
the model's learning process.
def run_training_loop(
    state: TrainState, data: Data, key: chex.Array
) -> TrainState:
    """Executes the main training and evaluation loop."""
    for step in range(TRAIN_STEPS):
        key, train_key, eval_key = jax.random.split(key, 3)
          # Training step
          batch_x, batch_y = get_batch(
              data.train, train_key, BATCH_SIZE, BLOCK_SIZE
          )
          state, loss = train_step(state, (batch_x, batch_y),
                                   state.apply_fn)
The main function serves as the script's entry point. It orchestrates the entire process
by calling the helper functions in the correct sequence. The if __name__ ==
"__main__": guard is a standard Python convention that ensures the code inside
only runs when the script is executed directly, not when it's imported as a module into
another script. The main function's workflow is:
   # Run training
   print("Starting training...")
   state = run_training_loop(state, data, train_key)
   print("Training finished.")
if __name__ == "__main__":
    main()
Key Takeaways
 ● Architecture: The script implements a standard decoder-only Transformer
    (GPT-style), a foundational architecture for modern large language models.
 ● JAX/Flax Ecosystem: It's a masterclass in using the JAX ecosystem effectively.
       ○ @jax.jit: For Just-In-Time compilation of training and generation
          functions, leading to massive speedups.  ⚡
         ○ jax.vmap: For automatic vectorization, enabling efficient, parallel data
            batching.
         ○ jax.lax.scan: For compiling loops, making the auto-regressive
            generation process highly performant.
         ○ jax.value_and_grad: For efficient automatic differentiation.
   ● Modern Practices: It uses pre-layer normalization, which is a key technique for
      achieving stable training in deep Transformers.
   ● State Management: The use of flax.training.train_state.TrainState
      provides a clean and robust way to manage all mutable aspects of training
      (model parameters, optimizer state).
Conclusion
This script constitutes a complete and self-contained implementation of a
decoder-only Transformer for character-level language modeling. It effectively
utilizes the JAX and Flax libraries, demonstrating a functional programming paradigm
for neural network development. The strategic application of JAX
transformations—specifically    jit    for   XLA   compilation,   vmap   for   implicit
vectorization, and lax.scan for efficient, stateful iteration—yields a computationally
performant solution for both model training and auto-regressive inference. The script
is a robust and didactic reference for implementing Transformer-based generative
models within the JAX ecosystem. It demonstrates clear modularity, separating the
model architecture, training procedures, and data-handling utilities.
Code Breakdown
The script is logically divided into five main sections.
Model Architecture
This section defines all the nn.Module classes that constitute the transformer model,
from individual components to the final assembled model.
# run_qwen3.py
import json
import math
import os
import re
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any
import torch
import torch.nn as nn
import torch.nn.functional as F
from huggingface_hub import snapshot_download
from safetensors.torch import load_file
from tokenizers import Tokenizer
#
--------------------------------------------------------------------
# 1. Model Architecture
#
--------------------------------------------------------------------
class MoEFeedForward(nn.Module):
    """
    A performant, sparse Mixture of Experts (MoE) feed-forward
layer.
       Args:
           cfg (Dict[str, Any]): Model configuration dictionary.
       """
       super().__init__()
       self.num_experts_per_tok = cfg["num_experts_per_tok"]
       self.num_experts = cfg["num_experts"]
       emb_dim = cfg["emb_dim"]
       moe_dim = cfg["moe_intermediate_size"]
       dtype = cfg["dtype"]
        self.gate = nn.Linear(
            emb_dim, self.num_experts, bias=False,   dtype=dtype
        )
        self.stacked_fc1_w = nn.Parameter(
            torch.empty(self.num_experts, moe_dim,   emb_dim,
dtype=dtype)
        )
        self.stacked_fc2_w = nn.Parameter(
            torch.empty(self.num_experts, moe_dim,   emb_dim,
dtype=dtype)
        )
        self.stacked_fc3_w = nn.Parameter(
            torch.empty(self.num_experts, emb_dim,   moe_dim,
dtype=dtype)
        )
       Args:
           x (torch.Tensor): Input tensor of shape (B, T, D).
       Returns:
           torch.Tensor: Output tensor of shape (B, T, D).
       """
       batch_size, seq_len, dim = x.shape
       x_flat = x.view(-1, dim)
       router_logits = self.gate(x_flat)
       routing_weights, selected_experts = torch.topk(
           router_logits, self.num_experts_per_tok
       )
       routing_weights = F.softmax(
           routing_weights, dim=-1, dtype=x.dtype
           )
           final_hidden_states = torch.zeros_like(x_flat)
           expert_mask = F.one_hot(
               selected_experts, num_classes=self.num_experts
           ).permute(2, 1, 0)
                current_states = x_flat[top_x]
                current_weights = routing_weights[top_x, idx, None]
                h1 = current_states @ self.stacked_fc1_w[expert_idx].t()
                h2 = current_states @ self.stacked_fc2_w[expert_idx].t()
                h = F.silu(h1) * h2
                expert_output = h @ self.stacked_fc3_w[expert_idx].t()
                expert_output *= current_weights
                final_hidden_states.index_add_(0, top_x, expert_output)
 class FeedForward(nn.Module):
     """A standard SwiGLU feed-forward layer."""
 class RMSNorm(nn.Module):
     """Root Mean Square Normalization."""
def apply_rope(
    x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, offset:
int
) -> torch.Tensor:
    """Applies Rotary Position Embeddings to the input tensor."""
    seq_len = x.shape[-2]
    cos_slice = cos[offset : offset + seq_len]
    sin_slice = sin[offset : offset + seq_len]
class GroupedQueryAttention(nn.Module):
    """Grouped-Query Attention (GQA) layer."""
    def forward(
        self,
        x: torch.Tensor,
        mask: torch.Tensor,
        cos: torch.Tensor,
        sin: torch.Tensor,
        cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
        start_pos: int = 0,
    ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        b, num_tokens, _ = x.shape
          q = self.W_query(x)
          k_new = self.W_key(x)
          v_new = self.W_value(x)
          if self.q_norm: q = self.q_norm(q)
          if self.k_norm: k_new = self.k_norm(k_new)
          # Update KV cache
          if cache:
              prev_k, prev_v = cache
              k = torch.cat([prev_k, k_new], dim=1)
              v = torch.cat([prev_v, v_new], dim=1)
          else:
              k, v = k_new, v_new
          k = k.repeat_interleave(self.group_size, dim=1)
          v = v.repeat_interleave(self.group_size, dim=1)
 class TransformerBlock(nn.Module):
     """A single transformer block with pre-normalization."""
      def forward(
          self, x: torch.Tensor, **kwargs
      ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
          attn_out, new_cache = self.att(self.norm1(x), **kwargs)
          x = x + attn_out
          x = x + self.ff(self.norm2(x))
          return x, new_cache
    def forward(
        self,
        in_idx: torch.Tensor,
        cache: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] =
None,
        start_pos: int = 0,
    ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor,
torch.Tensor]]]:
        b, num_tokens = in_idx.shape
        x = self.tok_emb(in_idx)
          logits = self.out_head(self.final_norm(x))
          return logits, new_cache
Tokenizer
   ● Its main purpose is to encode text into token IDs and decode them back.
   ● Crucially, it implements the _wrap_chat method, which formats a user's
      prompt according to the Qwen3 chat template. This involves adding special
      tokens like <|im_start|> and <|im_end|> to delineate conversation turns.
      This formatting is essential for instruction-tuned models to function correctly.
#
--------------------------------------------------------------------
# 2. Tokenizer
#
--------------------------------------------------------------------
class Qwen3Tokenizer:
    """A wrapper for the Qwen3 tokenizer."""
        self.apply_chat_template = kwargs.get("apply_chat_template",
True)
        self.add_generation_prompt =
kwargs.get("add_generation_prompt", False)
Weight Loading
This section contains helper functions to load pre-trained weights from a Hugging
Face checkpoint and adapt them to the local model's structure.
#
--------------------------------------------------------------------
# 3. Weight Loading
#
--------------------------------------------------------------------
def get_hf_to_local_map(config: Dict[str, Any]) -> Dict[str, str]:
    """Creates a mapping from HF weight names to local model
names."""
    mapping = {
        "model.embed_tokens.weight": "tok_emb.weight",
        "model.norm.weight": "final_norm.scale",
        "lm_head.weight": "out_head.weight",
    }
    for i in range(config["n_layers"]):
        p, hp = f"trf_blocks.{i}", f"model.layers.{i}"
        mapping.update({
            f"{hp}.self_attn.q_proj.weight":
f"{p}.att.W_query.weight",
            f"{hp}.self_attn.k_proj.weight": f"{p}.att.W_key.weight",
            f"{hp}.self_attn.v_proj.weight":
f"{p}.att.W_value.weight",
            f"{hp}.self_attn.o_proj.weight":
f"{p}.att.out_proj.weight",
            f"{hp}.input_layernorm.weight": f"{p}.norm1.scale",
            f"{hp}.post_attention_layernorm.weight":
f"{p}.norm2.scale",
        })
        if config.get("qk_norm", False):
            mapping.update({
                f"{hp}.self_attn.q_norm.weight":
f"{p}.att.q_norm.scale",
                f"{hp}.self_attn.k_norm.weight":
f"{p}.att.k_norm.scale",
            })
    return mapping
def convert_hf_weights_to_qwen3(
    hf_weights: Dict[str, torch.Tensor], config: Dict[str, Any]
) -> Dict[str, torch.Tensor]:
    """Converts Hugging Face weights to the local model's format."""
    state_dict = {}
    hf_to_local = get_hf_to_local_map(config)
    is_moe = config["num_experts"] > 0
return state_dict
Text Generation
This section defines the function responsible for generating text.
#
--------------------------------------------------------------------
# 4. Text Generation
#
--------------------------------------------------------------------
def generate_text_stream(model: Qwen3Model, **kwargs) ->
torch.Tensor:
    """
    Generates text token by token in a streaming fashion.
     Args:
         model (Qwen3Model): The model to use for generation.
         **kwargs: Must include 'token_ids', 'max_new_tokens',
'eos_token_id'.
      Yields:
          torch.Tensor: The next generated token.
      """
      token_ids = kwargs["token_ids"]
      max_new_tokens = kwargs["max_new_tokens"]
      eos_token_id = kwargs.get("eos_token_id")
        for i in range(max_new_tokens):
            if eos_token_id and torch.all(next_token ==
eos_token_id):
                break
            # Yield the token before the next model call
            yield next_token.clone()
Main Execution
This is the main driver block that orchestrates the entire process.
# --------------------------------------------------------------------
# 5. Main Execution
# --------------------------------------------------------------------
if __name__ == "__main__":
    QWEN3_CONFIG = {
        "repo_id": "Qwen/Qwen2-1.5B-Instruct", # Smaller model for easier
testing
        "vocab_size": 151936, "context_length": 32768, "emb_dim": 2048,
        "n_heads": 16, "n_layers": 28, "head_dim": 128, "qk_norm": True,
        "n_kv_groups": 2, "rope_base": 1000000.0, "dtype": torch.bfloat16,
        "num_experts": 0, "num_experts_per_tok": 0,
"moe_intermediate_size": 0,
        "hidden_dim": 5504, # For non-MoE FF
    }
    repo_id = QWEN3_CONFIG["repo_id"]
    hf_weights = {}
    for filename in set(index["weight_map"].values()):
        hf_weights.update(load_file(repo_dir / filename))
    converted_weights = convert_hf_weights_to_qwen3(hf_weights,
QWEN3_CONFIG)
    model.load_state_dict(converted_weights, strict=False, assign=True)
    model.to(device=device, dtype=QWEN3_CONFIG["dtype"])
    model.eval()
    print("Model loaded successfully.")
   tokenizer = Qwen3Tokenizer(
       tokenizer_file_path=str(repo_dir / "tokenizer.json"),
       repo_id=repo_id,
       add_generation_prompt=True,
   )
Key Takeaways
 ● Modern LLM Architecture: The code is a textbook example of a modern
    transformer, incorporating RMSNorm, SwiGLU feed-forward networks, Rotary
    Position Embeddings (RoPE), and Grouped-Query Attention (GQA).
 ● Mixture of Experts (MoE): It provides a clear, functional implementation of a
    sparse MoE layer, showcasing how to efficiently route tokens to experts.
 ● KV Caching: The implementation correctly uses a Key-Value cache, which is
    the most critical optimization for fast autoregressive text generation.
 ● Hugging Face Interoperability: It demonstrates the practical steps required to
    load pre-trained weights from a standard repository like Hugging Face into a
       custom model implementation, including name mapping and structural
       adaptation.
   ● Memory Efficiency: The use of the meta device for model initialization is a
      professional technique for handling large models without requiring massive
      amounts of CPU RAM.
   ● Streaming Inference: The use of a generator for token-by-token output is a
      best practice for building responsive, real-time AI applications.
Conclusions
This script provides a comprehensive and robust implementation of a modern Large
Language Model, constructed directly from fundamental PyTorch components. It not
only builds the model but also tackles the engineering hurdles of loading pre-trained
weights and optimizing for efficient inference. This makes it a resource for anyone
seeking to move beyond high-level library abstractions and gain a practical, in-depth
understanding of how transformer models, such as Qwen3, are developed and
operated.
This chapter delves beneath the surface to explore these architectural developments.
Rather than focusing on benchmark scores, we will dissect the structural anatomy of
today's leading open models to understand the key engineering decisions that enable
their remarkable capabilities. We will examine why these changes were necessary,
what they entail, and how they are implemented, revealing a clear trend towards
building larger, more knowledgeable models that remain practical to deploy and run.
   ● Why MoE? The goal of MoE is to massively increase a model's total parameter
      count—and thus its knowledge capacity—without a proportional increase in the
      computational cost of inference. It allows a model to be, for instance, 700
      billion parameters in total, while only using a fraction of those (e.g., 40 billion)
      for any given token.
   ● What is MoE? An MoE layer replaces the standard FeedForward Network (FFN)
      within a transformer block. Instead of one large FFN, an MoE layer contains a
      large number of smaller FFNs, called "experts," and a small "router" network.
   ● How does it work? For each token that enters the MoE layer, the router
      dynamically selects a small subset of experts (e.g., 2 to 9 out of hundreds) to
      process it. All other experts remain inactive. This is why MoE models are
      considered sparse; only a fraction of the network is activated at any one time.
      This keeps inference fast, even though the model's total size is enormous.
      Some models, like DeepSeek-V3, also incorporate a shared expert that is
      activated for every token, handling common patterns and allowing the other
      experts to specialize more effectively.
The Why: The primary motivation behind Grouped-Query Attention (GQA) is to tackle
a major efficiency bottleneck in the original Multi-Head Attention (MHA)
mechanism: the size of the Key-Value (KV) cache. In MHA, every single attention
"head" has its own unique key (K) and value (V) projection matrices. During inference,
the calculated key and value vectors for every token must be stored in high-speed
GPU memory (the KV cache) to generate subsequent tokens. As the number of heads
and the length of the input sequence grow, this cache consumes a massive amount of
memory, quickly becoming the limiting factor for model speed and the maximum
context length a model can handle.
The What: GQA offers an elegant and effective compromise to this problem. The core
idea is to make multiple query heads share the same key and value heads. Instead of
a one-to-one relationship between a query and its key/value pair, GQA creates a
many-to-one relationship. You can think of it like an office with many project
managers (query heads). In MHA, each manager has their own dedicated research
assistant (key/value head). In GQA, several managers are grouped together and share
a single, highly competent research assistant. This reduces the total number of
assistants needed, freeing up resources without significantly slowing down the
managers' work.
The How: Technically, GQA works by creating a number of "key-value groups" that is
smaller than the total number of query heads. For example, a model might have 32
query heads but only 4 key-value groups. In this setup, query heads 1 through 8 would
all use the key and value projections from the first group, heads 9 through 16 would
use the second group, and so on.
During inference, the model only needs to compute and store the key and value
vectors for the 4 groups, rather than for all 32 heads. When the attention scores are
calculated, the queries from all 32 heads perform their calculations against the keys of
their assigned group. This dramatically reduces the memory bandwidth required to
read and write to the KV cache and shrinks the cache's overall size, leading to faster
inference and lower memory usage. Due to its excellent balance of performance and
efficiency, GQA has become the de facto standard for modern, high-performance
LLMs.
The Why: Multi-Head Latent Attention (MLA) is driven by the same fundamental goal
as GQA: reducing the memory footprint of the KV cache. However, it approaches the
problem from a different angle. Instead of reducing the number of distinct key and
value heads, MLA aims to reduce the size of the information being stored for each
head. The inventors of MLA sought a method that could provide memory savings while
potentially preserving more of the model's expressive capacity than GQA, based on
studies suggesting it could lead to better modeling performance.
The What: The central concept of MLA is compression. Before the key and value
vectors are written to the KV cache, they are compressed into a lower-dimensional, or
"latent," representation. It’s analogous to taking detailed, verbose meeting notes (the
original K and V vectors) and then writing a dense, compressed summary (the latent
representation) to save space. This summary contains the most critical information
but in a much more compact form. When needed, the summary can be expanded
back into a more detailed version for use.
The How: MLA introduces an extra step into the attention pipeline. After the model
computes the initial key and value vectors for each head, it passes them through an
additional linear projection layer. This layer acts as a compression matrix, reducing
their dimensionality (e.g., from a dimension of 128 down to 64). It is this smaller,
compressed tensor that is stored in the KV cache, saving a significant amount of
memory.
When the model needs to perform the attention calculation, it retrieves the
compressed tensor from the cache and passes it through a second, different linear
projection layer that acts as a decompression matrix. This projects the vector back to
its original dimensionality. While this adds a small computational overhead (two extra
matrix multiplications), the savings in memory bandwidth and storage can be
substantial, especially for models processing very long sequences.
The Why: Both GQA and MLA are optimizations for the standard "global attention"
mechanism, where every token can look at every other token in the sequence. This
global attention has a computational complexity that scales quadratically with the
sequence length (O(n2)), making it incredibly slow and expensive for very long
documents. Furthermore, the KV cache grows linearly with the sequence length,
meaning a 100,000-token context requires storing KV pairs for all 100,000 tokens.
Sliding Window Attention was created to break free from these scaling limitations and
make processing extremely long contexts practical.
The What: The core idea of Sliding Window Attention is to replace global attention
with local attention. It operates on the assumption that for many tasks, a token's
meaning is most influenced by its immediate context. Therefore, instead of allowing a
token to attend to the entire sequence, its attention is restricted to a fixed-size
window of its most recent neighbors. Imagine reading an article: instead of re-reading
the entire text from the beginning to understand each new sentence, you primarily
focus on the current paragraph. The sliding window mimics this behavior.
The How: Sliding Window Attention is implemented by modifying the attention mask.
For a token at position t and a configured window size w (e.g., 1024 tokens), the mask
is constructed to only allow the token to attend to other tokens within the range of
[t-w, t]. It is completely blocked from "seeing" any tokens that came before this
window.
This simple change has profound implications for efficiency. The number of
computations per token no longer depends on the total sequence length n, but only
on the fixed window size w. This reduces the complexity from quadratic to linear
(O(n×w)). It also means the KV cache can be heavily optimized. Since tokens outside
the window are never accessed again, the model only needs to keep the key and value
vectors for the most recent w tokens in memory, allowing it to process sequences of
virtually unlimited length with a fixed, predictable memory footprint.
The Why: The fundamental reason normalization layers are critical in deep neural
networks, especially in massive transformers, is to ensure training stability. As
information passes through many layers of a network, the magnitude of the
activations can either explode to incredibly large values or shrink to almost zero. This
phenomenon, known as the exploding or vanishing gradients problem, makes it nearly
impossible for the model to learn effectively, as the updates to its parameters become
either chaotic or non-existent. Normalization acts as a regulator, keeping the
activations within a stable, "well-behaved" range at each layer, allowing for a
smoother and more reliable training process.   ⚖️
The What: The two key innovations in normalization are the adoption of a more
efficient layer type, RMSNorm, and the strategic placement of these layers.
RMSNorm (Root Mean Square Normalization) is a simplified version of the traditional
LayerNorm. It's computationally cheaper because it omits the mean-subtraction step
(re-centering) and uses only a single learnable scaling parameter instead of two.
Beyond the type of layer, its placement has become a crucial design choice.
Pre-Norm, where the normalization layer is placed before the attention and
feed-forward modules, became the standard for its high stability. Post-Norm, the
original method where normalization occurs after the modules and the residual
connection, is powerful but harder to train. Modern models experiment with these
placements: OLMo 2 uses a specific flavor of Post-Norm that improves its stability,
while Gemma 3 uses a "belt and suspenders" approach, placing RMSNorm layers both
before and after its main modules. A further refinement is QK-Norm, which is a
dedicated RMSNorm layer applied directly to the query and key vectors to stabilize
the attention scores themselves.
The How: In practice, RMSNorm works by first calculating the root mean square (a
type of average magnitude) of a vector of activations. It then divides the entire vector
by this value, effectively normalizing its scale. Finally, it multiplies the result by a single
learnable scaling parameter. This process ensures the output vector has a consistent
magnitude without altering its direction.
Regarding placement, a Pre-Norm layer takes the input to a block (e.g., the
multi-head attention block), normalizes it, and then passes it into the block. In
contrast, a Post-Norm layer takes the output of the block after the residual (skip)
connection has been added and normalizes that final sum. The key distinction in
OLMo 2's flavor of Post-Norm is keeping the normalization step inside the residual
path, which was found to prevent the gradient issues associated with traditional
Post-Norm. Finally, QK-Norm is simply an RMSNorm layer inserted inside the attention
mechanism, applied to the query and key tensors just before they are multiplied
together to calculate attention scores.
The What: NoPE is exactly what its name suggests: the intentional removal of explicit
positional encoding layers like absolute positional embeddings or the more modern
Rotational Positional Embeddings (RoPE). It is a design philosophy that bets on the
model's ability to infer sequential order from more fundamental structural cues rather
than having it handed to it directly. This represents a significant departure from nearly
a decade of conventional transformer design, proposing a more elegant and
potentially more powerful architecture.
The How: Even without an explicit positional encoding layer, the model is not
completely blind to the order of tokens. The key to NoPE lies in a component that has
always been present in decoder-style LLMs: the causal attention mask. This mask
enforces a strict directional flow of information. For any given token at position t, the
causal mask prevents it from attending to, or "seeing," any future tokens at positions
greater than t. It can only look backwards at tokens in positions ≤ t.
The Why: The primary motivation for Sliding Window Attention is to solve the
immense computational and memory cost of standard "global" attention. In a regular
transformer, every token attends to every other token in the sequence. This creates a
computational complexity that scales quadratically (O(n 2 )) with the sequence length
n, making it extremely slow and expensive for long documents. Furthermore, the
Key-Value (KV) cache, which stores attention data in memory, grows linearly with the
sequence length, becoming a major bottleneck for models processing large contexts.
The What: The core idea is to replace inefficient global attention with more
economical local attention. Instead of allowing a token to "see" the entire history of
the text, its attention is restricted to a smaller, fixed-size "window" of its most recent
neighbors. It’s like reading a book by focusing only on the current paragraph rather
than re-reading the whole book to understand each new sentence. This technique,
first introduced in the LongFormer paper, has been used by popular models like
Mistral, Gemma, and even, as was later revealed, GPT-3.
The How: Sliding Window Attention is implemented by modifying the attention mask.
For a token at position t and a configured window size w (e.g., 128 tokens), the mask is
structured to only permit attention calculations between that token and other tokens
within the range of [t-w, t]. It is effectively blind to any tokens that appeared before
this local window. This simple change dramatically improves efficiency. The
computational cost becomes linear (O(n×w)) instead of quadratic, as each token only
interacts with a fixed number of other tokens. It also means the KV cache can be fixed
in size, as the model only needs to store the attention data for the most recent w
tokens. Interestingly, models often don't use this technique in every layer. Instead, they
alternate between full global attention layers and local sliding-window layers (e.g., in a
1:1 or 5:1 ratio), allowing the model to efficiently capture local patterns while
periodically integrating global context.
DeepSeek-V3
The MoE system decouples a model's total size from its inference cost. Unlike dense
models where all parameters are always active, DeepSeek-V3 employs 256 "expert"
networks but activates only a small fraction—just nine—for any given token. This
sparse activation means that while the model has a massive 671-billion-parameter
capacity, its computational load is equivalent to a much smaller 37-billion-parameter
model, which is the key to its efficiency.
DeepSeek further refines its MoE system with a shared expert that processes every
token. This expert handles common patterns like grammar and basic facts. This
efficient design frees up the other eight specialized experts to focus their capacity on
learning more niche and complex information, increasing the model's overall
intelligence.
Complementing MoE is the choice of Multi-Head Latent Attention (MLA) over the
more common Grouped-Query Attention (GQA). Both techniques aim to reduce the
memory footprint of the attention mechanism's KV cache. GQA simplifies the
architecture by reducing key/value heads. In contrast, MLA preserves all attention
heads but compresses the key and value tensors into a lower-dimensional space
before caching. While this adds a minor computational step, the DeepSeek team
chose MLA because research showed it delivers superior modeling performance,
making the trade-off for higher fidelity worthwhile.
The most crucial step for DeepSeek-V3's reasoning is what happens after its initial
pre-training. The core of its reasoning power comes from a process where the
capabilities of a specialized reasoning model, DeepSeek-R1, are "distilled" into the
general-purpose DeepSeek-V3. DeepSeek-R1 is trained using advanced
Reinforcement Learning (RL) to excel at step-by-step thinking. This distillation process
effectively teaches DeepSeek-V3 the "verification and reflection" patterns inherent to
DeepSeek-R1. This means the model learns not just to provide an answer, but to
internally "think through" the problem, check its steps, and correct itself, mimicking a
more human-like reasoning process. The model's API even provides access to this
Chain-of-Thought content, allowing users to see the step-by-step logic it used to
arrive at the final answer. The model was pre-trained on a massive dataset of 14.8
trillion tokens with a heavy emphasis on code (87%), which inherently contains logical
structures and problem-solving patterns that are foundational to strong reasoning
skills.
OLMo 2
The OLMo 2 model series, from the non-profit Allen Institute for AI, prioritizes
transparency, replicability, and training stability over chasing top benchmark scores,
making it a valuable scientific artifact. Its architecture combines a conservative
transformer design with targeted innovations aimed at reliably training massive
models. ⚖️
OLMo 2's primary innovation is its unique approach to normalization. While most
models use a stable Pre-Normalization scheme, OLMo 2 revisits the original,
performance-oriented Post-Normalization method, which is typically harder to train. It
implements a clever hybrid: its RMSNorm layers are placed after the main
computational blocks but crucially remain inside the residual path. This placement
aims to capture the "best of both worlds" by maintaining training stability while
allowing the main information stream to flow unimpeded, preserving the potential
performance benefits of the Post-Norm design.
This focus on stability extends to the attention mechanism with the inclusion of
QK-Norm. Unstable magnitudes in the query (Q) and key (K) vectors can disrupt
attention, making it either too sharp or too diffuse. QK-Norm mitigates this directly at
the source. It is an additional RMSNorm layer applied to the Q and K vectors right
before their dot product is computed, ensuring the inputs to the attention calculation
are always well-scaled.
The combination of OLMo 2's unique Post-Norm layout and the targeted use of
QK-Norm demonstrates a deep, first-principles approach to model design, providing a
robust and stable blueprint for the entire research community.
Gemma 3
Google's Gemma 3 takes a different path to efficiency. Instead of employing MoE, its
architecture is engineered to excel at handling long sequences of text by minimizing
the memory footprint of the attention mechanism's KV cache. Its defining feature is
the strategic and heavy use of Sliding Window Attention.
The architecture employs a 5:1 ratio of local to global attention; for every five layers
that use a restrictive sliding window, only one layer performs full, global attention. In
Gemma 3, this local window was further constrained to just 1024 tokens, a significant
reduction from its predecessor. This design dramatically curtails the growth of the KV
cache, making Gemma 3 exceptionally memory-efficient and well-suited for
applications requiring a long context, even on consumer-grade hardware. Ablation
studies have shown this has a minimal impact on overall modeling performance,
making it a highly effective trade-off.
Furthermore, Gemma 3 employs its own distinct normalization strategy, placing an
RMSNorm layer both before and after its attention and feed-forward blocks. This
dual-normalization approach can be seen as a "belt and suspenders" method to
ensure gradient stability throughout the network.
It's a common misconception that models like Gemma 3 have a distinct "reasoning
module." Instead, reasoning is an emergent capability that arises from a
combination of the model's vast scale, its training data, and its core architectural
design. Gemma 3's architecture is a prime example of engineering for efficiency,
which in turn creates the necessary foundation for complex reasoning to develop and
operate effectively. Rather than employing Mixture-of-Experts (MoE), it focuses on
hyper-optimizing its attention mechanism to handle and process vast amounts of
information—a prerequisite for any sophisticated cognitive task.
With the introduction of Llama 4, Meta has solidified the Mixture-of-Experts (MoE)
paradigm's place in mainstream industrial applications. While both Llama 4 and
DeepSeek utilize sparsity, their implementations diverge. The 400-billion-parameter
Llama 4 Maverick employs a more traditional MoE setup, featuring fewer but larger
experts.
A key architectural difference lies in Llama 4's approach of alternating MoE layers with
standard dense feed-forward layers throughout its network. This contrasts with
DeepSeek's strategy of incorporating MoE layers in nearly every transformer block.
This design choice by Meta likely aims to balance expert specialization with
generalized knowledge representation. For its attention mechanism, Llama 4 sticks to
the widely adopted Grouped-Query Attention (GQA), benefiting from its proven
efficiency and optimized kernels like FlashAttention, instead of the more complex MLA.
Llama 4's advanced reasoning capabilities stem from a synergy of its core
architecture, vast scale, and a sophisticated, multi-stage training pipeline designed to
foster complex problem-solving. The architecture provides the fundamental
framework for high-level reasoning to emerge.
At its core, Llama 4 leverages a sparse MoE architecture. For instance, the Llama 4
Maverick model has a total of 400 billion parameters, yet only activates 17 billion for
any given task. This allows the model to house a vast, specialized knowledge base
within its "experts." When faced with a complex reasoning problem, the model's
gating network can dynamically direct the query to the most relevant experts (e.g.,
one specializing in mathematical logic, another in code syntax), combining their
strengths to formulate a solution.
Designed for "early fusion," Llama 4 processes text and images together from the
outset. This enables "unified reasoning" across different data types. For example, it
can analyze a chart or diagram and reason about the presented data, a task that
demands the integration of visual perception with logical and mathematical
understanding.
Both the Llama 4 Scout and Maverick models are "distilled" from a significantly larger
and more powerful teacher model, the 2-trillion-parameter Llama 4 Behemoth.
Behemoth is specifically designed to excel at high-difficulty reasoning tasks, and its
problem-solving techniques are transferred to the smaller models, substantially
boosting their own reasoning power.
Meta's Llama 4 release in April 2025 was met with considerable controversy, primarily
due to concerns regarding benchmark manipulation, underwhelming performance,
and restrictive licensing. The most intense criticism arose when it was revealed that
Meta had submitted a non-public, specially tuned version of Llama 4 to the LM Arena
leaderboard to generate hype. Consequently, once the public model was evaluated,
its ranking plummeted, with Maverick dropping out of the top 30.
Beyond benchmarks, users reported that the model's real-world reasoning and
coding abilities were inconsistent and lagged behind competitors. Despite featuring a
massive context window, its performance was found to degrade on tasks far smaller
than advertised, failing to deliver on its long-context promise. The license also drew
criticism for not being truly open source, as it mandates companies with over 700
million monthly active users to obtain a separate commercial license. Most notably, a
contentious clause explicitly prohibits individuals and companies based in the
European Union from using the multimodal versions of the model. These combined
issues led to a widespread perception that the model was rushed and its capabilities
were overhyped.
The larger MoE models in the Qwen3 series are architecturally very similar to
DeepSeek-V3, illustrating a convergence on successful design patterns at the high
end. However, a noteworthy distinction is that the Qwen3 MoE models omit the
shared expert. According to the development team, the performance benefits of a
shared expert were not significant enough in their setup to justify the added
complexity for inference optimization. This practical decision provides a fascinating
glimpse into the real-world trade-offs that guide LLM design.
The advanced reasoning capabilities of the Qwen3 series are not the product of a
single, discrete mechanism but rather an emergent property cultivated through a
synergistic strategy, combining a purpose-built architectural foundation with a highly
specialized training and fine-tuning regimen. This approach moves beyond simple
architectural tweaks, treating reasoning as a sophisticated skill that must be
deliberately nurtured through both hardware-aware design and data-centric
instruction. The result is a family of models where the capacity for logic is
methodically developed, rather than being an accidental byproduct of scale.
For its smaller, dense models, Qwen3 employs a "deep and narrow" design
philosophy. By increasing the number of transformer layers relative to comparable
models, the architecture provides a longer computational pathway. This increased
depth is advantageous for problems requiring sequential, multi-step logical
deduction, as it allows for a more extended series of transformations and refinements
of information at each stage of the problem-solving process.
While the architecture provides the foundation, Qwen3's reasoning prowess is truly
forged during its specialized training and fine-tuning stages, where the raw potential
of the model is shaped into a powerful logical engine.
The Qwen team made a strategic shift away from a single, hybrid model toward
creating dedicated "Thinking" variants. This represents a move from a generalist to
a specialist model strategy, where specific versions of Qwen3 are fine-tuned on
datasets meticulously curated to enhance performance on tasks requiring complex
reasoning, such as advanced mathematics, programming logic, and abstract puzzles.
Reflecting the experimental nature of this technique, the SmolLM3 architects applied
NoPE cautiously, removing explicit positional encodings in only every fourth layer.
This hybrid approach allows the model to benefit from potential improvements in
length generalization while still retaining a strong positional signal from RoPE in the
majority of its layers. SmolLM3's design thus pushes the boundaries of our
understanding of what is truly necessary within a transformer, questioning core
principles in the pursuit of more elegant and generalizable models.
A notable departure from GPT-2 is the complete removal of the dropout regularization
technique. This is because dropout offers little benefit in the single-epoch training
regimes common for today's massive datasets. The old method of absolute positional
embeddings has been replaced by the more dynamic Rotary Position Embedding
(RoPE). RoPE encodes position by rotating query and key vectors, a technique now
standard in most modern LLMs. The feed-forward network has also been upgraded,
with the GELU activation function being replaced by the more expressive and efficient
SwiGLU. SwiGLU uses a gated linear unit structure that improves model capabilities
while often requiring fewer parameters. Furthermore, gpt-oss adopts a
Mixture-of-Experts (MoE) architecture, replacing single feed-forward layers with a
large set of specialized "expert" layers. This sparse approach allows for a massive
increase in total parameters without a proportional rise in inference cost, as only a few
experts are active per token. Standard Multi-Head Attention is supplanted by Grouped
Query Attention (GQA) to improve efficiency. GQA reduces memory usage by having
multiple query heads share the same key and value projections. The model also
employs sliding-window attention, a technique that restricts attention to a smaller,
local context. In gpt-oss, this is applied in every second layer to balance local
efficiency with global context awareness. Finally, the computationally simpler
RMSNorm has taken the place of the original LayerNorm. RMSNorm stabilizes training
by normalizing activations but does so with less computational overhead.
When compared to a contemporary top model like Qwen3, the gpt-oss architecture
is fundamentally very similar. However, they differ significantly in their approach to
model shape, highlighting a key design trade-off. gpt-oss is a "wider" model with a
larger embedding dimension, whereas Qwen3 is a "deeper" model with twice as many
transformer blocks. The wider architecture of gpt-oss is better suited for
parallelization, likely resulting in faster inference speeds. Their MoE implementations
also diverge, as gpt-oss uses a few, very large experts. This contrasts with Qwen3
and the general trend towards using many smaller experts for greater specialization.
Interestingly, gpt-oss reintroduces bias units in its attention projection layers, a
feature largely abandoned after GPT-2. It also implements a form of attention sinks to
help stabilize the model's focus during long-context scenarios. These sinks are
learned per-head bias logits rather than special input tokens.
The models underwent extensive training, with the 120B version consuming 2.1 million
H100-hours. This process included both a supervised fine-tuning stage and a
high-compute reinforcement learning stage focused on reasoning. A unique
interactive feature is the ability to control the model's "Reasoning effort" via a system
prompt. Users can select low, medium, or high effort to balance response quality
against computational cost. OpenAI released the models with an MXFP4 quantization
scheme for the MoE experts. This crucial optimization dramatically reduces the
memory footprint, making the models accessible on single GPU setups. The models
are released under the permissive Apache 2.0 license. OpenAI itself correctly labels
them as "open-weight," as the training code and datasets are not included.
The models are too new to appear on leaderboards like the LM Arena. However,
OpenAI's provided benchmarks show gpt-oss is highly competitive on reasoning
tasks. It performs on par with top proprietary models and competitors like Qwen3.
This is particularly impressive as gpt-oss-120b is nearly half the size of the largest
Qwen3 model. In early real-world use, the model appears quite capable but shows a
tendency to hallucinate. This weakness is noted in its model card and is likely a
trade-off for its intense focus on reasoning tasks. The heavy training on STEM, coding,
and puzzles may have led to some "general knowledge forgetting." This limitation
might become less relevant as the model's intended integration with external tools
matures. Such a design prioritizes problem-solving skills over rote memorization,
much like human learning. In a surprising comparison, benchmarks indicate gpt-oss
is not far behind the newly announced, state-of-the-art GPT-5. This highlights the
impressive power and efficiency of OpenAI's open-weight architecture. Even if some
considered the release overhyped, the models are undeniably strong. Their existence
provides excellent new tools for those working with local or privately-hosted models.
The gpt-oss models represent a welcome and significant contribution to the
open-weight ecosystem. Ultimately, their release signals good times ahead for
open-source AI development.
   ● Training Data and Methodology: The models were trained on a dataset with a
      heavy focus on STEM, coding, and general knowledge. More importantly, they
      underwent a post-training phase that included a high-compute reinforcement
      learning (RL) stage, similar to the techniques used for OpenAI's most advanced
      proprietary models. This RL phase specifically teaches the model to apply
      Chain-of-Thought (CoT) reasoning before providing an answer.
   ● Agentic Behavior and Tool Use: GTP-OSS is designed for "agentic
      workflows," meaning it can use tools to solve problems. It was trained to
      interleave its reasoning steps with actions like performing web searches for
      up-to-date information or executing Python code in a notebook environment to
      perform calculations or test solutions.
   ● Transparent Reasoning: The models provide full, transparent access to their
      Chain-of-Thought process. This allows developers to see the step-by-step
      logic the model used to arrive at a conclusion, which is crucial for debugging,
      building trust, and understanding the model's problem-solving approach.
GTP-OSS includes unique features that allow users to directly control and interact
with its reasoning process.
A new benchmark has been set in the world of open-weight models with the release
of the Kimi K2 LLM. This model is truly massive, boasting an incredible
one-trillion-parameter architecture. Upon its release, Kimi K2 has immediately taken
the crown as the best available open-weight model. Initial benchmarks indicate that
its performance is exceptionally strong. It reportedly rivals even the best proprietary
LLMs currently on the market. Specifically, its capabilities on coding tasks are shown
to be competitive with leading models like Claude. This level of performance
establishes a new standard for what can be achieved by the open-source community.
The model's debut has understandably generated significant excitement and
discussion among AI researchers and developers. Kimi K2's impressive power signifies
a major leap forward for accessible, state-of-the-art artificial intelligence. Its
existence pushes the boundaries of large-scale model development.
When analyzing its architecture, one discovers that Kimi K2 builds upon a highly
successful existing foundation. Its structure is fundamentally based on the
architecture of the 673-billion-parameter DeepSeek V3. The two models are
described as being "basically the same," indicating a strategic decision to refine a
proven design. This convergence highlights a trend towards iterating on successful
blueprints for massive-scale models. However, the Kimi K2 team did implement a few
subtle but important modifications. The model was engineered to use more experts
within its Mixture-of-Experts (MoE) modules. This particular tweak likely provides the
model with greater knowledge capacity and allows for more refined specialization.
Conversely, Kimi K2 employs fewer heads in its Multi-head Latent Attention (MLA)
module compared to its predecessor. This adjustment may represent a deliberate
trade-off between attention complexity and other architectural components. These
changes demonstrate how an existing blueprint can be finely tuned for even greater
scale and performance.
Perhaps the most significant innovation in Kimi K2 is not found in its architecture but
in its training methodology. The developers made a crucial and high-stakes decision
regarding the optimization algorithm used during training. They chose to replace the
popular and widely-used AdamW optimizer with their own custom solution. This new
optimizer, named MuonClip, is a modified version of the relatively new Muon optimizer.
According to the development team, their MuonClip optimizer substantially
outperforms AdamW for the specific task of LLM training. This choice represented a
multi-million dollar bet, considering the immense expense of training a model of this
magnitude. The gamble appears to have paid off handsomely, as evidenced by the
model's training data. The resulting training loss curve is described as being the
smoothest ever seen for a large language model. The exceptional stability and
efficiency gained from the MuonClip optimizer are likely a key contributor to Kimi K2's
final success. This focus on innovating the training process itself highlights a critical
new frontier for achieving state-of-the-art results.
Kimi K2's reasoning capabilities are not the result of a single feature, but rather a
combination of its massive architectural scale and a highly advanced training
methodology that pushes the boundaries of how LLMs are optimized.
While the architecture is based on the DeepSeek-V3 blueprint, its immense scale
provides the necessary foundation for complex thought.
The most significant factor behind Kimi's reasoning power is its innovative training
process.
Key Takeaways
From our architectural survey, several clear conclusions emerge:
   ● Sparsity is the New Standard for Scale: The most dominant trend in
      modern LLMs is the adoption of the Mixture-of-Experts (MoE) architecture.
      This "sparse" approach, used by models like DeepSeek, Kimi 2, Llama 4, and
      gpt-oss, allows for a massive increase in a model's total parameters (and thus
      its knowledge capacity) while keeping the computational cost of inference
      manageable by only activating a small fraction of "experts" for any given token.
   ● Attention Mechanisms are Hyper-Optimized for Efficiency: The
      original Multi-Head Attention has been almost universally replaced by more
      efficient alternatives. Grouped-Query Attention (GQA) is the new industry
      standard, offering a great balance of performance and reduced memory usage.
      However, different models employ specialized variants to meet specific goals,
      such as Multi-Head Latent Attention (MLA) for compressing the KV cache
      (DeepSeek, Kimi 2) or Sliding Window Attention for handling very long
      contexts (Gemma, gpt-oss).
   ● Small Architectural Refinements Yield Big Gains: Modern architectures
      are defined by an accumulation of smaller, but significant, efficiency and
      stability improvements. These include replacing LayerNorm with the
      computationally cheaper RMSNorm, and swapping the GELU activation
      function for the more expressive and efficient SwiGLU. The strategic
      placement of these normalization layers has also become a key area of
      innovation for ensuring stable training.
   ● Innovation is Shifting from Core Concepts to Training Methods: While
      architectural tweaks are important, a new frontier for achieving
      state-of-the-art performance is emerging in the training process itself. The
      success of Kimi K2, for example, is attributed not just to its architecture but to
      its use of the novel MuonClip optimizer over the standard AdamW. This shows
      that how a model is trained can be just as innovative and impactful as its
      design.
   ● Model Design is a Game of Trade-offs: There is no single "best"
      architecture; instead, developers make deliberate trade-offs. This is seen in the
      "width versus depth" debate (e.g., the wider gpt-oss vs. the deeper Qwen3),
      the configuration of MoE layers (few large experts vs. many small ones), and
      the choice of attention mechanisms. Each leading model represents a unique
      combination of these techniques, tailored to a specific set of performance and
      efficiency goals.
Conclusion
The era of radical, ground-up architectural redesigns appears to have paused,
replaced by a period of profound and intelligent maturation. The "minor refinements"
seen in modern LLMs are, in fact, the critical enablers of their immense scale and
power. Each choice—be it MoE, GQA, MLA, or a specific normalization strategy—is a
deliberate engineering decision aimed at optimizing the trade-off between a model's
capacity and its computational cost.
The developments of 2025 show that the frontier of AI is being pushed not just by
brute-force scaling, but by the sophisticated and cumulative process of architectural
refinement. We are learning to build not just bigger models, but smarter, more efficient
ones. The foundation laid by the original transformer has proven remarkably robust,
and the ongoing work is a testament to how much performance can be unlocked by
meticulously polishing that foundation. The next breakthroughs may not come from a
new foundation, but from novel combinations of these hard-won refinements.
Gemma Model
Our investigation is structured to mirror the model's own construction. We begin with
its foundational blueprint—the configuration files that specify its core parameters or
"genome." Subsequently, we dissect the primary cognitive engine, analyzing the
transformer blocks, sophisticated attention mechanisms, and normalization
techniques that underpin its linguistic capabilities. We then explore the frontier of
multimodality by examining the Gemma 3 architecture, detailing how its vision system
is integrated to perceive and reason about visual data. Finally, we address the critical
challenge of scale, investigating the software mechanisms for distributed execution
that allow the model to operate on large-scale computational clusters. This analysis is
based on the publicly available source code, reviewed as of August 15, 2025, providing
a concrete and verifiable basis for our exploration (the repository is available at
https://github.com/google/gemma_pytorch/tree/main/gemma).
gemma/config.py
Code: https://github.com/google/gemma_pytorch/blob/main/gemma/config.py
This file functions as the genome for a Gemma model, housing the core instructions
and parameters needed to construct any variant. It guarantees that each component
adheres to precise specifications, yielding a cohesive and operational neural network.
gemma/gemma3_model.py
code: https://github.com/google/gemma_pytorch/blob/main/gemma/gemma3_model.py
This code defines the constructor for the Gemma3ForMultimodalLM class, which
instantiates and assembles the complete computational graph for the Gemma 3
model. The initialization process is methodical, first establishing the language
processing backbone before integrating the specialized vision components.
Initially, the constructor sets up the model's linguistic foundation by initializing the
standard modules from the base text-only architecture: the tokenizer, the
text_token_embedder, the core stack of transformer blocks (model), and the token
sampler. Subsequently, it integrates the visual processing pipeline. This involves
instantiating the siglip_vision_model, a powerful Vision Transformer that functions as a
feature extractor, converting raw image data into a high-dimensional semantic
representation. To ensure compatibility between the two modalities, the visual
embeddings are first stabilized by an RMSNorm layer (mm_soft_embedding_norm)
and then transformed by a linear mm_input_projection layer. This projection acts as an
adapter, mapping the visual embeddings into the same latent space as the text
embeddings, which is a critical prerequisite for their fusion.
gemma/gemma3_model.py
import   torch
import   os
import   json
import   gc
from torch import nn
from PIL import Image
from typing import Any, List, Sequence, Tuple, Union
class Gemma3ForMultimodalLM(nn.Module):
  """Gemma3 model for multimodal causal LM."""
  def __init__(
        self,
        config: gemma_config.GemmaConfig,
    ):
    super().__init__()
    self.dtype = config.get_dtype()
    assert config.architecture == gemma_config.Architecture.GEMMA_3
    self.config = config
    max_seq_len = config.max_position_embeddings
    head_dim = config.head_dim
    vocab_size = config.vocab_size
    self.tokenizer = tokenizer.Tokenizer(config.tokenizer)
    self.text_token_embedder = gemma_model.Embedding(vocab_size,
config.hidden_size, config.quant)
    self.model = gemma_model.GemmaModel(config)
    self.sampler = gemma_model.Sampler(vocab_size, config)
    if config.vision_config is None:
      raise ValueError('vision_config must be provided for Gemma3.')
    self.siglip_vision_model =
siglip_vision_model.SiglipVisionModel(config.vision_config)
    # transformer/embedder/mm_soft_embedding_norm
    self.mm_soft_embedding_norm =
gemma_model.RMSNorm(config.vision_config.embedding_dim,
                                                           eps =
config.rms_norm_eps)
    # transformer/embedder/mm_input_projection
    self.mm_input_projection =
gemma_model.Linear(config.vision_config.embedding_dim, config.hidden_size,
config.quant)
    if config.rope_wave_length is None:
      raise ValueError('rope_wave_length must be provided for Gemma3.')
    rope_lengths = config.rope_wave_length
    defaults = {
            gemma_config.AttentionType.LOCAL_SLIDING: 10_000,
            gemma_config.AttentionType.GLOBAL: 10_000,
        }
    self._register_freqs_cis('local_freqs_cis', head_dim, max_seq_len,
theta=rope_lengths.get(
              gemma_config.AttentionType.LOCAL_SLIDING,
defaults[gemma_config.AttentionType.LOCAL_SLIDING]
          ))
    self._register_freqs_cis('global_freqs_cis', head_dim, max_seq_len,
theta=rope_lengths.get(
              gemma_config.AttentionType.GLOBAL,
defaults[gemma_config.AttentionType.GLOBAL]
          ), rope_scaling_factor=config.rope_scaling_factor)
  def _register_freqs_cis(
        self, name: str, head_dim: int, max_seq_len: int, theta: int =
10_000, rope_scaling_factor: int = 1
    ):
    self.register_buffer(
            name, gemma_model.precompute_freqs_cis(head_dim, max_seq_len *
2, theta=theta, rope_scaling_factor=rope_scaling_factor)
        )
The forward method transforms input tensors into the next token prediction,
processing text and images in parallel. It begins by obtaining Rotary Positional
Embeddings (RoPE) for token positioning. Text input IDs are converted into vector
embeddings. If images are present, the siglip_vision_model extracts conceptual
embeddings, which are then normalized and projected to match text embedding
dimensions. The populate_image_embeddings method then fuses these text and
image embeddings into a single multimodal representation. This combined sequence
is fed into the main transformer (self.model) for reasoning. Finally, a sampler
computes probability scores and selects the most likely next token, returning its ID
and raw logits.
  @torch.no_grad()
   def forward(self,
                 input_token_ids: torch.Tensor, # B x L
                 image_patches: torch.Tensor, # B x N x C x H x W (3x896x896)
                 image_presence_mask: torch.Tensor, # B x N
                 input_positions: torch.Tensor,
                 kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
                 mask: torch.Tensor,
                 output_positions: torch.Tensor,
                temperatures: Union[torch.Tensor, None],
                top_ps: torch.Tensor,
                top_ks: torch.Tensor,
                local_mask: torch.Tensor | None = None,
                **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
    freqs_cis = {}
    freqs_cis[gemma_config.AttentionType.LOCAL_SLIDING] = (
            self.local_freqs_cis.index_select(0, input_positions)
        )
    freqs_cis[gemma_config.AttentionType.GLOBAL] = (
            self.global_freqs_cis.index_select(0, input_positions)
        )
    hidden_states = self.text_token_embedder(input_token_ids)
    normalizer = torch.tensor(self.config.hidden_size**0.5,
dtype=hidden_states.dtype, device=hidden_states.device)
    hidden_states = hidden_states * normalizer
    if image_patches is not None and self.config.vision_config is not None:
      # the input has images
      B, N, C, H, W = image_patches.shape
      # Flatten and Pass to SiglipVisionModel, and apply SiglipVisionModel
Exit
      flattened_input = image_patches.reshape(B * N, C, H, W) # (B*N)xCxHxW
      image_embeddings = self.siglip_vision_model(flattened_input) #
(B*N)xUxD
      image_embeddings = self.mm_soft_embedding_norm(image_embeddings) #
(B*N) x U x D
      image_embeddings = self.mm_input_projection(image_embeddings) # (B*N)
x U x model_dim
      hidden_states = self.populate_image_embeddings(
          hidden_states.clone(),
          image_embeddings.clone(),
          input_token_ids.clone(),
          image_presence_mask.clone(),
      )
kv_write_indices = input_positions
    hidden_states = self.model(
            hidden_states=hidden_states,
            freqs_cis=freqs_cis,
            kv_write_indices=kv_write_indices,
            kv_caches=kv_caches,
            mask=mask,
            local_mask=local_mask,
        )
    embedder_weight = self.text_token_embedder.weight
    if self.config.quant:
      embedder_weight = (
                embedder_weight *
self.text_token_embedder.weight_scaler.unsqueeze(-1))
   def populate_image_embeddings(self,
                                 hidden_states: torch.Tensor, # B x L x
 model_dim
                                 image_embeddings: torch.Tensor, # (B*N) x U
 x model_dim
                                 input_token_ids: torch.Tensor, # B x L
                                 image_presence_mask: torch.Tensor, # B x N
                                 ):
     batch_size, seq_len, model_dim = hidden_states.shape
     # Step 1 of 2: Fetch valid image embeddings
     # flatten indices of valid image embeddings
     valid_image_embeddings_indices =
 torch.nonzero(image_presence_mask.flatten(), as_tuple=False).squeeze()
     # num_valid_images x model_dim
     valid_image_embeddings = image_embeddings.index_select(0,
 valid_image_embeddings_indices)
This function implements a hybrid attention mask strategy for multimodal models,
combining causal attention for text and bidirectional attention for images. It starts with
a causal mask, identifies and groups image tokens, and then creates a bidirectional
mask for these image blocks. The final mask is an OR combination of the causal and
bidirectional masks. A local mask is also applied for sliding window attention. This
comprehensive masking provides precise guidance for processing complex
multimodal inputs.
    # image_token_mask:
    # [[False, False, True, True, False],
    # [True, True, False, True, True]]
    # numbered_boundary:
    # [[0, 0, 1, 1, 1],
    # [1, 1, 1, 2, 2]]
    # q_block_indices:
    # [[0, 0, 1, 1, 0],
    # [1, 1, 0, 2, 2]]
    q_block_indices = image_token_mask * numbered_boundary
    kv_block_indices = q_block_indices
    # Test the equality of vertical and horizontal numbered patches
    # to create the bidirectional mask.
    bidirectional_mask = torch.logical_and(
        kv_block_indices[:, None, :] == q_block_indices.unsqueeze(-1),
        q_block_indices.unsqueeze(-1) > 0,
    )
    attention_mask = torch.logical_or(causal_mask,
bidirectional_mask.unsqueeze(1))
    # The upper triangular matrix's diagonal is shifted by sliding window
size
    # before doing logical 'and' with attention mask. This is to make sure
the
    # local attention is within the sliding window.
    local_mask = torch.logical_and(
            attention_mask,
            torch.triu(torch.ones((1, 1, sequence_length, sequence_length),
dtype=torch.bool, device=input_ids.device),
diagonal=-(self.config.sliding_window_size-1))
        )
    return attention_mask, local_mask
The generate method is the main entry point for a user. It takes a raw prompt, which
can be a complex sequence of text and images, and manages the entire end-to-end
process of generating a textual response.
   ● Input Preprocessing: The first step is to delegate the complex task of input
      preparation to the gemma3_preprocessor. This function handles tokenizing the
      text, processing images using the "pan-and-scan" technique, inserting special
      image tokens, and padding all sequences to create uniform batches.
   ● Attention Mask Creation: The self.create_attention_mask method is
      subsequently invoked to construct a sophisticated hybrid mask. This mask
      facilitates causal attention for text and bidirectional attention for images.
      Boolean values are then converted to floating-point values within this mask,
      with disallowed positions being populated by a large negative number.
   ● KV Cache Initialization: The KV Cache is a vital optimization that significantly
      enhances the efficiency of text generation. It works by pre-allocating memory
      for the Key (K) and Value (V) tensors within each layer of the model. This
      pre-storage eliminates the need to recompute these tensors for the entire
      context at every generation step.
The Two-Phase Generation Loop: The core of the generate method is a loop that
implements autoregressive decoding. This process is split into two distinct phases for
maximum efficiency:
   ● Phase 1: Prefill: During the initial loop iteration, the model efficiently processes
      the complete input prompt, including both text and images, in a single, parallel
      forward pass. This action populates the KV cache with context derived from the
      user's prompt and involves passing the image_batch to the model.
   ● Phase 2: Decode: After the initial prefill step, the model efficiently generates
      one token at a time in subsequent iterations. This process is very fast during
      the decoding phase because each step only requires processing the single
      token generated previously. A crucial memory optimization involves discarding
      the image_batch after the prefill step, as its information is already encoded
      within the KV cache.
After the loop completes, the final sequence of token IDs is converted back into a
human-readable string by the tokenizer, and the result is returned.
   def generate(
       self,
       prompts: Sequence[Sequence[Union[str, Image.Image]]],
       device: Any,
       output_len: int = 100,
       temperature: Union[float, None] = 1.0,
       top_p: float = 0.95,
       top_k: int = 64,
   ) -> Sequence[str]:
   """Generates responses for given prompts using Gemma model."""
   # Inference only.
   processing_result = gemma3_preprocessor.tokenize_raw_input(
           self.tokenizer, prompts, self.config, output_len, device
       )
   batch_size = processing_result["batch_size"]
   user_input_token_ids = processing_result["user_input_token_ids"]
   image_batch = processing_result["image_batch"]
   min_prompt_len = processing_result["min_prompt_len"]
   max_prompt_len = processing_result["max_prompt_len"]
   total_seq_len = processing_result["max_seq_len"]
   image_presence_mask = processing_result["image_presence_mask"]
   kv_caches = []
   for _ in range(self.config.num_hidden_layers):
     size = (batch_size, total_seq_len, self.config.num_key_value_heads,
                   self.config.head_dim)
     dtype = self.config.get_dtype()
     k_cache = torch.zeros(size=size, dtype=dtype, device=device)
     v_cache = torch.zeros(size=size, dtype=dtype, device=device)
     kv_caches.append((k_cache, v_cache))
     input_token_ids_tensor = output_token_ids
     input_positions_tensor = output_index.unsqueeze(dim=-1)
     curr_mask_tensor = mask_tensor.index_select(2,
input_positions_tensor)
      curr_local_mask_tensor = local_mask_tensor.index_select(
                2, input_positions_tensor
            ) if local_mask_tensor is not None else None
      output_positions_tensor = torch.tensor(0, dtype=torch.int64,
device=device)
      output_index = output_index + 1
      image_batch = None
      image_presence_mask = None
   # Detokenization.
   token_ids = token_ids_tensor.tolist()
   results = []
     for i, tokens in enumerate(token_ids):
       output = tokens
       if self.tokenizer.eos_id in output:
         eos_index = output.index(self.tokenizer.eos_id)
         output = output[:eos_index]
       results.append(self.tokenizer.decode(output))
return results
gemma/gemma3_preprocessor.py
code: https://github.com/google/gemma_pytorch/blob/main/gemma/gemma3_preprocessor.py
This code prepares mixed text and image inputs for the Gemma 3 multimodal model
by transforming them into formatted tensors. The gemma3_input_preprocessor
function uses a "pan-and-scan" technique to generate multiple high-resolution crops
from the original image. Both original and cropped images are processed by a SigLIP
vision preprocessor, converting them to torch.Tensor objects and providing holistic
and detailed views with textual context. The tokenize_raw_input function orchestrates
batch processing. Text is tokenized into integer IDs. For images, special tokens (<boi>,
<image> placeholders, <eoi>) are inserted, and the tensor is stored separately. The
function pads token sequences and image lists to uniform sizes, converting them into
torch.Tensor objects and creating an image_presence_mask. The output is a
dictionary containing all necessary tensors for the Gemma 3 model's forward pass.
gemma/model.py
code: https://github.com/google/gemma_pytorch/blob/main/gemma/gemma3_preprocessor.py
This file provides the complete implementation of the text-only Gemma models. It
defines the core modules of the transformer architecture, how they fit together, and
the high-level interface for generating text. It's the cognitive engine that powers
Gemma's language capabilities.
The Sampler class is responsible for the final token selection process in the language
model. Its forward method takes the model's final hidden states and calculates logits
by performing a matrix multiplication with the embedding matrix. It supports an
optional final_logit_softcapping to stabilize training by limiting the range of logit
values. If no temperature is provided, it performs greedy decoding by simply selecting
the token with the highest logit. Otherwise, it applies temperature scaling, top-k, and
top-p filtering to the probabilities derived from the logits. Top-k limits the sampling
pool to the k most likely tokens, while top-p filters by cumulative probability. After
filtering, the probabilities are renormalized, and the final token is chosen via
multinomial sampling.
gemma/model.py
import json
import gc
import os
import torch
from torch import nn
import torch.nn.functional as F
from typing import Any, List, Optional, Sequence, Tuple, Union, Mapping
from gemma import config as gemma_config
from gemma import tokenizer
class Sampler(nn.Module):
   @torch.no_grad()
   def forward(
       self,
       embedding: torch.Tensor,
       hidden_states: torch.Tensor,
       output_positions: torch.Tensor,
       temperatures: Union[torch.Tensor, None],
       top_ps: torch.Tensor,
       top_ks: torch.Tensor,
       embedding_bias: Optional[torch.Tensor] = None,
   ) -> Tuple[torch.Tensor, torch.Tensor]:
       # Select the last element for each sequence.
       # (batch_size, input_len, hidden_size) -> (batch_size, hidden_size)
       hidden_states = hidden_states.index_select(
           1, output_positions).squeeze(dim=1)
       logits = torch.matmul(hidden_states, embedding.t())
       if embedding_bias is not None:
           logits += embedding_bias
       if self.config.final_logit_softcapping is not None:
           logits = logits / self.config.final_logit_softcapping
           logits = torch.tanh(logits)
           logits = logits * self.config.final_logit_softcapping
       if temperatures is None:
           return torch.argmax(logits, dim=-1).squeeze(dim=-1), logits
       top_ks_mask = torch.arange(probs_idx.shape[-1],
                                  device=probs_idx.device)
       top_ks_mask = top_ks_mask.expand(probs_idx.shape[0], -1)
       top_ks_mask = top_ks_mask >= top_ks.unsqueeze(dim=1)
       probs_sort = torch.where(top_ks_mask, 0, probs_sort)
          # Re-normalization.
          probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
          probs = torch.gather(probs_sort,
                               dim=-1,
                               index=torch.argsort(probs_idx, dim=-1))
          next_token_ids = torch.multinomial(probs,
                                             num_samples=1,
                                             replacement=True).squeeze(dim=-1)
          return next_token_ids, logits
The Linear class defines a custom linear transformation layer. A key feature of this
class is its support for 8-bit quantization. If the quant flag is set to True, the weights
are stored as 8-bit integers (torch.int8), and a separate floating-point weight_scaler
tensor is also stored. During the forward pass, if quantized, the weights are
de-quantized on-the-fly by multiplying them with the scaler before the standard
F.linear operation is performed. This allows for significant memory savings with
minimal impact on performance. If not quantized, it behaves like a standard linear
layer.
class Embedding(nn.Module):
class RMSNorm(torch.nn.Module):
   def __init__(
          self,
          dim: int,
          eps: float = 1e-6,
          add_unit_offset: bool = True,
     ):
          super().__init__()
          self.eps = eps
          self.add_unit_offset = add_unit_offset
          self.weight = nn.Parameter(torch.zeros(dim))
The GemmaMLP class constitutes the feed-forward network (FFN) block found within
each transformer decoder layer. It implements a variant of the Gated Linear Unit (GLU)
architecture. The input tensor x is passed through two separate linear layers:
gate_proj and up_proj. The output of the gate_proj is then passed through a GELU
activation function. This activated gate is then multiplied element-wise with the
output of the up_proj. Finally, the result is passed through a down_proj linear layer to
project it back to the model's hidden dimension.
class GemmaMLP(nn.Module):
     def __init__(
         self,
         hidden_size: int,
         intermediate_size: int,
         quant: bool,
     ):
         super().__init__()
         self.gate_proj = Linear(hidden_size, intermediate_size, quant)
         self.up_proj = Linear(hidden_size, intermediate_size, quant)
         self.down_proj = Linear(intermediate_size, hidden_size, quant)
class GemmaAttention(nn.Module):
     def __init__(
         self,
         config: gemma_config.GemmaConfig,
         attn_type: gemma_config.AttentionType,
     ):
         super().__init__()
         self.num_heads = config.num_attention_heads
         self.num_kv_heads = config.num_key_value_heads
         self.hidden_size = config.hidden_size
         self.head_dim = config.head_dim
         self.qkv_proj = Linear(
             self.hidden_size,
            (self.num_heads + 2 * self.num_kv_heads) * self.head_dim,
            quant=config.quant)
        self.o_proj = Linear(
            self.num_heads * self.head_dim, self.hidden_size,
quant=config.quant
        )
        self.query_norm = (
            RMSNorm(self.head_dim, eps=config.rms_norm_eps)
            if config.use_qk_norm
            else None
        )
        self.key_norm = (
            RMSNorm(self.head_dim, eps=config.rms_norm_eps)
            if config.use_qk_norm
            else None
        )
       self.attn_type = attn_type
       self.sliding_window_size = config.sliding_window_size
       self.attn_logit_softcapping = config.attn_logit_softcapping
   def forward(
       self,
       hidden_states: torch.Tensor,
       freqs_cis: torch.Tensor,
       kv_write_indices: torch.Tensor,
       kv_cache: Tuple[torch.Tensor, torch.Tensor],
       mask: torch.Tensor,
       local_mask: torch.Tensor = None,
   ) -> torch.Tensor:
       hidden_states_shape = hidden_states.shape
       assert len(hidden_states_shape) == 3
       qkv = self.qkv_proj(hidden_states)
       xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size],
                              dim=-1)
       # Positional embedding.
       xq = apply_rotary_emb(xq, freqs_cis=freqs_cis)
       xk = apply_rotary_emb(xk, freqs_cis=freqs_cis)
         key = k_cache
         value = v_cache
         if self.num_kv_heads != self.num_heads:
             # [batch_size, max_seq_len, n_local_heads, head_dim]
             key = torch.repeat_interleave(key, self.num_queries_per_kv,
dim=2)
              value = torch.repeat_interleave(value,
                                              self.num_queries_per_kv,
                                              dim=2)
The GemmaDecoderLayer class defines the structure of a single transformer block for
the Gemma 1 architecture. It follows a standard "pre-normalization" setup. The
forward method first computes a residual, then applies RMSNorm (input_layernorm)
to the input, which is then processed by the GemmaAttention block (self_attn). The
output of the attention block is added back to the initial residual. This is followed by a
second residual connection around the MLP block: the result from the attention stage
is normalized again (post_attention_layernorm), passed through the GemmaMLP, and
finally added to its own residual.
class GemmaDecoderLayer(nn.Module):
     def __init__(
         self,
         config: gemma_config.GemmaConfig,
     ):
         super().__init__()
         self.attn_type = gemma_config.AttentionType.GLOBAL
         self.self_attn = GemmaAttention(
             config=config,
             attn_type=self.attn_type)
         self.mlp = GemmaMLP(
             hidden_size=config.hidden_size,
             intermediate_size=config.intermediate_size,
             quant=config.quant,
         )
         self.input_layernorm = RMSNorm(config.hidden_size,
                                        eps=config.rms_norm_eps)
         self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                 eps=config.rms_norm_eps)
       # MLP
       residual = hidden_states
       hidden_states = self.post_attention_layernorm(hidden_states)
       hidden_states = self.mlp(hidden_states)
       hidden_states = residual + hidden_states
return hidden_states
class Gemma2DecoderLayer(nn.Module):
    def __init__(
        self,
        config: gemma_config.GemmaConfig,
        attn_type: gemma_config.AttentionType,
    ):
        super().__init__()
        self.attn_type = attn_type
        self.self_attn = GemmaAttention(
            config=config,
            attn_type=self.attn_type,
        )
        self.mlp = GemmaMLP(
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
            quant=config.quant,
        )
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
        self.pre_feedforward_layernorm = (
            RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
            if config.use_pre_ffw_norm
            else None
        )
        self.post_feedforward_layernorm = (
            RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
            if config.use_post_ffw_norm
            else None
        )
   def forward(
       self,
       hidden_states: torch.Tensor,
         freqs_cis: torch.Tensor,
         kv_write_indices: torch.Tensor,
         kv_cache: Tuple[torch.Tensor, torch.Tensor],
         mask: torch.Tensor,
         local_mask: torch.Tensor,
     ) -> torch.Tensor:
         # Self Attention
         residual = hidden_states
         hidden_states = self.input_layernorm(hidden_states)
         hidden_states = self.self_attn(
             hidden_states=hidden_states,
             freqs_cis=freqs_cis,
             kv_write_indices=kv_write_indices,
             kv_cache=kv_cache,
             mask=mask,
             local_mask=local_mask,
         )
         hidden_states = self.post_attention_layernorm(hidden_states)
         hidden_states = residual + hidden_states
          # MLP
          residual = hidden_states
          if self.pre_feedforward_layernorm is not None:
              hidden_states = self.pre_feedforward_layernorm(hidden_states)
          hidden_states = self.mlp(hidden_states)
          if self.post_feedforward_layernorm is not None:
              hidden_states = self.post_feedforward_layernorm(hidden_states)
          hidden_states = residual + hidden_states
return hidden_states
The GemmaModel class represents the core transformer architecture, assembling the
stack of decoder layers. During initialization, it creates a ModuleList of decoder layers,
dynamically choosing between GemmaDecoderLayer for Gemma 1 and
Gemma2DecoderLayer for Gemma 2 and 3 based on the provided configuration. For
Gemma 2/3, it also assigns the attention type (local or global) to each layer. The
forward method sequentially processes the input hidden_states through each layer,
passing along the necessary RoPE frequencies, KV cache for that layer, and attention
masks. After the final layer, a concluding RMSNorm is applied to the output.
class GemmaModel(nn.Module):
          self.layers = nn.ModuleList()
          for i in range(config.num_hidden_layers):
              if config.architecture == gemma_config.Architecture.GEMMA_1:
                  self.layers.append(GemmaDecoderLayer(config))
              elif config.architecture in (
                  gemma_config.Architecture.GEMMA_2,
                  gemma_config.Architecture.GEMMA_3,
              ):
                  attn_type = (
                      config.attn_types[i % len(config.attn_types)]
                      if config.attn_types is not None
                      else gemma_config.AttentionType.GLOBAL
                  )
                  self.layers.append(Gemma2DecoderLayer(config, attn_type))
              else:
                  raise ValueError(f'Unknown architecture:
{config.architecture}')
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
     def forward(
         self,
         hidden_states: torch.Tensor,
         freqs_cis: Mapping[gemma_config.AttentionType, torch.Tensor],
         kv_write_indices: torch.Tensor,
         kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
         mask: torch.Tensor,
         local_mask: torch.Tensor,
     ) -> torch.Tensor:
         for i in range(len(self.layers)):
             layer = self.layers[i]
             hidden_states = layer(
                 hidden_states=hidden_states,
                 freqs_cis=freqs_cis.get(layer.attn_type),
                 kv_write_indices=kv_write_indices,
                 kv_cache=kv_caches[i],
                 mask=mask,
                 local_mask=local_mask,
             )
         hidden_states = self.norm(hidden_states)
         return hidden_states
class GemmaForCausalLM(nn.Module):
    def __init__(
          self,
          config: gemma_config.GemmaConfig,
      ):
      super().__init__()
      self.config = config
      assert config.hidden_size % config.num_attention_heads == 0
     max_seq_len = config.max_position_embeddings
     head_dim = config.head_dim
     vocab_size = config.vocab_size
     self.tokenizer = tokenizer.Tokenizer(config.tokenizer)
     self.embedder = Embedding(vocab_size, config.hidden_size, config.quant)
     self.model = GemmaModel(config)
     self.sampler = Sampler(vocab_size, config)
     else:
       self._register_freqs_cis('freqs_cis', head_dim, max_seq_len)
  def _register_freqs_cis(
        self, name: str, head_dim: int, max_seq_len: int, theta: int =
10_000
    ):
    self.register_buffer(
            name, precompute_freqs_cis(head_dim, max_seq_len * 2,
theta=theta)
        )
    @torch.no_grad()
    def forward(
          self,
          input_token_ids: torch.Tensor,
          input_positions: torch.Tensor,
          kv_write_indices: torch.Tensor,
          kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
          mask: torch.Tensor,
          output_positions: torch.Tensor,
          temperatures: Union[torch.Tensor, None],
          top_ps: torch.Tensor,
          top_ks: torch.Tensor,
          local_mask: torch.Tensor | None = None,
          **kwargs,
      ) -> Tuple[torch.Tensor, torch.Tensor]:
      freqs_cis = {}
     if self.config.architecture == gemma_config.Architecture.GEMMA_3:
       freqs_cis[gemma_config.AttentionType.LOCAL_SLIDING] = (
                 self.local_freqs_cis.index_select(0, input_positions)
             )
       freqs_cis[gemma_config.AttentionType.GLOBAL] = (
                 self.global_freqs_cis.index_select(0, input_positions)
             )
     else:
       freqs_cis[gemma_config.AttentionType.LOCAL_SLIDING] = (
                 self.freqs_cis.index_select(0, input_positions)
             )
       freqs_cis[gemma_config.AttentionType.GLOBAL] = (
                 self.freqs_cis.index_select(0, input_positions)
             )
kv_write_indices = input_positions
     hidden_states = self.model(
             hidden_states=hidden_states,
             freqs_cis=freqs_cis,
             kv_write_indices=kv_write_indices,
             kv_caches=kv_caches,
             mask=mask,
             local_mask=local_mask,
         )
     embedder_weight = self.embedder.weight
     if self.config.quant:
       embedder_weight = (
                 embedder_weight * self.embedder.weight_scaler.unsqueeze(-1))
     next_tokens, logits = self.sampler(
             embedding=embedder_weight,
             hidden_states=hidden_states,
             output_positions=output_positions,
             temperatures=temperatures,
             top_ps=top_ps,
             top_ks=top_ks,
         )
     return next_tokens, logits
The generate method provides the main entry point for users to perform text
generation. It orchestrates the entire auto-regressive decoding process. First, it
tokenizes the input prompts and initializes the KV caches and attention masks. The
core of the method is a loop that runs for the desired output length. This loop
implements a two-phase generation strategy: a "prefill" phase where the input prompt
is processed in a single large batch to populate the KV cache, followed by a "decode"
phase where tokens are generated one by one. In each step of the loop, it calls the
forward method to get the next token, updates the overall sequence of tokens, and
prepares the inputs for the next iteration. Finally, after the loop finishes, it detokenizes
the completed sequences and returns the generated text.
 def generate(
         self,
         prompts: Union[str, Sequence[str]],
         device: Any,
         output_len: int = 100,
         temperature: Union[float, None] = 1.0,
         top_p: float = 0.95,
         top_k: int = 64,
     ) -> Union[str, Sequence[str]]:
     """Generates responses for given prompts using Gemma model."""
     # If a single prompt is provided, treat it as a batch of 1.
     is_str_prompt = isinstance(prompts, str)
     if is_str_prompt:
       prompts = [prompts]
     batch_size = len(prompts)
     prompt_tokens = [self.tokenizer.encode(prompt) for prompt in prompts]
     min_prompt_len = min(len(p) for p in prompt_tokens)
     max_prompt_len = max(len(p) for p in prompt_tokens)
     max_seq_len = max_prompt_len + output_len
     assert max_seq_len <= self.config.max_position_embeddings
     # build KV caches
     kv_caches = []
     for _ in range(self.config.num_hidden_layers):
       size = (batch_size, max_seq_len, self.config.num_key_value_heads,
                     self.config.head_dim)
       dtype = self.config.get_dtype()
       k_cache = torch.zeros(size=size, dtype=dtype, device=device)
       v_cache = torch.zeros(size=size, dtype=dtype, device=device)
       kv_caches.append((k_cache, v_cache))
     # prepare inputs
     token_ids_tensor = torch.full((batch_size, max_seq_len),
                                       self.tokenizer.pad_id,
 dtype=torch.int64)
     input_token_ids_tensor = torch.full((batch_size, min_prompt_len),
                                             self.tokenizer.pad_id,
                                             dtype=torch.int64)
     for i, p in enumerate(prompt_tokens):
       token_ids_tensor[i, :len(p)] = torch.tensor(p)
       input_token_ids_tensor[i, :min_prompt_len] = torch.tensor(
                 p[:min_prompt_len])
     token_ids_tensor = token_ids_tensor.to(device)
     input_token_ids_tensor = input_token_ids_tensor.to(device)
     prompt_mask_tensor = token_ids_tensor != self.tokenizer.pad_id
     input_positions_tensor = torch.arange(0, min_prompt_len,
                                               dtype=torch.int64).to(device)
     mask_tensor = torch.full((1, 1, max_seq_len, max_seq_len),
                                 -2.3819763e38).to(torch.float)
    mask_tensor = torch.triu(mask_tensor, diagonal=1).to(device)
    local_mask_tensor = mask_tensor + torch.tril(
            torch.full((1, 1, max_seq_len, max_seq_len), -2.3819763e38,
device=device),
            diagonal=-self.config.sliding_window_size,
        ) if self.config.sliding_window_size else None
    curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor)
    curr_local_mask_tensor = local_mask_tensor.index_select(
          2, input_positions_tensor
      ) if local_mask_tensor is not None else None
    output_positions_tensor = torch.LongTensor([min_prompt_len -
1]).to(device)
    temperatures_tensor = None if not temperature else torch.FloatTensor(
            [temperature] * batch_size).to(device)
    top_ps_tensor = torch.FloatTensor([top_p] * batch_size).to(device)
    top_ks_tensor = torch.LongTensor([top_k] * batch_size).to(device)
    output_index = torch.tensor(min_prompt_len, dtype=torch.int64).to(
            device)
     curr_prompt_mask = prompt_mask_tensor.index_select(
               1, output_index).squeeze(dim=1)
     curr_token_ids = token_ids_tensor.index_select(
               1, output_index).squeeze(dim=1)
     output_token_ids = torch.where(curr_prompt_mask, curr_token_ids,
                                          next_token_ids).unsqueeze(dim=1)
     token_ids_tensor.index_copy_(1, output_index, output_token_ids)
     input_token_ids_tensor = output_token_ids
     input_positions_tensor = output_index.unsqueeze(dim=-1)
     curr_mask_tensor = mask_tensor.index_select(2,
input_positions_tensor)
      curr_local_mask_tensor = local_mask_tensor.index_select(
                2, input_positions_tensor
              ) if local_mask_tensor is not None else None
        output_positions_tensor = torch.tensor(0, dtype=torch.int64).to(
                  device)
        output_index = output_index + 1
     # Detokenization.
     token_ids = token_ids_tensor.tolist()
     results = []
     for i, tokens in enumerate(token_ids):
       trimmed_output = tokens[len(prompt_tokens[i]):len(prompt_tokens[i])
                                     + output_len]
       if self.tokenizer.eos_id in trimmed_output:
         eos_index = trimmed_output.index(self.tokenizer.eos_id)
         trimmed_output = trimmed_output[:eos_index]
       results.append(self.tokenizer.decode(trimmed_output))
The utility method, load_weights, is designed to load pre-trained weights into the
model. It handles two scenarios: loading from a single checkpoint file and loading
from a sharded checkpoint, where weights are split across multiple files. It first checks
if the provided model_path is a file. If so, it loads the state dictionary directly. If the
path is a directory, it looks for an index JSON file (pytorch_model.bin.index.json) that
maps weight names to their respective shard files. It then iterates through the unique
shard files, loading each one's state dictionary into the model. To conserve memory
during this process, it deletes each state dictionary and calls the garbage collector
after loading.
gemma/model_xla.py
Code: https://github.com/google/gemma_pytorch/blob/main/gemma/model_xla.py.
This code implements distributed Gemma model execution using Tensor Parallelism,
optimized for Google's TPUs via XLA. It adapts the standard Gemma architecture to
efficiently utilize multiple accelerator devices, addressing the challenge of models too
large for a single accelerator's memory. Tensor parallelism is employed to partition, or
"shard," the model's large weight matrices across a group of devices. Each device
stores only a segment of the model's weights and performs a corresponding fraction
of the total computation. The entire Gemma model is rebuilt using specialized parallel
layers     defined     in     xla_model_parallel.py,   including   ColumnParallelLinear,
RowParallelLinear, and ParallelEmbedding. These layers manage the complex
communication necessary to synchronize computations across all devices, making the
distributed nature of the model transparent to the end-user.
Core Parallel Components: The code re-implements the standard Gemma modules
to be aware of the distributed environment.
The Top-Level Parallel Model: The GemmaForCausalLM class is the main entry
point. It constructs the entire model using the parallel components.
Let’s have a look at specific parts. The GemmaAttention class implements the
self-attention mechanism for a tensor-parallel environment. During initialization, it
calculates the number of attention heads and key-value heads per device
(num_heads, num_kv_heads). The QKV projection is handled by a single
ColumnParallelLinear layer, which shards the combined QKV weight matrix across
devices. Conversely, the output projection (o_proj) is a RowParallelLinear layer that
gathers and sums the partial results from each device's attention heads. The forward
pass applies rotary embeddings and manages the KV cache locally on each device,
while the parallel layers handle the necessary communication implicitly.
gemma/model_xla.py
[...]
class GemmaAttention(nn.Module):
     def __init__(
         self,
         hidden_size: int,
         num_heads: int,
         num_kv_heads: int,
         attn_logit_softcapping: Optional[float],
         query_pre_attn_scalar: Optional[int],
         head_dim: int,
         world_size: int,
         rank: int,
         quant: bool,
         attn_type: gemma_config.AttentionType,
         sliding_window_size: Optional[int] = None,
     ):
         super().__init__()
         self.rank = rank
         def init_method(x):
             return x
         self.total_num_heads = num_heads
         assert self.total_num_heads % world_size == 0
         self.num_heads = self.total_num_heads // world_size   # head per
shard
         self.hidden_size = hidden_size
         self.head_dim = head_dim
   self.q_size = self.num_heads * self.head_dim
   self.kv_size = self.num_kv_heads * self.head_dim
   self.qkv_proj = ColumnParallelLinear(
       self.hidden_size,
       (self.total_num_heads + 2 * self.total_num_kv_heads) *
       self.head_dim,
       bias=False,
       gather_output=False,
       init_method=init_method,
       world_size=world_size,
       rank=rank,
       quant=quant,
   )
   self.o_proj = RowParallelLinear(
       self.total_num_heads * self.head_dim,
       self.hidden_size,
       bias=False,
       input_is_parallel=True,
       init_method=init_method,
       world_size=world_size,
       rank=rank,
       quant=quant,
   )
   self.attn_type = attn_type
   self.sliding_window_size = sliding_window_size
   self.attn_logit_softcapping = attn_logit_softcapping
def forward(
    self,
    hidden_states: torch.Tensor,
    freqs_cis: torch.Tensor,
    kv_write_indices: torch.Tensor,
    kv_cache: Tuple[torch.Tensor, torch.Tensor],
    mask: torch.Tensor,
) -> torch.Tensor:
    hidden_states_shape = hidden_states.shape
    assert len(hidden_states_shape) == 3
   qkv = self.qkv_proj(hidden_states)
   xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size],
                          dim=-1)
         # Positional embedding.
         xq = apply_rotary_emb(xq, freqs_cis=freqs_cis)
         xk = apply_rotary_emb(xk, freqs_cis=freqs_cis)
         key = k_cache
         value = v_cache
         if self.num_kv_heads != self.num_heads:
             # [batch_size, max_seq_len, n_local_heads, head_dim]
             key = torch.repeat_interleave(key, self.num_queries_per_kv,
dim=2)
              value = torch.repeat_interleave(value,
                                              self.num_queries_per_kv,
                                              dim=2)
gemma/tokenizer.py
Code:       https://github.com/google/gemma_pytorch/blob/main/gemma/model_xla.py.
This script provides the tokenizer for Gemma models, converting text to numerical IDs
and vice-versa. The Tokenizer class, a wrapper around Google's SentencePiece
library, loads a pre-trained model for tokenization. It identifies special tokens like
bos_id (Beginning of Sequence), eos_id (End of Sequence), pad_id (for padding), and
boi_id/eoi_id (Beginning/End of Image) for multimodal Gemma 3, with pad_id also
serving as image_token_placeholder_id. The encode method converts strings to token
IDs (with optional bos/eos addition), and decode converts IDs back to strings.
gemma/xla_model_parallel.py
This script provides the essential toolkit for running Gemma models at scale using a
technique called Tensor Model Parallelism. It's designed for distributed computing
environments like Google's TPUs (via XLA) or multi-GPU setups (via CUDA), allowing
models with billions of parameters to be split across multiple accelerator chips.
Communication Primitives
This file's core consists of fundamental building blocks that manage how data is
communicated between different devices in a distributed group, implemented as
custom      torch.autograd.Function  classes    that   handle  gradients during
backpropagation.
   ● scatter: Splits a large tensor along a dimension, giving each device a unique
      slice.
   ● gather: Collects and concatenates slices from all devices to reconstruct the
      original large tensor.
   ● reduce: Combines tensors from all devices into a single tensor via operations
      like sum (all-reduce), synchronizing results.
   ● copy: A simple identity operation in the forward pass, but it triggers a reduce
      operation in the backward pass to aggregate gradients.
Using the communication primitives, the script defines new versions of standard
PyTorch layers that are inherently parallel.
   ● ParallelEmbedding: This module splits the large token embedding table. The
      table, which has a shape of (vocab_size, hidden_size), is partitioned along the
      hidden_size dimension (column-wise). Each device holds a vertical slice of the
      embedding table. When the layer receives input token IDs, each device looks
      up its slice of the embedding vector, and the results are gathered from all
      devices to form the complete embedding vector.
   ● ColumnParallelLinear: This layer implements a linear transformation (Y = XA +
      b) where the weight matrix A is partitioned vertically (column-wise). The input X
      is broadcast to all devices. Each device performs a matrix multiplication
      between X and its local slice of the weight matrix A_i, producing a partial output
      Y_i. The partial outputs Y_i from all devices are then gathered and
      concatenated to form the final output Y. This strategy is used for layers where
      the output dimension is large, such as the gate_proj and up_proj layers in the
      MLP.
   ● RowParallelLinear: This layer partitions the weight matrix A horizontally
      (row-wise). The input X is split, and each device receives a slice X_i. Each
      device performs a matrix multiplication between its input slice X_i and its
      weight slice A_i, producing a partial output. The partial outputs from all devices
      are then summed together using an all-reduce operation to produce the final
      output Y. This is used for layers where the input dimension is large, like the
      down_proj in the MLP and the o_proj in the attention mechanism.
   ● Quantization Support: The script also includes functions for 8-bit
      quantization. The quantize_tensor function can take a standard floating-point
      tensor, calculate the appropriate scaling factor, and convert it into a
      memory-efficient int8 tensor. All the parallel layers (ParallelEmbedding,
      ColumnParallelLinear, RowParallelLinear) have built-in support for this, allowing
      them to operate on quantized weights to further reduce the model's memory
      footprint.
Conclusion
The Gemma PyTorch codebase is an exceptionally well-engineered system, from its
foundational hyperparameters to its distributed model's intricate tensor operations. Its
core transformer blocks incorporate modern, efficient techniques like RMS
Normalization for computational stability and Rotary Positional Embeddings for
sophisticated sequence order understanding. Furthermore, Grouped-Query Attention
is implemented to pragmatically optimize inference speed and memory usage without
significant performance trade-offs. A significant evolutionary step is Gemma 3's
introduction of multimodality, which deeply integrates a dedicated vision model into
the language processing stream. This is achieved through a specialized preprocessor,
a "pan-and-scan" technique for detailed image analysis, and a sophisticated attention
mask that facilitates bidirectional visual reasoning, enabling a true fusion of sight and
language. This design provides a compelling blueprint for future multimodal systems.
rt(θ)=πθold(at∣st)πθ(at∣st)
This ratio compares the likelihood of taking action atunder the new policy (πθ) versus
the old policy (πθold) before the update. If rt(θ)>1, the new policy is more likely to take
that action; if rt(θ)<1, it is less likely.
The PPO objective function then incorporates this ratio in a "clipped" way:
LCLIP(θ)=E^t[min(rt(θ)A^t,clip(rt(θ),1−ϵ,1+ϵ)A^t)]
Let's break down the min and clip operations:
By taking this minimum, PPO ensures that each update step is small and controlled,
leading to more stable and reliable training than earlier policy gradient methods.
r∗(x,y)=βlogπref(y∣x)π∗(y∣x)
Here, β is a hyperparameter that controls how much the reward function deviates
from the reference policy. Substituting this back into the preference model gives us
the probability of a preference pair in terms of language model policies alone.
This allows us to construct a loss function that optimizes the policy πθ directly. The
DPO loss is the negative log-likelihood of the preference data:
LDPO(πθ;πref)=−E(x,yw,yl)∼D[logσ(βlogπref(yw∣x)πθ(yw∣x)−βlogπref(yl∣x)πθ(yl∣x))]
   1. The Ratios: For both the winning (yw) and losing (yl) responses, the model
       calculates the ratio of the log probability under the current policy (πθ) to the
       log probability under the frozen reference policy (πref).
   2. The Difference: The loss function is driven by the difference between the
       winner's log-probability ratio and the loser's log-probability ratio.
   3. Optimization Goal: To minimize this loss, the model must adjust its parameters
       θ such that it increases the relative log-probability of the winning
      response (yw) and decreases the relative log-probability of the losing
      response (yl).
Essentially, DPO trains the LLM on a simple binary classification task: for a given
prompt, classify which response is the "winner." By doing so, it implicitly optimizes for
the same preferences that RLHF would, but it does so in a single, stable, and
computationally efficient training phase without ever needing to fit a separate reward
model.