Spaces:
Paused
Paused
File size: 2,608 Bytes
cf32bd5 7138209 cf32bd5 327e3b5 cf32bd5 327e3b5 cf32bd5 b2c1876 cf32bd5 327e3b5 cf32bd5 b2c1876 cf32bd5 b2c1876 cf32bd5 327e3b5 cf32bd5 bbb642a cf32bd5 eb6c3f3 cf32bd5 bbb642a b5778d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
# Gaepago model V1 (CPU Test)
# import package
from transformers import AutoModelForAudioClassification
from transformers import AutoFeatureExtractor
from transformers import pipeline
from datasets import Dataset, Audio
import gradio as gr
import torch
from utils.postprocess import text_mapping
import json
# Set model & Dataset NM
MODEL_NAME = "Gae8J/gaepago-20"
DATASET_NAME = "Gae8J/modeling_v1"
TEXT_LABEL = "text_label.json"
# Import Model & feature extractor
# model = AutoModelForAudioClassification.from_pretrained(MODEL_NAME)
from transformers import AutoConfig
config = AutoConfig.from_pretrained(MODEL_NAME)
model = torch.jit.load(f"./model/gaepago-20-lite/model_quant_int8.pt")
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)
# ๋ชจ๋ธ cpu๋ก ๋ณ๊ฒฝํ์ฌ ์งํ
model.to("cpu")
# TEXT LABEL ๋ถ๋ฌ์ค๊ธฐ
with open(TEXT_LABEL,"r",encoding='utf-8') as f:
text_label = json.load(f)
# Gaepago Inference Model function
def gaepago_fn(tmp_audio_dir):
print(tmp_audio_dir)
audio_dataset = Dataset.from_dict({"audio": [tmp_audio_dir]}).cast_column("audio", Audio(sampling_rate=16000))
inputs = feature_extractor(audio_dataset[0]["audio"]["array"]
,sampling_rate=audio_dataset[0]["audio"]["sampling_rate"]
,return_tensors="pt")
with torch.no_grad():
# logits = model(**inputs).logits
logits = model(**inputs)["logits"]
# predicted_class_ids = torch.argmax(logits).item()
# predicted_label = model.config.id2label[predicted_class_ids]
predicted_class_ids = torch.argmax(logits).item()
predicted_label = config.id2label[predicted_class_ids]
# add postprocessing
## 1. text mapping
output = text_mapping(predicted_label,text_label)
return output
# Main
example_list = ["./sample/bark_sample.wav"
,"./sample/growling_sample.wav"
,"./sample/howl_sample.wav"
,"./sample/panting_sample.wav"
,"./sample/whimper_sample.wav"
]
main_api = gr.Blocks()
with main_api:
gr.Markdown("## 8J Gaepago Demo(with CPU)")
with gr.Row():
audio = gr.Audio(source="microphone", type="filepath"
,label='๋
น์๋ฒํผ์ ๋๋ฌ ์ด์ฝ๊ฐ ํ๋ ๋ง์ ๋ค๋ ค์ฃผ์ธ์')
transcription = gr.Textbox(label='์ง๊ธ ์ด์ฝ๊ฐ ํ๋ ๋ง์...')
b1 = gr.Button("๊ฐ์์ง ์ธ์ด ๋ฒ์ญ!")
b1.click(gaepago_fn, inputs=audio, outputs=transcription)
examples = gr.Examples(examples=example_list, inputs=[audio])
main_api.launch() |