zeroMN-SHMT / app.py
zeroMN's picture
Update app.py
62b85d7 verified
raw
history blame
3.39 kB
import os
import torch
import torch.nn as nn
import numpy as np
import random
import gradio as gr
from transformers import (
BartForConditionalGeneration,
AutoModelForCausalLM,
BertModel,
Wav2Vec2Model,
CLIPModel,
AutoTokenizer
)
class MultiModalModel(nn.Module):
def __init__(self):
super(MultiModalModel, self).__init__()
# 初始化子模型
self.text_generator = BartForConditionalGeneration.from_pretrained('facebook/bart-base')
self.code_generator = AutoModelForCausalLM.from_pretrained('gpt2')
self.nlp_encoder = BertModel.from_pretrained('bert-base-uncased')
self.speech_encoder = Wav2Vec2Model.from_pretrained('facebook/wav2vec2-base-960h')
self.vision_encoder = CLIPModel.from_pretrained('openai/clip-vit-base-patch32')
# 初始化分词器和处理器
self.text_tokenizer = AutoTokenizer.from_pretrained('facebook/bart-base')
self.code_tokenizer = AutoTokenizer.from_pretrained('gpt2')
self.nlp_tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
self.speech_processor = AutoTokenizer.from_pretrained('facebook/wav2vec2-base-960h')
self.vision_processor = AutoTokenizer.from_pretrained('openai/clip-vit-base-patch32')
def forward(self, task, inputs):
if task == 'text_generation':
attention_mask = inputs.get('attention_mask')
outputs = self.text_generator.generate(
inputs['input_ids'],
max_new_tokens=100,
pad_token_id=self.text_tokenizer.eos_token_id,
attention_mask=attention_mask,
top_p=0.9,
top_k=50,
temperature=0.8,
do_sample=True
)
return self.text_tokenizer.decode(outputs[0], skip_special_tokens=True)
elif task == 'code_generation':
attention_mask = inputs.get('attention_mask')
outputs = self.code_generator.generate(
inputs['input_ids'],
max_new_tokens=50,
pad_token_id=self.code_tokenizer.eos_token_id,
attention_mask=attention_mask,
top_p=0.95,
top_k=50,
temperature=1.2,
do_sample=True
)
return self.code_tokenizer.decode(outputs[0], skip_special_tokens=True)
# 添加其他任务的逻辑...
# 定义 Gradio 接口的推理函数
def gradio_inference(task, input_text):
if task == "text_generation":
tokenizer = model.text_tokenizer
elif task == "code_generation":
tokenizer = model.code_tokenizer
# 根据任务选择合适的分词器
inputs = tokenizer(input_text, return_tensors='pt')
inputs['attention_mask'] = torch.ones_like(inputs['input_ids'])
with torch.no_grad():
result = model(task, inputs)
return result
# 初始化模型
model = MultiModalModel()
# 创建 Gradio 接口
interface = gr.Interface(
fn=gradio_inference,
inputs=[gr.Dropdown(choices=["text_generation", "code_generation"], label="任务类型"), gr.Textbox(lines=2, placeholder="输入文本...")],
outputs="text",
title="多模态模型推理",
description="选择任务类型并输入文本以进行推理"
)
# 启动 Gradio 应用
interface.launch()