-
Notifications
You must be signed in to change notification settings - Fork 14.1k
[Speculative decoding] feat: add EAGLE3 speculative decoding support #18039
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
[Speculative decoding] feat: add EAGLE3 speculative decoding support #18039
Conversation
EAGLE3 is an encoder-decoder based speculative decoding method: - Extracts features from target model at specific layers - Uses feature fusion layer to compress target features - Generates draft tokens with single-layer decoder - Maps draft vocabulary to target vocabulary via d2t tensor Key changes: - Add LLM_ARCH_EAGLE3 architecture - Add EAGLE3 encoder/decoder graph (src/models/eagle3.cpp) - Add feature extraction from target model layers - Add g_embeddings handling for decoder input - Add GGML_TENSOR_FLAG_SYNC for GPU synchronization - Add --eagle3 flag for speculative-simple example - Add EAGLE3 model conversion in convert_hf_to_gguf.py
src/models/eagle3.cpp
Outdated
|
|
||
| // Force a sync point between the two parallel RMS_NORM paths | ||
| // This prevents buffer reuse issues on GPU (EAGLE3 GPU fix) | ||
| ggml_set_sync(input_embeds_normed); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is very strange that you need to do it explicitly.
The ggml_concat operator (like every other ggml op) tracks the input tensors on which it depends. So it should not be possible to get a buffer reuse when the data in the buffer is still pending a computation.
I think this sync should not be necessary and if removing it causes some data corruption, the cause is something else which we should investigate in detail.
Can you confirm that removing this call still causes problems?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just revalidated this, and without calling ggml_set_sync, the buffer data gets overwritten, causing the acceptance rate to nearly 3-4%. This issue only occurs on the GPU side — when running draft model on the CPU, the acceptance rate remains stable, and ggml_set_sync is not required.
The results buffers from two RMS_NORM operations appear to conflict, with one being overwritten by invalid (garbage) values. ggml_set_sync is used to enforce synchronization between two RMS_NORM operations on GPU side.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also tried using ggml_set_output for the two RMS_NORM results to avoid buffer overwriting. However, once I set it, the buffer for the concatenated results got overwritten. I then tried setting that as well, but the subsequent Q, K, and V attention result buffers were still being overwritten. It seems there’s an issue with buffer allocation in the scheduler when handling parallel inputs on GPU. So I came up with this method to resolve the issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I am able to reproduce the issue. Looking into this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem is that here you are using the synchronous backend buffer call ggml_backend_tensor_get to get the output logits:
llama.cpp/src/llama-context.cpp
Lines 1247 to 1276 in 8fac4b1
| // EAGLE3: Map draft vocab to target vocab | |
| if (model.arch == LLM_ARCH_EAGLE3 && model.d2t) { | |
| static thread_local std::vector<int64_t> eagle3_d2t_map; | |
| static thread_local std::vector<float> eagle3_draft_logits; | |
| const int64_t draft_vocab_size = t_logits->ne[0]; | |
| const uint32_t last_idx = n_outputs - 1; | |
| // Load d2t mapping once (on first call) | |
| if (eagle3_d2t_map.empty()) { | |
| eagle3_d2t_map.resize(model.d2t->ne[0]); | |
| ggml_backend_tensor_get(model.d2t, eagle3_d2t_map.data(), 0, eagle3_d2t_map.size() * sizeof(int64_t)); | |
| } | |
| // Read only the last token's draft logits | |
| eagle3_draft_logits.resize(draft_vocab_size); | |
| const size_t last_offset = last_idx * draft_vocab_size * sizeof(float); | |
| ggml_backend_tensor_get(t_logits, eagle3_draft_logits.data(), last_offset, draft_vocab_size * sizeof(float)); | |
| // Map only the last token's draft logits to target vocab | |
| float * last_logits_out = logits_out + last_idx * n_vocab; | |
| std::fill(last_logits_out, last_logits_out + n_vocab, -std::numeric_limits<float>::infinity()); | |
| for (int64_t j = 0; j < draft_vocab_size; j++) { | |
| const int64_t target_id = j + eagle3_d2t_map[j]; | |
| GGML_ASSERT(target_id >= 0 && target_id < n_vocab); | |
| last_logits_out[target_id] = eagle3_draft_logits[j]; | |
| } | |
| } else { |
This is incorrect because the call will get queued in a different stream compared to where the computation runs, so effectively it will not wait for the computation to finish before extracting the result.
To fix this, use the backend async call like this for now:
diff --git a/src/llama-context.cpp b/src/llama-context.cpp
index ea6dfaea3..3506edd92 100644
--- a/src/llama-context.cpp
+++ b/src/llama-context.cpp
@@ -1261,7 +1261,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
// Read only the last token's draft logits
eagle3_draft_logits.resize(draft_vocab_size);
const size_t last_offset = last_idx * draft_vocab_size * sizeof(float);
- ggml_backend_tensor_get(t_logits, eagle3_draft_logits.data(), last_offset, draft_vocab_size * sizeof(float));
+ ggml_backend_tensor_get_async(backend_res, t_logits, eagle3_draft_logits.data(), last_offset, draft_vocab_size * sizeof(float));
+ synchronize();
// Map only the last token's draft logits to target vocab
diff --git a/src/models/eagle3.cpp b/src/models/eagle3.cpp
index 8987a0c58..43d7a331d 100644
--- a/src/models/eagle3.cpp
+++ b/src/models/eagle3.cpp
@@ -65,7 +65,7 @@ llm_build_eagle3_decode::llm_build_eagle3_decode(const llama_model & model, cons
// Force a sync point between the two parallel RMS_NORM paths
// This prevents buffer reuse issues on GPU (EAGLE3 GPU fix)
- ggml_set_sync(input_embeds_normed);
+ //ggml_set_sync(input_embeds_normed);
// Apply hidden_norm to g_embeddings
ggml_tensor * g_embeddings_normed = build_norm(g_embeddings,Please confirm that with this patch, you don't need the ggml_set_sync stuff.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it actually required to use get_async, or is there just a missing synchronize() after the async graph_compute call?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not required - synchronize() before the tensor_get() should also work. It's just that I expect that this synchronization will eventually be moved up the stack, similar to how we don't synchronize when extracting the regular logits data below, and this would have to become tensor_get_async either way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @ggerganov for pointing this out! I just updated this PR to fix the bug and remove the ggml_set_sync API. Rebuilt and tested, everything works well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great.
Btw, do you mind if I push in the branch directly? I want to do a cleanup pass over the implementation and it would be easier for me to push directly instead of creating PRs to your branch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, please go ahead. If there’s anything I could help with, just let me know.
|
Judging by the description of this PR, I believe many models with multiple-token prediction also have the same strategy of reusing hidden features from the main model. It can be quite interesting to generalize this features to support other models. I would expect some kind of sub- |
I will definitely be looking at refactoring the implementation to become more generic before merging it. The initial results in terms of performance are really great, but we'll need to work on cleaning up the code and reduce the special-casing in several places. I'll try to provide insights how to do that in the next days. |
Thanks @ggerganov @ngxson for your inputs. Definitely, looking forward to hearing your feedback and improving this PR. |
As discussed in #15902, Eagle3 represents the current SOTA in speculative decoding and is widely adopted across the industry. Integrating Eagle3 into llama.cpp enhances its performance and strengthens its competitiveness among leading inference frameworks. With Eagle3 speculative decoding now integrated into llama.cpp, inference performance has been significantly improved, achieving a 2–3× speedup.
This enhancement is the result of close collaboration between the NVIDIA and GGML teams, showcasing a strong technical partnership.
The following provides a brief overview of this PR:
EAGLE3 is an encoder-decoder based speculative decoding method:
Key changes:
EAGLE3 Architecture Overview :
How to run EAGLE3 in llama.cpp
Requirements
This PR currently only support two EAGLE3 models:
Step 1: Convert Models to GGUF Format
Step 2: Compile llama.cpp
Step 3: Run EAGLE3 Speculative Decoding
Performance Evaluation (RTX A6000 48GB)
Note: Using the chat_template for each model version can improve acceptance rates. Always apply the model’s corresponding chat_template when constructing prompts.
Q4_K_M, its Eagle3 withQ4_K_MQ4_K_M, its Eagle3 withQ4_K_MDetails of GGML backend modifications(Fixed, no longer needed)In the Eagle3 decoder, two parallel inputs are processed:When both RMS_NORM operations run in the same GPU split, a lack of synchronization causes buffer contention and race conditions (CPU execution is fine as it auto‑syncs between subgraphs).Solution:Useggml_set_sync()to add a synchronization point after the first RMS_NORM, forcing the scheduler to create a split boundary and synchronize before continuing.This ensures correct execution and can be applied to any parallel path that needs synchronization, not just Eagle3.Examples results
Future Steps