refactor(ml): model sessions ()

This commit is contained in:
Mert 2024-06-25 12:00:24 -04:00 committed by GitHub
parent 6538ad8de7
commit 6356c28f64
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 529 additions and 375 deletions
machine-learning/app

View file

@ -8,6 +8,8 @@ from fastapi.testclient import TestClient
from numpy.typing import NDArray
from PIL import Image
from app.config import log
from .main import app
@ -96,12 +98,77 @@ def clip_tokenizer_cfg() -> dict[str, Any]:
@pytest.fixture(scope="function")
def providers(request: pytest.FixtureRequest) -> Iterator[dict[str, Any]]:
def providers(request: pytest.FixtureRequest) -> Iterator[mock.Mock]:
marker = request.node.get_closest_marker("providers")
if marker is None:
raise ValueError("Missing marker 'providers'")
providers = marker.args[0]
with mock.patch("app.models.base.ort.get_available_providers") as mocked:
with mock.patch("app.sessions.ort.ort.get_available_providers") as mocked:
mocked.return_value = providers
yield providers
@pytest.fixture(scope="function")
def ort_pybind() -> Iterator[mock.Mock]:
with mock.patch("app.sessions.ort.ort.capi._pybind_state") as mocked:
yield mocked
@pytest.fixture(scope="function")
def ov_device_ids(request: pytest.FixtureRequest, ort_pybind: mock.Mock) -> Iterator[mock.Mock]:
marker = request.node.get_closest_marker("ov_device_ids")
if marker is None:
raise ValueError("Missing marker 'ov_device_ids'")
ort_pybind.get_available_openvino_device_ids.return_value = marker.args[0]
return ort_pybind
@pytest.fixture(scope="function")
def ort_session() -> Iterator[mock.Mock]:
with mock.patch("app.sessions.ort.ort.InferenceSession") as mocked:
yield mocked
@pytest.fixture(scope="function")
def ann_session() -> Iterator[mock.Mock]:
with mock.patch("app.sessions.ann.Ann") as mocked:
yield mocked
@pytest.fixture(scope="function")
def rmtree() -> Iterator[mock.Mock]:
with mock.patch("app.models.base.rmtree", autospec=True) as mocked:
mocked.avoids_symlink_attacks = True
yield mocked
@pytest.fixture(scope="function")
def path() -> Iterator[mock.Mock]:
path = mock.MagicMock()
path.exists.return_value = True
path.is_dir.return_value = True
path.is_file.return_value = True
path.with_suffix.return_value = path
path.return_value = path
with mock.patch("app.models.base.Path", return_value=path) as mocked:
yield mocked
@pytest.fixture(scope="function")
def info() -> Iterator[mock.Mock]:
with mock.patch.object(log, "info") as mocked:
yield mocked
@pytest.fixture(scope="function")
def warning() -> Iterator[mock.Mock]:
with mock.patch.object(log, "warning") as mocked:
yield mocked
@pytest.fixture(scope="function")
def snapshot_download() -> Iterator[mock.Mock]:
with mock.patch("app.models.base.snapshot_download") as mocked:
yield mocked