|
import pytest |
|
|
|
from ding.model.template.language_transformer import LanguageTransformer |
|
|
|
|
|
@pytest.mark.unittest |
|
class TestNLPPretrainedModel: |
|
|
|
def check_model(self): |
|
test_pids = [1] |
|
cand_pids = [0, 2, 4] |
|
problems = [ |
|
"This is problem 0", "This is the first question", "Second problem is here", "Another problem", |
|
"This is the last problem" |
|
] |
|
ctxt_list = [problems[pid] for pid in test_pids] |
|
cands_list = [problems[pid] for pid in cand_pids] |
|
|
|
model = LanguageTransformer(model_name="bert-base-uncased", add_linear=True, embedding_size=256) |
|
scores = model(ctxt_list, cands_list) |
|
assert scores.shape == (1, 3) |
|
|
|
model = LanguageTransformer(model_name="bert-base-uncased", add_linear=False, embedding_size=256) |
|
scores = model(ctxt_list, cands_list) |
|
assert scores.shape == (1, 3) |
|
|