File size: 4,879 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import time
import torch
import torch.nn.functional as F
from hpc_rll.origin.vtrace import vtrace_error_discrete_action, vtrace_data
from hpc_rll.rl_utils.vtrace import VTrace
from testbase import mean_relative_error, times
assert torch.cuda.is_available()
use_cuda = True
T = 128
B = 128
N = 128
def vtrace_val():
ori_target_output = torch.randn(T, B, N)
ori_behaviour_output = torch.randn(T, B, N)
ori_action = torch.randint(
0, N, size=(
T,
B,
)
)
ori_value = torch.randn(T + 1, B)
ori_reward = torch.randn(T, B)
hpc_target_output = ori_target_output.clone().detach()
hpc_behaviour_output = ori_behaviour_output.clone().detach()
hpc_action = ori_action.clone().detach()
hpc_value = ori_value.clone().detach()
hpc_reward = ori_reward.clone().detach()
hpc_vtrace = VTrace(T, B, N)
if use_cuda:
ori_target_output = ori_target_output.cuda()
ori_behaviour_output = ori_behaviour_output.cuda()
ori_action = ori_action.cuda()
ori_value = ori_value.cuda()
ori_reward = ori_reward.cuda()
hpc_target_output = hpc_target_output.cuda()
hpc_behaviour_output = hpc_behaviour_output.cuda()
hpc_action = hpc_action.cuda()
hpc_value = hpc_value.cuda()
hpc_reward = hpc_reward.cuda()
hpc_vtrace = hpc_vtrace.cuda()
ori_target_output.requires_grad_(True)
ori_value.requires_grad_(True)
ori_loss = vtrace_error_discrete_action(
vtrace_data(ori_target_output, ori_behaviour_output, ori_action, ori_value, ori_reward, None)
)
ori_loss = sum(ori_loss)
ori_loss.backward()
hpc_target_output.requires_grad_(True)
hpc_value.requires_grad_(True)
hpc_loss = hpc_vtrace(hpc_target_output, hpc_behaviour_output, hpc_action, hpc_value, hpc_reward)
hpc_loss = sum(hpc_loss)
hpc_loss.backward()
mre = mean_relative_error(
torch.flatten(ori_loss).cpu().detach().numpy(),
torch.flatten(hpc_loss).cpu().detach().numpy()
)
print("vtrace fp mean_relative_error: " + str(mre))
mre = mean_relative_error(
torch.flatten(ori_target_output.grad).cpu().detach().numpy(),
torch.flatten(hpc_target_output.grad).cpu().detach().numpy()
)
print("vtrace bp target_output mean_relative_error: " + str(mre))
mre = mean_relative_error(
torch.flatten(ori_value.grad).cpu().detach().numpy(),
torch.flatten(hpc_value.grad).cpu().detach().numpy()
)
print("vtrace bp value mean_relative_error: " + str(mre))
def vtrace_perf():
ori_target_output = torch.randn(T, B, N)
ori_behaviour_output = torch.randn(T, B, N)
ori_action = torch.randint(
0, N, size=(
T,
B,
)
)
ori_value = torch.randn(T + 1, B)
ori_reward = torch.randn(T, B)
hpc_target_output = ori_target_output.clone().detach()
hpc_behaviour_output = ori_behaviour_output.clone().detach()
hpc_action = ori_action.clone().detach()
hpc_value = ori_value.clone().detach()
hpc_reward = ori_reward.clone().detach()
hpc_vtrace = VTrace(T, B, N)
if use_cuda:
ori_target_output = ori_target_output.cuda()
ori_behaviour_output = ori_behaviour_output.cuda()
ori_action = ori_action.cuda()
ori_value = ori_value.cuda()
ori_reward = ori_reward.cuda()
hpc_target_output = hpc_target_output.cuda()
hpc_behaviour_output = hpc_behaviour_output.cuda()
hpc_action = hpc_action.cuda()
hpc_value = hpc_value.cuda()
hpc_reward = hpc_reward.cuda()
hpc_vtrace = hpc_vtrace.cuda()
ori_target_output.requires_grad_(True)
ori_value.requires_grad_(True)
for i in range(times):
t = time.time()
ori_loss = vtrace_error_discrete_action(
vtrace_data(ori_target_output, ori_behaviour_output, ori_action, ori_value, ori_reward, None)
)
ori_loss = sum(ori_loss)
ori_loss.backward()
if use_cuda:
torch.cuda.synchronize()
print('epoch: {}, original vtrace cost time: {}'.format(i, time.time() - t))
hpc_target_output.requires_grad_(True)
hpc_value.requires_grad_(True)
for i in range(times):
t = time.time()
hpc_loss = hpc_vtrace(hpc_target_output, hpc_behaviour_output, hpc_action, hpc_value, hpc_reward)
hpc_loss = sum(hpc_loss)
hpc_loss.backward()
if use_cuda:
torch.cuda.synchronize()
print('epoch: {}, hpc vtrace cost time: {}'.format(i, time.time() - t))
if __name__ == '__main__':
print("target problem: T = {}, B = {}, N = {}".format(T, B, N))
print("================run vtrace validation test================")
vtrace_val()
print("================run vtrace performance test================")
vtrace_perf()
|