Create modeling_llama3.py
Browse files- modeling_llama3.py +66 -0
modeling_llama3.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
import torch
|
3 |
+
from transformers.generation.logits_process import LogitsProcessor
|
4 |
+
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
5 |
+
from transformers import AutoTokenizer
|
6 |
+
import re
|
7 |
+
|
8 |
+
class FrequencyPenaltyLogitsProcessor(LogitsProcessor):
|
9 |
+
def __init__(self, penalty: float, penalty_dialog: torch.LongTensor, input_length: int):
|
10 |
+
if not isinstance(penalty, float) or not (penalty > 0):
|
11 |
+
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
|
12 |
+
|
13 |
+
self.penalty = penalty
|
14 |
+
self.input_length = input_length
|
15 |
+
self.penalty_dialog = penalty_dialog
|
16 |
+
|
17 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
18 |
+
new_scores = []
|
19 |
+
if self.penalty == 0.0:
|
20 |
+
return scores
|
21 |
+
for input_, score in zip(input_ids, scores):
|
22 |
+
generated_tokens = torch.cat((self.penalty_dialog, input_[self.input_length:]), dim=-1)
|
23 |
+
token_frequency = torch.bincount(generated_tokens, minlength=scores.size(-1)).to(scores.device)
|
24 |
+
new_scores.append(score - self.penalty * token_frequency)
|
25 |
+
|
26 |
+
return torch.stack(new_scores).float()
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
class LlamaForConditionalGeneration(LlamaForCausalLM):
|
31 |
+
def __init__(self, config):
|
32 |
+
super().__init__(config)
|
33 |
+
|
34 |
+
def generate(self, **kwargs):
|
35 |
+
history_penalty = kwargs.pop("history_penalty", 0.0)
|
36 |
+
penalty_turns = kwargs.pop("penalty_turns", 0)
|
37 |
+
messages = kwargs.pop("messages", [])
|
38 |
+
|
39 |
+
if history_penalty != 0.0 and penalty_turns >= 0:
|
40 |
+
input_ids = kwargs.get("input_ids", torch.tensor([[]]))
|
41 |
+
input_length = input_ids.size(-1)
|
42 |
+
|
43 |
+
dialogs = []
|
44 |
+
for i in range(len(messages)):
|
45 |
+
message = messages[i]
|
46 |
+
if message['role'] == 'assistant':
|
47 |
+
dialogs.append(message['content'])
|
48 |
+
|
49 |
+
penalty_dialog = []
|
50 |
+
for i in range(penalty_turns, 0, -1):
|
51 |
+
if i <= len(dialogs):
|
52 |
+
dialog = dialogs[-i].replace("("," ").replace(")"," ").replace("("," ").replace(")"," ")
|
53 |
+
penalty_dialog.append(dialog)
|
54 |
+
|
55 |
+
model_id = "Collective-Ai/collective-v0.1-chinese-roleplay-8b"
|
56 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
57 |
+
penalty_token = torch.LongTensor(tokenizer.encode(' '.join(penalty_dialog))).to(input_ids.device)
|
58 |
+
|
59 |
+
logits_processor = []
|
60 |
+
logits_processor.append(FrequencyPenaltyLogitsProcessor(penalty=history_penalty, penalty_dialog=penalty_token, input_length=input_length))
|
61 |
+
result = super().generate(logits_processor = logits_processor, **kwargs)
|
62 |
+
else:
|
63 |
+
result = super().generate(**kwargs)
|
64 |
+
|
65 |
+
return result
|
66 |
+
|