File size: 5,868 Bytes
baab6b5
114aae4
 
 
 
 
 
 
baab6b5
114aae4
 
 
baab6b5
114aae4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
baab6b5
114aae4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
baab6b5
114aae4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
baab6b5
 
99e9497
114aae4
 
 
baab6b5
114aae4
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import gradio as gr
import io
import numpy as np
import torch
from decord import cpu, VideoReader, bridge
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import BitsAndBytesConfig
import json

MODEL_PATH = "THUDM/cogvlm2-llama3-caption"
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
TORCH_TYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float16

DELAY_REASONS = {
    "step1": {"reasons": ["No raw material available", "Person repatching the tire"]},
    "step2": {"reasons": ["Person repatching the tire", "Lack of raw material"]},
    "step3": {"reasons": ["Person repatching the tire", "Lack of raw material"]},
    "step4": {"reasons": ["Person repatching the tire", "Lack of raw material"]},
    "step5": {"reasons": ["Person repatching the tire", "Lack of raw material"]},
    "step6": {"reasons": ["Person repatching the tire", "Lack of raw material"]},
    "step7": {"reasons": ["Person repatching the tire", "Lack of raw material"]},
    "step8": {"reasons": ["No person available to collect tire", "Person repatching the tire"]}
}

with open('delay_reasons.json', 'w') as f:
    json.dump(DELAY_REASONS, f, indent=4)

def load_video(video_data, strategy='chat'):
    bridge.set_bridge('torch')
    mp4_stream = video_data
    num_frames = 24
    decord_vr = VideoReader(io.BytesIO(mp4_stream), ctx=cpu(0))
    frame_id_list = []
    total_frames = len(decord_vr)
    timestamps = [i[0] for i in decord_vr.get_frame_timestamp(np.arange(total_frames))]
    max_second = round(max(timestamps)) + 1
    
    for second in range(max_second):
        closest_num = min(timestamps, key=lambda x: abs(x - second))
        index = timestamps.index(closest_num)
        frame_id_list.append(index)
        if len(frame_id_list) >= num_frames:
            break

    video_data = decord_vr.get_batch(frame_id_list)
    video_data = video_data.permute(3, 0, 1, 2)
    return video_data

def load_model():
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=TORCH_TYPE,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4"
    )
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_PATH,
        torch_dtype=TORCH_TYPE,
        trust_remote_code=True,
        quantization_config=quantization_config,
        device_map="auto"
    ).eval()
    return model, tokenizer

def predict(prompt, video_data, temperature, model, tokenizer):
    strategy = 'chat'
    video = load_video(video_data, strategy=strategy)
    history = []
    inputs = model.build_conversation_input_ids(
        tokenizer=tokenizer,
        query=prompt,
        images=[video],
        history=history,
        template_version=strategy
    )
    inputs = {
        'input_ids': inputs['input_ids'].unsqueeze(0).to(DEVICE),
        'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to(DEVICE),
        'attention_mask': inputs['attention_mask'].unsqueeze(0).to(DEVICE),
        'images': [[inputs['images'][0].to(DEVICE).to(TORCH_TYPE)]],
    }
    gen_kwargs = {
        "max_new_tokens": 2048,
        "pad_token_id": 128002,
        "top_k": 1,
        "do_sample": False,
        "top_p": 0.1,
        "temperature": temperature,
    }
    with torch.no_grad():
        outputs = model.generate(**inputs, **gen_kwargs)
        outputs = outputs[:, inputs['input_ids'].shape[1]:]
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response

def get_base_prompt():
    return """You are an expert AI model trained to analyze and interpret manufacturing processes. 
    The task is to evaluate video footage of specific steps in a tire manufacturing process. 
    The process has 8 total steps, but only delayed steps are provided for analysis.
    
    **Your Goal:**
    1. Analyze the provided video.
    2. Identify possible reasons for the delay in the manufacturing step shown in the video.
    3. Provide a clear explanation of the delay based on observed factors.
    
    **Context:** 
    Tire manufacturing involves 8 steps, and delays may occur due to machinery faults, 
    raw material availability, labor efficiency, or unexpected disruptions.
    
    **Output:**
    Explain why the delay occurred in this step. Include specific observations 
    and their connection to the delay."""

def inference(video, step_number, selected_reason):
    if not video:
        return "Please upload a video first."
    model, tokenizer = load_model()
    video_data = video.read()
    base_prompt = get_base_prompt()
    full_prompt = f"{base_prompt}\n\nAnalyzing Step {step_number}\nPossible reason: {selected_reason}"
    temperature = 0.8
    response = predict(full_prompt, video_data, temperature, model, tokenizer)
    return response

with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            video = gr.Video(label="Video Input", sources=["upload"])
            step_number = gr.Dropdown(choices=[f"Step {i}" for i in range(1, 9)], label="Manufacturing Step", value="Step 1")
            reason = gr.Dropdown(choices=DELAY_REASONS["step1"]["reasons"], label="Possible Delay Reason", value=DELAY_REASONS["step1"]["reasons"][0])
            analyze_btn = gr.Button("Analyze Delay", variant="primary")
        with gr.Column():
            output = gr.Textbox(label="Analysis Result")

    def update_reasons(step):
        step_num = step.lower().replace(" ", "")
        return gr.Dropdown(choices=DELAY_REASONS[step_num]["reasons"])

    step_number.change(fn=update_reasons, inputs=[step_number], outputs=[reason])
    analyze_btn.click(fn=inference, inputs=[video, step_number, reason], outputs=[output])

if __name__ == "__main__":
    demo.launch()