-
-
Notifications
You must be signed in to change notification settings - Fork 176
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add basic dataset schema for new 'beta' version of data creation (#663)
* add corpus, qa, dataset schema * change qa creation and corpus to the legacy * Revert "change qa creation and corpus to the legacy" This reverts commit e88b57f. * add first v2 data creation function which is generate answer * change to llama_index_gen_gt.py for prevent duplicate library name * add openai structured output generate answer * mock parsed response --------- Co-authored-by: jeffrey <vkefhdl1@gmail.com>
- Loading branch information
Showing
9 changed files
with
314 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# This is v2 version, the next version of data creation | ||
# The legacy (v1) version will be deprecated on AutoRAG version 0.3 | ||
# The legacy (v1) version and new v2 data creation is not compatible with each other |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import itertools | ||
from typing import Dict | ||
|
||
|
||
from llama_index.core.base.llms.base import BaseLLM | ||
from llama_index.core.base.llms.types import MessageRole, ChatMessage | ||
|
||
from autorag.data.beta.generation_gt.prompt import ( | ||
concise_answer_system_prompt, | ||
basic_answer_system_prompt, | ||
) | ||
|
||
|
||
async def make_gen_gt_llama_index(row: Dict, llm: BaseLLM, system_prompt: str): | ||
retrieval_gt_contents = list( | ||
itertools.chain.from_iterable(row["retrieval_gt_contents"]) | ||
) | ||
query = row["query"] | ||
passage_str = "\n".join(retrieval_gt_contents) | ||
user_prompt = f"Text:\n<|text_start|>\n{passage_str}\n<|text_end|>\n\nQuestion:\n{query}\n\nAnswer:" | ||
|
||
response = await llm.achat( | ||
messages=[ | ||
ChatMessage(role=MessageRole.SYSTEM, content=system_prompt), | ||
ChatMessage(role=MessageRole.USER, content=user_prompt), | ||
], | ||
temperature=0.0, | ||
) | ||
row["generation_gt"] = response.message.content | ||
return row | ||
|
||
|
||
async def make_concise_gen_gt(row: Dict, llm: BaseLLM) -> Dict: | ||
return await make_gen_gt_llama_index(row, llm, concise_answer_system_prompt) | ||
|
||
|
||
async def make_basic_gen_gt(row: Dict, llm: BaseLLM) -> Dict: | ||
return await make_gen_gt_llama_index(row, llm, basic_answer_system_prompt) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import itertools | ||
from typing import Dict | ||
|
||
from openai import AsyncClient | ||
from pydantic import BaseModel | ||
|
||
from autorag.data.beta.generation_gt.prompt import ( | ||
basic_answer_system_prompt, | ||
concise_answer_system_prompt, | ||
) | ||
|
||
|
||
class Response(BaseModel): | ||
answer: str | ||
|
||
|
||
async def make_gen_gt_openai( | ||
row: Dict, | ||
client: AsyncClient, | ||
system_prompt: str, | ||
model_name: str = "gpt-4o-2024-08-06", | ||
): | ||
retrieval_gt_contents = list( | ||
itertools.chain.from_iterable(row["retrieval_gt_contents"]) | ||
) | ||
query = row["query"] | ||
passage_str = "\n".join(retrieval_gt_contents) | ||
user_prompt = f"Text:\n<|text_start|>\n{passage_str}\n<|text_end|>\n\nQuestion:\n{query}\n\nAnswer:" | ||
|
||
completion = await client.beta.chat.completions.parse( | ||
model=model_name, | ||
messages=[ | ||
{"role": "system", "content": system_prompt}, | ||
{"role": "user", "content": user_prompt}, | ||
], | ||
temperature=0.0, | ||
response_format=Response, | ||
) | ||
response: Response = completion.choices[0].message.parsed | ||
row["generation_gt"] = response.answer | ||
return row | ||
|
||
|
||
async def make_concise_gen_gt( | ||
row: Dict, client: AsyncClient, model_name: str = "gpt-4o-2024-08-06" | ||
): | ||
""" | ||
Generate concise generation_gt using OpenAI Structured Output for preventing errors. | ||
It generates a concise answer, so it is generally a word or just a phrase. | ||
:param row: The input row of the qa dataframe. | ||
:param client: The OpenAI async client. | ||
:param model_name: The model name that supports structured output. | ||
It has to be "gpt-4o-2024-08-06" or "gpt-4o-mini-2024-07-18". | ||
:return: The output row of the qa dataframe with added "generation_gt" in it. | ||
""" | ||
return await make_gen_gt_openai( | ||
row, client, concise_answer_system_prompt, model_name | ||
) | ||
|
||
|
||
async def make_basic_gen_gt( | ||
row: Dict, client: AsyncClient, model_name: str = "gpt-4o-2024-08-06" | ||
): | ||
""" | ||
Generate basic generation_gt using OpenAI Structured Output for preventing errors. | ||
It generates a "basic" answer, and its prompt is simple. | ||
:param row: The input row of the qa dataframe. | ||
:param client: The OpenAI async client. | ||
:param model_name: The model name that supports structured output. | ||
It has to be "gpt-4o-2024-08-06" or "gpt-4o-mini-2024-07-18". | ||
:return: The output row of the qa dataframe with added "generation_gt" in it. | ||
""" | ||
return await make_gen_gt_openai(row, client, basic_answer_system_prompt, model_name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
concise_answer_system_prompt = """You are an AI assistant to answer the given question in the provide evidence text. | ||
You can find the evidence from the given text about question, and you have to write a proper answer to the given question. | ||
Your answer have to be concise and relevant to the question. | ||
Do not make a verbose answer and make it super clear. | ||
It doesn't have to be an full sentence. It can be the answer is a word or a paraphrase. | ||
""" | ||
basic_answer_system_prompt = """You are an AI assistant to answer the given question in the provide evidence text. | ||
You can find the evidence from the given text about question, and you have to write a proper answer to the given question. | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import logging | ||
from typing import Callable, Optional, Dict, Awaitable | ||
import pandas as pd | ||
from autorag.utils.util import process_batch, get_event_loop | ||
|
||
logger = logging.getLogger("AutoRAG") | ||
|
||
|
||
class Corpus: | ||
def __init__(self, corpus_df: Optional[pd.DataFrame] = None): | ||
self.data = corpus_df | ||
|
||
def batch_apply( | ||
self, fn: Callable[[Dict], Awaitable[Dict]], batch_size: int = 32 | ||
) -> "Corpus": | ||
corpus_dicts = self.data.to_dict(orient="records") | ||
loop = get_event_loop() | ||
tasks = [fn(corpus_dict) for corpus_dict in corpus_dicts] | ||
results = loop.run_until_complete(process_batch(tasks, batch_size)) | ||
return Corpus(pd.DataFrame(results)) | ||
|
||
def map(self, fn: Callable[[pd.DataFrame], pd.DataFrame]) -> "Corpus": | ||
return Corpus(fn(self.data)) | ||
|
||
|
||
class QA: | ||
def __init__(self, qa_df: Optional[pd.DataFrame] = None): | ||
self.data = qa_df | ||
|
||
def batch_apply( | ||
self, fn: Callable[[Dict], Awaitable[Dict]], batch_size: int = 32 | ||
) -> "QA": | ||
qa_dicts = self.data.to_dict(orient="records") | ||
loop = get_event_loop() | ||
tasks = [fn(qa_dict) for qa_dict in qa_dicts] | ||
results = loop.run_until_complete(process_batch(tasks, batch_size)) | ||
return QA(pd.DataFrame(results)) | ||
|
||
def map(self, fn: Callable[[pd.DataFrame], pd.DataFrame]) -> "QA": | ||
return QA(fn(self.data)) | ||
|
||
|
||
class Dataset: | ||
def __init__(self, qa: QA = None, corpus: Corpus = None): | ||
self.qa = qa | ||
self.corpus = corpus | ||
|
||
def map(self, fn: Callable) -> "Dataset": | ||
qa_df, corpus_df = fn(self.qa, self.corpus) | ||
return Dataset(qa_df, corpus_df) | ||
|
||
def flatmap(self, fn: Callable) -> "Dataset": | ||
dataset = fn(self.qa, self.corpus) | ||
if not isinstance(dataset, Dataset): | ||
logger.warning(f"Expected Dataset, got {type(dataset)}") | ||
return Dataset(None, None) | ||
return dataset | ||
|
||
def qa_map(self, fn: Callable) -> "Dataset": | ||
qa_df = fn(self.qa) | ||
return Dataset(qa_df, self.corpus) | ||
|
||
def corpus_map(self, fn: Callable) -> "Dataset": | ||
corpus_df = fn(self.corpus) | ||
return Dataset(self.qa, corpus_df) |
31 changes: 31 additions & 0 deletions
31
tests/autorag/data/beta/generation_gt/base_test_generation_gt.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import pandas as pd | ||
|
||
from autorag.schema.data import QA | ||
|
||
passage = """NewJeans (뉴진스) is a 5-member girl group under ADOR and HYBE Labels. | ||
The members consist of Minji, Hanni, Danielle, Haerin, and Hyein. | ||
They released their debut single “Attention” on July 22, 2022, | ||
followed by their debut extended play, New Jeans, which was released on August 1, 2022.""" | ||
question = "How many members are in the New Jeans?" | ||
qa_df = pd.DataFrame( | ||
{ | ||
"qid": ["jax"], | ||
"retrieval_gt": [[["havertz"]]], | ||
"retrieval_gt_contents": [[[passage]]], | ||
"query": [question], | ||
} | ||
) | ||
|
||
|
||
def check_generation_gt(result_qa: QA): | ||
assert isinstance(result_qa, QA) | ||
assert isinstance(result_qa.data, pd.DataFrame) | ||
assert set(result_qa.data.columns) == { | ||
"qid", | ||
"retrieval_gt", | ||
"retrieval_gt_contents", | ||
"query", | ||
"generation_gt", | ||
} | ||
assert len(result_qa.data) == len(qa_df) | ||
assert result_qa.data["generation_gt"].iloc[0] |
25 changes: 25 additions & 0 deletions
25
tests/autorag/data/beta/generation_gt/test_llama_index_gen_gt.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from llama_index.core.llms import MockLLM | ||
|
||
from autorag.data.beta.generation_gt.llama_index_gen_gt import ( | ||
make_concise_gen_gt, | ||
make_basic_gen_gt, | ||
) | ||
from autorag.schema.data import QA | ||
from tests.autorag.data.beta.generation_gt.base_test_generation_gt import ( | ||
qa_df, | ||
check_generation_gt, | ||
) | ||
|
||
llm = MockLLM() | ||
|
||
|
||
def test_make_concise_gen_gt(): | ||
qa = QA(qa_df) | ||
result_qa = qa.batch_apply(lambda row: make_concise_gen_gt(row, llm)) | ||
check_generation_gt(result_qa) | ||
|
||
|
||
def test_make_basic_gen_gt(): | ||
qa = QA(qa_df) | ||
result_qa = qa.batch_apply(lambda row: make_basic_gen_gt(row, llm)) | ||
check_generation_gt(result_qa) |
68 changes: 68 additions & 0 deletions
68
tests/autorag/data/beta/generation_gt/test_openai_gen_gt.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import time | ||
from unittest.mock import patch | ||
|
||
import openai.resources.beta.chat | ||
from openai import AsyncOpenAI | ||
from openai.types.chat import ( | ||
ParsedChatCompletion, | ||
ParsedChoice, | ||
ParsedChatCompletionMessage, | ||
) | ||
|
||
from autorag.data.beta.generation_gt.openai_gen_gt import ( | ||
make_concise_gen_gt, | ||
make_basic_gen_gt, | ||
Response, | ||
) | ||
from autorag.schema.data import QA | ||
from tests.autorag.data.beta.generation_gt.base_test_generation_gt import ( | ||
qa_df, | ||
check_generation_gt, | ||
) | ||
|
||
client = AsyncOpenAI() | ||
|
||
|
||
async def mock_gen_gt_response(*args, **kwargs) -> ParsedChatCompletion[Response]: | ||
return ParsedChatCompletion( | ||
id="test_id", | ||
choices=[ | ||
ParsedChoice( | ||
finish_reason="stop", | ||
index=0, | ||
message=ParsedChatCompletionMessage( | ||
parsed=Response(answer="mock answer"), | ||
role="assistant", | ||
), | ||
) | ||
], | ||
created=int(time.time()), | ||
model="gpt-4o-mini-2024-07-18", | ||
object="chat.completion", | ||
) | ||
|
||
|
||
@patch.object( | ||
openai.resources.beta.chat.completions.AsyncCompletions, | ||
"parse", | ||
mock_gen_gt_response, | ||
) | ||
def test_make_concise_gen_gt(): | ||
qa = QA(qa_df) | ||
result_qa = qa.batch_apply( | ||
lambda row: make_concise_gen_gt( | ||
row, client, model_name="gpt-4o-mini-2024-07-18" | ||
) | ||
) | ||
check_generation_gt(result_qa) | ||
|
||
|
||
@patch.object( | ||
openai.resources.beta.chat.completions.AsyncCompletions, | ||
"parse", | ||
mock_gen_gt_response, | ||
) | ||
def test_make_basic_gen_gt(): | ||
qa = QA(qa_df) | ||
result_qa = qa.batch_apply(lambda row: make_basic_gen_gt(row, client)) | ||
check_generation_gt(result_qa) |