File size: 1,099 Bytes
bd0a6b5
 
 
89b2326
bd0a6b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32




import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# 加载第一个模型
tokenizer1 = AutoTokenizer.from_pretrained("Emma0123/fine_tuned_model")
model1 = AutoModelForSequenceClassification.from_pretrained("Emma0123/fine_tuned_model")

# 加载第二个模型
tokenizer2 = AutoTokenizer.from_pretrained("jonas/roberta-base-finetuned-sdg")
model2 = AutoModelForSequenceClassification.from_pretrained("jonas/roberta-base-finetuned-sdg")

# 输入文本
input_text = input()

# 对第一个模型进行推理
inputs = tokenizer1(input_text, return_tensors="pt", truncation=True)
outputs = model1(**inputs)
predictions = torch.argmax(outputs.logits, dim=1).item()

# 根据第一个模型的输出进行条件判断
if predictions == 1:
    # 使用第二个模型进行判断
    inputs2 = tokenizer2(input_text, return_tensors="pt", truncation=True)
    outputs2 = model2(**inputs2)
    predictions2 = torch.argmax(outputs2.logits, dim=1).item()
    print("Second model prediction:", predictions2)
else:
    print("This content is unrelated to Environment.")