franciszzj commited on
Commit
c81c28f
1 Parent(s): 9ed5c4d

load model before predict

Browse files
Files changed (2) hide show
  1. app.py +27 -20
  2. leffa/inference.py +1 -2
app.py CHANGED
@@ -13,6 +13,28 @@ import gradio as gr
13
  # Download checkpoints
14
  snapshot_download(repo_id="franciszzj/Leffa", local_dir="./ckpts")
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  def leffa_predict(src_image_path, ref_image_path, control_type):
18
  assert control_type in [
@@ -27,20 +49,12 @@ def leffa_predict(src_image_path, ref_image_path, control_type):
27
 
28
  # Mask
29
  if control_type == "virtual_tryon":
30
- automasker = AutoMasker(
31
- densepose_path="./ckpts/densepose",
32
- schp_path="./ckpts/schp",
33
- )
34
  src_image = src_image.convert("RGB")
35
- mask = automasker(src_image, "upper")["mask"]
36
  elif control_type == "pose_transfer":
37
  mask = Image.fromarray(np.ones_like(src_image_array) * 255)
38
 
39
  # DensePose
40
- densepose_predictor = DensePosePredictor(
41
- config_path="./ckpts/densepose/densepose_rcnn_R_50_FPN_s1x.yaml",
42
- weights_path="./ckpts/densepose/model_final_162be9.pkl",
43
- )
44
  src_image_iuv_array = densepose_predictor.predict_iuv(src_image_array)
45
  src_image_seg_array = densepose_predictor.predict_seg(src_image_array)
46
  src_image_iuv = Image.fromarray(src_image_iuv_array)
@@ -52,17 +66,6 @@ def leffa_predict(src_image_path, ref_image_path, control_type):
52
 
53
  # Leffa
54
  transform = LeffaTransform()
55
- if control_type == "virtual_tryon":
56
- pretrained_model_name_or_path = "./ckpts/stable-diffusion-inpainting"
57
- pretrained_model = "./ckpts/virtual_tryon.pth"
58
- elif control_type == "pose_transfer":
59
- pretrained_model_name_or_path = "./ckpts/stable-diffusion-xl-1.0-inpainting-0.1"
60
- pretrained_model = "./ckpts/pose_transfer.pth"
61
- model = LeffaModel(
62
- pretrained_model_name_or_path=pretrained_model_name_or_path,
63
- pretrained_model=pretrained_model,
64
- )
65
- inference = LeffaInference(model=model)
66
 
67
  data = {
68
  "src_image": [src_image],
@@ -71,6 +74,10 @@ def leffa_predict(src_image_path, ref_image_path, control_type):
71
  "densepose": [densepose],
72
  }
73
  data = transform(data)
 
 
 
 
74
  output = inference(data)
75
  gen_image = output["generated_image"][0]
76
  # gen_image.save("gen_image.png")
 
13
  # Download checkpoints
14
  snapshot_download(repo_id="franciszzj/Leffa", local_dir="./ckpts")
15
 
16
+ mask_predictor = AutoMasker(
17
+ densepose_path="./ckpts/densepose",
18
+ schp_path="./ckpts/schp",
19
+ )
20
+
21
+ densepose_predictor = DensePosePredictor(
22
+ config_path="./ckpts/densepose/densepose_rcnn_R_50_FPN_s1x.yaml",
23
+ weights_path="./ckpts/densepose/model_final_162be9.pkl",
24
+ )
25
+
26
+ vt_model = LeffaModel(
27
+ pretrained_model_name_or_path="./ckpts/stable-diffusion-inpainting",
28
+ pretrained_model="./ckpts/virtual_tryon.pth",
29
+ )
30
+ vt_inference = LeffaInference(model=vt_model)
31
+
32
+ pt_model = LeffaModel(
33
+ pretrained_model_name_or_path="./ckpts/stable-diffusion-xl-1.0-inpainting-0.1",
34
+ pretrained_model="./ckpts/pose_transfer.pth",
35
+ )
36
+ pt_inference = LeffaInference(model=pt_model)
37
+
38
 
39
  def leffa_predict(src_image_path, ref_image_path, control_type):
40
  assert control_type in [
 
49
 
50
  # Mask
51
  if control_type == "virtual_tryon":
 
 
 
 
52
  src_image = src_image.convert("RGB")
53
+ mask = mask_predictor(src_image, "upper")["mask"]
54
  elif control_type == "pose_transfer":
55
  mask = Image.fromarray(np.ones_like(src_image_array) * 255)
56
 
57
  # DensePose
 
 
 
 
58
  src_image_iuv_array = densepose_predictor.predict_iuv(src_image_array)
59
  src_image_seg_array = densepose_predictor.predict_seg(src_image_array)
60
  src_image_iuv = Image.fromarray(src_image_iuv_array)
 
66
 
67
  # Leffa
68
  transform = LeffaTransform()
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  data = {
71
  "src_image": [src_image],
 
74
  "densepose": [densepose],
75
  }
76
  data = transform(data)
77
+ if control_type == "virtual_tryon":
78
+ inference = vt_inference
79
+ elif control_type == "pose_transfer":
80
+ inference = pt_inference
81
  output = inference(data)
82
  gen_image = output["generated_image"][0]
83
  # gen_image.save("gen_image.png")
leffa/inference.py CHANGED
@@ -17,7 +17,6 @@ class LeffaInference(object):
17
  self,
18
  model: nn.Module,
19
  ckpt_path: Optional[str] = None,
20
- repaint: bool = False,
21
  ) -> None:
22
  self.model: torch.nn.Module = model
23
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -28,7 +27,7 @@ class LeffaInference(object):
28
  self.model = self.model.to(self.device)
29
  self.model.eval()
30
 
31
- self.pipe = LeffaPipeline(model=self.model, repaint=repaint)
32
 
33
  def to_gpu(self, data: Dict[str, Any]) -> Dict[str, Any]:
34
  for k, v in data.items():
 
17
  self,
18
  model: nn.Module,
19
  ckpt_path: Optional[str] = None,
 
20
  ) -> None:
21
  self.model: torch.nn.Module = model
22
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
 
27
  self.model = self.model.to(self.device)
28
  self.model.eval()
29
 
30
+ self.pipe = LeffaPipeline(model=self.model)
31
 
32
  def to_gpu(self, data: Dict[str, Any]) -> Dict[str, Any]:
33
  for k, v in data.items():