Beat Detection TCN

A Temporal Convolutional Network (TCN) for detecting beats in music audio. Designed for use in video editing apps where users cut video to music beats.

Quick Start

import torch
import torchaudio.transforms as T
import numpy as np
import soundfile as sf
from huggingface_hub import hf_hub_download

# Download model
model_path = hf_hub_download("finnvoorhees/beat-detection-tcn", "model.pt")
checkpoint = torch.load(model_path, map_location="cpu", weights_only=False)
config = checkpoint["config"]

# Load audio
audio, sr = sf.read("song.wav")  # or use librosa, torchaudio, etc.
if audio.ndim > 1:
    audio = audio.mean(axis=1)

# Resample to 22050 Hz if needed
if sr != config["sample_rate"]:
    import librosa
    audio = librosa.resample(audio, orig_sr=sr, target_sr=config["sample_rate"])

# Compute mel spectrogram
waveform = torch.from_numpy(audio).float().unsqueeze(0)
mel_transform = T.MelSpectrogram(
    sample_rate=config["sample_rate"], n_fft=config["n_fft"],
    hop_length=config["hop_length"], n_mels=config["n_mels"],
    f_min=config["fmin"], f_max=config["fmax"], power=2.0
)
mel = torch.log1p(mel_transform(waveform))  # (1, n_mels, n_frames)

# Load model (see train_beat_detector.py for BeatTCN class definition)
from train_beat_detector import BeatTCN, Config
cfg = Config()
for k, v in config.items():
    setattr(cfg, k, v)
model = BeatTCN(cfg)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

# Predict
with torch.no_grad():
    logits = model(mel.unsqueeze(0) if mel.dim() == 2 else mel)
    activations = torch.sigmoid(logits).squeeze().numpy()

# Peak pick beats
fps = config["sample_rate"] / config["hop_length"]
threshold, min_interval = 0.3, 0.2
beats = []
i = 0
while i < len(activations):
    if activations[i] >= threshold:
        window_end = min(i + int(min_interval * fps), len(activations))
        peak_idx = i + np.argmax(activations[i:window_end])
        beats.append(peak_idx / fps)
        i = peak_idx + int(min_interval * fps)
    else:
        i += 1

print(f"Detected {len(beats)} beats: {beats[:10]}")

Model Details

Property Value
Architecture Temporal Convolutional Network (TCN) with dilated convolutions
Parameters 268,817 (~269K โ€” lightweight, suitable for edge/mobile deployment)
Input Log-mel spectrogram (81 mel bands, 100fps)
Output Frame-level beat activation probabilities
Sample Rate 22050 Hz
Inference Speed 6.9ms for 30s audio (4,374x real-time on T4 GPU)

Evaluation Results

Evaluated on synthetic test tracks spanning diverse tempos and noise conditions. F-measure with 70ms tolerance window (standard MIR evaluation metric).

Test Case F-measure Predicted Reference
Slow (60 BPM) 0.966 14 15
Moderate (80 BPM) 0.974 19 20
Medium (100 BPM) 0.941 26 25
Standard (120 BPM) 0.949 29 30
Fast (140 BPM) 0.955 32 35
Very fast (160 BPM) 0.933 35 40
Extreme (180 BPM) 0.816 31 45
Maximum (200 BPM) 0.795 33 50
120 BPM + heavy noise 0.889 33 30
120 BPM + extreme noise 0.517 59 30
90 BPM + moderate noise 0.933 22 23
75 BPM (hip-hop) 0.974 20 19
128 BPM (EDM) 0.951 29 32
150 BPM (drum & bass) 0.959 35 38
Average 0.897

BPM Estimation Accuracy

From a 120 BPM test track, the model estimated 120.3 BPM (0.25% error) via median beat interval.

Training Recipe

Based on published beat tracking research:

Hyperparameter Value
Optimizer RAdam
Learning rate 5e-4
Weight decay 2.5e-4
LR Schedule ReduceLROnPlateau (factor=0.3, patience=8)
Loss Weighted BCE (pos_weight=20)
TCN layers 11
TCN channels 48
Kernel size 5
Dilation pattern 2^l (layer l), doubled for second conv
Dropout 0.15
Batch size 16
Excerpt length 8 seconds
Augmentation Time stretch, random gain, noise injection
Early stopping Patience 15 epochs

Training Data: 2,000 synthetic tracks with diverse patterns:

  • 8 pattern types: click, kick, kick+snare, full drums, bass, chord stabs, layered production, sparse
  • Tempo range: 60-200 BPM with humanized timing
  • Background textures: pink noise, hum, pad, silence
  • Various noise levels and dynamics

Training converged at epoch 25 (early stopping) with best val F-measure of 0.9999.

Usage for Video Apps

# Detect beats for video cutting
beats = detect_beats("song.mp3", threshold=0.3)

# Filter to strong beats only (for dramatic cuts)
beats_strong = detect_beats("song.mp3", threshold=0.5)

# Get beats at minimum 0.5s intervals (slower cuts)
beats_slow = detect_beats("song.mp3", min_interval=0.5)

# Estimate BPM from detected beats
intervals = np.diff(beats)
bpm = 60.0 / np.median(intervals)

Files

  • model.pt โ€” PyTorch model checkpoint (config + weights)
  • train_beat_detector.py โ€” Full training script (reproducible)
  • beat_detector.py โ€” Minimal inference script

License

MIT

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Papers for finnvoorhees/beat-detection-tcn