-
-
Notifications
You must be signed in to change notification settings - Fork 174
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
148 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import asyncio | ||
from typing import List, Tuple | ||
|
||
import torch | ||
from FlagEmbedding import FlagReranker | ||
|
||
from autorag.nodes.passagereranker.base import passage_reranker_node | ||
from autorag.utils.util import process_batch | ||
|
||
|
||
@passage_reranker_node | ||
def flag_embedding_reranker(queries: List[str], contents_list: List[List[str]], | ||
scores_list: List[List[float]], ids_list: List[List[str]], | ||
top_k: int, batch: int = 64, use_fp16: bool = False, | ||
model_name: str = "BAAI/bge-reranker-large", | ||
) -> Tuple[List[List[str]], List[List[str]], List[List[float]]]: | ||
""" | ||
Rerank a list of contents based on their relevance to a query using BAAI Reranker model. | ||
:param queries: The list of queries to use for reranking | ||
:param contents_list: The list of lists of contents to rerank | ||
:param scores_list: The list of lists of scores retrieved from the initial ranking | ||
:param ids_list: The list of lists of ids retrieved from the initial ranking | ||
:param top_k: The number of passages to be retrieved | ||
:param batch: The number of queries to be processed in a batch | ||
:param use_fp16: Whether to use fp16 for inference | ||
:param model_name: The name of the BAAI Reranker model name. | ||
Default is "BAAI/bge-reranker-large" | ||
:return: tuple of lists containing the reranked contents, ids, and scores | ||
""" | ||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
model = FlagReranker( | ||
model_name, use_fp16=use_fp16, device=device | ||
) | ||
tasks = [flag_embedding_reranker_pure(query, contents, scores, top_k, ids, model) | ||
for query, contents, scores, ids in zip(queries, contents_list, scores_list, ids_list)] | ||
loop = asyncio.get_event_loop() | ||
results = loop.run_until_complete(process_batch(tasks, batch_size=batch)) | ||
content_result = list(map(lambda x: x[0], results)) | ||
id_result = list(map(lambda x: x[1], results)) | ||
score_result = list(map(lambda x: x[2], results)) | ||
|
||
del model | ||
if torch.cuda.is_available(): | ||
torch.cuda.empty_cache() | ||
|
||
return content_result, id_result, score_result | ||
|
||
|
||
async def flag_embedding_reranker_pure(query: str, contents: List[str], scores: List[float], top_k: int, | ||
ids: List[str], model) -> Tuple[List[str], List[str], List[float]]: | ||
""" | ||
Rerank a list of contents based on their relevance to a query using BAAI Reranker model. | ||
:param query: The query to use for reranking | ||
:param contents: The list of contents to rerank | ||
:param scores: The list of scores retrieved from the initial ranking | ||
:param ids: The list of ids retrieved from the initial ranking | ||
:param top_k: The number of passages to be retrieved | ||
:param model: BAAI Reranker model. | ||
:return: tuple of lists containing the reranked contents, ids, and scores | ||
""" | ||
input_texts = [(query, content) for content in contents] | ||
with torch.no_grad(): | ||
pred_scores = model.compute_score(sentence_pairs=input_texts) | ||
|
||
content_ids_probs = list(zip(contents, ids, pred_scores)) | ||
|
||
# Sort the list of pairs based on the relevance score in descending order | ||
sorted_content_ids_probs = sorted(content_ids_probs, key=lambda x: x[2], reverse=True) | ||
|
||
# crop with top_k | ||
if len(contents) < top_k: | ||
top_k = len(contents) | ||
sorted_content_ids_probs = sorted_content_ids_probs[:top_k] | ||
|
||
content_result, id_result, score_result = zip(*sorted_content_ids_probs) | ||
|
||
return list(content_result), list(id_result), list(score_result) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.