Nanobit commited on
Commit
572d114
1 Parent(s): a6190c8

Set mem cache args on inference

Browse files
Files changed (1) hide show
  1. scripts/finetune.py +6 -0
scripts/finetune.py CHANGED
@@ -77,6 +77,11 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
77
  importlib.import_module("axolotl.prompters"), prompter
78
  )
79
 
 
 
 
 
 
80
  while True:
81
  print("=" * 80)
82
  # support for multiline inputs
@@ -90,6 +95,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
90
  else:
91
  prompt = instruction.strip()
92
  batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
 
93
  print("=" * 40)
94
  model.eval()
95
  with torch.no_grad():
 
77
  importlib.import_module("axolotl.prompters"), prompter
78
  )
79
 
80
+ if cfg.landmark_attention:
81
+ model.set_mem_cache_args(
82
+ max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
83
+ )
84
+
85
  while True:
86
  print("=" * 80)
87
  # support for multiline inputs
 
95
  else:
96
  prompt = instruction.strip()
97
  batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
98
+
99
  print("=" * 40)
100
  model.eval()
101
  with torch.no_grad():