Skip to content

Commit

Permalink
Merge branch 'main' into Feature/#900
Browse files Browse the repository at this point in the history
  • Loading branch information
bwook00 authored Oct 30, 2024
2 parents ff9e0d7 + 17c94cc commit 676b5cc
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 9 deletions.
2 changes: 1 addition & 1 deletion sample_config/rag/full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ vectordb:
client_type: persistent
embedding_model: openai_embed_3_large
collection_name: openai_embed_3_large
path: ${PROJECT_DIR}/data/chroma
path: ${PROJECT_DIR}/resources/chroma
node_lines:
- node_line_name: pre_retrieve_node_line # Arbitrary node line name
nodes:
Expand Down
65 changes: 57 additions & 8 deletions tests/autorag/nodes/queryexpansion/test_query_expansion_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,23 @@
import pathlib
import shutil
import tempfile
from typing import List
from unittest.mock import patch

import pandas as pd
import pytest
from llama_index.core.base.llms.types import CompletionResponse
from llama_index.llms.openai import OpenAI

from autorag import embedding_models, MockEmbeddingRandom, LazyInit
from autorag.nodes.queryexpansion import QueryDecompose, HyDE
from autorag.nodes.queryexpansion.run import evaluate_one_query_expansion_node
from autorag.nodes.queryexpansion.run import run_query_expansion_node
from autorag.nodes.retrieval import BM25
from autorag.nodes.retrieval import BM25, VectorDB, HybridCC
from autorag.nodes.retrieval.vectordb import vectordb_ingest
from autorag.schema.metricinput import MetricInput
from autorag.utils.util import load_summary_file
from autorag.utils.util import load_summary_file, get_event_loop
from autorag.vectordb import load_all_vectordb_from_yaml

root_dir = pathlib.PurePath(
os.path.dirname(os.path.realpath(__file__))
Expand All @@ -39,29 +43,40 @@
@pytest.fixture
def node_line_dir():
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as project_dir:
embedding_models["mock_1536"] = LazyInit(MockEmbeddingRandom, embed_dim=1536)

sample_project_dir = os.path.join(resources_dir, "sample_project")
# copy & paste all folders and files in sample_project folder
shutil.copytree(sample_project_dir, project_dir, dirs_exist_ok=True)

os.environ["PROJECT_DIR"] = project_dir

# Set up the vector db
corpus_df = pd.read_parquet(os.path.join(project_dir, "data", "corpus.parquet"))
vectordbs = load_all_vectordb_from_yaml(
os.path.join(project_dir, "resources", "vectordb.yaml"), project_dir
)

loop = get_event_loop()
for vectordb in vectordbs:
loop.run_until_complete(vectordb_ingest(vectordb, corpus_df))

test_trail_dir = os.path.join(project_dir, "test_trial")
os.makedirs(test_trail_dir)
node_line_dir = os.path.join(test_trail_dir, "test_node_line")
os.makedirs(node_line_dir)
yield node_line_dir


def test_evaluate_one_query_expansion_node(node_line_dir):
def base_test_evaluate_one_query_expansion_node(
node_line_dir, retrieval_funcs: List, retrieval_params: List
):
project_dir = pathlib.PurePath(node_line_dir).parent.parent
qa_path = os.path.join(project_dir, "data", "qa.parquet")
previous_result = pd.read_parquet(qa_path)
sample_previous_result = previous_result.head(2)
sample_retrieval_gt = sample_previous_result["retrieval_gt"].tolist()

retrieval_funcs = [BM25, BM25]
retrieval_params = [
{"top_k": 1, "bm25_tokenizer": "gpt2"},
{"top_k": 2, "bm25_tokenizer": "gpt2"},
]
metric_inputs = [
MetricInput(queries=queries, retrieval_gt=ret_gt)
for queries, ret_gt in zip(sample_expanded_queries, sample_retrieval_gt)
Expand All @@ -81,6 +96,40 @@ def test_evaluate_one_query_expansion_node(node_line_dir):
assert len(best_result) == len(sample_expanded_queries)


def test_evaluate_one_query_expansion_node(node_line_dir):
retrieval_funcs = [BM25, BM25]
retrieval_params = [
{"top_k": 1, "bm25_tokenizer": "gpt2"},
{"top_k": 2, "bm25_tokenizer": "gpt2"},
]
base_test_evaluate_one_query_expansion_node(
node_line_dir, retrieval_funcs, retrieval_params
)


def test_evaluate_one_query_expansion_node_vectordb(node_line_dir):
retrieval_funcs = [VectorDB, VectorDB, HybridCC]
retrieval_params = [
{"top_k": 3, "vectordb": "chroma_large"},
{"top_k": 5, "vectordb": "chroma_small"},
{
"top_k": 5,
"target_modules": ("bm25", "vectordb"),
"target_module_params": (
{"top_k": 3, "bm25_tokenizer": "gpt2"},
{
"top_k": 3,
"vectordb": "chroma_large",
},
),
"weight": 0.36,
},
]
base_test_evaluate_one_query_expansion_node(
node_line_dir, retrieval_funcs, retrieval_params
)


def base_query_expansion_test(best_result, node_line_dir):
assert os.path.exists(os.path.join(node_line_dir, "query_expansion"))
expect_columns = [
Expand Down
13 changes: 13 additions & 0 deletions tests/resources/sample_project/resources/vectordb.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
vectordb:
- name: chroma_small
db_type: chroma
client_type: persistent
collection_name: chroma_small
embedding_model: mock
path: ${PROJECT_DIR}/resources/chroma
- name: chroma_large
db_type: chroma
client_type: persistent
collection_name: chroma_large
embedding_model: mock_1536
path: ${PROJECT_DIR}/resources/chroma

0 comments on commit 676b5cc

Please sign in to comment.