|
import streamlit as st |
|
from transformers import AutoTokenizer |
|
from transformers import GPT2LMHeadModel |
|
from transformers import set_seed |
|
|
|
import meta |
|
from normalizer import normalize |
|
from utils import load_json |
|
from utils import local_css |
|
|
|
EXAMPLES = load_json("examples.json") |
|
CK = "" |
|
QK = "Q:" |
|
AK = "A:" |
|
|
|
|
|
class TextGeneration: |
|
def __init__(self): |
|
self.debug = False |
|
self.dummy_output = "Destiny's Child" |
|
self.tokenizer = None |
|
self.model = None |
|
self.model_name_or_path = "m3hrdadfi/gpt2-QA" |
|
self.length_margin = 100 |
|
set_seed(42) |
|
|
|
def load(self): |
|
if not self.debug: |
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path) |
|
self.model = GPT2LMHeadModel.from_pretrained(self.model_name_or_path) |
|
|
|
def generate(self, prompt, generation_kwargs): |
|
|
|
if not self.debug: |
|
input_ids = self.tokenizer([prompt], return_tensors="pt")["input_ids"] |
|
max_length = len(input_ids[0]) + self.length_margin |
|
max_length = min(max_length, 1024) |
|
generation_kwargs["max_length"] = max_length |
|
|
|
generated = self.model.generate( |
|
input_ids, |
|
**generation_kwargs, |
|
)[0] |
|
|
|
answer = self.tokenizer.decode(generated, skip_special_tokens=True) |
|
found = answer.find(f"{AK}") |
|
if not found: |
|
return "" |
|
|
|
answer = [a.strip() for a in answer[found:].split(f"{AK}") if a.strip()] |
|
answer = answer[0] if len(answer) > 0 else "" |
|
return answer |
|
|
|
return self.dummy_output |
|
|
|
|
|
@st.cache(allow_output_mutation=True) |
|
def load_text_generator(): |
|
generator = TextGeneration() |
|
generator.load() |
|
return generator |
|
|
|
|
|
def main(): |
|
st.set_page_config( |
|
page_title="GPT2 QA", |
|
page_icon="⁉️", |
|
layout="wide", |
|
initial_sidebar_state="expanded" |
|
) |
|
local_css("assets/style.css") |
|
generator = load_text_generator() |
|
|
|
st.sidebar.markdown(meta.SIDEBAR_INFO) |
|
num_beams = st.sidebar.slider( |
|
label='Number of Beam', |
|
help="Number of beams for beam search", |
|
min_value=4, |
|
max_value=15, |
|
value=5, |
|
step=1 |
|
) |
|
repetition_penalty = st.sidebar.slider( |
|
label='Repetition Penalty', |
|
help="The parameter for repetition penalty", |
|
min_value=1.0, |
|
max_value=3.0, |
|
value=1.0, |
|
step=0.1 |
|
) |
|
length_penalty = st.sidebar.slider( |
|
label='Length Penalty', |
|
help="Exponential penalty to the length", |
|
min_value=0.0, |
|
max_value=2.0, |
|
value=1.0, |
|
step=0.1 |
|
) |
|
early_stopping = st.sidebar.selectbox( |
|
label='Early Stopping ?', |
|
options=(True, False), |
|
help="Whether to stop the beam search when at least num_beams sentences are finished per batch or not", |
|
) |
|
generation_kwargs = { |
|
"num_beams": num_beams, |
|
"early_stopping": early_stopping, |
|
"repetition_penalty": repetition_penalty, |
|
"length_penalty": length_penalty, |
|
} |
|
|
|
st.markdown(meta.HEADER_INFO) |
|
prompts = [e["title"] for e in EXAMPLES] + ["Custom"] |
|
prompt = st.selectbox('Examples', prompts, index=len(prompts) - 1) |
|
|
|
if prompt == "Custom": |
|
prompt_box = { |
|
"context": meta.C_PROMPT_BOX, |
|
"question": meta.Q_PROMPT_BOX, |
|
"answers": [meta.A_PROMPT_BOX], |
|
} |
|
else: |
|
prompt_box = next(e for e in EXAMPLES if e["title"] == prompt) |
|
|
|
context = st.text_area("Enter context", prompt_box["context"], height=200) |
|
question = st.text_area("Enter question", prompt_box["question"], height=100) |
|
answer = "Ground Truth Answers: " + \ |
|
"".join([f"<span class='ground-truth'>{answer}</span>" for answer in prompt_box["answers"]]) |
|
st.markdown( |
|
f'<p>' |
|
f'{answer}' |
|
f'<p>', |
|
unsafe_allow_html=True |
|
) |
|
generation_kwargs_ph = st.empty() |
|
|
|
if st.button("Find the answer 🔎 "): |
|
with st.spinner(text="Searching ..."): |
|
generation_kwargs_ph.markdown(", ".join([f"`{k}`: {v}" for k, v in generation_kwargs.items()])) |
|
context = normalize(context) |
|
question = normalize(question) |
|
|
|
if context and question: |
|
text = f"{context} {QK} {question} {AK}" |
|
generated_answer = generator.generate(text, generation_kwargs) |
|
generated_answer = f"{AK} {generated_answer}".strip() |
|
context = f"{CK} {context}".strip() |
|
question = f"{QK} {question}".strip() |
|
|
|
st.markdown( |
|
f'<p>' |
|
f'<span class="result-text">{context}<span><br/><br/>' |
|
f'<span class="result-text">{question}<span><br/><br/>' |
|
f'<span class="result-text generated-text">{generated_answer} </span>' |
|
f'</p>', |
|
unsafe_allow_html=True |
|
) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|