From 17c94cc1a269a975f3e22196bfb0d18be2b99118 Mon Sep 17 00:00:00 2001 From: "Jeffrey (Dongkyu) Kim" Date: Wed, 30 Oct 2024 14:32:15 +0900 Subject: [PATCH] Add test code for query expansion with vectordb (#902) * edit typo at full.yaml * add test code for vectordb use in query expansion (also hybrid retrieval) --------- Co-authored-by: jeffrey --- sample_config/rag/full.yaml | 2 +- .../test_query_expansion_run.py | 65 ++++++++++++++++--- .../sample_project/resources/vectordb.yaml | 13 ++++ 3 files changed, 71 insertions(+), 9 deletions(-) create mode 100644 tests/resources/sample_project/resources/vectordb.yaml diff --git a/sample_config/rag/full.yaml b/sample_config/rag/full.yaml index fcc25cf71..92fd748ad 100644 --- a/sample_config/rag/full.yaml +++ b/sample_config/rag/full.yaml @@ -4,7 +4,7 @@ vectordb: client_type: persistent embedding_model: openai_embed_3_large collection_name: openai_embed_3_large - path: ${PROJECT_DIR}/data/chroma + path: ${PROJECT_DIR}/resources/chroma node_lines: - node_line_name: pre_retrieve_node_line # Arbitrary node line name nodes: diff --git a/tests/autorag/nodes/queryexpansion/test_query_expansion_run.py b/tests/autorag/nodes/queryexpansion/test_query_expansion_run.py index a0dc0b13e..6ea5802b3 100644 --- a/tests/autorag/nodes/queryexpansion/test_query_expansion_run.py +++ b/tests/autorag/nodes/queryexpansion/test_query_expansion_run.py @@ -2,6 +2,7 @@ import pathlib import shutil import tempfile +from typing import List from unittest.mock import patch import pandas as pd @@ -9,12 +10,15 @@ from llama_index.core.base.llms.types import CompletionResponse from llama_index.llms.openai import OpenAI +from autorag import embedding_models, MockEmbeddingRandom, LazyInit from autorag.nodes.queryexpansion import QueryDecompose, HyDE from autorag.nodes.queryexpansion.run import evaluate_one_query_expansion_node from autorag.nodes.queryexpansion.run import run_query_expansion_node -from autorag.nodes.retrieval import BM25 +from autorag.nodes.retrieval import BM25, VectorDB, HybridCC +from autorag.nodes.retrieval.vectordb import vectordb_ingest from autorag.schema.metricinput import MetricInput -from autorag.utils.util import load_summary_file +from autorag.utils.util import load_summary_file, get_event_loop +from autorag.vectordb import load_all_vectordb_from_yaml root_dir = pathlib.PurePath( os.path.dirname(os.path.realpath(__file__)) @@ -39,10 +43,24 @@ @pytest.fixture def node_line_dir(): with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as project_dir: + embedding_models["mock_1536"] = LazyInit(MockEmbeddingRandom, embed_dim=1536) + sample_project_dir = os.path.join(resources_dir, "sample_project") # copy & paste all folders and files in sample_project folder shutil.copytree(sample_project_dir, project_dir, dirs_exist_ok=True) + os.environ["PROJECT_DIR"] = project_dir + + # Set up the vector db + corpus_df = pd.read_parquet(os.path.join(project_dir, "data", "corpus.parquet")) + vectordbs = load_all_vectordb_from_yaml( + os.path.join(project_dir, "resources", "vectordb.yaml"), project_dir + ) + + loop = get_event_loop() + for vectordb in vectordbs: + loop.run_until_complete(vectordb_ingest(vectordb, corpus_df)) + test_trail_dir = os.path.join(project_dir, "test_trial") os.makedirs(test_trail_dir) node_line_dir = os.path.join(test_trail_dir, "test_node_line") @@ -50,18 +68,15 @@ def node_line_dir(): yield node_line_dir -def test_evaluate_one_query_expansion_node(node_line_dir): +def base_test_evaluate_one_query_expansion_node( + node_line_dir, retrieval_funcs: List, retrieval_params: List +): project_dir = pathlib.PurePath(node_line_dir).parent.parent qa_path = os.path.join(project_dir, "data", "qa.parquet") previous_result = pd.read_parquet(qa_path) sample_previous_result = previous_result.head(2) sample_retrieval_gt = sample_previous_result["retrieval_gt"].tolist() - retrieval_funcs = [BM25, BM25] - retrieval_params = [ - {"top_k": 1, "bm25_tokenizer": "gpt2"}, - {"top_k": 2, "bm25_tokenizer": "gpt2"}, - ] metric_inputs = [ MetricInput(queries=queries, retrieval_gt=ret_gt) for queries, ret_gt in zip(sample_expanded_queries, sample_retrieval_gt) @@ -81,6 +96,40 @@ def test_evaluate_one_query_expansion_node(node_line_dir): assert len(best_result) == len(sample_expanded_queries) +def test_evaluate_one_query_expansion_node(node_line_dir): + retrieval_funcs = [BM25, BM25] + retrieval_params = [ + {"top_k": 1, "bm25_tokenizer": "gpt2"}, + {"top_k": 2, "bm25_tokenizer": "gpt2"}, + ] + base_test_evaluate_one_query_expansion_node( + node_line_dir, retrieval_funcs, retrieval_params + ) + + +def test_evaluate_one_query_expansion_node_vectordb(node_line_dir): + retrieval_funcs = [VectorDB, VectorDB, HybridCC] + retrieval_params = [ + {"top_k": 3, "vectordb": "chroma_large"}, + {"top_k": 5, "vectordb": "chroma_small"}, + { + "top_k": 5, + "target_modules": ("bm25", "vectordb"), + "target_module_params": ( + {"top_k": 3, "bm25_tokenizer": "gpt2"}, + { + "top_k": 3, + "vectordb": "chroma_large", + }, + ), + "weight": 0.36, + }, + ] + base_test_evaluate_one_query_expansion_node( + node_line_dir, retrieval_funcs, retrieval_params + ) + + def base_query_expansion_test(best_result, node_line_dir): assert os.path.exists(os.path.join(node_line_dir, "query_expansion")) expect_columns = [ diff --git a/tests/resources/sample_project/resources/vectordb.yaml b/tests/resources/sample_project/resources/vectordb.yaml new file mode 100644 index 000000000..8e0020f37 --- /dev/null +++ b/tests/resources/sample_project/resources/vectordb.yaml @@ -0,0 +1,13 @@ +vectordb: + - name: chroma_small + db_type: chroma + client_type: persistent + collection_name: chroma_small + embedding_model: mock + path: ${PROJECT_DIR}/resources/chroma + - name: chroma_large + db_type: chroma + client_type: persistent + collection_name: chroma_large + embedding_model: mock_1536 + path: ${PROJECT_DIR}/resources/chroma