SkalskiP's picture
initial code version
4ae7d54
raw
history blame
4.22 kB
import os
from unittest.mock import patch
import gradio as gr
import numpy as np
import supervision as sv
import torch
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoProcessor
from utils.imports import fixed_get_imports
from utils.models import (
run_captioning,
CAPTIONING_TASK,
run_caption_to_phrase_grounding
)
from utils.video import (
create_directory,
remove_files_older_than,
generate_file_name,
calculate_end_frame_index
)
MARKDOWN = """
# Florence-2 for Videos 🎬
<div>
<a href="https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/how-to-finetune-florence-2-on-detection-dataset.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Colab" style="display:inline-block;">
</a>
<a href="https://blog.roboflow.com/florence-2/">
<img src="https://raw.githubusercontent.com/roboflow-ai/notebooks/main/assets/badges/roboflow-blogpost.svg" alt="Roboflow" style="display:inline-block;">
</a>
<a href="https://arxiv.org/abs/2311.06242">
<img src="https://img.shields.io/badge/arXiv-2311.06242-b31b1b.svg" alt="arXiv" style="display:inline-block;">
</a>
</div>
"""
RESULTS = "results"
CHECKPOINT = "microsoft/Florence-2-base-ft"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
MODEL = AutoModelForCausalLM.from_pretrained(
CHECKPOINT, trust_remote_code=True).to(DEVICE)
PROCESSOR = AutoProcessor.from_pretrained(
CHECKPOINT, trust_remote_code=True)
BOUNDING_BOX_ANNOTATOR = sv.BoundingBoxAnnotator(color_lookup=sv.ColorLookup.TRACK)
LABEL_ANNOTATOR = sv.LabelAnnotator(color_lookup=sv.ColorLookup.TRACK)
TRACKER = sv.ByteTrack()
# creating video results directory
create_directory(directory_path=RESULTS)
def annotate_image(
input_image: np.ndarray,
detections: sv.Detections
) -> np.ndarray:
output_image = input_image.copy()
output_image = BOUNDING_BOX_ANNOTATOR.annotate(output_image, detections)
output_image = LABEL_ANNOTATOR.annotate(output_image, detections)
return output_image
def process_video(
input_video: str,
progress=gr.Progress(track_tqdm=True)
) -> str:
# cleanup of old video files
remove_files_older_than(RESULTS, 30)
video_info = sv.VideoInfo.from_video_path(input_video)
total = calculate_end_frame_index(input_video)
frame_generator = sv.get_video_frames_generator(
source_path=input_video,
end=total
)
result_file_name = generate_file_name(extension="mp4")
result_file_path = os.path.join(RESULTS, result_file_name)
TRACKER.reset()
with sv.VideoSink(result_file_path, video_info=video_info) as sink:
for _ in tqdm(range(total), desc="Processing video..."):
frame = next(frame_generator)
caption = run_captioning(
model=MODEL,
processor=PROCESSOR,
image=frame,
device=DEVICE
)[CAPTIONING_TASK]
detections = run_caption_to_phrase_grounding(
model=MODEL,
processor=PROCESSOR,
caption=caption,
image=frame,
device=DEVICE
)
detections = TRACKER.update_with_detections(detections)
frame = annotate_image(
input_image=frame,
detections=detections
)
sink.write_frame(frame)
return result_file_path
with gr.Blocks() as demo:
gr.Markdown(MARKDOWN)
with gr.Row():
input_video_component = gr.Video(
label='Input Video'
)
output_video_component = gr.Video(
label='Output Video'
)
with gr.Row():
submit_button_component = gr.Button(
value='Submit',
scale=1,
variant='primary'
)
submit_button_component.click(
fn=process_video,
inputs=[
input_video_component,
],
outputs=output_video_component
)
demo.launch(debug=False, show_error=True, max_threads=1)