SIUBIU commited on
Commit
281597f
·
verified ·
1 Parent(s): 0178a87

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +171 -0
  2. user_dress.py +106 -0
app.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from suggestion import generate_outfit_advice
3
+ from clothGen import cloth_gen
4
+ from user_dress import user_cloths
5
+ import requests
6
+ import os
7
+ from io import BytesIO
8
+ from PIL import Image
9
+ from cal_compatibility import cal_compatibility
10
+
11
+
12
+ gen_pic_num = 6
13
+
14
+
15
+ def get_select_index(evt: gr.SelectData, gallery):
16
+ print(gallery[evt.index][0])
17
+ # response = requests.get(gallery[evt.index][0])
18
+ # img = Image.open(BytesIO(response.content))
19
+ return gallery[evt.index][0]
20
+
21
+
22
+ def update_choices(dropout1, dropout2,):
23
+ if dropout1 == "男":
24
+ option = ['倒三角形', '矩形', '苹果形', '沙漏型', '胖型']
25
+ else:
26
+ option = ["梨形", "草莓形", "沙漏形", "标准", "苹果形"]
27
+ dropout2 = gr.Dropdown(choices=option)
28
+ return dropout2
29
+
30
+
31
+ with gr.Blocks(css="styles.css", theme=gr.themes.Base()) as demo:
32
+ with gr.Row():
33
+ # 左侧模块
34
+ with gr.Column(scale=1):
35
+ with gr.Row():
36
+ gr.Markdown(""
37
+ "# 用户信息"
38
+ "")
39
+ with gr.Row():
40
+ text_input1 = gr.Textbox(label="用户姓名", min_width=100)
41
+ text_input2 = gr.Textbox(label="身高/cm", min_width=100)
42
+ text_input3 = gr.Textbox(label="体重/kg", min_width=100)
43
+ text_input4 = gr.Textbox(label="腰围/cm", min_width=100)
44
+ text_input5 = gr.Textbox(label="胸围/cm", min_width=100)
45
+ text_input6 = gr.Textbox(label="臀围/cm", min_width=100)
46
+ text_input7 = gr.Textbox(label="肩宽/cm", min_width=100)
47
+ text_input8 = gr.Textbox(label="腿长/cm", min_width=100)
48
+ text_input9 = gr.Textbox(label="臂长/cm", min_width=100)
49
+ dropdown_options1 = ["男", "女"]
50
+ dropdown_input1 = gr.Dropdown(choices=dropdown_options1, label="性别", min_width=100)
51
+ dropdown_input2 = gr.Dropdown(choices=[], label="体型分类", min_width=100)
52
+ dropdown_input1.change(fn=update_choices, inputs=[dropdown_input1, dropdown_input2], outputs=dropdown_input2)
53
+ dropdown_options3 = ["浅色", "中等偏黄色", "中等偏褐色", "深色"]
54
+ dropdown_input3 = gr.Dropdown(choices=dropdown_options3, label="肤色", min_width=100)
55
+ text_input10 = gr.Textbox(label="穿衣风格偏好", min_width=1000)
56
+ text_input11 = gr.Textbox(label="生话方式和场景需求", min_width=1000)
57
+ text_input12 = gr.Textbox(label="其他特殊需求", min_width=1000)
58
+ with gr.Row():
59
+ user_pic = gr.Image(label="用户照片", value="model.jpg", height=550, width=300)
60
+
61
+ # 右侧模块
62
+ with gr.Column(scale=2):
63
+ with gr.Row():
64
+ gr.Markdown(""
65
+ "# 穿搭建议"
66
+ "")
67
+ with gr.Row():
68
+ text_output1 = gr.Textbox(label="穿搭建议", lines=12, max_lines=12, interactive=False, show_label=False,
69
+ min_width=1000)
70
+ submit_button_1 = gr.Button("AI智能分析,生成穿搭建议", min_width=1000)
71
+ image_output_1 = gr.Image(label="显示图像", value="image 209.png")
72
+
73
+ gallery_1 = gr.Gallery(
74
+ label="服装", elem_id="gallery",
75
+ value=[
76
+ # os.path.join(example_path, '上衣/_WEB_2016_09_26__2016092617451357e8ee2957aa1_TD.jpg'),
77
+ # os.path.join(example_path, '上衣/_WEB_2016_09_27__2016092717211057ea3a069c749_TD.jpg'),
78
+ # os.path.join(example_path, '上衣/_WEB_2016_09_27__2016092717391657ea3e446ce3f_TD.jpg'),
79
+ # os.path.join(example_path, '上衣/_WEB_2016_09_27__2016092717573057ea428a305bc_TD.jpg'),
80
+ # os.path.join(example_path, '上衣/_WEB_2016_09_28__2016092810150157eb27a56a631_TD.jpg'),
81
+ # os.path.join(example_path, '上衣/_WEB_2016_09_28__2016092810464557eb2f15e1df3_TD.jpg'),
82
+ ],
83
+ columns=[4], rows=[2], object_fit="contain", height=250, min_width=450)
84
+
85
+ gallery_3 = gr.Gallery(
86
+ label="配饰", elem_id="gallery",
87
+ value=[
88
+ 'downloads/access_1.jpg',
89
+ 'downloads/access_2.jpg',
90
+ 'downloads/access_3.jpg',
91
+ 'downloads/access_4.jpg',
92
+ 'downloads/access_5.jpg',
93
+ 'downloads/access_6.jpg',
94
+ ],
95
+ columns=[4], rows=[2], object_fit="contain", height=250, min_width=450)
96
+ submit_button_2 = gr.Button("AI智能分析,生成民族服饰")
97
+
98
+ with gr.Column(scale=2):
99
+ with gr.Row():
100
+ gr.Markdown(""
101
+ "# 搭配生成"
102
+ "")
103
+ with gr.Row():
104
+ gallery_4 = gr.Gallery(
105
+ label="套装", elem_id="gallery",
106
+ value=[],
107
+ columns=[3], rows=[1], object_fit="contain", height=180, min_width=450)
108
+ with gr.Row():
109
+ submit_button_3 = gr.Button("服饰及搭配兼容性排序")
110
+ with gr.Row():
111
+ image_output_5 = gr.Image(label="显示图像", show_label=False, min_width=200, height=350)
112
+ intro = gr.Textbox(label="服饰介绍", lines=14, max_lines=14)
113
+ with gr.Row():
114
+ submit_button_4 = gr.Button("虚拟试穿")
115
+ with gr.Row():
116
+ with gr.Column(scale=1):
117
+ gallery_user = gr.Gallery(
118
+ label="试穿结果",
119
+ elem_id="gallery",
120
+ value=[],
121
+ columns=[3], rows=[2],
122
+ object_fit="contain",
123
+ min_width=200,
124
+ height=350,
125
+ )
126
+ with gr.Column(scale=1):
127
+ feedback = gr.Textbox(label="反馈", placeholder="可以从款式、颜色、图案、风格倾向、文化偏好角度进行反馈", lines=1,
128
+ max_lines=1, elem_id="feedback")
129
+ submit_button_5 = gr.Button("反馈")
130
+ with gr.Row():
131
+ gr.Markdown("""
132
+ 女性体型备注:
133
+ 1. **梨形身材**:臀围比胸围**至少大5.08厘米**
134
+ 2. **草莓形身材**:臀围比胸围**至少小5.08厘米**
135
+ 3. **沙漏形身材**:胸围比腰围**至少大3.81厘米**且腰部线条明显
136
+ 4. **标准身材**:胸围比腰围**至少大3.81厘米**且腰部线条不明显
137
+ 5. **苹果形身材**:胸围比腰围**至少小3.81厘米**
138
+
139
+ 男性体型备注:
140
+ 1. **倒三角形身材**:肩宽比腰围**至少大10厘米**
141
+ 2. **矩形身材**:肩宽与腰围的差异**小于5厘米**
142
+ 3. **苹果形身材**:腰围比肩宽**至少大7.5厘米**
143
+ 4. **沙漏型身材**:腰围比肩宽或臀围**至少小10厘米**,且肩宽和臀围的差异**小于5厘米**
144
+ 5. **胖型身材**:腰围比胸围或肩宽**至少大10厘米**
145
+ """)
146
+
147
+ submit_button_1.click(fn=generate_outfit_advice,
148
+ inputs=[text_input1, text_input2, text_input3, text_input4, text_input5,
149
+ text_input6, text_input7, text_input8, text_input9, dropdown_input1,
150
+ dropdown_input2, dropdown_input3, text_input10, text_input11, text_input12,
151
+ feedback, user_pic],
152
+ outputs=text_output1)
153
+ submit_button_2.click(fn=cloth_gen,
154
+ inputs=[text_output1, dropdown_input1],
155
+ outputs=[gallery_1, image_output_5, intro])
156
+
157
+ gallery_1.select(get_select_index, gallery_1, image_output_5)
158
+ submit_button_3.click(fn=cal_compatibility,
159
+ inputs=[],
160
+ outputs=[gallery_4])
161
+ submit_button_4.click(fn=user_cloths,
162
+ inputs=[user_pic, image_output_5],
163
+ outputs=gallery_user)
164
+ submit_button_5.click(fn=generate_outfit_advice,
165
+ inputs=[text_input1, text_input2, text_input3, text_input4, text_input5,
166
+ text_input6, text_input7, text_input8, text_input9, dropdown_input1,
167
+ dropdown_input2, dropdown_input3, text_input10, text_input11, text_input12,
168
+ feedback, user_pic],
169
+ outputs=text_output1)
170
+
171
+ demo.launch(server_port=7860, share=True)
user_dress.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+ import fal_client
3
+ import os
4
+ from PIL import Image
5
+ import requests
6
+ import time
7
+
8
+ UPLOAD_FOLDER = 'uploads'
9
+ DOWNLOAD_FOLDER = 'downloads'
10
+
11
+
12
+ def user_dress_cat(user_pic, cloth_gen, index):
13
+ time_1 = time.time()
14
+ filename_1 = 'user_image.jpg'
15
+ filename_2 = 'cloth_image.jpg'
16
+ file_path_1 = os.path.join(UPLOAD_FOLDER, filename_1)
17
+ file_path_2 = os.path.join(UPLOAD_FOLDER, filename_2)
18
+ Image.fromarray(user_pic).save(file_path_1)
19
+ Image.fromarray(cloth_gen).save(file_path_2)
20
+ time_2 = time.time()
21
+ save_time = time_2 - time_1
22
+ print(f"save_time:{save_time}")
23
+ time_1 = time.time()
24
+ handler = fal_client.submit(
25
+ "fal-ai/cat-vton",
26
+ arguments={
27
+ "human_image_url": fal_client.upload_file(file_path_1),
28
+ "garment_image_url": fal_client.upload_file(file_path_2),
29
+ "cloth_type": "overall"
30
+ },
31
+ )
32
+ request_id = handler.request_id
33
+ result = fal_client.result("fal-ai/cat-vton", request_id)
34
+ time_2 = time.time()
35
+ cat_time = time_2 - time_1
36
+ print(f"cat_time:{cat_time}")
37
+ time_1 = time.time()
38
+ response = requests.get(result['image']['url'])
39
+ time_2 = time.time()
40
+ url_time = time_2 - time_1
41
+ print(f"url_time:{url_time}")
42
+ time_1 = time.time()
43
+ save_directory = "downloads"
44
+ if response.status_code == 200:
45
+ filename = os.path.join(save_directory, f"cat-vton_{index}.png")
46
+ with open(filename, 'wb') as f:
47
+ f.write(response.content)
48
+ else:
49
+ print(f"Failed to download image from {result['image']['url']}")
50
+ time_2 = time.time()
51
+ downloads_time = time_2 - time_1
52
+ print(f"downloads_time:{downloads_time}")
53
+ return os.path.join(save_directory, f"cat-vton_{index}.png")
54
+
55
+
56
+ def user_dress_idm(user_pic, cloth_gen, index):
57
+ time_1 = time.time()
58
+ filename_1 = 'user_image.jpg'
59
+ filename_2 = 'cloth_image.jpg'
60
+ file_path_1 = os.path.join(UPLOAD_FOLDER, filename_1)
61
+ file_path_2 = os.path.join(UPLOAD_FOLDER, filename_2)
62
+ Image.fromarray(user_pic).save(file_path_1)
63
+ Image.fromarray(cloth_gen).save(file_path_2)
64
+ time_2 = time.time()
65
+ save_time = time_2 - time_1
66
+ print(f"save_time:{save_time}")
67
+ time_1 = time.time()
68
+ handler = fal_client.submit(
69
+ "fal-ai/idm-vton",
70
+ arguments={
71
+ "human_image_url": fal_client.upload_file(file_path_1),
72
+ "garment_image_url": fal_client.upload_file(file_path_2),
73
+ "description": "long-sleeved long coat"
74
+ },
75
+ )
76
+ request_id = handler.request_id
77
+ result = fal_client.result("fal-ai/idm-vton", request_id)
78
+ time_2 = time.time()
79
+ cat_time = time_2 - time_1
80
+ print(f"idm_time:{cat_time}")
81
+ time_1 = time.time()
82
+ response = requests.get(result['image']['url'])
83
+ time_2 = time.time()
84
+ url_time = time_2 - time_1
85
+ print(f"url_time:{url_time}")
86
+ time_1 = time.time()
87
+ save_directory = "downloads"
88
+ if response.status_code == 200:
89
+ filename = os.path.join(save_directory, f"idm-vton_{index}.png")
90
+ with open(filename, 'wb') as f:
91
+ f.write(response.content)
92
+ else:
93
+ print(f"Failed to download image from {result['image']['url']}")
94
+ time_2 = time.time()
95
+ downloads_time = time_2 - time_1
96
+ print(f"downloads_time:{downloads_time}")
97
+ return os.path.join(save_directory, f"idm-vton_{index}.png")
98
+
99
+
100
+ def user_cloths(user_pic, cloth_gen):
101
+ user_cloth = []
102
+ for i in range(1, 4):
103
+ user_cloth.append(user_dress_cat(user_pic, cloth_gen, i))
104
+ for i in range(4, 6):
105
+ user_cloth.append(user_dress_idm(user_pic, cloth_gen, i))
106
+ return user_cloth