from __future__ import annotations

from abc import ABC, abstractmethod
from pathlib import Path
from shutil import rmtree
from typing import Any, ClassVar

from huggingface_hub import snapshot_download

import ann.ann
import app.sessions.rknn as rknn
from app.sessions.ort import OrtSession

from ..config import clean_name, log, settings
from ..schemas import ModelFormat, ModelIdentity, ModelSession, ModelTask, ModelType
from ..sessions.ann import AnnSession


class InferenceModel(ABC):
    depends: ClassVar[list[ModelIdentity]]
    identity: ClassVar[ModelIdentity]

    def __init__(
        self,
        model_name: str,
        cache_dir: Path | str | None = None,
        model_format: ModelFormat | None = None,
        session: ModelSession | None = None,
        **model_kwargs: Any,
    ) -> None:
        self.loaded = session is not None
        self.load_attempts = 0
        self.model_name = clean_name(model_name)
        self.cache_dir = Path(cache_dir) if cache_dir is not None else self._cache_dir_default
        self.model_format = model_format if model_format is not None else self._model_format_default
        if session is not None:
            self.session = session

    def download(self) -> None:
        if not self.cached:
            log.info(
                f"Downloading {self.model_type.replace('-', ' ')} model '{self.model_name}'. This may take a while."
            )
            self._download()

    def load(self) -> None:
        if self.loaded:
            return
        self.load_attempts += 1

        self.download()
        attempt = f"Attempt #{self.load_attempts} to load" if self.load_attempts > 1 else "Loading"
        log.info(f"{attempt} {self.model_type.replace('-', ' ')} model '{self.model_name}' to memory")
        self.session = self._load()
        self.loaded = True

    def predict(self, *inputs: Any, **model_kwargs: Any) -> Any:
        self.load()
        if model_kwargs:
            self.configure(**model_kwargs)
        return self._predict(*inputs, **model_kwargs)

    @abstractmethod
    def _predict(self, *inputs: Any, **model_kwargs: Any) -> Any: ...

    def configure(self, **kwargs: Any) -> None:
        pass

    def _download(self) -> None:
        ignored_patterns: dict[ModelFormat, list[str]] = {
            ModelFormat.ONNX: ["*.armnn", "*.rknn"],
            ModelFormat.ARMNN: ["*.rknn"],
            ModelFormat.RKNN: ["*.armnn"],
        }

        snapshot_download(
            f"immich-app/{clean_name(self.model_name)}",
            cache_dir=self.cache_dir,
            local_dir=self.cache_dir,
            ignore_patterns=ignored_patterns.get(self.model_format, []),
        )

    def _load(self) -> ModelSession:
        return self._make_session(self.model_path)

    def clear_cache(self) -> None:
        if not self.cache_dir.exists():
            log.warning(
                f"Attempted to clear cache for model '{self.model_name}', but cache directory does not exist",
            )
            return
        if not rmtree.avoids_symlink_attacks:
            raise RuntimeError("Attempted to clear cache, but rmtree is not safe on this platform")

        if self.cache_dir.is_dir():
            log.info(f"Cleared cache directory for model '{self.model_name}'.")
            rmtree(self.cache_dir)
        else:
            log.warning(
                (
                    f"Encountered file instead of directory at cache path "
                    f"for '{self.model_name}'. Removing file and replacing with a directory."
                ),
            )
            self.cache_dir.unlink()
        self.cache_dir.mkdir(parents=True, exist_ok=True)

    def _make_session(self, model_path: Path) -> ModelSession:
        if not model_path.is_file():
            raise FileNotFoundError(f"Model file not found: {model_path}")

        match model_path.suffix:
            case ".armnn":
                session: ModelSession = AnnSession(model_path)
            case ".onnx":
                session = OrtSession(model_path)
            case ".rknn":
                session = rknn.RknnSession(model_path)
            case _:
                raise ValueError(f"Unsupported model file type: {model_path.suffix}")
        return session

    def model_path_for_format(self, model_format: ModelFormat) -> Path:
        model_path_prefix = rknn.model_prefix if model_format == ModelFormat.RKNN else None
        if model_path_prefix:
            return self.model_dir / model_path_prefix / f"model.{model_format}"
        return self.model_dir / f"model.{model_format}"

    @property
    def model_dir(self) -> Path:
        return self.cache_dir / self.model_type.value

    @property
    def model_path(self) -> Path:
        return self.model_path_for_format(self.model_format)

    @property
    def model_task(self) -> ModelTask:
        return self.identity[1]

    @property
    def model_type(self) -> ModelType:
        return self.identity[0]

    @property
    def cache_dir(self) -> Path:
        return self._cache_dir

    @cache_dir.setter
    def cache_dir(self, cache_dir: Path) -> None:
        self._cache_dir = cache_dir

    @property
    def _cache_dir_default(self) -> Path:
        return settings.cache_folder / self.model_task.value / self.model_name

    @property
    def cached(self) -> bool:
        return self.model_path.is_file()

    @property
    def model_format(self) -> ModelFormat:
        return self._model_format

    @model_format.setter
    def model_format(self, model_format: ModelFormat) -> None:
        log.debug(f"Setting model format to {model_format}")
        self._model_format = model_format

    @property
    def _model_format_default(self) -> ModelFormat:
        if rknn.is_available:
            return ModelFormat.RKNN
        elif ann.ann.is_available and settings.ann:
            return ModelFormat.ARMNN
        else:
            return ModelFormat.ONNX