feat(ml): ML on Rockchip NPUs ()

This commit is contained in:
Yoni Yang 2025-03-18 00:04:08 +08:00 committed by GitHub
parent 1e184a70f1
commit 14c3b99c0f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
43 changed files with 2417 additions and 4726 deletions
machine-learning/app

View file

@ -25,6 +25,7 @@ from app.models.facial_recognition.detection import FaceDetector
from app.models.facial_recognition.recognition import FaceRecognizer
from app.sessions.ann import AnnSession
from app.sessions.ort import OrtSession
from app.sessions.rknn import RknnSession, run_inference
from .config import Settings, settings
from .models.base import InferenceModel
@ -69,6 +70,14 @@ class TestBase:
assert encoder.model_format == ModelFormat.ARMNN
def test_sets_default_model_format_to_rknn_if_available(self, mocker: MockerFixture) -> None:
mocker.patch.object(settings, "rknn", True)
mocker.patch("app.sessions.rknn.is_available", True)
encoder = OpenClipTextualEncoder("ViT-B-32__openai")
assert encoder.model_format == ModelFormat.RKNN
def test_casts_cache_dir_string_to_path(self) -> None:
cache_dir = "/test_cache"
encoder = OpenClipTextualEncoder("ViT-B-32__openai", cache_dir=cache_dir)
@ -125,7 +134,7 @@ class TestBase:
"immich-app/ViT-B-32__openai",
cache_dir=encoder.cache_dir,
local_dir=encoder.cache_dir,
ignore_patterns=["*.armnn"],
ignore_patterns=["*.armnn", "*.rknn"],
)
def test_download_downloads_armnn_if_preferred_format(self, snapshot_download: mock.Mock) -> None:
@ -136,7 +145,18 @@ class TestBase:
"immich-app/ViT-B-32__openai",
cache_dir=encoder.cache_dir,
local_dir=encoder.cache_dir,
ignore_patterns=[],
ignore_patterns=["*.rknn"],
)
def test_download_downloads_rknn_if_preferred_format(self, snapshot_download: mock.Mock) -> None:
encoder = OpenClipTextualEncoder("ViT-B-32__openai", model_format=ModelFormat.RKNN)
encoder.download()
snapshot_download.assert_called_once_with(
"immich-app/ViT-B-32__openai",
cache_dir=encoder.cache_dir,
local_dir=encoder.cache_dir,
ignore_patterns=["*.armnn"],
)
def test_throws_exception_if_model_path_does_not_exist(
@ -328,6 +348,33 @@ class TestAnnSession:
np_spy.assert_has_calls([mock.call(input1), mock.call(input2)])
class TestRknnSession:
def test_creates_rknn_session(self, rknn_session: mock.Mock, info: mock.Mock, mocker: MockerFixture) -> None:
model_path = mock.MagicMock(spec=Path)
tpe = 1
mocker.patch("app.sessions.rknn.soc_name", "rk3566")
mocker.patch("app.sessions.rknn.is_available", True)
RknnSession(model_path)
rknn_session.assert_called_once_with(model_path=model_path.as_posix(), tpes=tpe, func=run_inference)
info.assert_has_calls([mock.call(f"Loaded RKNN model from {model_path} with {tpe} threads.")])
def test_run_rknn(self, rknn_session: mock.Mock, mocker: MockerFixture) -> None:
rknn_session.return_value.load.return_value = 123
np_spy = mocker.spy(np, "ascontiguousarray")
mocker.patch("app.sessions.rknn.soc_name", "rk3566")
session = RknnSession(Path("ViT-B-32__openai"))
[input1, input2] = [np.random.rand(1, 3, 224, 224).astype(np.float32) for _ in range(2)]
input_feed = {"input.1": input1, "input.2": input2}
session.run(None, input_feed)
rknn_session.return_value.put.assert_called_once_with([input1, input2])
np_spy.call_count == 2
np_spy.assert_has_calls([mock.call(input1), mock.call(input2)])
class TestCLIP:
embedding = np.random.rand(512).astype(np.float32)
cache_dir = Path("test_cache")
@ -829,9 +876,7 @@ class TestLoad:
mock_model.clear_cache.assert_not_called()
mock_model.load.assert_not_called()
async def test_falls_back_to_onnx_if_other_format_does_not_exist(
self, exception: mock.Mock, warning: mock.Mock
) -> None:
async def test_falls_back_to_onnx_if_other_format_does_not_exist(self, warning: mock.Mock) -> None:
mock_model = mock.Mock(spec=InferenceModel)
mock_model.model_name = "test_model_name"
mock_model.model_type = ModelType.VISUAL
@ -846,8 +891,9 @@ class TestLoad:
mock_model.clear_cache.assert_not_called()
assert mock_model.load.call_count == 2
exception.assert_called_once_with(error)
warning.assert_called_once_with("ARMNN is available, but model 'test_model_name' does not support it.")
warning.assert_called_once_with(
"ARMNN is available, but model 'test_model_name' does not support it.", exc_info=error
)
mock_model.model_format = ModelFormat.ONNX