from typing import Any, List, Optional, Set, Type

from pydantic import ValidationError

from inference.core.entities.requests.inference import InferenceRequestImage
from inference.enterprise.workflows.entities.base import GraphNone
from inference.enterprise.workflows.errors import (
    InvalidStepInputDetected,
    VariableTypeError,
)

STEPS_WITH_IMAGE = {
    "InferenceImage",
    "Crop",
    "AbsoluteStaticCrop",
    "RelativeStaticCrop",
}


def validate_image_is_valid_selector(value: Any, field_name: str = "image") -> None:
    if issubclass(type(value), list):
        if any(not is_selector(selector_or_value=e) for e in value):
            raise ValueError(f"`{field_name}` field can only contain selector values")
    elif not is_selector(selector_or_value=value):
        raise ValueError(f"`{field_name}` field can only contain selector values")


def validate_field_is_in_range_zero_one_or_empty_or_selector(
    value: Any, field_name: str = "confidence"
) -> None:
    if is_selector(selector_or_value=value) or value is None:
        return None
    validate_value_is_empty_or_number_in_range_zero_one(
        value=value, field_name=field_name
    )


def validate_value_is_empty_or_number_in_range_zero_one(
    value: Any, field_name: str = "confidence", error: Type[Exception] = ValueError
) -> None:
    validate_field_has_given_type(
        field_name=field_name,
        allowed_types=[type(None), int, float],
        value=value,
        error=error,
    )
    if value is None:
        return None
    if not (0 <= value <= 1):
        raise error(f"Parameter `{field_name}` must be in range [0.0, 1.0]")


def validate_value_is_empty_or_selector_or_positive_number(
    value: Any, field_name: str
) -> None:
    if is_selector(selector_or_value=value):
        return None
    validate_value_is_empty_or_positive_number(value=value, field_name=field_name)


def validate_value_is_empty_or_positive_number(
    value: Any, field_name: str, error: Type[Exception] = ValueError
) -> None:
    validate_field_has_given_type(
        field_name=field_name,
        allowed_types=[type(None), int, float],
        value=value,
        error=error,
    )
    if value is None:
        return None
    if value <= 0:
        raise error(f"Parameter `{field_name}` must be positive (> 0)")


def validate_field_is_list_of_selectors(
    value: Any, field_name: str, error: Type[Exception] = ValueError
) -> None:
    if not issubclass(type(value), list):
        raise error(f"`{field_name}` field must be list")
    if any(not is_selector(selector_or_value=e) for e in value):
        raise error(f"Parameter `{field_name}` must be a list of selectors")


def validate_field_is_empty_or_selector_or_list_of_string(
    value: Any, field_name: str
) -> None:
    if is_selector(selector_or_value=value) or value is None:
        return value
    validate_field_is_list_of_string(value=value, field_name=field_name)


def validate_field_is_list_of_string(
    value: Any, field_name: str, error: Type[Exception] = ValueError
) -> None:
    if not issubclass(type(value), list):
        raise error(f"`{field_name}` field must be list")
    if any(not issubclass(type(e), str) for e in value):
        raise error(f"Parameter `{field_name}` must be a list of string")


def validate_field_is_selector_or_one_of_values(
    value: Any, field_name: str, selected_values: set
) -> None:
    if is_selector(selector_or_value=value) or value is None:
        return value
    validate_field_is_one_of_selected_values(
        value=value, field_name=field_name, selected_values=selected_values
    )


def validate_field_is_one_of_selected_values(
    value: Any,
    field_name: str,
    selected_values: set,
    error: Type[Exception] = ValueError,
) -> None:
    if value not in selected_values:
        raise error(
            f"Value of field `{field_name}` must be in {selected_values}. Found: {value}"
        )


def validate_field_is_selector_or_has_given_type(
    value: Any, field_name: str, allowed_types: List[type]
) -> None:
    if is_selector(selector_or_value=value):
        return None
    validate_field_has_given_type(
        field_name=field_name, allowed_types=allowed_types, value=value
    )
    return None


def validate_field_has_given_type(
    value: Any,
    field_name: str,
    allowed_types: List[type],
    error: Type[Exception] = ValueError,
) -> None:
    if all(not issubclass(type(value), allowed_type) for allowed_type in allowed_types):
        raise error(
            f"`{field_name}` field type must be one of {allowed_types}. Detected: {value}"
        )


def validate_image_biding(value: Any, field_name: str = "image") -> None:
    try:
        if not issubclass(type(value), list):
            value = [value]
        for e in value:
            InferenceRequestImage.model_validate(e)
    except (ValueError, ValidationError) as error:
        raise VariableTypeError(
            f"Parameter `{field_name}` must be compatible with `InferenceRequestImage`"
        ) from error


def validate_selector_is_inference_parameter(
    step_type: str,
    field_name: str,
    input_step: GraphNone,
    applicable_fields: Set[str],
) -> None:
    if field_name not in applicable_fields:
        return None
    input_step_type = input_step.get_type()
    if input_step_type not in {"InferenceParameter"}:
        raise InvalidStepInputDetected(
            f"Field {field_name} of step {step_type} comes from invalid input type: {input_step_type}. "
            f"Expected: `InferenceParameter`"
        )


def validate_selector_holds_image(
    step_type: str,
    field_name: str,
    input_step: GraphNone,
    applicable_fields: Optional[Set[str]] = None,
) -> None:
    if applicable_fields is None:
        applicable_fields = {"image"}
    if field_name not in applicable_fields:
        return None
    if input_step.get_type() not in STEPS_WITH_IMAGE:
        raise InvalidStepInputDetected(
            f"Field {field_name} of step {step_type} comes from invalid input type: {input_step.get_type()}. "
            f"Expected: {STEPS_WITH_IMAGE}"
        )


def validate_selector_holds_detections(
    step_name: str,
    image_selector: Optional[str],
    detections_selector: str,
    field_name: str,
    input_step: GraphNone,
    applicable_fields: Optional[Set[str]] = None,
) -> None:
    if applicable_fields is None:
        applicable_fields = {"detections"}
    if field_name not in applicable_fields:
        return None
    if input_step.get_type() not in {
        "ObjectDetectionModel",
        "KeypointsDetectionModel",
        "InstanceSegmentationModel",
        "DetectionFilter",
        "DetectionsConsensus",
        "DetectionOffset",
        "YoloWorld",
    }:
        raise InvalidStepInputDetected(
            f"Step step with name {step_name} cannot take as an input predictions from {input_step.get_type()}. "
            f"Step requires detection-based output."
        )
    if get_last_selector_chunk(detections_selector) != "predictions":
        raise InvalidStepInputDetected(
            f"Step with name {step_name} must take as input step output of name `predictions`"
        )
    if not hasattr(input_step, "image") or image_selector is None:
        # Here, filter do not hold the reference to image, we skip the check in this case
        return None
    input_step_image_reference = input_step.image
    if image_selector != input_step_image_reference:
        raise InvalidStepInputDetected(
            f"Step step with name {step_name} was given detections reference that is bound to different image: "
            f"step.image: {image_selector}, detections step image: {input_step_image_reference}"
        )


def is_selector(selector_or_value: Any) -> bool:
    if not issubclass(type(selector_or_value), str):
        return False
    return selector_or_value.startswith("$")


def get_last_selector_chunk(selector: str) -> str:
    return selector.split(".")[-1]