mirror of
https://github.com/immich-app/immich.git
synced 2025-06-14 21:38:26 +02:00
feat(ml): better multilingual search with nllb models (#13567)
This commit is contained in:
parent
838a8dd9a6
commit
6789c2ac19
16 changed files with 301 additions and 18 deletions
machine-learning
|
@ -10,6 +10,7 @@ from tokenizers import Encoding, Tokenizer
|
|||
|
||||
from immich_ml.config import log
|
||||
from immich_ml.models.base import InferenceModel
|
||||
from immich_ml.models.constants import WEBLATE_TO_FLORES200
|
||||
from immich_ml.models.transforms import clean_text, serialize_np_array
|
||||
from immich_ml.schemas import ModelSession, ModelTask, ModelType
|
||||
|
||||
|
@ -18,8 +19,9 @@ class BaseCLIPTextualEncoder(InferenceModel):
|
|||
depends = []
|
||||
identity = (ModelType.TEXTUAL, ModelTask.SEARCH)
|
||||
|
||||
def _predict(self, inputs: str, **kwargs: Any) -> str:
|
||||
res: NDArray[np.float32] = self.session.run(None, self.tokenize(inputs))[0][0]
|
||||
def _predict(self, inputs: str, language: str | None = None, **kwargs: Any) -> str:
|
||||
tokens = self.tokenize(inputs, language=language)
|
||||
res: NDArray[np.float32] = self.session.run(None, tokens)[0][0]
|
||||
return serialize_np_array(res)
|
||||
|
||||
def _load(self) -> ModelSession:
|
||||
|
@ -28,6 +30,7 @@ class BaseCLIPTextualEncoder(InferenceModel):
|
|||
self.tokenizer = self._load_tokenizer()
|
||||
tokenizer_kwargs: dict[str, Any] | None = self.text_cfg.get("tokenizer_kwargs")
|
||||
self.canonicalize = tokenizer_kwargs is not None and tokenizer_kwargs.get("clean") == "canonicalize"
|
||||
self.is_nllb = self.model_name.startswith("nllb")
|
||||
log.debug(f"Loaded tokenizer for CLIP model '{self.model_name}'")
|
||||
|
||||
return session
|
||||
|
@ -37,7 +40,7 @@ class BaseCLIPTextualEncoder(InferenceModel):
|
|||
pass
|
||||
|
||||
@abstractmethod
|
||||
def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]:
|
||||
def tokenize(self, text: str, language: str | None = None) -> dict[str, NDArray[np.int32]]:
|
||||
pass
|
||||
|
||||
@property
|
||||
|
@ -92,14 +95,23 @@ class OpenClipTextualEncoder(BaseCLIPTextualEncoder):
|
|||
|
||||
return tokenizer
|
||||
|
||||
def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]:
|
||||
def tokenize(self, text: str, language: str | None = None) -> dict[str, NDArray[np.int32]]:
|
||||
text = clean_text(text, canonicalize=self.canonicalize)
|
||||
if self.is_nllb and language is not None:
|
||||
flores_code = WEBLATE_TO_FLORES200.get(language)
|
||||
if flores_code is None:
|
||||
no_country = language.split("-")[0]
|
||||
flores_code = WEBLATE_TO_FLORES200.get(no_country)
|
||||
if flores_code is None:
|
||||
log.warning(f"Language '{language}' not found, defaulting to 'en'")
|
||||
flores_code = "eng_Latn"
|
||||
text = f"{flores_code}{text}"
|
||||
tokens: Encoding = self.tokenizer.encode(text)
|
||||
return {"text": np.array([tokens.ids], dtype=np.int32)}
|
||||
|
||||
|
||||
class MClipTextualEncoder(OpenClipTextualEncoder):
|
||||
def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]:
|
||||
def tokenize(self, text: str, language: str | None = None) -> dict[str, NDArray[np.int32]]:
|
||||
text = clean_text(text, canonicalize=self.canonicalize)
|
||||
tokens: Encoding = self.tokenizer.encode(text)
|
||||
return {
|
||||
|
|
|
@ -86,6 +86,66 @@ RKNN_SUPPORTED_SOCS = ["rk3566", "rk3568", "rk3576", "rk3588"]
|
|||
RKNN_COREMASK_SUPPORTED_SOCS = ["rk3576", "rk3588"]
|
||||
|
||||
|
||||
WEBLATE_TO_FLORES200 = {
|
||||
"af": "afr_Latn",
|
||||
"ar": "arb_Arab",
|
||||
"az": "azj_Latn",
|
||||
"be": "bel_Cyrl",
|
||||
"bg": "bul_Cyrl",
|
||||
"ca": "cat_Latn",
|
||||
"cs": "ces_Latn",
|
||||
"da": "dan_Latn",
|
||||
"de": "deu_Latn",
|
||||
"el": "ell_Grek",
|
||||
"en": "eng_Latn",
|
||||
"es": "spa_Latn",
|
||||
"et": "est_Latn",
|
||||
"fa": "pes_Arab",
|
||||
"fi": "fin_Latn",
|
||||
"fr": "fra_Latn",
|
||||
"he": "heb_Hebr",
|
||||
"hi": "hin_Deva",
|
||||
"hr": "hrv_Latn",
|
||||
"hu": "hun_Latn",
|
||||
"hy": "hye_Armn",
|
||||
"id": "ind_Latn",
|
||||
"it": "ita_Latn",
|
||||
"ja": "jpn_Hira",
|
||||
"kmr": "kmr_Latn",
|
||||
"ko": "kor_Hang",
|
||||
"lb": "ltz_Latn",
|
||||
"lt": "lit_Latn",
|
||||
"lv": "lav_Latn",
|
||||
"mfa": "zsm_Latn",
|
||||
"mk": "mkd_Cyrl",
|
||||
"mn": "khk_Cyrl",
|
||||
"mr": "mar_Deva",
|
||||
"ms": "zsm_Latn",
|
||||
"nb-NO": "nob_Latn",
|
||||
"nn": "nno_Latn",
|
||||
"nl": "nld_Latn",
|
||||
"pl": "pol_Latn",
|
||||
"pt-BR": "por_Latn",
|
||||
"pt": "por_Latn",
|
||||
"ro": "ron_Latn",
|
||||
"ru": "rus_Cyrl",
|
||||
"sk": "slk_Latn",
|
||||
"sl": "slv_Latn",
|
||||
"sr-Cyrl": "srp_Cyrl",
|
||||
"sv": "swe_Latn",
|
||||
"ta": "tam_Taml",
|
||||
"te": "tel_Telu",
|
||||
"th": "tha_Thai",
|
||||
"tr": "tur_Latn",
|
||||
"uk": "ukr_Cyrl",
|
||||
"ur": "urd_Arab",
|
||||
"vi": "vie_Latn",
|
||||
"zh-CN": "zho_Hans",
|
||||
"zh-Hans": "zho_Hans",
|
||||
"zh-TW": "zho_Hant",
|
||||
}
|
||||
|
||||
|
||||
def get_model_source(model_name: str) -> ModelSource | None:
|
||||
cleaned_name = clean_name(model_name)
|
||||
|
||||
|
|
|
@ -494,6 +494,88 @@ class TestCLIP:
|
|||
assert np.allclose(tokens["text"], np.array([mock_ids], dtype=np.int32), atol=0)
|
||||
mock_tokenizer.encode.assert_called_once_with("test search query")
|
||||
|
||||
def test_openclip_tokenizer_adds_flores_token_for_nllb(
|
||||
self,
|
||||
mocker: MockerFixture,
|
||||
clip_model_cfg: dict[str, Any],
|
||||
clip_tokenizer_cfg: Callable[[Path], dict[str, Any]],
|
||||
) -> None:
|
||||
mocker.patch.object(OpenClipTextualEncoder, "download")
|
||||
mocker.patch.object(OpenClipTextualEncoder, "model_cfg", clip_model_cfg)
|
||||
mocker.patch.object(OpenClipTextualEncoder, "tokenizer_cfg", clip_tokenizer_cfg)
|
||||
mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value
|
||||
mock_tokenizer = mocker.patch("immich_ml.models.clip.textual.Tokenizer.from_file", autospec=True).return_value
|
||||
mock_ids = [randint(0, 50000) for _ in range(77)]
|
||||
mock_tokenizer.encode.return_value = SimpleNamespace(ids=mock_ids)
|
||||
|
||||
clip_encoder = OpenClipTextualEncoder("nllb-clip-base-siglip__mrl", cache_dir="test_cache")
|
||||
clip_encoder._load()
|
||||
clip_encoder.tokenize("test search query", language="de")
|
||||
|
||||
mock_tokenizer.encode.assert_called_once_with("deu_Latntest search query")
|
||||
|
||||
def test_openclip_tokenizer_removes_country_code_from_language_for_nllb_if_not_found(
|
||||
self,
|
||||
mocker: MockerFixture,
|
||||
clip_model_cfg: dict[str, Any],
|
||||
clip_tokenizer_cfg: Callable[[Path], dict[str, Any]],
|
||||
) -> None:
|
||||
mocker.patch.object(OpenClipTextualEncoder, "download")
|
||||
mocker.patch.object(OpenClipTextualEncoder, "model_cfg", clip_model_cfg)
|
||||
mocker.patch.object(OpenClipTextualEncoder, "tokenizer_cfg", clip_tokenizer_cfg)
|
||||
mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value
|
||||
mock_tokenizer = mocker.patch("immich_ml.models.clip.textual.Tokenizer.from_file", autospec=True).return_value
|
||||
mock_ids = [randint(0, 50000) for _ in range(77)]
|
||||
mock_tokenizer.encode.return_value = SimpleNamespace(ids=mock_ids)
|
||||
|
||||
clip_encoder = OpenClipTextualEncoder("nllb-clip-base-siglip__mrl", cache_dir="test_cache")
|
||||
clip_encoder._load()
|
||||
clip_encoder.tokenize("test search query", language="de-CH")
|
||||
|
||||
mock_tokenizer.encode.assert_called_once_with("deu_Latntest search query")
|
||||
|
||||
def test_openclip_tokenizer_falls_back_to_english_for_nllb_if_language_code_not_found(
|
||||
self,
|
||||
mocker: MockerFixture,
|
||||
clip_model_cfg: dict[str, Any],
|
||||
clip_tokenizer_cfg: Callable[[Path], dict[str, Any]],
|
||||
warning: mock.Mock,
|
||||
) -> None:
|
||||
mocker.patch.object(OpenClipTextualEncoder, "download")
|
||||
mocker.patch.object(OpenClipTextualEncoder, "model_cfg", clip_model_cfg)
|
||||
mocker.patch.object(OpenClipTextualEncoder, "tokenizer_cfg", clip_tokenizer_cfg)
|
||||
mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value
|
||||
mock_tokenizer = mocker.patch("immich_ml.models.clip.textual.Tokenizer.from_file", autospec=True).return_value
|
||||
mock_ids = [randint(0, 50000) for _ in range(77)]
|
||||
mock_tokenizer.encode.return_value = SimpleNamespace(ids=mock_ids)
|
||||
|
||||
clip_encoder = OpenClipTextualEncoder("nllb-clip-base-siglip__mrl", cache_dir="test_cache")
|
||||
clip_encoder._load()
|
||||
clip_encoder.tokenize("test search query", language="unknown")
|
||||
|
||||
mock_tokenizer.encode.assert_called_once_with("eng_Latntest search query")
|
||||
warning.assert_called_once_with("Language 'unknown' not found, defaulting to 'en'")
|
||||
|
||||
def test_openclip_tokenizer_does_not_add_flores_token_for_non_nllb_model(
|
||||
self,
|
||||
mocker: MockerFixture,
|
||||
clip_model_cfg: dict[str, Any],
|
||||
clip_tokenizer_cfg: Callable[[Path], dict[str, Any]],
|
||||
) -> None:
|
||||
mocker.patch.object(OpenClipTextualEncoder, "download")
|
||||
mocker.patch.object(OpenClipTextualEncoder, "model_cfg", clip_model_cfg)
|
||||
mocker.patch.object(OpenClipTextualEncoder, "tokenizer_cfg", clip_tokenizer_cfg)
|
||||
mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value
|
||||
mock_tokenizer = mocker.patch("immich_ml.models.clip.textual.Tokenizer.from_file", autospec=True).return_value
|
||||
mock_ids = [randint(0, 50000) for _ in range(77)]
|
||||
mock_tokenizer.encode.return_value = SimpleNamespace(ids=mock_ids)
|
||||
|
||||
clip_encoder = OpenClipTextualEncoder("ViT-B-32__openai", cache_dir="test_cache")
|
||||
clip_encoder._load()
|
||||
clip_encoder.tokenize("test search query", language="de")
|
||||
|
||||
mock_tokenizer.encode.assert_called_once_with("test search query")
|
||||
|
||||
def test_mclip_tokenizer(
|
||||
self,
|
||||
mocker: MockerFixture,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue