bagualu-ie / models /extract_model.py
han liu
init
ff78ef7
# coding=utf-8
# Copyright 2021 The IDEA Authors. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import copy
from transformers import PreTrainedTokenizer
import argparse
from dataloaders.item_encoder import ItemEncoder
from dataloaders.item_decoder import ItemDecoder
from .model import BagualuIEModel
class BagualuIEExtractModel(object):
""" BagualuIEExtractModel
Args:
tokenizer (PreTrainedTokenizer): tokenizer
args (TrainingArgumentsIEStd): arguments
"""
def __init__(self,
tokenizer: PreTrainedTokenizer,
args: argparse.Namespace) -> None:
self.encoder = ItemEncoder(tokenizer, args.max_length)
self.decoder = ItemDecoder(tokenizer, args)
def extract(self, batch_data: List[dict], model: BagualuIEModel, use_cuda: bool) -> List[dict]:
""" extract
Args:
batch_data (List[dict]): batch of data
model (BagualuIEModel): model
Returns:
List[dict]: batch of data
"""
if use_cuda:
model = model.cuda()
model.eval()
batch_data = copy.deepcopy(batch_data)
batch = [self.encoder.encode_item(item, with_label=False) for item in batch_data]
batch = self.encoder.collate(batch)
if use_cuda:
batch = {k: v.cuda() for k, v in batch.items()}
span_logits = model(**batch).cpu().detach().numpy()
label_mask = batch["label_mask"].cpu().detach().numpy()
for i, item in enumerate(batch_data):
entity_list, spo_list = self.decoder.decode(item,
span_logits[i],
label_mask[i])
item["spo_list"] = spo_list
item["entity_list"] = entity_list
return batch_data