Spaces:
Sleeping
Sleeping
import streamlit as st | |
import time | |
from better_transformer import * | |
def main(): | |
# Enable CUDA if available and load in tokenizer | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
tokenizer, EMPTY_TOKENS = load_tokenizer(device) | |
st.title("TinyStories Transformer Demo π€") | |
st.subheader("Data and Training") | |
st.markdown("""We used the dataset from Microsoft Research's [TinyStories Paper](https://arxiv.org/pdf/2305.07759.pdf) (Eldan and Li), | |
which consists of 2.1 million synthetic short children's stories generated by GPT-4, to train a PyTorch Transformer LLM.""") | |
st.markdown("""Our model uses EleutherAI's [gpt-neo-1.3B tokenizer](https://huggingface.co/EleutherAI/gpt-neo-1.3B) (vocab size 50,257) and consists of 8 transformer blocks, | |
16 attention heads, and an embedding dimension of 768, for a total of ~56M non-embedding parameters. The model was trained overnight on 8 H100 GPUs, achieving a lower cross-entropy | |
validation loss than any of the models in the TinyStories paper (likely due to a larger vocab size).""") | |
st.markdown("""Despite the simple themes and limited vocabulary present in the training data, the model is | |
quite effective at generating new short stories. **Try it out below!**""") | |
st.subheader("How Do I Prompt?") | |
st.markdown( | |
""" | |
Instead of generating a new story from scratch, you can "prompt" the model by writing the first few sentences of a story, and let it finish from there. It can even jump in mid-sentence! | |
The model can struggle with some prompts, especially those outside of its limited domain. If a response isn't satisfactory, try repeating the generation, or make the following modifications: | |
""" | |
) | |
st.markdown( | |
""" | |
- **Use simple vocabulary and syntax** - words, structures, and themes you'd see in a children's story. | |
- Use common first names only - the model can struggle with longer or uncommon names. | |
`SAMPLE PROMPT: Once upon a time, there was a little girl named Lily. She loved to play at the park with her friend Timmy. One day, Lily and Timmy` | |
""" | |
) | |
st.subheader("Let's make some stories! π") | |
# Input from user | |
user_input = st.text_input("Enter your prompt:", placeholder="Write a prompt to make a story of your own, or leave it empty for a random story!").strip() | |
## Default values for advanced settings | |
user_seed = None # Remove if we're not rigging the "random" demo | |
generation_method = "top-k" | |
specified_k = 5 | |
specified_nucleus = 0.5 | |
specified_temperature = 0.4 | |
max_tokens = 1000 | |
if st.checkbox("Show Advanced Settings"): | |
user_seed = st.number_input("Randomness Seed:", value = None, step = 1, placeholder="Use to replicate response", min_value = 1) | |
generation_method = st.selectbox("Method of Generation:", ("top-k", "nucleus", "temperature", "multinomial", "greedy"), index = 0).strip() | |
if generation_method == "top-k": | |
specified_k = st.number_input("Value for k:", value = 5, step = 1) | |
if generation_method == "nucleus": | |
specified_nucleus = st.number_input("Nucleus Cutoff:", value = 0.5, step = 0.05, min_value = 0.0, max_value = 1.0) | |
if generation_method == "temperature": | |
specified_temperature = st.number_input("Value for temperature:", value = 0.4, step = 0.05, min_value = 0.0, max_value = 1.0) | |
max_tokens = st.slider('Max Tokens Generated:', 50, 1000, 1000) | |
# model_version = st.radio("Which model would you like to use?", ["smoll", "beeg"]) | |
# small_model = load_casey_model(tokenizer, device) | |
model = load_big_model(tokenizer, device) | |
model.to('cuda') | |
model.cuda() | |
if st.button('Write my story!'): | |
placeholder = st.empty() | |
# if model_version == 'smoll': | |
# model = load_casey_model(tokenizer, device) | |
# elif model_version == 'beeg': | |
# model = load_big_model(tokenizer, device) | |
# with placeholder.container(): | |
# st.write("Model Loaded! Preparing to Generate...") | |
with st.spinner(""): | |
result = generate(model, tokenizer, device, method=generation_method, k=specified_k, | |
p_nucleus=specified_nucleus, temp=specified_temperature, max_new_tokens=max_tokens, | |
cond=user_input, deterministic=user_seed) | |
st.markdown("\n") | |
st.markdown("\n") | |
st.markdown("\n") | |
st.markdown("\n") | |
streamed_input = "" | |
for word in user_input.split(' '): | |
streamed_input += word | |
with placeholder.container(): | |
st.markdown(f"**{streamed_input}**") | |
streamed_input += " " | |
time.sleep(0.1) | |
if user_input != "": ##conditional | |
result = result[len(user_input) + 3 :] | |
streamed_result = f"**{streamed_input[:-1]}**" | |
time.sleep(1) | |
else: ##unconditional | |
streamed_result = "" | |
for word in result.split(' '): | |
streamed_result += word + ' ' | |
with placeholder.container(): | |
st.markdown(f"{streamed_result}") | |
time.sleep(0.1) | |
if st.button('Clear Output'): | |
placeholder = st.empty() | |
st.markdown('####') | |
st.caption('UCLA DSU Project Fall 2023: Daniel Mendelevitch, Terry Ming, Casey Tattersall, Sean Tjoa') | |
st.caption(r'Data Attribution: Tinystories (License: CDLA-Sharing-1.0) https://arxiv.org/abs/2305.07759') | |
if __name__ == "__main__": | |
main() | |