Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor existing metric python files with input schema #667

Merged
merged 52 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
fa666c6
add Payload data schema
Eastsidegunn Aug 26, 2024
e2f85ad
edit Payload data schema
Eastsidegunn Sep 1, 2024
c9cc00a
edit test code with payload schema
Eastsidegunn Sep 1, 2024
096da36
edit generation.py with payload schema
Eastsidegunn Sep 1, 2024
54156f2
edit retrieval.py with payload schema
Eastsidegunn Sep 1, 2024
e34ddc9
edit retrieval_contents.py with payload schema
Eastsidegunn Sep 1, 2024
f9898a7
edit generator/run.py with payload schema
Eastsidegunn Sep 1, 2024
804e788
edit queryexpansion/run.py with payload schema
Eastsidegunn Sep 1, 2024
8a904ab
edit retrieval/run.py with payload schema
Eastsidegunn Sep 1, 2024
f534b0c
edit passagereranker/run.py with payload schema
Eastsidegunn Sep 1, 2024
16a3bdc
edit passagefilter/run.py with payload schema
Eastsidegunn Sep 1, 2024
223abd8
edit passagepromptmaker/run.py with payload schema
Eastsidegunn Sep 1, 2024
88222c1
Merge branch 'main' into Feature/#655
Eastsidegunn Sep 2, 2024
c046118
edit query_expansion with Payload
Eastsidegunn Sep 2, 2024
3d38e22
edit promptmaker with Payload
Eastsidegunn Sep 2, 2024
91ec088
edit passageaugmenter with Payload
Eastsidegunn Sep 2, 2024
ae6d3d6
rename Payload to MetricInput
Eastsidegunn Sep 2, 2024
ac1e56e
Merge branch 'main' into Feature/#655
vkehfdl1 Sep 4, 2024
abaf150
rename Payload to MetricInput
Eastsidegunn Sep 7, 2024
513a56a
edit test.py with MetricInput
Eastsidegunn Sep 7, 2024
f172034
add is_fileds_notnone MetricInput method
Eastsidegunn Sep 7, 2024
127da1c
edit run, evaluation code for pre-retrieval, retrieval process
Eastsidegunn Sep 7, 2024
d5921f6
edit run, evaluation code for post process
Eastsidegunn Sep 7, 2024
a16eff1
Merge branch 'refs/heads/main' into Feature/#655
Eastsidegunn Sep 7, 2024
349566f
Merge remote-tracking branch 'origin/Feature/#655' into Feature/#655
Eastsidegunn Sep 7, 2024
bd9884c
edit is_fields_notnone to provide flexibility in input data types.
Eastsidegunn Sep 7, 2024
206028a
add metric_input from_dataframe
Eastsidegunn Sep 7, 2024
b85a1d0
add metric_input from_dataframe annotation
Eastsidegunn Sep 7, 2024
f007e19
make check_list method react np.ndarray data format
Eastsidegunn Sep 8, 2024
5bdb764
modify run.py with MetricInput
Eastsidegunn Sep 8, 2024
9353425
Merge branch 'main' into Feature/#655
vkehfdl1 Sep 8, 2024
8b28c02
Renamed MetricInput class variable from ('gt_contents', 'retireval_i…
Eastsidegunn Sep 8, 2024
f3e3dc0
Renamed first parameter of is_fields_notnone method included in Metr…
Eastsidegunn Sep 8, 2024
c4ce922
make _check_list method of MetricInput schema to staticmethod
Eastsidegunn Sep 8, 2024
bfdb91d
rename autorag_metric_loop
Eastsidegunn Sep 8, 2024
d8fdf54
modify metric fields_to_check params with pre-commit: MetricInput var…
Eastsidegunn Sep 8, 2024
4ae5989
modify setattr function to more intutive code style. (set attribution…
Eastsidegunn Sep 8, 2024
872050e
modify code to use autorag.utils.util to_list func
Eastsidegunn Sep 8, 2024
58a0b6e
Merge remote-tracking branch 'origin/Feature/#655' into Feature/#655
Eastsidegunn Sep 8, 2024
a5cb92b
rename variance name 'payloads' to metric_inputs in whole code
Eastsidegunn Sep 8, 2024
832ca97
Renamed MetricInput class variable from ('gt_contents', 'retireval_i…
Eastsidegunn Sep 8, 2024
f28a47d
refactor MetricInput from_dataframe method
Eastsidegunn Sep 9, 2024
42583f6
Refactor: Extract type_checks dictionary from MetricInput class
Eastsidegunn Sep 9, 2024
8eb65bf
Merge branch 'main' into Feature/#655
vkehfdl1 Sep 10, 2024
82e2f8b
Modify the type hint of retrieval_gt_contents from List[str] to List[…
Eastsidegunn Sep 11, 2024
084c06b
Merge branch 'refs/heads/main' into Feature/#655
Eastsidegunn Sep 11, 2024
1023f26
Merge remote-tracking branch 'origin/Feature/#655' into Feature/#655
Eastsidegunn Sep 11, 2024
98aa386
Merge branch 'main' into Feature/#655
vkehfdl1 Sep 11, 2024
ec61c83
- Fixed test_single_token_f1 bug introduced in last commit(changing r…
Eastsidegunn Sep 11, 2024
b4664bb
Merge remote-tracking branch 'origin/Feature/#655' into Feature/#655
Eastsidegunn Sep 11, 2024
5a61095
- Fixed test_single_token_f1 bug introduced in last commit(changing r…
Eastsidegunn Sep 11, 2024
e608857
Merge branch 'main' into Feature/#655
vkehfdl1 Sep 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
edit test.py with MetricInput
  • Loading branch information
Eastsidegunn committed Sep 7, 2024
commit 513a56a31d95ec4ee51c6349dc06d006edb8d257
13 changes: 8 additions & 5 deletions tests/autorag/evaluate/metric/test_generation_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from llama_index.embeddings.openai import OpenAIEmbedding

from autorag.evaluation.metric import bleu, meteor, rouge, sem_score, g_eval, bert_score
from autorag.schema.metricinput import MetricInput
from tests.delete_tests import is_github_action
from tests.mock import mock_get_text_embedding_batch

Expand Down Expand Up @@ -38,9 +39,12 @@
"요즘 세상에서는 예술가가 되려면, AI를 이겨야 한다.",
]


metric_inputs = [MetricInput(generated_texts=gen, generation_gt=gen_gt) for gen, gen_gt in
zip(generations, generation_gts)]
ko_metric_inputs = [MetricInput(generated_texts=gen, generation_gt=gen_gt) for gen, gen_gt in
zip(ko_generations, ko_generation_gts)]
def base_test_generation_metrics(func, solution, **kwargs):
scores = func(generation_gt=generation_gts, generations=generations, **kwargs)
scores = func(metric_inputs, **kwargs)
assert len(scores) == len(generation_gts)
assert all(isinstance(score, float) for score in scores)
assert all(
Expand All @@ -49,7 +53,7 @@ def base_test_generation_metrics(func, solution, **kwargs):


def ko_base_test_generation_metrics(func, solution, **kwargs):
scores = func(generation_gt=ko_generation_gts, generations=ko_generations, **kwargs)
scores = func(ko_metric_inputs, **kwargs)
assert len(scores) == len(ko_generation_gts)
assert all(isinstance(score, float) for score in scores)
assert all(
Expand Down Expand Up @@ -86,8 +90,7 @@ def test_sem_score():
)
def test_sem_score_other_model():
scores = sem_score(
generation_gt=generation_gts,
generations=generations,
metric_inputs=metric_inputs,
embedding_model=OpenAIEmbedding(),
)
assert len(scores) == len(generation_gts)
Expand Down
12 changes: 7 additions & 5 deletions tests/autorag/evaluate/metric/test_retrieval_contents_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
retrieval_token_precision,
retrieval_token_recall,
)
from autorag.schema.metricinput import MetricInput

gt = [
["Enough for drinking water", "Just looking for a water bottle"],
Expand All @@ -23,6 +24,7 @@
["Who is son? He is great player in the world"],
["i love havertz", "i love kai havertz"],
]
metric_inputs = [MetricInput(gt_contents=g, retrieval_contents=p) for g, p in zip(gt, pred)]


def test_single_token_f1():
Expand All @@ -38,23 +40,23 @@ def test_single_token_f1():


def test_retrieval_token_f1():
f1 = retrieval_token_f1.__wrapped__(gt[0], pred[0])
f1 = retrieval_token_f1.__wrapped__(MetricInput(gt_contents=gt[0], retrieval_contents=pred[0]))
assert f1 == pytest.approx(0.38333, rel=0.001)

f1 = retrieval_token_f1.__wrapped__(gt[1], pred[1])
f1 = retrieval_token_f1.__wrapped__(MetricInput(gt_contents=gt[1], retrieval_contents=pred[1]))
assert f1 == pytest.approx(0.797979, rel=0.001)

result_f1 = retrieval_token_f1(gt_contents=gt, pred_contents=pred)
result_f1 = retrieval_token_f1(metric_inputs=metric_inputs)
assert result_f1 == pytest.approx([0.38333, 0.797979, None, None], rel=0.001)


def test_retrieval_token_precision():
result_precision = retrieval_token_precision(gt_contents=gt, pred_contents=pred)
result_precision = retrieval_token_precision(metric_inputs=metric_inputs)
assert result_precision == pytest.approx(
[0.383333, 0.8222222, None, None], rel=0.001
)


def test_retrieval_token_recall():
result_recall = retrieval_token_recall(gt_contents=gt, pred_contents=pred)
result_recall = retrieval_token_recall(metric_inputs=metric_inputs)
assert result_recall == pytest.approx([0.383333, 0.777777, None, None], rel=0.001)
19 changes: 11 additions & 8 deletions tests/autorag/evaluate/metric/test_retrieval_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
retrieval_mrr,
retrieval_map,
)
from autorag.schema.metricinput import MetricInput

retrieval_gt = [
[["test-1", "test-2"], ["test-3"]],
Expand Down Expand Up @@ -36,11 +37,11 @@
["pred-14"], # retrieval_gt is empty so not counted
["pred-15", "pred-16", "test-15"], # recall:1, precision: 1/3, f1: 0.5
]

metric_inputs = [MetricInput(retrieval_gt=ret_gt, retrieval_ids=pr) for ret_gt, pr in zip(retrieval_gt, pred)]

def test_retrieval_f1():
solution = [0.5, 2 / 7, 2 / 5, 4 / 7, 2 / 3, None, None, 0.5]
result = retrieval_f1(retrieval_gt=retrieval_gt, pred_ids=pred)
result = retrieval_f1(metric_inputs=metric_inputs)
for gt, res in zip(solution, result):
assert gt == pytest.approx(res, rel=1e-4)

Expand All @@ -49,21 +50,23 @@ def test_numpy_retrieval_metric():
retrieval_gt_np = [[np.array(["test-1", "test-4"])], np.array([["test-2"]])]
pred_np = np.array([["test-2", "test-3", "test-1"], ["test-5", "test-6", "test-8"]])
solution = [1.0, 0.0]
result = retrieval_recall(retrieval_gt=retrieval_gt_np, pred_ids=pred_np)
metric_inputs_np = [MetricInput(retrieval_gt=ret_gt_np, retrieval_ids=pr_np) for ret_gt_np, pr_np in
zip(retrieval_gt_np, pred_np)]
result = retrieval_recall(metric_inputs=metric_inputs_np)
for gt, res in zip(solution, result):
assert gt == pytest.approx(res, rel=1e-4)


def test_retrieval_recall():
solution = [0.5, 1 / 3, 1, 2 / 3, 1, None, None, 1]
result = retrieval_recall(retrieval_gt=retrieval_gt, pred_ids=pred)
result = retrieval_recall(metric_inputs=metric_inputs)
for gt, res in zip(solution, result):
assert gt == pytest.approx(res, rel=1e-4)


def test_retrieval_precision():
solution = [0.5, 0.25, 0.25, 0.5, 0.5, None, None, 1 / 3]
result = retrieval_precision(retrieval_gt=retrieval_gt, pred_ids=pred)
result = retrieval_precision(metric_inputs=metric_inputs)
for gt, res in zip(solution, result):
assert gt == pytest.approx(res, rel=1e-4)

Expand All @@ -79,20 +82,20 @@ def test_retrieval_ndcg():
None,
0.5,
]
result = retrieval_ndcg(retrieval_gt=retrieval_gt, pred_ids=pred)
result = retrieval_ndcg(metric_inputs=metric_inputs)
for gt, res in zip(solution, result):
assert gt == pytest.approx(res, rel=1e-4)


def test_retrieval_mrr():
solution = [1 / 2, 1 / 3, 1, 1 / 2, 1, None, None, 1 / 3]
result = retrieval_mrr(retrieval_gt=retrieval_gt, pred_ids=pred)
result = retrieval_mrr(metric_inputs=metric_inputs)
for gt, res in zip(solution, result):
assert gt == pytest.approx(res, rel=1e-4)


def test_retrieval_map():
solution = [5 / 12, 1 / 3, 1, 1 / 2, 1, None, None, 1 / 3]
result = retrieval_map(retrieval_gt=retrieval_gt, pred_ids=pred)
result = retrieval_map(metric_inputs=metric_inputs)
for gt, res in zip(solution, result):
assert gt == pytest.approx(res, rel=1e-4)