kailashahirwar commited on
Commit
475fa42
1 Parent(s): 3cdcc65

First Commit; code for demo app

Browse files
Files changed (3) hide show
  1. .gitignore +2 -0
  2. app.py +321 -0
  3. requirements.txt +2 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .idea
2
+ .token
app.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+
3
+ import gradio as gr
4
+ import json
5
+ import requests
6
+ import time
7
+ from gradio_modal import Modal
8
+ from io import BytesIO
9
+
10
+ TRYON_SERVER_HOST = "https://prod.server.tryonlabs.ai"
11
+ TRYON_SERVER_PORT = "80"
12
+ if TRYON_SERVER_PORT == "80":
13
+ TRYON_SERVER_URL = f"{TRYON_SERVER_HOST}"
14
+ else:
15
+ TRYON_SERVER_URL = f"{TRYON_SERVER_HOST}:{TRYON_SERVER_PORT}"
16
+
17
+ TRYON_SERVER_API_URL = f"{TRYON_SERVER_URL}/api/v1/"
18
+
19
+
20
+ def start_model_swap(input_image, prompt, cls, seed, guidance_scale, num_results, strength, inference_steps):
21
+ # make a request to TryOn Server
22
+ # 1. create an experiment image
23
+ print("inputs:", input_image, prompt, cls, seed, guidance_scale, num_results, strength, inference_steps)
24
+
25
+ if input_image is None:
26
+ raise gr.Error("Select an image!")
27
+
28
+ if prompt is None or prompt == "":
29
+ raise gr.Error("Enter a prompt!")
30
+
31
+ token = load_token()
32
+ if token is None or token == "":
33
+ raise gr.Error("You need to login first!")
34
+ else:
35
+ login(token)
36
+
37
+ byte_io = BytesIO()
38
+ input_image.save(byte_io, 'png')
39
+ byte_io.seek(0)
40
+
41
+ r = requests.post(f"{TRYON_SERVER_API_URL}experiment_image/",
42
+ files={"image": (
43
+ 'ei_image.png',
44
+ byte_io,
45
+ 'image/png'
46
+ )},
47
+ data={
48
+ "type": "model",
49
+ "preprocess": "false"},
50
+ headers={
51
+ "Authorization": f"Bearer {token}"
52
+ })
53
+ # print(r.json())
54
+ if r.status_code == 200 or r.status_code == 201:
55
+ print("Experiment image created successfully", r.json())
56
+ res = r.json()
57
+ # 2 create an experiment
58
+ r2 = requests.post(f"{TRYON_SERVER_API_URL}experiment/",
59
+ data={
60
+ "model_id": res['id'],
61
+ "action": "model_swap",
62
+ "params": json.dumps({"prompt": prompt,
63
+ "guidance_scale": guidance_scale,
64
+ "strength": strength,
65
+ "num_inference_steps": inference_steps,
66
+ "seed": seed,
67
+ "garment_class": f"{cls} garment",
68
+ "negative_prompt": "(hands:1.15), disfigured, ugly, bad, immature"
69
+ ", cartoon, anime, 3d, painting, b&w, (ugly),"
70
+ " (pixelated), watermark, glossy, smooth, "
71
+ "earrings, necklace",
72
+ "num_results": num_results})
73
+ },
74
+ headers={
75
+ "Authorization": f"Bearer {token}"
76
+ })
77
+ if r2.status_code == 200 or r2.status_code == 201:
78
+ # 3. keep checking the status of the experiment
79
+ res2 = r2.json()
80
+ print("Experiment created successfully", res2)
81
+ time.sleep(10)
82
+
83
+ experiment = res2['experiment']
84
+ status = fetch_experiment_status(experiment_id=experiment['id'], token=token)
85
+ status_status = status['status']
86
+ while status_status == "running":
87
+ time.sleep(10)
88
+ status = fetch_experiment_status(experiment_id=experiment['id'], token=token)
89
+ status_status = status['status']
90
+ print(f"Current status: {status_status}")
91
+
92
+ if status['status'] == "success":
93
+ print("Experiment successful")
94
+ print(f"Results:{status['result_images']}")
95
+ return status['result_images']
96
+ elif status['status'] == "failed":
97
+ print("Experiment failed")
98
+ raise gr.Error("Experiment failed")
99
+ else:
100
+ print(f"Error: {r2.text}")
101
+ raise gr.Error(f"Failure: {r2.text}")
102
+ else:
103
+ print(f"Error: {r.text}")
104
+ raise gr.Error(f"Failure: {r.text}")
105
+
106
+
107
+ def fetch_experiment_status(experiment_id, token):
108
+ print(f"experiment id:{experiment_id}")
109
+
110
+ r3 = requests.get(f"{TRYON_SERVER_API_URL}experiment/{experiment_id}/",
111
+ headers={
112
+ "Authorization": f"Bearer {token}"
113
+ })
114
+ if r3.status_code == 200:
115
+ res = r3.json()
116
+ if res['status'] == "running":
117
+ return {"status": "running"}
118
+ elif res['status'] == "success":
119
+ experiment = r3.json()['experiment']
120
+ result_images = [f"{TRYON_SERVER_URL}/{experiment['result']['image_url']}"]
121
+ if len(experiment['results']) > 0:
122
+ for result in experiment['results']:
123
+ result_images.append(f"{TRYON_SERVER_URL}/{result['image_url']}")
124
+ return {"status": "success", "result_images": result_images}
125
+ elif res['status'] == "failed":
126
+ return {"status": "failed"}
127
+ else:
128
+ print(f"Error: {r3.text}")
129
+ return {"status": "failed"}
130
+
131
+
132
+ def get_user_credits(token):
133
+ if token == "":
134
+ return None
135
+
136
+ r = requests.get(f"{TRYON_SERVER_API_URL}user/get/", headers={
137
+ "Authorization": f"Bearer {token}"
138
+ })
139
+ if r.status_code == 200:
140
+ res = r.json()
141
+ return res['credits']
142
+ else:
143
+ print(f"Error: {r.text}")
144
+ return None
145
+
146
+
147
+ def load_token():
148
+ if os.path.exists(".token"):
149
+ with open(".token", "r") as f:
150
+ return json.load(f)['token']
151
+ else:
152
+ return None
153
+
154
+
155
+ def save_token(access_token):
156
+ if access_token != "":
157
+ with open(".token", "w") as f:
158
+ json.dump({"token": access_token}, f)
159
+ else:
160
+ raise gr.Error("No token provided!")
161
+
162
+
163
+ def is_logged_in():
164
+ loaded_token = load_token()
165
+ if loaded_token is None or loaded_token == "":
166
+ return False
167
+ else:
168
+ return True
169
+
170
+
171
+ def login(token):
172
+ print("logging in...")
173
+ # validate token
174
+ r = requests.post(f"{TRYON_SERVER_URL}/api/token/verify/", data={"token": token})
175
+ if r.status_code == 200:
176
+ save_token(token)
177
+ return True
178
+ else:
179
+ raise gr.Error("Login failed")
180
+
181
+
182
+ def logout():
183
+ print("logged out")
184
+ with open(".token", "w") as f:
185
+ json.dump({"token": ""}, f)
186
+ return [False, ""]
187
+
188
+
189
+ css = """
190
+ #col-container {
191
+ margin: 0 auto;
192
+ max-width: 1024px;
193
+ }
194
+ #credits-col-container{
195
+ display:flex;
196
+ justify-content: right;
197
+ align-items: center;
198
+ font-size: 24px;
199
+ margin-right: 1rem;
200
+ }
201
+ #login-modal{
202
+ max-width: 728px;
203
+ margin: 0 auto;
204
+ margin-top: 1rem;
205
+ margin-bottom: 1rem;
206
+ }
207
+ #login-logout-btn{
208
+ display:inline;
209
+ max-width: 124px;
210
+ }
211
+ """
212
+
213
+ with gr.Blocks(css=css) as demo:
214
+ print("is logged in:", is_logged_in())
215
+ logged_in = gr.State(is_logged_in())
216
+ if os.path.exists(".token"):
217
+ with open(".token", "r") as f:
218
+ user_token = gr.State(json.load(f)["token"])
219
+ else:
220
+ user_token = gr.State("")
221
+
222
+ with Modal(visible=False) as modal:
223
+ @gr.render(inputs=user_token)
224
+ def rerender1(user_token1):
225
+ with gr.Column(elem_id="login-modal"):
226
+ access_token = gr.Textbox(
227
+ label="Token",
228
+ lines=1,
229
+ value=user_token1,
230
+ type="password",
231
+ placeholder="Enter your access token here!",
232
+ info="Visit https://playground.tryonlabs.ai to retrieve your access token."
233
+ )
234
+
235
+ login_submit_btn = gr.Button("Login", scale=1, variant='primary')
236
+ login_submit_btn.click(
237
+ fn=lambda access_token: (login(access_token), Modal(visible=False), access_token),
238
+ inputs=[access_token], outputs=[logged_in, modal, user_token],
239
+ concurrency_limit=1)
240
+
241
+ with gr.Row(elem_id="col-container"):
242
+ with gr.Column():
243
+ gr.Markdown(f"""
244
+ # Model Swap AI
245
+ ## by TryOn Labs (https://www.tryonlabs.ai)
246
+ Swap a human model with a artificial model generated by Artificial Model while keeping the garment intact.
247
+ """)
248
+
249
+
250
+ @gr.render(inputs=logged_in)
251
+ def rerender(is_logged_in):
252
+ with gr.Column():
253
+ if not is_logged_in:
254
+ with gr.Row(elem_id="credits-col-container"):
255
+ login_btn = gr.Button(value="Login", variant='primary', elem_id="login-logout-btn", size="sm")
256
+ login_btn.click(lambda: Modal(visible=True), None, modal)
257
+ else:
258
+ user_credits = get_user_credits(load_token())
259
+ print("user_credits", user_credits)
260
+ gr.HTML(f"""<div><p id="credits-col-container">Your Credits:
261
+ {user_credits if user_credits is not None else "0"}</p>
262
+ <p style="text-align: right;">Visit <a href="https://playground.tryonlabs.ai">
263
+ TryOn AI Playground</a> to acquire more credits</p></div>""")
264
+ with gr.Row(elem_id="credits-col-container"):
265
+ logout_btn = gr.Button(value="Logout", scale=1, variant='primary', size="sm",
266
+ elem_id="login-logout-btn")
267
+ logout_btn.click(fn=logout, inputs=None, outputs=[logged_in, user_token], concurrency_limit=1)
268
+
269
+ with gr.Column(elem_id="col-container"):
270
+ with gr.Row():
271
+ with gr.Column():
272
+ input_image = gr.Image(label="Original image", type='pil', height="400px", show_label=True)
273
+ prompt = gr.Textbox(
274
+ label="Prompt",
275
+ lines=3,
276
+ placeholder="Enter your prompt here!",
277
+ )
278
+ dropdown = gr.Dropdown(["upper", "lower", "dress"], value="upper", label="Retain garment",
279
+ info="Select the garment type you want to retain in the generated image!")
280
+
281
+ gallery = gr.Gallery(
282
+ label="Generated images", show_label=True, elem_id="gallery"
283
+ , columns=[3], rows=[1], object_fit="contain", height="auto")
284
+
285
+ # output_image = gr.Image(label="Swapped model", type='pil', height="400px", show_label=True,
286
+ # show_download_button=True)
287
+
288
+ with gr.Accordion("Advanced Settings", open=False):
289
+ with gr.Row():
290
+ seed = gr.Number(label="Seed", value=-1, interactive=True, minimum=-1)
291
+ guidance_scale = gr.Number(label="Guidance Scale", value=7.5, interactive=True, minimum=0.0,
292
+ maximum=10.0,
293
+ step=0.1)
294
+ num_results = gr.Number(label="Number of results", value=2, minimum=1, maximum=5)
295
+
296
+ with gr.Row():
297
+ strength = gr.Slider(0.00, 1.00, value=0.99, label="Strength",
298
+ info="Choose between 0.00 and 1.00", step=0.01, interactive=True)
299
+ inference_steps = gr.Number(label="Inference Steps", value=20, interactive=True, minimum=1, step=1)
300
+
301
+ with gr.Row():
302
+ submit_button = gr.Button("Submit", variant='primary', scale=1)
303
+ reset_button = gr.ClearButton(value="Reset", scale=1)
304
+
305
+ gr.on(
306
+ triggers=[submit_button.click],
307
+ fn=start_model_swap,
308
+ inputs=[input_image, prompt, dropdown, seed, guidance_scale, num_results, strength, inference_steps],
309
+ outputs=[gallery]
310
+ )
311
+
312
+ reset_button.click(
313
+ fn=lambda: (None, None, "upper", None, -1, 7.5, 2, 0.99, 20),
314
+ inputs=[],
315
+ outputs=[input_image, prompt, dropdown, gallery, seed, guidance_scale,
316
+ num_results, strength, inference_steps],
317
+ concurrency_limit=1,
318
+ )
319
+
320
+ if __name__ == '__main__':
321
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ gradio
2
+ gradio_modal