Spaces:
Sleeping
Sleeping
Upload app.py
Browse files
app.py
CHANGED
@@ -9,7 +9,7 @@ def main():
|
|
9 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
10 |
tokenizer, EMPTY_TOKENS = load_tokenizer(device)
|
11 |
|
12 |
-
st.title("TinyStories Transformer
|
13 |
|
14 |
st.subheader("Data and Training")
|
15 |
|
@@ -24,7 +24,7 @@ def main():
|
|
24 |
st.subheader("How Do I Prompt?")
|
25 |
st.markdown(
|
26 |
"""
|
27 |
-
Instead of generating a new story from scratch, you can "prompt" the model by writing the first few words
|
28 |
|
29 |
The model can struggle with some prompts, especially those outside of its limited domain. If a response isn't satisfactory, try repeating the generation, or make the following modifications:
|
30 |
"""
|
@@ -73,8 +73,8 @@ def main():
|
|
73 |
# model_version = st.radio("Which model would you like to use?", ["smoll", "beeg"])
|
74 |
# small_model = load_casey_model(tokenizer, device)
|
75 |
model = load_big_model(tokenizer, device)
|
76 |
-
|
77 |
-
|
78 |
|
79 |
|
80 |
if st.button('Write my story!'):
|
|
|
9 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
10 |
tokenizer, EMPTY_TOKENS = load_tokenizer(device)
|
11 |
|
12 |
+
st.title("TinyStories Transformer Demo 🤖")
|
13 |
|
14 |
st.subheader("Data and Training")
|
15 |
|
|
|
24 |
st.subheader("How Do I Prompt?")
|
25 |
st.markdown(
|
26 |
"""
|
27 |
+
Instead of generating a new story from scratch, you can "prompt" the model by writing the first few words of a story, and let it finish from there. It can even jump in mid-sentence!
|
28 |
|
29 |
The model can struggle with some prompts, especially those outside of its limited domain. If a response isn't satisfactory, try repeating the generation, or make the following modifications:
|
30 |
"""
|
|
|
73 |
# model_version = st.radio("Which model would you like to use?", ["smoll", "beeg"])
|
74 |
# small_model = load_casey_model(tokenizer, device)
|
75 |
model = load_big_model(tokenizer, device)
|
76 |
+
model.to('cuda')
|
77 |
+
model.cuda()
|
78 |
|
79 |
|
80 |
if st.button('Write my story!'):
|