csmx / app.py
wixcs's picture
Update app.py
80645ea verified
raw
history blame contribute delete
901 Bytes
from transformers import AutoTokenizer, AutoModelForVision2Seq
import gradio as gr
from PIL import Image
# 加载 Qwen2-VL-7B ζ¨‘εž‹
MODEL_NAME = "Qwen/Qwen2-VL-7B"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForVision2Seq.from_pretrained(MODEL_NAME)
def generate_response(image_path, text_prompt):
image = Image.open(image_path)
inputs = tokenizer(text_prompt, return_tensors="pt").to(model.device)
vision_inputs = model.processor(images=image, return_tensors="pt").to(model.device)
outputs = model.generate(**vision_inputs, **inputs)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# Gradio η•Œι’
iface = gr.Interface(
fn=generate_response,
inputs=[gr.Image(type="filepath"), gr.Textbox(label="Text Prompt")],
outputs="text",
title="Qwen2-VL-7B Image + Text Generator"
)
if __name__ == "__main__":
iface.launch()