heegyu commited on
Commit
6c1b57e
Β·
1 Parent(s): c4e6aec

add app.py

Browse files
Files changed (2) hide show
  1. app.py +61 -0
  2. requirements.txt +2 -0
app.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from streamlit_chat import message
3
+
4
+ @st.cache(allow_output_mutation=True)
5
+ def get_pipe():
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
7
+ tokenizer = AutoTokenizer.from_pretrained("heegyu/kodialogpt-v1")
8
+ model = AutoModelForCausalLM.from_pretrained("heegyu/kodialogpt-v1")
9
+ return pipeline("text-generation", model=model, tokenizer=tokenizer)
10
+
11
+ def get_response(generator, history, max_context: int = 7, bot_id: str = '1'):
12
+ generation_args = dict(
13
+ num_beams=4,
14
+ repetition_penalty=2.0,
15
+ no_repeat_ngram_size=4,
16
+ eos_token_id=375, # \n
17
+ max_new_tokens=64,
18
+ do_sample=True,
19
+ top_k=50,
20
+ early_stopping=True
21
+ )
22
+ context = []
23
+ for i, text in enumerate(history):
24
+ context.append(f"{i % 2} : {text}\n")
25
+
26
+ if len(context) > max_context:
27
+ context = context[-max_context:]
28
+ context = "".join(context) + f"{bot_id} : "
29
+
30
+ # print(f"get_response({context})")
31
+
32
+ response = generator(
33
+ context,
34
+ **generation_args
35
+ )[0]["generated_text"]
36
+ response = response[len(context):].split("\n")[0]
37
+ return response
38
+
39
+ st.title("kodialogpt-v1 demo")
40
+
41
+ with st.spinner("loading model..."):
42
+ generator = get_pipe()
43
+
44
+ if 'message_history' not in st.session_state:
45
+ st.session_state.message_history = []
46
+ history = st.session_state.message_history
47
+
48
+ # print(st.session_state.message_history)
49
+ for i, message_ in enumerate(st.session_state.message_history):
50
+ message(message_,is_user=i % 2 == 0) # display all the previous message
51
+
52
+ # placeholder = st.empty() # placeholder for latest message
53
+ input_ = st.text_input("YOU", value="")
54
+
55
+ if input_ is not None and len(input_) > 0:
56
+ if len(history) <= 1 or history[-2] != input_:
57
+ with st.spinner("λŒ€λ‹΅μ„ μƒμ„±μ€‘μž…λ‹ˆλ‹€..."):
58
+ st.session_state.message_history.append(input_)
59
+ response = get_response(generator, history)
60
+ st.session_state.message_history.append(response)
61
+ st.experimental_rerun()
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ transformers
2
+ streamlit_chat