tafxle commited on
Commit
09c0a40
1 Parent(s): 2bb3084

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -47
app.py CHANGED
@@ -1,48 +1,51 @@
1
- from transformers import BloomTokenizerFast, BloomModel
2
  import torch
3
- import gradio as gr
4
-
5
- mname = "bigscience/bloom-1b7"
6
- tokenizer = BloomTokenizerFast.from_pretrained(mname, use_cache=True)
7
- model = BloomModel.from_pretrained(mname, use_cache=True)
8
-
9
- def take_last_tokens(inputs, note_history, history):
10
- """Filter the last 256 tokens"""
11
- if inputs['input_ids'].shape[1] > 256:
12
- inputs['input_ids'] = torch.tensor([inputs['input_ids'][0][-256:].tolist()])
13
- inputs['attention_mask'] = torch.tensor([inputs['attention_mask'][0][-256:].tolist()])
14
- note_history = ['</s> <s>'.join(note_history[0].split('</s> <s>')[2:])]
15
- history = history[1:]
16
- return inputs, note_history, history
17
-
18
- def add_note_to_history(note, note_history):
19
- """Add a note to the historical information"""
20
- note_history.append(note)
21
- note_history = '</s> <s>'.join(note_history)
22
- return [note_history]
23
-
24
- def chat(message, history):
25
- history = history or []
26
- if history:
27
- history_useful = ['</s> <s>'.join([str(a[0])+'</s> <s>'+str(a[1]) for a in history])]
28
- else:
29
- history_useful = []
30
- history_useful = add_note_to_history(message, history_useful)
31
- inputs = tokenizer(history_useful, return_tensors="pt")
32
- inputs, history_useful, history = take_last_tokens(inputs, history_useful, history)
33
- reply_ids = model.generate(**inputs)
34
- response = tokenizer.batch_decode(reply_ids, skip_special_tokens=True)[0]
35
- history_useful = add_note_to_history(response, history_useful)
36
- list_history = history_useful[0].split('</s> <s>')
37
- history.append((list_history[-2], list_history[-1]))
38
- return history, history
39
-
40
- gr.Interface(
41
- fn=chat,
42
- theme="huggingface",
43
- css=".footer {display:none !important}",
44
- inputs=["text", "state"],
45
- outputs=["message", "state"],
46
- title="Bloom 1b3 chat",
47
- allow_flagging="never",
48
- ).launch()
 
 
 
 
 
 
1
  import torch
2
+ import transformers
3
+ import numpy as np
4
+ from huggingface_hub import hf_hub_download
5
+
6
+
7
+ tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
8
+
9
+ hf_hub_download("OpenDungeon/gpt-j-8bit-ffbgem", "model.pt")
10
+
11
+ qmodel = torch.load("model.pt")
12
+
13
+ def PrintContinuation(prompt, local_model, single_hook=None, batch=1, limit_tokens = 50):
14
+ past_key_values = None # used to keep track of conversation history
15
+ input_dict = tokenizer([prompt] * batch, return_tensors='pt', padding=False)
16
+ output = [""] * batch
17
+
18
+ with torch.inference_mode():
19
+ for i in range(limit_tokens + 20):
20
+ if i == 5:
21
+ start_time = time.perf_counter()
22
+
23
+ outputs = local_model.forward(**input_dict, use_cache=True, past_key_values=past_key_values)
24
+ last_logits = outputs.logits[:, -1]
25
+
26
+ for j in range(batch):
27
+ last_logits[j, last_logits[j].topk(k=10).indices] += 10
28
+
29
+ past_key_values = outputs.past_key_values
30
+ token_ix = torch.multinomial(last_logits.softmax(-1), 1)
31
+ output = [stream + tokenizer.decode(ix) for stream, ix in zip(output, token_ix)]
32
+
33
+ if single_hook is not None:
34
+ single_hook(tokenizer.decode(token_ix[0]))
35
+ if i == limit_tokens:
36
+ print()
37
+ print((time.perf_counter() - start_time) / (i - 4), "s per token")
38
+ break
39
+
40
+ input_dict = dict(input_ids=token_ix)
41
+ print()
42
+ return output
43
+
44
+ import streamlit as st
45
+
46
+ def process(text):
47
+ return text[::-1]
48
+
49
+
50
+ text = st.text_area("Prompt")
51
+ t.markdown(f"## {process(text)[0:i]}...")