Skip to content

Commit

Permalink
Implement flag_embedding_reranker
Browse files Browse the repository at this point in the history
  • Loading branch information
bwook00 committed Apr 8, 2024
1 parent da7f3f5 commit 8550b33
Show file tree
Hide file tree
Showing 10 changed files with 148 additions and 16 deletions.
16 changes: 8 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -237,14 +237,14 @@ node_lines:

# ❗Supporting Nodes & modules

| Nodes | Modules |
|:-----------------------------------------------------------------------------------------------------------------:|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|
| [Query_Expansion](https://marker-inc-korea.github.io/AutoRAG/nodes/query_expansion/query_expansion.html) | [Query_Decompose](https://marker-inc-korea.github.io/AutoRAG/nodes/query_expansion/query_decompose.html)<br/>[HyDE](https://marker-inc-korea.github.io/AutoRAG/nodes/query_expansion/hyde.html) |
| [Retrieval](https://marker-inc-korea.github.io/AutoRAG/nodes/retrieval/retrieval.html) | [BM25](https://marker-inc-korea.github.io/AutoRAG/nodes/retrieval/bm25.html)<br/>[VectorDB (choose embedding model)](https://marker-inc-korea.github.io/AutoRAG/nodes/retrieval/vectordb.html)<br/>[Hybrid with rrf (reciprocal rank fusion)](https://marker-inc-korea.github.io/AutoRAG/nodes/retrieval/hybrid_rrf.html)<br/>[Hybrid with cc (convex combination)](https://marker-inc-korea.github.io/AutoRAG/nodes/retrieval/hybrid_cc.html) |
| [Passage_Reranker](https://marker-inc-korea.github.io/AutoRAG/nodes/passage_reranker/passage_reranker.html) | [UPR](https://marker-inc-korea.github.io/AutoRAG/nodes/passage_reranker/upr.html)<br/>[Tart](https://marker-inc-korea.github.io/AutoRAG/nodes/passage_reranker/tart.html)<br/>[MonoT5](https://marker-inc-korea.github.io/AutoRAG/nodes/passage_reranker/monot5.html)<br/>[Ko-reranker](https://marker-inc-korea.github.io/AutoRAG/nodes/passage_reranker/koreranker.html)<br/>[cohere_reranker](https://marker-inc-korea.github.io/AutoRAG/nodes/passage_reranker/cohere.html)<br/>[RankGPT](https://marker-inc-korea.github.io/AutoRAG/nodes/passage_reranker/rankgpt.html)<br/>[Jina Reranker](https://marker-inc-korea.github.io/AutoRAG/nodes/passage_reranker/jina_reranker.html)<br/>[Sentence Transformer Reranker](https://marker-inc-korea.github.io/AutoRAG/nodes/passage_reranker/setence_transformer_reranker.html) |
| [Passage_Compressor](https://marker-inc-korea.github.io/AutoRAG/nodes/passage_compressor/passage_compressor.html) | [Tree Summarize](https://marker-inc-korea.github.io/AutoRAG/nodes/passage_compressor/tree_summarize.html) |
| [Prompt Maker](https://marker-inc-korea.github.io/AutoRAG/nodes/prompt_maker/prompt_maker.html) | [Default Prompt Maker (f-string)](https://marker-inc-korea.github.io/AutoRAG/nodes/prompt_maker/fstring.html) |
| [Generator](https://marker-inc-korea.github.io/AutoRAG/nodes/generator/generator.html) | [llama_index llm](https://marker-inc-korea.github.io/AutoRAG/nodes/generator/llama_index_llm.html)<br/>[vllm](https://marker-inc-korea.github.io/AutoRAG/nodes/generator/vllm.html) |
| Nodes | Modules |
|:-----------------------------------------------------------------------------------------------------------------:|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|
| [Query_Expansion](https://marker-inc-korea.github.io/AutoRAG/nodes/query_expansion/query_expansion.html) | [Query_Decompose](https://marker-inc-korea.github.io/AutoRAG/nodes/query_expansion/query_decompose.html)<br/>[HyDE](https://marker-inc-korea.github.io/AutoRAG/nodes/query_expansion/hyde.html) |
| [Retrieval](https://marker-inc-korea.github.io/AutoRAG/nodes/retrieval/retrieval.html) | [BM25](https://marker-inc-korea.github.io/AutoRAG/nodes/retrieval/bm25.html)<br/>[VectorDB (choose embedding model)](https://marker-inc-korea.github.io/AutoRAG/nodes/retrieval/vectordb.html)<br/>[Hybrid with rrf (reciprocal rank fusion)](https://marker-inc-korea.github.io/AutoRAG/nodes/retrieval/hybrid_rrf.html)<br/>[Hybrid with cc (convex combination)](https://marker-inc-korea.github.io/AutoRAG/nodes/retrieval/hybrid_cc.html) |
| [Passage_Reranker](https://marker-inc-korea.github.io/AutoRAG/nodes/passage_reranker/passage_reranker.html) | [UPR](https://marker-inc-korea.github.io/AutoRAG/nodes/passage_reranker/upr.html)<br/>[Tart](https://marker-inc-korea.github.io/AutoRAG/nodes/passage_reranker/tart.html)<br/>[MonoT5](https://marker-inc-korea.github.io/AutoRAG/nodes/passage_reranker/monot5.html)<br/>[Ko-reranker](https://marker-inc-korea.github.io/AutoRAG/nodes/passage_reranker/koreranker.html)<br/>[cohere_reranker](https://marker-inc-korea.github.io/AutoRAG/nodes/passage_reranker/cohere.html)<br/>[RankGPT](https://marker-inc-korea.github.io/AutoRAG/nodes/passage_reranker/rankgpt.html)<br/>[Jina Reranker](https://marker-inc-korea.github.io/AutoRAG/nodes/passage_reranker/jina_reranker.html)<br/>[Sentence Transformer Reranker](https://marker-inc-korea.github.io/AutoRAG/nodes/passage_reranker/setence_transformer_reranker.html)<br/>[Flag Embedding Reranker](https://marker-inc-korea.github.io/AutoRAG/nodes/passage_reranker/flag_embedding_reranker.html) |
| [Passage_Compressor](https://marker-inc-korea.github.io/AutoRAG/nodes/passage_compressor/passage_compressor.html) | [Tree Summarize](https://marker-inc-korea.github.io/AutoRAG/nodes/passage_compressor/tree_summarize.html) |
| [Prompt Maker](https://marker-inc-korea.github.io/AutoRAG/nodes/prompt_maker/prompt_maker.html) | [Default Prompt Maker (f-string)](https://marker-inc-korea.github.io/AutoRAG/nodes/prompt_maker/fstring.html) |
| [Generator](https://marker-inc-korea.github.io/AutoRAG/nodes/generator/generator.html) | [llama_index llm](https://marker-inc-korea.github.io/AutoRAG/nodes/generator/llama_index_llm.html)<br/>[vllm](https://marker-inc-korea.github.io/AutoRAG/nodes/generator/vllm.html) |

# 🛣Roadmap

Expand Down
1 change: 1 addition & 0 deletions autorag/nodes/passagereranker/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .cohere import cohere_reranker
from .flag_embedding import flag_embedding_reranker
from .jina import jina_reranker
from .koreranker import koreranker
from .monot5 import monot5
Expand Down
79 changes: 79 additions & 0 deletions autorag/nodes/passagereranker/flag_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import asyncio
from typing import List, Tuple

import torch
from FlagEmbedding import FlagReranker

from autorag.nodes.passagereranker.base import passage_reranker_node
from autorag.utils.util import process_batch


@passage_reranker_node
def flag_embedding_reranker(queries: List[str], contents_list: List[List[str]],
scores_list: List[List[float]], ids_list: List[List[str]],
top_k: int, batch: int = 64, use_fp16: bool = False,
model_name: str = "BAAI/bge-reranker-large",
) -> Tuple[List[List[str]], List[List[str]], List[List[float]]]:
"""
Rerank a list of contents based on their relevance to a query using BAAI Reranker model.
:param queries: The list of queries to use for reranking
:param contents_list: The list of lists of contents to rerank
:param scores_list: The list of lists of scores retrieved from the initial ranking
:param ids_list: The list of lists of ids retrieved from the initial ranking
:param top_k: The number of passages to be retrieved
:param batch: The number of queries to be processed in a batch
:param use_fp16: Whether to use fp16 for inference
:param model_name: The name of the BAAI Reranker model name.
Default is "BAAI/bge-reranker-large"
:return: tuple of lists containing the reranked contents, ids, and scores
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = FlagReranker(
model_name, use_fp16=use_fp16, device=device
)
tasks = [flag_embedding_reranker_pure(query, contents, scores, top_k, ids, model)
for query, contents, scores, ids in zip(queries, contents_list, scores_list, ids_list)]
loop = asyncio.get_event_loop()
results = loop.run_until_complete(process_batch(tasks, batch_size=batch))
content_result = list(map(lambda x: x[0], results))
id_result = list(map(lambda x: x[1], results))
score_result = list(map(lambda x: x[2], results))

del model
if torch.cuda.is_available():
torch.cuda.empty_cache()

return content_result, id_result, score_result


async def flag_embedding_reranker_pure(query: str, contents: List[str], scores: List[float], top_k: int,
ids: List[str], model) -> Tuple[List[str], List[str], List[float]]:
"""
Rerank a list of contents based on their relevance to a query using BAAI Reranker model.
:param query: The query to use for reranking
:param contents: The list of contents to rerank
:param scores: The list of scores retrieved from the initial ranking
:param ids: The list of ids retrieved from the initial ranking
:param top_k: The number of passages to be retrieved
:param model: BAAI Reranker model.
:return: tuple of lists containing the reranked contents, ids, and scores
"""
input_texts = [(query, content) for content in contents]
with torch.no_grad():
pred_scores = model.compute_score(sentence_pairs=input_texts)

content_ids_probs = list(zip(contents, ids, pred_scores))

# Sort the list of pairs based on the relevance score in descending order
sorted_content_ids_probs = sorted(content_ids_probs, key=lambda x: x[2], reverse=True)

# crop with top_k
if len(contents) < top_k:
top_k = len(contents)
sorted_content_ids_probs = sorted_content_ids_probs[:top_k]

content_result, id_result, score_result = zip(*sorted_content_ids_probs)

return list(content_result), list(id_result), list(score_result)
1 change: 1 addition & 0 deletions autorag/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def get_support_modules(module_name: str) -> Callable:
'rankgpt': ('autorag.nodes.passagereranker', 'rankgpt'),
'jina_reranker': ('autorag.nodes.passagereranker', 'jina_reranker'),
'sentence_transformer_reranker': ('autorag.nodes.passagereranker', 'sentence_transformer_reranker'),
'flag_embedding_reranker': ('autorag.nodes.passagereranker', 'flag_embedding_reranker'),
# passage_compressor
'tree_summarize': ('autorag.nodes.passagecompressor', 'tree_summarize'),
'pass_compressor': ('autorag.nodes.passagecompressor', 'pass_compressor'),
Expand Down
Loading

0 comments on commit 8550b33

Please sign in to comment.