ベースモデル
https://huggingface.co/globis-university/deberta-v3-japanese-large
特徴;
- 根拠文と生成文を入れることで、生成文が根拠文に基づいているかをチェックします。
- jsquadのようなQA, jnliのような含意、jfldのような論理推論に対応しています。
- 日本語の演繹推論コーパス JFLD https://github.com/hitachi-nlp/FLD-corp us/blob/main/README.JFLD.md
- max_tokenは2048です。
テスト結果:
F1 specificity recall
総合 88.96 89.07 88.92
jsquad 97.02 95.46 98.42
jnli 78.04 96.89 75.29
jdolly 97.94 98.46 97.54
jaquad 95.58 93.07 98.03
jfldD1 79.28 80.96 75.94
jfldD3 78.13 74.03 77.9
学習用データセットは以下のものをベースにし、独自にnegativeデータを追加しています。
- jsquad,jnli https://github.com/yahoojapan/JGLUE
- jdolly https://huggingface.co/datasets/kunishou/databricks-dolly-15k-ja
- jaquad https://github.com/SkelterLabsInc/JaQuAD
- jfld https://github.com/hitachi-nlp/FLD-corpus/blob/main/README.JFLD.md
スニペット:
class NLIPredictor:
def __init__(self, model_name, model_path, device="cpu", max_length=2048):
device = "cuda"
self.device = torch.device(device)
self.max_length = max_length
self.model = AutoModelForSequenceClassification.from_pretrained(model_path).to(
self.device
)
self.model.eval()
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
def predict(self, premise, hypothesis):
inputs = self.tokenizer(
premise,
hypothesis,
return_tensors="pt",
max_length=self.max_length,
truncation=True,
padding=True,
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits
probabilities = torch.softmax(logits, dim=1)
predicted_class = torch.argmax(probabilities, dim=1).item()
return predicted_class
- Downloads last month
- 0