|
import time |
|
import torch |
|
from hpc_rll.origin.rnn import get_lstm |
|
from hpc_rll.torch_utils.network.rnn import LSTM |
|
from testbase import mean_relative_error, times |
|
|
|
assert torch.cuda.is_available() |
|
use_cuda = True |
|
|
|
seq_len = 64 |
|
batch_size = 3 |
|
input_size = 1792 |
|
hidden_size = 384 |
|
num_layers = 3 |
|
norm_type = 'LN' |
|
dropout = 0 |
|
|
|
|
|
|
|
|
|
def lstm_val(): |
|
ori_lstm = get_lstm('normal', input_size, hidden_size, num_layers, norm_type, dropout) |
|
hpc_lstm = LSTM(seq_len, batch_size, input_size, hidden_size, num_layers, norm_type, dropout) |
|
|
|
ori_x = torch.randn(seq_len, batch_size, input_size) |
|
ori_h0 = torch.randn(num_layers, batch_size, hidden_size) |
|
ori_c0 = torch.randn(num_layers, batch_size, hidden_size) |
|
|
|
if use_cuda: |
|
ori_x = ori_x.cuda() |
|
ori_h0 = ori_h0.cuda() |
|
ori_c0 = ori_c0.cuda() |
|
ori_lstm = ori_lstm.cuda() |
|
hpc_lstm = hpc_lstm.cuda() |
|
|
|
ori_x.requires_grad_(True) |
|
ori_output, ori_next_state = ori_lstm(ori_x, [ori_h0, ori_c0]) |
|
ori_loss = ori_output.mean() |
|
ori_loss.backward() |
|
|
|
hpc_x = ori_x.clone().detach() |
|
hpc_h0 = ori_h0.clone().detach() |
|
hpc_c0 = ori_c0.clone().detach() |
|
hpc_x.requires_grad_(True) |
|
hpc_output, hpc_next_state = hpc_lstm(hpc_x, [hpc_h0, hpc_c0]) |
|
hpc_loss = hpc_output.mean() |
|
hpc_loss.backward() |
|
torch.cuda.synchronize() |
|
|
|
mre = mean_relative_error( |
|
torch.flatten(ori_loss).cpu().detach().numpy(), |
|
torch.flatten(hpc_loss).cpu().detach().numpy() |
|
) |
|
print("lstm fp mean_relative_error: " + str(mre)) |
|
mre = mean_relative_error( |
|
torch.flatten(ori_x.grad).cpu().detach().numpy(), |
|
torch.flatten(hpc_x.grad).cpu().detach().numpy() |
|
) |
|
print("lstm bp mean_relative_error: " + str(mre)) |
|
|
|
ori_wx_grad = torch.cat((ori_lstm.wx[0].grad, ori_lstm.wx[1].grad, ori_lstm.wx[2].grad)) |
|
hpc_wx_grad = hpc_lstm.wx.grad |
|
mre = mean_relative_error(torch.flatten(ori_wx_grad).cpu().numpy(), torch.flatten(hpc_wx_grad).cpu().numpy()) |
|
print("wx grad mean_relative_error: " + str(mre)) |
|
|
|
ori_wh_grad = torch.cat((ori_lstm.wh[0].grad, ori_lstm.wh[1].grad, ori_lstm.wh[2].grad)) |
|
hpc_wh_grad = hpc_lstm.wh.grad |
|
mre = mean_relative_error(torch.flatten(ori_wh_grad).cpu().numpy(), torch.flatten(hpc_wh_grad).cpu().numpy()) |
|
print("wh grad mean_relative_error: " + str(mre)) |
|
|
|
ori_bias_grad = ori_lstm.bias.grad |
|
hpc_bias_grad = hpc_lstm.bias.grad |
|
mre = mean_relative_error(torch.flatten(ori_bias_grad).cpu().numpy(), torch.flatten(hpc_bias_grad).cpu().numpy()) |
|
print("bias grad mean_relative_error: " + str(mre)) |
|
|
|
params = list(ori_lstm.parameters()) |
|
gamma_0_x = params[1] |
|
beta_0_x = params[2] |
|
gamma_0_h = params[3] |
|
beta_0_h = params[4] |
|
gamma_1_x = params[5] |
|
beta_1_x = params[6] |
|
gamma_1_h = params[7] |
|
beta_1_h = params[8] |
|
gamma_2_x = params[9] |
|
beta_2_x = params[10] |
|
gamma_2_h = params[11] |
|
beta_2_h = params[12] |
|
ori_gamma_grad = torch.cat( |
|
(gamma_0_x.grad, gamma_0_h.grad, gamma_1_x.grad, gamma_1_h.grad, gamma_2_x.grad, gamma_2_h.grad) |
|
) |
|
ori_beta_grad = torch.cat( |
|
(beta_0_x.grad, beta_0_h.grad, beta_1_x.grad, beta_1_h.grad, beta_2_x.grad, beta_2_h.grad) |
|
) |
|
hpc_gamma_grad = hpc_lstm.ln_gamma.grad |
|
hpc_beta_grad = hpc_lstm.ln_beta.grad |
|
mre = mean_relative_error(torch.flatten(ori_gamma_grad).cpu().numpy(), torch.flatten(hpc_gamma_grad).cpu().numpy()) |
|
print("ln gamma grad mean_relative_error: " + str(mre)) |
|
mre = mean_relative_error(torch.flatten(ori_beta_grad).cpu().numpy(), torch.flatten(hpc_beta_grad).cpu().numpy()) |
|
print("ln beta grad mean_relative_error: " + str(mre)) |
|
|
|
|
|
def lstm_perf(): |
|
ori_lstm = get_lstm('normal', input_size, hidden_size, num_layers, norm_type, dropout) |
|
hpc_lstm = LSTM(seq_len, batch_size, input_size, hidden_size, num_layers, norm_type, dropout) |
|
|
|
lstms = {'normal': ori_lstm, 'hpc': hpc_lstm} |
|
|
|
for lstm_type, lstm in lstms.items(): |
|
x = torch.rand(seq_len, batch_size, input_size) |
|
h0 = torch.randn(num_layers, batch_size, hidden_size) |
|
c0 = torch.randn(num_layers, batch_size, hidden_size) |
|
if use_cuda: |
|
x = x.cuda() |
|
h0 = h0.cuda() |
|
c0 = c0.cuda() |
|
lstm = lstm.cuda() |
|
|
|
prev_state = [h0, c0] |
|
x.requires_grad_(True) |
|
for i in range(times): |
|
t = time.time() |
|
output, _ = lstm(x, prev_state) |
|
loss = output.mean() |
|
loss.backward() |
|
if use_cuda: |
|
torch.cuda.synchronize() |
|
print('epoch: {}, {} lstm cost time: {}'.format(i, lstm_type, time.time() - t)) |
|
|
|
|
|
if __name__ == '__main__': |
|
print( |
|
"target problem: seq_len = {}, batch_size = {}, input_size = {}, hidden_size = {}, num_layers = {}, norm_type = {}, dropout = {}" |
|
.format(seq_len, batch_size, input_size, hidden_size, num_layers, norm_type, dropout) |
|
) |
|
print("==============lstm has no validation test================") |
|
|
|
|
|
print("===============run lstm performance test=================") |
|
lstm_perf() |
|
|