File size: 758 Bytes
e6ec6e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
046a31b
e6ec6e1
a37a1e7
e6ec6e1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
import gradio as gr
from model import EmotiClassifier

predictor = EmotiClassifier()

labels = ['angry', 'disgust', 'fear', 'happy', 'neutral', 'sad', 'surprise']

predictor.load_state_dict(torch.load('emoticlassifier-64acc-1_250loss.pth'))

def classify(image):
    
    torch_image = torch.Tensor(image)
    torch_image = torch_image.view(1, 1, torch_image.shape[0], torch_image.shape[1])
    
    
    pred = predictor(torch_image)
    
    label = torch.argmax(pred)
    
    pred_class = label.item()

    return labels[pred_class]

webcam = gr.Image(source='webcam',  streaming=True, shape=(48, 48), image_mode='L', interactive=True, tool="editor")

interface = gr.Interface(fn=classify, inputs=webcam, outputs='text')
interface.launch();