ariG23498 HF staff commited on
Commit
48b0a57
1 Parent(s): ae81e0f

chore: adding the app

Browse files
Files changed (1) hide show
  1. app.py +168 -4
app.py CHANGED
@@ -1,7 +1,171 @@
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
1
+ import sys
2
+ sys.path.append("../")
3
+
4
+ import torch
5
  import gradio as gr
6
+ from omegaconf import OmegaConf
7
+ from transformers import AutoTokenizer
8
+ from huggingface_hub import hf_hub_download
9
+
10
+ from src.utils.setup import seed_everything
11
+ from src.utils.logging import print_header
12
+ from src.model.pretrained import get_pretrained_loader
13
+ from src.model.load_model import load_and_convert_attns, load_and_convert_finetune
14
+
15
+ def load_model_from_checkpoint(
16
+ attn_mlp_checkpoint_path: str = None,
17
+ finetune_checkpoint_path: str = None,
18
+ model_config_path: str = None,
19
+ distill_config_path: str = None,
20
+ finetune_config_path: str = None,
21
+ config_dir: str = 'configs',
22
+ print_model: bool = False,
23
+ debug: bool = False,
24
+ huggingface_token: str = None,
25
+ use_cuda_kernels: bool = False,
26
+ use_attention: bool = False
27
+ ):
28
+
29
+ is_local = attn_mlp_checkpoint_path.endswith(".pt")
30
+
31
+ model_config = OmegaConf.load(model_config_path)
32
+ distill_config = OmegaConf.load(distill_config_path)
33
+ finetune_config = OmegaConf.load(finetune_config_path)
34
+
35
+ model_loader = get_pretrained_loader(**model_config.model,
36
+ huggingface_token=huggingface_token)
37
+ tokenizer = model_loader.load_tokenizer()
38
+ tokenizer.pad_token_id = tokenizer.eos_token_id
39
+ tokenizer.padding_side = 'left'
40
+ if use_attention:
41
+ model = model_loader.load('softmax')
42
+ return model, model_config, tokenizer
43
+
44
+ model = model_loader.load(model_config['attention']['attention_type'])
45
+ if use_cuda_kernels:
46
+ print('*** Using TK CUDA kernels **')
47
+ model_config['attention']['attention_type'] = 'lolcats_llama_window_tk_gen'
48
+
49
+ if is_local:
50
+ checkpoint_path = attn_mlp_checkpoint_path
51
+ else:
52
+ checkpoint_path = None
53
+ model, distill_peft_config = load_and_convert_attns(
54
+ model, model_config,
55
+ attention_type=None,
56
+ checkpoint_path=checkpoint_path,
57
+ print_model=debug,
58
+ merge_loras=False,
59
+ peft_gradient_checkpointing=False,
60
+ train_attention=False)
61
+
62
+ if is_local:
63
+ checkpoint_path = attn_mlp_checkpoint_path
64
+ else:
65
+ checkpoint_path = None
66
+ model, ft_peft_config = load_and_convert_finetune(
67
+ model, finetune_config,
68
+ checkpoint_path=checkpoint_path,
69
+ print_model=debug,
70
+ merge_loras=False,
71
+ peft_gradient_checkpointing=False)
72
+
73
+ if not is_local:
74
+ model = load_hf_weights(
75
+ model,
76
+ attn_mlp_checkpoint_path, finetune_checkpoint_path,
77
+ filename="model.pt"
78
+ )
79
+ if use_cuda_kernels:
80
+ print('*** Using TK CUDA kernels ***')
81
+
82
+ if print_model:
83
+ print('*** Model after checkpoint load ***')
84
+ print(model)
85
+
86
+ return model, model_config, tokenizer
87
+
88
+ def load_hf_weights(model, distill_repo_id, ft_repo_id, filename="model.pt"):
89
+ for repo_id in [distill_repo_id, ft_repo_id]:
90
+ if repo_id is None: continue
91
+
92
+ print(f"Loading weights from {repo_id}")
93
+
94
+ local_file_path = hf_hub_download(repo_id=repo_id, filename=filename)
95
+ state_dict = torch.load(local_file_path)
96
+ if 'model_state_dict' in state_dict:
97
+ state_dict = state_dict['model_state_dict']
98
+ else:
99
+ pass
100
+ _keys = model.load_state_dict(state_dict, strict=False)
101
+ if len(_keys.unexpected_keys) > 0:
102
+ new_state_dict = {k.replace('model.', 'model.model.'): v for k, v in state_dict.items()}
103
+ _keys = model.load_state_dict(new_state_dict, strict=False)
104
+ if len(_keys.unexpected_keys) > 0:
105
+ new_state_dict = {k.replace('model.', 'base_model.model.model.'): v for k, v in state_dict.items()}
106
+ _keys = model.load_state_dict(new_state_dict, strict=False)
107
+
108
+ try:
109
+ assert len(_keys.unexpected_keys) == 0
110
+ print('*** All expected keys matched successfully ***')
111
+ except Exception as e:
112
+ print(e)
113
+ print('*** Error: unexpected keys in checkpoint - please fix ***')
114
+ print('Unexpected keys:')
115
+ for k in _keys.unexpected_keys:
116
+ print(k)
117
+ exit()
118
+
119
+ return model
120
+
121
+ def load_model_and_tokenizer():
122
+ CONFIG_DIR = 'configs' # Update to your path
123
+
124
+ model_config_path = f"{CONFIG_DIR}/model/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml"
125
+ distill_config_path = f"{CONFIG_DIR}/experiment/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml"
126
+ finetune_config_path = f"{CONFIG_DIR}/experiment/finetune_lora_qkvo_alpaca_clean.yaml"
127
+ attn_mlp_checkpoint_path = 'hazyresearch/lolcats-llama-3.1-8b-distill'
128
+ finetune_checkpoint_path = 'hazyresearch/lolcats-llama-3.1-8b-ft-lora'
129
+
130
+ model, model_config, tokenizer = load_model_from_checkpoint(
131
+ attn_mlp_checkpoint_path=attn_mlp_checkpoint_path,
132
+ finetune_checkpoint_path=finetune_checkpoint_path,
133
+ model_config_path=model_config_path,
134
+ distill_config_path=distill_config_path,
135
+ finetune_config_path=finetune_config_path,
136
+ config_dir=CONFIG_DIR,
137
+ print_model=False,
138
+ debug=False,
139
+ huggingface_token=None,
140
+ use_cuda_kernels=False,
141
+ use_attention=False
142
+ )
143
+ model = model.to('cuda')
144
+ model.eval()
145
+ return model, tokenizer
146
+
147
+ model, tokenizer = load_model_and_tokenizer()
148
+
149
+ def generate_response(prompt):
150
+ all_prompts = [prompt]
151
+
152
+ with torch.no_grad():
153
+ model_input = tokenizer(all_prompts, return_tensors="pt").to(model.device)
154
+ model_output = model.generate(
155
+ **model_input, use_cache=True,
156
+ max_new_tokens=50,
157
+ do_sample=False,
158
+ top_k=1,
159
+ top_p=1.0,
160
+ num_return_sequences=1,
161
+ pad_token_id=tokenizer.eos_token_id)
162
+ generated_tokens = model_output[0]
163
+ input_len = model_input['input_ids'].shape[1]
164
+ generated_tokens = generated_tokens[input_len:]
165
+ generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
166
+
167
+ return generated_text
168
 
169
+ iface = gr.Interface(fn=generate_response, inputs="text", outputs="text", title="LOLcats Model Demo")
 
170
 
171
+ iface.launch()