Skip to content

Commit

Permalink
Add various option for tokenize BM25 (Marker-Inc-Korea#409)
Browse files Browse the repository at this point in the history
* pass tokenizer name at bm25_ingest

* make tokenize_ko_kiwi and port stemmer

* working new methods of bm25 tokenize

* fix parameter name to bm25_tokenizer and enable new pkl filename rule at evaluator.py

* use bm25_tokenizer at evaluator

* test done at using various bm25 tokenizer at evaluator

* update docs

* add space tokenization method

* enable auto-testing

* resolve error

* fix test error at test_load_node_line

---------

Co-authored-by: jeffrey <vkefhdl1@gmail.com>
  • Loading branch information
vkehfdl1 and jeffrey authored May 1, 2024
1 parent 7cce048 commit dd7c8ba
Show file tree
Hide file tree
Showing 17 changed files with 196 additions and 52 deletions.
18 changes: 11 additions & 7 deletions autorag/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from autorag import embedding_models
from autorag.node_line import run_node_line
from autorag.nodes.retrieval.base import get_bm25_pkl_name
from autorag.nodes.retrieval.bm25 import bm25_ingest
from autorag.nodes.retrieval.vectordb import vectordb_ingest
from autorag.schema import Node
Expand Down Expand Up @@ -104,13 +105,16 @@ def __embed(self, node_lines: Dict[str, List[Node]]):
if any(list(map(lambda nodes: module_type_exists(nodes, 'bm25'), node_lines.values()))):
# ingest BM25 corpus
logger.info('Embedding BM25 corpus...')
bm25_dir = os.path.join(self.project_dir, 'resources', 'bm25.pkl')
if not os.path.exists(os.path.dirname(bm25_dir)):
os.makedirs(os.path.dirname(bm25_dir))
if os.path.exists(bm25_dir):
logger.debug('BM25 corpus already exists.')
else:
bm25_ingest(bm25_dir, self.corpus_data)
bm25_tokenizer_list = list(chain.from_iterable(
map(lambda nodes: extract_values_from_nodes(nodes, 'bm25_tokenizer'), node_lines.values())))
if len(bm25_tokenizer_list) == 0:
bm25_tokenizer_list = ['porter_stemmer']
for bm25_tokenizer in bm25_tokenizer_list:
bm25_dir = os.path.join(self.project_dir, 'resources', get_bm25_pkl_name(bm25_tokenizer))
if not os.path.exists(os.path.dirname(bm25_dir)):
os.makedirs(os.path.dirname(bm25_dir))
# ingest because bm25 supports update new corpus data
bm25_ingest(bm25_dir, self.corpus_data, bm25_tokenizer=bm25_tokenizer)
logger.info('BM25 corpus embedding complete.')
if any(list(map(lambda nodes: module_type_exists(nodes, 'vectordb'), node_lines.values()))):
# load embedding_models in nodes
Expand Down
11 changes: 7 additions & 4 deletions autorag/nodes/retrieval/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ def wrapper(

if func.__name__ == "bm25":
# check if bm25_path and file exists
bm25_path = os.path.join(resources_dir, 'bm25.pkl')
bm25_tokenizer = kwargs.get('bm25_tokenizer', None)
if bm25_tokenizer is None:
bm25_tokenizer = "porter_stemmer"
bm25_path = os.path.join(resources_dir, get_bm25_pkl_name(bm25_tokenizer))
assert bm25_path is not None, "bm25_path must be specified for using bm25 retrieval."
assert os.path.exists(bm25_path), f"bm25_path {bm25_path} does not exist. Please ingest first."
elif func.__name__ == "vectordb":
Expand Down Expand Up @@ -142,6 +145,6 @@ def evenly_distribute_passages(ids: List[List[str]], scores: List[List[float]],
return new_ids, new_scores


def run_retrieval_modules(project_dir: str, previous_result: pd.DataFrame,
module_name: str, module_params: Dict) -> pd.DataFrame:
return
def get_bm25_pkl_name(bm25_tokenizer: str):
bm25_tokenizer = bm25_tokenizer.replace('/', '')
return f'bm25_{bm25_tokenizer}.pkl'
90 changes: 72 additions & 18 deletions autorag/nodes/retrieval/bm25.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,58 @@
import asyncio
import os
import pickle
from typing import List, Dict, Tuple
import re
from typing import List, Dict, Tuple, Callable, Union, Iterable

import numpy as np
import pandas as pd
from kiwipiepy import Kiwi, Token
from llama_index.core.indices.keyword_table.utils import simple_extract_keywords
from nltk import PorterStemmer
from rank_bm25 import BM25Okapi
from transformers import AutoTokenizer
from transformers import AutoTokenizer, PreTrainedTokenizerBase

from autorag.nodes.retrieval.base import retrieval_node, evenly_distribute_passages
from autorag.utils import validate_corpus_dataset
from autorag.utils.util import normalize_string


def tokenize_ko_kiwi(texts: List[str]) -> List[List[str]]:
texts = list(map(lambda x: x.strip().lower(), texts))
kiwi = Kiwi()
tokenized_list: Iterable[List[Token]] = kiwi.tokenize(texts)
return [list(map(lambda x: x.form, token_list)) for token_list in tokenized_list]


def tokenize_porter_stemmer(texts: List[str]) -> List[List[str]]:
def tokenize_remove_stopword(text: str, stemmer) -> List[str]:
text = text.lower()
words = list(simple_extract_keywords(text))
return [stemmer.stem(word) for word in words]

stemmer = PorterStemmer()
tokenized_list: List[List[str]] = list(map(lambda x: tokenize_remove_stopword(x, stemmer), texts))
return tokenized_list


def tokenize_space(texts: List[str]) -> List[List[str]]:
def tokenize_space_text(text: str) -> List[str]:
text = normalize_string(text)
return re.split(r'\s+', text.strip())

return list(map(tokenize_space_text, texts))


BM25_TOKENIZER = {
'porter_stemmer': tokenize_porter_stemmer,
'ko_kiwi': tokenize_ko_kiwi,
'space': tokenize_space,
}


@retrieval_node
def bm25(queries: List[List[str]], top_k: int, bm25_corpus: Dict) -> Tuple[List[List[str
]], List[List[float]]]:
def bm25(queries: List[List[str]], top_k: int, bm25_corpus: Dict, bm25_tokenizer: str = 'porter_stemmer') -> \
Tuple[List[List[str]], List[List[float]]]:
"""
BM25 retrieval function.
You have to load a pickle file that is already ingested.
Expand All @@ -32,13 +70,20 @@ def bm25(queries: List[List[str]], top_k: int, bm25_corpus: Dict) -> Tuple[List[
"passage_id": [], # 2d list of passage_id.
}
:param bm25_tokenizer: The tokenizer name that uses to the BM25.
It supports 'porter_stemmer', 'ko_kiwi', and huggingface `AutoTokenizer`.
You can pass huggingface tokenizer name.
Default is porter_stemmer.
:return: The 2-d list contains a list of passage ids that retrieved from bm25 and 2-d list of its scores.
It will be a length of queries. And each element has a length of top_k.
"""
# check if bm25_corpus is valid
assert ("tokens" and "passage_id" in list(bm25_corpus.keys())), \
"bm25_corpus must contain tokens and passage_id. Please check you ingested bm25 corpus correctly."
tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=False)
tokenizer = select_bm25_tokenizer(bm25_tokenizer)
assert bm25_corpus['tokenizer_name'] == bm25_tokenizer, \
(f"The bm25 corpus tokenizer is {bm25_corpus['tokenizer_name']}, but your input is {bm25_tokenizer}. "
f"You need to ingest again. Delete bm25 pkl file and re-ingest it.")
bm25_instance = BM25Okapi(bm25_corpus["tokens"])
# run async bm25_pure function
tasks = [bm25_pure(input_queries, top_k, tokenizer, bm25_instance, bm25_corpus) for input_queries in queries]
Expand All @@ -49,8 +94,8 @@ def bm25(queries: List[List[str]], top_k: int, bm25_corpus: Dict) -> Tuple[List[
return id_result, score_result


async def bm25_pure(queries: List[str], top_k: int, tokenizer, bm25_api: BM25Okapi, bm25_corpus: Dict) -> Tuple[
List[str], List[float]]:
async def bm25_pure(queries: List[str], top_k: int, tokenizer, bm25_api: BM25Okapi, bm25_corpus: Dict) -> \
Tuple[List[str], List[float]]:
"""
Async BM25 retrieval function.
Its usage is for async retrieval of bm25 row by row.
Expand All @@ -71,7 +116,10 @@ async def bm25_pure(queries: List[str], top_k: int, tokenizer, bm25_api: BM25Oka
:return: The tuple contains a list of passage ids that retrieved from bm25 and its scores.
"""
# I don't make queries operation to async, because queries length might be small, so it will occur overhead.
tokenized_queries = tokenizer(queries).input_ids
if isinstance(tokenizer, PreTrainedTokenizerBase):
tokenized_queries = tokenizer(queries).input_ids
else:
tokenized_queries = tokenizer(queries)
id_result = []
score_result = []
for query in tokenized_queries:
Expand All @@ -91,7 +139,7 @@ async def bm25_pure(queries: List[str], top_k: int, tokenizer, bm25_api: BM25Oka
return list(id_result), list(score_result)


def bm25_ingest(corpus_path: str, corpus_data: pd.DataFrame):
def bm25_ingest(corpus_path: str, corpus_data: pd.DataFrame, bm25_tokenizer: str = 'porter_stemmer'):
if not corpus_path.endswith('.pkl'):
raise ValueError(f"Corpus path {corpus_path} is not a pickle file.")
validate_corpus_dataset(corpus_data)
Expand All @@ -111,13 +159,14 @@ def bm25_ingest(corpus_path: str, corpus_data: pd.DataFrame):
new_passage = corpus_data

if not new_passage.empty:
tokenizer = AutoTokenizer.from_pretrained('gpt2', use_fast=False)
results = new_passage.swifter.apply(lambda x: bm25_tokenize(x['contents'], x['doc_id'], tokenizer), axis=1)
tokenized_corpus, passage_ids = zip(*results)

tokenizer = select_bm25_tokenizer(bm25_tokenizer)
if isinstance(tokenizer, PreTrainedTokenizerBase):
tokenized_corpus = tokenizer(new_passage['contents'].tolist()).input_ids
else:
tokenized_corpus = tokenizer(new_passage['contents'].tolist())
new_bm25_corpus = pd.DataFrame({
'tokens': list(tokenized_corpus),
'passage_id': list(passage_ids),
'tokens': tokenized_corpus,
'passage_id': new_passage['doc_id'].tolist(),
})

if not bm25_corpus.empty:
Expand All @@ -126,10 +175,15 @@ def bm25_ingest(corpus_path: str, corpus_data: pd.DataFrame):
else:
bm25_dict = new_bm25_corpus.to_dict('list')

# add tokenizer name to bm25_dict
bm25_dict['tokenizer_name'] = bm25_tokenizer

with open(corpus_path, 'wb') as w:
pickle.dump(bm25_dict, w)


def bm25_tokenize(queries: List[str], passage_id: str, tokenizer) -> Tuple[List[int], str]:
tokenized_queries = tokenizer(queries).input_ids
return tokenized_queries, passage_id
def select_bm25_tokenizer(bm25_tokenizer: str) -> Callable[[str], List[Union[int, str]]]:
if bm25_tokenizer in list(BM25_TOKENIZER.keys()):
return BM25_TOKENIZER[bm25_tokenizer]

return AutoTokenizer.from_pretrained(bm25_tokenizer, use_fast=False)
28 changes: 27 additions & 1 deletion docs/source/nodes/retrieval/bm25.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,36 @@
The `BM25` is the most popular TF-IDF method for retrieval, which reflects how important a word is to a document. It is often called sparse retrieval. It is different with dense retrieval, which is using embedding model and similarity search. Dense retrieval search passage using semantic similarity, but sparse retrieval uses word counts. If you use documents in specific domains, `BM25` can be more useful than `VectorDB`. It uses the BM25Okapi algorithm for scoring and ranking the passages.

## **Module Parameters**
- **Parameter**: `None`

- **bm25_tokenizer**: You can select which tokenize method you use for bm25.
The default method is 'porter_stemmer'.
And you can choose between 'ko_kiwi', 'space', and huggingface AutoTokenizer name.

### porter_stemmer

The default method to tokenize. It is optimized for English. It divides sentences to word, and extract stem.

It means, stemmer can change 'studying', 'studies' to 'study'.

### ko_kiwi

It uses kiwi tokenizer for Korean language.
We highly recommend to use it for Korean documents.
You can check more information about kiwi at [here](https://github.com/bab2min/Kiwi).

### space

It is simple method to divide words into just space.
It is simple, but it can be a great choice for multilingual documents.

### Huggingface AutoTokenizer

You can use any `AutoTokenizer` from huggingface, like gpt2 or mistralai/Mistral-7B-Instruct-v0.2.
Just type huggingface repo path, and you can use the tokenizer.

## **Example config.yaml**
```yaml
modules:
- module_type: bm25
bm25_tokenizer: [ porter_stemmer, ko_kiwi, space, gpt2 ]
```
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ sentence-transformers # for sentence transformer reranker
FlagEmbedding # for flag embedding reranker
ragas # evaluation data generation & evaluation
ray # for parallel processing
kiwipiepy # for BM25 Korean tokenizer

### LlamaIndex ###
llama-index>=0.10.1
Expand All @@ -39,6 +40,8 @@ llama-index-embeddings-huggingface
llama-index-llms-openai
llama-index-llms-huggingface
llama-index-llms-openai-like
# Retriever
llama-index-retrievers-bm25

# WebUI
streamlit
Expand Down
2 changes: 2 additions & 0 deletions sample_config/config_korean.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ node_lines:
modules:
- module_type: vectordb
embedding_model: openai
- module_type: bm25
bm25_tokenizer: ko_kiwi
top_k: 3
strategy:
metrics:
Expand Down
1 change: 1 addition & 0 deletions sample_config/full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ node_lines:
top_k: 10
retrieval_modules:
- module_type: bm25
bm25_tokenizer: [ porter_stemmer, ko_kiwi, space, gpt2 ]
- module_type: vectordb
embedding_model: openai
modules:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@ def node_line_dir():
yield node_line_dir


def test_evaluate_one_prompt_maker_node(node_line_dir):
def test_evaluate_one_query_expansion_node(node_line_dir):
project_dir = pathlib.PurePath(node_line_dir).parent.parent
qa_path = os.path.join(project_dir, "data", "qa.parquet")
previous_result = pd.read_parquet(qa_path)
sample_previous_result = previous_result.head(2)
sample_retrieval_gt = sample_previous_result['retrieval_gt'].tolist()

retrieval_funcs = [bm25, bm25]
retrieval_params = [{'top_k': 1}, {'top_k': 2}]
retrieval_params = [{'top_k': 1, 'bm25_tokenizer': 'gpt2'}, {'top_k': 2, 'bm25_tokenizer': 'gpt2'}]
best_result = evaluate_one_query_expansion_node(retrieval_funcs, retrieval_params,
sample_expanded_queries, sample_retrieval_gt,
metrics, project_dir, sample_previous_result)
Expand Down Expand Up @@ -99,7 +99,7 @@ def test_run_query_expansion_node(node_line_dir):
'metrics': metrics,
'speed_threshold': 5,
'top_k': 4,
'retrieval_modules': [{'module_type': 'bm25'}],
'retrieval_modules': [{'module_type': 'bm25', 'bm25_tokenizer': 'gpt2'}],
}
best_result = run_query_expansion_node(modules, module_params, previous_result, node_line_dir, strategies)
base_query_expansion_test(best_result, node_line_dir)
Expand Down
Loading

0 comments on commit dd7c8ba

Please sign in to comment.