Pranav-khetarpal commited on
Commit
ee45c78
·
verified ·
1 Parent(s): dc169a4

Updated app.py to handle custom api requests

Browse files
Files changed (1) hide show
  1. app.py +100 -22
app.py CHANGED
@@ -6,6 +6,7 @@ import argparse
6
  import time
7
  import subprocess
8
 
 
9
  import llava.serve.gradio_web_server as gws
10
 
11
  # Execute the pip install command with additional options
@@ -24,7 +25,7 @@ def start_controller():
24
  "--port",
25
  "10000",
26
  ]
27
- print(controller_command)
28
  return subprocess.Popen(controller_command)
29
 
30
 
@@ -51,28 +52,104 @@ def start_worker(model_path: str, bits=16):
51
  "liuhaotian/llava-1.5-7b",
52
  "--use-flash-attn",
53
  ]
54
- # if bits != 16:
55
- # worker_command += [f"--load-{bits}bit"]
56
- print(worker_command)
57
  return subprocess.Popen(worker_command)
58
 
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  if __name__ == "__main__":
61
- parser = argparse.ArgumentParser()
62
- parser.add_argument("--host", type=str, default="0.0.0.0")
63
- parser.add_argument("--port", type=int)
64
- parser.add_argument("--controller-url", type=str, default="http://localhost:10000")
65
- parser.add_argument("--concurrency-count", type=int, default=5)
66
- parser.add_argument("--model-list-mode", type=str, default="reload", choices=["once", "reload"])
67
- parser.add_argument("--share", action="store_false")
68
- parser.add_argument("--moderate", action="store_true")
69
- parser.add_argument("--embed", action="store_true")
70
- gws.args = parser.parse_args()
 
71
  gws.models = []
72
 
73
  gws.title_markdown += """ AstroLLaVA """
74
 
75
- print(f"astro args: {gws.args}")
76
 
77
  model_path = os.getenv("model", "universeTBD/AstroLLaVA_v2")
78
  bits = int(os.getenv("bits", 4))
@@ -82,24 +159,25 @@ if __name__ == "__main__":
82
  worker_proc = start_worker(model_path, bits=bits)
83
 
84
  # Wait for worker and controller to start
85
- print("Waiting for worker and controller to start")
86
  time.sleep(30)
87
 
88
  exit_status = 0
89
  try:
90
- demo = gws.build_demo(embed_mode=False, cur_dir='./', concurrency_count=concurrency_count)
91
- print("Launching gradio")
 
92
  demo.queue(
93
  status_update_rate=10,
94
  api_open=False
95
  ).launch(
96
- server_name=gws.args.host,
97
- server_port=gws.args.port,
98
- share=gws.args.share
99
  )
100
 
101
  except Exception as e:
102
- print(e)
103
  exit_status = 1
104
  finally:
105
  worker_proc.kill()
 
6
  import time
7
  import subprocess
8
 
9
+ import gradio as gr
10
  import llava.serve.gradio_web_server as gws
11
 
12
  # Execute the pip install command with additional options
 
25
  "--port",
26
  "10000",
27
  ]
28
+ print("Controller Command:", controller_command)
29
  return subprocess.Popen(controller_command)
30
 
31
 
 
52
  "liuhaotian/llava-1.5-7b",
53
  "--use-flash-attn",
54
  ]
55
+ print("Worker Command:", worker_command)
 
 
56
  return subprocess.Popen(worker_command)
57
 
58
 
59
+ def handle_text_prompt(text, temperature=0.2, top_p=0.7, max_new_tokens=512):
60
+ """
61
+ Custom API endpoint to handle text prompts.
62
+ Replace the placeholder logic with actual model inference.
63
+ """
64
+ # TODO: Replace the following placeholder with actual model inference code
65
+ print(f"Received prompt: {text}")
66
+ print(f"Parameters - Temperature: {temperature}, Top P: {top_p}, Max New Tokens: {max_new_tokens}")
67
+
68
+ # Example response (replace with actual model response)
69
+ response = f"Model response to '{text}' with temperature={temperature}, top_p={top_p}, max_new_tokens={max_new_tokens}"
70
+ return response
71
+
72
+
73
+ def add_text_with_image(text, image, mode):
74
+ """
75
+ Custom API endpoint to add text with an image.
76
+ Replace the placeholder logic with actual processing.
77
+ """
78
+ # TODO: Replace the following placeholder with actual processing code
79
+ print(f"Adding text: {text}")
80
+ print(f"Image path: {image}")
81
+ print(f"Image processing mode: {mode}")
82
+
83
+ # Example response (replace with actual processing code)
84
+ response = f"Added text '{text}' with image at '{image}' using mode '{mode}'."
85
+ return response
86
+
87
+
88
+ def build_custom_demo(embed_mode=False, cur_dir='./', concurrency_count=5):
89
+ """
90
+ Builds a Gradio Blocks interface with custom API endpoints.
91
+ """
92
+ with gr.Blocks() as demo:
93
+ gr.Markdown("# AstroLLaVA")
94
+ gr.Markdown("Welcome to the AstroLLaVA interface. Use the API endpoints to interact with the model.")
95
+
96
+ with gr.Row():
97
+ with gr.Column():
98
+ gr.Markdown("## Prompt the Model")
99
+ text_input = gr.Textbox(label="Enter your text prompt", placeholder="Type your prompt here...")
100
+ temperature_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, label="Temperature")
101
+ top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, label="Top P")
102
+ max_tokens_slider = gr.Slider(minimum=1, maximum=1024, value=512, step=1, label="Max New Tokens")
103
+ submit_button = gr.Button("Submit Prompt")
104
+ with gr.Column():
105
+ chatbot_output = gr.Textbox(label="Model Response", interactive=False)
106
+
107
+ submit_button.click(
108
+ fn=handle_text_prompt,
109
+ inputs=[text_input, temperature_slider, top_p_slider, max_tokens_slider],
110
+ outputs=chatbot_output,
111
+ api_name="prompt_model" # Custom API endpoint name
112
+ )
113
+
114
+ with gr.Row():
115
+ with gr.Column():
116
+ gr.Markdown("## Add Text with Image")
117
+ add_text_input = gr.Textbox(label="Add Text", placeholder="Enter text to add...")
118
+ add_image_input = gr.Image(label="Upload Image")
119
+ image_process_mode = gr.Radio(choices=["Crop", "Resize", "Pad", "Default"], value="Default", label="Image Process Mode")
120
+ add_submit_button = gr.Button("Add Text with Image")
121
+ with gr.Column():
122
+ add_output = gr.Textbox(label="Add Text Response", interactive=False)
123
+
124
+ add_submit_button.click(
125
+ fn=add_text_with_image,
126
+ inputs=[add_text_input, add_image_input, image_process_mode],
127
+ outputs=add_output,
128
+ api_name="add_text_with_image" # Another custom API endpoint
129
+ )
130
+
131
+ # Additional API endpoints can be added here following the same structure
132
+
133
+ return demo
134
+
135
+
136
  if __name__ == "__main__":
137
+ parser = argparse.ArgumentParser(description="AstroLLaVA Gradio App")
138
+ parser.add_argument("--host", type=str, default="0.0.0.0", help="Hostname to listen on")
139
+ parser.add_argument("--port", type=int, default=7860, help="Port number")
140
+ parser.add_argument("--controller-url", type=str, default="http://localhost:10000", help="Controller URL")
141
+ parser.add_argument("--concurrency-count", type=int, default=5, help="Number of concurrent requests")
142
+ parser.add_argument("--model-list-mode", type=str, default="reload", choices=["once", "reload"], help="Model list mode")
143
+ parser.add_argument("--share", action="store_true", help="Share the Gradio app publicly")
144
+ parser.add_argument("--moderate", action="store_true", help="Enable moderation")
145
+ parser.add_argument("--embed", action="store_true", help="Enable embed mode")
146
+ args = parser.parse_args()
147
+ gws.args = args
148
  gws.models = []
149
 
150
  gws.title_markdown += """ AstroLLaVA """
151
 
152
+ print(f"AstroLLaVA arguments: {gws.args}")
153
 
154
  model_path = os.getenv("model", "universeTBD/AstroLLaVA_v2")
155
  bits = int(os.getenv("bits", 4))
 
159
  worker_proc = start_worker(model_path, bits=bits)
160
 
161
  # Wait for worker and controller to start
162
+ print("Waiting for worker and controller to start...")
163
  time.sleep(30)
164
 
165
  exit_status = 0
166
  try:
167
+ # Build the custom Gradio demo with additional API endpoints
168
+ demo = build_custom_demo(embed_mode=False, cur_dir='./', concurrency_count=concurrency_count)
169
+ print("Launching Gradio with custom API endpoints...")
170
  demo.queue(
171
  status_update_rate=10,
172
  api_open=False
173
  ).launch(
174
+ server_name=args.host,
175
+ server_port=args.port,
176
+ share=args.share
177
  )
178
 
179
  except Exception as e:
180
+ print(f"An error occurred: {e}")
181
  exit_status = 1
182
  finally:
183
  worker_proc.kill()