fancyfeast commited on
Commit
5d57e40
1 Parent(s): f73cf3f

Improve handling caption tone special case. Also, derp, forgot to format the prompt string.

Browse files
Files changed (1) hide show
  1. app.py +11 -1
app.py CHANGED
@@ -144,12 +144,20 @@ image_adapter.to("cuda")
144
  def stream_chat(input_image: Image.Image, caption_type: str, caption_tone: str, caption_length: str | int) -> str:
145
  torch.cuda.empty_cache()
146
 
 
147
  length = None if caption_length == "any" else caption_length
 
 
 
 
 
 
148
  prompt_key = (caption_type, caption_tone, isinstance(length, str), isinstance(length, int))
149
  if prompt_key not in CAPTION_TYPE_MAP:
150
  raise ValueError(f"Invalid caption type: {prompt_key}")
151
 
152
- prompt_str = CAPTION_TYPE_MAP[prompt_key][0]
 
153
 
154
  # Preprocess image
155
  #image = clip_processor(images=input_image, return_tensors='pt').pixel_values
@@ -230,6 +238,8 @@ with gr.Blocks() as demo:
230
  value="any",
231
  )
232
 
 
 
233
  run_button = gr.Button("Caption")
234
 
235
  with gr.Column():
 
144
  def stream_chat(input_image: Image.Image, caption_type: str, caption_tone: str, caption_length: str | int) -> str:
145
  torch.cuda.empty_cache()
146
 
147
+ # 'any' means no length specified
148
  length = None if caption_length == "any" else caption_length
149
+
150
+ # 'rng-tags' and 'training_prompt' don't have formal/informal tones
151
+ if caption_type == "rng-tags" or caption_type == "training_prompt":
152
+ caption_tone = "formal"
153
+
154
+ # Build prompt
155
  prompt_key = (caption_type, caption_tone, isinstance(length, str), isinstance(length, int))
156
  if prompt_key not in CAPTION_TYPE_MAP:
157
  raise ValueError(f"Invalid caption type: {prompt_key}")
158
 
159
+ prompt_str = CAPTION_TYPE_MAP[prompt_key][0].format(length=length, word_count=length)
160
+ print(f"Prompt: {prompt_str}")
161
 
162
  # Preprocess image
163
  #image = clip_processor(images=input_image, return_tensors='pt').pixel_values
 
238
  value="any",
239
  )
240
 
241
+ gr.Markdown("**Note:** Caption tone doesn't affect `rng-tags` and `training_prompt`.")
242
+
243
  run_button = gr.Button("Caption")
244
 
245
  with gr.Column():