File size: 3,180 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import time
import torch
from hpc_rll.origin.td import td_lambda_error, td_lambda_data
from hpc_rll.rl_utils.td import TDLambda
from testbase import mean_relative_error, times
assert torch.cuda.is_available()
use_cuda = True
T = 1024
B = 64
def td_val():
ori_value = torch.randn(T + 1, B)
ori_reward = torch.randn(T, B)
ori_weight = torch.randn(T, B)
hpc_value = ori_value.clone().detach()
hpc_reward = ori_reward.clone().detach()
hpc_weight = ori_weight.clone().detach()
hpc_td = TDLambda(T, B)
if use_cuda:
ori_value = ori_value.cuda()
ori_reward = ori_reward.cuda()
ori_weight = ori_weight.cuda()
hpc_value = hpc_value.cuda()
hpc_reward = hpc_reward.cuda()
hpc_weight = hpc_weight.cuda()
hpc_td = hpc_td.cuda()
ori_value.requires_grad_(True)
ori_loss = td_lambda_error(td_lambda_data(ori_value, ori_reward, ori_weight))
ori_loss = ori_loss.mean()
ori_loss.backward()
if use_cuda:
torch.cuda.synchronize()
hpc_value.requires_grad_(True)
hpc_loss = hpc_td(hpc_value, hpc_reward, hpc_weight)
hpc_loss = hpc_loss.mean()
hpc_loss.backward()
if use_cuda:
torch.cuda.synchronize()
mre = mean_relative_error(
torch.flatten(ori_loss).cpu().detach().numpy(),
torch.flatten(hpc_loss).cpu().detach().numpy()
)
print("td fp mean_relative_error: " + str(mre))
mre = mean_relative_error(
torch.flatten(ori_value.grad).cpu().detach().numpy(),
torch.flatten(hpc_value.grad).cpu().detach().numpy()
)
print("td bp mean_relative_error: " + str(mre))
def td_perf():
ori_value = torch.randn(T + 1, B)
ori_reward = torch.randn(T, B)
ori_weight = torch.randn(T, B)
hpc_value = ori_value.clone().detach()
hpc_reward = ori_reward.clone().detach()
hpc_weight = ori_weight.clone().detach()
hpc_td = TDLambda(T, B)
if use_cuda:
ori_value = ori_value.cuda()
ori_reward = ori_reward.cuda()
ori_weight = ori_weight.cuda()
hpc_value = hpc_value.cuda()
hpc_reward = hpc_reward.cuda()
hpc_weight = hpc_weight.cuda()
hpc_td = hpc_td.cuda()
ori_value.requires_grad_(True)
for i in range(times):
t = time.time()
ori_loss = td_lambda_error(td_lambda_data(ori_value, ori_reward, ori_weight))
ori_loss = ori_loss.mean()
ori_loss.backward()
if use_cuda:
torch.cuda.synchronize()
print('epoch: {}, original td cost time: {}'.format(i, time.time() - t))
hpc_value.requires_grad_(True)
for i in range(times):
t = time.time()
hpc_loss = hpc_td(hpc_value, hpc_reward, hpc_weight)
hpc_loss = hpc_loss.mean()
hpc_loss.backward()
if use_cuda:
torch.cuda.synchronize()
print('epoch: {}, hpc td cost time: {}'.format(i, time.time() - t))
if __name__ == '__main__':
print("target problem: T = {}, B = {}".format(T, B))
print("================run td validation test================")
td_val()
print("================run td performance test================")
td_perf()
|