All-In-One Metrical And Functional Structure Analysis With Neighborhood Attentions on Demixed Audio
Paper โข 2307.16425 โข Published
A Temporal Convolutional Network (TCN) for detecting beats in music audio. Designed for use in video editing apps where users cut video to music beats.
import torch
import torchaudio.transforms as T
import numpy as np
import soundfile as sf
from huggingface_hub import hf_hub_download
# Download model
model_path = hf_hub_download("finnvoorhees/beat-detection-tcn", "model.pt")
checkpoint = torch.load(model_path, map_location="cpu", weights_only=False)
config = checkpoint["config"]
# Load audio
audio, sr = sf.read("song.wav") # or use librosa, torchaudio, etc.
if audio.ndim > 1:
audio = audio.mean(axis=1)
# Resample to 22050 Hz if needed
if sr != config["sample_rate"]:
import librosa
audio = librosa.resample(audio, orig_sr=sr, target_sr=config["sample_rate"])
# Compute mel spectrogram
waveform = torch.from_numpy(audio).float().unsqueeze(0)
mel_transform = T.MelSpectrogram(
sample_rate=config["sample_rate"], n_fft=config["n_fft"],
hop_length=config["hop_length"], n_mels=config["n_mels"],
f_min=config["fmin"], f_max=config["fmax"], power=2.0
)
mel = torch.log1p(mel_transform(waveform)) # (1, n_mels, n_frames)
# Load model (see train_beat_detector.py for BeatTCN class definition)
from train_beat_detector import BeatTCN, Config
cfg = Config()
for k, v in config.items():
setattr(cfg, k, v)
model = BeatTCN(cfg)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
# Predict
with torch.no_grad():
logits = model(mel.unsqueeze(0) if mel.dim() == 2 else mel)
activations = torch.sigmoid(logits).squeeze().numpy()
# Peak pick beats
fps = config["sample_rate"] / config["hop_length"]
threshold, min_interval = 0.3, 0.2
beats = []
i = 0
while i < len(activations):
if activations[i] >= threshold:
window_end = min(i + int(min_interval * fps), len(activations))
peak_idx = i + np.argmax(activations[i:window_end])
beats.append(peak_idx / fps)
i = peak_idx + int(min_interval * fps)
else:
i += 1
print(f"Detected {len(beats)} beats: {beats[:10]}")
| Property | Value |
|---|---|
| Architecture | Temporal Convolutional Network (TCN) with dilated convolutions |
| Parameters | 268,817 (~269K โ lightweight, suitable for edge/mobile deployment) |
| Input | Log-mel spectrogram (81 mel bands, 100fps) |
| Output | Frame-level beat activation probabilities |
| Sample Rate | 22050 Hz |
| Inference Speed | 6.9ms for 30s audio (4,374x real-time on T4 GPU) |
Evaluated on synthetic test tracks spanning diverse tempos and noise conditions. F-measure with 70ms tolerance window (standard MIR evaluation metric).
| Test Case | F-measure | Predicted | Reference |
|---|---|---|---|
| Slow (60 BPM) | 0.966 | 14 | 15 |
| Moderate (80 BPM) | 0.974 | 19 | 20 |
| Medium (100 BPM) | 0.941 | 26 | 25 |
| Standard (120 BPM) | 0.949 | 29 | 30 |
| Fast (140 BPM) | 0.955 | 32 | 35 |
| Very fast (160 BPM) | 0.933 | 35 | 40 |
| Extreme (180 BPM) | 0.816 | 31 | 45 |
| Maximum (200 BPM) | 0.795 | 33 | 50 |
| 120 BPM + heavy noise | 0.889 | 33 | 30 |
| 120 BPM + extreme noise | 0.517 | 59 | 30 |
| 90 BPM + moderate noise | 0.933 | 22 | 23 |
| 75 BPM (hip-hop) | 0.974 | 20 | 19 |
| 128 BPM (EDM) | 0.951 | 29 | 32 |
| 150 BPM (drum & bass) | 0.959 | 35 | 38 |
| Average | 0.897 |
From a 120 BPM test track, the model estimated 120.3 BPM (0.25% error) via median beat interval.
Based on published beat tracking research:
| Hyperparameter | Value |
|---|---|
| Optimizer | RAdam |
| Learning rate | 5e-4 |
| Weight decay | 2.5e-4 |
| LR Schedule | ReduceLROnPlateau (factor=0.3, patience=8) |
| Loss | Weighted BCE (pos_weight=20) |
| TCN layers | 11 |
| TCN channels | 48 |
| Kernel size | 5 |
| Dilation pattern | 2^l (layer l), doubled for second conv |
| Dropout | 0.15 |
| Batch size | 16 |
| Excerpt length | 8 seconds |
| Augmentation | Time stretch, random gain, noise injection |
| Early stopping | Patience 15 epochs |
Training Data: 2,000 synthetic tracks with diverse patterns:
Training converged at epoch 25 (early stopping) with best val F-measure of 0.9999.
# Detect beats for video cutting
beats = detect_beats("song.mp3", threshold=0.3)
# Filter to strong beats only (for dramatic cuts)
beats_strong = detect_beats("song.mp3", threshold=0.5)
# Get beats at minimum 0.5s intervals (slower cuts)
beats_slow = detect_beats("song.mp3", min_interval=0.5)
# Estimate BPM from detected beats
intervals = np.diff(beats)
bpm = 60.0 / np.median(intervals)
model.pt โ PyTorch model checkpoint (config + weights)train_beat_detector.py โ Full training script (reproducible)beat_detector.py โ Minimal inference scriptMIT