gastonduault commited on
Commit
aada626
1 Parent(s): a15daf1

update predict example

Browse files
Files changed (2) hide show
  1. README.md +31 -7
  2. predict-example.py +15 -30
README.md CHANGED
@@ -28,18 +28,42 @@ You can find a **GitHub** repository with an interface hosted by a Flask API to
28
  ## Example Usage
29
  ```python
30
  from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor
 
31
  import torch
32
 
33
- # Load model and feature extractor
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  model = Wav2Vec2ForSequenceClassification.from_pretrained("gastonduault/music-classifier")
35
  feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-large")
36
 
37
- # Process audio file
38
- audio_path = "path/to/audio.wav"
39
- audio_input = feature_extractor(audio_array, sampling_rate=16000, return_tensors="pt", padding=True)
 
 
 
 
 
 
 
40
 
41
  # Predict
42
  with torch.no_grad():
43
- logits = model(audio_input["input_values"])
44
- predicted_class = torch.argmax(logits.logits, dim=-1)
45
- print(predicted_class)
 
 
 
 
28
  ## Example Usage
29
  ```python
30
  from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor
31
+ import librosa
32
  import torch
33
 
34
+ # Genre mapping corrected to a dictionary
35
+ genre_mapping = {
36
+ 0: "Electronic",
37
+ 1: "Rock",
38
+ 2: "Punk",
39
+ 3: "Experimental",
40
+ 4: "Hip-Hop",
41
+ 5: "Folk",
42
+ 6: "Chiptune / Glitch",
43
+ 7: "Instrumental",
44
+ 8: "Pop",
45
+ 9: "International",
46
+ }
47
+
48
  model = Wav2Vec2ForSequenceClassification.from_pretrained("gastonduault/music-classifier")
49
  feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-large")
50
 
51
+ # Function for preprocessing audio for prediction
52
+ def preprocess_audio(audio_path):
53
+ audio_array, sampling_rate = librosa.load(audio_path, sr=16000)
54
+ return feature_extractor(audio_array, sampling_rate=16000, return_tensors="pt", padding=True)
55
+
56
+ # Path to your audio file
57
+ audio_path = "./Nirvana - Come As You Are.wav"
58
+
59
+ # Preprocess audio
60
+ inputs = preprocess_audio(audio_path)
61
 
62
  # Predict
63
  with torch.no_grad():
64
+ logits = model(**inputs).logits
65
+ predicted_class = torch.argmax(logits, dim=-1).item()
66
+
67
+ # Output the result
68
+ print(f"song analized:{audio_path}")
69
+ print(f"Predicted genre: {genre_mapping[predicted_class]}")
predict-example.py CHANGED
@@ -1,47 +1,32 @@
1
  from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor
2
- from datasets import load_dataset
3
- import numpy as np
4
  import librosa
5
  import torch
6
 
7
- # Paths
8
- MODEL_DIR = "./wav2vec_trained_model"
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- # Load the dataset
11
- dataset = load_dataset("lewtun/music_genres_small")
12
-
13
- # Retrieve the label names
14
- genre_mapping = {}
15
- for example in dataset["train"]:
16
- genre_id = example["genre_id"]
17
- genre = example["genre"]
18
- if genre_id not in genre_mapping:
19
- genre_mapping[genre_id] = genre
20
- if len(genre_mapping) == 9:
21
- break
22
-
23
- print(f"Loading model from {MODEL_DIR}...\n")
24
  model = Wav2Vec2ForSequenceClassification.from_pretrained("gastonduault/music-classifier")
25
  feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-large")
26
 
27
  # Function for preprocessing audio for prediction
28
- def preprocess_audio(audio_path, target_length=16000 * 180): # 30 seconds at 16kHz
29
  audio_array, sampling_rate = librosa.load(audio_path, sr=16000)
30
-
31
- if len(audio_array) > target_length:
32
- audio_array = audio_array[:target_length]
33
- else:
34
- padding = target_length - len(audio_array)
35
- audio_array = np.pad(audio_array, (0, padding), "constant")
36
-
37
- inputs = feature_extractor(audio_array, sampling_rate=16000, return_tensors="pt", padding=True)
38
- return inputs
39
-
40
 
41
  # Path to your audio file
42
  audio_path = "./Nirvana - Come As You Are.wav"
43
 
44
-
45
  # Preprocess audio
46
  inputs = preprocess_audio(audio_path)
47
 
 
1
  from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor
 
 
2
  import librosa
3
  import torch
4
 
5
+ # Genre mapping corrected to a dictionary
6
+ genre_mapping = {
7
+ 0: "Electronic",
8
+ 1: "Rock",
9
+ 2: "Punk",
10
+ 3: "Experimental",
11
+ 4: "Hip-Hop",
12
+ 5: "Folk",
13
+ 6: "Chiptune / Glitch",
14
+ 7: "Instrumental",
15
+ 8: "Pop",
16
+ 9: "International",
17
+ }
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  model = Wav2Vec2ForSequenceClassification.from_pretrained("gastonduault/music-classifier")
20
  feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-large")
21
 
22
  # Function for preprocessing audio for prediction
23
+ def preprocess_audio(audio_path):
24
  audio_array, sampling_rate = librosa.load(audio_path, sr=16000)
25
+ return feature_extractor(audio_array, sampling_rate=16000, return_tensors="pt", padding=True)
 
 
 
 
 
 
 
 
 
26
 
27
  # Path to your audio file
28
  audio_path = "./Nirvana - Come As You Are.wav"
29
 
 
30
  # Preprocess audio
31
  inputs = preprocess_audio(audio_path)
32