-
-
Notifications
You must be signed in to change notification settings - Fork 174
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Passage Filter node and its first module 'similarity_threshold_cu…
…toff' (#290) * make the first module of passage_filter, threshold_cutoff * add at passage filter __init__.py * return at least one * add docs annotation of similarity_threshold_cutoff * original function test done * input embedding model name not the emodel * passage filter node test * add test cases for that retrieved results counts are not consistent * make passage filter run.py * make run.py for passage filter * test new node passage filter * add passage filter node to full.yaml * add passage filter docs * add passage filter at index.md and README.md * add passage filter rst docs * fix test code for fixed simple.yaml --------- Co-authored-by: jeffrey <vkefhdl1@gmail.com>
- Loading branch information
Showing
22 changed files
with
493 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .threshold_cutoff import similarity_threshold_cutoff |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import functools | ||
from pathlib import Path | ||
from typing import Union, Tuple, List | ||
|
||
import pandas as pd | ||
|
||
from autorag.utils import result_to_dataframe, validate_qa_dataset | ||
|
||
|
||
# same with passage filter from now | ||
def passage_filter_node(func): | ||
@functools.wraps(func) | ||
@result_to_dataframe(['retrieved_contents', 'retrieved_ids', 'retrieve_scores']) | ||
def wrapper( | ||
project_dir: Union[str, Path], | ||
previous_result: pd.DataFrame, | ||
*args, **kwargs) -> Tuple[List[List[str]], List[List[str]], List[List[float]]]: | ||
validate_qa_dataset(previous_result) | ||
|
||
# find queries columns | ||
assert "query" in previous_result.columns, "previous_result must have query column." | ||
queries = previous_result["query"].tolist() | ||
|
||
# find contents_list columns | ||
assert "retrieved_contents" in previous_result.columns, "previous_result must have retrieved_contents column." | ||
contents = previous_result["retrieved_contents"].tolist() | ||
|
||
# find scores columns | ||
assert "retrieve_scores" in previous_result.columns, "previous_result must have retrieve_scores column." | ||
scores = previous_result["retrieve_scores"].tolist() | ||
|
||
# find ids columns | ||
assert "retrieved_ids" in previous_result.columns, "previous_result must have retrieved_ids column." | ||
ids = previous_result["retrieved_ids"].tolist() | ||
|
||
filtered_contents, filtered_ids, filtered_scores = func(queries=queries, contents_list=contents, | ||
scores_list=scores, ids_list=ids, *args, **kwargs) | ||
|
||
return filtered_contents, filtered_ids, filtered_scores | ||
|
||
return wrapper |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import os | ||
import pathlib | ||
from typing import List, Callable, Dict | ||
|
||
import pandas as pd | ||
|
||
from autorag.nodes.retrieval.run import evaluate_retrieval_node | ||
from autorag.strategy import measure_speed, filter_by_threshold, select_best_average | ||
|
||
|
||
def run_passage_filter_node(modules: List[Callable], | ||
module_params: List[Dict], | ||
previous_result: pd.DataFrame, | ||
node_line_dir: str, | ||
strategies: Dict) -> pd.DataFrame: | ||
""" | ||
Run evaluation and select the best module among passage filter node results. | ||
:param modules: Passage filter modules to run. | ||
:param module_params: Passage filter module parameters. | ||
:param previous_result: Previous result dataframe. | ||
Could be retrieval, reranker, passage filter modules result. | ||
It means it must contain 'query', 'retrieved_contents', 'retrieved_ids', 'retrieve_scores' columns. | ||
:param node_line_dir: This node line's directory. | ||
:param strategies: Strategies for passage reranker node. | ||
In this node, we use 'retrieval_f1', 'retrieval_recall' and 'retrieval_precision'. | ||
You can skip evaluation when you use only one module and a module parameter. | ||
:return: The best result dataframe with previous result columns. | ||
""" | ||
if not os.path.exists(node_line_dir): | ||
os.makedirs(node_line_dir) | ||
project_dir = pathlib.PurePath(node_line_dir).parent.parent | ||
retrieval_gt = pd.read_parquet(os.path.join(project_dir, "data", "qa.parquet"))['retrieval_gt'].tolist() | ||
|
||
results, execution_times = zip(*map(lambda task: measure_speed( | ||
task[0], project_dir=project_dir, previous_result=previous_result, **task[1]), zip(modules, module_params))) | ||
average_times = list(map(lambda x: x / len(results[0]), execution_times)) | ||
|
||
# run metrics before filtering | ||
if strategies.get('metrics') is None: | ||
raise ValueError("You must at least one metrics for passage_reranker evaluation.") | ||
results = list(map(lambda x: evaluate_retrieval_node(x, retrieval_gt, strategies.get('metrics')), results)) | ||
|
||
# save results to folder | ||
save_dir = os.path.join(node_line_dir, "passage_filter") # node name | ||
if not os.path.exists(save_dir): | ||
os.makedirs(save_dir) | ||
filepaths = list(map(lambda x: os.path.join(save_dir, f'{x}.parquet'), range(len(modules)))) | ||
list(map(lambda x: x[0].to_parquet(x[1], index=False), zip(results, filepaths))) # execute save to parquet | ||
filenames = list(map(lambda x: os.path.basename(x), filepaths)) | ||
|
||
summary_df = pd.DataFrame({ | ||
'filename': filenames, | ||
'module_name': list(map(lambda module: module.__name__, modules)), | ||
'module_params': module_params, | ||
'execution_time': average_times, | ||
**{f'passage_filter_{metric}': list(map(lambda result: result[metric].mean(), results)) for metric in | ||
strategies.get('metrics')}, | ||
}) | ||
|
||
# filter by strategies | ||
if strategies.get('speed_threshold') is not None: | ||
results, filenames = filter_by_threshold(results, average_times, strategies['speed_threshold'], filenames) | ||
selected_result, selected_filename = select_best_average(results, strategies.get('metrics'), filenames) | ||
selected_result = selected_result.rename(columns={ | ||
metric_name: f'passage_filter_{metric_name}' for metric_name in strategies['metrics']}) | ||
previous_result = previous_result.drop(columns=['retrieved_contents', 'retrieved_ids', 'retrieve_scores']) | ||
best_result = pd.concat([previous_result, selected_result], axis=1) | ||
|
||
# add 'is_best' column to summary file | ||
summary_df['is_best'] = summary_df['filename'] == selected_filename | ||
|
||
# save files | ||
summary_df.to_csv(os.path.join(save_dir, "summary.csv"), index=False) | ||
best_result.to_parquet(os.path.join(save_dir, f'best_{os.path.splitext(selected_filename)[0]}.parquet'), | ||
index=False) | ||
return best_result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import itertools | ||
from typing import List, Tuple, Optional | ||
|
||
import numpy as np | ||
import torch.cuda | ||
|
||
from autorag import embedding_models | ||
from autorag.evaluate.metric.util import calculate_cosine_similarity | ||
from autorag.nodes.passagefilter.base import passage_filter_node | ||
from autorag.utils.util import reconstruct_list | ||
|
||
|
||
@passage_filter_node | ||
def similarity_threshold_cutoff(queries: List[str], contents_list: List[List[str]], | ||
scores_list: List[List[float]], ids_list: List[List[str]], | ||
threshold: float, embedding_model: Optional[str] = None, | ||
batch: int = 128, | ||
) -> Tuple[List[List[str]], List[List[str]], List[List[float]]]: | ||
""" | ||
Re-calculate each content's similarity with the query and filter out the contents that are below the threshold. | ||
If all contents are filtered, keep the only one highest similarity content. | ||
This is a filter and does not override scores. | ||
The output of scores is not coming from query-content similarity. | ||
:param queries: The list of queries to use for filtering | ||
:param contents_list: The list of lists of contents to filter | ||
:param scores_list: The list of lists of scores retrieved | ||
:param ids_list: The list of lists of ids retrieved | ||
:param threshold: The threshold to cut off | ||
:param embedding_model: The embedding model to use for calculating similarity | ||
Default is OpenAIEmbedding. | ||
:param batch: The number of queries to be processed in a batch | ||
Default is 128. | ||
:return: Tuple of lists containing the filtered contents, ids, and scores | ||
""" | ||
if embedding_model is None: | ||
embedding_model = embedding_models['openai'] | ||
else: | ||
embedding_model = embedding_models[embedding_model] | ||
|
||
# Embedding using batch | ||
embedding_model.embed_batch_size = batch | ||
query_embeddings = embedding_model.get_text_embedding_batch(queries) | ||
|
||
content_lengths = list(map(len, contents_list)) | ||
content_embeddings_flatten = embedding_model.get_text_embedding_batch(list( | ||
itertools.chain.from_iterable(contents_list))) | ||
content_embeddings = reconstruct_list(content_embeddings_flatten, content_lengths) | ||
|
||
remain_indices = list(map(lambda x: similarity_threshold_cutoff_pure(x[0], x[1], threshold), | ||
zip(query_embeddings, content_embeddings))) | ||
|
||
remain_content_list = list(map(lambda c, idx: [c[i] for i in idx], contents_list, remain_indices)) | ||
remain_scores_list = list(map(lambda s, idx: [s[i] for i in idx], scores_list, remain_indices)) | ||
remain_ids_list = list(map(lambda _id, idx: [_id[i] for i in idx], ids_list, remain_indices)) | ||
|
||
del embedding_model | ||
if torch.cuda.is_available(): | ||
torch.cuda.empty_cache() | ||
|
||
return remain_content_list, remain_ids_list, remain_scores_list | ||
|
||
|
||
def similarity_threshold_cutoff_pure(query_embedding: str, | ||
content_embeddings: List[List[float]], | ||
threshold: float) -> List[int]: | ||
""" | ||
Return indices that have to remain. | ||
Return at least one index if there is nothing to remain. | ||
:param query_embedding: Query embedding | ||
:param content_embeddings: Each content embedding | ||
:param threshold: The threshold to cut off | ||
:return: Indices to remain at the contents | ||
""" | ||
|
||
similarities = np.array(list(map(lambda x: calculate_cosine_similarity(query_embedding, x), | ||
content_embeddings))) | ||
result = np.where(similarities >= threshold)[0].tolist() | ||
if len(result) > 0: | ||
return result | ||
return [np.argmax(similarities)] |
Oops, something went wrong.