Spaces:
Running
Running
import json | |
import os | |
import random | |
import warnings | |
import gradio as gr | |
import librosa | |
import numpy as np | |
from datasets import IterableDatasetDict, load_dataset | |
from gradio_client import Client | |
from loguru import logger | |
warnings.filterwarnings("ignore") | |
NUM_TAR_FILES = 115 | |
NUM_SAMPLES = 3746131 | |
HF_PATH_TO_DATASET = "litagin/Galgame_Speech_SER_16kHz" | |
hf_token = os.getenv("HF_TOKEN") | |
client = Client("litagin/ser_record", hf_token=hf_token) | |
id2label = { | |
0: "Angry", | |
1: "Disgusted", | |
2: "Embarrassed", | |
3: "Fearful", | |
4: "Happy", | |
5: "Sad", | |
6: "Surprised", | |
7: "Neutral", | |
8: "Sexual1", | |
9: "Sexual2", | |
} | |
id2rich_label = { | |
0: "😠 怒ってる (0)", | |
1: "😒 嫌悪・キモい・うんざり・気持ち悪い (1)", | |
2: "😳 恥ずかしい・戸惑ってる (2)", | |
3: "😨 怖がってる・おびえてる・ビクビクしてる (3)", | |
4: "😊 幸せ・楽しい・嬉しい (4)", | |
5: "😢 悲しい・切ない (5)", | |
6: "😲 驚いてる・びっくり (6)", | |
7: "😐 中立・平坦・普通 (7)", | |
8: "🥰 NSFW1 (8)", | |
9: "🍭 NSFW2 (9)", | |
} | |
current_item: dict | None = None | |
def _load_dataset( | |
*, | |
streaming: bool = True, | |
use_local_dataset: bool = False, | |
local_dataset_path: str | None = None, | |
data_dir: str = "data", | |
) -> IterableDatasetDict: | |
data_files = { | |
"train": [ | |
f"galgame-speech-ser-16kHz-train-000{index:03d}.tar" | |
for index in range(0, NUM_TAR_FILES) | |
], | |
} | |
if use_local_dataset: | |
assert local_dataset_path is not None | |
path = local_dataset_path | |
else: | |
path = HF_PATH_TO_DATASET | |
dataset: IterableDatasetDict = load_dataset( | |
path=path, data_dir=data_dir, data_files=data_files, streaming=streaming | |
) # type: ignore | |
dataset = dataset.remove_columns(["__url__"]) | |
dataset = dataset.rename_column("ogg", "audio") | |
return dataset | |
logger.info("Start loading dataset") | |
ds = _load_dataset(streaming=True, use_local_dataset=False) | |
logger.info("Dataset loaded") | |
seed = random.randint(0, 2**32 - 1) | |
logger.info(f"Seed: {seed}") | |
ds_iter = iter(ds["train"].shuffle(seed=seed)) | |
# ds_iter = iter(ds["train"]) | |
counter = 0 | |
shortcut_js = """ | |
<script> | |
function shortcuts(e) { | |
if (e.key === "a") { | |
document.getElementById("btn_skip").click(); | |
} else if (e.key === "0") { | |
document.getElementById("btn_0").click(); | |
} else if (e.key === "1") { | |
document.getElementById("btn_1").click(); | |
} else if (e.key === "2") { | |
document.getElementById("btn_2").click(); | |
} else if (e.key === "3") { | |
document.getElementById("btn_3").click(); | |
} else if (e.key === "4") { | |
document.getElementById("btn_4").click(); | |
} else if (e.key === "5") { | |
document.getElementById("btn_5").click(); | |
} else if (e.key === "6") { | |
document.getElementById("btn_6").click(); | |
} else if (e.key === "7") { | |
document.getElementById("btn_7").click(); | |
} else if (e.key === "8") { | |
document.getElementById("btn_8").click(); | |
} else if (e.key === "9") { | |
document.getElementById("btn_9").click(); | |
} | |
} | |
document.addEventListener('keypress', shortcuts, false); | |
</script> | |
""" | |
def parse_item(item) -> dict: | |
global counter | |
label_id = item["cls"] | |
sampling_rate = item["audio"]["sampling_rate"] | |
array = item["audio"]["array"] | |
return { | |
"key": item["__key__"], | |
"audio": (sampling_rate, array), | |
"text": item["txt"], | |
"label": id2rich_label[label_id], | |
"label_id": label_id, | |
"counter": counter, | |
} | |
def get_next_parsed_item() -> dict: | |
global counter, ds_iter | |
logger.info("Getting next item") | |
try: | |
next_item = next(ds_iter) | |
counter += 1 | |
except StopIteration: | |
logger.info("StopIteration, re-initializing using new seed") | |
seed = random.randint(0, 2**32 - 1) | |
logger.info(f"New Seed: {seed}") | |
ds_iter = iter(ds["train"].shuffle(seed=seed)) | |
next_item = next(ds_iter) | |
counter = 1 | |
parsed = parse_item(next_item) | |
logger.info( | |
f"Next item:\nkey={parsed['key']}\ntext={parsed['text']}\nlabel={parsed['label']}" | |
) | |
return parsed | |
md = """ | |
# 説明 | |
- **性的な音声が含まれるため、18歳未満の方はご利用をお控えください** | |
- このアプリは [このゲームのセリフ音声データセット](https://huggingface.co/datasets/litagin/Galgame_Speech_SER_16kHz) の感情ラベルを修正して、大規模で高品質な感情音声データセットを作成するためのものです | |
- 「**何を言っているか**」ではなく「**どのように言っているか**」に注目して、感情ラベルを付与してください(例: 悲しそうに「とっても楽しいです」と言っていたら、 `😊 幸せ` ではなく `😢 悲しみ` とする) | |
- 既存のラベルが適切であれば、そのまま「現在の感情ラベルで適切」ボタンを押してください(ショートカットキー: `A`) | |
- ラベルを修正する場合は、適切なボタンを押してください(ショートカットキー: `0` 〜 `9`) | |
# ラベル補足 | |
- `🥰 NSFW1` は女性の性的行為中の音声(喘ぎ声等) | |
- `🍭 NSFW2` はキスシーンでのリップ音やフェラシーンでのしゃぶる音(チュパ音)が多く含まれている音声(セリフ+チュパ音の場合も含む、しゃぶりつつのセリフだと思われる場合はこれ) | |
- 感情が音声からは特に読み取れない、普通のテンションの声の場合(平坦に「今日はラーメンを食べます」等)は `😐 中立` を選択してください | |
- 複数の感情が含まれている場合は、最も多く含まれている感情を選択してください | |
""" | |
with gr.Blocks(head=shortcut_js) as app: | |
gr.Markdown(md) | |
with gr.Row(): | |
with gr.Column(): | |
btn_init = gr.Button("読み込み") | |
with gr.Column(variant="panel"): | |
key = gr.Textbox(label="Key") | |
audio = gr.Audio( | |
show_download_button=False, | |
show_share_button=False, | |
interactive=False, | |
) | |
text = gr.Textbox(label="Text") | |
label = gr.Textbox(label="感情ラベル") | |
label_id = gr.Textbox(visible=False) | |
btn_skip = gr.Button("現在の感情ラベルで適切 (A)", elem_id="btn_skip") | |
with gr.Column(): | |
gr.Markdown("# 感情ラベルを修正する場合") | |
btn_list = [ | |
gr.Button(id2rich_label[_id], elem_id=f"btn_{_id}") for _id in range(10) | |
] | |
def update_current_item(data: dict) -> dict: | |
global current_item | |
if current_item is None: | |
current_item = get_next_parsed_item() | |
return { | |
key: current_item["key"], | |
audio: gr.Audio(current_item["audio"], autoplay=True), | |
text: current_item["text"], | |
label: current_item["label"], | |
label_id: current_item["label_id"], | |
} | |
def set_next_item(data: dict) -> dict: | |
global current_item | |
current_item = get_next_parsed_item() | |
return update_current_item(data) | |
def put_unmodified(data: dict) -> dict: | |
logger.info("Putting unmodified") | |
current_key = data[key] | |
current_label_id = data[label_id] | |
_ = client.predict( | |
new_data=json.dumps( | |
{ | |
"key": current_key, | |
"cls": int(current_label_id), | |
} | |
), | |
api_name="/put_data", | |
) | |
logger.info("Unmodified sent") | |
return set_next_item(data) | |
btn_init.click( | |
update_current_item, | |
outputs=[key, audio, text, label, label_id], | |
) | |
btn_skip.click( | |
put_unmodified, | |
inputs={key, label_id}, | |
outputs=[key, audio, text, label, label_id], | |
) | |
functions_list = [] | |
for _id in range(10): | |
def put_label(data: dict, _id=_id) -> dict: | |
logger.info(f"Putting label: {id2rich_label[_id]}") | |
current_key = data[key] | |
_ = client.predict( | |
new_data=json.dumps( | |
{ | |
"key": current_key, | |
"cls": _id, | |
} | |
), | |
api_name="/put_data", | |
) | |
logger.info("Modified sent") | |
return set_next_item(data) | |
functions_list.append(put_label) | |
for _id in range(10): | |
btn_list[_id].click( | |
functions_list[_id], | |
inputs={key}, | |
outputs=[key, audio, text, label, label_id], | |
) | |
app.launch() | |