Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, LlavaNextProcessor, LlavaNextForConditionalGeneration | |
from PIL import Image | |
# 获取 Hugging Face 访问令牌 | |
hf_token = os.getenv("HF_API_TOKEN") | |
# 定义模型名称 | |
vqa_model_name = "llava-hf/llava-v1.6-mistral-7b-hf" | |
language_model_name = "larry1129/WooWoof_AI_Vision_merged_16bit_3b" | |
# 全局变量用于缓存模型和分词器 | |
vqa_processor = None | |
vqa_model = None | |
language_tokenizer = None | |
language_model = None | |
# 初始化看图说话模型 | |
def load_vqa_model(): | |
global vqa_processor, vqa_model | |
if vqa_processor is None or vqa_model is None: | |
vqa_processor = LlavaNextProcessor.from_pretrained(vqa_model_name, use_auth_token=hf_token) | |
vqa_model = LlavaNextForConditionalGeneration.from_pretrained( | |
vqa_model_name, | |
torch_dtype=torch.float16, | |
low_cpu_mem_usage=True | |
).to("cuda:0") | |
return vqa_processor, vqa_model | |
# 初始化纯语言模型 | |
def load_language_model(): | |
global language_tokenizer, language_model | |
if language_tokenizer is None or language_model is None: | |
language_tokenizer = AutoTokenizer.from_pretrained(language_model_name, use_auth_token=hf_token) | |
language_model = AutoModelForCausalLM.from_pretrained( | |
language_model_name, | |
device_map="auto", | |
torch_dtype=torch.float16 | |
) | |
language_tokenizer.pad_token = language_tokenizer.eos_token | |
language_model.config.pad_token_id = language_tokenizer.pad_token_id | |
language_model.eval() | |
return language_tokenizer, language_model | |
# 从图片生成描述 | |
def generate_image_description(image): | |
vqa_processor, vqa_model = load_vqa_model() | |
conversation = [ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": "What is shown in this image?"}, | |
{"type": "image"}, | |
], | |
}, | |
] | |
prompt = vqa_processor.apply_chat_template(conversation, add_generation_prompt=True) | |
inputs = vqa_processor(images=image, text=prompt, return_tensors="pt").to("cuda:0") | |
with torch.no_grad(): | |
output = vqa_model.generate(**inputs, max_new_tokens=100) | |
image_description = vqa_processor.decode(output[0], skip_special_tokens=True) | |
return image_description | |
# 使用纯语言模型生成最终回答 | |
def generate_language_response(instruction, image_description): | |
language_tokenizer, language_model = load_language_model() | |
prompt = f"""### Instruction: | |
{instruction} | |
### Input: | |
{image_description} | |
### Response: | |
""" | |
inputs = language_tokenizer(prompt, return_tensors="pt").to(language_model.device) | |
with torch.no_grad(): | |
outputs = language_model.generate( | |
input_ids=inputs["input_ids"], | |
attention_mask=inputs.get("attention_mask"), | |
max_new_tokens=128, | |
temperature=0.7, | |
top_p=0.95, | |
do_sample=True, | |
) | |
response = language_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
response = response.split("### Response:")[-1].strip() | |
return response | |
# 整合的 Gradio 接口函数 | |
def process_image_and_text(image, instruction): | |
image_description = generate_image_description(image) | |
final_response = generate_language_response(instruction, image_description) | |
return f"图片描述: {image_description}\n\n最终回答: {final_response}" | |
# 创建 Gradio 界面 | |
iface = gr.Interface( | |
fn=process_image_and_text, | |
inputs=[ | |
gr.Image(type="pil", label="上传图片"), | |
gr.Textbox(lines=2, placeholder="Instruction", label="Instruction") | |
], | |
outputs="text", | |
title="WooWoof AI - 图片和文本交互", | |
description="输入图片并添加指令,生成基于图片描述的回答。", | |
allow_flagging="never" | |
) | |
# 启动 Gradio 接口 | |
iface.launch() | |