SIUBIU commited on
Commit
cf9150b
·
verified ·
1 Parent(s): de57da4

Update clothGen.py

Browse files
Files changed (1) hide show
  1. clothGen.py +125 -18
clothGen.py CHANGED
@@ -3,6 +3,11 @@ import pandas as pd
3
  from prompt_gen import prompt_gen
4
  import requests
5
  import os
 
 
 
 
 
6
 
7
  nv_prompt_file = pd.read_excel('汉服-女词库.xlsx')
8
  na_prompt_file = pd.read_excel('汉服-男词库.xlsx')
@@ -11,41 +16,51 @@ na_prompt = na_prompt_file.to_string(index=False)
11
  save_directory = "downloads"
12
 
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  def pro_gen(advice, gender, index):
15
  prompt = prompt_gen(advice, gender)
16
  start_index = prompt.find("Begin")
17
  if start_index == -1:
18
  start_index = prompt.find("begin")
19
- intro_index = prompt.find("服饰风格介绍")
20
- cloth_intro = ""
21
  prompt__gen = ""
22
  if start_index != -1:
23
  start_index += len("Begin\n")
24
  end_index = prompt.find("End")
25
  if end_index != -1:
26
  prompt__gen = prompt[start_index:end_index]
 
 
27
  filename = os.path.join(save_directory, f"prompt_{index}.txt")
28
  with open(filename, "w") as file:
29
  file.write(prompt__gen)
30
- print(prompt__gen)
31
  else:
32
  print("No 'promptEnd' found after 'prompt'.")
33
  else:
34
  print("No 'prompt' found in the text.")
35
- if intro_index != -1:
36
- intro_index += len("服饰风格介绍\n")
37
- cloth_intro = ("汉服,是汉民族的传统服饰。又称衣冠、衣裳、汉装。汉服是中国“衣冠上国”“礼仪之邦”“锦绣中华”的体现,承载了中国的染织绣等杰出"
38
- "工艺和美学,传承了30多项中国非物质文化遗产以及受保护的中国工艺美术。\n") + prompt[intro_index:]
39
- filename = os.path.join(save_directory, f"cloth_intro_{index}.txt")
40
- with open(filename, "w") as file:
41
- file.write(cloth_intro)
42
- print(cloth_intro)
43
- else:
44
- print("No '服饰风格介绍' found.")
45
  return prompt__gen
46
 
47
 
48
  def generate(lora_path, prompt__gen, index):
 
49
  handler = fal_client.submit(
50
  "fal-ai/fast-sdxl",
51
  arguments={
@@ -56,7 +71,7 @@ def generate(lora_path, prompt__gen, index):
56
  "image_size": "portrait_4_3",
57
  "num_inference_steps": 28,
58
  "guidance_scale": 7.5,
59
- "num_images": 2,
60
  "loras": [{"path": lora_path, "scale": 0.7}],
61
  "embeddings": [],
62
  "safety_checker_version": "v1",
@@ -70,28 +85,120 @@ def generate(lora_path, prompt__gen, index):
70
  for image in result['images']:
71
  response = requests.get(image['url'])
72
  if response.status_code == 200:
73
- filename = os.path.join(save_directory, f"gen_cloth_{image_index}.jpeg")
74
  with open(filename, 'wb') as f:
75
  f.write(response.content)
76
  image_index += 1
77
  else:
78
  print(f"Failed to download image from {image['url']}")
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  def cloth_gen(gender):
 
82
  lora_path = "https://huggingface.co/PPSharks/PPSharksModels/resolve/main/NV.safetensors"
 
 
83
  if gender == "男":
84
  lora_path = "https://huggingface.co/PPSharks/PPSharksModels/resolve/main/NA.safetensors"
85
- else:
 
 
 
86
  lora_path = "https://huggingface.co/PPSharks/PPSharksModels/resolve/main/NV.safetensors"
 
 
 
87
 
88
  cloth_image = []
89
  for i in range(1, 4):
90
  with open(os.path.join(save_directory, f"prompt_{i}.txt"), "r") as file:
91
  prompt__gen = file.read()
92
  generate(lora_path, prompt__gen, i)
93
- cloth_image.append(os.path.join(save_directory, f"gen_cloth_{i*2-1}.jpeg"))
94
- cloth_image.append(os.path.join(save_directory, f"gen_cloth_{i*2}.jpeg"))
 
 
95
  with open(os.path.join(save_directory, f"cloth_intro_1.txt"), "r") as file:
96
  cloth_intro = file.read()
97
  return cloth_image, cloth_image[0], cloth_intro
 
3
  from prompt_gen import prompt_gen
4
  import requests
5
  import os
6
+ from openai import OpenAI
7
+ import random
8
+ import shutil
9
+ from pathlib import Path
10
+ from PIL import Image
11
 
12
  nv_prompt_file = pd.read_excel('汉服-女词库.xlsx')
13
  na_prompt_file = pd.read_excel('汉服-男词库.xlsx')
 
16
  save_directory = "downloads"
17
 
18
 
19
+ def prompt_nan(prompt):
20
+ client = OpenAI()
21
+ completion = client.chat.completions.create(
22
+ model="gpt-4o",
23
+ messages=[
24
+ {"role": "system",
25
+ "content": "You are a helpful assistant.", },
26
+ {"role": "user",
27
+ "content": "Please showcase the overall appearance of this Hanfu robe against a contrasting white "
28
+ "background. Highlight its intricate details and unique design elements, including" + prompt,
29
+ }
30
+ ]
31
+
32
+ )
33
+ print("change prompt: ")
34
+ print(completion.choices[0].message.content)
35
+ return completion.choices[0].message.content
36
+
37
+
38
  def pro_gen(advice, gender, index):
39
  prompt = prompt_gen(advice, gender)
40
  start_index = prompt.find("Begin")
41
  if start_index == -1:
42
  start_index = prompt.find("begin")
 
 
43
  prompt__gen = ""
44
  if start_index != -1:
45
  start_index += len("Begin\n")
46
  end_index = prompt.find("End")
47
  if end_index != -1:
48
  prompt__gen = prompt[start_index:end_index]
49
+ # if gender == "男":
50
+ # prompt__gen = prompt_nan(prompt__gen)
51
  filename = os.path.join(save_directory, f"prompt_{index}.txt")
52
  with open(filename, "w") as file:
53
  file.write(prompt__gen)
54
+ # print(prompt__gen)
55
  else:
56
  print("No 'promptEnd' found after 'prompt'.")
57
  else:
58
  print("No 'prompt' found in the text.")
 
 
 
 
 
 
 
 
 
 
59
  return prompt__gen
60
 
61
 
62
  def generate(lora_path, prompt__gen, index):
63
+ # print(prompt__gen)
64
  handler = fal_client.submit(
65
  "fal-ai/fast-sdxl",
66
  arguments={
 
71
  "image_size": "portrait_4_3",
72
  "num_inference_steps": 28,
73
  "guidance_scale": 7.5,
74
+ "num_images": 1,
75
  "loras": [{"path": lora_path, "scale": 0.7}],
76
  "embeddings": [],
77
  "safety_checker_version": "v1",
 
85
  for image in result['images']:
86
  response = requests.get(image['url'])
87
  if response.status_code == 200:
88
+ filename = os.path.join(save_directory, f"cloth_{image_index}.jpeg")
89
  with open(filename, 'wb') as f:
90
  f.write(response.content)
91
  image_index += 1
92
  else:
93
  print(f"Failed to download image from {image['url']}")
94
 
95
+ client = OpenAI()
96
+ completion = client.chat.completions.create(
97
+ model="gpt-4o",
98
+ messages=[
99
+ {"role": "system",
100
+ "content": "You are a helpful assistant.", },
101
+ {"role": "user",
102
+ "content": prompt__gen + "以上是一段对于一套汉服的描述,请根据描述内容对该套汉服进行介绍。要求以介绍的口吻输出内容",
103
+ }
104
+ ]
105
+
106
+ )
107
+ cloth_intro = ("汉服,是汉民族的传统服饰。又称衣冠、衣裳、汉装。汉服是中国“衣冠上国”“礼仪之邦”“锦绣中华”的体现,承载了中国的染织绣等杰出"
108
+ "工艺和美学,传承了30多项中国非物质文化遗产以及受保护的中国工艺美术。\n") + completion.choices[
109
+ 0].message.content
110
+ filename = os.path.join(save_directory, f"cloth_intro_{index * 2 - 1}.txt")
111
+ with open(filename, "w") as file:
112
+ file.write(cloth_intro)
113
+
114
+
115
+ def convert_image_to_jpeg(input_path, output_path):
116
+ try:
117
+ image = Image.open(input_path)
118
+ if image.mode in ('RGBA', 'LA'):
119
+ image = image.convert('RGB')
120
+ image.save(output_path, 'JPEG')
121
+ except Exception as e:
122
+ print(f"转换图像时出错: {e}")
123
+
124
+
125
+ def pic_match(prompt__gen, cates, folder_path, intro_path, index):
126
+ client = OpenAI()
127
+ completion = client.chat.completions.create(
128
+ model="gpt-4o",
129
+ messages=[
130
+ {"role": "system",
131
+ "content": "You are a helpful assistant.", },
132
+ {"role": "user",
133
+ "content": prompt__gen + "以上是关于一套汉族服饰的描述,请根据描述内容从以下几种颜色中选择最符合描述的一种,可选颜色包括:" + cates
134
+ + ". 仅需输出一种颜色名称,不要带任何符号",
135
+ }
136
+ ]
137
+
138
+ )
139
+ print(f"Selected color: {completion.choices[0].message.content}")
140
+
141
+ folder_path = os.path.join(folder_path, completion.choices[0].message.content)
142
+ files = os.listdir(folder_path)
143
+ random_file = random.choice(files)
144
+ source_file_path = os.path.join(folder_path, random_file)
145
+ file_prefix, file_ext = os.path.splitext(random_file)
146
+ target_file_path = os.path.join(save_directory, f"cloth_{index * 2}.jpeg")
147
+ convert_image_to_jpeg(source_file_path, target_file_path)
148
+
149
+ file_extension = ".txt"
150
+ search_path = Path(intro_path)
151
+ for file in search_path.glob(f"{file_prefix}*{file_extension}"):
152
+ if file.is_file():
153
+ with open(file, "r") as f:
154
+ content = f.read()
155
+ client = OpenAI()
156
+ completion = client.chat.completions.create(
157
+ model="gpt-4o",
158
+ messages=[
159
+ {"role": "system",
160
+ "content": "You are a helpful assistant.", },
161
+ {"role": "user",
162
+ "content": content + "以上是一段对于一套汉服的描述,请根据描述内容对该套汉服进行介绍。要求以介绍的口吻输出内容",
163
+ }
164
+ ]
165
+
166
+ )
167
+ cloth_intro = ("汉服,是汉民族的传统服饰。又称衣冠、衣裳、汉装。汉服是中国“衣冠上国”“礼仪之邦”“锦绣中华”的体现,承载了中国的染织绣等杰出"
168
+ "工艺和美学,传承了30多项中国非物质文化遗产以及受保护的中国工艺美术。\n") + \
169
+ completion.choices[0].message.content
170
+ filename = os.path.join(save_directory, f"cloth_intro_{index * 2}.txt")
171
+ with open(filename, "w") as file:
172
+ file.write(cloth_intro)
173
+
174
+ return target_file_path
175
+
176
 
177
  def cloth_gen(gender):
178
+ cates = "Black, Blue, Green, Orange, Pink, Red, Violet, White, Yellow"
179
  lora_path = "https://huggingface.co/PPSharks/PPSharksModels/resolve/main/NV.safetensors"
180
+ folder_path = "database/female"
181
+ intro_path = "database/female_intro"
182
  if gender == "男":
183
  lora_path = "https://huggingface.co/PPSharks/PPSharksModels/resolve/main/NA.safetensors"
184
+ cates = "Black, Blue, Green, Brown, Red, Violet"
185
+ folder_path = "database/male"
186
+ intro_path = "database/male_intro"
187
+ elif gender == "女":
188
  lora_path = "https://huggingface.co/PPSharks/PPSharksModels/resolve/main/NV.safetensors"
189
+ cates = "Black, Blue, Green, Orange, Pink, Red, Violet, White, Yellow"
190
+ folder_path = "database/female"
191
+ intro_path = "database/female_intro"
192
 
193
  cloth_image = []
194
  for i in range(1, 4):
195
  with open(os.path.join(save_directory, f"prompt_{i}.txt"), "r") as file:
196
  prompt__gen = file.read()
197
  generate(lora_path, prompt__gen, i)
198
+ cloth_image.append(os.path.join(save_directory, f"cloth_{i*2-1}.jpeg"))
199
+ pic_path = pic_match(prompt__gen, cates, folder_path, intro_path, i)
200
+ cloth_image.append(pic_path)
201
+
202
  with open(os.path.join(save_directory, f"cloth_intro_1.txt"), "r") as file:
203
  cloth_intro = file.read()
204
  return cloth_image, cloth_image[0], cloth_intro