tafxle commited on
Commit
07d1ab1
1 Parent(s): 047c577

Change SDK

Browse files
Files changed (2) hide show
  1. README.md +2 -2
  2. app.py +95 -18
README.md CHANGED
@@ -3,8 +3,8 @@ title: Bloom Chat
3
  emoji: ⚡
4
  colorFrom: purple
5
  colorTo: green
6
- sdk: gradio
7
- sdk_version: 3.4
8
  app_file: app.py
9
  pinned: false
10
  license: openrail
 
3
  emoji: ⚡
4
  colorFrom: purple
5
  colorTo: green
6
+ sdk: streamlit
7
+ sdk_version: 1.10.0
8
  app_file: app.py
9
  pinned: false
10
  license: openrail
app.py CHANGED
@@ -1,27 +1,104 @@
1
- # import streamlit as st
 
 
 
 
 
 
 
 
2
 
3
- # text = st.text_area("Prefix", value="DM: You enter the room.")
4
- # batch = st.number_input("Variants", value=5)
5
- # st.markdown(f"{text} {batch}")
6
 
 
 
 
 
 
 
 
 
7
 
8
- import streamlit as st
9
- from transformers import pipeline
10
- from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- pipeline = pipeline(task="image-classification", model="julien-c/hotdog-not-hotdog")
 
13
 
14
- st.title("Hot Dog? Or Not?")
 
 
 
15
 
16
- file_name = st.file_uploader("Upload a hot dog candidate image")
 
 
17
 
18
- if file_name is not None:
19
- col1, col2 = st.columns(2)
 
 
20
 
21
- image = Image.open(file_name)
22
- col1.image(image, use_column_width=True)
23
- predictions = pipeline(image)
24
 
25
- col2.header("Probabilities")
26
- for p in predictions:
27
- col2.subheader(f"{ p['label'] }: { round(p['score'] * 100, 1)}%")
 
1
+ import torch
2
+ import transformers
3
+ import time
4
+ from huggingface_hub import snapshot_download
5
+ import streamlit as st
6
+ import copy
7
+ from transformers import AutoConfig, GPTJForCausalLM
8
+ from transformers.models.gptj.modeling_gptj import GPTJBlock
9
+ from tqdm import trange
10
 
 
 
 
11
 
12
+ @st.cache(allow_output_mutation=True)
13
+ def load_model():
14
+ for down in trange(1, disable=True):
15
+ fpath = snapshot_download("OpenDungeon/gpt-j-8bit-ffbgem", revision="separate")
16
+ config = AutoConfig.from_pretrained("EleutherAI/gpt-j-6B")
17
+ qconfig = torch.quantization.get_default_qconfig('fbgemm')
18
+ torch.backends.quantized.engine = 'fbgemm'
19
+ n_layer, config.n_layer = config.n_layer, 0
20
 
21
+ model = GPTJForCausalLM(config)
22
+ model.load_state_dict(torch.load(fpath + "/blocks/base.pt"))
23
+ ref_block = torch.quantization.quantize_dynamic(
24
+ GPTJBlock(config),
25
+ {torch.nn.Linear: qconfig},
26
+ dtype=torch.qint8,
27
+ inplace=True
28
+ )
29
+
30
+ for i in trange(n_layer):
31
+ new_block = copy.deepcopy(ref_block)
32
+ new_block.load_state_dict(torch.load(f"{fpath}/blocks/block{i}.pt"))
33
+ model.transformer.h.append(new_block)
34
+
35
+ config.n_layer = len(model.transformer.h)
36
+ del ref_block
37
+
38
+ return transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B"), model
39
+
40
+
41
+ def PrintContinuation(prompt, local_model, single_hook=None, batch=1, limit_tokens = 50):
42
+ past_key_values = None # used to keep track of conversation history
43
+ input_dict = tokenizer([prompt] * batch, return_tensors='pt', padding=False)
44
+ output = [""] * batch
45
+ batch_time = 0
46
+
47
+ with torch.inference_mode():
48
+ for i in range(limit_tokens + 20):
49
+ if i == 5:
50
+ start_time = time.perf_counter()
51
+
52
+ outputs = local_model.forward(**input_dict, use_cache=True, past_key_values=past_key_values)
53
+ last_logits = outputs.logits[:, -1]
54
+
55
+ for j in range(batch):
56
+ last_logits[j, last_logits[j].topk(k=10).indices] += 10
57
+
58
+ past_key_values = outputs.past_key_values
59
+ token_ix = torch.multinomial(last_logits.softmax(-1), 1)
60
+ output = [stream + tokenizer.decode(ix) for stream, ix in zip(output, token_ix)]
61
+
62
+ if single_hook is not None:
63
+ single_hook(tokenizer.decode(token_ix[0]))
64
+ if i == limit_tokens:
65
+ batch_time = (time.perf_counter() - start_time) / (i - 4)
66
+ break
67
+
68
+ input_dict = dict(input_ids=token_ix)
69
+ return output, batch_time
70
+
71
+ import sys
72
+
73
+ def Sureprint(text):
74
+ text = f"\nDDBG: {text}\n"
75
+ print(text, flush=True)
76
+ print(text, file=sys.stderr, flush=True)
77
+
78
+ Sureprint("ready to load")
79
+ tokenizer, model = load_model()
80
+ Sureprint("loaded")
81
+ text = st.text_area("Prefix", value="DM: You enter the room.")
82
+ Sureprint(f"text acquired '{text}'")
83
+ batch = st.number_input("Variants", value=5)
84
 
85
+ t = st.empty()
86
+ firstline = ""
87
 
88
+ def PrintSome(text):
89
+ global t, firstline
90
+ firstline += text
91
+ t.markdown(f"{firstline}...")
92
 
93
+ Sureprint("before inference")
94
+ choices, batch_time = PrintContinuation(text, model, PrintSome, batch, 50)
95
+ Sureprint("after inference")
96
 
97
+ final_page = ""
98
+ for i in range(batch):
99
+ final_page += f"#### choice №{i + 1} \n{choices[i]} \n______ \n"
100
+ final_page += f"Seconds per batch: {batch_time}, Batch: {batch}"
101
 
102
+ t.markdown(final_page)
 
 
103
 
104
+ Sureprint("all done")