WooWoof_AI / app.py
larry1129's picture
Update app.py
91633ba verified
raw
history blame
3.93 kB
import gradio as gr
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, LlavaNextProcessor, LlavaNextForConditionalGeneration
from PIL import Image
# 获取 Hugging Face 访问令牌
hf_token = os.getenv("HF_API_TOKEN")
# 定义模型名称
vqa_model_name = "llava-hf/llava-v1.6-mistral-7b-hf"
language_model_name = "larry1129/WooWoof_AI_Vision_merged_16bit_3b"
# 全局变量用于缓存模型和分词器
vqa_processor = None
vqa_model = None
language_tokenizer = None
language_model = None
# 初始化看图说话模型
def load_vqa_model():
global vqa_processor, vqa_model
if vqa_processor is None or vqa_model is None:
vqa_processor = LlavaNextProcessor.from_pretrained(vqa_model_name, use_auth_token=hf_token)
vqa_model = LlavaNextForConditionalGeneration.from_pretrained(
vqa_model_name,
torch_dtype=torch.float16,
low_cpu_mem_usage=True
).to("cuda:0")
return vqa_processor, vqa_model
# 初始化纯语言模型
def load_language_model():
global language_tokenizer, language_model
if language_tokenizer is None or language_model is None:
language_tokenizer = AutoTokenizer.from_pretrained(language_model_name, use_auth_token=hf_token)
language_model = AutoModelForCausalLM.from_pretrained(
language_model_name,
device_map="auto",
torch_dtype=torch.float16
)
language_tokenizer.pad_token = language_tokenizer.eos_token
language_model.config.pad_token_id = language_tokenizer.pad_token_id
language_model.eval()
return language_tokenizer, language_model
# 从图片生成描述
def generate_image_description(image):
vqa_processor, vqa_model = load_vqa_model()
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": "What is shown in this image?"},
{"type": "image"},
],
},
]
prompt = vqa_processor.apply_chat_template(conversation, add_generation_prompt=True)
inputs = vqa_processor(images=image, text=prompt, return_tensors="pt").to("cuda:0")
with torch.no_grad():
output = vqa_model.generate(**inputs, max_new_tokens=100)
image_description = vqa_processor.decode(output[0], skip_special_tokens=True)
return image_description
# 使用纯语言模型生成最终回答
def generate_language_response(instruction, image_description):
language_tokenizer, language_model = load_language_model()
prompt = f"""### Instruction:
{instruction}
### Input:
{image_description}
### Response:
"""
inputs = language_tokenizer(prompt, return_tensors="pt").to(language_model.device)
with torch.no_grad():
outputs = language_model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs.get("attention_mask"),
max_new_tokens=128,
temperature=0.7,
top_p=0.95,
do_sample=True,
)
response = language_tokenizer.decode(outputs[0], skip_special_tokens=True)
response = response.split("### Response:")[-1].strip()
return response
# 整合的 Gradio 接口函数
def process_image_and_text(image, instruction):
image_description = generate_image_description(image)
final_response = generate_language_response(instruction, image_description)
return f"图片描述: {image_description}\n\n最终回答: {final_response}"
# 创建 Gradio 界面
iface = gr.Interface(
fn=process_image_and_text,
inputs=[
gr.Image(type="pil", label="上传图片"),
gr.Textbox(lines=2, placeholder="Instruction", label="Instruction")
],
outputs="text",
title="WooWoof AI - 图片和文本交互",
description="输入图片并添加指令,生成基于图片描述的回答。",
allow_flagging="never"
)
# 启动 Gradio 接口
iface.launch()