sounar commited on
Commit
acfc179
1 Parent(s): 2bdc9ef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -82
app.py CHANGED
@@ -1,91 +1,33 @@
1
- import os
2
  import torch
3
- from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
4
- import gradio as gr
5
- from PIL import Image
6
- from torchvision.transforms import ToTensor
7
 
8
- # Get API token from environment variable
9
- api_token = os.getenv("HF_TOKEN").strip()
 
 
10
 
11
- # Quantization configuration
12
- bnb_config = BitsAndBytesConfig(
13
- load_in_4bit=True,
14
- bnb_4bit_quant_type="nf4",
15
- bnb_4bit_use_double_quant=True,
16
- bnb_4bit_compute_dtype=torch.float16
17
- )
18
 
19
- # Initialize model and tokenizer
20
- model = AutoModel.from_pretrained(
21
- "ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1",
22
- quantization_config=bnb_config,
23
- device_map="auto",
24
- torch_dtype=torch.float16,
25
- trust_remote_code=True,
26
- attn_implementation="flash_attention_2",
27
- token=api_token
28
- )
29
 
30
- tokenizer = AutoTokenizer.from_pretrained(
31
- "ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1",
32
- trust_remote_code=True,
33
- token=api_token
34
- )
35
 
36
- def analyze_input(image, question):
37
- try:
38
- if image is not None:
39
- # Convert to RGB if image is provided
40
- image = image.convert('RGB')
41
-
42
- # Prepare messages in the format expected by the model
43
- msgs = [{'role': 'user', 'content': [image, question]}]
44
-
45
- # Generate response using the chat method
46
- response_stream = model.chat(
47
- image=image,
48
- msgs=msgs,
49
- tokenizer=tokenizer,
50
- sampling=True,
51
- temperature=0.95,
52
- stream=True
53
- )
54
-
55
- # Collect the streamed response
56
- generated_text = ""
57
- for new_text in response_stream:
58
- generated_text += new_text
59
- print(new_text, flush=True, end='')
60
-
61
- return {"status": "success", "response": generated_text}
62
-
63
- except Exception as e:
64
- import traceback
65
- error_trace = traceback.format_exc()
66
- print(f"Error occurred: {error_trace}")
67
- return {"status": "error", "message": str(e)}
68
 
69
- # Create Gradio interface
70
- demo = gr.Interface(
71
- fn=analyze_input,
72
- inputs=[
73
- gr.Image(type="pil", label="Upload Medical Image"),
74
- gr.Textbox(
75
- label="Medical Question",
76
- placeholder="Give the modality, organ, analysis, abnormalities (if any), treatment (if abnormalities are present)?",
77
- value="Give the modality, organ, analysis, abnormalities (if any), treatment (if abnormalities are present)?"
78
- )
79
- ],
80
- outputs=gr.JSON(label="Analysis"),
81
- title="Medical Image Analysis Assistant",
82
- description="Upload a medical image and ask questions about it. The AI will analyze the image and provide detailed responses."
83
- )
84
 
85
- # Launch the Gradio app
86
  if __name__ == "__main__":
87
- demo.launch(
88
- share=True,
89
- server_name="0.0.0.0",
90
- server_port=7860
91
- )
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM
2
  import torch
 
 
 
 
3
 
4
+ # Load the model
5
+ model_name = "ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1"
6
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
7
+ model = AutoModelForCausalLM.from_pretrained(model_name)
8
 
9
+ def generate_response(input_text):
10
+ # Tokenize input text
11
+ inputs = tokenizer(input_text, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
12
+ # Generate response
13
+ outputs = model.generate(inputs["input_ids"], max_length=150, temperature=0.7)
14
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
15
+ return response
16
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ from flask import Flask, request, jsonify
19
+ from predict import generate_response # import from the predict file
 
 
 
20
 
21
+ app = Flask(__name__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ @app.route("/predict", methods=["POST"])
24
+ def predict():
25
+ data = request.get_json()
26
+ input_text = data.get("text")
27
+ if not input_text:
28
+ return jsonify({"error": "No input text provided"}), 400
29
+ response = generate_response(input_text)
30
+ return jsonify({"response": response})
 
 
 
 
 
 
 
31
 
 
32
  if __name__ == "__main__":
33
+ app.run(port=5000)