multimodalart's picture
Squashing commit
4450790 verified
from typing import TypedDict
import torch
import torchaudio
class AudioDict(TypedDict):
"""Comfy's representation of AUDIO data."""
sample_rate: int
waveform: torch.Tensor
AudioData = AudioDict | list[AudioDict]
class MtbAudio:
"""Base class for audio processing."""
@classmethod
def is_stereo(
cls,
audios: AudioData,
) -> bool:
if isinstance(audios, list):
return any(cls.is_stereo(audio) for audio in audios)
else:
return audios["waveform"].shape[1] == 2
@staticmethod
def resample(audio: AudioDict, common_sample_rate: int) -> AudioDict:
if audio["sample_rate"] != common_sample_rate:
resampler = torchaudio.transforms.Resample(
orig_freq=audio["sample_rate"], new_freq=common_sample_rate
)
return {
"sample_rate": common_sample_rate,
"waveform": resampler(audio["waveform"]),
}
else:
return audio
@staticmethod
def to_stereo(audio: AudioDict) -> AudioDict:
if audio["waveform"].shape[1] == 1:
return {
"sample_rate": audio["sample_rate"],
"waveform": torch.cat(
[audio["waveform"], audio["waveform"]], dim=1
),
}
else:
return audio
@classmethod
def preprocess_audios(
cls, audios: list[AudioDict]
) -> tuple[list[AudioDict], bool, int]:
max_sample_rate = max([audio["sample_rate"] for audio in audios])
resampled_audios = [
cls.resample(audio, max_sample_rate) for audio in audios
]
is_stereo = cls.is_stereo(audios)
if is_stereo:
audios = [cls.to_stereo(audio) for audio in resampled_audios]
return (audios, is_stereo, max_sample_rate)
class MTB_AudioCut(MtbAudio):
"""Basic audio cutter, values are in ms."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"audio": ("AUDIO",),
"length": (
("FLOAT"),
{
"default": 1000.0,
"min": 0.0,
"max": 999999.0,
"step": 1,
},
),
"offset": (
("FLOAT"),
{"default": 0.0, "min": 0.0, "max": 999999.0, "step": 1},
),
},
}
RETURN_TYPES = ("AUDIO",)
RETURN_NAMES = ("cut_audio",)
CATEGORY = "mtb/audio"
FUNCTION = "cut"
def cut(self, audio: AudioDict, length: float, offset: float):
sample_rate = audio["sample_rate"]
start_idx = int(offset * sample_rate / 1000)
end_idx = min(
start_idx + int(length * sample_rate / 1000),
audio["waveform"].shape[-1],
)
cut_waveform = audio["waveform"][:, :, start_idx:end_idx]
return (
{
"sample_rate": sample_rate,
"waveform": cut_waveform,
},
)
class MTB_AudioStack(MtbAudio):
"""Stack/Overlay audio inputs (dynamic inputs).
- pad audios to the longest inputs.
- resample audios to the highest sample rate in the inputs.
- convert them all to stereo if one of the inputs is.
"""
@classmethod
def INPUT_TYPES(cls):
return {"required": {}}
RETURN_TYPES = ("AUDIO",)
RETURN_NAMES = ("stacked_audio",)
CATEGORY = "mtb/audio"
FUNCTION = "stack"
def stack(self, **kwargs: AudioDict) -> tuple[AudioDict]:
audios, is_stereo, max_rate = self.preprocess_audios(
list(kwargs.values())
)
max_length = max([audio["waveform"].shape[-1] for audio in audios])
padded_audios: list[torch.Tensor] = []
for audio in audios:
padding = torch.zeros(
(
1,
2 if is_stereo else 1,
max_length - audio["waveform"].shape[-1],
)
)
padded_audio = torch.cat([audio["waveform"], padding], dim=-1)
padded_audios.append(padded_audio)
stacked_waveform = torch.stack(padded_audios, dim=0).sum(dim=0)
return (
{
"sample_rate": max_rate,
"waveform": stacked_waveform,
},
)
class MTB_AudioSequence(MtbAudio):
"""Sequence audio inputs (dynamic inputs).
- adding silence_duration between each segment
can now also be negative to overlap the clips, safely bound
to the the input length.
- resample audios to the highest sample rate in the inputs.
- convert them all to stereo if one of the inputs is.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"silence_duration": (
("FLOAT"),
{"default": 0.0, "min": -999.0, "max": 999, "step": 0.01},
)
},
}
RETURN_TYPES = ("AUDIO",)
RETURN_NAMES = ("sequenced_audio",)
CATEGORY = "mtb/audio"
FUNCTION = "sequence"
def sequence(self, silence_duration: float, **kwargs: AudioDict):
audios, is_stereo, max_rate = self.preprocess_audios(
list(kwargs.values())
)
sequence: list[torch.Tensor] = []
for i, audio in enumerate(audios):
if i > 0:
if silence_duration > 0:
silence = torch.zeros(
(
1,
2 if is_stereo else 1,
int(silence_duration * max_rate),
)
)
sequence.append(silence)
elif silence_duration < 0:
overlap = int(abs(silence_duration) * max_rate)
previous_audio = sequence[-1]
overlap = min(
overlap,
previous_audio.shape[-1],
audio["waveform"].shape[-1],
)
if overlap > 0:
overlap_part = (
previous_audio[:, :, -overlap:]
+ audio["waveform"][:, :, :overlap]
)
sequence[-1] = previous_audio[:, :, :-overlap]
sequence.append(overlap_part)
audio["waveform"] = audio["waveform"][:, :, overlap:]
sequence.append(audio["waveform"])
sequenced_waveform = torch.cat(sequence, dim=-1)
return (
{
"sample_rate": max_rate,
"waveform": sequenced_waveform,
},
)
__nodes__ = [MTB_AudioSequence, MTB_AudioStack, MTB_AudioCut]