File size: 5,134 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import time
import torch
from hpc_rll.origin.td import q_nstep_td_error_with_rescale, q_nstep_td_data
from hpc_rll.rl_utils.td import QNStepTDRescale
from testbase import mean_relative_error, times
assert torch.cuda.is_available()
use_cuda = True
T = 1024
B = 64
N = 64
gamma = 0.95
def qntd_rescale_val():
ori_q = torch.randn(B, N)
ori_next_n_q = torch.randn(B, N)
ori_action = torch.randint(0, N, size=(B, ))
ori_next_n_action = torch.randint(0, N, size=(B, ))
ori_reward = torch.randn(T, B)
ori_done = torch.randn(B)
ori_weight = torch.randn(B)
hpc_q = ori_q.clone().detach()
hpc_next_n_q = ori_next_n_q.clone().detach()
hpc_action = ori_action.clone().detach()
hpc_next_n_action = ori_next_n_action.clone().detach()
hpc_reward = ori_reward.clone().detach()
hpc_done = ori_done.clone().detach()
hpc_weight = ori_weight.clone().detach()
hpc_qntd_rescale = QNStepTDRescale(T, B, N)
if use_cuda:
ori_q = ori_q.cuda()
ori_next_n_q = ori_next_n_q.cuda()
ori_action = ori_action.cuda()
ori_next_n_action = ori_next_n_action.cuda()
ori_reward = ori_reward.cuda()
ori_done = ori_done.cuda()
ori_weight = ori_weight.cuda()
hpc_q = hpc_q.cuda()
hpc_next_n_q = hpc_next_n_q.cuda()
hpc_action = hpc_action.cuda()
hpc_next_n_action = hpc_next_n_action.cuda()
hpc_reward = hpc_reward.cuda()
hpc_done = hpc_done.cuda()
hpc_weight = hpc_weight.cuda()
hpc_qntd_rescale = hpc_qntd_rescale.cuda()
ori_q.requires_grad_(True)
ori_loss, _ = q_nstep_td_error_with_rescale(
q_nstep_td_data(ori_q, ori_next_n_q, ori_action, ori_next_n_action, ori_reward, ori_done, ori_weight), gamma, T
)
ori_loss = ori_loss.mean()
ori_loss.backward()
if use_cuda:
torch.cuda.synchronize()
hpc_q.requires_grad_(True)
hpc_loss, _ = hpc_qntd_rescale(
hpc_q, hpc_next_n_q, hpc_action, hpc_next_n_action, hpc_reward, hpc_done, hpc_weight, gamma
)
hpc_loss = hpc_loss.mean()
hpc_loss.backward()
if use_cuda:
torch.cuda.synchronize()
mre = mean_relative_error(
torch.flatten(ori_loss).cpu().detach().numpy(),
torch.flatten(hpc_loss).cpu().detach().numpy()
)
print("qntd rescale fp mean_relative_error: " + str(mre))
mre = mean_relative_error(
torch.flatten(ori_q.grad).cpu().detach().numpy(),
torch.flatten(hpc_q.grad).cpu().detach().numpy()
)
print("qntd rescale bp mean_relative_error: " + str(mre))
def qntd_rescale_perf():
ori_q = torch.randn(B, N)
ori_next_n_q = torch.randn(B, N)
ori_action = torch.randint(0, N, size=(B, ))
ori_next_n_action = torch.randint(0, N, size=(B, ))
ori_reward = torch.randn(T, B)
ori_done = torch.randn(B)
ori_weight = torch.randn(B)
hpc_q = ori_q.clone().detach()
hpc_next_n_q = ori_next_n_q.clone().detach()
hpc_action = ori_action.clone().detach()
hpc_next_n_action = ori_next_n_action.clone().detach()
hpc_reward = ori_reward.clone().detach()
hpc_done = ori_done.clone().detach()
hpc_weight = ori_weight.clone().detach()
hpc_qntd_rescale = QNStepTDRescale(T, B, N)
if use_cuda:
ori_q = ori_q.cuda()
ori_next_n_q = ori_next_n_q.cuda()
ori_action = ori_action.cuda()
ori_next_n_action = ori_next_n_action.cuda()
ori_reward = ori_reward.cuda()
ori_done = ori_done.cuda()
ori_weight = ori_weight.cuda()
hpc_q = hpc_q.cuda()
hpc_next_n_q = hpc_next_n_q.cuda()
hpc_action = hpc_action.cuda()
hpc_next_n_action = hpc_next_n_action.cuda()
hpc_reward = hpc_reward.cuda()
hpc_done = hpc_done.cuda()
hpc_weight = hpc_weight.cuda()
hpc_qntd_rescale = hpc_qntd_rescale.cuda()
ori_q.requires_grad_(True)
for i in range(times):
t = time.time()
ori_loss, _ = q_nstep_td_error_with_rescale(
q_nstep_td_data(ori_q, ori_next_n_q, ori_action, ori_next_n_action, ori_reward, ori_done, ori_weight),
gamma, T
)
ori_loss = ori_loss.mean()
ori_loss.backward()
if use_cuda:
torch.cuda.synchronize()
print('epoch: {}, original qntd rescale cost time: {}'.format(i, time.time() - t))
hpc_q.requires_grad_(True)
for i in range(times):
t = time.time()
hpc_loss, _ = hpc_qntd_rescale(
hpc_q, hpc_next_n_q, hpc_action, hpc_next_n_action, hpc_reward, hpc_done, hpc_weight, gamma
)
hpc_loss = hpc_loss.mean()
hpc_loss.backward()
if use_cuda:
torch.cuda.synchronize()
print('epoch: {}, hpc qntd rescale cost time: {}'.format(i, time.time() - t))
if __name__ == '__main__':
print("target problem: T = {}, B = {}, N = {}, gamma = {}".format(T, B, N, gamma))
print("================run qntd rescale validation test================")
qntd_rescale_val()
print("================run qntd rescale performance test================")
qntd_rescale_perf()
|