File size: 898 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import pytest

from ding.model.template.language_transformer import LanguageTransformer


@pytest.mark.unittest
class TestNLPPretrainedModel:

    def check_model(self):
        test_pids = [1]
        cand_pids = [0, 2, 4]
        problems = [
            "This is problem 0", "This is the first question", "Second problem is here", "Another problem",
            "This is the last problem"
        ]
        ctxt_list = [problems[pid] for pid in test_pids]
        cands_list = [problems[pid] for pid in cand_pids]

        model = LanguageTransformer(model_name="bert-base-uncased", add_linear=True, embedding_size=256)
        scores = model(ctxt_list, cands_list)
        assert scores.shape == (1, 3)

        model = LanguageTransformer(model_name="bert-base-uncased", add_linear=False, embedding_size=256)
        scores = model(ctxt_list, cands_list)
        assert scores.shape == (1, 3)