codemix_hate / models.py
debajyotimaz's picture
Upload models.py
aa14a33 verified
raw
history blame
1.68 kB
import torch
from transformers import BertTokenizer
from torch import nn
from transformers import BertModel
#CustomMBERTModel= torch.load("/data2/Akash_for_interface/model_mbert_1416.pt")
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
sentences= "you are good person."
max_len=len(sentences)
encoding = tokenizer.encode_plus(sentences,add_special_tokens=True,max_length=max_len,padding='max_length',truncation=True,return_tensors='pt')
input_ids=encoding['input_ids'].flatten(),
attention_mask= encoding['attention_mask'].flatten()
#print(input_ids[0])
labels=["Non Hateful","Hateful"]
device = 'cpu'
class CustomMBERTModel(nn.Module):
def __init__(self, num_labels):
super(CustomMBERTModel, self).__init__()
self.bert = BertModel.from_pretrained('bert-base-multilingual-cased')
# Freeze all layers except the top 2
for param in self.bert.parameters():
param.requires_grad = False
# Unfreeze the parameters of the top 2 layers
for param in self.bert.encoder.layer[-2:].parameters():
param.requires_grad = True
# Adding Linear layer
self.linear = nn.Linear(self.bert.config.hidden_size, num_labels)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
hidden_states = outputs.last_hidden_state
pooled_output= hidden_states[:,0,:]
pooled_output=torch.squeeze(pooled_output,dim=1)
#print('p-shape:', pooled_output.shape)
logits = self.linear(pooled_output)
#print('l-shape:', logits.shape)
return logits