File size: 2,651 Bytes
b213d84 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import pickle
import torch
from torch import nn
from detectron2.utils.file_io import PathManager
from .utils import normalize_embeddings
class VertexFeatureEmbedder(nn.Module):
"""
Class responsible for embedding vertex features. Mapping from
feature space to the embedding space is a tensor of size [K, D], where
K = number of dimensions in the feature space
D = number of dimensions in the embedding space
Vertex features is a tensor of size [N, K], where
N = number of vertices
K = number of dimensions in the feature space
Vertex embeddings are computed as F * E = tensor of size [N, D]
"""
def __init__(
self, num_vertices: int, feature_dim: int, embed_dim: int, train_features: bool = False
):
"""
Initialize embedder, set random embeddings
Args:
num_vertices (int): number of vertices to embed
feature_dim (int): number of dimensions in the feature space
embed_dim (int): number of dimensions in the embedding space
train_features (bool): determines whether vertex features should
be trained (default: False)
"""
super(VertexFeatureEmbedder, self).__init__()
if train_features:
self.features = nn.Parameter(torch.Tensor(num_vertices, feature_dim))
else:
self.register_buffer("features", torch.Tensor(num_vertices, feature_dim))
self.embeddings = nn.Parameter(torch.Tensor(feature_dim, embed_dim))
self.reset_parameters()
@torch.no_grad()
def reset_parameters(self):
self.features.zero_()
self.embeddings.zero_()
def forward(self) -> torch.Tensor:
"""
Produce vertex embeddings, a tensor of shape [N, D] where:
N = number of vertices
D = number of dimensions in the embedding space
Return:
Full vertex embeddings, a tensor of shape [N, D]
"""
return normalize_embeddings(torch.mm(self.features, self.embeddings))
@torch.no_grad()
def load(self, fpath: str):
"""
Load data from a file
Args:
fpath (str): file path to load data from
"""
with PathManager.open(fpath, "rb") as hFile:
data = pickle.load(hFile)
for name in ["features", "embeddings"]:
if name in data:
getattr(self, name).copy_(
torch.tensor(data[name]).float().to(device=getattr(self, name).device)
)
|