Skip to content

Commit

Permalink
add passage dependency filter at data creation (#751)
Browse files Browse the repository at this point in the history
* add passage dependency feature

* add docs about passage dependent filter

* add api docs of passage dependency filter

---------

Co-authored-by: jeffrey <vkefhdl1@gmail.com>
  • Loading branch information
vkehfdl1 and jeffrey authored Sep 27, 2024
1 parent c0b2f94 commit 9312b5b
Show file tree
Hide file tree
Showing 5 changed files with 314 additions and 1 deletion.
88 changes: 88 additions & 0 deletions autorag/data/qa/filter/passage_dependency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from typing import Dict, List

from llama_index.core.base.llms.base import BaseLLM
from llama_index.core.base.llms.types import ChatMessage, MessageRole, ChatResponse
from llama_index.llms.openai.utils import to_openai_message_dicts
from openai import AsyncClient
from pydantic import BaseModel

from autorag.data.qa.filter.prompt import FILTER_PROMPT


class Response(BaseModel):
is_passage_dependent: bool


async def passage_dependency_filter_openai(
row: Dict,
client: AsyncClient,
model_name: str = "gpt-4o-mini-2024-07-18",
lang: str = "en",
) -> bool:
"""
This will drop passage-dependent question rows.
Passage-dependent questions are questions that the answer will change depending on what passage you choose.
The passage-dependent questions will not be good for RAG evaluation, because any retrieval system can't find the right passage with passage-dependent question.
For example, when someone asks "What is the highest score according to the table?" the answer will be different depending on the table.
And what is the table? The retrieval system can't find the right passage with this question.
You can use this filter with the ` batch_filter ` function at `QA` class.
:param row: The row dict from QA dataset.
:param client: The OpenAI client.
:param model_name: The model name.
You have to use gpt-4o-2024-08-06 or gpt-4o-mini-2024-07-18.
:param lang: The supported language is en or ko.
:return: False if the row question is a passage-dependent question (to be filtered).
"""
assert "query" in row.keys(), "query column is not in the row."
system_prompt: List[ChatMessage] = FILTER_PROMPT["passage_dependency"][lang]
query = row["query"]
completion = await client.beta.chat.completions.parse(
model=model_name,
messages=to_openai_message_dicts(
system_prompt
+ [
ChatMessage(
role=MessageRole.USER,
content=f"Question: {query}\nIs this the question passage dependent?",
)
]
),
response_format=Response,
)
return not completion.choices[0].message.parsed.is_passage_dependent


async def passage_dependency_filter_llama_index(
row: Dict,
llm: BaseLLM,
lang: str = "en",
) -> bool:
"""
This will drop passage-dependent question rows.
Passage-dependent questions are questions that the answer will change depending on what passage you choose.
The passage-dependent questions will not be good for RAG evaluation, because any retrieval system can't find the right passage with passage-dependent question.
For example, when someone asks "What is the highest score according to the table?" the answer will be different depending on the table.
And what is the table? The retrieval system can't find the right passage with this question.
You can use this filter with the ` batch_filter ` function at `QA` class.
:param row: The row dict from QA dataset.
:param llm: The Llama index llm instance.
It will be good if you set max tokens to low for saving tokens.
:param lang: The supported language is en or ko.
:return: False if the row question is a passage-dependent question (to be filtered).
"""
assert "query" in row.keys(), "query column is not in the row."
system_prompt: List[ChatMessage] = FILTER_PROMPT["passage_dependency"][lang]
query = row["query"]
response: ChatResponse = await llm.achat(
messages=system_prompt
+ [
ChatMessage(
role=MessageRole.USER,
content=f"Question: {query}\nIs this the question passage dependent?",
)
]
)
result_str = response.message.content
return "true" not in result_str.lower().strip()
34 changes: 33 additions & 1 deletion autorag/data/qa/filter/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,37 @@
만약 해당 문장이 '모른다고' 답한 것이라면, True를 반환하세요. 그렇지 않다면 False를 반환하세요.""",
)
],
}
},
"passage_dependency": {
"en": [
ChatMessage(
role=MessageRole.SYSTEM,
content="""You are a classifier that recognize 'passage dependent' questions.
The 'passage dependent' is the question that the answer will be change depending on what passage you choose.
For example) 'What is the highest score according to the table?'
This sentence is the passage dependent question because the answer will be different depending on the table.
In contrast, the following sentence is not passage dependant.
'What is the highest score of the KBO baseball history in one game?'
'What is the capital of France?'
These sentences will have the same answer regardless of the passage.
Please return True if the input question is passage dependent. Else return False.""",
)
],
"ko": [
ChatMessage(
role=MessageRole.SYSTEM,
content="""당신은 '단락 의존' 질문을 인식하는 분류기입니다.
'단락 의존'이란 어떤 단락이 선택 되는지 따라 답이 달라지는 질문을 의미합니다.
예를 들어, '주어진 표에 따르면 가장 높은 점수는 무엇인가요?'라는 질문은 단락 의존 질문입니다. 왜냐하면 표가 어떤 것인지에 따라 그 답이 달라지기 때문입니다.
반면에, 다음 문장들은 단락 의존적이지 않습니다.
'KBO 야구 역사상 한 경기에서 가장 높은 점수는 무엇인가요?' 또는 '프랑스의 수도는 무엇인가요?'
이러한 문장은 단락에 관계 없이 동일한 답을 가집니다.
입력된 질문이 단락 의존적이라면 True를 반환하고, 그렇지 않으면 False를 반환하세요.""",
)
],
},
}
8 changes: 8 additions & 0 deletions docs/source/api_spec/autorag.data.qa.filter.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@ autorag.data.qa.filter.dontknow module
:undoc-members:
:show-inheritance:

autorag.data.qa.filter.passage\_dependency module
-------------------------------------------------

.. automodule:: autorag.data.qa.filter.passage_dependency
:members:
:undoc-members:
:show-inheritance:

autorag.data.qa.filter.prompt module
------------------------------------

Expand Down
40 changes: 40 additions & 0 deletions docs/source/data_creation/qa_creation/filter.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,43 @@ filtered_qa = qa.batch_filter(dontknow_filter_llama_index, llm=llm, lang="en").m
lambda df: df.reset_index(drop=True) # reset index
)
```


## 2. Passage Dependent Filtering

Passage-dependent questions are those where the answer varies depending on the passage or context selected.
Even if you have the greatest retrieval system, the system will not find the exact passage from the passage-dependent questions.

Since the passage-dependent questions are almost impossible to get a ground truth passage,
it will decrease the discriminative power of evaluation dataset.

So, it is good to filter the passage-dependent questions after generating QA dataset.
We use LLM as the filtering model.

- OpenAI

```python
from openai import AsyncOpenAI
from autorag.data.qa.schema import QA
from autorag.data.qa.filter.passage_dependency import passage_dependency_filter_openai

client = AsyncOpenAI()
en_qa = QA(en_qa_df)
result_en_qa = en_qa.batch_filter(
passage_dependency_filter_openai, client=client, lang="en"
).map(lambda df: df.reset_index(drop=True))
```

- LlamaIndex

```python
from autorag.data.qa.schema import QA
from llama_index.llms.openai import OpenAI
from autorag.data.qa.filter.passage_dependency import passage_dependency_filter_llama_index

llm = OpenAI(temperature=0, model="gpt-4o-mini")
en_qa = QA(en_qa_df)
result_en_qa = en_qa.batch_filter(
passage_dependency_filter_llama_index, llm=llm, lang="en"
).map(lambda df: df.reset_index(drop=True))
```
145 changes: 145 additions & 0 deletions tests/autorag/data/qa/filter/test_passage_dependency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import time
from unittest.mock import patch

import pandas as pd
from llama_index.llms.openai import OpenAI
from openai import AsyncOpenAI
from openai.types.chat import (
ParsedChatCompletion,
ParsedChoice,
ParsedChatCompletionMessage,
)
import openai.resources.beta.chat

from llama_index.core.base.llms.types import ChatResponse, ChatMessage, MessageRole
from autorag.data.qa.filter.passage_dependency import (
passage_dependency_filter_openai,
passage_dependency_filter_llama_index,
Response,
)
from autorag.data.qa.schema import QA

en_qa_df = pd.DataFrame(
{
"query": [
"What is the most significant discovery mentioned in the research paper?",
"What was the ruling in the case described in this legal brief?",
"What is the passage reranker role in the Advanced RAG system?",
"Who is the latest 30 homerun 30 steal record owner in the KBO league?",
]
}
)

ko_qa_df = pd.DataFrame(
{
"query": [
"연구 논문에서 언급된 가장 중요한 발견은 무엇입니까?",
"이 판결문에 기술된 사건의 판결은 무엇이었습니까?",
"Advanced RAG 시스템에서 리랭커 역할은 무엇입니까?",
"KBO 리그에서 가장 최근 30홈런 30도루를 기록한 선수는 누구입니까?",
]
}
)

expected_df_en = pd.DataFrame(
{
"query": [
"What is the passage reranker role in the Advanced RAG system?",
"Who is the latest 30 homerun 30 steal record owner in the KBO league?",
]
}
)

expected_df_ko = pd.DataFrame(
{
"query": [
"Advanced RAG 시스템에서 리랭커 역할은 무엇입니까?",
"KBO 리그에서 가장 최근 30홈런 30도루를 기록한 선수는 누구입니까?",
]
}
)

passage_dependent_response = [
"What is the most significant discovery mentioned in the research paper?",
"What was the ruling in the case described in this legal brief?",
"연구 논문에서 언급된 가장 중요한 발견은 무엇입니까?",
"이 판결문에 기술된 사건의 판결은 무엇이었습니까?",
]


async def mock_openai_response(*args, **kwargs) -> ParsedChatCompletion[Response]:
user_prompt = kwargs["messages"][1]["content"]
return ParsedChatCompletion(
id="test_id",
choices=[
ParsedChoice(
finish_reason="stop",
index=0,
message=ParsedChatCompletionMessage(
parsed=Response(
is_passage_dependent=user_prompt.split("\n")[0]
.split(":")[1]
.strip()
in passage_dependent_response
),
role="assistant",
),
)
],
created=int(time.time()),
model="gpt-4o-mini-2024-07-18",
object="chat.completion",
)


async def mock_llama_index_response(*args, **kwargs) -> ChatResponse:
user_prompt = kwargs["messages"][1].content
return ChatResponse(
message=ChatMessage(
role=MessageRole.ASSISTANT,
content=str(
user_prompt.split("\n")[0].split(":")[1].strip()
in passage_dependent_response
),
)
)


@patch.object(
openai.resources.beta.chat.completions.AsyncCompletions,
"parse",
mock_openai_response,
)
def test_passage_dependency_filter_openai():
client = AsyncOpenAI()
en_qa = QA(en_qa_df)
result_en_qa = en_qa.batch_filter(
passage_dependency_filter_openai, client=client, lang="en"
).map(lambda df: df.reset_index(drop=True))
pd.testing.assert_frame_equal(result_en_qa.data, expected_df_en)

ko_qa = QA(ko_qa_df)
result_ko_qa = ko_qa.batch_filter(
passage_dependency_filter_openai, client=client, lang="ko"
).map(lambda df: df.reset_index(drop=True))
pd.testing.assert_frame_equal(result_ko_qa.data, expected_df_ko)


@patch.object(
OpenAI,
"achat",
mock_llama_index_response,
)
def test_passage_dependency_filter_llama_index():
llm = OpenAI(temperature=0, model="gpt-4o-mini")
en_qa = QA(en_qa_df)
result_en_qa = en_qa.batch_filter(
passage_dependency_filter_llama_index, llm=llm, lang="en"
).map(lambda df: df.reset_index(drop=True))
pd.testing.assert_frame_equal(result_en_qa.data, expected_df_en)

ko_qa = QA(ko_qa_df)
result_ko_qa = ko_qa.batch_filter(
passage_dependency_filter_llama_index, llm=llm, lang="ko"
).map(lambda df: df.reset_index(drop=True))
pd.testing.assert_frame_equal(result_ko_qa.data, expected_df_ko)

0 comments on commit 9312b5b

Please sign in to comment.