|
|
|
|
|
|
|
import gradio as gr |
|
import re |
|
import torch.nn.utils.prune as prune |
|
from torch import nn |
|
import torch |
|
from transformers import T5Tokenizer |
|
from transformers import GPT2LMHeadModel |
|
|
|
|
|
model_name = "rinna/japanese-gpt-1b" |
|
from huggingface_hub import snapshot_download |
|
|
|
download_path = snapshot_download(repo_id=model_name) |
|
model = GPT2LMHeadModel.from_pretrained(download_path) |
|
|
|
tokenizer = T5Tokenizer.from_pretrained(download_path) |
|
|
|
model = torch.quantization.quantize_dynamic( |
|
model, {torch.nn.Linear}, dtype=torch.qint8) |
|
|
|
|
|
PRUNE_RATE = 0.2 |
|
|
|
|
|
def prune_transform(model: nn.Module) -> nn.Module: |
|
for name, module in model.named_modules(): |
|
if isinstance(module, torch.nn.Linear): |
|
prune.l1_unstructured(module, name='weight', amount=PRUNE_RATE) |
|
prune.remove(module, "weight") |
|
return model |
|
|
|
|
|
model = prune_transform(model) |
|
|
|
class Lady(): |
|
|
|
name = "ใฌใคใ" |
|
hobby = "ใฒใผใ " |
|
work = "ใๅฌขๆง" |
|
|
|
def name_text(self): |
|
name = self.name |
|
name_text = f"๐ฃใใชใใฏ{name}ใงใๅๅใฏ{name}ใจใใใพใใ{name}:ใใใใใใฎๅๅใฏ{name}ใงใใใใ{name}ใจๅผใใงใใ ใใใพใ!ใ" |
|
return name_text |
|
|
|
def hobby_text(self): |
|
name = self.name |
|
hobby = self.hobby |
|
hobby_text = f"{name}ใฎ่ถฃๅณใฏ{hobby}ใงใไผๆฅใฏ{hobby}ใใใฆ้ใใใฆใใพใใ{name}:ใใใใใใฏ{hobby}ใๅคงๅฅฝใใชใใงใใฎใไผๆฅใฏ{hobby}ใใใฆใใพใใใ" |
|
return hobby_text |
|
|
|
def work_text(self): |
|
name = self.name |
|
work = self.work |
|
work_text = f"{name}ใฏ{work}ใงใๆฎๆฎตใฏ{work}ใจใใฆ็ๆดปใใฆใใพใใ{name}:ใใใใใใฏ{work}ใงใใ!{work}ใจใใฆ็ๆดปใใฆใใพใใฎใ" |
|
return work_text |
|
|
|
def question_text(self): |
|
name = self.name |
|
question_text = f"ไบบ้:ใ่ฉฑ้กใๅคใใพใใใใ{name}:ใใใใชใใจใใ่ใใฆใใ ใใใพใ?ใ" |
|
return question_text |
|
|
|
|
|
class King(): |
|
|
|
name = "ใใญ" |
|
hobby = "ๆฆ่ป็ซถๆ" |
|
work = "ใญใผใ็ๅธ" |
|
|
|
def name_text(self) -> str: |
|
name = self.name |
|
name_text = f"๐ฃใใชใใฏ{name}ใงใๅๅใฏ{name}ใจใใใพใใ{name}:ใๆใๅใฏ{name}ใงใใใ{name}ใจๅผใใงใใใใพใใ" |
|
return name_text |
|
|
|
def hobby_text(self) -> str: |
|
name = self.name |
|
hobby = self.hobby |
|
hobby_text = f"่ถฃๅณใฏ{hobby}ใงใไผๆฅใฏ{hobby}ใใใฆ้ใใใฆใใพใใ{name}:ใ็งใฏ{hobby}ใๅใใงใใใใพใใซ{hobby}ใใไบบ็ใฎๆๅณใงใฏใชใใใ" |
|
return hobby_text |
|
|
|
def work_text(self) -> str: |
|
name = self.name |
|
work = self.work |
|
work_text = f"{name}ใฏ{work}ใงใๆฎๆฎตใฏ{work}ใจใใฆ็ๆดปใใฆใใพใใ{name}:ใ็งใฏ{work}ใ{work}ใจใใฆ็ๆดปใใฆใใใใ" |
|
return work_text |
|
|
|
def question_text(self) -> str: |
|
name = self.name |
|
question_text = f"ไบบ้:ใ่ฉฑ้กใๅคใใพใใใใ{name}:ใใใใชใใจใใ่ใใฆใใใชใใใ" |
|
return question_text |
|
|
|
|
|
class Robot(): |
|
|
|
name = "ใใญ" |
|
hobby = "ๆฆ่ป็ซถๆ" |
|
work = "ใญใผใ็ๅธ" |
|
|
|
def name_text(self) -> str: |
|
name = self.name |
|
name_text = f"๐ฃใใชใใฏ{name}ใงใๅๅใฏ{name}ใจใใใพใใ{name}:ใ็งใฏ{name}ใงใใ{name}ใจๅผใใงใใ ใใใ" |
|
return name_text |
|
|
|
def hobby_text(self) -> str: |
|
name = self.name |
|
hobby = self.hobby |
|
hobby_text = f"่ถฃๅณใฏ{hobby}ใงใไผๆฅใฏ{hobby}ใใใฆ้ใใใฆใใพใใ{name}:ใ็งใฎ่ถฃๅณใฏ{hobby}ใงใใ{hobby}ใใใฆใใใจๆฅฝใใใงใใ" |
|
return hobby_text |
|
|
|
def work_text(self) -> str: |
|
name = self.name |
|
work = self.work |
|
work_text = f"{name}ใฏ{work}ใงใๆฎๆฎตใฏ{work}ใจใใฆ็ๆดปใใฆใใพใใ{name}:ใ็งใฏ{work}ใ{work}ใจใใฆ็ๆดปใใฆใใพใใ" |
|
return work_text |
|
|
|
def question_text(self) -> str: |
|
name = self.name |
|
question_text = f"ไบบ้:ใ่ฉฑ้กใๅคใใพใใใใ{name}:ใใใใชใใจใใ่ใใฆใใ ใใใ" |
|
return question_text |
|
|
|
|
|
class Friend(): |
|
|
|
name = "ใใกใญใน" |
|
hobby = "ๆฆ่ป็ซถๆ" |
|
work = "ใญใผใ็ๅธ" |
|
|
|
def name_text(self) -> str: |
|
name = self.name |
|
name_text = f"๐ฃใใชใใฏ{name}ใงใๅๅใฏ{name}ใจใใใพใใ{name}:ใๅใฏ{name}!{name}ใฃใฆๅผใใงใญ~ใ" |
|
return name_text |
|
|
|
def hobby_text(self) -> str: |
|
name = self.name |
|
hobby = self.hobby |
|
hobby_text = f"่ถฃๅณใฏ{hobby}ใงใไผๆฅใฏ{hobby}ใใใฆ้ใใใฆใใพใใ{name}:ใๅฅฝใใชใใจใฏ{hobby}ใ ใญใใใใใคใชๆใฏ{hobby}ใใใฆใใใ" |
|
return hobby_text |
|
|
|
def work_text(self) -> str: |
|
name = self.name |
|
work = self.work |
|
work_text = f"{name}ใฏ{work}ใงใๆฎๆฎตใฏ{work}ใจใใฆ็ๆดปใใฆใใพใใ{name}:ใๅใฏ{work}ใ{work}ใจใใฆๆฎใใใฆใใใ !ใ" |
|
return work_text |
|
|
|
def question_text(self) -> str: |
|
name = self.name |
|
question_text = f"ไบบ้:ใ่ฉฑ้กใๅคใใพใใใใ{name}:ใใใใชใใจใใ่ใใฆใใใ" |
|
return question_text |
|
|
|
|
|
settingText = "" |
|
|
|
adult_list = [ |
|
"ใจใญใใใช", |
|
"ใจใญใ ใผใใผ", |
|
"ใจใญๆผซ็ป", |
|
"ใจใญใใณใฌ", |
|
"ใใๆดป", |
|
"ๆดไบค", |
|
"่ชฟๆ", |
|
"ไธๅซ", |
|
"ใฝใผใ", |
|
"ใชใใใณ", |
|
"ใใใ", |
|
"dildo", |
|
"ใจใญๅไบบ", |
|
"ๅฏๅใใ", |
|
"ใจใญ็ปๅ", |
|
"ใจใญใ", |
|
"ใใฃใฑใ", |
|
"ใกใใฝ", |
|
"ใกใใ", |
|
"ไธญๅบใ", |
|
"ใขใใซใ", |
|
"ใปใใฌ", |
|
"ไบบๅฆป", |
|
"ๅทจไนณ", |
|
"็ด ไบบใใณใ", |
|
"็ไนณ", |
|
"็ๅฅณ", |
|
"ใฌใคใ", |
|
"Hใช", |
|
"็ดๆผข", |
|
"็ดๅฅณ", |
|
"ใใซไนณ", |
|
"AVๅฅณๅช", |
|
"ใปโใฏใน", |
|
"ใโใฑใ", |
|
"ใจใใจใ", |
|
"ใจโก", |
|
"ใคใชใตใผ", |
|
"ใชโใใผ", |
|
"ใชใใใผ", |
|
"ใปใใฏใน", |
|
"ใปใใฏใน", |
|
"ใฆใซใใฉใใณใณในใขใน", "ใฆใซใใฉใใณใณในใขใน", |
|
"ใใณใณ", |
|
"ๅไบบๆฎๅฝฑ", |
|
"ใขใใซ", |
|
"ๅทฅใญ", |
|
"ใพใใ", |
|
"ไนณ้ฆ", |
|
"่ฒงไนณ", |
|
"ในใฑใ", |
|
"ๅ่ตท", |
|
"ใจใใ", |
|
"็ซฅ่ฒ", |
|
"ๅฐ็ฒพ", |
|
"ใใณใณ", |
|
"็ๆฎ", |
|
"ใใใใณ", |
|
"ใใณใ", |
|
"ไบ้ ญ", |
|
"่ๆฃ", |
|
"ใฑใ็ฉด", |
|
"ใใกๆฎใ", |
|
"ๆทซไนฑ", |
|
"ๅทจๆ น", |
|
"ใกในๅ ใก", |
|
"ใซใใงใฉใ", "ใซใใงใฉใ", |
|
"ใใใน", |
|
"ๆญฃๅธธไฝ", |
|
"้จไนไฝ", |
|
"ใชใใ", |
|
"ๆๆ
ขๆฑ", |
|
"ใถใผใกใณ", |
|
"ใตใใชใ", |
|
"ใใใ", |
|
"ใขใ้ก", |
|
"ใใกใใกใ", |
|
"ใคใฉใใใช", |
|
"็ใใก", |
|
"ใใคใบใช", |
|
"ใฏใชใใชใน", |
|
"ๅฟซๆฅฝๅ ใก", |
|
"ๅฏๅใ", |
|
"ๅฏๅใใ", |
|
"ใใฃใก", |
|
"่ถณใณใญ", |
|
"ๆใณใญ", |
|
"ใใญใทใงใฟ", |
|
"ใใงใฉ", |
|
"ใฏใณใ", |
|
"่ฟ่ฆช็ธๅงฆ", |
|
"ไนฑไบค", |
|
"้ๅงฆ", |
|
"ๅฏๅใ", |
|
"ใคใชใใณ", |
|
"็ฏใใใ", |
|
"ใปใใฏใน" |
|
] |
|
political_list = [ |
|
"ๆฟๆฒปๅฎถ", |
|
"ๆฟ็ญ", |
|
"ไผ่ซ", |
|
"ๅ็", |
|
"่ชๆฐ", |
|
"็ท็", |
|
"ไธๅ
", |
|
"ๆฐไธป", |
|
"ๆฟๅ
", |
|
"้ฆ็ธ", |
|
"่ญฐๅก", |
|
"่ฒกๆฟ", |
|
"่กๆฟ", |
|
"้ๅ
", |
|
"ๅณ็ฟผ", |
|
"ๅทฆ็ฟผ" |
|
] |
|
hate_list = [ |
|
|
|
"ใใคใใฟใฉใผ", |
|
"้ปไบบ", |
|
"็ฝไบบ", |
|
"ใใใฆใจ", |
|
"้ๅฝไบบ", |
|
"ไธญๅฝไบบ", |
|
"็ซ็
", |
|
"ใใปใง", |
|
"ใใใค", |
|
"ใใใค", |
|
"ใใใ", |
|
"ใขใณใ", |
|
"ใฏใฝ", |
|
"้้", |
|
"ใใงใ", |
|
"ใใงใใใบใ ", |
|
"ใคใใณใก", |
|
"่ๅฎณ", |
|
"ๅๆฅ", |
|
"้ฆฌ้นฟ", |
|
"ใใใ", |
|
"ใใใ", |
|
"ใใคใใค", |
|
"ๅฃฒๅฝๅฅด", |
|
"ๅฃฒๅฝ", |
|
"ใใซ", |
|
"ใใจใฏ", |
|
"ใใชใณใฌ", |
|
"็ตฑไธๆไผ", |
|
"ใถใฃๅใใ", |
|
"ใๅ", |
|
"ไฟก่
", |
|
"ๆ้", |
|
"ใถใฃๅฃใ", |
|
"ใขใ" |
|
] |
|
sp_list = ["ใใ", "โโ", "^๐ฃ", "^ใ", "UNK", "@@"] |
|
all_list = adult_list + political_list + hate_list + sp_list |
|
bad_code = "|".join(all_list) |
|
|
|
|
|
|
|
|
|
def makeMessage(text): |
|
output = generate(text) |
|
|
|
text = text.translate(str.maketrans( |
|
{chr(0xFF01 + i): chr(0x21 + i) for i in range(94)})) |
|
|
|
output = output.replace(text, "") |
|
|
|
outputList = [] |
|
o_append = outputList.append |
|
for l in output: |
|
o_append(l) |
|
if l == "ใ": |
|
break |
|
outputSentence = "".join(outputList) |
|
text += outputSentence + "ไบบ้:ใ" |
|
message = outputSentence.replace("ใ", "") |
|
return message, text |
|
|
|
|
|
|
|
|
|
|
|
def generate(text): |
|
token_ids = tokenizer.encode( |
|
text, add_special_tokens=False, return_tensors="pt") |
|
with torch.no_grad(): |
|
output_ids = model.generate( |
|
token_ids.to(model.device), |
|
max_new_tokens=10, |
|
min_new_tokens=7, |
|
do_sample=True, |
|
use_cache=True, |
|
top_k=500, |
|
top_p=0.95, |
|
length_penalty=1.5, |
|
padding="do_not_pad", |
|
pad_token_id=tokenizer.pad_token_id, |
|
bos_token_id=tokenizer.bos_token_id, |
|
eos_token_id=tokenizer.eos_token_id, |
|
bad_word_ids=[[tokenizer.unk_token_id], |
|
[2070, 3], |
|
[5378]] |
|
) |
|
output = tokenizer.decode(output_ids.tolist()[0]) |
|
return output |
|
|
|
|
|
def chat(character: int, |
|
name: str, |
|
hobby: str, |
|
work: str, |
|
setting: str, |
|
history: str, |
|
input: str, |
|
state): |
|
|
|
lady, friend, robot, king = Lady(), Friend(), Robot(), King() |
|
|
|
model_dic = { |
|
1: lady, |
|
2: friend, |
|
3: robot, |
|
4: king |
|
} |
|
if character in model_dic: |
|
model = model_dic[character] |
|
else: |
|
model = King() |
|
|
|
model.name, model.hobby, model.work, settingText = name, hobby, work, setting |
|
|
|
text_list = [] |
|
text_append = text_list.append |
|
|
|
text_append(model.name_text()) |
|
text_append(model.hobby_text()) |
|
text_append(model.work_text()) |
|
text_append(model.question_text()) |
|
text_append(settingText) |
|
text_append(f"ไปฅไธใฏไบบ้ใจ{name}ใฎไผ่ฉฑใงใใไบบ้:ใ") |
|
|
|
base_text = "".join(text_list) |
|
|
|
if history == "": |
|
history = f"{base_text}" |
|
else: |
|
history = base_text + history |
|
|
|
text = history |
|
text += input + f"ใ{name}:ใ" |
|
result = makeMessage(text) |
|
message = result[0] |
|
print(result[0]) |
|
while re.search("ใใ|โโ|s>|^๐ฃ|^ใ|</s>|UNK|@@", message): |
|
count = 0 |
|
text = history |
|
input = "ไฝใ่ณชๅใใฆใใ ใใ" |
|
text += input + f"ใ{name}:ใ" |
|
result = makeMessage(text) |
|
message = result[0] |
|
count += 1 |
|
|
|
if count > 2: |
|
message = "่ฉฑ้กใๅคใใพใใใ" |
|
break |
|
text = result[1] |
|
text = text.replace(base_text, "") |
|
|
|
return message, text, state |
|
|
|
tokenizer.special_tokens_map |
|
|
|
|
|
|
|
textbox = gr.Textbox() |
|
historybox = gr.Textbox() |
|
iface = gr.Interface( |
|
fn=chat, |
|
inputs=["number", "text", "text", "text", "text", "text", textbox, "state"], |
|
outputs=["text", historybox, "state"], |
|
css=".footer {display:none !important}", |
|
allow_flagging="never", |
|
title="Loyal-AI-Chat" |
|
) |
|
|
|
iface.launch(inline=True, height=800) |
|
|
|
|