glove-wiki-gigaword-50 / pretrained_word2vec.py
Iseratho's picture
add model
d4efd8e
raw
history blame
914 Bytes
from transformers import PreTrainedModel, PretrainedConfig
from torch import nn
import torch
class PretrainedWord2VecHFConfig(PretrainedConfig):
model_type = "glove"
def __init__(self, num_words=400001, vector_size=50, **kwargs):
self.num_words = num_words
self.vector_size = vector_size
self.hidden_size = self.vector_size # Required for sBERT
super().__init__(**kwargs)
class PretrainedWord2VecHFModel(PreTrainedModel):
config_class = PretrainedWord2VecHFConfig
def __init__(self, config):
super().__init__(config)
self.embeddings = nn.Embedding(config.num_words, config.vector_size)
def set_embeddings(self, embeddings):
self.embeddings = nn.Embedding.from_pretrained(torch.tensor(embeddings))
def forward(self, input_ids, **kwargs):
x = self.embeddings(torch.tensor(input_ids))
return x