Spaces:
Sleeping
Sleeping
model files uploaded
Browse files- app.py +118 -0
- better_transformer.py +399 -0
- bt_8_LAYERs_100_DATA_PCT_768_EMBD_DIM_epoch_10.pt +3 -0
- requirements +70 -0
app.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import time
|
3 |
+
|
4 |
+
from better_transformer import *
|
5 |
+
|
6 |
+
def main():
|
7 |
+
|
8 |
+
# Enable CUDA if available and load in tokenizer
|
9 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
10 |
+
tokenizer, EMPTY_TOKENS = load_tokenizer(device)
|
11 |
+
|
12 |
+
st.title("Scaling Transformers")
|
13 |
+
st.subheader("UCLA DSU Project, Fall 2023")
|
14 |
+
st.markdown("Daniel Mendelevitch \n Terry Ming \n Casey Tattersall \n Sean Tjoa")
|
15 |
+
|
16 |
+
st.header("What Are Transformers? 🚗🔄🤖")
|
17 |
+
|
18 |
+
header_text = """A transformer is a specific type of neural network that uses a mechanism called self-attention to learn the context (and
|
19 |
+
thus meaning) of sequential data. Transformer-based models can be used in many different domains, such as processing language, predicting
|
20 |
+
the weather, or even generating images. \n\n You might be familiar with ChatGPT, a Transformer-based model which cost over \$100 million to train. \n In contrast, we spent \$40*.
|
21 |
+
"""
|
22 |
+
st.markdown(header_text)
|
23 |
+
|
24 |
+
st.header("Let's make some stories! 📖")
|
25 |
+
|
26 |
+
# Input from user
|
27 |
+
user_input = st.text_input("Enter your prompt:", placeholder="Write a prompt to make a story of your own or leave it empty for a random story!").strip()
|
28 |
+
|
29 |
+
if st.checkbox("Show Prompting Tips"):
|
30 |
+
st.markdown("Our model was trained on the TinyStories dataset, a collection of synthetic short stories generated by GPT-4. These stories only contain words and themes that a typical 3-4 year old would understand.")
|
31 |
+
st.markdown(
|
32 |
+
"""
|
33 |
+
- Use simple vocabulary - words and themes that would appear in a children's story
|
34 |
+
- Avoid using idioms - for example, instead of "hit the gym", say "went to the gym"
|
35 |
+
- Include plenty of descriptive adjectives
|
36 |
+
- The model often struggles with names - using common names and only including a person's first name can help
|
37 |
+
"""
|
38 |
+
)
|
39 |
+
## Default values for advanced settings
|
40 |
+
user_seed = 27 # Remove if we're not rigging the "random" demo
|
41 |
+
generation_method = "top-k"
|
42 |
+
specified_k = 5
|
43 |
+
specified_nucleus = 0.5
|
44 |
+
specified_temperature = 0.9
|
45 |
+
max_tokens = 400
|
46 |
+
|
47 |
+
if st.checkbox("Show Advanced Settings"):
|
48 |
+
user_seed = st.number_input("Randomness Seed:", value = None, step = 1, placeholder="Use to replicate response", min_value = 1)
|
49 |
+
generation_method = st.selectbox("Method of Generation:", ("top-k", "multinomial", "temperature", "greedy", "nucleus"), index = 0).strip()
|
50 |
+
|
51 |
+
if generation_method == "top-k":
|
52 |
+
specified_k = st.number_input("Value for k:", value = 5, step = 1)
|
53 |
+
|
54 |
+
if generation_method == "nucleus":
|
55 |
+
specified_nucleus = st.number_input("Value for k:", value = 0.5, step = 0.05, min_value = 0.0, max_value = 1.0)
|
56 |
+
|
57 |
+
if generation_method == "temperature":
|
58 |
+
specified_temperature = st.number_input("Value for temperature:", value = 0.9, step = 0.05, min_value = 0.0, max_value = 1.0)
|
59 |
+
|
60 |
+
max_tokens = st.slider('Max Tokens Generated:', 100, 500, 400)
|
61 |
+
|
62 |
+
## Settings Clean up
|
63 |
+
if not user_seed:
|
64 |
+
user_seed = 7
|
65 |
+
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
# model_version = st.radio("Which model would you like to use?", ["smoll", "beeg"])
|
70 |
+
# small_model = load_casey_model(tokenizer, device)
|
71 |
+
model = load_big_model(tokenizer, device)
|
72 |
+
|
73 |
+
if st.button('Write my story!'):
|
74 |
+
placeholder = st.empty()
|
75 |
+
# if model_version == 'smoll':
|
76 |
+
# model = load_casey_model(tokenizer, device)
|
77 |
+
# elif model_version == 'beeg':
|
78 |
+
# model = load_big_model(tokenizer, device)
|
79 |
+
# with placeholder.container():
|
80 |
+
# st.write("Model Loaded! Preparing to Generate...")
|
81 |
+
|
82 |
+
|
83 |
+
|
84 |
+
|
85 |
+
with st.spinner(""):
|
86 |
+
result = generate(model, tokenizer, device, method=generation_method, k=specified_k,
|
87 |
+
p_nucleus=specified_nucleus, temp=specified_temperature, max_new_tokens=max_tokens,
|
88 |
+
cond=user_input, deterministic=user_seed)
|
89 |
+
|
90 |
+
streamed_input = ""
|
91 |
+
for word in user_input.split(' '):
|
92 |
+
streamed_input += word
|
93 |
+
with placeholder.container():
|
94 |
+
st.markdown(f"**{streamed_input}**")
|
95 |
+
streamed_input += " "
|
96 |
+
time.sleep(0.1)
|
97 |
+
|
98 |
+
if user_input != "": ##conditional
|
99 |
+
result = result[len(user_input) + 3 :]
|
100 |
+
streamed_result = f"**{streamed_input[:-1]}**"
|
101 |
+
time.sleep(1)
|
102 |
+
else: ##unconditional
|
103 |
+
streamed_result = ""
|
104 |
+
|
105 |
+
|
106 |
+
for word in result.split(' '):
|
107 |
+
streamed_result += word + ' '
|
108 |
+
with placeholder.container():
|
109 |
+
st.write(streamed_result)
|
110 |
+
time.sleep(0.1)
|
111 |
+
if st.button('Clear Output'):
|
112 |
+
placeholder = st.empty()
|
113 |
+
|
114 |
+
|
115 |
+
|
116 |
+
|
117 |
+
if __name__ == "__main__":
|
118 |
+
main()
|
better_transformer.py
ADDED
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import random
|
4 |
+
import time
|
5 |
+
import streamlit as st
|
6 |
+
import re
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from transformers import AutoTokenizer
|
12 |
+
|
13 |
+
MODEL_FILE = r'bt_8_LAYERs_100_DATA_PCT_768_EMBD_DIM_epoch_10.pt' ##place model file in same directory as app.py
|
14 |
+
|
15 |
+
# Better Transformer Class –––––––––––––––––––––––––––––––––––––––––––––––
|
16 |
+
|
17 |
+
class MLP(nn.Module):
|
18 |
+
def __init__(self, n_embd, dropout=0.1):
|
19 |
+
super().__init__()
|
20 |
+
self.net = nn.Sequential(
|
21 |
+
nn.Linear(n_embd, 4 * n_embd),
|
22 |
+
nn.GELU(), # replaced ReLU
|
23 |
+
nn.Dropout(p=dropout),
|
24 |
+
nn.Linear(4 * n_embd, n_embd),
|
25 |
+
)
|
26 |
+
|
27 |
+
def forward(self, x):
|
28 |
+
return self.net(x)
|
29 |
+
|
30 |
+
class MultiHeadAttention(nn.Module):
|
31 |
+
def __init__(self, n_embd, n_head, seq_length, dropout=0.1):
|
32 |
+
super().__init__()
|
33 |
+
|
34 |
+
self.n_embd = n_embd
|
35 |
+
self.n_head = n_head
|
36 |
+
self.head_dim = n_embd // n_head # Dimension of each head's key, query, and value
|
37 |
+
assert self.head_dim * n_head == self.n_embd, "n_embd must be divisible by n_head"
|
38 |
+
self.seq_length = seq_length
|
39 |
+
self.drop = nn.Dropout(p=dropout)
|
40 |
+
|
41 |
+
self.query = nn.Linear(n_embd, n_embd, bias=False)
|
42 |
+
self.key = nn.Linear(n_embd, n_embd, bias=False)
|
43 |
+
self.value = nn.Linear(n_embd, n_embd, bias=False)
|
44 |
+
self.out = nn.Linear(n_embd, n_embd, bias=False) # multi-head combining weight matrix
|
45 |
+
|
46 |
+
def split_heads(self, x):
|
47 |
+
B, S, D = x.size()
|
48 |
+
# split dimension into n_head * head_dim, then transpose the sequence length w/ n_head
|
49 |
+
# output: [B, n_head, S, head_dim]
|
50 |
+
return x.view(B, S, self.n_head, self.head_dim).transpose(1, 2)
|
51 |
+
|
52 |
+
def combine_heads(self, x):
|
53 |
+
# use permute or transpose to reverse
|
54 |
+
# taking a view earlier may produce a non-contiguous tensor, so we convert back because view needs a contiguous input
|
55 |
+
B, _, S, head_dim = x.size() # _ is n_head which we will merge
|
56 |
+
# output: [B, S, n_embd]
|
57 |
+
return x.transpose(1, 2).contiguous().view(B, S, self.n_embd)
|
58 |
+
|
59 |
+
def scaled_dot_product(self, q, k, v, dropout, mask=None):
|
60 |
+
# q,k,v are [B, n_head, S, head_dim]
|
61 |
+
# the key transpose sets up batch multiplication s.t. wei = [B, n_head, S, S]
|
62 |
+
wei = q @ k.transpose(-2,-1) / np.sqrt(self.head_dim)
|
63 |
+
# mask is [B, 1, S, S], so simply broadcasted across each head and works as expected
|
64 |
+
if mask is not None:
|
65 |
+
wei = wei.masked_fill(mask, float('-inf'))
|
66 |
+
wei = dropout(F.softmax(wei, dim=-1))
|
67 |
+
out = wei @ v
|
68 |
+
return out
|
69 |
+
|
70 |
+
def forward(self, x, mask=None):
|
71 |
+
# x: (B, S, n_embd)
|
72 |
+
# Step 1 and 2: Project full query, key, value, then split via reshaping
|
73 |
+
q = self.split_heads(self.query(x))
|
74 |
+
k = self.split_heads(self.key(x))
|
75 |
+
v = self.split_heads(self.value(x))
|
76 |
+
|
77 |
+
# Step 3: Compute scaled dot-product attention with causal mask
|
78 |
+
# not done. should use generate_mask
|
79 |
+
attn = self.scaled_dot_product(q, k, v, self.drop, mask)
|
80 |
+
|
81 |
+
# Step 4 and 5: Concatenate attention scores, return projected output matrix
|
82 |
+
out = self.out(self.combine_heads(attn)) # (B, S, n_embd)
|
83 |
+
return out
|
84 |
+
|
85 |
+
class Block(nn.Module):
|
86 |
+
def __init__(self, n_embd, n_head, seq_length, dropout=0.1):
|
87 |
+
super().__init__()
|
88 |
+
self.sa = MultiHeadAttention(n_embd, n_head, seq_length, dropout)
|
89 |
+
self.mlp = MLP(n_embd, dropout)
|
90 |
+
self.ln1 = nn.LayerNorm(n_embd)
|
91 |
+
self.ln2 = nn.LayerNorm(n_embd)
|
92 |
+
# experimentally, apply layer norm before attention/MLP
|
93 |
+
self.drop = nn.Dropout(p=dropout)
|
94 |
+
|
95 |
+
def forward(self, x, mask):
|
96 |
+
# residual connection (stream)
|
97 |
+
x = x + self.drop(self.sa(self.ln1(x), mask))
|
98 |
+
x = x + self.drop(self.mlp(self.ln2(x)))
|
99 |
+
return x
|
100 |
+
|
101 |
+
class PositionalEncoding(nn.Module):
|
102 |
+
"""
|
103 |
+
Formula taken from the original Transformer paper:
|
104 |
+
PE(pos, 2i (even)) = sin(pos/(10000^{2i/d_model}))
|
105 |
+
PE(pos, 2i+1 (odd)) = cos(pos/(10000^{2i/d_model}))
|
106 |
+
|
107 |
+
See reference for more details:
|
108 |
+
https://kikaben.com/transformers-positional-encoding/
|
109 |
+
"""
|
110 |
+
def __init__(self, d_model, max_len):
|
111 |
+
# just set d_model = n_embd and max_len = seq_len
|
112 |
+
super().__init__()
|
113 |
+
|
114 |
+
position = torch.arange(max_len).unsqueeze(1) # [max_len, 1]
|
115 |
+
divisor = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model)) # [d_model / 2, half for each of sin and cos]
|
116 |
+
pe = torch.zeros(max_len, d_model)
|
117 |
+
pe[:, 0::2] = torch.sin(position * divisor) # 0 for second dim or :?
|
118 |
+
pe[:, 1::2] = torch.cos(position * divisor)
|
119 |
+
self.register_buffer('pe', pe) # result: self.pe = [max_len, d_model], mapping each token index to a vector of length d_model as desired
|
120 |
+
|
121 |
+
def forward(self, x):
|
122 |
+
# x = torch.arange(seq_length) has shape [seq_length], so x.size(0) extracts it, then we index self.pe for the first seq_length mappings
|
123 |
+
# note we do not add the positional embeddings to x itself yet, we simply return them
|
124 |
+
# output = (seq_length, d_model=n_embd)
|
125 |
+
return self.pe[:x.size(0)]
|
126 |
+
|
127 |
+
class BetterTransformer(nn.Module):
|
128 |
+
def __init__(self, vocab_size, seq_length, n_embd, n_head, n_layer, pad_idx, eos_token_id, device, dropout=0.1):
|
129 |
+
super().__init__()
|
130 |
+
self.token_embedding = nn.Embedding(vocab_size, n_embd, padding_idx=pad_idx)
|
131 |
+
# we need to make sure the embedding ignores the padding token right?
|
132 |
+
self.position_embedding = PositionalEncoding(n_embd, seq_length)
|
133 |
+
self.blocks = nn.Sequential(*[Block(n_embd,
|
134 |
+
n_head,
|
135 |
+
seq_length,
|
136 |
+
dropout) for _ in range(n_layer)])
|
137 |
+
self.lm_head = nn.Linear(n_embd, vocab_size)
|
138 |
+
self.drop = nn.Dropout(dropout)
|
139 |
+
self.seq_length = seq_length
|
140 |
+
self.pad_idx = pad_idx
|
141 |
+
self.eos_token_id = eos_token_id
|
142 |
+
self.device = device
|
143 |
+
self.init_params()
|
144 |
+
|
145 |
+
# optional weight initialization (e.g. Xavier uniform)
|
146 |
+
def init_params(self, default_initialization=False):
|
147 |
+
if not default_initialization:
|
148 |
+
for name, p in self.named_parameters():
|
149 |
+
if p.dim() > 1:
|
150 |
+
nn.init.xavier_uniform_(p)
|
151 |
+
|
152 |
+
def get_causal_mask(self, x):
|
153 |
+
"""
|
154 |
+
Generates causal mask for decoding
|
155 |
+
"""
|
156 |
+
seq_len = x.size(-1) # x = (batch_size x seq_len)
|
157 |
+
attn_shape = (1, seq_len, seq_len)
|
158 |
+
subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8') # k = 1 shifts the diagonal, so that the main diagonal gets 0's
|
159 |
+
return (torch.from_numpy(subsequent_mask) == 0).to(self.device) # (1, seq_len x seq_len)
|
160 |
+
# True along main diagonal + below, False elsewhere
|
161 |
+
|
162 |
+
def get_pad_mask(self, x, pad_idx):
|
163 |
+
"""
|
164 |
+
Generates padding mask
|
165 |
+
"""
|
166 |
+
return (x != pad_idx).unsqueeze(1).unsqueeze(-2).to(self.device)
|
167 |
+
# (batch_size x 1 x 1 x seq_len)
|
168 |
+
|
169 |
+
def forward(self, x, targets=None):
|
170 |
+
|
171 |
+
# should alr be int64 tokens but explicit cast in case
|
172 |
+
x = x.to(torch.int64)
|
173 |
+
B, S = x.shape
|
174 |
+
|
175 |
+
# get mask
|
176 |
+
mask = self.get_pad_mask(x, self.pad_idx) & self.get_causal_mask(x).to(self.device)
|
177 |
+
# mask = (batch_size x 1 x seq_len x seq_len)
|
178 |
+
|
179 |
+
tok_emb = self.token_embedding(x)
|
180 |
+
pos_emb = self.position_embedding(torch.arange(S))
|
181 |
+
x = self.drop(tok_emb + pos_emb)
|
182 |
+
# (B, S, n_embd)
|
183 |
+
for block in self.blocks:
|
184 |
+
x = block(x, ~mask) # (batch_size, seq_length, n_embd)
|
185 |
+
# negate mask to fill originally False values with -inf later
|
186 |
+
logits = self.lm_head(x) # (batch_size, seq_length, vocab_size)
|
187 |
+
|
188 |
+
# this code assumes teacher forcing——for each text of seq length S we have S autoregressive predictions,
|
189 |
+
# thus we have B*S logits and B*S targets
|
190 |
+
if targets is None:
|
191 |
+
loss = None
|
192 |
+
else:
|
193 |
+
B, S, C = logits.shape
|
194 |
+
logits = logits.view(B*S, C)
|
195 |
+
targets = targets.view(B*S)
|
196 |
+
loss = F.cross_entropy(logits, targets, ignore_index=self.pad_idx)
|
197 |
+
# we need to make sure loss ignores the padding token right?
|
198 |
+
# this helps it avoid wasting compute on learning PAD -> PAD, etc.
|
199 |
+
|
200 |
+
return logits, loss
|
201 |
+
|
202 |
+
|
203 |
+
def generate(self, input_ids, method='multinomial',
|
204 |
+
max_new_tokens=1000, temp=None,
|
205 |
+
num_beams=None, p_nucleus=None, k=None):
|
206 |
+
|
207 |
+
# TODO: see Huggingface's .generate() function
|
208 |
+
# https://huggingface.co/transformers/v3.4.0/_modules/transformers/generation_utils.html
|
209 |
+
|
210 |
+
if method == 'temperature':
|
211 |
+
assert (temp is not None) and (0 < temp) and (temp <= 1)
|
212 |
+
# if method == 'num_beams':
|
213 |
+
# assert isinstance(num_beams, int) and (num_beams) > 0 and (num_beams) < 100
|
214 |
+
if method == 'top-k':
|
215 |
+
assert isinstance(k, int) and (k > 0)
|
216 |
+
|
217 |
+
# input_ids begins as (batch_size, seq_length)
|
218 |
+
|
219 |
+
for _ in range(max_new_tokens):
|
220 |
+
if method in ['multinomial', 'temperature', 'greedy', 'nucleus', 'top-k']:
|
221 |
+
# i) Truncate to the most recent `max length` tokens
|
222 |
+
text_cond = input_ids[:, -self.seq_length:]
|
223 |
+
# ii) Retrieve predictions
|
224 |
+
logits, loss = self(text_cond) # no loss because no targets ofc
|
225 |
+
# model output: (batch_size, seq_length, vocab_size)
|
226 |
+
# iii) Find last token logits of each
|
227 |
+
logits = logits[:, -1, :] # (batch_size, vocab_size)
|
228 |
+
|
229 |
+
# aside: if temperature sampling, divide logits by temp before applying softmax
|
230 |
+
if method == 'temperature':
|
231 |
+
logits = logits / temp
|
232 |
+
|
233 |
+
# iv) Take softmax along each
|
234 |
+
probs = F.softmax(logits, dim=-1)
|
235 |
+
|
236 |
+
# v) Sample next token depending on method
|
237 |
+
if method == 'greedy':
|
238 |
+
next_idx = probs.argmax(dim=-1).unsqueeze(-1)
|
239 |
+
|
240 |
+
elif method in ['multinomial', 'temperature', 'nucleus', 'top-k']:
|
241 |
+
if method == 'nucleus':
|
242 |
+
assert p_nucleus is not None and (0 < p_nucleus) and (p_nucleus <= 1)
|
243 |
+
|
244 |
+
sorted_probs, sorted_idx = probs.sort(dim=-1, descending=True)
|
245 |
+
prob_cumsum = sorted_probs.cumsum(dim=-1)
|
246 |
+
idx_remove = prob_cumsum > p_nucleus
|
247 |
+
# shift one right to ensure the first token is above the threshold
|
248 |
+
idx_remove[..., 1:] = idx_remove[..., :-1].clone()
|
249 |
+
idx_remove[..., 0] = False
|
250 |
+
# retrieve original indices by reverse-sorting
|
251 |
+
remove_mask = idx_remove.gather(dim=-1,
|
252 |
+
index=sorted_idx.argsort(dim=-1))
|
253 |
+
# ^ specifically, we do this by first argsorting the indices which were returned from argsort. this is crazy y'all
|
254 |
+
# you can show that this returns indices that when used to subset a sorted array, returns the original array in unsorted order
|
255 |
+
# https://stackoverflow.com/questions/52127723/pytorch-better-way-to-get-back-original-tensor-order-after-torch-sort
|
256 |
+
# torch.gather is how we apply a multi-dimensional index
|
257 |
+
# https://stackoverflow.com/questions/50999977/what-does-the-gather-function-do-in-pytorch-in-layman-terms
|
258 |
+
probs[remove_mask] = 0
|
259 |
+
|
260 |
+
if method == 'top-k':
|
261 |
+
remove_mask = probs < torch.topk(probs, k).values[..., -1, None] # the topk returns (B, 1), leaving only the
|
262 |
+
# kth largest probs (i.e. the cutoff value for each). Then mask is same size as probs (B, vocab_size)
|
263 |
+
probs[remove_mask] = 0
|
264 |
+
|
265 |
+
# Sample probabilistically via scores
|
266 |
+
next_idx = torch.multinomial(probs, num_samples=1) # (batch_size, 1)
|
267 |
+
|
268 |
+
# vi) Autoregressively append to input_text
|
269 |
+
input_ids = torch.cat((input_ids, next_idx), dim=-1)
|
270 |
+
# end prematurely if <EOS> generated
|
271 |
+
if next_idx == self.eos_token_id:
|
272 |
+
break
|
273 |
+
# now input_text = (batch_size, seq_length + 1)
|
274 |
+
|
275 |
+
return input_ids
|
276 |
+
|
277 |
+
# END OF Better Transformer Class –––––––––––––––––––––––––––––––––––––––––––––––
|
278 |
+
|
279 |
+
def set_seed(seed = 42):
|
280 |
+
random.seed(seed)
|
281 |
+
np.random.seed(seed)
|
282 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
283 |
+
torch.manual_seed(seed)
|
284 |
+
torch.cuda.manual_seed(seed)
|
285 |
+
# torch.cuda.manual_seed_all(seed) # if multi-GPU
|
286 |
+
torch.backends.cudnn.deterministic=True # only applies to CUDA convolution operations
|
287 |
+
torch.backends.cudnn.benchmark = False
|
288 |
+
# usually CuDNN has heuristics as to which algorithm to pick. cudnn.benchmark benchmarks several algorithms and picks the fastest
|
289 |
+
# often helpful if your input shapes are fixed and not changing a lot during training
|
290 |
+
# however, this means it may pick a different algorithm even when the deterministic flag is set.
|
291 |
+
# As such it is good practice to turn off cudnn.benchmark when turning on cudnn.deterministic
|
292 |
+
|
293 |
+
def load_tokenizer(device):
|
294 |
+
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")
|
295 |
+
if tokenizer.pad_token is None:
|
296 |
+
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
297 |
+
EMPTY_TOKENS = torch.full((1,1), tokenizer.bos_token_id, dtype=torch.long).to(device)
|
298 |
+
return tokenizer, EMPTY_TOKENS
|
299 |
+
|
300 |
+
|
301 |
+
def load_big_model(tokenizer, device):
|
302 |
+
## Model architecture
|
303 |
+
set_seed(42)
|
304 |
+
N_HEAD = 16
|
305 |
+
N_LAYER = 8
|
306 |
+
N_EMBD = 768
|
307 |
+
VOCAB_SIZE = 50258
|
308 |
+
SEQ_LENGTH = 384
|
309 |
+
|
310 |
+
model = BetterTransformer(VOCAB_SIZE, SEQ_LENGTH, N_EMBD, N_HEAD, N_LAYER, tokenizer.pad_token_id, tokenizer.eos_token_id, device=device)
|
311 |
+
model.init_params()
|
312 |
+
path = MODEL_FILE
|
313 |
+
model.load_state_dict(torch.load(path, map_location=device)["model_state_dict"])
|
314 |
+
|
315 |
+
return model
|
316 |
+
|
317 |
+
def generate(model, tokenizer, device, method=None, k=None,
|
318 |
+
p_nucleus=None, temp=None, max_new_tokens=None, cond="", deterministic=None):
|
319 |
+
"""
|
320 |
+
Wrapper for generating text using the specified model. Generates unconditionally if cond=None.
|
321 |
+
|
322 |
+
Inputs:
|
323 |
+
-model: Decoder model to be used for text generation
|
324 |
+
-tokenizer: Compatible tokenizer
|
325 |
+
-device: Device of model (CPU/CUDA)
|
326 |
+
-method (str): Decoding method for text generation ('multinomial', 'temperature', 'greedy', 'nucleus', or 'top-k')
|
327 |
+
-k (int): Positive integer for top-k logits to sample if top-k decoding
|
328 |
+
-p_nucleus (float/int): Cumulative probability cutoff if nucleus/top-p decoding
|
329 |
+
-temp (float/int): Temperature if temperature decoding
|
330 |
+
-max_new_tokens (int): Maximum number of tokens to generate
|
331 |
+
-cond (str=None): If provided, will serve as conditional prompt for text generation
|
332 |
+
-deterministic (int): If deterministic, uses the specified seed for model generation
|
333 |
+
Returns:
|
334 |
+
-res (str): Generated text string
|
335 |
+
"""
|
336 |
+
assert method in ['multinomial', 'temperature', 'greedy', 'nucleus', 'top-k'], \
|
337 |
+
"method must be 'multinomial', 'temperature', 'greedy', 'nucleus', or 'top-k'"
|
338 |
+
|
339 |
+
#if method == 'temperature':
|
340 |
+
# assert (temp is not None) and isinstance(temp, (int, float)) and (0 < temp) and (temp <= 1), \
|
341 |
+
# "temp must be defined as a number between (0, 1]"
|
342 |
+
#if method == 'nucleus':
|
343 |
+
# assert (p_nucleus is not None) and isinstance(p_nucleus, (int, float)) and (0 < p_nucleus) and (p_nucleus <= 1), \
|
344 |
+
# "p_nucleus must be defined as a number between (0, 1]"
|
345 |
+
## if method == 'num_beams':
|
346 |
+
## assert isinstance(num_beams, int) and (num_beams) > 0 and (num_beams) < 100
|
347 |
+
#if method == 'top-k':
|
348 |
+
# assert (k is not None) and isinstance(k, int) and (k > 0) and (k < SEQ_LENGTH), \
|
349 |
+
# "k must be defined as an integer greater than 0 and less than the model sequence length"
|
350 |
+
|
351 |
+
#if max_new_tokens is None:
|
352 |
+
# print('No max_new_tokens provided, using a default value of 250\n')
|
353 |
+
# max_new_tokens = 250
|
354 |
+
|
355 |
+
#assert (max_new_tokens is not None) and isinstance(max_new_tokens, int) and (max_new_tokens) > 0 and (max_new_tokens) <= 1000, \
|
356 |
+
#"max_new_tokens must be an integer between (0, 1000]"
|
357 |
+
|
358 |
+
if deterministic is not None:
|
359 |
+
set_seed(deterministic)
|
360 |
+
|
361 |
+
if cond != "":
|
362 |
+
|
363 |
+
cond_tokens = tokenizer(cond).input_ids
|
364 |
+
|
365 |
+
gen_tokens = model.generate(torch.tensor(cond_tokens).unsqueeze(0).long().to(device),
|
366 |
+
method=method, k=k, p_nucleus=p_nucleus, temp=temp,
|
367 |
+
max_new_tokens=max_new_tokens)[0]
|
368 |
+
|
369 |
+
# Insert delimiter to indicate where prompt ends
|
370 |
+
gen_prep = torch.zeros(len(gen_tokens)+2).long() # make space for two more tokens for delimiter
|
371 |
+
gen_prep -= 1
|
372 |
+
gen_prep[:len(cond_tokens)] = gen_tokens[:len(cond_tokens)]
|
373 |
+
gen_prep[-(len(gen_tokens)-len(cond_tokens)):] = gen_tokens[-(len(gen_tokens)-len(cond_tokens)):]
|
374 |
+
gen_prep[gen_prep == -1] = torch.tensor(tokenizer.encode(' || ')) # insert tokens for || in between
|
375 |
+
|
376 |
+
res = tokenizer.decode(gen_prep)
|
377 |
+
res = re.sub(re.escape(tokenizer.bos_token), '', res, count=1) ## Remove end token
|
378 |
+
|
379 |
+
|
380 |
+
else:
|
381 |
+
empty_tokens = torch.full((1,1), tokenizer.bos_token_id, dtype=torch.long).to(device)
|
382 |
+
|
383 |
+
res = tokenizer.batch_decode(model.generate(empty_tokens,
|
384 |
+
method=method, k=k,
|
385 |
+
p_nucleus=p_nucleus, temp=temp,
|
386 |
+
max_new_tokens=max_new_tokens))[0]
|
387 |
+
|
388 |
+
res = re.sub(re.escape(tokenizer.bos_token), '', res, count=2) ## Remove start and end tokens
|
389 |
+
|
390 |
+
# Clean up Unicode character issues
|
391 |
+
# '“' then 'â€' = opening and closing double quotes
|
392 |
+
# '’' = apostrophe
|
393 |
+
res = re.sub(r'“', '"', res)
|
394 |
+
res = re.sub(r'’', "'", res)
|
395 |
+
res = re.sub(r'â€', '"', res)
|
396 |
+
res = res + " [END]" ## better end token
|
397 |
+
return res
|
398 |
+
|
399 |
+
|
bt_8_LAYERs_100_DATA_PCT_768_EMBD_DIM_epoch_10.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cb144d619ba6f662571efa9936e655274289d6773d3f4ee37601f16e9484a20d
|
3 |
+
size 1608405399
|
requirements
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiohttp==3.9.1
|
2 |
+
aiosignal==1.3.1
|
3 |
+
altair==4.0.0
|
4 |
+
async-timeout==4.0.3
|
5 |
+
attrs==23.1.0
|
6 |
+
blinker==1.7.0
|
7 |
+
cachetools==5.3.2
|
8 |
+
certifi==2023.11.17
|
9 |
+
charset-normalizer==3.3.2
|
10 |
+
click==8.1.7
|
11 |
+
datasets==2.15.0
|
12 |
+
dill==0.3.7
|
13 |
+
entrypoints==0.4
|
14 |
+
filelock==3.13.1
|
15 |
+
frozenlist==1.4.0
|
16 |
+
fsspec==2023.10.0
|
17 |
+
gitdb==4.0.11
|
18 |
+
GitPython==3.1.40
|
19 |
+
huggingface-hub==0.19.4
|
20 |
+
idna==3.6
|
21 |
+
importlib-metadata==7.0.0
|
22 |
+
Jinja2==3.1.2
|
23 |
+
jsonschema==4.20.0
|
24 |
+
jsonschema-specifications==2023.11.2
|
25 |
+
markdown-it-py==3.0.0
|
26 |
+
MarkupSafe==2.1.3
|
27 |
+
mdurl==0.1.2
|
28 |
+
mpmath==1.3.0
|
29 |
+
multidict==6.0.4
|
30 |
+
multiprocess==0.70.15
|
31 |
+
networkx==3.2.1
|
32 |
+
numpy==1.26.2
|
33 |
+
packaging==23.2
|
34 |
+
pandas==2.1.3
|
35 |
+
Pillow==10.1.0
|
36 |
+
protobuf==3.20.3
|
37 |
+
pyarrow==14.0.1
|
38 |
+
pyarrow-hotfix==0.6
|
39 |
+
pydeck==0.8.1b0
|
40 |
+
Pygments==2.17.2
|
41 |
+
Pympler==1.0.1
|
42 |
+
python-dateutil==2.8.2
|
43 |
+
pytz==2023.3.post1
|
44 |
+
PyYAML==6.0.1
|
45 |
+
referencing==0.31.1
|
46 |
+
regex==2023.10.3
|
47 |
+
requests==2.31.0
|
48 |
+
rich==13.7.0
|
49 |
+
rpds-py==0.13.2
|
50 |
+
safetensors==0.4.1
|
51 |
+
semver==3.0.2
|
52 |
+
six==1.16.0
|
53 |
+
smmap==5.0.1
|
54 |
+
streamlit==1.12.0
|
55 |
+
sympy==1.12
|
56 |
+
tokenizers==0.15.0
|
57 |
+
toml==0.10.2
|
58 |
+
toolz==0.12.0
|
59 |
+
torch==2.1.1
|
60 |
+
tornado==6.4
|
61 |
+
tqdm==4.66.1
|
62 |
+
transformers==4.35.2
|
63 |
+
typing_extensions==4.8.0
|
64 |
+
tzdata==2023.3
|
65 |
+
tzlocal==5.2
|
66 |
+
urllib3==2.1.0
|
67 |
+
validators==0.22.0
|
68 |
+
xxhash==3.4.1
|
69 |
+
yarl==1.9.3
|
70 |
+
zipp==3.17.0
|