From 4cdf1f96f2202d7ae90f4be696f6e88112306f48 Mon Sep 17 00:00:00 2001 From: kimbwook Date: Sun, 6 Oct 2024 17:40:16 +0900 Subject: [PATCH 01/11] Baseline --- autorag/nodes/passagereranker/voyageai.py | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 autorag/nodes/passagereranker/voyageai.py diff --git a/autorag/nodes/passagereranker/voyageai.py b/autorag/nodes/passagereranker/voyageai.py new file mode 100644 index 000000000..94c256db3 --- /dev/null +++ b/autorag/nodes/passagereranker/voyageai.py @@ -0,0 +1,3 @@ +import voyageai + +vo = voyageai.Client(api_key="") From 8506c1cff45f22edbd840f1692215dc0a7044ec0 Mon Sep 17 00:00:00 2001 From: kimbwook Date: Sun, 6 Oct 2024 18:47:07 +0900 Subject: [PATCH 02/11] Add test code --- autorag/nodes/passagereranker/__init__.py | 1 + autorag/nodes/passagereranker/voyageai.py | 108 +++++++++++++++++- .../passagereranker/test_voyageai_reranker.py | 83 ++++++++++++++ 3 files changed, 191 insertions(+), 1 deletion(-) create mode 100644 tests/autorag/nodes/passagereranker/test_voyageai_reranker.py diff --git a/autorag/nodes/passagereranker/__init__.py b/autorag/nodes/passagereranker/__init__.py index 33706e028..c284b3988 100644 --- a/autorag/nodes/passagereranker/__init__.py +++ b/autorag/nodes/passagereranker/__init__.py @@ -11,3 +11,4 @@ from .tart.tart import Tart from .time_reranker import TimeReranker from .upr import Upr +from .voyageai import VoyageAIReranker diff --git a/autorag/nodes/passagereranker/voyageai.py b/autorag/nodes/passagereranker/voyageai.py index 94c256db3..beb4208cc 100644 --- a/autorag/nodes/passagereranker/voyageai.py +++ b/autorag/nodes/passagereranker/voyageai.py @@ -1,3 +1,109 @@ +import os +from typing import List, Tuple +import pandas as pd import voyageai -vo = voyageai.Client(api_key="") +from autorag.nodes.passagereranker.base import BasePassageReranker +from autorag.utils.util import result_to_dataframe + + +class VoyageAIReranker(BasePassageReranker): + def __init__(self, project_dir: str, *args, **kwargs): + super().__init__(project_dir) + api_key = kwargs.pop("api_key", None) + api_key = os.getenv("VOYAGE_API_KEY", None) if api_key is None else api_key + if api_key is None: + raise KeyError( + "Please set the API key for VoyageAI rerank in the environment variable VOYAGE_API_KEY " + "or directly set it on the config YAML file." + ) + + self.voyage_client = voyageai.Client() + + def __del__(self): + del self.voyage_client + super().__del__() + + @result_to_dataframe(["retrieved_contents", "retrieved_ids", "retrieve_scores"]) + def pure(self, previous_result: pd.DataFrame, *args, **kwargs): + queries, contents, scores, ids = self.cast_to_run(previous_result) + top_k = kwargs.pop("top_k") + batch = kwargs.pop("batch", 64) + model = kwargs.pop("model", "rerank-2") + truncation = kwargs.pop("truncation", True) + return self._pure(queries, contents, ids, top_k, batch, model, truncation) + + def _pure( + self, + queries: List[str], + contents_list: List[List[str]], + ids_list: List[List[str]], + top_k: int, + batch: int = 64, + model: str = "rerank-2", + truncation: bool = True, + ) -> Tuple[List[List[str]], List[List[str]], List[List[float]]]: + """ + Rerank a list of contents with VoyageAI rerank models. + You can get the API key from https://docs.voyageai.com/docs/api-key-and-installation and set it in the environment variable VOYAGE_API_KEY. + + :param queries: The list of queries to use for reranking + :param contents_list: The list of lists of contents to rerank + :param ids_list: The list of lists of ids retrieved from the initial ranking + :param top_k: The number of passages to be retrieved + :param batch: The number of queries to be processed in a batch + :param model: The model name for Cohere rerank. + You can choose between "rerank-2" and "rerank-2-lite". + Default is "rerank-2". + :param truncation: Whether to truncate the input to satisfy the 'context length limit' on the query and the documents. + :return: Tuple of lists containing the reranked contents, ids, and scores + """ + content_result, id_result, score_result = zip( + *[ + voyageai_rerank_pure( + self.voyage_client, model, query, document, ids, top_k, truncation + ) + for query, document, ids in zip(queries, contents_list, ids_list) + ] + ) + + return content_result, id_result, score_result + + +def voyageai_rerank_pure( + voyage_client: voyageai.Client, + model: str, + query: str, + documents: List[str], + ids: List[str], + top_k: int, + truncation: bool = True, +) -> Tuple[List[str], List[str], List[float]]: + """ + Rerank a list of contents with Cohere rerank models. + + :param voyage_client: The Voyage Client to use for reranking + :param model: The model name for Cohere rerank + :param query: The query to use for reranking + :param documents: The list of contents to rerank + :param ids: The list of ids corresponding to the documents + :param top_k: The number of passages to be retrieved + :param truncation: Whether to truncate the input to satisfy the 'context length limit' on the query and the documents. + :return: Tuple of lists containing the reranked contents, ids, and scores + """ + rerank_results = voyage_client.rerank( + model=model, + query=query, + documents=documents, + top_k=top_k, + truncation=truncation, + ) + reranked_scores: List[float] = list( + map(lambda x: x.relevance_score, rerank_results.results) + ) + reranked_contents: List[str] = list( + map(lambda x: x.document, rerank_results.results) + ) + indices = list(map(lambda x: x.index, rerank_results.results)) + reranked_ids: List[str] = list(map(lambda i: ids[i], indices)) + return reranked_contents, reranked_ids, reranked_scores diff --git a/tests/autorag/nodes/passagereranker/test_voyageai_reranker.py b/tests/autorag/nodes/passagereranker/test_voyageai_reranker.py new file mode 100644 index 000000000..5ac4132a7 --- /dev/null +++ b/tests/autorag/nodes/passagereranker/test_voyageai_reranker.py @@ -0,0 +1,83 @@ +from unittest.mock import patch + +import pytest + +import autorag +from autorag.nodes.passagereranker import VoyageAIReranker +from tests.autorag.nodes.passagereranker.test_passage_reranker_base import ( + queries_example, + contents_example, + ids_example, + base_reranker_test, + project_dir, + previous_result, + base_reranker_node_test, +) + + +def mock_voyageai_reranker_pure( + voyage_client, model, query, documents, ids, top_k, truncation +): + if query == queries_example[0]: + return ( + [documents[1], documents[2], documents[0]][:top_k], + [ids[1], ids[2], ids[0]][:top_k], + [0.8, 0.2, 0.1][:top_k], + ) + elif query == queries_example[1]: + return ( + [documents[1], documents[0], documents[2]][:top_k], + [ids[1], ids[0], ids[2]][:top_k], + [0.8, 0.2, 0.1][:top_k], + ) + else: + raise ValueError(f"Unexpected query: {query}") + + +@pytest.fixture +def voyageai_reranker_instance(): + return VoyageAIReranker(project_dir, api_key="mock_api_key") + + +@patch.object( + autorag.nodes.passagereranker.voyageai, + "voyageai_rerank_pure", + mock_voyageai_reranker_pure, +) +def test_voyageai_reranker(voyageai_reranker_instance): + top_k = 3 + contents_result, id_result, score_result = voyageai_reranker_instance._pure( + queries_example, contents_example, ids_example, top_k + ) + base_reranker_test(contents_result, id_result, score_result, top_k) + + +@patch.object( + autorag.nodes.passagereranker.voyageai, + "voyageai_rerank_pure", + mock_voyageai_reranker_pure, +) +def test_voyageai_reranker_batch_one(voyageai_reranker_instance): + top_k = 3 + batch = 1 + contents_result, id_result, score_result = voyageai_reranker_instance._pure( + queries_example, + contents_example, + ids_example, + top_k, + batch=batch, + ) + base_reranker_test(contents_result, id_result, score_result, top_k) + + +@patch.object( + autorag.nodes.passagereranker.voyageai, + "voyageai_rerank_pure", + mock_voyageai_reranker_pure, +) +def test_voyageai_reranker_node(): + top_k = 1 + result_df = VoyageAIReranker.run_evaluator( + project_dir=project_dir, previous_result=previous_result, top_k=top_k + ) + base_reranker_node_test(result_df, top_k) From 276b08043a0dea4f39329da98827997154e86f8d Mon Sep 17 00:00:00 2001 From: kimbwook Date: Sun, 6 Oct 2024 18:48:06 +0900 Subject: [PATCH 03/11] Add mock api key at node test --- .../autorag/nodes/passagereranker/test_voyageai_reranker.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/autorag/nodes/passagereranker/test_voyageai_reranker.py b/tests/autorag/nodes/passagereranker/test_voyageai_reranker.py index 5ac4132a7..cf1673e8d 100644 --- a/tests/autorag/nodes/passagereranker/test_voyageai_reranker.py +++ b/tests/autorag/nodes/passagereranker/test_voyageai_reranker.py @@ -78,6 +78,9 @@ def test_voyageai_reranker_batch_one(voyageai_reranker_instance): def test_voyageai_reranker_node(): top_k = 1 result_df = VoyageAIReranker.run_evaluator( - project_dir=project_dir, previous_result=previous_result, top_k=top_k + project_dir=project_dir, + previous_result=previous_result, + top_k=top_k, + api_key="mock_api_key", ) base_reranker_node_test(result_df, top_k) From 5c8bfebdfe7ef3a0bddda0c6d9740ca63eace5dd Mon Sep 17 00:00:00 2001 From: kimbwook Date: Sun, 6 Oct 2024 18:50:24 +0900 Subject: [PATCH 04/11] Add rst full.yaml support.py --- autorag/nodes/passagereranker/voyageai.py | 6 +++--- autorag/support.py | 2 ++ docs/source/api_spec/autorag.nodes.passagereranker.rst | 8 ++++++++ sample_config/rag/full.yaml | 1 + 4 files changed, 14 insertions(+), 3 deletions(-) diff --git a/autorag/nodes/passagereranker/voyageai.py b/autorag/nodes/passagereranker/voyageai.py index beb4208cc..5cd3ccbd6 100644 --- a/autorag/nodes/passagereranker/voyageai.py +++ b/autorag/nodes/passagereranker/voyageai.py @@ -52,7 +52,7 @@ def _pure( :param ids_list: The list of lists of ids retrieved from the initial ranking :param top_k: The number of passages to be retrieved :param batch: The number of queries to be processed in a batch - :param model: The model name for Cohere rerank. + :param model: The model name for VoyageAI rerank. You can choose between "rerank-2" and "rerank-2-lite". Default is "rerank-2". :param truncation: Whether to truncate the input to satisfy the 'context length limit' on the query and the documents. @@ -80,10 +80,10 @@ def voyageai_rerank_pure( truncation: bool = True, ) -> Tuple[List[str], List[str], List[float]]: """ - Rerank a list of contents with Cohere rerank models. + Rerank a list of contents with VoyageAI rerank models. :param voyage_client: The Voyage Client to use for reranking - :param model: The model name for Cohere rerank + :param model: The model name for VoyageAI rerank :param query: The query to use for reranking :param documents: The list of contents to rerank :param ids: The list of ids corresponding to the documents diff --git a/autorag/support.py b/autorag/support.py index 7b587c062..620d97776 100644 --- a/autorag/support.py +++ b/autorag/support.py @@ -117,6 +117,8 @@ def get_support_modules(module_name: str) -> Callable: ), "time_reranker": ("autorag.nodes.passagereranker", "TimeReranker"), "TimeReranker": ("autorag.nodes.passagereranker", "TimeReranker"), + "voyageai_reranker": ("autorag.nodes.passagereranker", "VoyageAIReranker"), + "VoyageAIReranker": ("autorag.nodes.passagereranker", "VoyageAIReranker"), # passage_filter "pass_passage_filter": ("autorag.nodes.passagefilter", "PassPassageFilter"), "similarity_threshold_cutoff": ( diff --git a/docs/source/api_spec/autorag.nodes.passagereranker.rst b/docs/source/api_spec/autorag.nodes.passagereranker.rst index 6745ebad8..6b4dc3898 100644 --- a/docs/source/api_spec/autorag.nodes.passagereranker.rst +++ b/docs/source/api_spec/autorag.nodes.passagereranker.rst @@ -124,6 +124,14 @@ autorag.nodes.passagereranker.upr module :undoc-members: :show-inheritance: +autorag.nodes.passagereranker.voyageai module +--------------------------------------------- + +.. automodule:: autorag.nodes.passagereranker.voyageai + :members: + :undoc-members: + :show-inheritance: + Module contents --------------- diff --git a/sample_config/rag/full.yaml b/sample_config/rag/full.yaml index 27a2751d1..d7faacdeb 100644 --- a/sample_config/rag/full.yaml +++ b/sample_config/rag/full.yaml @@ -73,6 +73,7 @@ node_lines: - module_type: flag_embedding_reranker - module_type: flag_embedding_llm_reranker - module_type: time_reranker + - module_type: voyageai_reranker - node_type: passage_filter strategy: metrics: [ retrieval_f1, retrieval_recall, retrieval_precision ] From b06e2850780c25a5a17e88a3365d850d6280ad0d Mon Sep 17 00:00:00 2001 From: kimbwook Date: Sun, 6 Oct 2024 18:56:45 +0900 Subject: [PATCH 05/11] delete batch parameters --- autorag/nodes/passagereranker/voyageai.py | 5 +---- .../passagereranker/test_voyageai_reranker.py | 18 ------------------ 2 files changed, 1 insertion(+), 22 deletions(-) diff --git a/autorag/nodes/passagereranker/voyageai.py b/autorag/nodes/passagereranker/voyageai.py index 5cd3ccbd6..dc6a03424 100644 --- a/autorag/nodes/passagereranker/voyageai.py +++ b/autorag/nodes/passagereranker/voyageai.py @@ -28,10 +28,9 @@ def __del__(self): def pure(self, previous_result: pd.DataFrame, *args, **kwargs): queries, contents, scores, ids = self.cast_to_run(previous_result) top_k = kwargs.pop("top_k") - batch = kwargs.pop("batch", 64) model = kwargs.pop("model", "rerank-2") truncation = kwargs.pop("truncation", True) - return self._pure(queries, contents, ids, top_k, batch, model, truncation) + return self._pure(queries, contents, ids, top_k, model, truncation) def _pure( self, @@ -39,7 +38,6 @@ def _pure( contents_list: List[List[str]], ids_list: List[List[str]], top_k: int, - batch: int = 64, model: str = "rerank-2", truncation: bool = True, ) -> Tuple[List[List[str]], List[List[str]], List[List[float]]]: @@ -51,7 +49,6 @@ def _pure( :param contents_list: The list of lists of contents to rerank :param ids_list: The list of lists of ids retrieved from the initial ranking :param top_k: The number of passages to be retrieved - :param batch: The number of queries to be processed in a batch :param model: The model name for VoyageAI rerank. You can choose between "rerank-2" and "rerank-2-lite". Default is "rerank-2". diff --git a/tests/autorag/nodes/passagereranker/test_voyageai_reranker.py b/tests/autorag/nodes/passagereranker/test_voyageai_reranker.py index cf1673e8d..70c06d360 100644 --- a/tests/autorag/nodes/passagereranker/test_voyageai_reranker.py +++ b/tests/autorag/nodes/passagereranker/test_voyageai_reranker.py @@ -52,24 +52,6 @@ def test_voyageai_reranker(voyageai_reranker_instance): base_reranker_test(contents_result, id_result, score_result, top_k) -@patch.object( - autorag.nodes.passagereranker.voyageai, - "voyageai_rerank_pure", - mock_voyageai_reranker_pure, -) -def test_voyageai_reranker_batch_one(voyageai_reranker_instance): - top_k = 3 - batch = 1 - contents_result, id_result, score_result = voyageai_reranker_instance._pure( - queries_example, - contents_example, - ids_example, - top_k, - batch=batch, - ) - base_reranker_test(contents_result, id_result, score_result, top_k) - - @patch.object( autorag.nodes.passagereranker.voyageai, "voyageai_rerank_pure", From b533848028b86ac92455ac771db565a0ebe8e1d2 Mon Sep 17 00:00:00 2001 From: kimbwook Date: Sun, 6 Oct 2024 18:58:33 +0900 Subject: [PATCH 06/11] Add docs --- .../passage_reranker/passage_reranker.md | 1 + .../passage_reranker/voyageai_reranker.md | 52 +++++++++++++++++++ 2 files changed, 53 insertions(+) create mode 100644 docs/source/nodes/passage_reranker/voyageai_reranker.md diff --git a/docs/source/nodes/passage_reranker/passage_reranker.md b/docs/source/nodes/passage_reranker/passage_reranker.md index 7004e87ef..efebc9089 100644 --- a/docs/source/nodes/passage_reranker/passage_reranker.md +++ b/docs/source/nodes/passage_reranker/passage_reranker.md @@ -70,4 +70,5 @@ sentence_transformer_reranker.md flag_embedding_reranker.md flag_embedding_llm_reranker.md time_reranker.md +voyageai_reranker.md ``` diff --git a/docs/source/nodes/passage_reranker/voyageai_reranker.md b/docs/source/nodes/passage_reranker/voyageai_reranker.md new file mode 100644 index 000000000..acb3a3f34 --- /dev/null +++ b/docs/source/nodes/passage_reranker/voyageai_reranker.md @@ -0,0 +1,52 @@ +--- +myst: + html_meta: + title: AutoRAG - VoyageAI Reranker + description: Learn about voyage ai reranker module in AutoRAG + keywords: AutoRAG,RAG,Advanced RAG,Reranker,VoyageAI +--- +# voyageai_reranker + +The `voyageai reranker` module is a reranker from [VoyageAI](https://www.voyageai.com/). +It supports powerful and fast reranker for passage retrieval. + +## Before Usage + +At first, you need to get the VoyageAI API key from [here](https://docs.voyageai.com/docs/api-key-and-installation). + +Next, you can set your VoyageAI API key in the environment variable "VOYAGE_API_KEY". + +```bash +export VOYAGE_API_KEY=your_voyageai_api_key +``` + +Or, you can set your VoyageAI API key in the config.yaml file directly. + +```yaml +- module_type: voyageai_reranker + api_key: your_voyageai_api_key +``` + +## **Module Parameters** + +- **model** : The type of model you want to use for reranking. Default is "rerank-2" and you can change + it to "rerank-2-lite" +- **api_key** : The voyageai api key. +- **truncation** : Whether to truncate the input to satisfy the 'context length limit' on the query and the documents. Default is True. + +## **Example config.yaml** + +```yaml +- module_type: voyageai_reranker + api_key: your_voyageai_api_key + model: rerank-2 +``` + +### Supported Model Names + +You can see the supported model names [here](https://docs.voyageai.com/docs/reranker). + +| Model Name | +|:-------------:| +| rerank-2 | +| rerank-2-lite | From 0e41efbfa4d9996efe96337c61efc9c305ea9d3b Mon Sep 17 00:00:00 2001 From: kimbwook Date: Sun, 6 Oct 2024 19:08:50 +0900 Subject: [PATCH 07/11] Add requirements.txt --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 0d474fe3f..b26b22d3f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,6 +25,7 @@ sentence-transformers # for sentence transformer reranker FlagEmbedding # for flag embedding reranker llmlingua # for longllmlingua peft +voyageai # for voyageai reranker ### LlamaIndex ### llama-index>=0.11.0 From fbc2e254e7396964d995ca09e51a90adc9d91e7f Mon Sep 17 00:00:00 2001 From: kimbwook Date: Sun, 6 Oct 2024 19:17:55 +0900 Subject: [PATCH 08/11] Set api_key parameter --- autorag/nodes/passagereranker/voyageai.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autorag/nodes/passagereranker/voyageai.py b/autorag/nodes/passagereranker/voyageai.py index dc6a03424..22ca96c02 100644 --- a/autorag/nodes/passagereranker/voyageai.py +++ b/autorag/nodes/passagereranker/voyageai.py @@ -18,7 +18,7 @@ def __init__(self, project_dir: str, *args, **kwargs): "or directly set it on the config YAML file." ) - self.voyage_client = voyageai.Client() + self.voyage_client = voyageai.Client(api_key=api_key) def __del__(self): del self.voyage_client From ad48a920efb94334dd7fe94c41d718357b34b27b Mon Sep 17 00:00:00 2001 From: kimbwook Date: Mon, 7 Oct 2024 23:05:18 +0900 Subject: [PATCH 09/11] apply async --- autorag/nodes/passagereranker/voyageai.py | 33 +++++++++++++---------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/autorag/nodes/passagereranker/voyageai.py b/autorag/nodes/passagereranker/voyageai.py index 22ca96c02..024ef628b 100644 --- a/autorag/nodes/passagereranker/voyageai.py +++ b/autorag/nodes/passagereranker/voyageai.py @@ -4,7 +4,7 @@ import voyageai from autorag.nodes.passagereranker.base import BasePassageReranker -from autorag.utils.util import result_to_dataframe +from autorag.utils.util import result_to_dataframe, get_event_loop, process_batch class VoyageAIReranker(BasePassageReranker): @@ -18,7 +18,7 @@ def __init__(self, project_dir: str, *args, **kwargs): "or directly set it on the config YAML file." ) - self.voyage_client = voyageai.Client(api_key=api_key) + self.voyage_client = voyageai.AsyncClient(api_key=api_key) def __del__(self): del self.voyage_client @@ -28,9 +28,10 @@ def __del__(self): def pure(self, previous_result: pd.DataFrame, *args, **kwargs): queries, contents, scores, ids = self.cast_to_run(previous_result) top_k = kwargs.pop("top_k") + batch = kwargs.pop("batch", 8) model = kwargs.pop("model", "rerank-2") truncation = kwargs.pop("truncation", True) - return self._pure(queries, contents, ids, top_k, model, truncation) + return self._pure(queries, contents, ids, top_k, model, batch, truncation) def _pure( self, @@ -39,6 +40,7 @@ def _pure( ids_list: List[List[str]], top_k: int, model: str = "rerank-2", + batch: int = 8, truncation: bool = True, ) -> Tuple[List[List[str]], List[List[str]], List[List[float]]]: """ @@ -52,22 +54,25 @@ def _pure( :param model: The model name for VoyageAI rerank. You can choose between "rerank-2" and "rerank-2-lite". Default is "rerank-2". + :param batch: The number of queries to be processed in a batch :param truncation: Whether to truncate the input to satisfy the 'context length limit' on the query and the documents. :return: Tuple of lists containing the reranked contents, ids, and scores """ - content_result, id_result, score_result = zip( - *[ - voyageai_rerank_pure( - self.voyage_client, model, query, document, ids, top_k, truncation - ) - for query, document, ids in zip(queries, contents_list, ids_list) - ] - ) + tasks = [ + voyageai_rerank_pure( + self.client, query, contents, ids, top_k=top_k, model=model + ) + for query, contents, ids in zip(queries, contents_list, ids_list) + ] + loop = get_event_loop() + results = loop.run_until_complete(process_batch(tasks, batch)) + + content_result, id_result, score_result = zip(*results) - return content_result, id_result, score_result + return list(content_result), list(id_result), list(score_result) -def voyageai_rerank_pure( +async def voyageai_rerank_pure( voyage_client: voyageai.Client, model: str, query: str, @@ -88,7 +93,7 @@ def voyageai_rerank_pure( :param truncation: Whether to truncate the input to satisfy the 'context length limit' on the query and the documents. :return: Tuple of lists containing the reranked contents, ids, and scores """ - rerank_results = voyage_client.rerank( + rerank_results = await voyage_client.rerank( model=model, query=query, documents=documents, From 7d992f6d59c0ffd3cd4f75e7aa1136582ae5fe91 Mon Sep 17 00:00:00 2001 From: kimbwook Date: Mon, 7 Oct 2024 23:15:50 +0900 Subject: [PATCH 10/11] apply voyage_client --- autorag/nodes/passagereranker/voyageai.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/autorag/nodes/passagereranker/voyageai.py b/autorag/nodes/passagereranker/voyageai.py index 024ef628b..619105865 100644 --- a/autorag/nodes/passagereranker/voyageai.py +++ b/autorag/nodes/passagereranker/voyageai.py @@ -60,7 +60,7 @@ def _pure( """ tasks = [ voyageai_rerank_pure( - self.client, query, contents, ids, top_k=top_k, model=model + self.voyage_client, model, query, contents, ids, top_k, truncation ) for query, contents, ids in zip(queries, contents_list, ids_list) ] @@ -73,7 +73,7 @@ def _pure( async def voyageai_rerank_pure( - voyage_client: voyageai.Client, + voyage_client: voyageai.AsyncClient, model: str, query: str, documents: List[str], From d808fe2416129e8fa773c50ca177bededf0f4c41 Mon Sep 17 00:00:00 2001 From: kimbwook Date: Mon, 7 Oct 2024 23:53:25 +0900 Subject: [PATCH 11/11] Add mock def --- autorag/nodes/passagereranker/voyageai.py | 4 +- .../passagereranker/test_voyageai_reranker.py | 79 +++++++++++++------ 2 files changed, 54 insertions(+), 29 deletions(-) diff --git a/autorag/nodes/passagereranker/voyageai.py b/autorag/nodes/passagereranker/voyageai.py index 619105865..2868189d2 100644 --- a/autorag/nodes/passagereranker/voyageai.py +++ b/autorag/nodes/passagereranker/voyageai.py @@ -103,9 +103,7 @@ async def voyageai_rerank_pure( reranked_scores: List[float] = list( map(lambda x: x.relevance_score, rerank_results.results) ) - reranked_contents: List[str] = list( - map(lambda x: x.document, rerank_results.results) - ) indices = list(map(lambda x: x.index, rerank_results.results)) + reranked_contents: List[str] = list(map(lambda i: documents[i], indices)) reranked_ids: List[str] = list(map(lambda i: ids[i], indices)) return reranked_contents, reranked_ids, reranked_scores diff --git a/tests/autorag/nodes/passagereranker/test_voyageai_reranker.py b/tests/autorag/nodes/passagereranker/test_voyageai_reranker.py index 70c06d360..cdcf90aef 100644 --- a/tests/autorag/nodes/passagereranker/test_voyageai_reranker.py +++ b/tests/autorag/nodes/passagereranker/test_voyageai_reranker.py @@ -2,6 +2,11 @@ import pytest +from collections import namedtuple +import voyageai +from voyageai.object.reranking import RerankingObject, RerankingResult +from voyageai.api_resources import VoyageResponse + import autorag from autorag.nodes.passagereranker import VoyageAIReranker from tests.autorag.nodes.passagereranker.test_passage_reranker_base import ( @@ -15,23 +20,43 @@ ) -def mock_voyageai_reranker_pure( - voyage_client, model, query, documents, ids, top_k, truncation +async def mock_voyageai_reranker_pure( + self, + query, + documents, + model, + top_k, + truncation, ): - if query == queries_example[0]: - return ( - [documents[1], documents[2], documents[0]][:top_k], - [ids[1], ids[2], ids[0]][:top_k], - [0.8, 0.2, 0.1][:top_k], - ) - elif query == queries_example[1]: - return ( - [documents[1], documents[0], documents[2]][:top_k], - [ids[1], ids[0], ids[2]][:top_k], - [0.8, 0.2, 0.1][:top_k], - ) - else: - raise ValueError(f"Unexpected query: {query}") + mock_documents = ["Document 1 content", "Document 2 content", "Document 3 content"] + + # Mock response data + mock_response_data = [ + {"index": 1, "relevance_score": 0.8}, + {"index": 2, "relevance_score": 0.2}, + {"index": 0, "relevance_score": 0.1}, + ] + + # Mock usage data + mock_usage = {"total_tokens": 100} + + # Create a mock VoyageResponse object + mock_response = VoyageResponse() + mock_response.data = [ + namedtuple("MockData", d.keys())(*d.values()) for d in mock_response_data + ] + mock_response.usage = namedtuple("MockUsage", mock_usage.keys())( + *mock_usage.values() + ) + + # Create an instance of RerankingObject using the mock data + object = RerankingObject(documents=mock_documents, response=mock_response) + + if top_k == 1: + object.results = [ + RerankingResult(index=1, document="nodonggunn", relevance_score=0.8) + ] + return object @pytest.fixture @@ -39,11 +64,7 @@ def voyageai_reranker_instance(): return VoyageAIReranker(project_dir, api_key="mock_api_key") -@patch.object( - autorag.nodes.passagereranker.voyageai, - "voyageai_rerank_pure", - mock_voyageai_reranker_pure, -) +@patch.object(voyageai.AsyncClient, "rerank", mock_voyageai_reranker_pure) def test_voyageai_reranker(voyageai_reranker_instance): top_k = 3 contents_result, id_result, score_result = voyageai_reranker_instance._pure( @@ -52,11 +73,17 @@ def test_voyageai_reranker(voyageai_reranker_instance): base_reranker_test(contents_result, id_result, score_result, top_k) -@patch.object( - autorag.nodes.passagereranker.voyageai, - "voyageai_rerank_pure", - mock_voyageai_reranker_pure, -) +@patch.object(voyageai.AsyncClient, "rerank", mock_voyageai_reranker_pure) +def test_voyageai_reranker_batch_one(voyageai_reranker_instance): + top_k = 1 + batch = 1 + contents_result, id_result, score_result = voyageai_reranker_instance._pure( + queries_example, contents_example, ids_example, top_k, batch=batch + ) + base_reranker_test(contents_result, id_result, score_result, top_k) + + +@patch.object(voyageai.AsyncClient, "rerank", mock_voyageai_reranker_pure) def test_voyageai_reranker_node(): top_k = 1 result_df = VoyageAIReranker.run_evaluator(