-
-
Notifications
You must be signed in to change notification settings - Fork 174
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add passage dependency filter at data creation (#751)
* add passage dependency feature * add docs about passage dependent filter * add api docs of passage dependency filter --------- Co-authored-by: jeffrey <vkefhdl1@gmail.com>
- Loading branch information
Showing
5 changed files
with
314 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
from typing import Dict, List | ||
|
||
from llama_index.core.base.llms.base import BaseLLM | ||
from llama_index.core.base.llms.types import ChatMessage, MessageRole, ChatResponse | ||
from llama_index.llms.openai.utils import to_openai_message_dicts | ||
from openai import AsyncClient | ||
from pydantic import BaseModel | ||
|
||
from autorag.data.qa.filter.prompt import FILTER_PROMPT | ||
|
||
|
||
class Response(BaseModel): | ||
is_passage_dependent: bool | ||
|
||
|
||
async def passage_dependency_filter_openai( | ||
row: Dict, | ||
client: AsyncClient, | ||
model_name: str = "gpt-4o-mini-2024-07-18", | ||
lang: str = "en", | ||
) -> bool: | ||
""" | ||
This will drop passage-dependent question rows. | ||
Passage-dependent questions are questions that the answer will change depending on what passage you choose. | ||
The passage-dependent questions will not be good for RAG evaluation, because any retrieval system can't find the right passage with passage-dependent question. | ||
For example, when someone asks "What is the highest score according to the table?" the answer will be different depending on the table. | ||
And what is the table? The retrieval system can't find the right passage with this question. | ||
You can use this filter with the ` batch_filter ` function at `QA` class. | ||
:param row: The row dict from QA dataset. | ||
:param client: The OpenAI client. | ||
:param model_name: The model name. | ||
You have to use gpt-4o-2024-08-06 or gpt-4o-mini-2024-07-18. | ||
:param lang: The supported language is en or ko. | ||
:return: False if the row question is a passage-dependent question (to be filtered). | ||
""" | ||
assert "query" in row.keys(), "query column is not in the row." | ||
system_prompt: List[ChatMessage] = FILTER_PROMPT["passage_dependency"][lang] | ||
query = row["query"] | ||
completion = await client.beta.chat.completions.parse( | ||
model=model_name, | ||
messages=to_openai_message_dicts( | ||
system_prompt | ||
+ [ | ||
ChatMessage( | ||
role=MessageRole.USER, | ||
content=f"Question: {query}\nIs this the question passage dependent?", | ||
) | ||
] | ||
), | ||
response_format=Response, | ||
) | ||
return not completion.choices[0].message.parsed.is_passage_dependent | ||
|
||
|
||
async def passage_dependency_filter_llama_index( | ||
row: Dict, | ||
llm: BaseLLM, | ||
lang: str = "en", | ||
) -> bool: | ||
""" | ||
This will drop passage-dependent question rows. | ||
Passage-dependent questions are questions that the answer will change depending on what passage you choose. | ||
The passage-dependent questions will not be good for RAG evaluation, because any retrieval system can't find the right passage with passage-dependent question. | ||
For example, when someone asks "What is the highest score according to the table?" the answer will be different depending on the table. | ||
And what is the table? The retrieval system can't find the right passage with this question. | ||
You can use this filter with the ` batch_filter ` function at `QA` class. | ||
:param row: The row dict from QA dataset. | ||
:param llm: The Llama index llm instance. | ||
It will be good if you set max tokens to low for saving tokens. | ||
:param lang: The supported language is en or ko. | ||
:return: False if the row question is a passage-dependent question (to be filtered). | ||
""" | ||
assert "query" in row.keys(), "query column is not in the row." | ||
system_prompt: List[ChatMessage] = FILTER_PROMPT["passage_dependency"][lang] | ||
query = row["query"] | ||
response: ChatResponse = await llm.achat( | ||
messages=system_prompt | ||
+ [ | ||
ChatMessage( | ||
role=MessageRole.USER, | ||
content=f"Question: {query}\nIs this the question passage dependent?", | ||
) | ||
] | ||
) | ||
result_str = response.message.content | ||
return "true" not in result_str.lower().strip() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
145 changes: 145 additions & 0 deletions
145
tests/autorag/data/qa/filter/test_passage_dependency.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
import time | ||
from unittest.mock import patch | ||
|
||
import pandas as pd | ||
from llama_index.llms.openai import OpenAI | ||
from openai import AsyncOpenAI | ||
from openai.types.chat import ( | ||
ParsedChatCompletion, | ||
ParsedChoice, | ||
ParsedChatCompletionMessage, | ||
) | ||
import openai.resources.beta.chat | ||
|
||
from llama_index.core.base.llms.types import ChatResponse, ChatMessage, MessageRole | ||
from autorag.data.qa.filter.passage_dependency import ( | ||
passage_dependency_filter_openai, | ||
passage_dependency_filter_llama_index, | ||
Response, | ||
) | ||
from autorag.data.qa.schema import QA | ||
|
||
en_qa_df = pd.DataFrame( | ||
{ | ||
"query": [ | ||
"What is the most significant discovery mentioned in the research paper?", | ||
"What was the ruling in the case described in this legal brief?", | ||
"What is the passage reranker role in the Advanced RAG system?", | ||
"Who is the latest 30 homerun 30 steal record owner in the KBO league?", | ||
] | ||
} | ||
) | ||
|
||
ko_qa_df = pd.DataFrame( | ||
{ | ||
"query": [ | ||
"연구 논문에서 언급된 가장 중요한 발견은 무엇입니까?", | ||
"이 판결문에 기술된 사건의 판결은 무엇이었습니까?", | ||
"Advanced RAG 시스템에서 리랭커 역할은 무엇입니까?", | ||
"KBO 리그에서 가장 최근 30홈런 30도루를 기록한 선수는 누구입니까?", | ||
] | ||
} | ||
) | ||
|
||
expected_df_en = pd.DataFrame( | ||
{ | ||
"query": [ | ||
"What is the passage reranker role in the Advanced RAG system?", | ||
"Who is the latest 30 homerun 30 steal record owner in the KBO league?", | ||
] | ||
} | ||
) | ||
|
||
expected_df_ko = pd.DataFrame( | ||
{ | ||
"query": [ | ||
"Advanced RAG 시스템에서 리랭커 역할은 무엇입니까?", | ||
"KBO 리그에서 가장 최근 30홈런 30도루를 기록한 선수는 누구입니까?", | ||
] | ||
} | ||
) | ||
|
||
passage_dependent_response = [ | ||
"What is the most significant discovery mentioned in the research paper?", | ||
"What was the ruling in the case described in this legal brief?", | ||
"연구 논문에서 언급된 가장 중요한 발견은 무엇입니까?", | ||
"이 판결문에 기술된 사건의 판결은 무엇이었습니까?", | ||
] | ||
|
||
|
||
async def mock_openai_response(*args, **kwargs) -> ParsedChatCompletion[Response]: | ||
user_prompt = kwargs["messages"][1]["content"] | ||
return ParsedChatCompletion( | ||
id="test_id", | ||
choices=[ | ||
ParsedChoice( | ||
finish_reason="stop", | ||
index=0, | ||
message=ParsedChatCompletionMessage( | ||
parsed=Response( | ||
is_passage_dependent=user_prompt.split("\n")[0] | ||
.split(":")[1] | ||
.strip() | ||
in passage_dependent_response | ||
), | ||
role="assistant", | ||
), | ||
) | ||
], | ||
created=int(time.time()), | ||
model="gpt-4o-mini-2024-07-18", | ||
object="chat.completion", | ||
) | ||
|
||
|
||
async def mock_llama_index_response(*args, **kwargs) -> ChatResponse: | ||
user_prompt = kwargs["messages"][1].content | ||
return ChatResponse( | ||
message=ChatMessage( | ||
role=MessageRole.ASSISTANT, | ||
content=str( | ||
user_prompt.split("\n")[0].split(":")[1].strip() | ||
in passage_dependent_response | ||
), | ||
) | ||
) | ||
|
||
|
||
@patch.object( | ||
openai.resources.beta.chat.completions.AsyncCompletions, | ||
"parse", | ||
mock_openai_response, | ||
) | ||
def test_passage_dependency_filter_openai(): | ||
client = AsyncOpenAI() | ||
en_qa = QA(en_qa_df) | ||
result_en_qa = en_qa.batch_filter( | ||
passage_dependency_filter_openai, client=client, lang="en" | ||
).map(lambda df: df.reset_index(drop=True)) | ||
pd.testing.assert_frame_equal(result_en_qa.data, expected_df_en) | ||
|
||
ko_qa = QA(ko_qa_df) | ||
result_ko_qa = ko_qa.batch_filter( | ||
passage_dependency_filter_openai, client=client, lang="ko" | ||
).map(lambda df: df.reset_index(drop=True)) | ||
pd.testing.assert_frame_equal(result_ko_qa.data, expected_df_ko) | ||
|
||
|
||
@patch.object( | ||
OpenAI, | ||
"achat", | ||
mock_llama_index_response, | ||
) | ||
def test_passage_dependency_filter_llama_index(): | ||
llm = OpenAI(temperature=0, model="gpt-4o-mini") | ||
en_qa = QA(en_qa_df) | ||
result_en_qa = en_qa.batch_filter( | ||
passage_dependency_filter_llama_index, llm=llm, lang="en" | ||
).map(lambda df: df.reset_index(drop=True)) | ||
pd.testing.assert_frame_equal(result_en_qa.data, expected_df_en) | ||
|
||
ko_qa = QA(ko_qa_df) | ||
result_ko_qa = ko_qa.batch_filter( | ||
passage_dependency_filter_llama_index, llm=llm, lang="ko" | ||
).map(lambda df: df.reset_index(drop=True)) | ||
pd.testing.assert_frame_equal(result_ko_qa.data, expected_df_ko) |