diff --git a/docs/docs/features/img/moblie-smart-serach.webp b/docs/docs/features/img/mobile-smart-search.webp similarity index 100% rename from docs/docs/features/img/moblie-smart-serach.webp rename to docs/docs/features/img/mobile-smart-search.webp diff --git a/docs/docs/features/searching.md b/docs/docs/features/searching.md index 15f83949f2..7c7e387218 100644 --- a/docs/docs/features/searching.md +++ b/docs/docs/features/searching.md @@ -45,7 +45,7 @@ Some search examples: </TabItem> <TabItem value="Mobile" label="Mobile"> -<img src={require('./img/moblie-smart-serach.webp').default} width="30%" title='Smart search on mobile' /> +<img src={require('./img/mobile-smart-search.webp').default} width="30%" title='Smart search on mobile' /> </TabItem> </Tabs> @@ -56,7 +56,20 @@ Navigating to `Administration > Settings > Machine Learning Settings > Smart Sea ### CLIP models -More powerful models can be used for more accurate search results, but are slower and can require more server resources. Check the dropdowns below to see how they compare in memory usage, speed and quality by language. +The default search model is fast, but there are many other options that can provide better search results. The tradeoff of using these models is that they're slower and/or use more memory (both when indexing images with background Smart Search jobs and when searching). + +The first step of choosing the right model for you is to know which languages your users will search in. + +If your users will only search in English, then the [CLIP][huggingface-clip] section is the first place to look. This is a curated list of the models that generally perform the best for their size class. The models here are ordered from higher to lower quality. This means that the top models will generally rank the most relevant results higher and have a higher capacity to understand descriptive, detailed, and/or niche queries. The models are also generally ordered from larger to smaller, so consider the impact on memory usage, job processing and search speed when deciding on one. The smaller models in this list are not too different in quality and many times faster. + +[Multilingual models][huggingface-multilingual-clip] are also available so users can search in their native language. Use these models if you expect non-English searches to be common. They can be separated into three search patterns: + +- `nllb` models expect the search query to be in the language specified in the user settings +- `xlm` and `siglip2` models understand search text regardless of the current language setting + +`nllb` models tend to perform the best and are recommended when users primarily searches in their native, non-English language. `xlm` and `siglip2` models are more flexible and are recommended for mixed language search, where the same user might search in different languages at different times. + +For more details, check the tables below to see how they compare in memory usage, speed and quality by language. Once you've chosen a model, follow these steps: diff --git a/machine-learning/immich_ml/models/clip/textual.py b/machine-learning/immich_ml/models/clip/textual.py index 603cd29400..c1b3a9eba4 100644 --- a/machine-learning/immich_ml/models/clip/textual.py +++ b/machine-learning/immich_ml/models/clip/textual.py @@ -10,6 +10,7 @@ from tokenizers import Encoding, Tokenizer from immich_ml.config import log from immich_ml.models.base import InferenceModel +from immich_ml.models.constants import WEBLATE_TO_FLORES200 from immich_ml.models.transforms import clean_text, serialize_np_array from immich_ml.schemas import ModelSession, ModelTask, ModelType @@ -18,8 +19,9 @@ class BaseCLIPTextualEncoder(InferenceModel): depends = [] identity = (ModelType.TEXTUAL, ModelTask.SEARCH) - def _predict(self, inputs: str, **kwargs: Any) -> str: - res: NDArray[np.float32] = self.session.run(None, self.tokenize(inputs))[0][0] + def _predict(self, inputs: str, language: str | None = None, **kwargs: Any) -> str: + tokens = self.tokenize(inputs, language=language) + res: NDArray[np.float32] = self.session.run(None, tokens)[0][0] return serialize_np_array(res) def _load(self) -> ModelSession: @@ -28,6 +30,7 @@ class BaseCLIPTextualEncoder(InferenceModel): self.tokenizer = self._load_tokenizer() tokenizer_kwargs: dict[str, Any] | None = self.text_cfg.get("tokenizer_kwargs") self.canonicalize = tokenizer_kwargs is not None and tokenizer_kwargs.get("clean") == "canonicalize" + self.is_nllb = self.model_name.startswith("nllb") log.debug(f"Loaded tokenizer for CLIP model '{self.model_name}'") return session @@ -37,7 +40,7 @@ class BaseCLIPTextualEncoder(InferenceModel): pass @abstractmethod - def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]: + def tokenize(self, text: str, language: str | None = None) -> dict[str, NDArray[np.int32]]: pass @property @@ -92,14 +95,23 @@ class OpenClipTextualEncoder(BaseCLIPTextualEncoder): return tokenizer - def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]: + def tokenize(self, text: str, language: str | None = None) -> dict[str, NDArray[np.int32]]: text = clean_text(text, canonicalize=self.canonicalize) + if self.is_nllb and language is not None: + flores_code = WEBLATE_TO_FLORES200.get(language) + if flores_code is None: + no_country = language.split("-")[0] + flores_code = WEBLATE_TO_FLORES200.get(no_country) + if flores_code is None: + log.warning(f"Language '{language}' not found, defaulting to 'en'") + flores_code = "eng_Latn" + text = f"{flores_code}{text}" tokens: Encoding = self.tokenizer.encode(text) return {"text": np.array([tokens.ids], dtype=np.int32)} class MClipTextualEncoder(OpenClipTextualEncoder): - def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]: + def tokenize(self, text: str, language: str | None = None) -> dict[str, NDArray[np.int32]]: text = clean_text(text, canonicalize=self.canonicalize) tokens: Encoding = self.tokenizer.encode(text) return { diff --git a/machine-learning/immich_ml/models/constants.py b/machine-learning/immich_ml/models/constants.py index 85b5b53991..41b0990f71 100644 --- a/machine-learning/immich_ml/models/constants.py +++ b/machine-learning/immich_ml/models/constants.py @@ -86,6 +86,66 @@ RKNN_SUPPORTED_SOCS = ["rk3566", "rk3568", "rk3576", "rk3588"] RKNN_COREMASK_SUPPORTED_SOCS = ["rk3576", "rk3588"] +WEBLATE_TO_FLORES200 = { + "af": "afr_Latn", + "ar": "arb_Arab", + "az": "azj_Latn", + "be": "bel_Cyrl", + "bg": "bul_Cyrl", + "ca": "cat_Latn", + "cs": "ces_Latn", + "da": "dan_Latn", + "de": "deu_Latn", + "el": "ell_Grek", + "en": "eng_Latn", + "es": "spa_Latn", + "et": "est_Latn", + "fa": "pes_Arab", + "fi": "fin_Latn", + "fr": "fra_Latn", + "he": "heb_Hebr", + "hi": "hin_Deva", + "hr": "hrv_Latn", + "hu": "hun_Latn", + "hy": "hye_Armn", + "id": "ind_Latn", + "it": "ita_Latn", + "ja": "jpn_Hira", + "kmr": "kmr_Latn", + "ko": "kor_Hang", + "lb": "ltz_Latn", + "lt": "lit_Latn", + "lv": "lav_Latn", + "mfa": "zsm_Latn", + "mk": "mkd_Cyrl", + "mn": "khk_Cyrl", + "mr": "mar_Deva", + "ms": "zsm_Latn", + "nb-NO": "nob_Latn", + "nn": "nno_Latn", + "nl": "nld_Latn", + "pl": "pol_Latn", + "pt-BR": "por_Latn", + "pt": "por_Latn", + "ro": "ron_Latn", + "ru": "rus_Cyrl", + "sk": "slk_Latn", + "sl": "slv_Latn", + "sr-Cyrl": "srp_Cyrl", + "sv": "swe_Latn", + "ta": "tam_Taml", + "te": "tel_Telu", + "th": "tha_Thai", + "tr": "tur_Latn", + "uk": "ukr_Cyrl", + "ur": "urd_Arab", + "vi": "vie_Latn", + "zh-CN": "zho_Hans", + "zh-Hans": "zho_Hans", + "zh-TW": "zho_Hant", +} + + def get_model_source(model_name: str) -> ModelSource | None: cleaned_name = clean_name(model_name) diff --git a/machine-learning/test_main.py b/machine-learning/test_main.py index 4a3696f320..a19ec65c5f 100644 --- a/machine-learning/test_main.py +++ b/machine-learning/test_main.py @@ -494,6 +494,88 @@ class TestCLIP: assert np.allclose(tokens["text"], np.array([mock_ids], dtype=np.int32), atol=0) mock_tokenizer.encode.assert_called_once_with("test search query") + def test_openclip_tokenizer_adds_flores_token_for_nllb( + self, + mocker: MockerFixture, + clip_model_cfg: dict[str, Any], + clip_tokenizer_cfg: Callable[[Path], dict[str, Any]], + ) -> None: + mocker.patch.object(OpenClipTextualEncoder, "download") + mocker.patch.object(OpenClipTextualEncoder, "model_cfg", clip_model_cfg) + mocker.patch.object(OpenClipTextualEncoder, "tokenizer_cfg", clip_tokenizer_cfg) + mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value + mock_tokenizer = mocker.patch("immich_ml.models.clip.textual.Tokenizer.from_file", autospec=True).return_value + mock_ids = [randint(0, 50000) for _ in range(77)] + mock_tokenizer.encode.return_value = SimpleNamespace(ids=mock_ids) + + clip_encoder = OpenClipTextualEncoder("nllb-clip-base-siglip__mrl", cache_dir="test_cache") + clip_encoder._load() + clip_encoder.tokenize("test search query", language="de") + + mock_tokenizer.encode.assert_called_once_with("deu_Latntest search query") + + def test_openclip_tokenizer_removes_country_code_from_language_for_nllb_if_not_found( + self, + mocker: MockerFixture, + clip_model_cfg: dict[str, Any], + clip_tokenizer_cfg: Callable[[Path], dict[str, Any]], + ) -> None: + mocker.patch.object(OpenClipTextualEncoder, "download") + mocker.patch.object(OpenClipTextualEncoder, "model_cfg", clip_model_cfg) + mocker.patch.object(OpenClipTextualEncoder, "tokenizer_cfg", clip_tokenizer_cfg) + mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value + mock_tokenizer = mocker.patch("immich_ml.models.clip.textual.Tokenizer.from_file", autospec=True).return_value + mock_ids = [randint(0, 50000) for _ in range(77)] + mock_tokenizer.encode.return_value = SimpleNamespace(ids=mock_ids) + + clip_encoder = OpenClipTextualEncoder("nllb-clip-base-siglip__mrl", cache_dir="test_cache") + clip_encoder._load() + clip_encoder.tokenize("test search query", language="de-CH") + + mock_tokenizer.encode.assert_called_once_with("deu_Latntest search query") + + def test_openclip_tokenizer_falls_back_to_english_for_nllb_if_language_code_not_found( + self, + mocker: MockerFixture, + clip_model_cfg: dict[str, Any], + clip_tokenizer_cfg: Callable[[Path], dict[str, Any]], + warning: mock.Mock, + ) -> None: + mocker.patch.object(OpenClipTextualEncoder, "download") + mocker.patch.object(OpenClipTextualEncoder, "model_cfg", clip_model_cfg) + mocker.patch.object(OpenClipTextualEncoder, "tokenizer_cfg", clip_tokenizer_cfg) + mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value + mock_tokenizer = mocker.patch("immich_ml.models.clip.textual.Tokenizer.from_file", autospec=True).return_value + mock_ids = [randint(0, 50000) for _ in range(77)] + mock_tokenizer.encode.return_value = SimpleNamespace(ids=mock_ids) + + clip_encoder = OpenClipTextualEncoder("nllb-clip-base-siglip__mrl", cache_dir="test_cache") + clip_encoder._load() + clip_encoder.tokenize("test search query", language="unknown") + + mock_tokenizer.encode.assert_called_once_with("eng_Latntest search query") + warning.assert_called_once_with("Language 'unknown' not found, defaulting to 'en'") + + def test_openclip_tokenizer_does_not_add_flores_token_for_non_nllb_model( + self, + mocker: MockerFixture, + clip_model_cfg: dict[str, Any], + clip_tokenizer_cfg: Callable[[Path], dict[str, Any]], + ) -> None: + mocker.patch.object(OpenClipTextualEncoder, "download") + mocker.patch.object(OpenClipTextualEncoder, "model_cfg", clip_model_cfg) + mocker.patch.object(OpenClipTextualEncoder, "tokenizer_cfg", clip_tokenizer_cfg) + mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value + mock_tokenizer = mocker.patch("immich_ml.models.clip.textual.Tokenizer.from_file", autospec=True).return_value + mock_ids = [randint(0, 50000) for _ in range(77)] + mock_tokenizer.encode.return_value = SimpleNamespace(ids=mock_ids) + + clip_encoder = OpenClipTextualEncoder("ViT-B-32__openai", cache_dir="test_cache") + clip_encoder._load() + clip_encoder.tokenize("test search query", language="de") + + mock_tokenizer.encode.assert_called_once_with("test search query") + def test_mclip_tokenizer( self, mocker: MockerFixture, diff --git a/mobile/lib/models/search/search_filter.model.dart b/mobile/lib/models/search/search_filter.model.dart index 87e7b24e34..598b71ef4e 100644 --- a/mobile/lib/models/search/search_filter.model.dart +++ b/mobile/lib/models/search/search_filter.model.dart @@ -236,6 +236,7 @@ class SearchFilter { String? context; String? filename; String? description; + String? language; Set<Person> people; SearchLocationFilter location; SearchCameraFilter camera; @@ -249,6 +250,7 @@ class SearchFilter { this.context, this.filename, this.description, + this.language, required this.people, required this.location, required this.camera, @@ -279,6 +281,7 @@ class SearchFilter { String? context, String? filename, String? description, + String? language, Set<Person>? people, SearchLocationFilter? location, SearchCameraFilter? camera, @@ -290,6 +293,7 @@ class SearchFilter { context: context ?? this.context, filename: filename ?? this.filename, description: description ?? this.description, + language: language ?? this.language, people: people ?? this.people, location: location ?? this.location, camera: camera ?? this.camera, @@ -301,7 +305,7 @@ class SearchFilter { @override String toString() { - return 'SearchFilter(context: $context, filename: $filename, description: $description, people: $people, location: $location, camera: $camera, date: $date, display: $display, mediaType: $mediaType)'; + return 'SearchFilter(context: $context, filename: $filename, description: $description, language: $language, people: $people, location: $location, camera: $camera, date: $date, display: $display, mediaType: $mediaType)'; } @override @@ -311,6 +315,7 @@ class SearchFilter { return other.context == context && other.filename == filename && other.description == description && + other.language == language && other.people == people && other.location == location && other.camera == camera && @@ -324,6 +329,7 @@ class SearchFilter { return context.hashCode ^ filename.hashCode ^ description.hashCode ^ + language.hashCode ^ people.hashCode ^ location.hashCode ^ camera.hashCode ^ diff --git a/mobile/lib/pages/search/search.page.dart b/mobile/lib/pages/search/search.page.dart index 9ff8caff1d..b2bed73c6a 100644 --- a/mobile/lib/pages/search/search.page.dart +++ b/mobile/lib/pages/search/search.page.dart @@ -48,6 +48,8 @@ class SearchPage extends HookConsumerWidget { isFavorite: false, ), mediaType: prefilter?.mediaType ?? AssetType.other, + language: + "${context.locale.languageCode}-${context.locale.countryCode}", ), ); diff --git a/mobile/lib/services/search.service.dart b/mobile/lib/services/search.service.dart index 4c6c80abf3..44ace78852 100644 --- a/mobile/lib/services/search.service.dart +++ b/mobile/lib/services/search.service.dart @@ -60,6 +60,7 @@ class SearchService { response = await _apiService.searchApi.searchSmart( SmartSearchDto( query: filter.context!, + language: filter.language, country: filter.location.country, state: filter.location.state, city: filter.location.city, diff --git a/mobile/openapi/lib/model/smart_search_dto.dart b/mobile/openapi/lib/model/smart_search_dto.dart index f377c23f22..47c800ff09 100644 --- a/mobile/openapi/lib/model/smart_search_dto.dart +++ b/mobile/openapi/lib/model/smart_search_dto.dart @@ -25,6 +25,7 @@ class SmartSearchDto { this.isNotInAlbum, this.isOffline, this.isVisible, + this.language, this.lensModel, this.libraryId, this.make, @@ -132,6 +133,14 @@ class SmartSearchDto { /// bool? isVisible; + /// + /// Please note: This property should have been non-nullable! Since the specification file + /// does not include a default value (using the "default:" property), however, the generated + /// source code must fall back to having a nullable type. + /// Consider adding a "default:" property in the specification file to hide this note. + /// + String? language; + String? lensModel; String? libraryId; @@ -271,6 +280,7 @@ class SmartSearchDto { other.isNotInAlbum == isNotInAlbum && other.isOffline == isOffline && other.isVisible == isVisible && + other.language == language && other.lensModel == lensModel && other.libraryId == libraryId && other.make == make && @@ -308,6 +318,7 @@ class SmartSearchDto { (isNotInAlbum == null ? 0 : isNotInAlbum!.hashCode) + (isOffline == null ? 0 : isOffline!.hashCode) + (isVisible == null ? 0 : isVisible!.hashCode) + + (language == null ? 0 : language!.hashCode) + (lensModel == null ? 0 : lensModel!.hashCode) + (libraryId == null ? 0 : libraryId!.hashCode) + (make == null ? 0 : make!.hashCode) + @@ -331,7 +342,7 @@ class SmartSearchDto { (withExif == null ? 0 : withExif!.hashCode); @override - String toString() => 'SmartSearchDto[city=$city, country=$country, createdAfter=$createdAfter, createdBefore=$createdBefore, deviceId=$deviceId, isArchived=$isArchived, isEncoded=$isEncoded, isFavorite=$isFavorite, isMotion=$isMotion, isNotInAlbum=$isNotInAlbum, isOffline=$isOffline, isVisible=$isVisible, lensModel=$lensModel, libraryId=$libraryId, make=$make, model=$model, page=$page, personIds=$personIds, query=$query, rating=$rating, size=$size, state=$state, tagIds=$tagIds, takenAfter=$takenAfter, takenBefore=$takenBefore, trashedAfter=$trashedAfter, trashedBefore=$trashedBefore, type=$type, updatedAfter=$updatedAfter, updatedBefore=$updatedBefore, withArchived=$withArchived, withDeleted=$withDeleted, withExif=$withExif]'; + String toString() => 'SmartSearchDto[city=$city, country=$country, createdAfter=$createdAfter, createdBefore=$createdBefore, deviceId=$deviceId, isArchived=$isArchived, isEncoded=$isEncoded, isFavorite=$isFavorite, isMotion=$isMotion, isNotInAlbum=$isNotInAlbum, isOffline=$isOffline, isVisible=$isVisible, language=$language, lensModel=$lensModel, libraryId=$libraryId, make=$make, model=$model, page=$page, personIds=$personIds, query=$query, rating=$rating, size=$size, state=$state, tagIds=$tagIds, takenAfter=$takenAfter, takenBefore=$takenBefore, trashedAfter=$trashedAfter, trashedBefore=$trashedBefore, type=$type, updatedAfter=$updatedAfter, updatedBefore=$updatedBefore, withArchived=$withArchived, withDeleted=$withDeleted, withExif=$withExif]'; Map<String, dynamic> toJson() { final json = <String, dynamic>{}; @@ -395,6 +406,11 @@ class SmartSearchDto { } else { // json[r'isVisible'] = null; } + if (this.language != null) { + json[r'language'] = this.language; + } else { + // json[r'language'] = null; + } if (this.lensModel != null) { json[r'lensModel'] = this.lensModel; } else { @@ -508,6 +524,7 @@ class SmartSearchDto { isNotInAlbum: mapValueOfType<bool>(json, r'isNotInAlbum'), isOffline: mapValueOfType<bool>(json, r'isOffline'), isVisible: mapValueOfType<bool>(json, r'isVisible'), + language: mapValueOfType<String>(json, r'language'), lensModel: mapValueOfType<String>(json, r'lensModel'), libraryId: mapValueOfType<String>(json, r'libraryId'), make: mapValueOfType<String>(json, r'make'), diff --git a/open-api/immich-openapi-specs.json b/open-api/immich-openapi-specs.json index 5ba08ab80b..b948ef0386 100644 --- a/open-api/immich-openapi-specs.json +++ b/open-api/immich-openapi-specs.json @@ -11853,6 +11853,9 @@ "isVisible": { "type": "boolean" }, + "language": { + "type": "string" + }, "lensModel": { "nullable": true, "type": "string" diff --git a/open-api/typescript-sdk/src/fetch-client.ts b/open-api/typescript-sdk/src/fetch-client.ts index 252ce9bc69..26929ba4e6 100644 --- a/open-api/typescript-sdk/src/fetch-client.ts +++ b/open-api/typescript-sdk/src/fetch-client.ts @@ -924,6 +924,7 @@ export type SmartSearchDto = { isNotInAlbum?: boolean; isOffline?: boolean; isVisible?: boolean; + language?: string; lensModel?: string | null; libraryId?: string | null; make?: string; diff --git a/server/src/dtos/search.dto.ts b/server/src/dtos/search.dto.ts index 3589331c78..e0b5c9b779 100644 --- a/server/src/dtos/search.dto.ts +++ b/server/src/dtos/search.dto.ts @@ -191,6 +191,11 @@ export class SmartSearchDto extends BaseSearchDto { @IsNotEmpty() query!: string; + @IsString() + @IsNotEmpty() + @Optional() + language?: string; + @IsInt() @Min(1) @Type(() => Number) diff --git a/server/src/repositories/machine-learning.repository.ts b/server/src/repositories/machine-learning.repository.ts index 95aa4cff1e..a52bc58bc3 100644 --- a/server/src/repositories/machine-learning.repository.ts +++ b/server/src/repositories/machine-learning.repository.ts @@ -53,6 +53,7 @@ export interface Face { export type FacialRecognitionResponse = { [ModelTask.FACIAL_RECOGNITION]: Face[] } & VisualResponse; export type DetectedFaces = { faces: Face[] } & VisualResponse; export type MachineLearningRequest = ClipVisualRequest | ClipTextualRequest | FacialRecognitionRequest; +export type TextEncodingOptions = ModelOptions & { language?: string }; @Injectable() export class MachineLearningRepository { @@ -170,8 +171,8 @@ export class MachineLearningRepository { return response[ModelTask.SEARCH]; } - async encodeText(urls: string[], text: string, { modelName }: CLIPConfig) { - const request = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: { modelName } } }; + async encodeText(urls: string[], text: string, { language, modelName }: TextEncodingOptions) { + const request = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: { modelName, options: { language } } } }; const response = await this.predict<ClipTextualResponse>(urls, { text }, request); return response[ModelTask.SEARCH]; } diff --git a/server/src/services/search.service.spec.ts b/server/src/services/search.service.spec.ts index 79f3a77ebe..51c6b55e11 100644 --- a/server/src/services/search.service.spec.ts +++ b/server/src/services/search.service.spec.ts @@ -1,3 +1,4 @@ +import { BadRequestException } from '@nestjs/common'; import { mapAsset } from 'src/dtos/asset-response.dto'; import { SearchSuggestionType } from 'src/dtos/search.dto'; import { SearchService } from 'src/services/search.service'; @@ -15,6 +16,7 @@ describe(SearchService.name, () => { beforeEach(() => { ({ sut, mocks } = newTestService(SearchService)); + mocks.partner.getAll.mockResolvedValue([]); }); it('should work', () => { @@ -155,4 +157,83 @@ describe(SearchService.name, () => { expect(mocks.search.getCameraModels).toHaveBeenCalledWith([authStub.user1.user.id], expect.anything()); }); }); + + describe('searchSmart', () => { + beforeEach(() => { + mocks.search.searchSmart.mockResolvedValue({ hasNextPage: false, items: [] }); + mocks.machineLearning.encodeText.mockResolvedValue('[1, 2, 3]'); + }); + + it('should raise a BadRequestException if machine learning is disabled', async () => { + mocks.systemMetadata.get.mockResolvedValue({ + machineLearning: { enabled: false }, + }); + + await expect(sut.searchSmart(authStub.user1, { query: 'test' })).rejects.toThrowError( + new BadRequestException('Smart search is not enabled'), + ); + }); + + it('should raise a BadRequestException if smart search is disabled', async () => { + mocks.systemMetadata.get.mockResolvedValue({ + machineLearning: { clip: { enabled: false } }, + }); + + await expect(sut.searchSmart(authStub.user1, { query: 'test' })).rejects.toThrowError( + new BadRequestException('Smart search is not enabled'), + ); + }); + + it('should work', async () => { + await sut.searchSmart(authStub.user1, { query: 'test' }); + + expect(mocks.machineLearning.encodeText).toHaveBeenCalledWith( + [expect.any(String)], + 'test', + expect.objectContaining({ modelName: expect.any(String) }), + ); + expect(mocks.search.searchSmart).toHaveBeenCalledWith( + { page: 1, size: 100 }, + { query: 'test', embedding: '[1, 2, 3]', userIds: [authStub.user1.user.id] }, + ); + }); + + it('should consider page and size parameters', async () => { + await sut.searchSmart(authStub.user1, { query: 'test', page: 2, size: 50 }); + + expect(mocks.machineLearning.encodeText).toHaveBeenCalledWith( + [expect.any(String)], + 'test', + expect.objectContaining({ modelName: expect.any(String) }), + ); + expect(mocks.search.searchSmart).toHaveBeenCalledWith( + { page: 2, size: 50 }, + expect.objectContaining({ query: 'test', embedding: '[1, 2, 3]', userIds: [authStub.user1.user.id] }), + ); + }); + + it('should use clip model specified in config', async () => { + mocks.systemMetadata.get.mockResolvedValue({ + machineLearning: { clip: { modelName: 'ViT-B-16-SigLIP__webli' } }, + }); + + await sut.searchSmart(authStub.user1, { query: 'test' }); + + expect(mocks.machineLearning.encodeText).toHaveBeenCalledWith( + [expect.any(String)], + 'test', + expect.objectContaining({ modelName: 'ViT-B-16-SigLIP__webli' }), + ); + }); + + it('should use language specified in request', async () => { + await sut.searchSmart(authStub.user1, { query: 'test', language: 'de' }); + + expect(mocks.machineLearning.encodeText).toHaveBeenCalledWith( + [expect.any(String)], + 'test', + expect.objectContaining({ language: 'de' }), + ); + }); + }); }); diff --git a/server/src/services/search.service.ts b/server/src/services/search.service.ts index e2ad9e7f99..1c0c0ad490 100644 --- a/server/src/services/search.service.ts +++ b/server/src/services/search.service.ts @@ -78,12 +78,10 @@ export class SearchService extends BaseService { } const userIds = await this.getUserIdsToSearch(auth); - - const embedding = await this.machineLearningRepository.encodeText( - machineLearning.urls, - dto.query, - machineLearning.clip, - ); + const embedding = await this.machineLearningRepository.encodeText(machineLearning.urls, dto.query, { + modelName: machineLearning.clip.modelName, + language: dto.language, + }); const page = dto.page ?? 1; const size = dto.size || 100; const { hasNextPage, items } = await this.searchRepository.searchSmart( diff --git a/web/src/routes/(user)/search/[[photos=photos]]/[[assetId=id]]/+page.svelte b/web/src/routes/(user)/search/[[photos=photos]]/[[assetId=id]]/+page.svelte index e5e336521c..c750f02aed 100644 --- a/web/src/routes/(user)/search/[[photos=photos]]/[[assetId=id]]/+page.svelte +++ b/web/src/routes/(user)/search/[[photos=photos]]/[[assetId=id]]/+page.svelte @@ -33,7 +33,7 @@ } from '@immich/sdk'; import { mdiArrowLeft, mdiDotsVertical, mdiImageOffOutline, mdiPlus, mdiSelectAll } from '@mdi/js'; import type { Viewport } from '$lib/stores/assets-store.svelte'; - import { locale } from '$lib/stores/preferences.store'; + import { lang, locale } from '$lib/stores/preferences.store'; import LoadingSpinner from '$lib/components/shared-components/loading-spinner.svelte'; import { handlePromiseError } from '$lib/utils'; import { parseUtcDate } from '$lib/utils/date-time'; @@ -153,6 +153,7 @@ page: nextPage, withExif: true, isVisible: true, + language: $lang, ...terms, };