CL-KWS_202408_v1 / inference.py
Francis0917's picture
Upload folder using huggingface_hub
2045faa verified
raw
history blame
4.74 kB
import sys, os, datetime, warnings, argparse
import tensorflow as tf
import numpy as np
from model import ukws
from dataset import google_infe202405
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
warnings.filterwarnings('ignore')
warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning)
np.warnings.filterwarnings('ignore', category=np.VisibleDeprecationWarning)
warnings.simplefilter("ignore")
seed = 42
tf.random.set_seed(seed)
np.random.seed(seed)
parser = argparse.ArgumentParser()
parser.add_argument('--text_input', required=False, type=str, default='g2p_embed')
parser.add_argument('--audio_input', required=False, type=str, default='both')
parser.add_argument('--load_checkpoint_path', required=True, type=str)
parser.add_argument('--google_pkl', required=False, type=str, default='/home/DB/data/google_test_all.pkl')
parser.add_argument('--stack_extractor', action='store_true')
args = parser.parse_args()
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
try:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
except RuntimeError as e:
print(e)
strategy = tf.distribute.MirroredStrategy()
# Batch size per GPU
GLOBAL_BATCH_SIZE = 1000 * strategy.num_replicas_in_sync
BATCH_SIZE_PER_REPLICA = GLOBAL_BATCH_SIZE / strategy.num_replicas_in_sync
# Make Dataloader
text_input = args.text_input
audio_input = args.audio_input
load_checkpoint_path = args.load_checkpoint_path
test_google_dataset = google_infe202405.GoogleCommandsDataloader(batch_size=GLOBAL_BATCH_SIZE, features=text_input, shuffle=False, pkl=args.google_pkl)
test_google_dataset = google_infe202405.convert_sequence_to_dataset(test_google_dataset)
test_google_dist_dataset = strategy.experimental_distribute_dataset(test_google_dataset)
phonemes = ["<pad>", ] + ['AA0', 'AA1', 'AA2', 'AE0', 'AE1', 'AE2', 'AH0', 'AH1', 'AH2', 'AO0',
'AO1', 'AO2', 'AW0', 'AW1', 'AW2', 'AY0', 'AY1', 'AY2', 'B', 'CH',
'D', 'DH', 'EH0', 'EH1', 'EH2', 'ER0', 'ER1', 'ER2', 'EY0', 'EY1',
'EY2', 'F', 'G', 'HH', 'IH0', 'IH1', 'IH2', 'IY0', 'IY1', 'IY2',
'JH', 'K', 'L', 'M', 'N', 'NG', 'OW0', 'OW1', 'OW2', 'OY0',
'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH0', 'UH1',
'UH2', 'UW', 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH',
' ']
# Number of phonemes
vocab = len(phonemes)
# Model params.
kwargs = {
'vocab' : vocab,
'text_input' : text_input,
'audio_input' : audio_input,
'frame_length' : 400,
'hop_length' : 160,
'num_mel' : 40,
'sample_rate' : 16000,
'log_mel' : False,
'stack_extractor' : args.stack_extractor,
}
# Make tensorboard dict.
param = kwargs
with strategy.scope():
model = ukws.BaseUKWS(**kwargs)
if args.load_checkpoint_path:
checkpoint_dir=args.load_checkpoint_path
checkpoint = tf.train.Checkpoint(model=model)
checkpoint_manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=5)
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
if latest_checkpoint:
checkpoint.restore(latest_checkpoint)
print("Checkpoint restored!")
# @tf.function
def test_step_metric_only(inputs):
clean_speech = inputs[0]
text = inputs[1]
labels = inputs[2]
prob = model(clean_speech, text, training=False)[0]
dim1=labels.shape[0]//20
prob = tf.reshape(prob,[dim1,20])
labels = tf.reshape(labels,[dim1,20])
predictions = tf.math.argmax(prob, axis=1)
actuals = tf.math.argmax(labels, axis=1)
true_count = tf.reduce_sum(tf.cast(tf.math.equal(predictions , actuals), tf.float32)).numpy()
num_testdata = dim1
return true_count, num_testdata
def distributed_test_step_metric_only(dataset_inputs):
true_count, num_testdata = strategy.run(test_step_metric_only, args=(dataset_inputs,))
return true_count, num_testdata
total_true_count = 0
total_num_testdata = 0
for x in test_google_dist_dataset:
true_count, num_testdata = distributed_test_step_metric_only(x)
total_true_count += true_count
total_num_testdata += num_testdata
accuracy = total_true_count / total_num_testdata * 100.0
print("準確率:", accuracy, "%")