poi_Engineering / app.py
Juartaurus's picture
Upload folder using huggingface_hub
1865436
raw
history blame contribute delete
No virus
10.2 kB
import numpy as np
import gradio as gr
import cv2
import os
import shutil
import re
import torch
import csv
import time
from src.sts.demo.sts import handle_sts
from src.ir.ir import handle_ir
from src.ir.src.models.tc_classifier import TCClassifier
from src.tracker.signboard_track import SignboardTracker
from omegaconf import DictConfig
from hydra import compose, initialize
signboardTracker = SignboardTracker()
tracking_result_dir = ""
output_track_format = "mp4v"
output_track = ""
output_sts = ""
video_dir = ""
vd_dir = ""
labeling_dir = ""
frame_out = {}
rs = {}
results = []
# with initialize(version_base=None, config_path="src/ir/configs", job_name="ir"):
# config = compose(config_name="test")
# config: DictConfig
# model_ir = TCClassifier(config.model.train.model_name,
# config.model.train.n_classes,
# config.model.train.lr,
# config.model.train.scheduler_type,
# config.model.train.max_steps,
# config.model.train.weight_decay,
# config.model.train.classifier_dropout,
# config.model.train.mixout,
# config.model.train.freeze_encoder)
# model_ir = model_ir.load_from_checkpoint(checkpoint_path=config.ckpt_path, map_location=torch.device("cuda"))
def create_dir(list_dir_path):
for dir_path in list_dir_path:
if not os.path.isdir(dir_path):
os.makedirs(dir_path)
def get_meta_from_video(input_video):
if input_video is not None:
video_name = os.path.basename(input_video).split('.')[0]
global video_dir
video_dir = os.path.join("static/videos/", f"{video_name}")
global vd_dir
vd_dir = os.path.join(video_dir, os.path.basename(input_video))
global output_track
output_track = os.path.join(video_dir,"original")
global tracking_result_dir
tracking_result_dir = os.path.join(video_dir,"track/cropped")
global output_sts
output_sts = os.path.join(video_dir,"track/sts")
global labeling_dir
labeling_dir = os.path.join(video_dir,"track/labeling")
if os.path.isdir(video_dir):
return None
else:
create_dir([output_track, video_dir, os.path.join(video_dir, "track/segment"), output_sts, tracking_result_dir, labeling_dir])
# initialize the video stream
video_cap = cv2.VideoCapture(input_video)
# grab the width, height, and fps of the frames in the video stream.
frame_width = int(video_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(video_cap.get(cv2.CAP_PROP_FPS))
#tổng Fps
# total_frames = int(video_cap.get(cv2.CAP_PROP_FRAME_COUNT))
# print(total_frames)
# # Tính tổng số giây trong video
# total_seconds = total_frames / video_cap.get(cv2.CAP_PROP_FPS)
# print(total_seconds)
# initialize the FourCC and a video writer object
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
output = cv2.VideoWriter(vd_dir, fourcc, fps, (frame_width, frame_height))
while True:
success, frame = video_cap.read()
# write the frame to the output file
if success == True:
output.write(frame)
else:
break
# print(fps)
# return gr.Slider(1, fps, value=4, label="FPS",step=1, info="Choose between 1 and {fps}", interactive=True)
return gr.Textbox(value=fps)
def get_signboard(evt: gr.SelectData):
name_fr = int(evt.index) + 1
ids_dir = tracking_result_dir
all_ids = os.listdir(ids_dir)
gallery=[]
for i in all_ids:
fr_id = str(name_fr)
al = re.search("[\d]*_"+fr_id+".png", i)
if al:
id_dir = os.path.join(ids_dir, i)
gallery.append(id_dir)
gallery = sorted(gallery)
return gallery, name_fr
def tracking(fps_target):
start = time.time()
fps_target = int(fps_target)
global results
results = signboardTracker.inference_signboard(fps_target, vd_dir, output_track, output_track_format, tracking_result_dir)[0]
# print("result", results)
fd = []
global frame_out
list_id = []
with open(os.path.join(video_dir, "track/label.csv"), 'w', newline='') as file:
writer = csv.writer(file)
writer.writerow(["Signboard", "Frame", "Text"])
for frame, values in results.items():
frame_dir = os.path.join(output_track, f"{frame}.jpg")
# segment = os.path.join(video_dir,"segment/" + f"{frame}.jpg")
list_boxs = []
full = []
list_id_tmp = []
# print("values", values)
for value in values:
list_boxs.append(value['box'])
list_id_tmp.append(value['id'])
_, dict_rec_sign_out = handle_sts(frame_dir, labeling_dir, list_boxs, list_id_tmp)
# predicted = handle_ir(frame_dir, dict_rec_sign_out, os.path.join(video_dir, "ir"))
# print(predicted)
# fd.append(frame_dir)
# frame_out[frame] = full
list_id.extend(list_id_tmp)
list_id = list(set(list_id))
# print(list_id)
print(time.time()-start)
return gr.Dropdown(label="signboard",choices=list_id, interactive=True)
def get_select_index(img_id, evt: gr.SelectData):
ids_dir = tracking_result_dir
# print(ids_dir)
all_ids = os.listdir(ids_dir)
gallery = []
for i in all_ids:
fr_id = str(img_id)
al = re.search("[\d]*_"+fr_id+".png", i)
if al:
id_dir = os.path.join(ids_dir, i)
gallery.append(id_dir)
gallery = sorted(gallery)
gallery_id=[]
id_name = gallery[evt.index]
id = os.path.basename(id_name).split(".")[0].split("_")[0]
for i in all_ids:
al = re.search("^" +id + "_[\d]*.png", i)
if al:
id_dir = os.path.join(ids_dir, i)
gallery_id.append(id_dir)
gallery_id = sorted(gallery_id)
return gallery_id
id_glb = None
def select_id(evt: gr.SelectData):
choice=[]
global id_glb
id_glb = evt.value
for key, values in results.items():
for value in values:
if value['id'] == evt.value:
choice.append(int(key))
return gr.Dropdown(label="frame", choices=choice, interactive=True)
import pandas as pd
frame_glb = None
def select_frame(evt: gr.SelectData):
full_img = os.path.join(output_track, str(evt.value) + ".jpg")
crop_img = os.path.join(tracking_result_dir, str(id_glb) + "_" + str(evt.value) + ".png")
global frame_glb
frame_glb = evt.value
data = pd.read_csv(os.path.join(labeling_dir, str(id_glb) + "_" + str(frame_glb) + '.csv'), header=0)
return full_img, crop_img, data
def get_data(dtfr):
print(dtfr)
# df = pd.read_csv(os.path.join(video_dir, "track/label.csv"))
# for i, row in df.iterrows():
# if str(row["Signboard"]) == str(id_tmp) and str(row["Frame"]) == str(frame_tmp):
# # print(row["Text"])
# df_new = df.replace(str(row["Text"]), str(labeling))
# print(df_new)
dtfr.to_csv(os.path.join(labeling_dir, str(id_glb) + "_" + str(frame_glb) + '.csv'), index=False, header=True)
return
def seg_track_app():
##########################################################
###################### Front-end ########################
##########################################################
with gr.Blocks(css=".gradio-container {background-color: white}") as demo:
gr.Markdown(
'''
<div style="text-align:center;">
<span style="font-size:3em; font-weight:bold;">POI Engineeing</span>
</div>
'''
)
with gr.Row():
# video input
with gr.Column(scale=0.2):
tab_video_input = gr.Row(label="Video type input")
with tab_video_input:
input_video = gr.Video(label='Input video')
tab_everything = gr.Row(label="Tracking")
with tab_everything:
with gr.Row():
seg_signboard = gr.Button(value="Tracking", interactive=True)
all_info = gr.Row(label="Information about video")
with all_info:
with gr.Row():
text = gr.Textbox(label="Fps")
check_fps = gr.Textbox(label="Choose fps for output", interactive=True)
with gr.Column(scale=1):
with gr.Row():
with gr.Column(scale=2):
with gr.Row():
with gr.Column(scale=1):
id_drop = gr.Dropdown(label="Signboards",choices=[])
with gr.Column(scale=1):
fr_drop = gr.Dropdown(label="Frames",choices=[])
full_img = gr.Image(label="Full Image")
with gr.Column(scale=1):
crop_img = gr.Image(label="Cropped Image")
with gr.Row():
dtfr = gr.Dataframe(headers=["Tag", "Value"], datatype=["str", "str"], interactive=True)
with gr.Row():
submit = gr.Button(value="Submit", interactive=True)
##########################################################
###################### back-end #########################
##########################################################
input_video.change(
fn=get_meta_from_video,
inputs=input_video,
outputs=text
)
seg_signboard.click(
fn=tracking,
inputs=check_fps,
outputs=id_drop
)
id_drop.select(select_id, None, fr_drop)
fr_drop.select(select_frame, None, [full_img,crop_img, dtfr])
submit.click(get_data, dtfr, None)
demo.queue(concurrency_count=1)
demo.launch(debug=True, enable_queue=True, share=True)
if __name__ == "__main__":
seg_track_app()