mirror of
https://github.com/immich-app/immich.git
synced 2025-07-09 09:12:57 +02:00
refactor(ml): model downloading (#3545)
* download facial recognition models * download hf models * simplified logic * updated `predict` for facial recognition * ensure download method is called * fixed repo_id for clip * fixed download destination * use st's own `snapshot_download` * conditional download * fixed predict method * check if loaded * minor fixes * updated mypy overrides * added pytest-mock * updated tests * updated lock
This commit is contained in:
parent
2f26a7edae
commit
c73832bd9c
10 changed files with 350 additions and 274 deletions
machine-learning/app
|
@ -1,5 +1,4 @@
|
|||
from types import SimpleNamespace
|
||||
from typing import Any, Iterator, TypeAlias
|
||||
from typing import Iterator, TypeAlias
|
||||
from unittest import mock
|
||||
|
||||
import numpy as np
|
||||
|
@ -22,91 +21,6 @@ def cv_image(pil_image: Image.Image) -> ndarray:
|
|||
return np.asarray(pil_image)[:, :, ::-1] # PIL uses RGB while cv2 uses BGR
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_classifier_pipeline() -> Iterator[mock.Mock]:
|
||||
with mock.patch("app.models.image_classification.pipeline") as model:
|
||||
classifier_preds = [
|
||||
{"label": "that's an image alright", "score": 0.8},
|
||||
{"label": "well it ends with .jpg", "score": 0.1},
|
||||
{"label": "idk, im just seeing bytes", "score": 0.05},
|
||||
{"label": "not sure", "score": 0.04},
|
||||
{"label": "probably a virus", "score": 0.01},
|
||||
]
|
||||
|
||||
def forward(
|
||||
inputs: Image.Image | list[Image.Image], **kwargs: Any
|
||||
) -> list[dict[str, Any]] | list[list[dict[str, Any]]]:
|
||||
if isinstance(inputs, list) and not all([isinstance(img, Image.Image) for img in inputs]):
|
||||
raise TypeError
|
||||
elif not isinstance(inputs, Image.Image):
|
||||
raise TypeError
|
||||
|
||||
if isinstance(inputs, list):
|
||||
return [classifier_preds] * len(inputs)
|
||||
|
||||
return classifier_preds
|
||||
|
||||
model.return_value = forward
|
||||
yield model
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_st() -> Iterator[mock.Mock]:
|
||||
with mock.patch("app.models.clip.SentenceTransformer") as model:
|
||||
embedding = np.random.rand(512).astype(np.float32)
|
||||
|
||||
def encode(inputs: Image.Image | list[Image.Image], **kwargs: Any) -> ndarray | list[ndarray]:
|
||||
# mypy complains unless isinstance(inputs, list) is used explicitly
|
||||
img_batch = isinstance(inputs, list) and all([isinstance(inst, Image.Image) for inst in inputs])
|
||||
text_batch = isinstance(inputs, list) and all([isinstance(inst, str) for inst in inputs])
|
||||
if isinstance(inputs, list) and not any([img_batch, text_batch]):
|
||||
raise TypeError
|
||||
|
||||
if isinstance(inputs, list):
|
||||
return np.stack([embedding] * len(inputs))
|
||||
|
||||
return embedding
|
||||
|
||||
mocked = mock.Mock()
|
||||
mocked.encode = encode
|
||||
model.return_value = mocked
|
||||
yield model
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_faceanalysis() -> Iterator[mock.Mock]:
|
||||
with mock.patch("app.models.facial_recognition.FaceAnalysis") as model:
|
||||
face_preds = [
|
||||
SimpleNamespace( # this is so these fields can be accessed through dot notation
|
||||
**{
|
||||
"bbox": np.random.rand(4).astype(np.float32),
|
||||
"kps": np.random.rand(5, 2).astype(np.float32),
|
||||
"det_score": np.array([0.67]).astype(np.float32),
|
||||
"normed_embedding": np.random.rand(512).astype(np.float32),
|
||||
}
|
||||
),
|
||||
SimpleNamespace(
|
||||
**{
|
||||
"bbox": np.random.rand(4).astype(np.float32),
|
||||
"kps": np.random.rand(5, 2).astype(np.float32),
|
||||
"det_score": np.array([0.4]).astype(np.float32),
|
||||
"normed_embedding": np.random.rand(512).astype(np.float32),
|
||||
}
|
||||
),
|
||||
]
|
||||
|
||||
def get(image: np.ndarray[int, np.dtype[np.float32]], **kwargs: Any) -> list[SimpleNamespace]:
|
||||
if not isinstance(image, np.ndarray):
|
||||
raise TypeError
|
||||
|
||||
return face_preds
|
||||
|
||||
mocked = mock.Mock()
|
||||
mocked.get = get
|
||||
model.return_value = mocked
|
||||
yield model
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_model() -> Iterator[mock.Mock]:
|
||||
with mock.patch("app.models.cache.InferenceModel.from_model_type", autospec=True) as mocked:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue