fix(ml): tokenization for webli models ()

This commit is contained in:
Mert 2024-08-18 11:05:10 -04:00 committed by GitHub
parent 5ab92f346a
commit 036676d501
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 48 additions and 3 deletions
machine-learning/app

View file

@ -379,13 +379,40 @@ class TestCLIP:
clip_encoder = OpenClipTextualEncoder("ViT-B-32__openai", cache_dir="test_cache")
clip_encoder._load()
tokens = clip_encoder.tokenize("test search query")
tokens = clip_encoder.tokenize("test search query")
assert "text" in tokens
assert isinstance(tokens["text"], np.ndarray)
assert tokens["text"].shape == (1, 77)
assert tokens["text"].dtype == np.int32
assert np.allclose(tokens["text"], np.array([mock_ids], dtype=np.int32), atol=0)
mock_tokenizer.encode.assert_called_once_with("test search query")
def test_openclip_tokenizer_canonicalizes_text(
self,
mocker: MockerFixture,
clip_model_cfg: dict[str, Any],
clip_tokenizer_cfg: Callable[[Path], dict[str, Any]],
) -> None:
clip_model_cfg["text_cfg"]["tokenizer_kwargs"] = {"clean": "canonicalize"}
mocker.patch.object(OpenClipTextualEncoder, "download")
mocker.patch.object(OpenClipTextualEncoder, "model_cfg", clip_model_cfg)
mocker.patch.object(OpenClipTextualEncoder, "tokenizer_cfg", clip_tokenizer_cfg)
mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value
mock_tokenizer = mocker.patch("app.models.clip.textual.Tokenizer.from_file", autospec=True).return_value
mock_ids = [randint(0, 50000) for _ in range(77)]
mock_tokenizer.encode.return_value = SimpleNamespace(ids=mock_ids)
clip_encoder = OpenClipTextualEncoder("ViT-B-32__openai", cache_dir="test_cache")
clip_encoder._load()
tokens = clip_encoder.tokenize("Test Search Query!")
assert "text" in tokens
assert isinstance(tokens["text"], np.ndarray)
assert tokens["text"].shape == (1, 77)
assert tokens["text"].dtype == np.int32
assert np.allclose(tokens["text"], np.array([mock_ids], dtype=np.int32), atol=0)
mock_tokenizer.encode.assert_called_once_with("test search query")
def test_mclip_tokenizer(
self,