Spaces:
Running
on
Zero
Running
on
Zero
# A100 Zero GPU | |
import spaces | |
import time | |
import torch | |
import gradio as gr | |
from PIL import Image | |
from utils.utils import * | |
from threading import Thread | |
import torch.nn.functional as F | |
from accelerate import Accelerator | |
from meteor.load_mmamba import load_mmamba | |
from meteor.load_meteor import load_meteor | |
from transformers import TextIteratorStreamer | |
from torchvision.transforms.functional import pil_to_tensor | |
# flash attention | |
import subprocess | |
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) | |
# accel | |
accel = Accelerator() | |
# loading meteor model | |
mmamba = load_mmamba('BK-Lee/Meteor-Mamba') | |
meteor, tok_meteor = load_meteor('BK-Lee/Meteor-MLM', bits=4) | |
# freeze model | |
freeze_model(mmamba) | |
freeze_model(meteor) | |
# previous length | |
previous_length = 0 | |
def threading_function(inputs, image_token_number, streamer, device, temperature, new_max_token, top_p): | |
# Meteor Mamba | |
mmamba_inputs = mmamba.eval_process(inputs=inputs, tokenizer=tok_meteor, device=device, img_token_number=image_token_number) | |
if 'image' in mmamba_inputs.keys(): | |
clip_features = meteor.clip_features(mmamba_inputs['image']) | |
mmamba_inputs.update({"image_features": clip_features}) | |
mmamba_outputs = mmamba(**mmamba_inputs) | |
# Meteor | |
meteor_inputs = meteor.eval_process(inputs=inputs, data='demo', tokenizer=tok_meteor, device=device, img_token_number=image_token_number) | |
if 'image' in mmamba_inputs.keys(): | |
meteor_inputs.update({"image_features": clip_features}) | |
meteor_inputs.update({"tor_features": mmamba_outputs.tor_features}) | |
generation_kwargs = meteor_inputs | |
generation_kwargs.update({'streamer': streamer}) | |
generation_kwargs.update({'do_sample': True}) | |
generation_kwargs.update({'max_new_tokens': new_max_token}) | |
generation_kwargs.update({'top_p': top_p}) | |
generation_kwargs.update({'temperature': temperature}) | |
generation_kwargs.update({'use_cache': True}) | |
return meteor.generate(**generation_kwargs) | |
def bot_streaming(message, history, temperature, new_max_token, top_p): | |
try: | |
# param | |
for param in mmamba.parameters(): | |
param.data = param.to(accel.device) | |
for param in meteor.parameters(): | |
param.data = param.to(accel.device) | |
# prompt type -> input prompt | |
image_token_number = int((490/14)**2) | |
if len(message['files']) != 0: | |
# Image Load | |
image = F.interpolate(pil_to_tensor(Image.open(message['files'][0]).convert("RGB")).unsqueeze(0), size=(490, 490), mode='bicubic').squeeze(0) | |
inputs = [{'image': image, 'question': message['text']}] | |
else: | |
inputs = [{'question': message['text']}] | |
# [4] Meteor Generation | |
with torch.inference_mode(): | |
# kwargs | |
streamer = TextIteratorStreamer(tok_meteor, skip_special_tokens=True) | |
# Threading generation | |
thread = Thread(target=threading_function, kwargs=dict(inputs=inputs, | |
image_token_number=image_token_number, | |
streamer=streamer, | |
device=accel.device, | |
temperature=temperature, | |
new_max_token=new_max_token, | |
top_p=top_p)) | |
thread.start() | |
# generated text | |
generated_text = "" | |
for new_text in streamer: | |
generated_text += new_text | |
generated_text | |
# Text decoding | |
response = generated_text.split('assistant\n')[-1].split('[U')[0].strip() | |
except: | |
response = "There may be unsupported format: ex) pdf, video, sound. Only supported is single image in this version." | |
# private log print | |
text = message['text'] | |
files = message['files'] | |
print(f'Text: {text}') | |
print(f'MM Files: {files}') | |
buffer = "" | |
for character in response: | |
buffer += character | |
time.sleep(0.015) | |
yield buffer | |
demo = gr.ChatInterface(fn=bot_streaming, | |
additional_inputs = [gr.Slider(0, 1, 0.9, label="temperature"), gr.Slider(1, 1024, 128, label="new_max_token"), gr.Slider(0, 1, 0.95, label="top_p")], | |
additional_inputs_accordion="Generation Hyperparameters", | |
theme=gr.themes.Soft(), | |
title="☄️Meteor", | |
description="Meteor is efficient 7B size Large Language and Vision Model built on the help of traversal of rationale.\n" | |
"Its inference speed highly depends on assinging non-scheduled GPU. (Therefore, once all GPUs are busy, then inference may be taken in infinity)", | |
stop_btn="Stop Generation", multimodal=True) | |
demo.launch() |