from __future__ import annotations

from pathlib import Path
from typing import Any, NamedTuple

import numpy as np
from numpy.typing import NDArray

from immich_ml.config import log, settings
from immich_ml.schemas import SessionNode

from .rknnpool import RknnPoolExecutor, is_available, soc_name

is_available = is_available and settings.rknn
model_prefix = Path("rknpu") / soc_name if is_available and soc_name is not None else None


def run_inference(rknn_lite: Any, input: list[NDArray[np.float32]]) -> list[NDArray[np.float32]]:
    outputs: list[NDArray[np.float32]] = rknn_lite.inference(inputs=input, data_format="nchw")
    return outputs


input_output_mapping: dict[str, dict[str, Any]] = {
    "detection": {
        "input": {"norm_tensor:0": (1, 3, 640, 640)},
        "output": {
            "norm_tensor:1": (12800, 1),
            "norm_tensor:2": (3200, 1),
            "norm_tensor:3": (800, 1),
            "norm_tensor:4": (12800, 4),
            "norm_tensor:5": (3200, 4),
            "norm_tensor:6": (800, 4),
            "norm_tensor:7": (12800, 10),
            "norm_tensor:8": (3200, 10),
            "norm_tensor:9": (800, 10),
        },
    },
    "recognition": {"input": {"norm_tensor:0": (1, 3, 112, 112)}, "output": {"norm_tensor:1": (1, 512)}},
}


class RknnSession:
    def __init__(self, model_path: Path) -> None:
        self.model_type = "detection" if "detection" in model_path.parts else "recognition"
        self.tpe = settings.rknn_threads

        log.info(f"Loading RKNN model from {model_path} with {self.tpe} threads.")
        self.rknnpool = RknnPoolExecutor(model_path=model_path.as_posix(), tpes=self.tpe, func=run_inference)
        log.info(f"Loaded RKNN model from {model_path} with {self.tpe} threads.")

    def get_inputs(self) -> list[SessionNode]:
        return [RknnNode(name=k, shape=v) for k, v in input_output_mapping[self.model_type]["input"].items()]

    def get_outputs(self) -> list[SessionNode]:
        return [RknnNode(name=k, shape=v) for k, v in input_output_mapping[self.model_type]["output"].items()]

    def run(
        self,
        output_names: list[str] | None,
        input_feed: dict[str, NDArray[np.float32]] | dict[str, NDArray[np.int32]],
        run_options: Any = None,
    ) -> list[NDArray[np.float32]]:
        input_data: list[NDArray[np.float32]] = [np.ascontiguousarray(v) for v in input_feed.values()]
        self.rknnpool.put(input_data)
        res = self.rknnpool.get()
        if res is None:
            raise RuntimeError("RKNN inference failed!")
        return res


class RknnNode(NamedTuple):
    name: str | None
    shape: tuple[int, ...]


__all__ = ["RknnSession", "RknnNode", "is_available", "soc_name", "model_prefix"]