lm commited on
Commit
67a39d6
1 Parent(s): 8ae8eb8
Files changed (1) hide show
  1. app.py +40 -0
app.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
+
3
+ hf_model = "law-llm/law-glm-10b"
4
+ max_question_length = 64
5
+ max_generation_length = 490
6
+
7
+ tokenizer = AutoTokenizer.from_pretrained(
8
+ hf_model,
9
+ cache_dir=model_cache_dir,
10
+ use_fast=True,
11
+ trust_remote_code=True
12
+ )
13
+
14
+ model = AutoModelForSeq2SeqLM.from_pretrained(
15
+ hf_model,
16
+ cache_dir=model_cache_dir,
17
+ trust_remote_code=True
18
+ )
19
+
20
+ model = model.to('cuda')
21
+ model.eval()
22
+
23
+ model_inputs = "提问: 犯了盗窃罪怎么判刑? 回答: [gMASK]"
24
+
25
+ model_inputs = tokenizer(model_inputs,
26
+ max_length=max_question_length,
27
+ padding=True,
28
+ truncation=True,
29
+ return_tensors="pt")
30
+
31
+ model_inputs = tokenizer.build_inputs_for_generation(model_inputs,
32
+ targets=None,
33
+ max_gen_length=max_generation_length,
34
+ padding=True)
35
+
36
+ inputs = model_inputs.to('cuda')
37
+
38
+ outputs = model.generate(**inputs, max_length=max_generation_length,
39
+ eos_token_id=tokenizer.eop_token_id)
40
+ prediction = tokenizer.decode(outputs[0].tolist())