diff --git a/autorag/evaluate/metric/__init__.py b/autorag/evaluate/metric/__init__.py index eea0a44da..4d52144a1 100644 --- a/autorag/evaluate/metric/__init__.py +++ b/autorag/evaluate/metric/__init__.py @@ -1,3 +1,3 @@ from .generation import bleu, meteor, rouge, sem_score, g_eval, bert_score -from .retrieval import retrieval_f1, retrieval_recall, retrieval_precision, retrieval_mrr, retrieval_ndcg +from .retrieval import retrieval_f1, retrieval_recall, retrieval_precision, retrieval_mrr, retrieval_ndcg, retrieval_map from .retrieval_contents import retrieval_token_f1, retrieval_token_precision, retrieval_token_recall diff --git a/autorag/evaluate/metric/retrieval.py b/autorag/evaluate/metric/retrieval.py index 6ca460585..fa3b4798c 100644 --- a/autorag/evaluate/metric/retrieval.py +++ b/autorag/evaluate/metric/retrieval.py @@ -87,3 +87,20 @@ def retrieval_mrr(gt: List[List[str]], pred: List[str]) -> float: rr_list.append(1.0 / (i + 1)) break return sum(rr_list) / len(gt_sets) if rr_list else 0.0 + + +@retrieval_metric +def retrieval_map(gt: List[List[str]], pred: List[str]) -> float: + """ + Mean Average Precision (MAP) is the mean of Average Precision (AP) for all queries. + """ + gt_sets = [frozenset(g) for g in gt] + + ap_list = [] + + for gt_set in gt_sets: + pred_hits = [1 if pred_id in gt_set else 0 for pred_id in pred] + precision_list = [sum(pred_hits[:i + 1]) / (i + 1) for i, hit in enumerate(pred_hits) if hit == 1] + ap_list.append(sum(precision_list) / len(precision_list) if precision_list else 0.0) + + return sum(ap_list) / len(gt_sets) if ap_list else 0.0 diff --git a/autorag/evaluate/retrieval.py b/autorag/evaluate/retrieval.py index 6a9fab205..98874e522 100644 --- a/autorag/evaluate/retrieval.py +++ b/autorag/evaluate/retrieval.py @@ -4,7 +4,8 @@ import pandas as pd -from autorag.evaluate.metric import retrieval_recall, retrieval_precision, retrieval_f1, retrieval_ndcg, retrieval_mrr +from autorag.evaluate.metric import (retrieval_recall, retrieval_precision, retrieval_f1, retrieval_ndcg, retrieval_mrr, + retrieval_map) def evaluate_retrieval(retrieval_gt: List[List[List[str]]], metrics: List[str]): @@ -28,6 +29,7 @@ def wrapper(*args, **kwargs) -> pd.DataFrame: retrieval_f1.__name__: retrieval_f1, retrieval_ndcg.__name__: retrieval_ndcg, retrieval_mrr.__name__: retrieval_mrr, + retrieval_map.__name__: retrieval_map, } metric_scores = {} diff --git a/tests/autorag/evaluate/metric/test_retrieval_metric.py b/tests/autorag/evaluate/metric/test_retrieval_metric.py index 7e3a246c5..5f3dc14f9 100644 --- a/tests/autorag/evaluate/metric/test_retrieval_metric.py +++ b/tests/autorag/evaluate/metric/test_retrieval_metric.py @@ -1,6 +1,7 @@ import pytest -from autorag.evaluate.metric import retrieval_f1, retrieval_precision, retrieval_recall, retrieval_ndcg, retrieval_mrr +from autorag.evaluate.metric import (retrieval_f1, retrieval_precision, retrieval_recall, retrieval_ndcg, retrieval_mrr, + retrieval_map) retrieval_gt = [ [['test-1', 'test-2'], ['test-3']], @@ -58,3 +59,10 @@ def test_retrieval_mrr(): result = retrieval_mrr(retrieval_gt=retrieval_gt, pred_ids=pred) for gt, res in zip(solution, result): assert gt == pytest.approx(res, rel=1e-4) + + +def test_retrieval_map(): + solution = [5 / 12, 1 / 3, 1, 1 / 2, 1, None, None, 1 / 3] + result = retrieval_map(retrieval_gt=retrieval_gt, pred_ids=pred) + for gt, res in zip(solution, result): + assert gt == pytest.approx(res, rel=1e-4)