fix(ml): limit load retries ()

This commit is contained in:
Mert 2024-06-20 14:13:18 -04:00 committed by GitHub
commit a42af06889
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 26 additions and 11 deletions
machine-learning/app

View file

@ -11,6 +11,7 @@ import cv2
import numpy as np
import onnxruntime as ort
import pytest
from fastapi import HTTPException
from fastapi.testclient import TestClient
from PIL import Image
from pytest import MonkeyPatch
@ -627,6 +628,7 @@ class TestLoad:
async def test_load(self) -> None:
mock_model = mock.Mock(spec=InferenceModel)
mock_model.loaded = False
mock_model.load_attempts = 0
res = await load(mock_model)
@ -650,6 +652,7 @@ class TestLoad:
mock_model.model_task = ModelTask.SEARCH
mock_model.load.side_effect = [OSError, None]
mock_model.loaded = False
mock_model.load_attempts = 0
res = await load(mock_model)
@ -657,6 +660,20 @@ class TestLoad:
mock_model.clear_cache.assert_called_once()
assert mock_model.load.call_count == 2
async def test_load_clears_cache_and_raises_if_os_error_and_already_retried(self) -> None:
mock_model = mock.Mock(spec=InferenceModel)
mock_model.model_name = "test_model_name"
mock_model.model_type = ModelType.VISUAL
mock_model.model_task = ModelTask.SEARCH
mock_model.loaded = False
mock_model.load_attempts = 2
with pytest.raises(HTTPException):
await load(mock_model)
mock_model.clear_cache.assert_not_called()
mock_model.load.assert_not_called()
@pytest.mark.skipif(
not settings.test_full,