Spaces:
Sleeping
Sleeping
File size: 5,868 Bytes
baab6b5 114aae4 baab6b5 114aae4 baab6b5 114aae4 baab6b5 114aae4 baab6b5 114aae4 baab6b5 99e9497 114aae4 baab6b5 114aae4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import gradio as gr
import io
import numpy as np
import torch
from decord import cpu, VideoReader, bridge
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import BitsAndBytesConfig
import json
MODEL_PATH = "THUDM/cogvlm2-llama3-caption"
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
TORCH_TYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float16
DELAY_REASONS = {
"step1": {"reasons": ["No raw material available", "Person repatching the tire"]},
"step2": {"reasons": ["Person repatching the tire", "Lack of raw material"]},
"step3": {"reasons": ["Person repatching the tire", "Lack of raw material"]},
"step4": {"reasons": ["Person repatching the tire", "Lack of raw material"]},
"step5": {"reasons": ["Person repatching the tire", "Lack of raw material"]},
"step6": {"reasons": ["Person repatching the tire", "Lack of raw material"]},
"step7": {"reasons": ["Person repatching the tire", "Lack of raw material"]},
"step8": {"reasons": ["No person available to collect tire", "Person repatching the tire"]}
}
with open('delay_reasons.json', 'w') as f:
json.dump(DELAY_REASONS, f, indent=4)
def load_video(video_data, strategy='chat'):
bridge.set_bridge('torch')
mp4_stream = video_data
num_frames = 24
decord_vr = VideoReader(io.BytesIO(mp4_stream), ctx=cpu(0))
frame_id_list = []
total_frames = len(decord_vr)
timestamps = [i[0] for i in decord_vr.get_frame_timestamp(np.arange(total_frames))]
max_second = round(max(timestamps)) + 1
for second in range(max_second):
closest_num = min(timestamps, key=lambda x: abs(x - second))
index = timestamps.index(closest_num)
frame_id_list.append(index)
if len(frame_id_list) >= num_frames:
break
video_data = decord_vr.get_batch(frame_id_list)
video_data = video_data.permute(3, 0, 1, 2)
return video_data
def load_model():
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=TORCH_TYPE,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
torch_dtype=TORCH_TYPE,
trust_remote_code=True,
quantization_config=quantization_config,
device_map="auto"
).eval()
return model, tokenizer
def predict(prompt, video_data, temperature, model, tokenizer):
strategy = 'chat'
video = load_video(video_data, strategy=strategy)
history = []
inputs = model.build_conversation_input_ids(
tokenizer=tokenizer,
query=prompt,
images=[video],
history=history,
template_version=strategy
)
inputs = {
'input_ids': inputs['input_ids'].unsqueeze(0).to(DEVICE),
'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to(DEVICE),
'attention_mask': inputs['attention_mask'].unsqueeze(0).to(DEVICE),
'images': [[inputs['images'][0].to(DEVICE).to(TORCH_TYPE)]],
}
gen_kwargs = {
"max_new_tokens": 2048,
"pad_token_id": 128002,
"top_k": 1,
"do_sample": False,
"top_p": 0.1,
"temperature": temperature,
}
with torch.no_grad():
outputs = model.generate(**inputs, **gen_kwargs)
outputs = outputs[:, inputs['input_ids'].shape[1]:]
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
def get_base_prompt():
return """You are an expert AI model trained to analyze and interpret manufacturing processes.
The task is to evaluate video footage of specific steps in a tire manufacturing process.
The process has 8 total steps, but only delayed steps are provided for analysis.
**Your Goal:**
1. Analyze the provided video.
2. Identify possible reasons for the delay in the manufacturing step shown in the video.
3. Provide a clear explanation of the delay based on observed factors.
**Context:**
Tire manufacturing involves 8 steps, and delays may occur due to machinery faults,
raw material availability, labor efficiency, or unexpected disruptions.
**Output:**
Explain why the delay occurred in this step. Include specific observations
and their connection to the delay."""
def inference(video, step_number, selected_reason):
if not video:
return "Please upload a video first."
model, tokenizer = load_model()
video_data = video.read()
base_prompt = get_base_prompt()
full_prompt = f"{base_prompt}\n\nAnalyzing Step {step_number}\nPossible reason: {selected_reason}"
temperature = 0.8
response = predict(full_prompt, video_data, temperature, model, tokenizer)
return response
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
video = gr.Video(label="Video Input", sources=["upload"])
step_number = gr.Dropdown(choices=[f"Step {i}" for i in range(1, 9)], label="Manufacturing Step", value="Step 1")
reason = gr.Dropdown(choices=DELAY_REASONS["step1"]["reasons"], label="Possible Delay Reason", value=DELAY_REASONS["step1"]["reasons"][0])
analyze_btn = gr.Button("Analyze Delay", variant="primary")
with gr.Column():
output = gr.Textbox(label="Analysis Result")
def update_reasons(step):
step_num = step.lower().replace(" ", "")
return gr.Dropdown(choices=DELAY_REASONS[step_num]["reasons"])
step_number.change(fn=update_reasons, inputs=[step_number], outputs=[reason])
analyze_btn.click(fn=inference, inputs=[video, step_number, reason], outputs=[output])
if __name__ == "__main__":
demo.launch()
|