PEFT
Safetensors
GGUF
German
trl
sft
Generated from Trainer
Inference Endpoints
conversational
JanPf commited on
Commit
f9a712c
·
verified ·
1 Parent(s): a2f572a

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +56 -1
README.md CHANGED
@@ -20,4 +20,59 @@ license: other
20
  # LLäMmlein 1B Chat
21
 
22
  This is a chat adapter for the German Tinyllama 1B language model.
23
- Find more details on our [page](https://www.informatik.uni-wuerzburg.de/datascience/projects/nlp/llammlein/) and our [preprint](arxiv.org/abs/2411.11171)!
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  # LLäMmlein 1B Chat
21
 
22
  This is a chat adapter for the German Tinyllama 1B language model.
23
+ Find more details on our [page](https://www.informatik.uni-wuerzburg.de/datascience/projects/nlp/llammlein/) and our [preprint](arxiv.org/abs/2411.11171)!
24
+
25
+ ## Run it
26
+ ```py
27
+ import torch
28
+ from peft import PeftConfig, PeftModel
29
+ from transformers import AutoModelForCausalLM, AutoTokenizer
30
+
31
+ torch.manual_seed(42)
32
+
33
+ # script config
34
+ base_model_name = "LSX-UniWue/llammchen_1b"
35
+ chat_adapter_name = "LSX-UniWue/LLaMmlein_1B_chat_selected"
36
+ device = "mps" # or cuda
37
+
38
+ # chat history
39
+ messages = [
40
+ {
41
+ "role": "user",
42
+ "content": """Na wie geht's?""",
43
+ },
44
+ ]
45
+
46
+ # load model
47
+ config = PeftConfig.from_pretrained(chat_adapter_name)
48
+ base_model = model = AutoModelForCausalLM.from_pretrained(
49
+ base_model_name,
50
+ attn_implementation="flash_attention_2" if device == "cuda" else None,
51
+ torch_dtype=torch.bfloat16,
52
+ device_map=device,
53
+ )
54
+ base_model.resize_token_embeddings(32064)
55
+ model = PeftModel.from_pretrained(base_model, chat_adapter_name)
56
+ tokenizer = AutoTokenizer.from_pretrained(chat_adapter_name)
57
+
58
+ # encode message in "ChatML" format
59
+ chat = tokenizer.apply_chat_template(
60
+ messages,
61
+ return_tensors="pt",
62
+ add_generation_prompt=True,
63
+ ).to(device)
64
+
65
+ # generate response
66
+ print(
67
+ tokenizer.decode(
68
+ model.generate(
69
+ chat,
70
+ max_new_tokens=300,
71
+ pad_token_id=tokenizer.pad_token_id,
72
+ eos_token_id=tokenizer.eos_token_id,
73
+ )[0],
74
+ skip_special_tokens=False,
75
+ )
76
+ )
77
+
78
+ ```