rmayormartins commited on
Commit
77c34b5
1 Parent(s): a3414e2

Subindo arquivos7

Browse files
Files changed (2) hide show
  1. app.py +15 -12
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,12 +1,15 @@
1
  import gradio as gr
2
  import torch
3
  import numpy as np
4
- from transformers import Wav2Vec2Processor, Wav2Vec2ForSequenceClassification
 
5
 
6
- #
7
  model_name = "results"
8
  processor = Wav2Vec2Processor.from_pretrained(model_name)
9
- model = Wav2Vec2ForSequenceClassification.from_pretrained(model_name, from_tf=False, from_flax=False, from_safetensors=True)
 
 
10
 
11
  def classify_accent(audio):
12
  if audio is None:
@@ -19,40 +22,40 @@ def classify_accent(audio):
19
  print(f"Entrada de audio recibida: {audio}")
20
 
21
  try:
22
- audio_array = audio[1] #
23
- sample_rate = audio[0] #
24
 
25
  print(f"Forma del audio: {audio_array.shape}, Frecuencia de muestreo: {sample_rate}")
26
 
27
- #
28
  audio_array = audio_array.astype(np.float32)
29
 
30
- #
31
  if sample_rate != 16000:
32
  import librosa
33
  audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=16000)
34
 
35
  input_values = processor(audio_array, return_tensors="pt", sampling_rate=16000).input_values
36
 
37
- # Infer
38
  with torch.no_grad():
39
  logits = model(input_values).logits
40
  predicted_ids = torch.argmax(logits, dim=-1).item()
41
 
42
- #
43
  labels = ["Español", "Otro"]
44
  return labels[predicted_ids]
45
 
46
  except Exception as e:
47
  return f"Error al procesar el audio: {str(e)}"
48
 
49
- #
50
  description_html = """
51
  <p>Prueba con grabación o cargando un archivo de audio. Para probar, recomiendo una palabra.</p>
52
- <p>Ramon Mayor Martins, Ph.D.: <a href="https://rmayormartins.github.io/" target="_blank">Website</a> | <a href="https://huggingface.co/rmayormartins" target="_blank">Spaces</a></p>
53
  """
54
 
55
- #
56
  interface = gr.Interface(
57
  fn=classify_accent,
58
  inputs=gr.Audio(type="numpy"),
 
1
  import gradio as gr
2
  import torch
3
  import numpy as np
4
+ from transformers import Wav2Vec2Processor
5
+ from safetensors.torch import load_file
6
 
7
+ # Carregar o modelo e o processador salvos
8
  model_name = "results"
9
  processor = Wav2Vec2Processor.from_pretrained(model_name)
10
+
11
+ # Carregar o modelo do arquivo safetensors
12
+ model = load_file("results/model.safetensors")
13
 
14
  def classify_accent(audio):
15
  if audio is None:
 
22
  print(f"Entrada de audio recibida: {audio}")
23
 
24
  try:
25
+ audio_array = audio[1] # O áudio da tupla
26
+ sample_rate = audio[0] # A taxa de amostragem da tupla
27
 
28
  print(f"Forma del audio: {audio_array.shape}, Frecuencia de muestreo: {sample_rate}")
29
 
30
+ # Converter o áudio para float32
31
  audio_array = audio_array.astype(np.float32)
32
 
33
+ # Resample para 16kHz, se necessário
34
  if sample_rate != 16000:
35
  import librosa
36
  audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=16000)
37
 
38
  input_values = processor(audio_array, return_tensors="pt", sampling_rate=16000).input_values
39
 
40
+ # Inferência
41
  with torch.no_grad():
42
  logits = model(input_values).logits
43
  predicted_ids = torch.argmax(logits, dim=-1).item()
44
 
45
+ # IDs de sotaque
46
  labels = ["Español", "Otro"]
47
  return labels[predicted_ids]
48
 
49
  except Exception as e:
50
  return f"Error al procesar el audio: {str(e)}"
51
 
52
+ # Interface do Gradio
53
  description_html = """
54
  <p>Prueba con grabación o cargando un archivo de audio. Para probar, recomiendo una palabra.</p>
55
+ <p>Ramon Mayor Martins: <a href="https://rmayormartins.github.io/" target="_blank">Website</a> | <a href="https://huggingface.co/rmayormartins" target="_blank">Spaces</a></p>
56
  """
57
 
58
+ # Interface do Gradio
59
  interface = gr.Interface(
60
  fn=classify_accent,
61
  inputs=gr.Audio(type="numpy"),
requirements.txt CHANGED
@@ -1,5 +1,6 @@
1
- gradio==4.29.0
2
  torch==2.0.1
3
  numpy==1.23.5
4
  transformers==4.24.0
5
  librosa==0.9.2
 
 
1
+ gradio==4.12.0
2
  torch==2.0.1
3
  numpy==1.23.5
4
  transformers==4.24.0
5
  librosa==0.9.2
6
+ safetensors==0.2.9