|
from transformers import PreTrainedModel, PretrainedConfig |
|
from torch import nn |
|
import torch |
|
|
|
class PretrainedWord2VecHFConfig(PretrainedConfig): |
|
model_type = "glove" |
|
|
|
def __init__(self, num_words=400001, vector_size=50, **kwargs): |
|
self.num_words = num_words |
|
self.vector_size = vector_size |
|
self.hidden_size = self.vector_size |
|
super().__init__(**kwargs) |
|
|
|
class PretrainedWord2VecHFModel(PreTrainedModel): |
|
config_class = PretrainedWord2VecHFConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.embeddings = nn.Embedding(config.num_words, config.vector_size) |
|
|
|
def set_embeddings(self, embeddings): |
|
self.embeddings = nn.Embedding.from_pretrained(torch.tensor(embeddings)) |
|
|
|
def forward(self, input_ids, **kwargs): |
|
x = self.embeddings(torch.tensor(input_ids)) |
|
return x |
|
|