mirror of
https://github.com/immich-app/immich.git
synced 2025-06-16 21:38:28 +02:00
feat(ml): better multilingual search with nllb models (#13567)
This commit is contained in:
parent
838a8dd9a6
commit
6789c2ac19
16 changed files with 301 additions and 18 deletions
machine-learning
|
@ -494,6 +494,88 @@ class TestCLIP:
|
|||
assert np.allclose(tokens["text"], np.array([mock_ids], dtype=np.int32), atol=0)
|
||||
mock_tokenizer.encode.assert_called_once_with("test search query")
|
||||
|
||||
def test_openclip_tokenizer_adds_flores_token_for_nllb(
|
||||
self,
|
||||
mocker: MockerFixture,
|
||||
clip_model_cfg: dict[str, Any],
|
||||
clip_tokenizer_cfg: Callable[[Path], dict[str, Any]],
|
||||
) -> None:
|
||||
mocker.patch.object(OpenClipTextualEncoder, "download")
|
||||
mocker.patch.object(OpenClipTextualEncoder, "model_cfg", clip_model_cfg)
|
||||
mocker.patch.object(OpenClipTextualEncoder, "tokenizer_cfg", clip_tokenizer_cfg)
|
||||
mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value
|
||||
mock_tokenizer = mocker.patch("immich_ml.models.clip.textual.Tokenizer.from_file", autospec=True).return_value
|
||||
mock_ids = [randint(0, 50000) for _ in range(77)]
|
||||
mock_tokenizer.encode.return_value = SimpleNamespace(ids=mock_ids)
|
||||
|
||||
clip_encoder = OpenClipTextualEncoder("nllb-clip-base-siglip__mrl", cache_dir="test_cache")
|
||||
clip_encoder._load()
|
||||
clip_encoder.tokenize("test search query", language="de")
|
||||
|
||||
mock_tokenizer.encode.assert_called_once_with("deu_Latntest search query")
|
||||
|
||||
def test_openclip_tokenizer_removes_country_code_from_language_for_nllb_if_not_found(
|
||||
self,
|
||||
mocker: MockerFixture,
|
||||
clip_model_cfg: dict[str, Any],
|
||||
clip_tokenizer_cfg: Callable[[Path], dict[str, Any]],
|
||||
) -> None:
|
||||
mocker.patch.object(OpenClipTextualEncoder, "download")
|
||||
mocker.patch.object(OpenClipTextualEncoder, "model_cfg", clip_model_cfg)
|
||||
mocker.patch.object(OpenClipTextualEncoder, "tokenizer_cfg", clip_tokenizer_cfg)
|
||||
mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value
|
||||
mock_tokenizer = mocker.patch("immich_ml.models.clip.textual.Tokenizer.from_file", autospec=True).return_value
|
||||
mock_ids = [randint(0, 50000) for _ in range(77)]
|
||||
mock_tokenizer.encode.return_value = SimpleNamespace(ids=mock_ids)
|
||||
|
||||
clip_encoder = OpenClipTextualEncoder("nllb-clip-base-siglip__mrl", cache_dir="test_cache")
|
||||
clip_encoder._load()
|
||||
clip_encoder.tokenize("test search query", language="de-CH")
|
||||
|
||||
mock_tokenizer.encode.assert_called_once_with("deu_Latntest search query")
|
||||
|
||||
def test_openclip_tokenizer_falls_back_to_english_for_nllb_if_language_code_not_found(
|
||||
self,
|
||||
mocker: MockerFixture,
|
||||
clip_model_cfg: dict[str, Any],
|
||||
clip_tokenizer_cfg: Callable[[Path], dict[str, Any]],
|
||||
warning: mock.Mock,
|
||||
) -> None:
|
||||
mocker.patch.object(OpenClipTextualEncoder, "download")
|
||||
mocker.patch.object(OpenClipTextualEncoder, "model_cfg", clip_model_cfg)
|
||||
mocker.patch.object(OpenClipTextualEncoder, "tokenizer_cfg", clip_tokenizer_cfg)
|
||||
mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value
|
||||
mock_tokenizer = mocker.patch("immich_ml.models.clip.textual.Tokenizer.from_file", autospec=True).return_value
|
||||
mock_ids = [randint(0, 50000) for _ in range(77)]
|
||||
mock_tokenizer.encode.return_value = SimpleNamespace(ids=mock_ids)
|
||||
|
||||
clip_encoder = OpenClipTextualEncoder("nllb-clip-base-siglip__mrl", cache_dir="test_cache")
|
||||
clip_encoder._load()
|
||||
clip_encoder.tokenize("test search query", language="unknown")
|
||||
|
||||
mock_tokenizer.encode.assert_called_once_with("eng_Latntest search query")
|
||||
warning.assert_called_once_with("Language 'unknown' not found, defaulting to 'en'")
|
||||
|
||||
def test_openclip_tokenizer_does_not_add_flores_token_for_non_nllb_model(
|
||||
self,
|
||||
mocker: MockerFixture,
|
||||
clip_model_cfg: dict[str, Any],
|
||||
clip_tokenizer_cfg: Callable[[Path], dict[str, Any]],
|
||||
) -> None:
|
||||
mocker.patch.object(OpenClipTextualEncoder, "download")
|
||||
mocker.patch.object(OpenClipTextualEncoder, "model_cfg", clip_model_cfg)
|
||||
mocker.patch.object(OpenClipTextualEncoder, "tokenizer_cfg", clip_tokenizer_cfg)
|
||||
mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value
|
||||
mock_tokenizer = mocker.patch("immich_ml.models.clip.textual.Tokenizer.from_file", autospec=True).return_value
|
||||
mock_ids = [randint(0, 50000) for _ in range(77)]
|
||||
mock_tokenizer.encode.return_value = SimpleNamespace(ids=mock_ids)
|
||||
|
||||
clip_encoder = OpenClipTextualEncoder("ViT-B-32__openai", cache_dir="test_cache")
|
||||
clip_encoder._load()
|
||||
clip_encoder.tokenize("test search query", language="de")
|
||||
|
||||
mock_tokenizer.encode.assert_called_once_with("test search query")
|
||||
|
||||
def test_mclip_tokenizer(
|
||||
self,
|
||||
mocker: MockerFixture,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue