gaunernst commited on
Commit
ad5370b
1 Parent(s): 3058f72

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +10 -20
README.md CHANGED
@@ -23,34 +23,24 @@ A Vision Transformer (ViT) for audio. Pretrained on AudioSet-2M with Self-Superv
23
  ### Audio Classification and Embeddings
24
 
25
  ```python
26
- from urllib.request import urlopen
27
  import timm
28
- from torchaudio.compliance import kaldi
29
  import torch
30
-
31
- # TODO: change this to audio
32
- img = Image.open(urlopen(
33
- 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
34
- ))
35
 
36
  # NOTE: for timm<0.9.11, you also need to pass `global_pool='avg'`
37
  # if only embeddings are needed, pass `num_classes=0`
38
  model = timm.create_model("hf_hub:gaunernst/vit_base_patch16_1024_128.audiomae_as2m_ft_as20k", pretrained=True)
39
  model = model.eval()
40
 
41
- # TODO: HF preprocessor (AST)
42
- audio = torch.randn(1, 10 * 16_000)
43
- melspec = kaldi.fbank(
44
- audio,
45
- htk_compat=True,
46
- sample_frequency=16_000,
47
- use_energy=False,
48
- window_type='hanning',
49
- num_mel_bins=128,
50
- dither=0.0,
51
- frame_shift=10,
52
- ) # shape (n_frames, 128)
53
- melspec = melspec[:1024] # AudioMAE only accepts 1024-frame input
54
 
55
  melspec = melspec.view(1, 1, 1024, 128) # add batch dim and channel dim
56
  output = model(melspec)
 
23
  ### Audio Classification and Embeddings
24
 
25
  ```python
 
26
  import timm
 
27
  import torch
28
+ import torch.nn.functional as F
29
+ from torchaudio.compliance import kaldi
 
 
 
30
 
31
  # NOTE: for timm<0.9.11, you also need to pass `global_pool='avg'`
32
  # if only embeddings are needed, pass `num_classes=0`
33
  model = timm.create_model("hf_hub:gaunernst/vit_base_patch16_1024_128.audiomae_as2m_ft_as20k", pretrained=True)
34
  model = model.eval()
35
 
36
+ audio = torch.randn(1, 10 * 16_000) # make sure input is 16kHz
37
+ melspec = kaldi.fbank(audio, htk_compat=True, window_type="hanning", num_mel_bins=128) # shape (n_frames, 128)
38
+
39
+ # AudioMAE only accepts 1024-frame input
40
+ if melspec.shape[0] < 1024:
41
+ melspec = F.pad(melspec, (0, 0, 0, 1024 - melspec.shape[0]))
42
+ else:
43
+ melspec = melspec[:1024]
 
 
 
 
 
44
 
45
  melspec = melspec.view(1, 1, 1024, 128) # add batch dim and channel dim
46
  output = model(melspec)