File size: 4,142 Bytes
e111c36
 
87f6c9c
e111c36
 
 
f24ec85
fc0b2dd
f24ec85
 
 
 
fc0b2dd
e111c36
87f6c9c
e111c36
fc0b2dd
e111c36
f24ec85
e111c36
 
0a3c034
e111c36
 
fc0b2dd
f24ec85
87f6c9c
f24ec85
 
 
fc0b2dd
 
 
f24ec85
87f6c9c
0a3c034
87f6c9c
 
f24ec85
87f6c9c
 
f24ec85
fc0b2dd
 
87f6c9c
 
f24ec85
87f6c9c
 
f24ec85
 
87f6c9c
 
fc0b2dd
f24ec85
0a3c034
fc0b2dd
 
 
87f6c9c
 
 
 
7b4cf33
87f6c9c
6d09264
87f6c9c
 
 
f24ec85
 
87f6c9c
f24ec85
fc0b2dd
f24ec85
87f6c9c
 
f24ec85
87f6c9c
e111c36
f24ec85
56606dd
 
f24ec85
e111c36
 
87f6c9c
0a3c034
87f6c9c
56606dd
87f6c9c
 
 
56606dd
87f6c9c
 
56606dd
 
 
87f6c9c
56606dd
 
 
 
 
 
 
 
0a3c034
 
e111c36
56813b6
 
d5a4fc1
 
 
 
 
 
f24ec85
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import gradio as gr
import torch
import torch.nn.functional as F
from transformers import Wav2Vec2Processor, Wav2Vec2ForSequenceClassification
import torchaudio

# Define emotion labels and corresponding icons
emotion_labels = ["angry", "calm", "disgust", "fearful", "happy", "neutral", "sad", "surprised"]
emotion_icons = {
    "angry": "😠", "calm": "😌", "disgust": "🀒", "fearful": "😨",
    "happy": "😊", "neutral": "😐", "sad": "😒", "surprised": "😲"
}

# Load model and processor
model_name = "Dpngtm/wav2vec2-emotion-recognition"
model = Wav2Vec2ForSequenceClassification.from_pretrained(model_name)
processor = Wav2Vec2Processor.from_pretrained(model_name, num_labels=len(emotion_labels))

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()

def recognize_emotion(audio):
    try:
        # Handle case where no audio is provided
        if audio is None:
            return {f"{emotion} {emotion_icons[emotion]}": 0.0 for emotion in emotion_labels}
        
        # Load and preprocess the audio
        audio_path = audio if isinstance(audio, str) else audio.name
        speech_array, sampling_rate = torchaudio.load(audio_path)
        
        # Limit audio length to 1 minute (60 seconds)
        duration = speech_array.shape[1] / sampling_rate
        if duration > 60:
            return {
                "Error": "Audio too long (max 1 minute)",
                **{f"{emotion} {emotion_icons[emotion]}": 0.0 for emotion in emotion_labels}
            }
        
        # Resample audio if not at 16kHz
        if sampling_rate != 16000:
            resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000)
            speech_array = resampler(speech_array)
        
        # Convert stereo to mono if necessary
        if speech_array.shape[0] > 1:
            speech_array = torch.mean(speech_array, dim=0, keepdim=True)
        
        # Normalize audio
        speech_array = speech_array / torch.max(torch.abs(speech_array))
        speech_array = speech_array.squeeze().numpy()
        
        # Process audio with the model
        inputs = processor(speech_array, sampling_rate=16000, return_tensors='pt', padding=True)
        input_values = inputs.input_values.to(device)
        
        with torch.no_grad():
            outputs = model(input_values)
            logits = outputs.logits
            probs = F.softmax(logits, dim=-1)[0].cpu().numpy()
            
            # Prepare the confidence scores without converting to percentages
            confidence_scores = {
                f"{emotion} {emotion_icons[emotion]}": prob
                for emotion, prob in zip(emotion_labels, probs)
            }
            
            # Sort scores in descending order
            sorted_scores = dict(sorted(confidence_scores.items(), key=lambda x: x[1], reverse=True))
            return sorted_scores
    
    except Exception as e:
        # Return error message along with zeroed-out emotion scores
        return {
            "Error": str(e),
            **{f"{emotion} {emotion_icons[emotion]}": 0.0 for emotion in emotion_labels}
        }

# Supported emotions for display
supported_emotions = " | ".join([f"{emotion_icons[emotion]} {emotion}" for emotion in emotion_labels])

# Gradio Interface setup
interface = gr.Interface(
    fn=recognize_emotion,
    inputs=gr.Audio(
        sources=["microphone", "upload"],
        type="filepath",
        label="Record or Upload Audio"
    ),
    outputs=gr.Label(
        num_top_classes=len(emotion_labels),
        label="Detected Emotion"
    ),
    title="Speech Emotion Recognition",
    description=f"""
    ### Supported Emotions:
    {supported_emotions}
    
    Maximum audio length: 1 minute""",
    theme=gr.themes.Soft(
        primary_hue="orange",
        secondary_hue="blue"
    ),
    css="""
        .gradio-container {max-width: 800px}
        .label {font-size: 18px}
    """
)



if __name__ == "__main__":
    interface.launch(
        share=True,
        debug=True,
        server_name="0.0.0.0",
        server_port=7860
    )