Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion libs/infinity_emb/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ benchmark_embed: tests/data/benchmark/benchmark_embed.json
# sudo apt-get apache2-utils

benchmark_embed_vision: tests/data/benchmark/benchmark_embed_image.json
ab -n 10000 -c 10 -l -s 480 \
ab -n 100 -c 50 -l -s 480 \
-T 'application/json' \
-p $< \
http://127.0.0.1:7997/embeddings
Expand Down
18 changes: 10 additions & 8 deletions libs/infinity_emb/infinity_emb/infinity_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,14 +661,16 @@ def typer_option_resolve(*args):
import typer
import uvicorn

try:
import uvloop

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
loopname = "uvloop"
except ImportError:
# Windows does not support uvloop
loopname = "auto"
loopname = "auto"
if sys.version_info < (3, 12):
try:
import uvloop

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
loopname = "uvloop"
except ImportError:
# Windows does not support uvloop
pass

tp = typer.Typer()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,14 @@ def __init__(self, *, engine_args=EngineArgs):
# to be able to could the tokens in another thread
# without corrupting the original.
fm = self._first_module()

self.normalize_embeddings = True

self.mode_colbert = False
if "colbert" in fm.auto_model.config.architectures[0].lower():
self.mode_colbert = True
self.normalize_embeddings = False

self._infinity_tokenizer = copy.deepcopy(fm.tokenizer)
self.eval()
self.engine_args = engine_args
Expand Down Expand Up @@ -102,20 +110,38 @@ def encode_core(self, features: dict[str, "Tensor"]) -> "Tensor":

with torch.no_grad():
features = util.batch_to_device(features, self.device) # type: ignore
out_features: "Tensor" = self.forward(features)["sentence_embedding"]
out: dict[str, "Tensor"] = self.forward(features)
if not self.mode_colbert:
out_features = out["sentence_embedding"].detach().cpu()
else:
out_features = { # type: ignore # noqa
"token_embeddings": out["token_embeddings"].detach().cpu(),
"attention_mask": out["attention_mask"].detach().cpu(),
}

return out_features.detach().cpu()
return out_features

@quant_embedding_decorator()
def encode_post(
self, out_features: "Tensor", normalize_embeddings: bool = True
self,
out_features: "Tensor",
) -> "EmbeddingReturnType":
with torch.inference_mode():
embeddings: "Tensor" = out_features.to(torch.float32)
if normalize_embeddings:
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)

embeddings_np: np.ndarray = embeddings.numpy()
if not self.mode_colbert:
embeddings: "Tensor" = out_features.to(torch.float32)
if self.normalize_embeddings:
Comment on lines +131 to +132
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: redundant check for not self.mode_colbert since it's already in an if not self.mode_colbert block

embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
embeddings_np: np.ndarray = embeddings.numpy()
else:
# remove the attention mask for two inputs with 5 and 3 tokens that's [[1,1,1,1,1],[1,1,1,0,0]]
# and convert to list of numpy arrays
embeddings_np = [ # type: ignore # noqa
z[m].numpy()
for z, m in zip(
out_features["token_embeddings"].to(torch.float32), # type: ignore
out_features["attention_mask"].bool(), # type: ignore
)
]

return embeddings_np

Expand Down
1 change: 1 addition & 0 deletions libs/infinity_emb/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
pytest.DEFAULT_AUDIO_MODEL = "laion/clap-htsat-unfused"
pytest.DEFAULT_IMAGE_MODEL = "wkcn/TinyCLIP-ViT-8M-16-Text-3M-YFCC15M"
pytest.DEFAULT_IMAGE_COLPALI_MODEL = "michaelfeil/colpali-v12-random-testing"
pytest.DEFAULT_COLBERT_MODEL = "michaelfeil/colbert-tiny-random"

pytest.IMAGE_SAMPLE_URL = "https://github.com/michaelfeil/infinity/raw/06fd1f4d8f0a869f4482fc1c78b62a75ccbb66a1/docs/assets/cats_coco_sample.jpg"
pytest.AUDIO_SAMPLE_URL = "https://github.com/michaelfeil/infinity/raw/3b72eb7c14bae06e68ddd07c1f23fe0bf403f220/libs/infinity_emb/tests/data/audio/beep.wav"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import pytest
import torch
from asgi_lifespan import LifespanManager
from httpx import AsyncClient
import numpy as np
from infinity_emb import create_server
from infinity_emb.args import EngineArgs
from infinity_emb.primitives import Device, InferenceEngine

PREFIX = "/v1_sentence_transformers_colbert"
MODEL: str = pytest.DEFAULT_COLBERT_MODEL # type: ignore[assignment]
batch_size = 64 if torch.cuda.is_available() else 8

app = create_server(
url_prefix=PREFIX,
engine_args_list=[
EngineArgs(
model_name_or_path=MODEL,
batch_size=batch_size,
engine=InferenceEngine.torch,
device=Device.auto if not torch.backends.mps.is_available() else Device.cpu,
)
],
)


# @pytest.fixture
# def model_base() -> SentenceTransformer:
# # model = SentenceTransformer(MODEL)
# # if model.device == "cuda":
# # model = model.to(torch.float16)
# # return model
# model


@pytest.fixture()
async def client():
async with AsyncClient(app=app, base_url="http://test", timeout=20) as client, LifespanManager(
app
):
yield client


# def test_load_model(model_base):
# # this makes sure that the error below is not based on a slow download
# # or internal pytorch errors
# model_base.encode(["This is a test sentence."])


@pytest.mark.anyio
async def test_model_route(client):
response = await client.get(f"{PREFIX}/models")
assert response.status_code == 200
rdata = response.json()
assert "data" in rdata
assert rdata["data"][0].get("id", "") == MODEL
assert isinstance(rdata["data"][0].get("stats"), dict)


@pytest.mark.anyio
async def test_embedding(client):
response = await client.post(
f"{PREFIX}/embeddings", json=dict(input=["This is a test", "hi", "hi"], model=MODEL)
)
assert response.status_code == 200
rdata = response.json()
assert "data" in rdata
assert len(rdata["data"]) == 3
# TODO: Check if start and end tokens should be embedded
# TODO: Check if normalization is applied or should be applied?
assert len(rdata["data"][0]["embedding"]) == 6 # This is a test -> 6 tokens
assert len(rdata["data"][1]["embedding"]) == 3 # hi -> 3 tokens
np.testing.assert_allclose(
rdata["data"][1]["embedding"], rdata["data"][2]["embedding"], atol=5e-3
)