Spaces:
Runtime error
Runtime error
import sys, os, datetime, warnings, argparse | |
import tensorflow as tf | |
import numpy as np | |
from model import ukws | |
from dataset import google_infe202405 | |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | |
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) | |
warnings.filterwarnings('ignore') | |
warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning) | |
np.warnings.filterwarnings('ignore', category=np.VisibleDeprecationWarning) | |
warnings.simplefilter("ignore") | |
seed = 42 | |
tf.random.set_seed(seed) | |
np.random.seed(seed) | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--text_input', required=False, type=str, default='g2p_embed') | |
parser.add_argument('--audio_input', required=False, type=str, default='both') | |
parser.add_argument('--load_checkpoint_path', required=True, type=str) | |
parser.add_argument('--google_pkl', required=False, type=str, default='/home/DB/data/google_test_all.pkl') | |
parser.add_argument('--stack_extractor', action='store_true') | |
args = parser.parse_args() | |
gpus = tf.config.experimental.list_physical_devices('GPU') | |
if gpus: | |
try: | |
for gpu in gpus: | |
tf.config.experimental.set_memory_growth(gpu, True) | |
except RuntimeError as e: | |
print(e) | |
strategy = tf.distribute.MirroredStrategy() | |
# Batch size per GPU | |
GLOBAL_BATCH_SIZE = 1000 * strategy.num_replicas_in_sync | |
BATCH_SIZE_PER_REPLICA = GLOBAL_BATCH_SIZE / strategy.num_replicas_in_sync | |
# Make Dataloader | |
text_input = args.text_input | |
audio_input = args.audio_input | |
load_checkpoint_path = args.load_checkpoint_path | |
test_google_dataset = google_infe202405.GoogleCommandsDataloader(batch_size=GLOBAL_BATCH_SIZE, features=text_input, shuffle=False, pkl=args.google_pkl) | |
test_google_dataset = google_infe202405.convert_sequence_to_dataset(test_google_dataset) | |
test_google_dist_dataset = strategy.experimental_distribute_dataset(test_google_dataset) | |
phonemes = ["<pad>", ] + ['AA0', 'AA1', 'AA2', 'AE0', 'AE1', 'AE2', 'AH0', 'AH1', 'AH2', 'AO0', | |
'AO1', 'AO2', 'AW0', 'AW1', 'AW2', 'AY0', 'AY1', 'AY2', 'B', 'CH', | |
'D', 'DH', 'EH0', 'EH1', 'EH2', 'ER0', 'ER1', 'ER2', 'EY0', 'EY1', | |
'EY2', 'F', 'G', 'HH', 'IH0', 'IH1', 'IH2', 'IY0', 'IY1', 'IY2', | |
'JH', 'K', 'L', 'M', 'N', 'NG', 'OW0', 'OW1', 'OW2', 'OY0', | |
'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH0', 'UH1', | |
'UH2', 'UW', 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH', | |
' '] | |
# Number of phonemes | |
vocab = len(phonemes) | |
# Model params. | |
kwargs = { | |
'vocab' : vocab, | |
'text_input' : text_input, | |
'audio_input' : audio_input, | |
'frame_length' : 400, | |
'hop_length' : 160, | |
'num_mel' : 40, | |
'sample_rate' : 16000, | |
'log_mel' : False, | |
'stack_extractor' : args.stack_extractor, | |
} | |
# Make tensorboard dict. | |
param = kwargs | |
with strategy.scope(): | |
model = ukws.BaseUKWS(**kwargs) | |
if args.load_checkpoint_path: | |
checkpoint_dir=args.load_checkpoint_path | |
checkpoint = tf.train.Checkpoint(model=model) | |
checkpoint_manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=5) | |
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir) | |
if latest_checkpoint: | |
checkpoint.restore(latest_checkpoint) | |
print("Checkpoint restored!") | |
# @tf.function | |
def test_step_metric_only(inputs): | |
clean_speech = inputs[0] | |
text = inputs[1] | |
labels = inputs[2] | |
prob = model(clean_speech, text, training=False)[0] | |
dim1=labels.shape[0]//20 | |
prob = tf.reshape(prob,[dim1,20]) | |
labels = tf.reshape(labels,[dim1,20]) | |
predictions = tf.math.argmax(prob, axis=1) | |
actuals = tf.math.argmax(labels, axis=1) | |
true_count = tf.reduce_sum(tf.cast(tf.math.equal(predictions , actuals), tf.float32)).numpy() | |
num_testdata = dim1 | |
return true_count, num_testdata | |
def distributed_test_step_metric_only(dataset_inputs): | |
true_count, num_testdata = strategy.run(test_step_metric_only, args=(dataset_inputs,)) | |
return true_count, num_testdata | |
total_true_count = 0 | |
total_num_testdata = 0 | |
for x in test_google_dist_dataset: | |
true_count, num_testdata = distributed_test_step_metric_only(x) | |
total_true_count += true_count | |
total_num_testdata += num_testdata | |
accuracy = total_true_count / total_num_testdata * 100.0 | |
print("準確率:", accuracy, "%") | |