Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/op_examples/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

lm = LM(model="gpt-4o-mini")
rm = SentenceTransformersRM(model="intfloat/e5-base-v2")
vs = FaissVS()
vs = FaissVS()

lotus.settings.configure(lm=lm, rm=rm, vs=vs)
data = {
Expand Down
2 changes: 1 addition & 1 deletion examples/op_examples/dedup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from lotus.vector_store import FaissVS

rm = SentenceTransformersRM(model="intfloat/e5-base-v2")
vs = FaissVS()
vs = FaissVS()
lotus.settings.configure(rm=rm, vs=vs)
data = {
"Text": [
Expand Down
2 changes: 1 addition & 1 deletion examples/op_examples/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
}

# you can optionally set extract_quotes=True to return quotes that support each output
new_df = df.sem_extract(input_cols, output_cols, extract_quotes=True)
new_df = df.sem_extract(input_cols, output_cols, extract_quotes=True)
print(new_df)

# A description can also be omitted for each output column
Expand Down
2 changes: 1 addition & 1 deletion examples/op_examples/join_cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

lm = LM(model="gpt-4o-mini")
rm = SentenceTransformersRM(model="intfloat/e5-base-v2")
vs = FaissVS()
vs = FaissVS()

lotus.settings.configure(lm=lm, rm=rm, vs=vs)
data = {
Expand Down
2 changes: 1 addition & 1 deletion examples/op_examples/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

lm = LM(max_tokens=2048)
rm = SentenceTransformersRM(model="intfloat/e5-base-v2")
vs = FaissVS()
vs = FaissVS()

lotus.settings.configure(lm=lm, rm=rm, vs=vs)
data = {
Expand Down
2 changes: 1 addition & 1 deletion examples/op_examples/sim_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

lm = LM(model="gpt-4o-mini")
rm = LiteLLMRM(model="text-embedding-3-small")
vs = FaissVS()
vs = FaissVS()

lotus.settings.configure(lm=lm, rm=rm, vs=vs)
data = {
Expand Down
7 changes: 2 additions & 5 deletions lotus/dtype_extensions/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,20 +85,17 @@ def copy(self) -> "ImageArray":
def _concat_same_type(cls, to_concat: Sequence["ImageArray"]) -> "ImageArray":
"""
Concatenate multiple ImageArray instances into a single one.

Args:
to_concat (Sequence[ImageArray]): A sequence of ImageArray instances to concatenate.

Returns:
ImageArray: A new ImageArray containing all elements from the input arrays.
"""
# create list of all data
combined_data = np.concatenate([arr._data for arr in to_concat])
return cls._from_sequence(combined_data)




@classmethod
def _from_sequence(cls, scalars, dtype=None, copy=False):
if copy:
Expand Down
4 changes: 3 additions & 1 deletion lotus/nl_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ def parse_cols(text: str) -> list[str]:
matches = re.findall(pattern, text)

if not matches:
raise ValueError("Language expression contains no parameterized columns. Please specify the name of the relevant data column(s) in brackets {} within your language expression.")
raise ValueError(
"Language expression contains no parameterized columns. Please specify the name of the relevant data column(s) in brackets {} within your language expression."
)
return matches


Expand Down
6 changes: 3 additions & 3 deletions lotus/sem_ops/sem_cluster_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ def __call__(
Returns:
pd.DataFrame: The DataFrame with the cluster assignments.
"""
rm = lotus.settings.rm
vs = lotus.settings.vs
if rm is None or vs is None :
rm = lotus.settings.rm
vs = lotus.settings.vs
if rm is None or vs is None:
raise ValueError(
"The retrieval model must be an instance of RM, and the vector store must be an instance of VS. Please configure a valid retrieval model using lotus.settings.configure()"
)
Expand Down
2 changes: 1 addition & 1 deletion lotus/sem_ops/sem_dedup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __call__(
pd.DataFrame: The DataFrame with duplicates removed.
"""
rm = lotus.settings.rm
vs = lotus.settings.vs
vs = lotus.settings.vs
if rm is None or vs is None:
raise ValueError(
"The retrieval model must be an instance of RM, and the vector store must be an instance of VS. Please configure a valid retrieval model using lotus.settings.configure()"
Expand Down
6 changes: 3 additions & 3 deletions lotus/sem_ops/sem_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def __call__(
if K is not None:
# get retriever model and index
rm = lotus.settings.rm
vs = lotus.settings.vs
if rm is None or vs is None :
vs = lotus.settings.vs
if rm is None or vs is None:
raise ValueError(
"The retrieval model must be an instance of RM, and the vector store should be an instance of VS. Please configure a valid retrieval model and vector store using lotus.settings.configure()"
)
Expand All @@ -61,7 +61,7 @@ def __call__(

df_idxs = self._obj.index
cur_min = len(df_idxs)
K = min(K, cur_min)
K = min(K, cur_min)
search_K = K
while True:
query_vectors = rm.convert_query_to_query_vector(query)
Expand Down
4 changes: 2 additions & 2 deletions lotus/sem_ops/sem_sim_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@ def __call__(
raise ValueError("Other Series must have a name")
other = pd.DataFrame({other.name: other})

rm = lotus.settings.rm
rm = lotus.settings.rm
vs = lotus.settings.vs
if not isinstance(rm, RM) or not isinstance(vs, VS):
raise ValueError(
"The retrieval model must be an instance of RM, and the vector store must be an instance of VS. Please configure a valid retrieval model or vector store using lotus.settings.configure()"
)

# load query embeddings from index if they exist
if left_on in self._obj.attrs.get("index_dirs", []):
query_index_dir = self._obj.attrs["index_dirs"][left_on]
Expand Down
8 changes: 2 additions & 6 deletions lotus/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@
class Settings:
# Models
lm: lotus.models.LM | None = None
rm: lotus.models.RM | None = None # supposed to only generate embeddings
rm: lotus.models.RM | None = None # supposed to only generate embeddings
helper_lm: lotus.models.LM | None = None
reranker: lotus.models.Reranker | None = None
vs: lotus.vector_store.VS | None = None

vs: lotus.vector_store.VS | None = None

# Cache settings
enable_cache: bool = False
Expand All @@ -24,16 +23,13 @@ class Settings:
parallel_groupby_max_threads: int = 8

def configure(self, **kwargs):


for key, value in kwargs.items():
if not hasattr(self, key):
raise ValueError(f"Invalid setting: {key}")
setattr(self, key, value)

def __str__(self):
return str(vars(self))



settings = Settings()
2 changes: 1 addition & 1 deletion lotus/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def ret(

# get rmodel and index
rm = lotus.settings.rm
vs = lotus.settings.vs
vs = lotus.settings.vs
if rm is None or vs is None:
raise ValueError(
"The retrieval model must be an instance of RM, and the vector store must be an instance of VS. Please configure a valid retrieval model using lotus.settings.configure()"
Expand Down
2 changes: 1 addition & 1 deletion lotus/vector_store/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
from lotus.vector_store.faiss_vs import FaissVS
from lotus.vector_store.weaviate_vs import WeaviateVS

__all__ = ["VS", "FaissVS", "WeaviateVS"]
__all__ = ["VS", "FaissVS", "WeaviateVS"]
35 changes: 20 additions & 15 deletions tests/extract_with_prefilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

lotus.settings.configure(lm=lm)

# Creating a sample DataFrame
# Creating a sample DataFrame
data = {
"title": [
"Who's Who: Large Language Models Meet Knowledge",
Expand All @@ -19,7 +19,7 @@
"From Tokens to Materials: Leveraging Language Models",
"A Recommendation Model Utilizing Separation Embeddings",
"Incorporating Group Prior into Variational Inference",
"HyQE: Ranking Contexts with Hybrid Query Expansion"
"HyQE: Ranking Contexts with Hybrid Query Expansion",
],
"authors": [
["Quang Hieu Pham", "Hoang Ngo", "Anh Tuan Lu"],
Expand All @@ -31,7 +31,7 @@
["Yuwei Wan", "Tong Xie", "Nan Wu", "Wenjie Zhou"],
["Wenyi Liu", "Rui Wang", "Yuanshuai Luo", "Jian Sun"],
["Han Xu", "Taoxing Pan", "Zhiqiang Liu", "Xia Wu"],
["Weichao Zhou", "Jiaxin Zhang", "Hilat Hasson"]
["Weichao Zhou", "Jiaxin Zhang", "Hilat Hasson"],
],
"abstract": [
"Retrieval-augmented generation (RAG) methods are becoming increasingly popular...",
Expand All @@ -43,25 +43,30 @@
"Exploring the predictive capabilities of language models in material discovery...",
"With the explosive growth of Internet data, users require personalized recommendation...",
"User behavior modeling – which aims to extract latent features...",
"In retrieval-augmented systems, ranking relevant contexts is a challenge..."
"In retrieval-augmented systems, ranking relevant contexts is a challenge...",
],
"arxiv_id": [
"2410.15737", "2410.15884", "2410.15016", "2410.15081",
"2410.15753", "2410.15272", "2410.16165", "2410.15026",
"2410.15098", "2410.15262"
]
"2410.15737",
"2410.15884",
"2410.15016",
"2410.15081",
"2410.15753",
"2410.15272",
"2410.16165",
"2410.15026",
"2410.15098",
"2410.15262",
],
}

# Create DataFrame
df = pd.DataFrame(data)
filtered_df = df.iloc[2:7]

input_cols = ['abstract']
output_cols = {
'topics': None
}
input_cols = ["abstract"]
output_cols = {"topics": None}

new_df1 =filtered_df.sem_extract(input_cols, output_cols)
new_df2 =df.sem_extract(input_cols, output_cols)
new_df1 = filtered_df.sem_extract(input_cols, output_cols)
new_df2 = df.sem_extract(input_cols, output_cols)
print(new_df1)
print(new_df2)
print(new_df2)
44 changes: 21 additions & 23 deletions tests/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,18 @@

@pytest.fixture
def sample_df():
return pd.DataFrame({
"Course Name": [
"Probability and Random Processes",
"Statistics and Data Analysis",
"Cooking Basics",
"Advanced Culinary Arts",
"Digital Circuit Design",
"Computer Architecture"
]
})
return pd.DataFrame(
{
"Course Name": [
"Probability and Random Processes",
"Statistics and Data Analysis",
"Cooking Basics",
"Advanced Culinary Arts",
"Digital Circuit Design",
"Computer Architecture",
]
}
)


class TestClusterBy(BaseTest):
Expand All @@ -26,7 +28,6 @@ def test_basic_clustering(self, sample_df):
assert len(result["cluster_id"].unique()) == 2
assert len(result) == len(sample_df)


# Get the two clusters
cluster_0_courses = set(result[result["cluster_id"] == 0]["Course Name"])
cluster_1_courses = set(result[result["cluster_id"] == 1]["Course Name"])
Expand All @@ -36,17 +37,14 @@ def test_basic_clustering(self, sample_df):
"Probability and Random Processes",
"Statistics and Data Analysis",
"Digital Circuit Design",
"Computer Architecture"
}
culinary_courses = {
"Cooking Basics",
"Advanced Culinary Arts"
"Computer Architecture",
}
culinary_courses = {"Cooking Basics", "Advanced Culinary Arts"}

# Check that one cluster contains tech courses and the other contains culinary courses
assert (cluster_0_courses == tech_courses and cluster_1_courses == culinary_courses) or \
(cluster_1_courses == tech_courses and cluster_0_courses == culinary_courses), \
"Clusters don't match expected course groupings"
assert (cluster_0_courses == tech_courses and cluster_1_courses == culinary_courses) or (
cluster_1_courses == tech_courses and cluster_0_courses == culinary_courses
), "Clusters don't match expected course groupings"

def test_clustering_with_more_clusters(self, sample_df):
"""Test clustering with more clusters than necessary"""
Expand Down Expand Up @@ -75,17 +73,17 @@ def test_clustering_with_empty_dataframe(self):
def test_clustering_similar_items(self, sample_df):
"""Test that similar items are clustered together"""
result = sample_df.sem_cluster_by("Course Name", 3)

# Get cluster IDs for similar courses
stats_cluster = result[result["Course Name"].str.contains("Statistics")]["cluster_id"].iloc[0]
prob_cluster = result[result["Course Name"].str.contains("Probability")]["cluster_id"].iloc[0]

# Similar courses should be in the same cluster
assert stats_cluster == prob_cluster

cooking_cluster = result[result["Course Name"].str.contains("Cooking")]["cluster_id"].iloc[0]
culinary_cluster = result[result["Course Name"].str.contains("Culinary")]["cluster_id"].iloc[0]

assert cooking_cluster == culinary_cluster

def test_clustering_with_verbose(self, sample_df):
Expand All @@ -98,7 +96,7 @@ def test_clustering_with_iterations(self, sample_df):
"""Test clustering with different iteration counts"""
result1 = sample_df.sem_cluster_by("Course Name", 2, niter=5)
result2 = sample_df.sem_cluster_by("Course Name", 2, niter=20)

# Both should produce valid clusterings
assert len(result1["cluster_id"].unique()) == 2
assert len(result2["cluster_id"].unique()) == 2
Loading
Loading