Corianas commited on
Commit
b96281c
1 Parent(s): 11407c5

Update to version 0.2

Browse files

Stopped loading the model unnecessarily

Files changed (2) hide show
  1. llm_chargpt.py +33 -31
  2. pyproject.toml +1 -1
llm_chargpt.py CHANGED
@@ -273,46 +273,48 @@ def add_caseifer(text):
273
  model_dir = '16bit'
274
  device = 'cuda'
275
  dtype = 'bfloat16'
276
- torch.backends.cuda.matmul.allow_tf32 = True
277
- torch.backends.cudnn.allow_tf32 = True
278
- device_type = 'cuda' if 'cuda' in device else 'cpu'
279
- ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
280
- ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
281
- max_new_tokens = 2048 # number of tokens generated in each sample
282
- temperature = 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
283
- top_k = 24 # retain only the top_k most likely tokens, clamp others to have 0 probability
284
-
285
- ckpt_path = os.path.join(model_dir, 'ckpt.pt')
286
- checkpoint = torch.load(ckpt_path, map_location=device)
287
- gptconf = GPTConfig(**checkpoint['model_args'])
288
- model = GPT(gptconf)
289
- state_dict = checkpoint['model']
290
- unwanted_prefix = '_orig_mod.'
291
- for k,v in list(state_dict.items()):
292
- if k.startswith(unwanted_prefix):
293
- state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
294
- model.load_state_dict(state_dict)
295
-
296
- model.eval()
297
- model.to(device)
298
- meta_path = os.path.join(model_dir, 'meta.pkl')
299
- load_meta = os.path.exists(meta_path)
300
- with open(meta_path, 'rb') as f:
301
- meta = pickle.load(f)
302
- # TODO want to make this more general to arbitrary encoder/decoder schemes
303
- stoi, itos = meta['stoi'], meta['itos']
304
- encode = lambda s: [stoi[c] for c in s]
305
- decode = lambda l: ''.join([itos[i] for i in l])
306
 
307
  class CharGPT(llm.Model):
308
  model_id = "chargpt"
309
-
 
310
  def execute(self, prompt, stream, response, conversation):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  text = prompt.prompt
312
  shift = False
313
  # generated_text = ''
314
  start_ids = encode(add_caseifer(text))
315
  x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
 
316
  for idx_next in model.generate_streaming(x, max_new_tokens, temperature=temperature, top_k=top_k):
317
  # convert the index to a character and print it to the screen
318
  char = decode([idx_next])
 
273
  model_dir = '16bit'
274
  device = 'cuda'
275
  dtype = 'bfloat16'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
 
277
  class CharGPT(llm.Model):
278
  model_id = "chargpt"
279
+
280
+
281
  def execute(self, prompt, stream, response, conversation):
282
+
283
+ torch.backends.cuda.matmul.allow_tf32 = True
284
+ torch.backends.cudnn.allow_tf32 = True
285
+ device_type = 'cuda' if 'cuda' in device else 'cpu'
286
+ ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
287
+ ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
288
+ max_new_tokens = 2048 # number of tokens generated in each sample
289
+ temperature = 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
290
+ top_k = 24 # retain only the top_k most likely tokens, clamp others to have 0 probability
291
+
292
+ ckpt_path = os.path.join(model_dir, 'ckpt.pt')
293
+ checkpoint = torch.load(ckpt_path, map_location=device)
294
+ gptconf = GPTConfig(**checkpoint['model_args'])
295
+ model = GPT(gptconf)
296
+ state_dict = checkpoint['model']
297
+ unwanted_prefix = '_orig_mod.'
298
+ for k,v in list(state_dict.items()):
299
+ if k.startswith(unwanted_prefix):
300
+ state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
301
+ model.load_state_dict(state_dict)
302
+
303
+ model.eval()
304
+ model.to(device)
305
+ meta_path = os.path.join(model_dir, 'meta.pkl')
306
+ with open(meta_path, 'rb') as f:
307
+ meta = pickle.load(f)
308
+ # TODO want to make this more general to arbitrary encoder/decoder schemes
309
+ stoi, itos = meta['stoi'], meta['itos']
310
+ encode = lambda s: [stoi[c] for c in s]
311
+ decode = lambda l: ''.join([itos[i] for i in l])
312
  text = prompt.prompt
313
  shift = False
314
  # generated_text = ''
315
  start_ids = encode(add_caseifer(text))
316
  x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
317
+ print(text, end='', flush=True)
318
  for idx_next in model.generate_streaming(x, max_new_tokens, temperature=temperature, top_k=top_k):
319
  # convert the index to a character and print it to the screen
320
  char = decode([idx_next])
pyproject.toml CHANGED
@@ -1,6 +1,6 @@
1
  [project]
2
  name = "chargpt"
3
- version = "0.1"
4
 
5
  [project.entry-points.llm]
6
  chargpt = "llm_chargpt"
 
1
  [project]
2
  name = "chargpt"
3
+ version = "0.2"
4
 
5
  [project.entry-points.llm]
6
  chargpt = "llm_chargpt"