Spaces:
Runtime error
Runtime error
Cache + Measure time
Browse files
app.py
CHANGED
@@ -1,20 +1,23 @@
|
|
1 |
import torch
|
2 |
import transformers
|
3 |
-
import
|
4 |
from huggingface_hub import hf_hub_download
|
|
|
5 |
|
6 |
|
7 |
-
|
8 |
-
|
9 |
-
hf_hub_download("OpenDungeon/gpt-j-8bit-ffbgem", "model.pt")
|
|
|
|
|
10 |
|
11 |
-
qmodel = torch.load("model.pt")
|
12 |
|
13 |
def PrintContinuation(prompt, local_model, single_hook=None, batch=1, limit_tokens = 50):
|
14 |
past_key_values = None # used to keep track of conversation history
|
15 |
input_dict = tokenizer([prompt] * batch, return_tensors='pt', padding=False)
|
16 |
output = [""] * batch
|
17 |
-
|
|
|
18 |
with torch.inference_mode():
|
19 |
for i in range(limit_tokens + 20):
|
20 |
if i == 5:
|
@@ -33,16 +36,25 @@ def PrintContinuation(prompt, local_model, single_hook=None, batch=1, limit_toke
|
|
33 |
if single_hook is not None:
|
34 |
single_hook(tokenizer.decode(token_ix[0]))
|
35 |
if i == limit_tokens:
|
36 |
-
|
37 |
-
print((time.perf_counter() - start_time) / (i - 4), "s per token")
|
38 |
break
|
39 |
|
40 |
input_dict = dict(input_ids=token_ix)
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
|
|
46 |
|
47 |
-
|
48 |
-
PrintContinuation(text, qmodel, lambda x: t.markdown(f"## {x}..."))
|
|
|
1 |
import torch
|
2 |
import transformers
|
3 |
+
import time
|
4 |
from huggingface_hub import hf_hub_download
|
5 |
+
import streamlit as st
|
6 |
|
7 |
|
8 |
+
@st.cache
|
9 |
+
def load_model():
|
10 |
+
hf_hub_download("OpenDungeon/gpt-j-8bit-ffbgem", "model.pt")
|
11 |
+
qmodel = torch.load("model.pt")
|
12 |
+
return transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B"), qmodel
|
13 |
|
|
|
14 |
|
15 |
def PrintContinuation(prompt, local_model, single_hook=None, batch=1, limit_tokens = 50):
|
16 |
past_key_values = None # used to keep track of conversation history
|
17 |
input_dict = tokenizer([prompt] * batch, return_tensors='pt', padding=False)
|
18 |
output = [""] * batch
|
19 |
+
batch_time = 0
|
20 |
+
|
21 |
with torch.inference_mode():
|
22 |
for i in range(limit_tokens + 20):
|
23 |
if i == 5:
|
|
|
36 |
if single_hook is not None:
|
37 |
single_hook(tokenizer.decode(token_ix[0]))
|
38 |
if i == limit_tokens:
|
39 |
+
batch_time = (time.perf_counter() - start_time) / (i - 4)
|
|
|
40 |
break
|
41 |
|
42 |
input_dict = dict(input_ids=token_ix)
|
43 |
+
return output, batch_time
|
44 |
+
|
45 |
+
|
46 |
+
tokenizer, model = load_model()
|
47 |
+
text = st.text_area("Prefix")
|
48 |
+
batch = st.number_input("Variants", value=1)
|
49 |
+
|
50 |
+
t = st.empty()
|
51 |
+
firstline = ""
|
52 |
+
|
53 |
+
def PrintSome(text):
|
54 |
+
global t, firstline
|
55 |
+
firstline += text
|
56 |
+
t.markdown(f"## {firstline}...")
|
57 |
|
58 |
+
choices, batch_time = PrintContinuation(text, model, PrintSome, batch, 50)
|
59 |
|
60 |
+
t.markdown(" \n\n".join(choices) + f" \n\nBatch:Seconds per batch: {batch_time}, Batch: {batch}")
|
|