ytzfhqs commited on
Commit
2dfc001
1 Parent(s): 09fc082

Update README For Batch Usage

Browse files
Files changed (1) hide show
  1. README.md +28 -0
README.md CHANGED
@@ -53,4 +53,32 @@ logits = outputs.logits
53
  id = torch.argmax(logits, dim=-1).item()
54
  response = ID2LABEL[id]
55
  print(response)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  ```
 
53
  id = torch.argmax(logits, dim=-1).item()
54
  response = ID2LABEL[id]
55
  print(response)
56
+ # "非正文"
57
+ ```
58
+
59
+ # For Batch Usage
60
+ ```python
61
+ import torch
62
+ from transformers import AutoModelForSequenceClassification
63
+ from transformers import AutoTokenizer
64
+
65
+ ID2LABEL = {0: "正文", 1: "非正文"}
66
+
67
+ model_name = 'ytzfhqs/Qwen2.5-med-book-main-classification'
68
+ model = AutoModelForSequenceClassification.from_pretrained(
69
+ model_name,
70
+ torch_dtype="auto",
71
+ device_map="auto"
72
+ )
73
+ tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
74
+
75
+ text = ['下列为修订说明','阴离子间隙是一项受到广泛重视的酸碱指标。AG是一个计算值,指血浆中未测定的阴离子与未测定的阳离子的差值,正常机体血浆中的阳离子与阴离子总量相等,均为151mmol/L,从而维持电荷平衡。']
76
+ encoding = tokenizer(text, return_tensors='pt',padding=True)
77
+ encoding = {k: v.to(model.device) for k, v in encoding.items()}
78
+ outputs = model(**encoding)
79
+ logits = outputs.logits
80
+ ids = torch.argmax(logits, dim=-1).tolist()
81
+ response = [ID2LABEL[id] for id in ids]
82
+ print(response)
83
+ # ['非正文', '正文']
84
  ```