thewhole's picture
Upload 245 files
2fa4776
raw
history blame
26 kB
import bisect
import math
import random
from dataclasses import dataclass, field
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, IterableDataset
import threestudio
from threestudio import register
from threestudio.utils.base import Updateable
from threestudio.utils.config import parse_structured
from threestudio.utils.misc import get_device
from threestudio.utils.ops import (
get_mvp_matrix,
get_projection_matrix,
get_ray_directions,
get_rays,
)
from threestudio.utils.typing import *
import os
import numpy as np
def safe_normalize(x, eps=1e-20):
return x / torch.sqrt(torch.clamp(torch.sum(x * x, -1, keepdim=True), min=eps))
def convert_camera_to_world_transform(transform):
# 将右手坐标系的变换矩阵转换为左手坐标系
# 复制原始变换矩阵
converted_transform = transform.clone()
# 反转观察方向(将平移分量的第三个元素乘以-1)
converted_transform[:, 2] *= -1
# 交换第一行和第三行
converted_transform[[0, 2], :] = converted_transform[[2, 0], :]
return converted_transform
def circle_poses(device, radius=torch.tensor([3.2]), theta=torch.tensor([60]), phi=torch.tensor([0])):
theta = theta / 180 * np.pi
phi = phi / 180 * np.pi
centers = torch.stack([
radius * torch.sin(theta) * torch.sin(phi),
radius * torch.cos(theta),
radius * torch.sin(theta) * torch.cos(phi),
], dim=-1) # [B, 3]
# lookat
forward_vector = safe_normalize(centers)
up_vector = torch.FloatTensor([0, 1, 0]).to(device).unsqueeze(0).repeat(len(centers), 1)
right_vector = safe_normalize(torch.cross(forward_vector, up_vector, dim=-1))
up_vector = safe_normalize(torch.cross(right_vector, forward_vector, dim=-1))
poses = torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(len(centers), 1, 1)
poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1)
poses[:, :3, 3] = centers
return poses
trans_t = lambda t : torch.Tensor([
[1,0,0,0],
[0,1,0,0],
[0,0,1,t],
[0,0,0,1]]).float()
rot_phi = lambda phi : torch.Tensor([
[1,0,0,0],
[0,np.cos(phi),-np.sin(phi),0],
[0,np.sin(phi), np.cos(phi),0],
[0,0,0,1]]).float()
rot_theta = lambda th : torch.Tensor([
[np.cos(th),0,-np.sin(th),0],
[0,1,0,0],
[np.sin(th),0, np.cos(th),0],
[0,0,0,1]]).float()
def rodrigues_mat_to_rot(R):
eps =1e-16
trc = np.trace(R)
trc2 = (trc - 1.)/ 2.
#sinacostrc2 = np.sqrt(1 - trc2 * trc2)
s = np.array([R[2, 1] - R[1, 2], R[0, 2] - R[2, 0], R[1, 0] - R[0, 1]])
if (1 - trc2 * trc2) >= eps:
tHeta = np.arccos(trc2)
tHetaf = tHeta / (2 * (np.sin(tHeta)))
else:
tHeta = np.real(np.arccos(trc2))
tHetaf = 0.5 / (1 - tHeta / 6)
omega = tHetaf * s
return omega
def rodrigues_rot_to_mat(r):
wx,wy,wz = r
theta = np.sqrt(wx * wx + wy * wy + wz * wz)
a = np.cos(theta)
b = (1 - np.cos(theta)) / (theta*theta)
c = np.sin(theta) / theta
R = np.zeros([3,3])
R[0, 0] = a + b * (wx * wx)
R[0, 1] = b * wx * wy - c * wz
R[0, 2] = b * wx * wz + c * wy
R[1, 0] = b * wx * wy + c * wz
R[1, 1] = a + b * (wy * wy)
R[1, 2] = b * wy * wz - c * wx
R[2, 0] = b * wx * wz - c * wy
R[2, 1] = b * wz * wy + c * wx
R[2, 2] = a + b * (wz * wz)
return R
def pose_spherical(theta, phi, radius):
c2w = trans_t(radius)
c2w = rot_phi(phi/180.*np.pi) @ c2w
c2w = rot_theta(theta/180.*np.pi) @ c2w
c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w
return c2w
def convert_camera_pose(camera_pose):
# Clone the tensor to avoid in-place operations
colmap_pose = camera_pose.clone()
# Extract rotation and translation components
rotation = colmap_pose[:, :3, :3]
translation = colmap_pose[:, :3, 3]
# Change rotation orientation
rotation[:, 0, :] *= -1
rotation[:, 1, :] *= -1
# Change translation position
translation[:, 0] *= -1
translation[:, 1] *= -1
return colmap_pose
def convert_camera_pose(camera_pose):
# Clone the tensor to avoid in-place operations
colmap_pose = camera_pose.clone()
# Extract rotation and translation components
rotation = colmap_pose[:, :3, :3]
translation = colmap_pose[:, :3, 3]
# Change rotation orientation
rotation[:, 0, :] *= -1
rotation[:, 1, :] *= -1
# Change translation position
translation[:, 0] *= -1
translation[:, 1] *= -1
return colmap_pose
@dataclass
class RandomCameraDataModuleConfig:
# height, width, and batch_size should be Union[int, List[int]]
# but OmegaConf does not support Union of containers
height: Any = 512
width: Any = 512
batch_size: Any = 1
resolution_milestones: List[int] = field(default_factory=lambda: [])
eval_height: int = 512
eval_width: int = 512
eval_batch_size: int = 1
n_val_views: int = 1
n_test_views: int = 120
elevation_range: Tuple[float, float] = (-10, 60)
azimuth_range: Tuple[float, float] = (-180, 180)
camera_distance_range: Tuple[float, float] = (4.,6.)
fovy_range: Tuple[float, float] = (
40,
70,
) # in degrees, in vertical direction (along height)
camera_perturb: float = 0.
center_perturb: float = 0.
up_perturb: float = 0.0
light_position_perturb: float = 1.0
light_distance_range: Tuple[float, float] = (0.8, 1.5)
eval_elevation_deg: float = 15.0
eval_camera_distance: float = 6.
eval_fovy_deg: float = 70.0
light_sample_strategy: str = "dreamfusion"
batch_uniform_azimuth: bool = True
progressive_until: int = 0 # progressive ranges for elevation, azimuth, r, fovy
class RandomCameraIterableDataset(IterableDataset, Updateable):
def __init__(self, cfg: Any) -> None:
super().__init__()
self.cfg: RandomCameraDataModuleConfig = cfg
self.heights: List[int] = (
[self.cfg.height] if isinstance(self.cfg.height, int) else self.cfg.height
)
self.widths: List[int] = (
[self.cfg.width] if isinstance(self.cfg.width, int) else self.cfg.width
)
self.batch_sizes: List[int] = (
[self.cfg.batch_size]
if isinstance(self.cfg.batch_size, int)
else self.cfg.batch_size
)
assert len(self.heights) == len(self.widths) == len(self.batch_sizes)
self.resolution_milestones: List[int]
if (
len(self.heights) == 1
and len(self.widths) == 1
and len(self.batch_sizes) == 1
):
if len(self.cfg.resolution_milestones) > 0:
threestudio.warn(
"Ignoring resolution_milestones since height and width are not changing"
)
self.resolution_milestones = [-1]
else:
assert len(self.heights) == len(self.cfg.resolution_milestones) + 1
self.resolution_milestones = [-1] + self.cfg.resolution_milestones
self.directions_unit_focals = [
get_ray_directions(H=height, W=width, focal=1.0)
for (height, width) in zip(self.heights, self.widths)
]
self.height: int = self.heights[0]
self.width: int = self.widths[0]
self.batch_size: int = self.batch_sizes[0]
self.directions_unit_focal = self.directions_unit_focals[0]
self.elevation_range = self.cfg.elevation_range
self.azimuth_range = self.cfg.azimuth_range
self.camera_distance_range = self.cfg.camera_distance_range
self.fovy_range = self.cfg.fovy_range
def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False):
size_ind = bisect.bisect_right(self.resolution_milestones, global_step) - 1
self.height = self.heights[size_ind]
self.width = self.widths[size_ind]
self.batch_size = self.batch_sizes[size_ind]
self.directions_unit_focal = self.directions_unit_focals[size_ind]
threestudio.debug(
f"Training height: {self.height}, width: {self.width}, batch_size: {self.batch_size}"
)
# progressive view
self.progressive_view(global_step)
def __iter__(self):
while True:
yield {}
def progressive_view(self, global_step):
pass
# r = min(1.0, global_step / (self.cfg.progressive_until + 1))
# self.elevation_range = [
# (1 - r) * self.cfg.eval_elevation_deg + r * self.cfg.elevation_range[0],
# (1 - r) * self.cfg.eval_elevation_deg + r * self.cfg.elevation_range[1],
# ]
# self.azimuth_range = [
# (1 - r) * 0.0 + r * self.cfg.azimuth_range[0],
# (1 - r) * 0.0 + r * self.cfg.azimuth_range[1],
# ]
# self.camera_distance_range = [
# (1 - r) * self.cfg.eval_camera_distance
# + r * self.cfg.camera_distance_range[0],
# (1 - r) * self.cfg.eval_camera_distance
# + r * self.cfg.camera_distance_range[1],
# ]
# self.fovy_range = [
# (1 - r) * self.cfg.eval_fovy_deg + r * self.cfg.fovy_range[0],
# (1 - r) * self.cfg.eval_fovy_deg + r * self.cfg.fovy_range[1],
# ]
def collate(self, batch) -> Dict[str, Any]:
# sample elevation angles
elevation_deg: Float[Tensor, "B"]
elevation: Float[Tensor, "B"]
if random.random() < 0.5:
# sample elevation angles uniformly with a probability 0.5 (biased towards poles)
elevation_deg = (
torch.rand(self.batch_size)
* (self.elevation_range[1] - self.elevation_range[0])
+ self.elevation_range[0]
)
elevation = elevation_deg * math.pi / 180
else:
# otherwise sample uniformly on sphere
elevation_range_percent = [
(self.elevation_range[0] + 90.0) / 180.0,
(self.elevation_range[1] + 90.0) / 180.0,
]
# inverse transform sampling
elevation = torch.asin(
2
* (
torch.rand(self.batch_size)
* (elevation_range_percent[1] - elevation_range_percent[0])
+ elevation_range_percent[0]
)
- 1.0
)
elevation_deg = elevation / math.pi * 180.0
# sample azimuth angles from a uniform distribution bounded by azimuth_range
azimuth_deg: Float[Tensor, "B"]
if self.cfg.batch_uniform_azimuth:
# ensures sampled azimuth angles in a batch cover the whole range
azimuth_deg = (
torch.rand(self.batch_size) + torch.arange(self.batch_size)
) / self.batch_size * (
self.azimuth_range[1] - self.azimuth_range[0]
) + self.azimuth_range[
0
]
else:
# simple random sampling
azimuth_deg = (
torch.rand(self.batch_size)
* (self.azimuth_range[1] - self.azimuth_range[0])
+ self.azimuth_range[0]
)
azimuth = azimuth_deg * math.pi / 180
# sample distances from a uniform distribution bounded by distance_range
camera_distances: Float[Tensor, "B"] = (
torch.rand(self.batch_size)
* (self.camera_distance_range[1] - self.camera_distance_range[0])
+ self.camera_distance_range[0]
)
# convert spherical coordinates to cartesian coordinates
# right hand coordinate system, x back, y right, z up
# elevation in (-90, 90), azimuth from +x to +y in (-180, 180)
camera_positions: Float[Tensor, "B 3"] = torch.stack(
[
camera_distances * torch.cos(elevation) * torch.cos(azimuth),
camera_distances * torch.cos(elevation) * torch.sin(azimuth),
camera_distances * torch.sin(elevation),
],
dim=-1,
)
# default scene center at origin
center: Float[Tensor, "B 3"] = torch.zeros_like(camera_positions)
# default camera up direction as +z
up: Float[Tensor, "B 3"] = torch.as_tensor([0, 0, 1], dtype=torch.float32)[
None, :
].repeat(self.batch_size, 1)
# sample camera perturbations from a uniform distribution [-camera_perturb, camera_perturb]
camera_perturb: Float[Tensor, "B 3"] = (
torch.rand(self.batch_size, 3) * 2 * self.cfg.camera_perturb
- self.cfg.camera_perturb
)
camera_positions = camera_positions + camera_perturb
# sample center perturbations from a normal distribution with mean 0 and std center_perturb
center_perturb: Float[Tensor, "B 3"] = (
torch.randn(self.batch_size, 3) * self.cfg.center_perturb
)
center = center + center_perturb
# sample up perturbations from a normal distribution with mean 0 and std up_perturb
up_perturb: Float[Tensor, "B 3"] = (
torch.randn(self.batch_size, 3) * self.cfg.up_perturb
)
up = up + up_perturb
# sample fovs from a uniform distribution bounded by fov_range
fovy_deg: Float[Tensor, "B"] = (
torch.rand(self.batch_size) * (self.fovy_range[1] - self.fovy_range[0])
+ self.fovy_range[0]
)
fovy = fovy_deg * math.pi / 180
# sample light distance from a uniform distribution bounded by light_distance_range
light_distances: Float[Tensor, "B"] = (
torch.rand(self.batch_size)
* (self.cfg.light_distance_range[1] - self.cfg.light_distance_range[0])
+ self.cfg.light_distance_range[0]
)
if self.cfg.light_sample_strategy == "dreamfusion" or self.cfg.light_sample_strategy == "dreamfusion3dgs":
# sample light direction from a normal distribution with mean camera_position and std light_position_perturb
light_direction: Float[Tensor, "B 3"] = F.normalize(
camera_positions
+ torch.randn(self.batch_size, 3) * self.cfg.light_position_perturb,
dim=-1,
)
# get light position by scaling light direction by light distance
light_positions: Float[Tensor, "B 3"] = (
light_direction * light_distances[:, None]
)
elif self.cfg.light_sample_strategy == "magic3d":
# sample light direction within restricted angle range (pi/3)
local_z = F.normalize(camera_positions, dim=-1)
local_x = F.normalize(
torch.stack(
[local_z[:, 1], -local_z[:, 0], torch.zeros_like(local_z[:, 0])],
dim=-1,
),
dim=-1,
)
local_y = F.normalize(torch.cross(local_z, local_x, dim=-1), dim=-1)
rot = torch.stack([local_x, local_y, local_z], dim=-1)
light_azimuth = (
torch.rand(self.batch_size) * math.pi * 2 - math.pi
) # [-pi, pi]
light_elevation = (
torch.rand(self.batch_size) * math.pi / 3 + math.pi / 6
) # [pi/6, pi/2]
light_positions_local = torch.stack(
[
light_distances
* torch.cos(light_elevation)
* torch.cos(light_azimuth),
light_distances
* torch.cos(light_elevation)
* torch.sin(light_azimuth),
light_distances * torch.sin(light_elevation),
],
dim=-1,
)
light_positions = (rot @ light_positions_local[:, :, None])[:, :, 0]
else:
raise ValueError(
f"Unknown light sample strategy: {self.cfg.light_sample_strategy}"
)
lookat: Float[Tensor, "B 3"] = F.normalize(center - camera_positions, dim=-1)
right: Float[Tensor, "B 3"] = F.normalize(torch.cross(lookat, up), dim=-1)
up = F.normalize(torch.cross(right, lookat), dim=-1)
c2w3x4: Float[Tensor, "B 3 4"] = torch.cat(
[torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]],
dim=-1,
)
c2w: Float[Tensor, "B 4 4"] = torch.cat(
[c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1
)
c2w[:, 3, 3] = 1.0
# get directions by dividing directions_unit_focal by focal length
focal_length: Float[Tensor, "B"] = 0.5 * self.height / torch.tan(0.5 * fovy)
directions: Float[Tensor, "B H W 3"] = self.directions_unit_focal[
None, :, :, :
].repeat(self.batch_size, 1, 1, 1)
directions[:, :, :, :2] = (
directions[:, :, :, :2] / focal_length[:, None, None, None]
)
proj_mtx: Float[Tensor, "B 4 4"] = get_projection_matrix(
fovy, self.width / self.height, 0.1, 1000.0
) # FIXME: hard-coded near and far
mvp_mtx: Float[Tensor, "B 4 4"] = get_mvp_matrix(c2w, proj_mtx)
c2w_3dgs = []
for id in range(self.batch_size):
render_pose = pose_spherical( azimuth_deg[id] + 180.0, -elevation_deg[id], camera_distances[id])
# print(azimuth_deg[id] , -elevation_deg[id], camera_distances[id]*2.0)
# print(render_pose)
matrix = torch.linalg.inv(render_pose)
# R = -np.transpose(matrix[:3,:3])
# R = -np.transpose(matrix[:3,:3])
R = -torch.transpose(matrix[:3,:3], 0, 1)
R[:,0] = -R[:,0]
T = -matrix[:3, 3]
c2w_single = torch.cat([R, T[:,None]], 1)
c2w_single = torch.cat([c2w_single, torch.tensor([[0,0,0,1]])], 0)
# c2w_single = convert_camera_to_world_transform(c2w_single)
c2w_3dgs.append(c2w_single)
c2w_3dgs = torch.stack(c2w_3dgs, 0)
return {
"mvp_mtx": mvp_mtx,
"camera_positions": camera_positions,
"c2w": c2w,
"c2w_3dgs":c2w_3dgs,
"light_positions": light_positions,
"elevation": elevation_deg,
"azimuth": azimuth_deg,
"camera_distances": camera_distances,
"height": self.height,
"width": self.width,
"fovy":fovy,
}
class RandomCameraDataset(Dataset):
def __init__(self, cfg: Any, split: str) -> None:
super().__init__()
self.cfg: RandomCameraDataModuleConfig = cfg
self.split = split
if split == "val":
self.n_views = self.cfg.n_val_views
else:
self.n_views = self.cfg.n_test_views
azimuth_deg: Float[Tensor, "B"]
if self.split == "val":
# make sure the first and last view are not the same
azimuth_deg = torch.linspace(-180., 180.0, self.n_views + 1)[: self.n_views]
else:
azimuth_deg = torch.linspace(-180., 180.0, self.n_views)
elevation_deg: Float[Tensor, "B"] = torch.full_like(
azimuth_deg, self.cfg.eval_elevation_deg
)
camera_distances: Float[Tensor, "B"] = torch.full_like(
elevation_deg, self.cfg.eval_camera_distance
)
elevation = elevation_deg * math.pi / 180
azimuth = azimuth_deg * math.pi / 180
# convert spherical coordinates to cartesian coordinates
# right hand coordinate system, x back, y right, z up
# elevation in (-90, 90), azimuth from +x to +y in (-180, 180)
camera_positions: Float[Tensor, "B 3"] = torch.stack(
[
camera_distances * torch.cos(elevation) * torch.cos(azimuth),
camera_distances * torch.cos(elevation) * torch.sin(azimuth),
camera_distances * torch.sin(elevation),
],
dim=-1,
)
# default scene center at origin
center: Float[Tensor, "B 3"] = torch.zeros_like(camera_positions)
# default camera up direction as +z
up: Float[Tensor, "B 3"] = torch.as_tensor([0, 0, 1], dtype=torch.float32)[
None, :
].repeat(self.cfg.eval_batch_size, 1)
fovy_deg: Float[Tensor, "B"] = torch.full_like(
elevation_deg, self.cfg.eval_fovy_deg
)
fovy = fovy_deg * math.pi / 180
light_positions: Float[Tensor, "B 3"] = camera_positions
lookat: Float[Tensor, "B 3"] = F.normalize(center - camera_positions, dim=-1)
right: Float[Tensor, "B 3"] = F.normalize(torch.cross(lookat, up), dim=-1)
up = F.normalize(torch.cross(right, lookat), dim=-1)
c2w3x4: Float[Tensor, "B 3 4"] = torch.cat(
[torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]],
dim=-1,
)
c2w: Float[Tensor, "B 4 4"] = torch.cat(
[c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1
)
c2w[:, 3, 3] = 1.0
# get directions by dividing directions_unit_focal by focal length
focal_length: Float[Tensor, "B"] = (
0.5 * self.cfg.eval_height / torch.tan(0.5 * fovy)
)
directions_unit_focal = get_ray_directions(
H=self.cfg.eval_height, W=self.cfg.eval_width, focal=1.0
)
directions: Float[Tensor, "B H W 3"] = directions_unit_focal[
None, :, :, :
].repeat(self.n_views, 1, 1, 1)
directions[:, :, :, :2] = (
directions[:, :, :, :2] / focal_length[:, None, None, None]
)
proj_mtx: Float[Tensor, "B 4 4"] = get_projection_matrix(
fovy, self.cfg.eval_width / self.cfg.eval_height, 0.1, 1000.0
) # FIXME: hard-coded near and far
mvp_mtx: Float[Tensor, "B 4 4"] = get_mvp_matrix(c2w, proj_mtx)
c2w_3dgs = []
for id in range(self.n_views):
render_pose = pose_spherical( azimuth_deg[id] + 180.0, -elevation_deg[id], camera_distances[id])
matrix = torch.linalg.inv(render_pose)
# R = -np.transpose(matrix[:3,:3])
# R = -np.transpose(matrix[:3,:3])
R = -torch.transpose(matrix[:3,:3], 0, 1)
R[:,0] = -R[:,0]
T = -matrix[:3, 3]
c2w_single = torch.cat([R, T[:,None]], 1)
c2w_single = torch.cat([c2w_single, torch.tensor([[0,0,0,1]])], 0)
# c2w_single = convert_camera_to_world_transform(c2w_single)
c2w_3dgs.append(c2w_single)
c2w_3dgs = torch.stack(c2w_3dgs, 0)
self.mvp_mtx = mvp_mtx
self.c2w = c2w
self.c2w_3dgs = c2w_3dgs
self.camera_positions = camera_positions
self.light_positions = light_positions
self.elevation, self.azimuth = elevation, azimuth
self.elevation_deg, self.azimuth_deg = elevation_deg, azimuth_deg
self.camera_distances = camera_distances
self.fovy = fovy
def __len__(self):
return self.n_views
def __getitem__(self, index):
return {
"index": index,
"mvp_mtx": self.mvp_mtx[index],
"c2w": self.c2w[index],
"c2w_3dgs": self.c2w_3dgs[index],
"camera_positions": self.camera_positions[index],
"light_positions": self.light_positions[index],
"elevation": self.elevation_deg[index],
"azimuth": self.azimuth_deg[index],
"camera_distances": self.camera_distances[index],
"height": self.cfg.eval_height,
"width": self.cfg.eval_width,
"fovy":self.fovy[index],
}
def collate(self, batch):
batch = torch.utils.data.default_collate(batch)
batch.update({"height": self.cfg.eval_height, "width": self.cfg.eval_width})
return batch
@register("random-camera-datamodule")
class RandomCameraDataModule(pl.LightningDataModule):
cfg: RandomCameraDataModuleConfig
def __init__(self, cfg: Optional[Union[dict, DictConfig]] = None) -> None:
super().__init__()
self.cfg = parse_structured(RandomCameraDataModuleConfig, cfg)
def setup(self, stage=None) -> None:
if stage in [None, "fit"]:
self.train_dataset = RandomCameraIterableDataset(self.cfg)
if stage in [None, "fit", "validate"]:
self.val_dataset = RandomCameraDataset(self.cfg, "val")
if stage in [None, "test", "predict"]:
self.test_dataset = RandomCameraDataset(self.cfg, "test")
def prepare_data(self):
pass
def general_loader(self, dataset, batch_size, collate_fn=None) -> DataLoader:
return DataLoader(
dataset,
# very important to disable multi-processing if you want to change self attributes at runtime!
# (for example setting self.width and self.height in update_step)
num_workers=0, # type: ignore
batch_size=batch_size,
collate_fn=collate_fn,
)
def train_dataloader(self) -> DataLoader:
return self.general_loader(
self.train_dataset, batch_size=None, collate_fn=self.train_dataset.collate
)
def val_dataloader(self) -> DataLoader:
return self.general_loader(
self.val_dataset, batch_size=1, collate_fn=self.val_dataset.collate
)
# return self.general_loader(self.train_dataset, batch_size=None, collate_fn=self.train_dataset.collate)
def test_dataloader(self) -> DataLoader:
return self.general_loader(
self.test_dataset, batch_size=1, collate_fn=self.test_dataset.collate
)
def predict_dataloader(self) -> DataLoader:
return self.general_loader(
self.test_dataset, batch_size=1, collate_fn=self.test_dataset.collate
)