Spaces:
Build error
Build error
File size: 6,203 Bytes
45ee559 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
# visualisation tools for mimic2
import argparse
import csv
import os
import random
from statistics import StatisticsError, mean, median, mode, stdev
import matplotlib.pyplot as plt
import seaborn as sns
from text.cmudict import CMUDict
def get_audio_seconds(frames):
return (frames * 12.5) / 1000
def append_data_statistics(meta_data):
# get data statistics
for char_cnt in meta_data:
data = meta_data[char_cnt]["data"]
audio_len_list = [d["audio_len"] for d in data]
mean_audio_len = mean(audio_len_list)
try:
mode_audio_list = [round(d["audio_len"], 2) for d in data]
mode_audio_len = mode(mode_audio_list)
except StatisticsError:
mode_audio_len = audio_len_list[0]
median_audio_len = median(audio_len_list)
try:
std = stdev(d["audio_len"] for d in data)
except StatisticsError:
std = 0
meta_data[char_cnt]["mean"] = mean_audio_len
meta_data[char_cnt]["median"] = median_audio_len
meta_data[char_cnt]["mode"] = mode_audio_len
meta_data[char_cnt]["std"] = std
return meta_data
def process_meta_data(path):
meta_data = {}
# load meta data
with open(path, "r", encoding="utf-8") as f:
data = csv.reader(f, delimiter="|")
for row in data:
frames = int(row[2])
utt = row[3]
audio_len = get_audio_seconds(frames)
char_count = len(utt)
if not meta_data.get(char_count):
meta_data[char_count] = {"data": []}
meta_data[char_count]["data"].append(
{
"utt": utt,
"frames": frames,
"audio_len": audio_len,
"row": "{}|{}|{}|{}".format(row[0], row[1], row[2], row[3]),
}
)
meta_data = append_data_statistics(meta_data)
return meta_data
def get_data_points(meta_data):
x = meta_data
y_avg = [meta_data[d]["mean"] for d in meta_data]
y_mode = [meta_data[d]["mode"] for d in meta_data]
y_median = [meta_data[d]["median"] for d in meta_data]
y_std = [meta_data[d]["std"] for d in meta_data]
y_num_samples = [len(meta_data[d]["data"]) for d in meta_data]
return {
"x": x,
"y_avg": y_avg,
"y_mode": y_mode,
"y_median": y_median,
"y_std": y_std,
"y_num_samples": y_num_samples,
}
def save_training(file_path, meta_data):
rows = []
for char_cnt in meta_data:
data = meta_data[char_cnt]["data"]
for d in data:
rows.append(d["row"] + "\n")
random.shuffle(rows)
with open(file_path, "w+", encoding="utf-8") as f:
for row in rows:
f.write(row)
def plot(meta_data, save_path=None):
save = False
if save_path:
save = True
graph_data = get_data_points(meta_data)
x = graph_data["x"]
y_avg = graph_data["y_avg"]
y_std = graph_data["y_std"]
y_mode = graph_data["y_mode"]
y_median = graph_data["y_median"]
y_num_samples = graph_data["y_num_samples"]
plt.figure()
plt.plot(x, y_avg, "ro")
plt.xlabel("character lengths", fontsize=30)
plt.ylabel("avg seconds", fontsize=30)
if save:
name = "char_len_vs_avg_secs"
plt.savefig(os.path.join(save_path, name))
plt.figure()
plt.plot(x, y_mode, "ro")
plt.xlabel("character lengths", fontsize=30)
plt.ylabel("mode seconds", fontsize=30)
if save:
name = "char_len_vs_mode_secs"
plt.savefig(os.path.join(save_path, name))
plt.figure()
plt.plot(x, y_median, "ro")
plt.xlabel("character lengths", fontsize=30)
plt.ylabel("median seconds", fontsize=30)
if save:
name = "char_len_vs_med_secs"
plt.savefig(os.path.join(save_path, name))
plt.figure()
plt.plot(x, y_std, "ro")
plt.xlabel("character lengths", fontsize=30)
plt.ylabel("standard deviation", fontsize=30)
if save:
name = "char_len_vs_std"
plt.savefig(os.path.join(save_path, name))
plt.figure()
plt.plot(x, y_num_samples, "ro")
plt.xlabel("character lengths", fontsize=30)
plt.ylabel("number of samples", fontsize=30)
if save:
name = "char_len_vs_num_samples"
plt.savefig(os.path.join(save_path, name))
def plot_phonemes(train_path, cmu_dict_path, save_path):
cmudict = CMUDict(cmu_dict_path)
phonemes = {}
with open(train_path, "r", encoding="utf-8") as f:
data = csv.reader(f, delimiter="|")
phonemes["None"] = 0
for row in data:
words = row[3].split()
for word in words:
pho = cmudict.lookup(word)
if pho:
indie = pho[0].split()
for nemes in indie:
if phonemes.get(nemes):
phonemes[nemes] += 1
else:
phonemes[nemes] = 1
else:
phonemes["None"] += 1
x, y = [], []
for k, v in phonemes.items():
x.append(k)
y.append(v)
plt.figure()
plt.rcParams["figure.figsize"] = (50, 20)
barplot = sns.barplot(x=x, y=y)
if save_path:
fig = barplot.get_figure()
fig.savefig(os.path.join(save_path, "phoneme_dist"))
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--train_file_path",
required=True,
help="this is the path to the train.txt file that the preprocess.py script creates",
)
parser.add_argument("--save_to", help="path to save charts of data to")
parser.add_argument("--cmu_dict_path", help="give cmudict-0.7b to see phoneme distribution")
args = parser.parse_args()
meta_data = process_meta_data(args.train_file_path)
plt.rcParams["figure.figsize"] = (10, 5)
plot(meta_data, save_path=args.save_to)
if args.cmu_dict_path:
plt.rcParams["figure.figsize"] = (30, 10)
plot_phonemes(args.train_file_path, args.cmu_dict_path, args.save_to)
plt.show()
if __name__ == "__main__":
main()
|