Spaces:
Sleeping
Sleeping
from torch import nn | |
from transformers import AutoConfig, AutoModel, AutoTokenizer | |
import torch | |
from torch.utils.data import Dataset | |
from utils import read_yaml | |
class BanglaHSDataset(Dataset): | |
def __init__(self, tokenizer, max_length): | |
self.tokenizer = tokenizer | |
self.max_length = max_length | |
def __len__(self): return 0 | |
def __getitem__(self, text): | |
inputs = self.tokenizer( | |
text, | |
max_length=self.max_length, padding='max_length', | |
truncation=True, | |
return_offsets_mapping=False | |
) | |
for k, v in inputs.items(): inputs[k] = torch.tensor(v, dtype=torch.long).unsqueeze(dim=0) | |
label = torch.tensor(0, dtype=torch.float) | |
return inputs, label | |
def get_class(index): | |
ind2cat = [ | |
'Geopolitical', | |
'Personal', | |
'Political', | |
'Religious', | |
] | |
return ind2cat[index] | |
if __name__ == '__main__': | |
cfg = read_yaml('./baseline.yaml') | |
# cfg.Model.target_size = 6 | |
# model = BanglaHS_Model(cfg.Model) | |
# #model.load_state_dict(torch.load('./model_fold-0_best.pt', map_location=torch.device('cpu'))) | |
# model.eval() | |
# ds = BanglaHSDataset(cfg.Dataset, model) | |
# x = ds['Hello hi'][0] | |
# with torch.no_grad(): | |
# y = model(x) | |
# print('y:', y) | |