Skip to content

Commit

Permalink
Add Passage Filter node and its first module 'similarity_threshold_cu…
Browse files Browse the repository at this point in the history
…toff' (#290)

* make the first module of passage_filter, threshold_cutoff

* add at passage filter __init__.py

* return at least one

* add docs annotation of similarity_threshold_cutoff

* original function test done

* input embedding model name not the emodel

* passage filter node test

* add test cases for that retrieved results counts are not consistent

* make passage filter run.py

* make run.py for passage filter

* test new node passage filter

* add passage filter node to full.yaml

* add passage filter docs

* add passage filter at index.md and README.md

* add passage filter rst docs

* fix test code for fixed simple.yaml

---------

Co-authored-by: jeffrey <vkefhdl1@gmail.com>
  • Loading branch information
vkehfdl1 and jeffrey authored Apr 10, 2024
1 parent 2651c74 commit 6490afa
Show file tree
Hide file tree
Showing 22 changed files with 493 additions and 17 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,9 @@ node_lines:
| [Query_Expansion](https://marker-inc-korea.github.io/AutoRAG/nodes/query_expansion/query_expansion.html) | [Query_Decompose](https://marker-inc-korea.github.io/AutoRAG/nodes/query_expansion/query_decompose.html)<br/>[HyDE](https://marker-inc-korea.github.io/AutoRAG/nodes/query_expansion/hyde.html) |
| [Retrieval](https://marker-inc-korea.github.io/AutoRAG/nodes/retrieval/retrieval.html) | [BM25](https://marker-inc-korea.github.io/AutoRAG/nodes/retrieval/bm25.html)<br/>[VectorDB (choose embedding model)](https://marker-inc-korea.github.io/AutoRAG/nodes/retrieval/vectordb.html)<br/>[Hybrid with rrf (reciprocal rank fusion)](https://marker-inc-korea.github.io/AutoRAG/nodes/retrieval/hybrid_rrf.html)<br/>[Hybrid with cc (convex combination)](https://marker-inc-korea.github.io/AutoRAG/nodes/retrieval/hybrid_cc.html) |
| [Passage_Reranker](https://marker-inc-korea.github.io/AutoRAG/nodes/passage_reranker/passage_reranker.html) | [UPR](https://marker-inc-korea.github.io/AutoRAG/nodes/passage_reranker/upr.html)<br/>[Tart](https://marker-inc-korea.github.io/AutoRAG/nodes/passage_reranker/tart.html)<br/>[MonoT5](https://marker-inc-korea.github.io/AutoRAG/nodes/passage_reranker/monot5.html)<br/>[Ko-reranker](https://marker-inc-korea.github.io/AutoRAG/nodes/passage_reranker/koreranker.html)<br/>[cohere_reranker](https://marker-inc-korea.github.io/AutoRAG/nodes/passage_reranker/cohere.html)<br/>[RankGPT](https://marker-inc-korea.github.io/AutoRAG/nodes/passage_reranker/rankgpt.html)<br/>[Jina Reranker](https://marker-inc-korea.github.io/AutoRAG/nodes/passage_reranker/jina_reranker.html)<br/>[Sentence Transformer Reranker](https://marker-inc-korea.github.io/AutoRAG/nodes/passage_reranker/setence_transformer_reranker.html)<br/>[Colbert Reranker](https://marker-inc-korea.github.io/AutoRAG/nodes/passage_reranker/colbert.html)<br/>[Flag Embedding Reranker](https://marker-inc-korea.github.io/AutoRAG/nodes/passage_reranker/flag_embedding_reranker.html) |
| [Passage_Compressor](https://marker-inc-korea.github.io/AutoRAG/nodes/passage_compressor/passage_compressor.html) | [Tree Summarize](https://marker-inc-korea.github.io/AutoRAG/nodes/passage_compressor/tree_summarize.html)<br/>[Long Context Reorder](https://marker-inc-korea.github.io/AutoRAG/nodes/passage_compressor/long_context_reorder.html) |
| [Prompt Maker](https://marker-inc-korea.github.io/AutoRAG/nodes/prompt_maker/prompt_maker.html) | [Default Prompt Maker (f-string)](https://marker-inc-korea.github.io/AutoRAG/nodes/prompt_maker/fstring.html) |
| [Passage Filter](https://marker-inc-korea.github.io/AutoRAG/nodes/passage_filter/passage_filter.html) | [similarity_threshold_cutoff](https://marker-inc-korea.github.io/AutoRAG/nodes/passage_filter/similarity_threshold_cutoff.html) | |
| [Passage_Compressor](https://marker-inc-korea.github.io/AutoRAG/nodes/passage_compressor/passage_compressor.html) | [Tree Summarize](https://marker-inc-korea.github.io/AutoRAG/nodes/passage_compressor/tree_summarize.html) |
| [Prompt Maker](https://marker-inc-korea.github.io/AutoRAG/nodes/prompt_maker/prompt_maker.html) | [Default Prompt Maker (f-string)](https://marker-inc-korea.github.io/AutoRAG/nodes/prompt_maker/fstring.html)<br/>[Long Context Reorder](https://marker-inc-korea.github.io/AutoRAG/nodes/prompt_maker/long_context_reorder.html) |
| [Generator](https://marker-inc-korea.github.io/AutoRAG/nodes/generator/generator.html) | [llama_index llm](https://marker-inc-korea.github.io/AutoRAG/nodes/generator/llama_index_llm.html)<br/>[vllm](https://marker-inc-korea.github.io/AutoRAG/nodes/generator/vllm.html) |

# 🛣Roadmap
Expand Down
1 change: 1 addition & 0 deletions autorag/nodes/passagefilter/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .threshold_cutoff import similarity_threshold_cutoff
41 changes: 41 additions & 0 deletions autorag/nodes/passagefilter/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import functools
from pathlib import Path
from typing import Union, Tuple, List

import pandas as pd

from autorag.utils import result_to_dataframe, validate_qa_dataset


# same with passage filter from now
def passage_filter_node(func):
@functools.wraps(func)
@result_to_dataframe(['retrieved_contents', 'retrieved_ids', 'retrieve_scores'])
def wrapper(
project_dir: Union[str, Path],
previous_result: pd.DataFrame,
*args, **kwargs) -> Tuple[List[List[str]], List[List[str]], List[List[float]]]:
validate_qa_dataset(previous_result)

# find queries columns
assert "query" in previous_result.columns, "previous_result must have query column."
queries = previous_result["query"].tolist()

# find contents_list columns
assert "retrieved_contents" in previous_result.columns, "previous_result must have retrieved_contents column."
contents = previous_result["retrieved_contents"].tolist()

# find scores columns
assert "retrieve_scores" in previous_result.columns, "previous_result must have retrieve_scores column."
scores = previous_result["retrieve_scores"].tolist()

# find ids columns
assert "retrieved_ids" in previous_result.columns, "previous_result must have retrieved_ids column."
ids = previous_result["retrieved_ids"].tolist()

filtered_contents, filtered_ids, filtered_scores = func(queries=queries, contents_list=contents,
scores_list=scores, ids_list=ids, *args, **kwargs)

return filtered_contents, filtered_ids, filtered_scores

return wrapper
77 changes: 77 additions & 0 deletions autorag/nodes/passagefilter/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import os
import pathlib
from typing import List, Callable, Dict

import pandas as pd

from autorag.nodes.retrieval.run import evaluate_retrieval_node
from autorag.strategy import measure_speed, filter_by_threshold, select_best_average


def run_passage_filter_node(modules: List[Callable],
module_params: List[Dict],
previous_result: pd.DataFrame,
node_line_dir: str,
strategies: Dict) -> pd.DataFrame:
"""
Run evaluation and select the best module among passage filter node results.
:param modules: Passage filter modules to run.
:param module_params: Passage filter module parameters.
:param previous_result: Previous result dataframe.
Could be retrieval, reranker, passage filter modules result.
It means it must contain 'query', 'retrieved_contents', 'retrieved_ids', 'retrieve_scores' columns.
:param node_line_dir: This node line's directory.
:param strategies: Strategies for passage reranker node.
In this node, we use 'retrieval_f1', 'retrieval_recall' and 'retrieval_precision'.
You can skip evaluation when you use only one module and a module parameter.
:return: The best result dataframe with previous result columns.
"""
if not os.path.exists(node_line_dir):
os.makedirs(node_line_dir)
project_dir = pathlib.PurePath(node_line_dir).parent.parent
retrieval_gt = pd.read_parquet(os.path.join(project_dir, "data", "qa.parquet"))['retrieval_gt'].tolist()

results, execution_times = zip(*map(lambda task: measure_speed(
task[0], project_dir=project_dir, previous_result=previous_result, **task[1]), zip(modules, module_params)))
average_times = list(map(lambda x: x / len(results[0]), execution_times))

# run metrics before filtering
if strategies.get('metrics') is None:
raise ValueError("You must at least one metrics for passage_reranker evaluation.")
results = list(map(lambda x: evaluate_retrieval_node(x, retrieval_gt, strategies.get('metrics')), results))

# save results to folder
save_dir = os.path.join(node_line_dir, "passage_filter") # node name
if not os.path.exists(save_dir):
os.makedirs(save_dir)
filepaths = list(map(lambda x: os.path.join(save_dir, f'{x}.parquet'), range(len(modules))))
list(map(lambda x: x[0].to_parquet(x[1], index=False), zip(results, filepaths))) # execute save to parquet
filenames = list(map(lambda x: os.path.basename(x), filepaths))

summary_df = pd.DataFrame({
'filename': filenames,
'module_name': list(map(lambda module: module.__name__, modules)),
'module_params': module_params,
'execution_time': average_times,
**{f'passage_filter_{metric}': list(map(lambda result: result[metric].mean(), results)) for metric in
strategies.get('metrics')},
})

# filter by strategies
if strategies.get('speed_threshold') is not None:
results, filenames = filter_by_threshold(results, average_times, strategies['speed_threshold'], filenames)
selected_result, selected_filename = select_best_average(results, strategies.get('metrics'), filenames)
selected_result = selected_result.rename(columns={
metric_name: f'passage_filter_{metric_name}' for metric_name in strategies['metrics']})
previous_result = previous_result.drop(columns=['retrieved_contents', 'retrieved_ids', 'retrieve_scores'])
best_result = pd.concat([previous_result, selected_result], axis=1)

# add 'is_best' column to summary file
summary_df['is_best'] = summary_df['filename'] == selected_filename

# save files
summary_df.to_csv(os.path.join(save_dir, "summary.csv"), index=False)
best_result.to_parquet(os.path.join(save_dir, f'best_{os.path.splitext(selected_filename)[0]}.parquet'),
index=False)
return best_result
82 changes: 82 additions & 0 deletions autorag/nodes/passagefilter/threshold_cutoff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import itertools
from typing import List, Tuple, Optional

import numpy as np
import torch.cuda

from autorag import embedding_models
from autorag.evaluate.metric.util import calculate_cosine_similarity
from autorag.nodes.passagefilter.base import passage_filter_node
from autorag.utils.util import reconstruct_list


@passage_filter_node
def similarity_threshold_cutoff(queries: List[str], contents_list: List[List[str]],
scores_list: List[List[float]], ids_list: List[List[str]],
threshold: float, embedding_model: Optional[str] = None,
batch: int = 128,
) -> Tuple[List[List[str]], List[List[str]], List[List[float]]]:
"""
Re-calculate each content's similarity with the query and filter out the contents that are below the threshold.
If all contents are filtered, keep the only one highest similarity content.
This is a filter and does not override scores.
The output of scores is not coming from query-content similarity.
:param queries: The list of queries to use for filtering
:param contents_list: The list of lists of contents to filter
:param scores_list: The list of lists of scores retrieved
:param ids_list: The list of lists of ids retrieved
:param threshold: The threshold to cut off
:param embedding_model: The embedding model to use for calculating similarity
Default is OpenAIEmbedding.
:param batch: The number of queries to be processed in a batch
Default is 128.
:return: Tuple of lists containing the filtered contents, ids, and scores
"""
if embedding_model is None:
embedding_model = embedding_models['openai']
else:
embedding_model = embedding_models[embedding_model]

# Embedding using batch
embedding_model.embed_batch_size = batch
query_embeddings = embedding_model.get_text_embedding_batch(queries)

content_lengths = list(map(len, contents_list))
content_embeddings_flatten = embedding_model.get_text_embedding_batch(list(
itertools.chain.from_iterable(contents_list)))
content_embeddings = reconstruct_list(content_embeddings_flatten, content_lengths)

remain_indices = list(map(lambda x: similarity_threshold_cutoff_pure(x[0], x[1], threshold),
zip(query_embeddings, content_embeddings)))

remain_content_list = list(map(lambda c, idx: [c[i] for i in idx], contents_list, remain_indices))
remain_scores_list = list(map(lambda s, idx: [s[i] for i in idx], scores_list, remain_indices))
remain_ids_list = list(map(lambda _id, idx: [_id[i] for i in idx], ids_list, remain_indices))

del embedding_model
if torch.cuda.is_available():
torch.cuda.empty_cache()

return remain_content_list, remain_ids_list, remain_scores_list


def similarity_threshold_cutoff_pure(query_embedding: str,
content_embeddings: List[List[float]],
threshold: float) -> List[int]:
"""
Return indices that have to remain.
Return at least one index if there is nothing to remain.
:param query_embedding: Query embedding
:param content_embeddings: Each content embedding
:param threshold: The threshold to cut off
:return: Indices to remain at the contents
"""

similarities = np.array(list(map(lambda x: calculate_cosine_similarity(query_embedding, x),
content_embeddings)))
result = np.where(similarities >= threshold)[0].tolist()
if len(result) > 0:
return result
return [np.argmax(similarities)]
Loading

0 comments on commit 6490afa

Please sign in to comment.