File size: 1,885 Bytes
89650c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
import os
import sys
from transformers import PretrainedConfig, PreTrainedModel
#sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from ultra.models import Ultra
from ultra.datasets import WN18RR, CoDExSmall, FB15k237, FB15k237Inductive
from ultra.eval import test
class UltraConfig(PretrainedConfig):
model_type = "ultra"
def __init__(
self,
relation_model_layers: int = 6,
relation_model_dim: int = 64,
entity_model_layers: int = 6,
entity_model_dim: int = 64,
**kwargs):
self.relation_model_cfg = dict(
input_dim=relation_model_dim,
hidden_dims=[relation_model_dim]*relation_model_layers,
message_func="distmult",
aggregate_func="sum",
short_cut=True,
layer_norm=True
)
self.entity_model_cfg = dict(
input_dim=entity_model_dim,
hidden_dims=[entity_model_dim]*entity_model_layers,
message_func="distmult",
aggregate_func="sum",
short_cut=True,
layer_norm=True
)
super().__init__(**kwargs)
class UltraLinkPrediction(PreTrainedModel):
config_class = UltraConfig
def __init__(self, config):
super().__init__(config)
self.model = Ultra(
rel_model_cfg=config.relation_model_cfg,
entity_model_cfg=config.entity_model_cfg,
)
def forward(self, data, batch):
# data: PyG data object
# batch shape: (bs, 1+num_negs, 3)
return self.model.forward(data, batch)
if __name__ == "__main__":
model = UltraLinkPrediction.from_pretrained("mgalkin/ultra_50g")
dataset = CoDExSmall(root="./datasets/")
test(model, mode="test", dataset=dataset, gpus=None)
# mrr: 0.497697
# hits@10: 0.685175 |