Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 50 additions & 3 deletions MaxText/maxtext_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def calculate_gemma2_tflops_training_per_device(config, total_ffn_flops, qkv_flo
Calculate training TFLOP for Gemma2 as in Gemma2 we combine [local_attention, global_attention] into one decoder
layer and we use sliding window attention in local_attention
"""
attention_flops = (
noncausal_attention_flops = (
# global attention
4 * config.per_device_batch_size * config.max_target_length**2 * config.num_query_heads * config.head_dim
+
Expand All @@ -170,7 +170,8 @@ def calculate_gemma2_tflops_training_per_device(config, total_ffn_flops, qkv_flo
* config.num_query_heads
* config.head_dim
)
attention_tflops = attention_flops * config.num_decoder_layers * 3 / 10**12
causal_attention_flops = noncausal_attention_flops / 2
attention_tflops = causal_attention_flops * config.num_decoder_layers * 3 / 10**12

# multiply num_decoder_layers by 2 because we combine [local_attention, global_attention] into one decoder layer
learnable_weight_tflops = (
Expand All @@ -180,6 +181,48 @@ def calculate_gemma2_tflops_training_per_device(config, total_ffn_flops, qkv_flo
return attention_tflops, learnable_weight_tflops


def calculate_gemma3_tflops_training_per_device(config, total_ffn_flops, qkv_flops, projection_flops, embedding_flops):
"""
Calculate training TFLOPs for Gemma3, which has an alternating pattern of
5 local attention layers and 1 global attention layer.
"""
num_layers = config.num_decoder_layers

num_global_layers = num_layers // 6
num_local_layers = num_layers - num_global_layers

# FLOPs for a single global attention layer (full attention)
# Formula: 4 * batch_size * seq_len^2 * num_heads * head_dim
global_attention_flops_per_layer = (
4 * config.per_device_batch_size * config.max_target_length**2 * config.num_query_heads * config.head_dim
)

# FLOPs for a single local attention layer (sliding window)
# Formula: 4 * batch_size * seq_len * window_size * num_heads * head_dim
local_attention_flops_per_layer = (
4
* config.per_device_batch_size
* config.max_target_length
* min(config.sliding_window_size, config.max_target_length)
* config.num_query_heads
* config.head_dim
)

# Total attention FLOPs = (num_global_layers * FLOPs_per_global) + (num_local_layers * FLOPs_per_local)
noncausal_attention_flops = (
num_global_layers * global_attention_flops_per_layer + num_local_layers * local_attention_flops_per_layer
)
causal_attention_flops = noncausal_attention_flops / 2

# Convert to TFLOPs and multiply by 3 for fwd/bwd pass
attention_tflops = causal_attention_flops * 3 / 10**12

# Learnable weights (FFN, QKV, Projections) are present in every layer.
learnable_weight_tflops = ((total_ffn_flops + qkv_flops + projection_flops) * num_layers + embedding_flops) * 3 / 10**12

return attention_tflops, learnable_weight_tflops


def calculate_mla_tflops_per_device(config):
"""Calculate Multi-Head Latent Attention TFLOP"""
batch_len = config.per_device_batch_size * config.max_target_length
Expand Down Expand Up @@ -304,6 +347,10 @@ def calculate_tflops_training_per_device(config, log=True):
attention_tflops, learnable_weight_tflops = calculate_gemma2_tflops_training_per_device(
config, total_ffn_flops, qkv_flops, projection_flops, embedding_flops
)
elif config.decoder_block == DecoderBlockType.GEMMA3:
attention_tflops, learnable_weight_tflops = calculate_gemma3_tflops_training_per_device(
config, total_ffn_flops, qkv_flops, projection_flops, embedding_flops
)
elif config.decoder_block in (DecoderBlockType.DEEPSEEK, DecoderBlockType.LLAMA4):
learnable_weight_tflops = (
(total_ffn_flops + (qkv_flops + projection_flops) * config.num_decoder_layers + embedding_flops) * 3 / 10**12
Expand Down Expand Up @@ -1080,7 +1127,7 @@ def get_formatted_sharding_annotations(params, mesh=None):
spec_parts = []
for item in p_leaf.sharding.spec:
# Represent None as "Replicated" to make it explicit.
spec_parts.append(str(item) if item is not None else "Relicated")
spec_parts.append(str(item) if item is not None else "Replicated")
sharding_desc = f"PartitionSpec({', '.join(spec_parts)})"
# Case 2: The parameter is explicitly marked as fully replicated.
elif hasattr(p_leaf.sharding, "spec") and p_leaf.sharding.spec is None:
Expand Down
Loading