Spaces:
Running
Running
File size: 4,644 Bytes
1543414 ad25137 1543414 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import json
import os
import shutil
import sys
from collections import defaultdict
import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix
from utils import compute_average_wer, download_dataset
def main():
"""
Main function to orchestrate the multilingual data generation process.
This function performs the following steps:
1. Downloads multilingual evaluation data if requested.
2. Processes multilingual evaluation files.
3. Calculates and saves results, including Word Error Rate (WER) and
language detection confusion matrices.
"""
source_repo = "argmaxinc/whisperkit-evals-multilingual"
source_subfolder = "WhisperKit"
source_directory = f"{source_repo}/{source_subfolder}"
if len(sys.argv) > 1 and sys.argv[1] == "download":
try:
shutil.rmtree(source_repo)
except:
print("Nothing to remove.")
download_dataset(source_repo, source_repo, source_subfolder)
results = defaultdict(
lambda: {
"average_wer": [],
"language_wer": defaultdict(list),
"language_detection": [],
}
)
confusion_matrices = {}
for subdir, _, files in os.walk(source_directory):
for filename in files:
if not filename.endswith(".json") or "summary" in filename:
continue
file_path = os.path.join(subdir, filename)
with open(file_path, "r") as f:
data = json.load(f)
subdir_components = subdir.split(os.path.sep)
is_forced = "forced" in subdir_components
model = subdir_components[-3] if not is_forced else subdir_components[-4]
key = f"{model}/{'forced' if is_forced else 'not_forced'}"
for item in data["results"]:
if "reference_language" not in item:
continue
reference_language = item["reference_language"]
wer = item["wer"]
detected_language = item["predicted_language"]
result = {
"reference": item["reference"],
"prediction": item["prediction"],
}
results[key]["average_wer"].append(result)
results[key]["language_wer"][reference_language].append(result)
results[key]["language_detection"].append(
(reference_language, detected_language)
)
calculate_and_save_results(results, confusion_matrices)
def calculate_and_save_results(results, confusion_matrices):
"""
Calculates final multilingual metrics and saves them to CSV and JSON files.
:param results: Dictionary containing raw multilingual evaluation data.
:param confusion_matrices: Dictionary to store confusion matrices for language detection.
This function processes the raw multilingual data, calculates average metrics,
creates confusion matrices for language detection, and saves the results to:
1. A CSV file with WER data for each model and language.
2. A JSON file with confusion matrices for language detection.
"""
wer_data = []
for key, data in results.items():
model, forced = key.rsplit("/", 1)
model = model.replace("_", "/")
row = {
"Model": model,
"Forced Tokens": forced == "forced",
"Average WER": compute_average_wer(data["average_wer"]),
}
for lang, wers in data["language_wer"].items():
row[f"WER_{lang}"] = compute_average_wer(wers)
wer_data.append(row)
true_languages, detected_languages = zip(*data["language_detection"])
unique_languages = sorted(set(true_languages))
cm = confusion_matrix(
true_languages, detected_languages, labels=unique_languages
)
row_sums = cm.sum(axis=1)
cm_normalized = np.zeros_like(cm, dtype=float)
non_zero_rows = row_sums != 0
cm_normalized[non_zero_rows] = (
cm[non_zero_rows] / row_sums[non_zero_rows, np.newaxis]
)
if model not in confusion_matrices:
confusion_matrices[model] = {}
confusion_matrices[model][forced] = {
"matrix": cm_normalized.tolist(),
"labels": unique_languages,
}
df = pd.DataFrame(wer_data)
df.to_csv("dashboard_data/multilingual_results.csv", index=False)
with open("dashboard_data/multilingual_confusion_matrices.json", "w") as f:
json.dump(confusion_matrices, f, indent=2)
if __name__ == "__main__":
main()
|