Rex Cheng commited on
Commit
164c335
1 Parent(s): 627e0b8

speed up inference

Browse files
app.py CHANGED
@@ -48,7 +48,8 @@ def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
48
  synchformer_ckpt=model.synchformer_ckpt,
49
  enable_conditions=True,
50
  mode=model.mode,
51
- bigvgan_vocoder_ckpt=model.bigvgan_16k_path)
 
52
  feature_utils = feature_utils.to(device, dtype).eval()
53
 
54
  return net, feature_utils, seq_cfg
 
48
  synchformer_ckpt=model.synchformer_ckpt,
49
  enable_conditions=True,
50
  mode=model.mode,
51
+ bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
52
+ need_vae_encoder=False)
53
  feature_utils = feature_utils.to(device, dtype).eval()
54
 
55
  return net, feature_utils, seq_cfg
mmaudio/eval_utils.py CHANGED
@@ -76,29 +76,37 @@ all_model_cfg: dict[str, ModelConfig] = {
76
  }
77
 
78
 
79
- def generate(clip_video: Optional[torch.Tensor],
80
- sync_video: Optional[torch.Tensor],
81
- text: Optional[list[str]],
82
- *,
83
- negative_text: Optional[list[str]] = None,
84
- feature_utils: FeaturesUtils,
85
- net: MMAudio,
86
- fm: FlowMatching,
87
- rng: torch.Generator,
88
- cfg_strength: float):
 
 
 
 
89
  device = feature_utils.device
90
  dtype = feature_utils.dtype
91
 
92
  bs = len(text)
93
  if clip_video is not None:
94
  clip_video = clip_video.to(device, dtype, non_blocking=True)
95
- clip_features = feature_utils.encode_video_with_clip(clip_video, batch_size=bs)
 
 
96
  else:
97
  clip_features = net.get_empty_clip_sequence(bs)
98
 
99
  if sync_video is not None:
100
  sync_video = sync_video.to(device, dtype, non_blocking=True)
101
- sync_features = feature_utils.encode_video_with_sync(sync_video, batch_size=bs)
 
 
102
  else:
103
  sync_features = net.get_empty_sync_sequence(bs)
104
 
@@ -185,14 +193,9 @@ def load_video(video_path: Path, duration_sec: float) -> tuple[torch.Tensor, tor
185
  data_chunk = reader.pop_chunks()
186
  clip_chunk = data_chunk[0]
187
  sync_chunk = data_chunk[1]
188
- print('clip', clip_chunk.shape, clip_chunk.dtype, clip_chunk.max())
189
- print('sync', sync_chunk.shape, sync_chunk.dtype, sync_chunk.max())
190
  assert clip_chunk is not None
191
  assert sync_chunk is not None
192
 
193
- for i in range(reader.num_out_streams):
194
- print(reader.get_out_stream_info(i))
195
-
196
  clip_frames = clip_transform(clip_chunk)
197
  sync_frames = sync_transform(sync_chunk)
198
 
 
76
  }
77
 
78
 
79
+ def generate(
80
+ clip_video: Optional[torch.Tensor],
81
+ sync_video: Optional[torch.Tensor],
82
+ text: Optional[list[str]],
83
+ *,
84
+ negative_text: Optional[list[str]] = None,
85
+ feature_utils: FeaturesUtils,
86
+ net: MMAudio,
87
+ fm: FlowMatching,
88
+ rng: torch.Generator,
89
+ cfg_strength: float,
90
+ clip_batch_size_multiplier: int = 40,
91
+ sync_batch_size_multiplier: int = 40,
92
+ ) -> torch.Tensor:
93
  device = feature_utils.device
94
  dtype = feature_utils.dtype
95
 
96
  bs = len(text)
97
  if clip_video is not None:
98
  clip_video = clip_video.to(device, dtype, non_blocking=True)
99
+ clip_features = feature_utils.encode_video_with_clip(clip_video,
100
+ batch_size=bs *
101
+ clip_batch_size_multiplier)
102
  else:
103
  clip_features = net.get_empty_clip_sequence(bs)
104
 
105
  if sync_video is not None:
106
  sync_video = sync_video.to(device, dtype, non_blocking=True)
107
+ sync_features = feature_utils.encode_video_with_sync(sync_video,
108
+ batch_size=bs *
109
+ sync_batch_size_multiplier)
110
  else:
111
  sync_features = net.get_empty_sync_sequence(bs)
112
 
 
193
  data_chunk = reader.pop_chunks()
194
  clip_chunk = data_chunk[0]
195
  sync_chunk = data_chunk[1]
 
 
196
  assert clip_chunk is not None
197
  assert sync_chunk is not None
198
 
 
 
 
199
  clip_frames = clip_transform(clip_chunk)
200
  sync_frames = sync_transform(sync_chunk)
201
 
mmaudio/ext/autoencoder/autoencoder.py CHANGED
@@ -15,7 +15,8 @@ class AutoEncoderModule(nn.Module):
15
  *,
16
  vae_ckpt_path,
17
  vocoder_ckpt_path: Optional[str] = None,
18
- mode: Literal['16k', '44k']):
 
19
  super().__init__()
20
  self.vae: VAE = get_my_vae(mode).eval()
21
  vae_state_dict = torch.load(vae_ckpt_path, weights_only=True, map_location='cpu')
@@ -35,6 +36,9 @@ class AutoEncoderModule(nn.Module):
35
  for param in self.parameters():
36
  param.requires_grad = False
37
 
 
 
 
38
  @torch.inference_mode()
39
  def encode(self, x: torch.Tensor) -> DiagonalGaussianDistribution:
40
  return self.vae.encode(x)
 
15
  *,
16
  vae_ckpt_path,
17
  vocoder_ckpt_path: Optional[str] = None,
18
+ mode: Literal['16k', '44k'],
19
+ need_vae_encoder: bool = True):
20
  super().__init__()
21
  self.vae: VAE = get_my_vae(mode).eval()
22
  vae_state_dict = torch.load(vae_ckpt_path, weights_only=True, map_location='cpu')
 
36
  for param in self.parameters():
37
  param.requires_grad = False
38
 
39
+ if not need_vae_encoder:
40
+ del self.vae.encoder
41
+
42
  @torch.inference_mode()
43
  def encode(self, x: torch.Tensor) -> DiagonalGaussianDistribution:
44
  return self.vae.encode(x)
mmaudio/model/utils/features_utils.py CHANGED
@@ -41,6 +41,7 @@ class FeaturesUtils(nn.Module):
41
  synchformer_ckpt: Optional[str] = None,
42
  enable_conditions: bool = True,
43
  mode=Literal['16k', '44k'],
 
44
  ):
45
  super().__init__()
46
 
@@ -64,19 +65,18 @@ class FeaturesUtils(nn.Module):
64
  if tod_vae_ckpt is not None:
65
  self.tod = AutoEncoderModule(vae_ckpt_path=tod_vae_ckpt,
66
  vocoder_ckpt_path=bigvgan_vocoder_ckpt,
67
- mode=mode)
 
68
  else:
69
  self.tod = None
70
  self.mel_converter = MelConverter()
71
 
72
  def compile(self):
73
  if self.clip_model is not None:
74
- self.encode_video_with_clip = torch.compile(self.encode_video_with_clip)
75
  self.clip_model.encode_image = torch.compile(self.clip_model.encode_image)
76
  self.clip_model.encode_text = torch.compile(self.clip_model.encode_text)
77
  if self.synchformer is not None:
78
  self.synchformer = torch.compile(self.synchformer)
79
- self.tod.encode = torch.compile(self.tod.encode)
80
  self.decode = torch.compile(self.decode)
81
  self.vocode = torch.compile(self.vocode)
82
 
@@ -121,9 +121,11 @@ class FeaturesUtils(nn.Module):
121
  outputs = []
122
  if batch_size < 0:
123
  batch_size = b
124
- for i in range(0, b, batch_size):
 
125
  outputs.append(self.synchformer(x[i:i + batch_size]))
126
- x = torch.cat(outputs, dim=0).flatten(start_dim=1, end_dim=2)
 
127
  return x
128
 
129
  @torch.inference_mode()
 
41
  synchformer_ckpt: Optional[str] = None,
42
  enable_conditions: bool = True,
43
  mode=Literal['16k', '44k'],
44
+ need_vae_encoder: bool = True,
45
  ):
46
  super().__init__()
47
 
 
65
  if tod_vae_ckpt is not None:
66
  self.tod = AutoEncoderModule(vae_ckpt_path=tod_vae_ckpt,
67
  vocoder_ckpt_path=bigvgan_vocoder_ckpt,
68
+ mode=mode,
69
+ need_vae_encoder=need_vae_encoder)
70
  else:
71
  self.tod = None
72
  self.mel_converter = MelConverter()
73
 
74
  def compile(self):
75
  if self.clip_model is not None:
 
76
  self.clip_model.encode_image = torch.compile(self.clip_model.encode_image)
77
  self.clip_model.encode_text = torch.compile(self.clip_model.encode_text)
78
  if self.synchformer is not None:
79
  self.synchformer = torch.compile(self.synchformer)
 
80
  self.decode = torch.compile(self.decode)
81
  self.vocode = torch.compile(self.vocode)
82
 
 
121
  outputs = []
122
  if batch_size < 0:
123
  batch_size = b
124
+ x = rearrange(x, 'b s t c h w -> (b s) 1 t c h w')
125
+ for i in range(0, b * num_segments, batch_size):
126
  outputs.append(self.synchformer(x[i:i + batch_size]))
127
+ x = torch.cat(outputs, dim=0)
128
+ x = rearrange(x, '(b s) 1 t d -> b (s t) d', b=b)
129
  return x
130
 
131
  @torch.inference_mode()