File size: 1,888 Bytes
579cd1a
 
ccfdebe
 
 
 
16c6a3a
 
 
 
 
95564ae
 
 
 
 
 
 
 
 
 
 
 
16c6a3a
e9cb845
16c6a3a
 
95564ae
bb273c4
16c6a3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
---
license: apache-2.0
datasets:
- jojo0217/korean_rlhf_dataset
language:
- ko
---

์„ฑ๊ท ๊ด€๋Œ€ํ•™๊ต ์‚ฐํ•™ํ˜‘๋ ฅ ๊ณผ์ •์—์„œ ๋งŒ๋“  ํ…Œ์ŠคํŠธ ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค.   
ํ•™์Šต ๋ฐ์ดํ„ฐ์˜ ์ฐธ๊ณ  ๋ชจ๋ธ์ด๋ผ๊ณ  ์ƒ๊ฐํ•˜์‹œ๋ฉด ์ข‹์„ ๊ฒƒ ๊ฐ™์Šต๋‹ˆ๋‹ค.   
๊ธฐ์กด 10๋งŒ 7์ฒœ๊ฐœ์˜ ๋ฐ์ดํ„ฐ + 2์ฒœ๊ฐœ ์ผ์ƒ๋Œ€ํ™” ์ถ”๊ฐ€ ๋ฐ์ดํ„ฐ๋ฅผ ์ฒจ๊ฐ€ํ•˜์—ฌ ํ•™์Šตํ•˜์˜€์Šต๋‹ˆ๋‹ค.   
ํ•™์Šต parameter์€ ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.   
___
batch_size: 128   
micro_batch_size: 8   
num_epochs: 3   
learning_rate: 3e-4   
cutoff_len: 1024   
lora_r: 8   
lora_alpha: 16   
lora_dropout: 0.05   
weight_decay: 0.1   
___

์ธก์ •ํ•œ kobest 10shot ์ ์ˆ˜๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.   
![score](./asset/score.png)    

___
๋ชจ๋ธ prompt template๋Š” kullm์˜ template๋ฅผ ์‚ฌ์šฉํ•˜์˜€์Šต๋‹ˆ๋‹ค.   
ํ…Œ์ŠคํŠธ ์ฝ”๋“œ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.   
```
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer

model_name="jojo0217/ChatSKKU5.8B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
            model_name,
            device_map="auto",
            load_in_8bit=True,#๋งŒ์•ฝ ์–‘์žํ™” ๋„๊ณ  ์‹ถ๋‹ค๋ฉด false
        )
tokenizer = AutoTokenizer.from_pretrained(model_name)
pipe = pipeline(
            "text-generation",
            model=model,
            tokenizer=model_name,
            device_map="auto"
        )

def answer(message):
  prompt=f"์•„๋ž˜๋Š” ์ž‘์—…์„ ์„ค๋ช…ํ•˜๋Š” ๋ช…๋ น์–ด์ž…๋‹ˆ๋‹ค. ์š”์ฒญ์„ ์ ์ ˆํžˆ ์™„๋ฃŒํ•˜๋Š” ์‘๋‹ต์„ ์ž‘์„ฑํ•˜์„ธ์š”.\n\n### ๋ช…๋ น์–ด:\n{message}"
  ans = pipe(
        prompt + "\n\n### ์‘๋‹ต:",
        do_sample=True,
        max_new_tokens=512,
        temperature=0.9,
        num_beams = 1,
        repetition_penalty = 1.0,
        return_full_text=False,
        eos_token_id=2,
    )
  msg = ans[0]["generated_text"]
  return msg
answer('์„ฑ๊ท ๊ด€๋Œ€ํ•™๊ต์—๋Œ€ํ•ด ์•Œ๋ ค์ค˜')
```