# coding=utf-8 # Copyright 2021 The IDEA Authors. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List import copy from transformers import PreTrainedTokenizer import argparse from dataloaders.item_encoder import ItemEncoder from dataloaders.item_decoder import ItemDecoder from .model import BagualuIEModel class BagualuIEExtractModel(object): """ BagualuIEExtractModel Args: tokenizer (PreTrainedTokenizer): tokenizer args (TrainingArgumentsIEStd): arguments """ def __init__(self, tokenizer: PreTrainedTokenizer, args: argparse.Namespace) -> None: self.encoder = ItemEncoder(tokenizer, args.max_length) self.decoder = ItemDecoder(tokenizer, args) def extract(self, batch_data: List[dict], model: BagualuIEModel, use_cuda: bool) -> List[dict]: """ extract Args: batch_data (List[dict]): batch of data model (BagualuIEModel): model Returns: List[dict]: batch of data """ if use_cuda: model = model.cuda() model.eval() batch_data = copy.deepcopy(batch_data) batch = [self.encoder.encode_item(item, with_label=False) for item in batch_data] batch = self.encoder.collate(batch) if use_cuda: batch = {k: v.cuda() for k, v in batch.items()} span_logits = model(**batch).cpu().detach().numpy() label_mask = batch["label_mask"].cpu().detach().numpy() for i, item in enumerate(batch_data): entity_list, spo_list = self.decoder.decode(item, span_logits[i], label_mask[i]) item["spo_list"] = spo_list item["entity_list"] = entity_list return batch_data