mirror of
https://github.com/immich-app/immich.git
synced 2025-07-13 20:38:46 +02:00
fix(ml): armnn not being used (#10929)
* fix armnn not being used, move fallback handling to main, add tests * formatting
This commit is contained in:
parent
59aa347912
commit
f43721ec92
7 changed files with 111 additions and 44 deletions
machine-learning/app
|
@ -43,7 +43,7 @@ class TestBase:
|
|||
|
||||
assert encoder.cache_dir == cache_dir
|
||||
|
||||
def test_sets_default_preferred_format(self, mocker: MockerFixture) -> None:
|
||||
def test_sets_default_model_format(self, mocker: MockerFixture) -> None:
|
||||
mocker.patch.object(settings, "ann", True)
|
||||
mocker.patch("ann.ann.is_available", False)
|
||||
|
||||
|
@ -51,7 +51,7 @@ class TestBase:
|
|||
|
||||
assert encoder.model_format == ModelFormat.ONNX
|
||||
|
||||
def test_sets_default_preferred_format_to_armnn_if_available(self, path: mock.Mock, mocker: MockerFixture) -> None:
|
||||
def test_sets_default_model_format_to_armnn_if_available(self, path: mock.Mock, mocker: MockerFixture) -> None:
|
||||
mocker.patch.object(settings, "ann", True)
|
||||
mocker.patch("ann.ann.is_available", True)
|
||||
path.suffix = ".armnn"
|
||||
|
@ -60,11 +60,11 @@ class TestBase:
|
|||
|
||||
assert encoder.model_format == ModelFormat.ARMNN
|
||||
|
||||
def test_sets_preferred_format_kwarg(self, mocker: MockerFixture) -> None:
|
||||
def test_sets_model_format_kwarg(self, mocker: MockerFixture) -> None:
|
||||
mocker.patch.object(settings, "ann", False)
|
||||
mocker.patch("ann.ann.is_available", False)
|
||||
|
||||
encoder = OpenClipTextualEncoder("ViT-B-32__openai", preferred_format=ModelFormat.ARMNN)
|
||||
encoder = OpenClipTextualEncoder("ViT-B-32__openai", model_format=ModelFormat.ARMNN)
|
||||
|
||||
assert encoder.model_format == ModelFormat.ARMNN
|
||||
|
||||
|
@ -129,7 +129,7 @@ class TestBase:
|
|||
)
|
||||
|
||||
def test_download_downloads_armnn_if_preferred_format(self, snapshot_download: mock.Mock) -> None:
|
||||
encoder = OpenClipTextualEncoder("ViT-B-32__openai", preferred_format=ModelFormat.ARMNN)
|
||||
encoder = OpenClipTextualEncoder("ViT-B-32__openai", model_format=ModelFormat.ARMNN)
|
||||
encoder.download()
|
||||
|
||||
snapshot_download.assert_called_once_with(
|
||||
|
@ -140,6 +140,19 @@ class TestBase:
|
|||
ignore_patterns=[],
|
||||
)
|
||||
|
||||
def test_throws_exception_if_model_path_does_not_exist(
|
||||
self, snapshot_download: mock.Mock, ort_session: mock.Mock, path: mock.Mock
|
||||
) -> None:
|
||||
path.return_value.__truediv__.return_value.__truediv__.return_value.is_file.return_value = False
|
||||
|
||||
encoder = OpenClipTextualEncoder("ViT-B-32__openai", cache_dir=path)
|
||||
|
||||
with pytest.raises(FileNotFoundError):
|
||||
encoder.load()
|
||||
|
||||
snapshot_download.assert_called_once()
|
||||
ort_session.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("ort_session")
|
||||
class TestOrtSession:
|
||||
|
@ -467,16 +480,18 @@ class TestFaceRecognition:
|
|||
assert isinstance(call_args[0][0], np.ndarray)
|
||||
assert call_args[0][0].shape == (112, 112, 3)
|
||||
|
||||
def test_recognition_adds_batch_axis_for_ort(self, ort_session: mock.Mock, mocker: MockerFixture) -> None:
|
||||
def test_recognition_adds_batch_axis_for_ort(
|
||||
self, ort_session: mock.Mock, path: mock.Mock, mocker: MockerFixture
|
||||
) -> None:
|
||||
onnx = mocker.patch("app.models.facial_recognition.recognition.onnx", autospec=True)
|
||||
update_dims = mocker.patch(
|
||||
"app.models.facial_recognition.recognition.update_inputs_outputs_dims", autospec=True
|
||||
)
|
||||
mocker.patch("app.models.base.InferenceModel.download")
|
||||
mocker.patch("app.models.facial_recognition.recognition.ArcFaceONNX")
|
||||
|
||||
ort_session.return_value.get_inputs.return_value = [SimpleNamespace(name="input.1", shape=(1, 3, 224, 224))]
|
||||
ort_session.return_value.get_outputs.return_value = [SimpleNamespace(name="output.1", shape=(1, 800))]
|
||||
path.return_value.__truediv__.return_value.__truediv__.return_value.suffix = ".onnx"
|
||||
|
||||
proto = mock.Mock()
|
||||
|
||||
|
@ -492,27 +507,30 @@ class TestFaceRecognition:
|
|||
|
||||
onnx.load.return_value = proto
|
||||
|
||||
face_recognizer = FaceRecognizer("buffalo_s")
|
||||
face_recognizer = FaceRecognizer("buffalo_s", cache_dir=path)
|
||||
face_recognizer.load()
|
||||
|
||||
assert face_recognizer.batch is True
|
||||
update_dims.assert_called_once_with(proto, {"input.1": ["batch", 3, 224, 224]}, {"output.1": ["batch", 800]})
|
||||
onnx.save.assert_called_once_with(update_dims.return_value, face_recognizer.model_path)
|
||||
|
||||
def test_recognition_does_not_add_batch_axis_if_exists(self, ort_session: mock.Mock, mocker: MockerFixture) -> None:
|
||||
def test_recognition_does_not_add_batch_axis_if_exists(
|
||||
self, ort_session: mock.Mock, path: mock.Mock, mocker: MockerFixture
|
||||
) -> None:
|
||||
onnx = mocker.patch("app.models.facial_recognition.recognition.onnx", autospec=True)
|
||||
update_dims = mocker.patch(
|
||||
"app.models.facial_recognition.recognition.update_inputs_outputs_dims", autospec=True
|
||||
)
|
||||
mocker.patch("app.models.base.InferenceModel.download")
|
||||
mocker.patch("app.models.facial_recognition.recognition.ArcFaceONNX")
|
||||
path.return_value.__truediv__.return_value.__truediv__.return_value.suffix = ".onnx"
|
||||
|
||||
inputs = [SimpleNamespace(name="input.1", shape=("batch", 3, 224, 224))]
|
||||
outputs = [SimpleNamespace(name="output.1", shape=("batch", 800))]
|
||||
ort_session.return_value.get_inputs.return_value = inputs
|
||||
ort_session.return_value.get_outputs.return_value = outputs
|
||||
|
||||
face_recognizer = FaceRecognizer("buffalo_s")
|
||||
face_recognizer = FaceRecognizer("buffalo_s", cache_dir=path)
|
||||
face_recognizer.load()
|
||||
|
||||
assert face_recognizer.batch is True
|
||||
|
@ -520,6 +538,30 @@ class TestFaceRecognition:
|
|||
onnx.load.assert_not_called()
|
||||
onnx.save.assert_not_called()
|
||||
|
||||
def test_recognition_does_not_add_batch_axis_for_armnn(
|
||||
self, ann_session: mock.Mock, path: mock.Mock, mocker: MockerFixture
|
||||
) -> None:
|
||||
onnx = mocker.patch("app.models.facial_recognition.recognition.onnx", autospec=True)
|
||||
update_dims = mocker.patch(
|
||||
"app.models.facial_recognition.recognition.update_inputs_outputs_dims", autospec=True
|
||||
)
|
||||
mocker.patch("app.models.base.InferenceModel.download")
|
||||
mocker.patch("app.models.facial_recognition.recognition.ArcFaceONNX")
|
||||
path.return_value.__truediv__.return_value.__truediv__.return_value.suffix = ".armnn"
|
||||
|
||||
inputs = [SimpleNamespace(name="input.1", shape=("batch", 3, 224, 224))]
|
||||
outputs = [SimpleNamespace(name="output.1", shape=("batch", 800))]
|
||||
ann_session.return_value.get_inputs.return_value = inputs
|
||||
ann_session.return_value.get_outputs.return_value = outputs
|
||||
|
||||
face_recognizer = FaceRecognizer("buffalo_s", model_format=ModelFormat.ARMNN, cache_dir=path)
|
||||
face_recognizer.load()
|
||||
|
||||
assert face_recognizer.batch is False
|
||||
update_dims.assert_not_called()
|
||||
onnx.load.assert_not_called()
|
||||
onnx.save.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestCache:
|
||||
|
@ -693,7 +735,7 @@ class TestLoad:
|
|||
mock_model.clear_cache.assert_called_once()
|
||||
assert mock_model.load.call_count == 2
|
||||
|
||||
async def test_load_clears_cache_and_raises_if_os_error_and_already_retried(self) -> None:
|
||||
async def test_load_raises_if_os_error_and_already_retried(self) -> None:
|
||||
mock_model = mock.Mock(spec=InferenceModel)
|
||||
mock_model.model_name = "test_model_name"
|
||||
mock_model.model_type = ModelType.VISUAL
|
||||
|
@ -707,6 +749,27 @@ class TestLoad:
|
|||
mock_model.clear_cache.assert_not_called()
|
||||
mock_model.load.assert_not_called()
|
||||
|
||||
async def test_falls_back_to_onnx_if_other_format_does_not_exist(
|
||||
self, exception: mock.Mock, warning: mock.Mock
|
||||
) -> None:
|
||||
mock_model = mock.Mock(spec=InferenceModel)
|
||||
mock_model.model_name = "test_model_name"
|
||||
mock_model.model_type = ModelType.VISUAL
|
||||
mock_model.model_task = ModelTask.SEARCH
|
||||
mock_model.model_format = ModelFormat.ARMNN
|
||||
mock_model.loaded = False
|
||||
mock_model.load_attempts = 0
|
||||
error = FileNotFoundError()
|
||||
mock_model.load.side_effect = [error, None]
|
||||
|
||||
await load(mock_model)
|
||||
|
||||
mock_model.clear_cache.assert_not_called()
|
||||
assert mock_model.load.call_count == 2
|
||||
exception.assert_called_once_with(error)
|
||||
warning.assert_called_once_with("ARMNN is available, but model 'test_model_name' does not support it.")
|
||||
mock_model.model_format = ModelFormat.ONNX
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not settings.test_full,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue