|
""" |
|
© Battelle Memorial Institute 2023 |
|
Made available under the GNU General Public License v 2.0 |
|
|
|
BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY |
|
FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN |
|
OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES |
|
PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED |
|
OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF |
|
MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS |
|
TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE |
|
PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, |
|
REPAIR OR CORRECTION. |
|
""" |
|
|
|
import torch |
|
from transformers import PreTrainedModel |
|
|
|
from .fup_bert_config import FupBERTConfig |
|
from .fup_bert_model import FupBERTModel |
|
|
|
|
|
class FupBERT(PreTrainedModel): |
|
"""Hugging Face Wrapper""" |
|
config_class = FupBERTConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.model = FupBERTModel(ntoken=config.ntoken, |
|
ninp=config.ninp, |
|
nhead=config.nhead, |
|
nhid=config.nhid, |
|
nlayers=config.nlayers, |
|
token_reduction=config.token_reduction, |
|
padding_idx=config.padding_idx, |
|
cls_idx=config.cls_idx, |
|
edge_idx=config.edge_idx, |
|
num_out=config.num_out, |
|
dropout=config.dropout, |
|
) |
|
|
|
def forward(self, src): |
|
return self.model(src) |
|
|
|
def load_params(self, pt_file): |
|
self.model.load_state_dict(torch.load(pt_file)) |
|
|