spoorthibhat commited on
Commit
36cca9a
·
verified ·
1 Parent(s): e693582

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -47
app.py CHANGED
@@ -11,26 +11,17 @@ print(torch.cuda.is_available())
11
 
12
  print(os.system('python -m bitsandbytes'))
13
 
14
- import os
15
- import torch
16
- import warnings
17
- warnings.filterwarnings('ignore')
18
-
19
  import io
20
  from contextlib import redirect_stdout
21
- import gradio as gr
22
- from transformers import AutoTokenizer
23
- from llava.model.builder import load_pretrained_model
24
- from llava.mm_utils import get_model_name_from_path
25
  from llava.eval.run_llava import eval_model
26
 
27
- # Check CUDA availability with error handling
28
- device = "cuda" if torch.cuda.is_available() else "cpu"
29
- print(f"Using device: {device}")
30
-
31
- # Define the model path
32
  model_path = "Veda0718/llava-med-v1.5-mistral-7b-finetuned"
33
-
34
  kwargs = {"device_map": "auto"}
35
  kwargs['load_in_4bit'] = True
36
  kwargs['quantization_config'] = BitsAndBytesConfig(
@@ -42,48 +33,44 @@ kwargs['quantization_config'] = BitsAndBytesConfig(
42
  model = LlavaMistralForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
43
  tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
44
 
45
- # Define the inference function
46
- def run_inference(image, question):
47
- if model is None:
48
- return "Model failed to load. Please check the logs."
49
-
50
- args = type('Args', (), {
51
- "model_path": model_path,
52
- "model_base": None,
53
- "image_file": image,
54
- "query": question,
55
- "conv_mode": None,
56
- "sep": ",",
57
- "temperature": 0,
58
- "top_p": None,
59
- "num_beams": 1,
60
- "max_new_tokens": 256
61
- })()
62
-
63
- # Capture the printed output of eval_model
64
- f = io.StringIO()
65
- with redirect_stdout(f):
66
- eval_model(args)
67
- output = f.getvalue()
68
- return output
69
-
70
- # Create the Gradio interface
71
  with gr.Blocks(theme=gr.themes.Monochrome()) as app:
72
  with gr.Column(scale=1):
73
- gr.Markdown("<center><h1>LLaVA-Med</h1></center>")
74
 
75
  with gr.Row():
76
  image = gr.Image(type="filepath", scale=2)
77
- question = gr.Textbox(placeholder="Enter a question", scale=3)
78
 
79
  with gr.Row():
80
- answer = gr.Textbox(placeholder="Answer pops up here", scale=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  with gr.Row():
83
  btn = gr.Button("Run Inference", scale=1)
84
 
85
  btn.click(fn=run_inference, inputs=[image, question], outputs=answer)
86
 
87
- # Launch the app
88
- if __name__ == "__main__":
89
- app.queue().launch(debug=True)
 
11
 
12
  print(os.system('python -m bitsandbytes'))
13
 
14
+ import gradio as gr
 
 
 
 
15
  import io
16
  from contextlib import redirect_stdout
17
+ import openai
18
+ import torch
19
+ from transformers import AutoTokenizer, BitsAndBytesConfig
20
+ from llava.model import LlavaMistralForCausalLM
21
  from llava.eval.run_llava import eval_model
22
 
23
+ # LLaVa-Med model setup
 
 
 
 
24
  model_path = "Veda0718/llava-med-v1.5-mistral-7b-finetuned"
 
25
  kwargs = {"device_map": "auto"}
26
  kwargs['load_in_4bit'] = True
27
  kwargs['quantization_config'] = BitsAndBytesConfig(
 
33
  model = LlavaMistralForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
34
  tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  with gr.Blocks(theme=gr.themes.Monochrome()) as app:
37
  with gr.Column(scale=1):
38
+ gr.Markdown("<center><h1>LLaVa-Med</h1></center>")
39
 
40
  with gr.Row():
41
  image = gr.Image(type="filepath", scale=2)
42
+ question = gr.Textbox(placeholder="Enter a question", label="Question", scale=3)
43
 
44
  with gr.Row():
45
+ answer = gr.Textbox(placeholder="Answer pops up here", label="Answer", scale=1)
46
+
47
+ def run_inference(image, question):
48
+ # Arguments for the model
49
+ args = type('Args', (), {
50
+ "model_path": model_path,
51
+ "model_base": None,
52
+ "image_file": image,
53
+ "query": question,
54
+ "conv_mode": None,
55
+ "sep": ",",
56
+ "temperature": 0,
57
+ "top_p": None,
58
+ "num_beams": 1,
59
+ "max_new_tokens": 512
60
+ })()
61
+
62
+ # Capture the printed output of eval_model
63
+ f = io.StringIO()
64
+ with redirect_stdout(f):
65
+ eval_model(args)
66
+ llava_med_result = f.getvalue()
67
+ print(llava_med_result)
68
+
69
+ return llava_med_result
70
 
71
  with gr.Row():
72
  btn = gr.Button("Run Inference", scale=1)
73
 
74
  btn.click(fn=run_inference, inputs=[image, question], outputs=answer)
75
 
76
+ app.launch(debug=True, height=800, width="100%")