File size: 1,984 Bytes
5769ee4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f360f3
5769ee4
2f360f3
 
5769ee4
2f360f3
5769ee4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from datasets import load_dataset
import datasets
import json
from mmcv import Config
import numpy 
import torch

from risk_biased.utils.waymo_dataloader import WaymoDataloaders


config_path = "risk_biased/config/waymo_config.py"
cfg = Config.fromfile(config_path)
dataloaders = WaymoDataloaders(cfg)
sample_dataloader = dataloaders.sample_dataloader()
(
    x,
    mask_x,
    y,
    mask_y,
    mask_loss,
    map_data,
    mask_map,
    offset,
    x_ego,
    y_ego,
) = sample_dataloader.collate_fn(sample_dataloader.dataset)

# dataset = load_dataset("json", data_files="../risk_biased_dataset/data.json", split="test", field="x")
# dataset = load_from_disk("../risk_biased_dataset/data.json")
dataset = load_dataset("jmercat/risk_biased_dataset", split="test")

x_c = torch.from_numpy(numpy.array(dataset["x"]).astype(numpy.float32))
mask_x_c = torch.from_numpy(numpy.array(dataset["mask_x"]).astype(numpy.bool_))
y_c = torch.from_numpy(numpy.array(dataset["y"]).astype(numpy.float32))
mask_y_c = torch.from_numpy(numpy.array(dataset["mask_y"]).astype(numpy.bool_))
mask_loss_c = torch.from_numpy( numpy.array(dataset["mask_loss"]).astype(numpy.bool_))
map_data_c = torch.from_numpy(numpy.array(dataset["map_data"]).astype(numpy.float32))
mask_map_c = torch.from_numpy(numpy.array(dataset["mask_map"]).astype(numpy.bool_))
offset_c = torch.from_numpy(numpy.array(dataset["offset"]).astype(numpy.float32))
x_ego_c = torch.from_numpy(numpy.array(dataset["x_ego"]).astype(numpy.float32))
y_ego_c = torch.from_numpy(numpy.array(dataset["y_ego"]).astype(numpy.float32))

assert torch.allclose(x, x_c)
assert torch.allclose(mask_x, mask_x_c)
assert torch.allclose(y, y_c)
assert torch.allclose(mask_y, mask_y_c)
assert torch.allclose(mask_loss, mask_loss_c)
assert torch.allclose(map_data, map_data_c)
assert torch.allclose(mask_map, mask_map_c)
assert torch.allclose(offset, offset_c)
assert torch.allclose(x_ego, x_ego_c)
assert torch.allclose(y_ego, y_ego_c)

print("All good!")