Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : add reranking support #9510

Merged
merged 25 commits into from
Sep 28, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
3453e62
py : add XLMRobertaForSequenceClassification [no ci]
ggerganov Sep 16, 2024
77723ed
py : fix scalar-tensor conversion [no ci]
ggerganov Sep 17, 2024
49f90de
py : fix position embeddings chop [no ci]
ggerganov Sep 17, 2024
dc0cdd8
llama : read new cls tensors [no ci]
ggerganov Sep 17, 2024
d0a7bf9
llama : add classigication head (wip) [no ci]
ggerganov Sep 18, 2024
125a067
llama : add "rank" pooling type
ggerganov Sep 19, 2024
6235c62
server : add rerank endpoint
ggerganov Sep 19, 2024
6916ed1
llama : aboud ggml_repeat during classification
ggerganov Sep 23, 2024
62a45d1
rerank : cleanup + comments
ggerganov Sep 25, 2024
7bde9a0
server : accept /rerank endpoint in addition to /v1/rerank [no ci]
ggerganov Sep 25, 2024
c62a39d
embedding : parse special tokens
ggerganov Sep 25, 2024
866c011
jina : support v1 reranker
ggerganov Sep 25, 2024
84f56f3
vocab : minor style
ggerganov Sep 25, 2024
00b3376
server : initiate tests for later
ggerganov Sep 26, 2024
877a04c
server : add docs
ggerganov Sep 26, 2024
4d45775
llama : add comment [no ci]
ggerganov Sep 26, 2024
ca99a6c
llama : fix uninitialized tensors
ggerganov Sep 26, 2024
f19554f
ci : add rerank tests
ggerganov Sep 26, 2024
f27dd69
add reranking test
ngxson Sep 26, 2024
1ae8376
change test data
ngxson Sep 26, 2024
84b0af8
Update examples/server/server.cpp
ggerganov Sep 27, 2024
0d6f6a7
add `--reranking` argument
ngxson Sep 27, 2024
a4ac45f
update server docs
ngxson Sep 27, 2024
39167b6
llama : fix comment [no ci]
ggerganov Sep 28, 2024
aeac876
Merge branch 'master' into gg/rerank
ggerganov Sep 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
llama : add classigication head (wip) [no ci]
  • Loading branch information
ggerganov committed Sep 25, 2024
commit d0a7bf9382782368b57e68585b8926aa875a2f95
2 changes: 1 addition & 1 deletion common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
[](gpt_params & params) {
params.verbose_prompt = true;
}
).set_examples({LLAMA_EXAMPLE_MAIN}));
));
add_opt(llama_arg(
{"--no-display-prompt"},
format("don't print prompt at generation (default: %s)", !params.display_prompt ? "true" : "false"),
Expand Down
14 changes: 13 additions & 1 deletion src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11455,8 +11455,20 @@ struct llm_build_context {
inpL = cur;
}

// final output
cur = inpL;

// classification head
// https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
// TODO: become pooling layer?
if (model.cls) {
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.cls, cur), model.cls_b);

cur = ggml_tanh(ctx0, cur);

cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.cls_out, cur), model.cls_out_b);
// TODO: cur is now a scalar - what to do?
}

cb(cur, "result_embd", -1);

ggml_build_forward_expand(gf, cur);
Expand Down