Merge pull request #183 from OpenAccess-AI-Collective/inference-from-stdin
Browse files- README.md +5 -0
- scripts/finetune.py +18 -5
README.md
CHANGED
@@ -495,6 +495,11 @@ Pass the appropriate flag to the train command:
|
|
495 |
```bash
|
496 |
--inference --base_model ./completed-model
|
497 |
```
|
|
|
|
|
|
|
|
|
|
|
498 |
|
499 |
### Merge LORA to base
|
500 |
|
|
|
495 |
```bash
|
496 |
--inference --base_model ./completed-model
|
497 |
```
|
498 |
+
- Full weights finetune w/ a prompt from a text file:
|
499 |
+
```bash
|
500 |
+
cat /tmp/prompt.txt | python scripts/finetune.py configs/your_config.yml \
|
501 |
+
--base_model ./completed-model --inference --prompter=None --load_in_8bit=True
|
502 |
+
```
|
503 |
|
504 |
### Merge LORA to base
|
505 |
|
scripts/finetune.py
CHANGED
@@ -71,7 +71,11 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
|
|
71 |
if not (cfg.special_tokens and token in cfg.special_tokens):
|
72 |
tokenizer.add_special_tokens({token: symbol})
|
73 |
|
74 |
-
prompter_module =
|
|
|
|
|
|
|
|
|
75 |
|
76 |
while True:
|
77 |
print("=" * 80)
|
@@ -79,9 +83,12 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
|
|
79 |
instruction = get_multi_line_input()
|
80 |
if not instruction:
|
81 |
return
|
82 |
-
|
83 |
-
|
84 |
-
|
|
|
|
|
|
|
85 |
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
86 |
print("=" * 40)
|
87 |
model.eval()
|
@@ -242,7 +249,13 @@ def train(
|
|
242 |
|
243 |
if "inference" in kwargs:
|
244 |
logging.info("calling do_inference function")
|
245 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
246 |
return
|
247 |
|
248 |
if "shard" in kwargs:
|
|
|
71 |
if not (cfg.special_tokens and token in cfg.special_tokens):
|
72 |
tokenizer.add_special_tokens({token: symbol})
|
73 |
|
74 |
+
prompter_module = None
|
75 |
+
if prompter:
|
76 |
+
prompter_module = getattr(
|
77 |
+
importlib.import_module("axolotl.prompters"), prompter
|
78 |
+
)
|
79 |
|
80 |
while True:
|
81 |
print("=" * 80)
|
|
|
83 |
instruction = get_multi_line_input()
|
84 |
if not instruction:
|
85 |
return
|
86 |
+
if prompter_module:
|
87 |
+
prompt: str = next(
|
88 |
+
prompter_module().build_prompt(instruction=instruction.strip("\n"))
|
89 |
+
)
|
90 |
+
else:
|
91 |
+
prompt = instruction.strip()
|
92 |
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
93 |
print("=" * 40)
|
94 |
model.eval()
|
|
|
249 |
|
250 |
if "inference" in kwargs:
|
251 |
logging.info("calling do_inference function")
|
252 |
+
inf_kwargs: Dict[str, Any] = {}
|
253 |
+
if "prompter" in kwargs:
|
254 |
+
if kwargs["prompter"] == "None":
|
255 |
+
inf_kwargs["prompter"] = None
|
256 |
+
else:
|
257 |
+
inf_kwargs["prompter"] = kwargs["prompter"]
|
258 |
+
do_inference(cfg, model, tokenizer, **inf_kwargs)
|
259 |
return
|
260 |
|
261 |
if "shard" in kwargs:
|