Spaces:
Paused
Paused
import os | |
os.system('pip install -q -e .') | |
os.system('pip uninstall bitsandbytes') | |
os.system('pip install bitsandbytes-0.45.0-py3-none-manylinux_2_24_x86_64.whl') | |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" | |
import torch | |
print(torch.cuda.is_available()) | |
print(os.system('python -m bitsandbytes')) | |
import gradio as gr | |
import io | |
from contextlib import redirect_stdout | |
import openai | |
import torch | |
from transformers import AutoTokenizer, BitsAndBytesConfig | |
from llava.model import LlavaMistralForCausalLM | |
from llava.eval.run_llava import eval_model | |
# LLaVa-Med model setup | |
model_path = "Veda0718/llava-med-v1.5-mistral-7b-finetuned" | |
kwargs = {"device_map": "auto"} | |
kwargs['load_in_4bit'] = True | |
kwargs['quantization_config'] = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_compute_dtype=torch.float16, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_quant_type='nf4' | |
) | |
model = LlavaMistralForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) | |
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) | |
with gr.Blocks(theme=gr.themes.Monochrome()) as app: | |
with gr.Column(scale=1): | |
gr.Markdown("<center><h1>LLaVa-Med</h1></center>") | |
with gr.Row(): | |
image = gr.Image(type="filepath", scale=2) | |
question = gr.Textbox(placeholder="Enter a question", label="Question", scale=3) | |
with gr.Row(): | |
answer = gr.Textbox(placeholder="Answer pops up here", label="Answer", scale=1) | |
def run_inference(image, question): | |
# Arguments for the model | |
args = type('Args', (), { | |
"model_path": model_path, | |
"model_base": None, | |
"image_file": image, | |
"query": question, | |
"conv_mode": None, | |
"sep": ",", | |
"temperature": 0, | |
"top_p": None, | |
"num_beams": 1, | |
"max_new_tokens": 512 | |
})() | |
# Capture the printed output of eval_model | |
f = io.StringIO() | |
with redirect_stdout(f): | |
eval_model(args) | |
llava_med_result = f.getvalue() | |
print(llava_med_result) | |
return llava_med_result | |
with gr.Row(): | |
btn = gr.Button("Run Inference", scale=1) | |
btn.click(fn=run_inference, inputs=[image, question], outputs=answer) | |
app.launch(debug=True, height=800, width="100%") |