File size: 4,571 Bytes
934ba8d
 
 
 
 
 
 
7701eda
5bfc0ed
 
 
 
 
7701eda
c6caac4
934ba8d
7701eda
934ba8d
 
c6caac4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b71ed1
c6caac4
 
 
 
934ba8d
c6caac4
934ba8d
c6caac4
 
 
 
 
 
934ba8d
c6caac4
 
934ba8d
c6caac4
934ba8d
 
7701eda
c6caac4
 
934ba8d
df1acda
934ba8d
 
c6caac4
 
e920a86
 
 
 
 
c6caac4
934ba8d
 
c6caac4
 
 
 
 
 
 
 
 
 
 
 
e920a86
 
c6caac4
e920a86
c6caac4
 
 
 
 
 
 
e920a86
 
934ba8d
 
c6caac4
 
 
 
 
934ba8d
 
c6caac4
d3e4173
934ba8d
 
 
 
7701eda
c6caac4
934ba8d
39bdbb6
df1acda
934ba8d
 
 
 
 
968742e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import gradio as gr
import cv2
import requests
import os

from ultralyticsplus import YOLO, render_result

# Model Heading and Description
model_heading = "TTG's Candle Stick Scan: Pattern Recognition BETA"
description = """ #
#
📧 #
👍 #"""

image_path= [['test/test1.jpg', 'foduucom/stockmarket-pattern-detection-yolov8', 640, 0.25, 0.45], ['test/test2.jpg', 'foduucom/stockmarket-pattern-detection-yolov8', 640, 0.25, 0.45]]

# Load YOLO model
model = YOLO('foduucom/stockmarket-pattern-detection-yolov8')

#############################################################Image Inference############################################################
def yolov8_img_inference(
    image: gr.inputs.Image = None,
    model_path: gr.inputs.Dropdown = None,
    image_size: gr.inputs.Slider = 640,
    conf_threshold: gr.inputs.Slider = 0.25,
    iou_threshold: gr.inputs.Slider = 0.45,
):
    """
    YOLOv8 inference function
    Args:
        image: Input image
        model_path: Path to the model
        image_size: Image size
        conf_threshold: Confidence threshold
        iou_threshold: IOU threshold
    Returns:
        Rendered image
    """
    model = YOLO(model_path)
    model.overrides['conf'] = conf_threshold
    model.overrides['iou']= iou_threshold
    model.overrides['agnostic_nms'] = False  # NMS class-agnostic
    model.overrides['max_det'] = 1000 
    # image = read_image(image)
    results = model.predict(image)
    render = render_result(model=model, image=image, result=results[0])
    
    return render

    
inputs_image = [
    gr.inputs.Image(type="filepath", label="Input Image"),
    gr.inputs.Dropdown(["foduucom/stockmarket-pattern-detection-yolov8"], 
                       default="foduucom/stockmarket-pattern-detection-yolov8", label="Model"),
    gr.inputs.Slider(minimum=320, maximum=1280, default=640, step=32, label="Image Size"),
    gr.inputs.Slider(minimum=0.0, maximum=1.0, default=0.25, step=0.05, label="Confidence Threshold"),
    gr.inputs.Slider(minimum=0.0, maximum=1.0, default=0.45, step=0.05, label="IOU Threshold"),
]

outputs_image =gr.outputs.Image(type="filepath", label="Output Image")
interface_image = gr.Interface(
    fn=yolov8_img_inference,
    inputs=inputs_image,
    outputs=outputs_image,
    title=model_heading,
    description=description,
    examples=image_path,
    cache_examples=False,
    theme='huggingface'
)

##################################################Video Inference################################################################
def show_preds_video(
    video_path: str = None,
    model_path: str = None,
    image_size: int = 640,
    conf_threshold: float = 0.25,
    iou_threshold: float = 0.45,
):
    cap = cv2.VideoCapture(video_path)

    while cap.isOpened():
        success, frame = cap.read()

        if success:
            model = YOLO(model_path)
            model.overrides['conf'] = conf_threshold
            model.overrides['iou'] = iou_threshold
            model.overrides['agnostic_nms'] = False
            model.overrides['max_det'] = 1000 
            results = model.predict(frame)
            annotated_frame = results[0].plot()

            # Do not display the frame using cv2.imshow
            # cv2.imshow("YOLOv8 Inference", annotated_frame)

            # Break the loop if 'q' is pressed
            if cv2.waitKey(1) & 0xFF == ord("q"):
                break
        else:
            break

    cap.release()
    cv2.destroyAllWindows()


inputs_video = [
    gr.components.Video(type="filepath", label="Input Video"),
    gr.inputs.Dropdown(["foduucom/stockmarket-pattern-detection-yolov8"], 
                       default="foduucom/stockmarket-pattern-detection-yolov8", label="Model"),
    gr.inputs.Slider(minimum=320, maximum=1280, default=640, step=32, label="Image Size"),
    gr.inputs.Slider(minimum=0.0, maximum=1.0, default=0.25, step=0.05, label="Confidence Threshold"),
    gr.inputs.Slider(minimum=0.0, maximum=1.0, default=0.45, step=0.05, label="IOU Threshold"),

]
outputs_video = gr.outputs.Image(type="filepath", label="Output Video")
video_path=[['test/testvideo.mp4','foduucom/stockmarket-pattern-detection-yolov8', 640, 0.25, 0.45]]
interface_video = gr.Interface(
    fn=show_preds_video,
    inputs=inputs_video,
    outputs=outputs_video,
    title=model_heading,
    description=description,
    examples=video_path,
    cache_examples=False,
    theme='huggingface'
)

gr.TabbedInterface(
    [interface_image, interface_video],
    tab_names=['Image inference', 'Video inference']
).queue().launch()