|
from transformers import PreTrainedModel, modeling_outputs |
|
from torch import nn |
|
import torch |
|
from .configuration_word2vec import PretrainedWord2VecHFConfig |
|
|
|
class PretrainedWord2VecHFModel(PreTrainedModel): |
|
config_class = PretrainedWord2VecHFConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.embeddings = nn.Embedding(config.num_words, config.vector_size) |
|
|
|
def set_embeddings(self, embeddings): |
|
self.embeddings = nn.Embedding.from_pretrained(torch.tensor(embeddings)) |
|
|
|
def forward(self, input_ids, **kwargs): |
|
if type(input_ids) != torch.tensor: |
|
input_ids = torch.tensor(input_ids) |
|
x = self.embeddings(input_ids) |
|
return modeling_outputs.BaseModelOutput(last_hidden_state=x) |
|
|