Mairaaa commited on
Commit
ea12b33
·
verified ·
1 Parent(s): 290f7fe

Update src/eval.py

Browse files
Files changed (1) hide show
  1. src/eval.py +74 -71
src/eval.py CHANGED
@@ -1,7 +1,9 @@
1
  import os
2
 
3
- # External libraries
4
  import torch
 
 
5
  from accelerate import Accelerator
6
  from accelerate.logging import get_logger
7
  from diffusers import AutoencoderKL, DDIMScheduler
@@ -9,15 +11,16 @@ from diffusers.utils import check_min_version
9
  from diffusers.utils.import_utils import is_xformers_available
10
  from transformers import CLIPTextModel, CLIPTokenizer
11
 
12
- # Custom imports
13
  from src.datasets.dresscode import DressCodeDataset
14
  from src.datasets.vitonhd import VitonHDDataset
15
  from src.mgd_pipelines.mgd_pipe import MGDPipe
16
  from src.mgd_pipelines.mgd_pipe_disentangled import MGDPipeDisentangled
 
17
  from src.utils.image_from_pipe import generate_images_from_mgd_pipe
18
  from src.utils.set_seeds import set_seed
19
 
20
- # Ensure the minimum version of diffusers is installed
21
  check_min_version("0.10.0.dev0")
22
 
23
  logger = get_logger(__name__, log_level="INFO")
@@ -25,139 +28,139 @@ os.environ["TOKENIZERS_PARALLELISM"] = "true"
25
  os.environ["WANDB_START_METHOD"] = "thread"
26
 
27
 
28
- def main(args):
29
- # Initialize Accelerator
30
- accelerator = Accelerator(mixed_precision=args.get("mixed_precision", "fp16"))
 
 
31
  device = accelerator.device
32
 
33
  # Set the training seed
34
- if args.get("seed") is not None:
35
- set_seed(args["seed"])
36
 
37
  # Load scheduler, tokenizer, and models
38
- val_scheduler = DDIMScheduler.from_pretrained(args["pretrained_model_name_or_path"], subfolder="scheduler")
39
  val_scheduler.set_timesteps(50, device=device)
40
 
41
  tokenizer = CLIPTokenizer.from_pretrained(
42
- args["pretrained_model_name_or_path"], subfolder="tokenizer", revision=args.get("revision", None)
43
  )
44
  text_encoder = CLIPTextModel.from_pretrained(
45
- args["pretrained_model_name_or_path"], subfolder="text_encoder", revision=args.get("revision", None)
46
  )
47
- vae = AutoencoderKL.from_pretrained(args["pretrained_model_name_or_path"], subfolder="vae", revision=args.get("revision", None))
48
 
49
- # Load UNet
50
  unet = torch.hub.load(
 
51
  repo_or_dir="aimagelab/multimodal-garment-designer",
52
  source="github",
53
  model="mgd",
54
  pretrained=True,
55
  )
56
 
57
- # Freeze models
58
  vae.requires_grad_(False)
59
  text_encoder.requires_grad_(False)
60
 
61
  # Enable memory efficient attention if requested
62
- if args.get("enable_xformers_memory_efficient_attention", False):
63
  if is_xformers_available():
64
  unet.enable_xformers_memory_efficient_attention()
65
  else:
66
- raise ValueError("xformers is not available. Install it to enable memory-efficient attention.")
67
 
68
- # Set dataset category
69
- category = [args.get("category", "dresses")]
70
 
71
- # Load dataset
72
- if args["dataset"] == "dresscode":
73
  test_dataset = DressCodeDataset(
74
- dataroot_path=args["dataset_path"],
75
  phase="test",
76
- order=args.get("test_order", 0),
77
  radius=5,
78
  sketch_threshold_range=(20, 20),
79
  tokenizer=tokenizer,
80
  category=category,
81
  size=(512, 384),
82
  )
83
- elif args["dataset"] == "vitonhd":
84
  test_dataset = VitonHDDataset(
85
- dataroot_path=args["dataset_path"],
86
  phase="test",
87
- order=args.get("test_order", 0),
88
  sketch_threshold_range=(20, 20),
89
  radius=5,
90
  tokenizer=tokenizer,
91
  size=(512, 384),
92
  )
93
  else:
94
- raise NotImplementedError(f"Dataset {args['dataset']} is not supported.")
95
 
96
- # Prepare dataloader
97
  test_dataloader = torch.utils.data.DataLoader(
98
  test_dataset,
99
  shuffle=False,
100
- batch_size=args.get("batch_size", 1),
101
- num_workers=args.get("num_workers_test", 4),
102
  )
103
 
104
- # Cast models to appropriate precision
105
- weight_dtype = torch.float32 if args.get("mixed_precision") != "fp16" else torch.float16
106
  text_encoder.to(device, dtype=weight_dtype)
107
  vae.to(device, dtype=weight_dtype)
 
 
108
  unet.eval()
109
 
110
- # Select pipeline
111
  with torch.inference_mode():
112
- pipeline_class = MGDPipeDisentangled if args.get("disentagle", False) else MGDPipe
113
- val_pipe = pipeline_class(
114
- text_encoder=text_encoder,
115
- vae=vae,
116
- unet=unet.to(vae.dtype),
117
- tokenizer=tokenizer,
118
- scheduler=val_scheduler,
119
- ).to(device)
120
-
 
 
 
 
 
 
 
 
 
 
 
 
121
  val_pipe.enable_attention_slicing()
122
 
123
  # Prepare dataloader with accelerator
124
  test_dataloader = accelerator.prepare(test_dataloader)
125
 
126
- # Generate images
127
- output_path = os.path.join(args["output_dir"], args.get("save_name", "generated_image.png"))
128
  generate_images_from_mgd_pipe(
129
- test_order=args.get("test_order", 0),
130
  pipe=val_pipe,
131
  test_dataloader=test_dataloader,
132
- save_name=args.get("save_name", "generated_image"),
133
- dataset=args["dataset"],
134
- output_dir=args["output_dir"],
135
- guidance_scale=args.get("guidance_scale", 7.5),
136
- guidance_scale_pose=args.get("guidance_scale_pose", 0.5),
137
- guidance_scale_sketch=args.get("guidance_scale_sketch", 7.5),
138
- sketch_cond_rate=args.get("sketch_cond_rate", 1.0),
139
- start_cond_rate=args.get("start_cond_rate", 0.0),
140
  no_pose=False,
141
- disentagle=args.get("disentagle", False),
142
- seed=args.get("seed", None),
143
  )
144
 
145
- # Return the output image path for verification
146
- return output_path
147
-
148
 
149
  if __name__ == "__main__":
150
- # Example usage for debugging
151
- example_args = {
152
- "pretrained_model_name_or_path": "./models",
153
- "dataset": "dresscode",
154
- "dataset_path": "./datasets/dresscode",
155
- "output_dir": "./outputs",
156
- "guidance_scale": 7.5,
157
- "guidance_scale_sketch": 7.5,
158
- "mixed_precision": "fp16",
159
- "batch_size": 1,
160
- "seed": 42,
161
- }
162
- output_image = main(example_args)
163
- print(f"Image generated at: {output_image}")
 
1
  import os
2
 
3
+ # external libraries
4
  import torch
5
+ import torch.utils.checkpoint
6
+ import torch.utils.checkpoint
7
  from accelerate import Accelerator
8
  from accelerate.logging import get_logger
9
  from diffusers import AutoencoderKL, DDIMScheduler
 
11
  from diffusers.utils.import_utils import is_xformers_available
12
  from transformers import CLIPTextModel, CLIPTokenizer
13
 
14
+ # custom imports
15
  from src.datasets.dresscode import DressCodeDataset
16
  from src.datasets.vitonhd import VitonHDDataset
17
  from src.mgd_pipelines.mgd_pipe import MGDPipe
18
  from src.mgd_pipelines.mgd_pipe_disentangled import MGDPipeDisentangled
19
+ from src.utils.arg_parser import eval_parse_args
20
  from src.utils.image_from_pipe import generate_images_from_mgd_pipe
21
  from src.utils.set_seeds import set_seed
22
 
23
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
24
  check_min_version("0.10.0.dev0")
25
 
26
  logger = get_logger(__name__, log_level="INFO")
 
28
  os.environ["WANDB_START_METHOD"] = "thread"
29
 
30
 
31
+ def main() -> None:
32
+ args = eval_parse_args()
33
+ accelerator = Accelerator(
34
+ mixed_precision=args.mixed_precision,
35
+ )
36
  device = accelerator.device
37
 
38
  # Set the training seed
39
+ if args.seed is not None:
40
+ set_seed(args.seed)
41
 
42
  # Load scheduler, tokenizer, and models
43
+ val_scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
44
  val_scheduler.set_timesteps(50, device=device)
45
 
46
  tokenizer = CLIPTokenizer.from_pretrained(
47
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
48
  )
49
  text_encoder = CLIPTextModel.from_pretrained(
50
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
51
  )
52
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
53
 
54
+ # Load unet
55
  unet = torch.hub.load(
56
+ dataset=args.dataset,
57
  repo_or_dir="aimagelab/multimodal-garment-designer",
58
  source="github",
59
  model="mgd",
60
  pretrained=True,
61
  )
62
 
63
+ # Freeze vae and text_encoder
64
  vae.requires_grad_(False)
65
  text_encoder.requires_grad_(False)
66
 
67
  # Enable memory efficient attention if requested
68
+ if args.enable_xformers_memory_efficient_attention:
69
  if is_xformers_available():
70
  unet.enable_xformers_memory_efficient_attention()
71
  else:
72
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
73
 
74
+ # Set the dataset category
75
+ category = [args.category] if args.category else ["dresses", "upper_body", "lower_body"]
76
 
77
+ # Load the appropriate dataset
78
+ if args.dataset == "dresscode":
79
  test_dataset = DressCodeDataset(
80
+ dataroot_path=args.dataset_path,
81
  phase="test",
82
+ order=args.test_order,
83
  radius=5,
84
  sketch_threshold_range=(20, 20),
85
  tokenizer=tokenizer,
86
  category=category,
87
  size=(512, 384),
88
  )
89
+ elif args.dataset == "vitonhd":
90
  test_dataset = VitonHDDataset(
91
+ dataroot_path=args.dataset_path,
92
  phase="test",
93
+ order=args.test_order,
94
  sketch_threshold_range=(20, 20),
95
  radius=5,
96
  tokenizer=tokenizer,
97
  size=(512, 384),
98
  )
99
  else:
100
+ raise NotImplementedError(f"Dataset {args.dataset} is not supported.")
101
 
102
+ # Prepare the dataloader
103
  test_dataloader = torch.utils.data.DataLoader(
104
  test_dataset,
105
  shuffle=False,
106
+ batch_size=args.batch_size,
107
+ num_workers=args.num_workers_test,
108
  )
109
 
110
+ # Cast text_encoder and vae to half-precision for mixed precision training
111
+ weight_dtype = torch.float32 if args.mixed_precision != "fp16" else torch.float16
112
  text_encoder.to(device, dtype=weight_dtype)
113
  vae.to(device, dtype=weight_dtype)
114
+
115
+ # Ensure unet is in eval mode
116
  unet.eval()
117
 
118
+ # Select the appropriate pipeline
119
  with torch.inference_mode():
120
+ if args.disentagle:
121
+ val_pipe = MGDPipeDisentangled(
122
+ text_encoder=text_encoder,
123
+ vae=vae,
124
+ unet=unet.to(vae.dtype),
125
+ tokenizer=tokenizer,
126
+ scheduler=val_scheduler,
127
+ ).to(device)
128
+ else:
129
+ val_pipe = MGDPipe(
130
+ text_encoder=text_encoder,
131
+ vae=vae,
132
+ unet=unet.to(vae.dtype),
133
+ tokenizer=tokenizer,
134
+ scheduler=val_scheduler,
135
+ ).to(device)
136
+
137
+ # Debugging: Ensure val_pipe is callable
138
+ assert callable(val_pipe), "The pipeline object (val_pipe) is not callable. Check MGDPipe implementation."
139
+
140
+ # Enable attention slicing for memory efficiency
141
  val_pipe.enable_attention_slicing()
142
 
143
  # Prepare dataloader with accelerator
144
  test_dataloader = accelerator.prepare(test_dataloader)
145
 
146
+ # Call the image generation function
 
147
  generate_images_from_mgd_pipe(
148
+ test_order=args.test_order,
149
  pipe=val_pipe,
150
  test_dataloader=test_dataloader,
151
+ save_name=args.save_name,
152
+ dataset=args.dataset,
153
+ output_dir=args.output_dir,
154
+ guidance_scale=args.guidance_scale,
155
+ guidance_scale_pose=args.guidance_scale_pose,
156
+ guidance_scale_sketch=args.guidance_scale_sketch,
157
+ sketch_cond_rate=args.sketch_cond_rate,
158
+ start_cond_rate=args.start_cond_rate,
159
  no_pose=False,
160
+ disentagle=args.disentagle,
161
+ seed=args.seed,
162
  )
163
 
 
 
 
164
 
165
  if __name__ == "__main__":
166
+ main()