fzmnm commited on
Commit
49d729a
·
verified ·
1 Parent(s): 82a85bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -66
app.py CHANGED
@@ -1,67 +1,79 @@
1
- import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
3
- import torch
4
- from threading import Thread
5
-
6
- model_name = "fzmnm/TinyStoriesAdv_v2_92M"
7
-
8
- tokenizer = AutoTokenizer.from_pretrained(model_name)
9
- model = AutoModelForCausalLM.from_pretrained(model_name)
10
- model.eval()
11
-
12
- model.generation_config.pad_token_id = tokenizer.eos_token_id
13
-
14
- max_tokens = 512
15
-
16
- def build_input_str(message: str, history: 'list[list[str]]'):
17
- history_str = ""
18
- for entity in history:
19
- if entity['role'] == 'user':
20
- history_str += f"问:{entity['content']}\n\n"
21
- elif entity['role'] == 'assistant':
22
- history_str += f"答:{entity['content']}\n\n"
23
- return history_str + f"问:{message}\n\n"
24
-
25
- def stop_criteria(input_str):
26
- return input_str.endswith("\n") and len(input_str.strip()) > 0
27
-
28
- class StopOnTokens(StoppingCriteria):
29
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
30
- input_str = tokenizer.decode(input_ids[0], skip_special_tokens=True)
31
- return stop_criteria(input_str)
32
-
33
- def chat(message, history):
34
- input_str = build_input_str(message, history)
35
- input_ids = tokenizer.encode(input_str, return_tensors="pt")
36
- input_ids = input_ids[:, -max_tokens:]
37
- streamer = TextIteratorStreamer(
38
- tokenizer,
39
- timeout=10,
40
- skip_prompt=True,
41
- skip_special_tokens=True)
42
- stopping_criteria = StoppingCriteriaList([StopOnTokens()])
43
- generate_kwargs = dict(
44
- input_ids=input_ids,
45
- streamer=streamer,
46
- stopping_criteria=stopping_criteria,
47
- max_new_tokens=512,
48
- top_p=0.9,
49
- do_sample=True,
50
- temperature=0.7
51
- )
52
- t = Thread(target=model.generate, kwargs=generate_kwargs)
53
- t.start()
54
-
55
- output_str = ""
56
- for new_str in streamer:
57
- output_str += new_str
58
- yield output_str
59
-
60
- app = gr.ChatInterface(
61
- fn=chat,
62
- type='messages',
63
- examples=['什么是鹦鹉?', '什么是大象?', '谁是李白?', '什么是黑洞?'],
64
- title='聊天机器人',
65
- )
66
-
 
 
 
 
 
 
 
 
 
 
 
 
67
  app.launch()
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
3
+ import torch
4
+ from threading import Thread
5
+
6
+ import os; os.chdir(os.path.dirname(__file__))
7
+
8
+ # model_name = "./92M_low_kv_dropout_v3_hf"
9
+ model_name = "fzmnm/TinyStoriesAdv_v2_92M"
10
+
11
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
12
+ model = AutoModelForCausalLM.from_pretrained(model_name)
13
+ model.eval()
14
+
15
+ model.generation_config.pad_token_id = tokenizer.eos_token_id
16
+
17
+ max_tokens = 512
18
+
19
+ def build_input_str(message: str, history: 'list[list[str]]'):
20
+ history_str = ""
21
+ for entity in history:
22
+ if entity['role'] == 'user':
23
+ history_str += f"问:{entity['content']}\n\n"
24
+ elif entity['role'] == 'assistant':
25
+ history_str += f"答:{entity['content']}\n\n"
26
+ return history_str + f"问:{message}\n\n"
27
+
28
+ def stop_criteria(input_str):
29
+ # return input_str.endswith("\n") and len(input_str.strip()) > 0
30
+ input_str=input_str.replace(":",":")
31
+ return input_str.endswith("问:") or input_str.endswith("meta_tag:")
32
+
33
+ def remove_ending(input_str):
34
+ if input_str.replace(":",":").endswith("问:"):
35
+ return input_str[:-2]
36
+ if input_str.endswith("meta_tag:"):
37
+ return input_str[:-9]
38
+ return input_str
39
+
40
+ class StopOnTokens(StoppingCriteria):
41
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
42
+ input_str = tokenizer.decode(input_ids[0], skip_special_tokens=True)
43
+ return stop_criteria(input_str)
44
+
45
+ def chat(message, history):
46
+ input_str = build_input_str(message, history)
47
+ input_ids = tokenizer.encode(input_str, return_tensors="pt")
48
+ input_ids = input_ids[:, -max_tokens:]
49
+ streamer = TextIteratorStreamer(
50
+ tokenizer,
51
+ timeout=10,
52
+ skip_prompt=True,
53
+ skip_special_tokens=True)
54
+ stopping_criteria = StoppingCriteriaList([StopOnTokens()])
55
+ generate_kwargs = dict(
56
+ input_ids=input_ids,
57
+ streamer=streamer,
58
+ stopping_criteria=stopping_criteria,
59
+ max_new_tokens=512,
60
+ top_p=0.9,
61
+ do_sample=True,
62
+ temperature=0.7
63
+ )
64
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
65
+ t.start()
66
+
67
+ output_str = ""
68
+ for new_str in streamer:
69
+ output_str += new_str
70
+ yield remove_ending(output_str)
71
+
72
+ app = gr.ChatInterface(
73
+ fn=chat,
74
+ type='messages',
75
+ examples=['什么是鹦鹉?', '什么是大象?', '谁是李白?', '什么是黑洞?'],
76
+ title='聊天机器人',
77
+ )
78
+
79
  app.launch()