File size: 898 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
import pytest
from ding.model.template.language_transformer import LanguageTransformer
@pytest.mark.unittest
class TestNLPPretrainedModel:
def check_model(self):
test_pids = [1]
cand_pids = [0, 2, 4]
problems = [
"This is problem 0", "This is the first question", "Second problem is here", "Another problem",
"This is the last problem"
]
ctxt_list = [problems[pid] for pid in test_pids]
cands_list = [problems[pid] for pid in cand_pids]
model = LanguageTransformer(model_name="bert-base-uncased", add_linear=True, embedding_size=256)
scores = model(ctxt_list, cands_list)
assert scores.shape == (1, 3)
model = LanguageTransformer(model_name="bert-base-uncased", add_linear=False, embedding_size=256)
scores = model(ctxt_list, cands_list)
assert scores.shape == (1, 3)
|