|
import time |
|
import torch |
|
from hpc_rll.origin.upgo import upgo_loss |
|
from hpc_rll.rl_utils.upgo import UPGO |
|
from testbase import mean_relative_error, times |
|
|
|
assert torch.cuda.is_available() |
|
use_cuda = True |
|
|
|
T = 256 |
|
B = 256 |
|
N = 256 |
|
|
|
|
|
def upgo_val(): |
|
ori_target_output = torch.randn(T, B, N) |
|
ori_rhos = torch.randn(T, B) |
|
ori_action = torch.randint( |
|
0, N, size=( |
|
T, |
|
B, |
|
) |
|
) |
|
ori_rewards = torch.randn(T, B) |
|
ori_bootstrap_values = torch.randn(T + 1, B) |
|
|
|
hpc_target_output = ori_target_output.clone().detach() |
|
hpc_rhos = ori_rhos.clone().detach() |
|
hpc_action = ori_action.clone().detach() |
|
hpc_rewards = ori_rewards.clone().detach() |
|
hpc_bootstrap_values = ori_bootstrap_values.clone().detach() |
|
hpc_upgo = UPGO(T, B, N) |
|
|
|
if use_cuda: |
|
ori_target_output = ori_target_output.cuda() |
|
ori_rhos = ori_rhos.cuda() |
|
ori_action = ori_action.cuda() |
|
ori_rewards = ori_rewards.cuda() |
|
ori_bootstrap_values = ori_bootstrap_values.cuda() |
|
|
|
hpc_target_output = hpc_target_output.cuda() |
|
hpc_rhos = hpc_rhos.cuda() |
|
hpc_action = hpc_action.cuda() |
|
hpc_rewards = hpc_rewards.cuda() |
|
hpc_bootstrap_values = hpc_bootstrap_values.cuda() |
|
hpc_upgo = hpc_upgo.cuda() |
|
|
|
ori_target_output.requires_grad_(True) |
|
ori_loss = upgo_loss(ori_target_output, ori_rhos, ori_action, ori_rewards, ori_bootstrap_values) |
|
ori_loss = ori_loss.mean() |
|
ori_loss.backward() |
|
if use_cuda: |
|
torch.cuda.synchronize() |
|
|
|
hpc_target_output.requires_grad_(True) |
|
hpc_loss = hpc_upgo(hpc_target_output, hpc_rhos, hpc_action, hpc_rewards, hpc_bootstrap_values) |
|
hpc_loss = hpc_loss.mean() |
|
hpc_loss.backward() |
|
if use_cuda: |
|
torch.cuda.synchronize() |
|
|
|
mre = mean_relative_error( |
|
torch.flatten(ori_loss).cpu().detach().numpy(), |
|
torch.flatten(hpc_loss).cpu().detach().numpy() |
|
) |
|
print("upgo fp mean_relative_error: " + str(mre)) |
|
mre = mean_relative_error( |
|
torch.flatten(ori_target_output.grad).cpu().detach().numpy(), |
|
torch.flatten(hpc_target_output.grad).cpu().detach().numpy() |
|
) |
|
print("upgo bp mean_relative_error: " + str(mre)) |
|
|
|
|
|
def upgo_perf(): |
|
ori_target_output = torch.randn(T, B, N) |
|
ori_rhos = torch.randn(T, B) |
|
ori_action = torch.randint( |
|
0, N, size=( |
|
T, |
|
B, |
|
) |
|
) |
|
ori_rewards = torch.randn(T, B) |
|
ori_bootstrap_values = torch.randn(T + 1, B) |
|
|
|
hpc_target_output = ori_target_output.clone().detach() |
|
hpc_rhos = ori_rhos.clone().detach() |
|
hpc_action = ori_action.clone().detach() |
|
hpc_rewards = ori_rewards.clone().detach() |
|
hpc_bootstrap_values = ori_bootstrap_values.clone().detach() |
|
hpc_upgo = UPGO(T, B, N) |
|
|
|
if use_cuda: |
|
ori_target_output = ori_target_output.cuda() |
|
ori_rhos = ori_rhos.cuda() |
|
ori_action = ori_action.cuda() |
|
ori_rewards = ori_rewards.cuda() |
|
ori_bootstrap_values = ori_bootstrap_values.cuda() |
|
|
|
hpc_target_output = hpc_target_output.cuda() |
|
hpc_rhos = hpc_rhos.cuda() |
|
hpc_action = hpc_action.cuda() |
|
hpc_rewards = hpc_rewards.cuda() |
|
hpc_bootstrap_values = hpc_bootstrap_values.cuda() |
|
hpc_upgo = hpc_upgo.cuda() |
|
|
|
ori_target_output.requires_grad_(True) |
|
for i in range(times): |
|
t = time.time() |
|
ori_loss = upgo_loss(ori_target_output, ori_rhos, ori_action, ori_rewards, ori_bootstrap_values) |
|
ori_loss = ori_loss.mean() |
|
ori_loss.backward() |
|
if use_cuda: |
|
torch.cuda.synchronize() |
|
print('epoch: {}, original upgo cost time: {}'.format(i, time.time() - t)) |
|
|
|
hpc_target_output.requires_grad_(True) |
|
for i in range(times): |
|
t = time.time() |
|
hpc_loss = hpc_upgo(hpc_target_output, hpc_rhos, hpc_action, hpc_rewards, hpc_bootstrap_values) |
|
hpc_loss = hpc_loss.mean() |
|
hpc_loss.backward() |
|
if use_cuda: |
|
torch.cuda.synchronize() |
|
print('epoch: {}, hpc upgo cost time: {}'.format(i, time.time() - t)) |
|
|
|
|
|
if __name__ == '__main__': |
|
print("target problem: T = {}, B = {}, N = {}".format(T, B, N)) |
|
print("================run upgo validation test================") |
|
upgo_val() |
|
print("================run upgo performance test================") |
|
upgo_perf() |
|
|