Spaces:
Sleeping
Sleeping
Process examples when loaded
Browse files
app.py
CHANGED
@@ -297,23 +297,11 @@ class ImageConductor:
|
|
297 |
guidance_scale,
|
298 |
num_inference_steps,
|
299 |
personalized,
|
300 |
-
examples_type,
|
301 |
):
|
302 |
print("Run!")
|
303 |
-
if examples_type != "":
|
304 |
-
### for adapting high version gradio
|
305 |
-
tracking_points = gr.State([])
|
306 |
-
first_frame_path = IMAGE_PATH[examples_type]
|
307 |
-
points = json.load(open(POINTS[examples_type]))
|
308 |
-
tracking_points.value.extend(points)
|
309 |
-
print("example first_frame_path", first_frame_path)
|
310 |
-
print("example tracking_points", tracking_points.value)
|
311 |
|
312 |
original_width, original_height = 384, 256
|
313 |
-
|
314 |
-
input_all_points = tracking_points
|
315 |
-
else:
|
316 |
-
input_all_points = tracking_points.value
|
317 |
|
318 |
print("input_all_points", input_all_points)
|
319 |
resized_all_points = [
|
@@ -415,7 +403,7 @@ class ImageConductor:
|
|
415 |
# outputs_path = os.path.join(output_dir, f'output_{i}_{id}.gif')
|
416 |
# save_videos_grid(sample[0][None], outputs_path)
|
417 |
print("Done!")
|
418 |
-
return
|
419 |
|
420 |
|
421 |
def reset_states(first_frame_path, tracking_points):
|
@@ -487,6 +475,54 @@ def add_tracking_points(
|
|
487 |
return {tracking_points_var: tracking_points, input_image: trajectory_map}
|
488 |
|
489 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
490 |
def add_drag(tracking_points):
|
491 |
if not tracking_points or tracking_points[-1]:
|
492 |
tracking_points.append([])
|
@@ -571,6 +607,15 @@ def delete_last_step(tracking_points, first_frame_path, drag_mode):
|
|
571 |
return {tracking_points_var: tracking_points, input_image: trajectory_map}
|
572 |
|
573 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
574 |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
575 |
ImageConductor_net = ImageConductor(
|
576 |
device=device,
|
@@ -725,9 +770,16 @@ with block:
|
|
725 |
guidance_scale,
|
726 |
num_inference_steps,
|
727 |
personalized,
|
728 |
-
examples_type,
|
729 |
],
|
730 |
[output_image, output_video],
|
731 |
)
|
732 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
733 |
block.queue().launch()
|
|
|
297 |
guidance_scale,
|
298 |
num_inference_steps,
|
299 |
personalized,
|
|
|
300 |
):
|
301 |
print("Run!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
302 |
|
303 |
original_width, original_height = 384, 256
|
304 |
+
input_all_points = tracking_points
|
|
|
|
|
|
|
305 |
|
306 |
print("input_all_points", input_all_points)
|
307 |
resized_all_points = [
|
|
|
403 |
# outputs_path = os.path.join(output_dir, f'output_{i}_{id}.gif')
|
404 |
# save_videos_grid(sample[0][None], outputs_path)
|
405 |
print("Done!")
|
406 |
+
return visualized_drag, outputs_path
|
407 |
|
408 |
|
409 |
def reset_states(first_frame_path, tracking_points):
|
|
|
475 |
return {tracking_points_var: tracking_points, input_image: trajectory_map}
|
476 |
|
477 |
|
478 |
+
def preprocess_example_image(image_path, tracking_points, drag_mode):
|
479 |
+
image_pil = image2pil(image_path)
|
480 |
+
raw_w, raw_h = image_pil.size
|
481 |
+
resize_ratio = max(384 / raw_w, 256 / raw_h)
|
482 |
+
image_pil = image_pil.resize((int(raw_w * resize_ratio), int(raw_h * resize_ratio)), Image.BILINEAR)
|
483 |
+
image_pil = transforms.CenterCrop((256, 384))(image_pil.convert("RGB"))
|
484 |
+
id = str(uuid.uuid4())[:4]
|
485 |
+
first_frame_path = os.path.join(output_dir, f"first_frame_{id}.jpg")
|
486 |
+
image_pil.save(first_frame_path, quality=95)
|
487 |
+
|
488 |
+
if drag_mode == "object":
|
489 |
+
color = (255, 0, 0, 255)
|
490 |
+
elif drag_mode == "camera":
|
491 |
+
color = (0, 0, 255, 255)
|
492 |
+
|
493 |
+
transparent_background = Image.open(first_frame_path).convert("RGBA")
|
494 |
+
w, h = transparent_background.size
|
495 |
+
transparent_layer = np.zeros((h, w, 4))
|
496 |
+
|
497 |
+
for track in tracking_points:
|
498 |
+
if len(track) > 1:
|
499 |
+
for i in range(len(track) - 1):
|
500 |
+
start_point = track[i]
|
501 |
+
end_point = track[i + 1]
|
502 |
+
vx = end_point[0] - start_point[0]
|
503 |
+
vy = end_point[1] - start_point[1]
|
504 |
+
arrow_length = np.sqrt(vx**2 + vy**2)
|
505 |
+
if i == len(track) - 2:
|
506 |
+
cv2.arrowedLine(
|
507 |
+
transparent_layer, tuple(start_point), tuple(end_point), color, 2, tipLength=8 / arrow_length
|
508 |
+
)
|
509 |
+
else:
|
510 |
+
cv2.line(
|
511 |
+
transparent_layer,
|
512 |
+
tuple(start_point),
|
513 |
+
tuple(end_point),
|
514 |
+
color,
|
515 |
+
2,
|
516 |
+
)
|
517 |
+
else:
|
518 |
+
cv2.circle(transparent_layer, tuple(track[0]), 5, color, -1)
|
519 |
+
|
520 |
+
transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
|
521 |
+
trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
|
522 |
+
|
523 |
+
return trajectory_map, first_frame_path
|
524 |
+
|
525 |
+
|
526 |
def add_drag(tracking_points):
|
527 |
if not tracking_points or tracking_points[-1]:
|
528 |
tracking_points.append([])
|
|
|
607 |
return {tracking_points_var: tracking_points, input_image: trajectory_map}
|
608 |
|
609 |
|
610 |
+
def load_example(drag_mode, examples_type):
|
611 |
+
example_image_path = IMAGE_PATH[examples_type]
|
612 |
+
with open(POINTS[examples_type]) as f:
|
613 |
+
tracking_points = json.load(f)
|
614 |
+
tracking_points = np.round(tracking_points).astype(int).tolist()
|
615 |
+
trajectory_map, first_frame_path = preprocess_example_image(example_image_path, tracking_points, drag_mode)
|
616 |
+
return {input_image: trajectory_map, first_frame_path_var: first_frame_path, tracking_points_var: tracking_points}
|
617 |
+
|
618 |
+
|
619 |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
620 |
ImageConductor_net = ImageConductor(
|
621 |
device=device,
|
|
|
770 |
guidance_scale,
|
771 |
num_inference_steps,
|
772 |
personalized,
|
|
|
773 |
],
|
774 |
[output_image, output_video],
|
775 |
)
|
776 |
|
777 |
+
examples_type.change(
|
778 |
+
fn=load_example,
|
779 |
+
inputs=[drag_mode, examples_type],
|
780 |
+
outputs=[input_image, first_frame_path_var, tracking_points_var],
|
781 |
+
api_name=False,
|
782 |
+
queue=False,
|
783 |
+
)
|
784 |
+
|
785 |
block.queue().launch()
|