apollo_genai / app2.py
arjunanand13's picture
Rename app.py to app2.py
bd5e320 verified
import gradio as gr
import io
import numpy as np
import torch
from decord import cpu, VideoReader, bridge
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import BitsAndBytesConfig
import json
MODEL_PATH = "THUDM/cogvlm2-llama3-caption"
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
TORCH_TYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float16
DELAY_REASONS = {
"step1": {"reasons": ["No raw material available", "Person repatching the tire"]},
"step2": {"reasons": ["Person repatching the tire", "Lack of raw material"]},
"step3": {"reasons": ["Person repatching the tire", "Lack of raw material"]},
"step4": {"reasons": ["Person repatching the tire", "Lack of raw material"]},
"step5": {"reasons": ["Person repatching the tire", "Lack of raw material"]},
"step6": {"reasons": ["Person repatching the tire", "Lack of raw material"]},
"step7": {"reasons": ["Person repatching the tire", "Lack of raw material"]},
"step8": {"reasons": ["No person available to collect tire", "Person repatching the tire"]}
}
with open('delay_reasons.json', 'w') as f:
json.dump(DELAY_REASONS, f, indent=4)
def load_video(video_data, strategy='chat'):
bridge.set_bridge('torch')
mp4_stream = video_data
num_frames = 24
decord_vr = VideoReader(io.BytesIO(mp4_stream), ctx=cpu(0))
frame_id_list = []
total_frames = len(decord_vr)
timestamps = [i[0] for i in decord_vr.get_frame_timestamp(np.arange(total_frames))]
max_second = round(max(timestamps)) + 1
for second in range(max_second):
closest_num = min(timestamps, key=lambda x: abs(x - second))
index = timestamps.index(closest_num)
frame_id_list.append(index)
if len(frame_id_list) >= num_frames:
break
video_data = decord_vr.get_batch(frame_id_list)
video_data = video_data.permute(3, 0, 1, 2)
return video_data
def load_model():
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=TORCH_TYPE,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
torch_dtype=TORCH_TYPE,
trust_remote_code=True,
quantization_config=quantization_config,
device_map="auto"
).eval()
return model, tokenizer
def predict(prompt, video_data, temperature, model, tokenizer):
strategy = 'chat'
video = load_video(video_data, strategy=strategy)
history = []
inputs = model.build_conversation_input_ids(
tokenizer=tokenizer,
query=prompt,
images=[video],
history=history,
template_version=strategy
)
inputs = {
'input_ids': inputs['input_ids'].unsqueeze(0).to(DEVICE),
'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to(DEVICE),
'attention_mask': inputs['attention_mask'].unsqueeze(0).to(DEVICE),
'images': [[inputs['images'][0].to(DEVICE).to(TORCH_TYPE)]],
}
gen_kwargs = {
"max_new_tokens": 2048,
"pad_token_id": 128002,
"top_k": 1,
"do_sample": False,
"top_p": 0.1,
"temperature": temperature,
}
with torch.no_grad():
outputs = model.generate(**inputs, **gen_kwargs)
outputs = outputs[:, inputs['input_ids'].shape[1]:]
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
def get_base_prompt():
return """You are an expert AI model trained to analyze and interpret manufacturing processes.
The task is to evaluate video footage of specific steps in a tire manufacturing process.
The process has 8 total steps, but only delayed steps are provided for analysis.
**Your Goal:**
1. Analyze the provided video.
2. Identify possible reasons for the delay in the manufacturing step shown in the video.
3. Provide a clear explanation of the delay based on observed factors.
**Context:**
Tire manufacturing involves 8 steps, and delays may occur due to machinery faults,
raw material availability, labor efficiency, or unexpected disruptions.
**Output:**
Explain why the delay occurred in this step. Include specific observations
and their connection to the delay."""
def inference(video, step_number, selected_reason):
if not video:
return "Please upload a video first."
model, tokenizer = load_model()
video_data = video.read()
base_prompt = get_base_prompt()
full_prompt = f"{base_prompt}\n\nAnalyzing Step {step_number}\nPossible reason: {selected_reason}"
temperature = 0.8
response = predict(full_prompt, video_data, temperature, model, tokenizer)
return response
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
video = gr.Video(label="Video Input", sources=["upload"])
step_number = gr.Dropdown(choices=[f"Step {i}" for i in range(1, 9)], label="Manufacturing Step", value="Step 1")
reason = gr.Dropdown(choices=DELAY_REASONS["step1"]["reasons"], label="Possible Delay Reason", value=DELAY_REASONS["step1"]["reasons"][0])
analyze_btn = gr.Button("Analyze Delay", variant="primary")
with gr.Column():
output = gr.Textbox(label="Analysis Result")
def update_reasons(step):
step_num = step.lower().replace(" ", "")
return gr.Dropdown(choices=DELAY_REASONS[step_num]["reasons"])
step_number.change(fn=update_reasons, inputs=[step_number], outputs=[reason])
analyze_btn.click(fn=inference, inputs=[video, step_number, reason], outputs=[output])
if __name__ == "__main__":
demo.launch()