kodialogpt-v1 / app.py
heegyu's picture
add app.py
6c1b57e
raw
history blame
2.06 kB
import streamlit as st
from streamlit_chat import message
@st.cache(allow_output_mutation=True)
def get_pipe():
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
tokenizer = AutoTokenizer.from_pretrained("heegyu/kodialogpt-v1")
model = AutoModelForCausalLM.from_pretrained("heegyu/kodialogpt-v1")
return pipeline("text-generation", model=model, tokenizer=tokenizer)
def get_response(generator, history, max_context: int = 7, bot_id: str = '1'):
generation_args = dict(
num_beams=4,
repetition_penalty=2.0,
no_repeat_ngram_size=4,
eos_token_id=375, # \n
max_new_tokens=64,
do_sample=True,
top_k=50,
early_stopping=True
)
context = []
for i, text in enumerate(history):
context.append(f"{i % 2} : {text}\n")
if len(context) > max_context:
context = context[-max_context:]
context = "".join(context) + f"{bot_id} : "
# print(f"get_response({context})")
response = generator(
context,
**generation_args
)[0]["generated_text"]
response = response[len(context):].split("\n")[0]
return response
st.title("kodialogpt-v1 demo")
with st.spinner("loading model..."):
generator = get_pipe()
if 'message_history' not in st.session_state:
st.session_state.message_history = []
history = st.session_state.message_history
# print(st.session_state.message_history)
for i, message_ in enumerate(st.session_state.message_history):
message(message_,is_user=i % 2 == 0) # display all the previous message
# placeholder = st.empty() # placeholder for latest message
input_ = st.text_input("YOU", value="")
if input_ is not None and len(input_) > 0:
if len(history) <= 1 or history[-2] != input_:
with st.spinner("λŒ€λ‹΅μ„ μƒμ„±μ€‘μž…λ‹ˆλ‹€..."):
st.session_state.message_history.append(input_)
response = get_response(generator, history)
st.session_state.message_history.append(response)
st.experimental_rerun()