sounar commited on
Commit
4f5fa66
1 Parent(s): 2bf9d03

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -35
app.py CHANGED
@@ -1,16 +1,12 @@
1
- import os
2
  import torch
3
- from flask import Flask, request, jsonify
4
  from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
5
  from PIL import Image
6
- import io
7
- import base64
8
 
9
- # Get API token from environment variable
10
  api_token = os.getenv("HF_TOKEN").strip()
11
 
12
- app = Flask(__name__)
13
-
14
  # Quantization configuration
15
  bnb_config = BitsAndBytesConfig(
16
  load_in_4bit=True,
@@ -19,7 +15,7 @@ bnb_config = BitsAndBytesConfig(
19
  bnb_4bit_compute_dtype=torch.float16
20
  )
21
 
22
- # Load model without Flash Attention
23
  model = AutoModel.from_pretrained(
24
  "ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1",
25
  quantization_config=bnb_config,
@@ -35,45 +31,42 @@ tokenizer = AutoTokenizer.from_pretrained(
35
  token=api_token
36
  )
37
 
38
- def decode_base64_image(base64_string):
39
- # Decode base64 image
40
- image_data = base64.b64decode(base64_string)
41
- image = Image.open(io.BytesIO(image_data)).convert('RGB')
42
- return image
43
-
44
- @app.route('/analyze', methods=['POST'])
45
- def analyze_input():
46
- data = request.json
47
- question = data.get('question', '')
48
- base64_image = data.get('image', None)
49
-
50
  try:
51
- # Process with image if provided
52
- if base64_image:
53
- image = decode_base64_image(base64_image)
54
  inputs = model.prepare_inputs_for_generation(
55
  input_ids=tokenizer(question, return_tensors="pt").input_ids,
56
  images=[image]
57
  )
58
  outputs = model.generate(**inputs, max_new_tokens=256)
59
  else:
60
- # Text-only processing
61
  inputs = tokenizer(question, return_tensors="pt")
62
  outputs = model.generate(**inputs, max_new_tokens=256)
63
 
64
  # Decode response
65
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
66
-
67
- return jsonify({
68
- 'status': 'success',
69
- 'response': response
70
- })
71
 
72
  except Exception as e:
73
- return jsonify({
74
- 'status': 'error',
75
- 'message': str(e)
76
- }), 500
 
 
 
 
 
 
 
 
 
 
77
 
78
- if __name__ == '__main__':
79
- app.run(debug=True)
 
 
1
+ import gradio as gr
2
  import torch
 
3
  from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
4
  from PIL import Image
5
+ import os
 
6
 
7
+ # Get API token from environment variables
8
  api_token = os.getenv("HF_TOKEN").strip()
9
 
 
 
10
  # Quantization configuration
11
  bnb_config = BitsAndBytesConfig(
12
  load_in_4bit=True,
 
15
  bnb_4bit_compute_dtype=torch.float16
16
  )
17
 
18
+ # Load the model and tokenizer
19
  model = AutoModel.from_pretrained(
20
  "ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1",
21
  quantization_config=bnb_config,
 
31
  token=api_token
32
  )
33
 
34
+ # Function to handle inputs
35
+ def process_query(image, question):
 
 
 
 
 
 
 
 
 
 
36
  try:
37
+ if image:
38
+ # Process image and text
39
+ image = image.convert('RGB')
40
  inputs = model.prepare_inputs_for_generation(
41
  input_ids=tokenizer(question, return_tensors="pt").input_ids,
42
  images=[image]
43
  )
44
  outputs = model.generate(**inputs, max_new_tokens=256)
45
  else:
46
+ # Process text-only
47
  inputs = tokenizer(question, return_tensors="pt")
48
  outputs = model.generate(**inputs, max_new_tokens=256)
49
 
50
  # Decode response
51
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
52
+ return response
 
 
 
 
53
 
54
  except Exception as e:
55
+ return f"Error: {str(e)}"
56
+
57
+ # Define Gradio interface
58
+ interface = gr.Interface(
59
+ fn=process_query,
60
+ inputs=[
61
+ gr.Image(type="pil", label="Upload an Image (Optional)"),
62
+ gr.Textbox(label="Enter a Question")
63
+ ],
64
+ outputs="text",
65
+ title="ContactDoctor Multimodal Medical Assistant",
66
+ description="Provide an image and/or question to get AI-powered medical advice.",
67
+ enable_api=True # Enable API for external calls
68
+ )
69
 
70
+ # Launch the app
71
+ if __name__ == "__main__":
72
+ interface.launch()