Plonk / models /preprocessing.py
nicolas-dufour's picture
squash: merge all unpushed commits
c4c7cee
raw
history blame
1.55 kB
import torch
from torch import nn
import numpy as np
class NormGPS(nn.Module):
def __init__(self, input_key="gps", output_key="x_0", normalize=True):
super().__init__()
self.input_key = input_key
self.output_key = output_key
self.normalize = normalize
if self.normalize:
self.register_buffer(
"gps_normalize", 1 / torch.Tensor([np.pi * 0.5, np.pi]).unsqueeze(0)
)
def forward(self, batch):
"""Normalize latitude longtitude radians to -1, 1.""" # not used currently
x = batch[self.input_key]
if self.normalize:
x = x * self.gps_normalize
batch[self.output_key] = x
return batch
class GPStoCartesian(nn.Module):
def __init__(self, input_key="gps", output_key="x_0"):
super().__init__()
self.input_key = input_key
self.output_key = output_key
def forward(self, batch):
"""Project latitude longtitude radians to 3D coordinates."""
x = batch[self.input_key]
lat, lon = x[:, 0], x[:, 1]
x = torch.stack([lat.cos() * lon.cos(), lat.cos() * lon.sin(), lat.sin()], dim=-1)
batch[self.output_key] = x
return batch
class PrecomputedPreconditioning:
def __init__(
self,
input_key="emb",
output_key="emb",
):
self.input_key = input_key
self.output_key = output_key
def __call__(self, batch, device=None):
batch[self.output_key] = batch[self.input_key]
return batch