Flux9665's picture
Update app.py
ddcb369 verified
import json
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import torch
from ArticulatoryTextFrontend import ArticulatoryTextFrontend
def visualize_one_hot_encoded_sequence(tensor, sentence, col_labels, cmap='BuGn'):
"""
Visualize a 2D one-hot encoded tensor as a heatmap.
"""
tensor = torch.clamp(tensor, min=0, max=1).transpose(0, 1).cpu().numpy()
if tensor.ndim != 2:
raise ValueError("Input tensor must be a 2D array")
# Check the size of labels matches the tensor dimensions
row_labels = ["stressed", "very-high-tone", "high-tone", "mid-tone", "low-tone", "very-low-tone", "rising-tone", "falling-tone", "peaking-tone", "dipping-tone", "lengthened", "half-length", "shortened", "consonant", "vowel", "phoneme", "silence", "end of sentence", "questionmark", "exclamationmark", "fullstop", "word-boundary", "dental", "postalveolar",
"velar", "palatal", "glottal", "uvular", "labiodental", "labial-velar", "alveolar", "bilabial", "alveolopalatal", "retroflex", "pharyngal", "epiglottal", "central", "back", "front_central", "front", "central_back", "mid", "close-mid", "close", "open-mid", "close_close-mid", "open-mid_open", "open", "rounded", "unrounded", "plosive",
"nasal", "approximant", "trill", "flap", "fricative", "lateral-approximant", "implosive", "vibrant", "click", "ejective", "aspirated", "unvoiced", "voiced"]
if row_labels and len(row_labels) != tensor.shape[0]:
raise ValueError("Number of row labels must match the number of rows in the tensor")
if col_labels and len(col_labels) != tensor.shape[1]:
raise ValueError("Number of column labels must match the number of columns in the tensor")
fig, ax = plt.subplots(figsize=(16, 16))
# Create the heatmap
ax.imshow(tensor, cmap=cmap, aspect='auto')
# Add labels
if row_labels:
ax.set_yticks(np.arange(tensor.shape[0]), row_labels)
if col_labels:
ax.set_xticks(np.arange(tensor.shape[1]), col_labels, rotation=0)
ax.grid(False)
ax.set_xlabel('Phones')
ax.set_ylabel('Features')
ax.set_xticks(np.arange(-0.5, tensor.shape[1], 1), minor=True)
ax.set_yticks(np.arange(-0.5, tensor.shape[0], 1), minor=True)
# Turn on the grid for minor ticks (i.e., between the cells)
ax.grid(which='minor', color='darkgrey', linestyle='-', linewidth=1)
ax.tick_params(which='minor', size=0) # Hide minor tick marks
# Display the heatmap
ax.set_title(f"»{sentence}«")
return fig
def vis_wrapper(sentence, language):
tf = ArticulatoryTextFrontend(language=language.split(" ")[-1].split("(")[1].split(")")[0])
features = tf.string_to_tensor(sentence)
phones = tf.get_phone_string(sentence)
return visualize_one_hot_encoded_sequence(tensor=features, sentence=sentence, col_labels=phones)
def load_json_from_path(path):
with open(path, "r", encoding="utf8") as f:
obj = json.loads(f.read())
return obj
iso_to_name = load_json_from_path("iso_to_fullname.json")
text_selection = [f"{iso_to_name[iso_code]} ({iso_code})" for iso_code in iso_to_name]
iface = gr.Interface(fn=vis_wrapper,
inputs=[gr.Textbox(lines=2,
placeholder="write the sentence you want to visualize here...",
value="What I cannot create, I do not understand.",
label="Text input"),
gr.Dropdown(text_selection,
type="value",
value='English (eng)',
label="Select the Language of the Text (type on your keyboard to find it quickly)")],
outputs=[gr.Plot(label="", show_label=False, format="png", container=True)],
allow_flagging="never",
description="<br><br>This demo converts any sentence into a sequence of articulatory features and then displays a visualization of them. This can be useful for phonetic applications, as well as text-to-speech, since this representation is language agnostic. The only major bottleneck is the conversion from graphemes to phonemes and their modifiers. While there are more than 7000 languages supported, the correctness and completeness of the produced phoneme sequences with their modifiers varies a lot across languages. To use this in a project, have a look at https://github.com/Flux9665/ArticulatoryTextFrontend <br><br>")
iface.launch()