import json import os import random import warnings import gradio as gr from datasets import IterableDatasetDict, load_dataset from gradio_client import Client from loguru import logger warnings.filterwarnings("ignore") NUM_TAR_FILES = 115 HF_PATH_TO_DATASET = "litagin/Galgame_Speech_SER_16kHz" hf_token = os.getenv("HF_TOKEN") client = Client("litagin/ser_record", hf_token=hf_token) id2label = { 0: "Angry", 1: "Disgusted", 2: "Embarrassed", 3: "Fearful", 4: "Happy", 5: "Sad", 6: "Surprised", 7: "Neutral", 8: "Sexual1", 9: "Sexual2", } id2rich_label = { 0: "😠 怒ってる (0)", 1: "😒 嫌悪・キモい・うんざり・気持ち悪い (1)", 2: "😳 恥ずかしい・戸惑ってる (2)", 3: "😨 怖がってる・おびえてる・ビクビクしてる (3)", 4: "😊 幸せ・楽しい・嬉しい (4)", 5: "😢 悲しい・切ない (5)", 6: "😲 驚いてる・びっくり (6)", 7: "😐 中立・平坦・普通 (7)", 8: "🥰 NSFW1 (8)", 9: "🍭 NSFW2 (9)", } current_item: dict | None = None def _load_dataset( *, streaming: bool = True, use_local_dataset: bool = False, local_dataset_path: str | None = None, data_dir: str = "data", ) -> IterableDatasetDict: data_files = { "train": [ f"galgame-speech-ser-16kHz-train-000{index:03d}.tar" for index in range(0, NUM_TAR_FILES) ], } if use_local_dataset: assert local_dataset_path is not None path = local_dataset_path else: path = HF_PATH_TO_DATASET dataset: IterableDatasetDict = load_dataset( path=path, data_dir=data_dir, data_files=data_files, streaming=streaming ) # type: ignore dataset = dataset.remove_columns(["__url__"]) dataset = dataset.rename_column("ogg", "audio") return dataset logger.info("Start loading dataset") ds = _load_dataset(streaming=True, use_local_dataset=False) logger.info("Dataset loaded") seed = random.randint(0, 2**32 - 1) logger.info(f"Seed: {seed}") ds_iter = iter(ds["train"].shuffle(seed=seed)) # ds_iter = iter(ds["train"]) counter = 0 shortcut_js = """ """ def parse_item(item) -> dict: global counter label_id = item["cls"] sampling_rate = item["audio"]["sampling_rate"] array = item["audio"]["array"] return { "key": item["__key__"], "audio": (sampling_rate, array), "text": item["txt"], "label": id2rich_label[label_id], "label_id": label_id, "counter": counter, } def get_next_parsed_item() -> dict: global counter, ds_iter logger.info("Getting next item") try: next_item = next(ds_iter) counter += 1 except StopIteration: logger.info("StopIteration, re-initializing using new seed") seed = random.randint(0, 2**32 - 1) logger.info(f"New Seed: {seed}") ds_iter = iter(ds["train"].shuffle(seed=seed)) next_item = next(ds_iter) counter = 1 parsed = parse_item(next_item) logger.info( f"Next item:\nkey={parsed['key']}\ntext={parsed['text']}\nlabel={parsed['label']}" ) return parsed md = """ # 説明 - **性的な音声が含まれるため、18歳未満の方はご利用をお控えください** - このアプリは [このゲームのセリフ音声データセット](https://huggingface.co/datasets/litagin/Galgame_Speech_SER_16kHz) の感情ラベルを修正して、大規模で高品質な感情音声データセットを作成するためのものです - 「**何を言っているか**」ではなく「**どのように言っているか**」に注目して、感情ラベルを付与してください(例: 悲しそうに「とっても楽しいです」と言っていたら、 `😊 幸せ` ではなく `😢 悲しみ` とする) - 既存のラベルが適切であれば、そのまま「現在の感情ラベルで適切」ボタンを押してください(ショートカットキー: `A`) - ラベルを修正する場合は、適切なボタンを押してください(ショートカットキー: `0` 〜 `9`) # ラベル補足 - `🥰 NSFW1` は女性の性的行為中の音声(喘ぎ声等) - `🍭 NSFW2` はキスシーンでのリップ音やフェラシーンでのしゃぶる音(チュパ音)が多く含まれている音声(セリフ+チュパ音の場合も含む、しゃぶりつつのセリフだと思われる場合はこれ) - 感情が音声からは特に読み取れない、普通のテンションの声の場合(平坦に「今日はラーメンを食べます」等)は `😐 中立` を選択してください - 複数の感情が含まれている場合は、最も多く含まれている感情を選択してください """ with gr.Blocks(head=shortcut_js) as app: gr.Markdown(md) with gr.Row(): with gr.Column(): btn_init = gr.Button("読み込み") with gr.Column(variant="panel"): key = gr.Textbox(label="Key") audio = gr.Audio( show_download_button=False, show_share_button=False, interactive=False, ) text = gr.Textbox(label="Text") label = gr.Textbox(label="感情ラベル") label_id = gr.Textbox(visible=False) btn_skip = gr.Button("現在の感情ラベルで適切 (A)", elem_id="btn_skip") with gr.Column(): gr.Markdown("# 感情ラベルを修正する場合") btn_list = [ gr.Button(id2rich_label[_id], elem_id=f"btn_{_id}") for _id in range(10) ] def update_current_item(data: dict) -> dict: global current_item if current_item is None: current_item = get_next_parsed_item() return { key: current_item["key"], audio: gr.Audio(current_item["audio"], autoplay=True), text: current_item["text"], label: current_item["label"], label_id: current_item["label_id"], } def set_next_item(data: dict) -> dict: global current_item current_item = get_next_parsed_item() return update_current_item(data) def put_unmodified(data: dict) -> dict: logger.info("Putting unmodified") current_key = data[key] current_label_id = data[label_id] result = client.predict( new_data=json.dumps( { "key": current_key, "cls": int(current_label_id), } ), api_name="/put_data", ) logger.info(f"Unmodified sent, result: {result}") return set_next_item(data) btn_init.click( update_current_item, outputs=[key, audio, text, label, label_id], ) btn_skip.click( put_unmodified, inputs={key, label_id}, outputs=[key, audio, text, label, label_id], ) functions_list = [] for _id in range(10): def put_label(data: dict, _id=_id) -> dict: logger.info(f"Putting label: {id2rich_label[_id]}") current_key = data[key] result = client.predict( new_data=json.dumps( { "key": current_key, "cls": _id, } ), api_name="/put_data", ) logger.info(f"Modified sent, result: {result}") return set_next_item(data) functions_list.append(put_label) for _id in range(10): btn_list[_id].click( functions_list[_id], inputs={key}, outputs=[key, audio, text, label, label_id], ) app.launch()