File size: 1,253 Bytes
c165076
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2ef383
c165076
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import gradio
import torchaudio
from fastai.vision.all import *
from fastai.learner import load_learner
from torchvision.utils import save_image
from huggingface_hub import hf_hub_download


model = load_learner(
    hf_hub_download("kurianbenoy/music_genre_classification_baseline", "model.pkl")
)

EXAMPLES_PATH = Path("./examples")
labels = model.dls.vocab

interface_options = {
    "title": "Music Genre Classification",
    "description": "A simple baseline model for classifying music genres with fast.ai on [Kaggle competition data](https://www.kaggle.com/competitions/kaggle-pog-series-s01e02/data)",
    "examples": [f"{EXAMPLES_PATH}/{f.name}" for f in EXAMPLES_PATH.iterdir()],
    "interpretation": "default",
    "layout": "horizontal",
    "theme": "default",
}


def predict(img):
    img = PILImage.create(img)
    _pred, _pred_w_idx, probs = model.predict(img)
    labels_probs = {labels[i]: float(probs[i]) for i, _ in enumerate(labels)}
    return labels_probs


demo = gradio.Interface(
    fn=predict,
    inputs=gradio.inputs.Image(shape=(512, 512)),
    outputs=gradio.outputs.Label(num_top_classes=5),
    **interface_options,
)

launch_options = {
    "enable_queue": True,
    "share": False,
}

demo.launch(**launch_options)