improve inference
Browse files- scripts/finetune.py +25 -25
- src/axolotl/utils/models.py +18 -15
scripts/finetune.py
CHANGED
@@ -79,31 +79,31 @@ def do_inference(cfg, model, tokenizer):
|
|
79 |
|
80 |
from axolotl.prompters import ReflectAlpacaPrompter
|
81 |
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
|
108 |
|
109 |
def choose_config(path: Path):
|
|
|
79 |
|
80 |
from axolotl.prompters import ReflectAlpacaPrompter
|
81 |
|
82 |
+
while True:
|
83 |
+
instruction = str(input("Give me an instruction: "))
|
84 |
+
if not instruction:
|
85 |
+
return
|
86 |
+
prompt = ReflectAlpacaPrompter().build_prompt(instruction=instruction)
|
87 |
+
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
88 |
+
|
89 |
+
model.eval()
|
90 |
+
with torch.no_grad():
|
91 |
+
# gc = GenerationConfig() # TODO swap out and use this
|
92 |
+
generated = model.generate(
|
93 |
+
inputs=batch["input_ids"].to("cuda"),
|
94 |
+
do_sample=True,
|
95 |
+
use_cache=True,
|
96 |
+
repetition_penalty=1.1,
|
97 |
+
max_new_tokens=100,
|
98 |
+
temperature=0.9,
|
99 |
+
top_p=0.95,
|
100 |
+
top_k=40,
|
101 |
+
return_dict_in_generate=True,
|
102 |
+
output_attentions=False,
|
103 |
+
output_hidden_states=False,
|
104 |
+
output_scores=False,
|
105 |
+
)
|
106 |
+
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
|
107 |
|
108 |
|
109 |
def choose_config(path: Path):
|
src/axolotl/utils/models.py
CHANGED
@@ -66,22 +66,25 @@ def load_model(
|
|
66 |
from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
|
67 |
from huggingface_hub import snapshot_download
|
68 |
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
if len(files) > 0:
|
79 |
-
model_path = str(files[0])
|
80 |
-
else:
|
81 |
-
logging.warning(
|
82 |
-
"unable to find a cached model file, this will likely fail..."
|
83 |
)
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
model, tokenizer = load_llama_model_4bit_low_ram(
|
86 |
base_model_config if base_model_config else base_model,
|
87 |
model_path,
|
|
|
66 |
from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
|
67 |
from huggingface_hub import snapshot_download
|
68 |
|
69 |
+
try:
|
70 |
+
snapshot_download_kwargs = {}
|
71 |
+
if cfg.base_model_ignore_patterns:
|
72 |
+
snapshot_download_kwargs["ignore_patterns"] = cfg.base_model_ignore_patterns
|
73 |
+
cache_model_path = Path(snapshot_download(base_model, ** snapshot_download_kwargs))
|
74 |
+
files = (
|
75 |
+
list(cache_model_path.glob("*.pt"))
|
76 |
+
+ list(cache_model_path.glob("*.safetensors"))
|
77 |
+
+ list(cache_model_path.glob("*.bin"))
|
|
|
|
|
|
|
|
|
|
|
78 |
)
|
79 |
+
if len(files) > 0:
|
80 |
+
model_path = str(files[0])
|
81 |
+
else:
|
82 |
+
logging.warning(
|
83 |
+
"unable to find a cached model file, this will likely fail..."
|
84 |
+
)
|
85 |
+
model_path = str(cache_model_path)
|
86 |
+
except:
|
87 |
+
model_path = cfg.base_model
|
88 |
model, tokenizer = load_llama_model_4bit_low_ram(
|
89 |
base_model_config if base_model_config else base_model,
|
90 |
model_path,
|