Spaces:
Runtime error
Runtime error
File size: 482 Bytes
7900c16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
import torch.nn as nn
class SegEmbedding(nn.Module):
"""
BERT Segment Embedding
"""
def __init__(self, args, _):
super(SegEmbedding, self).__init__()
self.embedding = nn.Embedding(3, args.emb_size)
def forward(self, _, seg):
"""
Args:
seg: [batch_size x seq_length]
Returns:
emb: [batch_size x seq_length x hidden_size]
"""
seg_emb = self.embedding(seg)
return seg_emb
|