Skip to content

Commit

Permalink
Add MAP (Mean Average Precision) retrieval metric (Marker-Inc-Korea#401)
Browse files Browse the repository at this point in the history
* add mrr not complete just commit

* add retrieval_ndcg_metric

* add retrieval_mrr_metric

* use hits in ndcg

* add next function at mrr

* edit ndcg solution

* change logic ndcg

* change logic mrr

* add metric funcs

* add metric funcs

* just commit for merge

* add retrieval map

---------

Co-authored-by: Jeffrey (Dongkyu) Kim <vkehfdl1@gmail.com>
  • Loading branch information
bwook00 and vkehfdl1 authored Apr 30, 2024
1 parent 951bbcf commit fcfd5f2
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 3 deletions.
2 changes: 1 addition & 1 deletion autorag/evaluate/metric/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .generation import bleu, meteor, rouge, sem_score, g_eval, bert_score
from .retrieval import retrieval_f1, retrieval_recall, retrieval_precision, retrieval_mrr, retrieval_ndcg
from .retrieval import retrieval_f1, retrieval_recall, retrieval_precision, retrieval_mrr, retrieval_ndcg, retrieval_map
from .retrieval_contents import retrieval_token_f1, retrieval_token_precision, retrieval_token_recall
17 changes: 17 additions & 0 deletions autorag/evaluate/metric/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,20 @@ def retrieval_mrr(gt: List[List[str]], pred: List[str]) -> float:
rr_list.append(1.0 / (i + 1))
break
return sum(rr_list) / len(gt_sets) if rr_list else 0.0


@retrieval_metric
def retrieval_map(gt: List[List[str]], pred: List[str]) -> float:
"""
Mean Average Precision (MAP) is the mean of Average Precision (AP) for all queries.
"""
gt_sets = [frozenset(g) for g in gt]

ap_list = []

for gt_set in gt_sets:
pred_hits = [1 if pred_id in gt_set else 0 for pred_id in pred]
precision_list = [sum(pred_hits[:i + 1]) / (i + 1) for i, hit in enumerate(pred_hits) if hit == 1]
ap_list.append(sum(precision_list) / len(precision_list) if precision_list else 0.0)

return sum(ap_list) / len(gt_sets) if ap_list else 0.0
4 changes: 3 additions & 1 deletion autorag/evaluate/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

import pandas as pd

from autorag.evaluate.metric import retrieval_recall, retrieval_precision, retrieval_f1, retrieval_ndcg, retrieval_mrr
from autorag.evaluate.metric import (retrieval_recall, retrieval_precision, retrieval_f1, retrieval_ndcg, retrieval_mrr,
retrieval_map)


def evaluate_retrieval(retrieval_gt: List[List[List[str]]], metrics: List[str]):
Expand All @@ -28,6 +29,7 @@ def wrapper(*args, **kwargs) -> pd.DataFrame:
retrieval_f1.__name__: retrieval_f1,
retrieval_ndcg.__name__: retrieval_ndcg,
retrieval_mrr.__name__: retrieval_mrr,
retrieval_map.__name__: retrieval_map,
}

metric_scores = {}
Expand Down
10 changes: 9 additions & 1 deletion tests/autorag/evaluate/metric/test_retrieval_metric.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

from autorag.evaluate.metric import retrieval_f1, retrieval_precision, retrieval_recall, retrieval_ndcg, retrieval_mrr
from autorag.evaluate.metric import (retrieval_f1, retrieval_precision, retrieval_recall, retrieval_ndcg, retrieval_mrr,
retrieval_map)

retrieval_gt = [
[['test-1', 'test-2'], ['test-3']],
Expand Down Expand Up @@ -58,3 +59,10 @@ def test_retrieval_mrr():
result = retrieval_mrr(retrieval_gt=retrieval_gt, pred_ids=pred)
for gt, res in zip(solution, result):
assert gt == pytest.approx(res, rel=1e-4)


def test_retrieval_map():
solution = [5 / 12, 1 / 3, 1, 1 / 2, 1, None, None, 1 / 3]
result = retrieval_map(retrieval_gt=retrieval_gt, pred_ids=pred)
for gt, res in zip(solution, result):
assert gt == pytest.approx(res, rel=1e-4)

0 comments on commit fcfd5f2

Please sign in to comment.