|
""" |
|
monkeypatch to add a get_turns method |
|
""" |
|
|
|
import logging |
|
from typing import Generator, Tuple |
|
|
|
from fastchat.conversation import SeparatorStyle |
|
|
|
LOG = logging.getLogger("axolotl.monkeypatch.fastchat_conversation_turns") |
|
|
|
|
|
def get_prompt(self) -> str: |
|
ret = "" |
|
for role, msg in self.get_turns(): |
|
ret += role + msg |
|
return ret |
|
|
|
|
|
def get_turns( |
|
self, |
|
) -> Generator[Tuple[str, str], None, None]: |
|
"""Get the prompt for generation.""" |
|
system_prompt = self.system_template.format(system_message=self.system_message) |
|
if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE: |
|
yield "", system_prompt + self.sep |
|
for role, message in self.messages: |
|
if message: |
|
yield role + ": ", message + self.sep |
|
else: |
|
yield role + ":", "" |
|
return |
|
if self.sep_style == SeparatorStyle.ADD_COLON_TWO: |
|
seps = [self.sep, self.sep2] |
|
yield "", system_prompt + seps[0] |
|
for i, (role, message) in enumerate(self.messages): |
|
if message: |
|
yield role + ": ", message + seps[i % 2] |
|
else: |
|
yield role + ":", "" |
|
return |
|
if self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE: |
|
yield "", system_prompt + self.sep |
|
for role, message in self.messages: |
|
if message: |
|
yield role + ": ", message + self.sep |
|
else: |
|
yield role + ": ", "" |
|
return |
|
if self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE: |
|
yield "", "" if system_prompt == "" else system_prompt + self.sep |
|
for role, message in self.messages: |
|
if message: |
|
yield role + "\n", message + self.sep |
|
else: |
|
yield role + "\n", "" |
|
return |
|
if self.sep_style == SeparatorStyle.NO_COLON_SINGLE: |
|
yield "", system_prompt |
|
for role, message in self.messages: |
|
if message: |
|
yield role, message + self.sep |
|
else: |
|
yield role, "" |
|
return |
|
if self.sep_style == SeparatorStyle.NO_COLON_TWO: |
|
seps = [self.sep, self.sep2] |
|
yield "", system_prompt |
|
for i, (role, message) in enumerate(self.messages): |
|
if message: |
|
yield role, message + seps[i % 2] |
|
else: |
|
yield role, "" |
|
return |
|
if self.sep_style == SeparatorStyle.RWKV: |
|
yield "", system_prompt |
|
for i, (role, message) in enumerate(self.messages): |
|
if message: |
|
yield role + ": ", message.replace("\r\n", "\n").replace( |
|
"\n\n", "\n" |
|
) + "\n\n" |
|
else: |
|
yield role + ":", "" |
|
return |
|
if self.sep_style == SeparatorStyle.LLAMA2: |
|
seps = [self.sep, self.sep2] |
|
if self.system_message: |
|
yield "", system_prompt |
|
else: |
|
yield "", "[INST] " |
|
for i, (role, message) in enumerate(self.messages[1:]): |
|
if message: |
|
yield role + " ", message + seps[i % 2] |
|
else: |
|
yield role, "" |
|
return |
|
if self.sep_style == SeparatorStyle.CHATGLM: |
|
|
|
|
|
round_add_n = 1 if self.name == "chatglm2" else 0 |
|
if system_prompt: |
|
yield "", system_prompt + self.sep |
|
|
|
for i, (role, message) in enumerate(self.messages): |
|
if i % 2 == 0: |
|
yield "", f"[Round {i//2 + round_add_n}]{self.sep}" |
|
|
|
if message: |
|
yield f"{role}:", f"{message}{self.sep}" |
|
else: |
|
yield f"{role}:", "" |
|
return |
|
if self.sep_style == SeparatorStyle.CHATML: |
|
yield "", "" if system_prompt == "" else system_prompt + self.sep + "\n" |
|
for role, message in self.messages: |
|
if message: |
|
yield role + "\n", message + self.sep + "\n" |
|
else: |
|
yield role + "\n", "" |
|
return |
|
if self.sep_style == SeparatorStyle.CHATINTERN: |
|
|
|
seps = [self.sep, self.sep2] |
|
yield "", system_prompt |
|
for i, (role, message) in enumerate(self.messages): |
|
prefix = "<s>" if i % 2 == 0 else "" |
|
if message: |
|
yield prefix + role + ":", message + seps[i % 2] + "\n" |
|
else: |
|
yield role + ":", "" |
|
return |
|
if self.sep_style == SeparatorStyle.DOLLY: |
|
seps = [self.sep, self.sep2] |
|
yield "", system_prompt |
|
for i, (role, message) in enumerate(self.messages): |
|
if message: |
|
suffix = "\n\n" if i % 2 == 1 else "" |
|
yield role + ":\n", message + seps[i % 2] + suffix |
|
else: |
|
yield role + ":\n", "" |
|
return |
|
if self.sep_style == SeparatorStyle.PHOENIX: |
|
yield "", system_prompt |
|
for role, message in self.messages: |
|
if message: |
|
yield role + ": ", "<s>" + message + "</s>" |
|
else: |
|
yield role + ": " + "<s>", "" |
|
return |
|
if self.sep_style == SeparatorStyle.ROBIN: |
|
yield "", system_prompt + self.sep |
|
for role, message in self.messages: |
|
if message: |
|
yield role + ":\n", message + self.sep |
|
else: |
|
yield role + ":\n", "" |
|
return |
|
if self.sep_style == SeparatorStyle.FALCON_CHAT: |
|
if self.system_message: |
|
yield "", system_prompt + self.sep |
|
for role, message in self.messages: |
|
if message: |
|
yield role + ": ", message + self.sep |
|
else: |
|
yield role + ":", "" |
|
else: |
|
raise ValueError(f"Invalid style: {self.sep_style}") |
|
|
|
|
|
def add_get_turns_to_conversation(): |
|
import fastchat.conversation |
|
|
|
fastchat.conversation.Conversation.get_turns = get_turns |
|
fastchat.conversation.Conversation.get_prompt = get_prompt |
|
|