import gradio as gr
import cv2
import requests
import os

from ultralyticsplus import YOLO, render_result

# Model Heading and Description
model_heading = "TTG's Candle Stick Scan: Pattern Recognition BETA"
description = """ #
#
📧 #
👍 #"""

image_path= [['test/test1.jpg', 'foduucom/stockmarket-pattern-detection-yolov8', 640, 0.25, 0.45], ['test/test2.jpg', 'foduucom/stockmarket-pattern-detection-yolov8', 640, 0.25, 0.45]]

# Load YOLO model
model = YOLO('foduucom/stockmarket-pattern-detection-yolov8')

#############################################################Image Inference############################################################
def yolov8_img_inference(
    image: gr.inputs.Image = None,
    model_path: gr.inputs.Dropdown = None,
    image_size: gr.inputs.Slider = 640,
    conf_threshold: gr.inputs.Slider = 0.25,
    iou_threshold: gr.inputs.Slider = 0.45,
):
    """
    YOLOv8 inference function
    Args:
        image: Input image
        model_path: Path to the model
        image_size: Image size
        conf_threshold: Confidence threshold
        iou_threshold: IOU threshold
    Returns:
        Rendered image
    """
    model = YOLO(model_path)
    model.overrides['conf'] = conf_threshold
    model.overrides['iou']= iou_threshold
    model.overrides['agnostic_nms'] = False  # NMS class-agnostic
    model.overrides['max_det'] = 1000 
    # image = read_image(image)
    results = model.predict(image)
    render = render_result(model=model, image=image, result=results[0])
    
    return render

    
inputs_image = [
    gr.inputs.Image(type="filepath", label="Input Image"),
    gr.inputs.Dropdown(["foduucom/stockmarket-pattern-detection-yolov8"], 
                       default="foduucom/stockmarket-pattern-detection-yolov8", label="Model"),
    gr.inputs.Slider(minimum=320, maximum=1280, default=640, step=32, label="Image Size"),
    gr.inputs.Slider(minimum=0.0, maximum=1.0, default=0.25, step=0.05, label="Confidence Threshold"),
    gr.inputs.Slider(minimum=0.0, maximum=1.0, default=0.45, step=0.05, label="IOU Threshold"),
]

outputs_image =gr.outputs.Image(type="filepath", label="Output Image")
interface_image = gr.Interface(
    fn=yolov8_img_inference,
    inputs=inputs_image,
    outputs=outputs_image,
    title=model_heading,
    description=description,
    examples=image_path,
    cache_examples=False,
    theme='huggingface'
)

##################################################Video Inference################################################################
def show_preds_video(
    video_path: str = None,
    model_path: str = None,
    image_size: int = 640,
    conf_threshold: float = 0.25,
    iou_threshold: float = 0.45,
):
    cap = cv2.VideoCapture(video_path)

    while cap.isOpened():
        success, frame = cap.read()

        if success:
            model = YOLO(model_path)
            model.overrides['conf'] = conf_threshold
            model.overrides['iou'] = iou_threshold
            model.overrides['agnostic_nms'] = False
            model.overrides['max_det'] = 1000 
            results = model.predict(frame)
            annotated_frame = results[0].plot()

            # Do not display the frame using cv2.imshow
            # cv2.imshow("YOLOv8 Inference", annotated_frame)

            # Break the loop if 'q' is pressed
            if cv2.waitKey(1) & 0xFF == ord("q"):
                break
        else:
            break

    cap.release()
    cv2.destroyAllWindows()


inputs_video = [
    gr.components.Video(type="filepath", label="Input Video"),
    gr.inputs.Dropdown(["foduucom/stockmarket-pattern-detection-yolov8"], 
                       default="foduucom/stockmarket-pattern-detection-yolov8", label="Model"),
    gr.inputs.Slider(minimum=320, maximum=1280, default=640, step=32, label="Image Size"),
    gr.inputs.Slider(minimum=0.0, maximum=1.0, default=0.25, step=0.05, label="Confidence Threshold"),
    gr.inputs.Slider(minimum=0.0, maximum=1.0, default=0.45, step=0.05, label="IOU Threshold"),

]
outputs_video = gr.outputs.Image(type="filepath", label="Output Video")
video_path=[['test/testvideo.mp4','foduucom/stockmarket-pattern-detection-yolov8', 640, 0.25, 0.45]]
interface_video = gr.Interface(
    fn=show_preds_video,
    inputs=inputs_video,
    outputs=outputs_video,
    title=model_heading,
    description=description,
    examples=video_path,
    cache_examples=False,
    theme='huggingface'
)

gr.TabbedInterface(
    [interface_image, interface_video],
    tab_names=['Image inference', 'Video inference']
).queue().launch()