|
import numpy as np
|
|
import math
|
|
import struct
|
|
import os
|
|
import threading
|
|
import torch
|
|
from transformers import StoppingCriteria, AutoModelForCausalLM, AutoTokenizer, StoppingCriteriaList, TextIteratorStreamer
|
|
from auto_gptq import AutoGPTQForCausalLM
|
|
|
|
|
|
|
|
class Prompt:
|
|
def __init__(self, contexto_maximo=10000):
|
|
|
|
|
|
self.contexto_maximo = contexto_maximo
|
|
|
|
self.template = [
|
|
"Abaixo está uma instrução que descreve uma tarefa",
|
|
"Escreva uma resposta que apropriadamente satisfaça os pedidos",
|
|
"### Instrução:",
|
|
"Seu nome é Capivarinha, você é uma assistente e amiga. Converse naturalmente de forma alegre dando dicas, recomendações e respondendo perguntas, de o máximo de informações que puder se mantendo dentro do tema"
|
|
|
|
]
|
|
|
|
self.contexto = []
|
|
self.entrada = "### Entrada:"
|
|
self.resposta = "### Resposta:"
|
|
self.tamanhoAtual = len("\n".join(self.template))
|
|
self.resposta_recente = ""
|
|
|
|
def adicionar_contexto(self, texto):
|
|
self.contexto.append(texto)
|
|
self.tamanhoAtual = len("\n".join(self.template + self.contexto))
|
|
|
|
if self.tamanhoAtual > self.contexto_maximo:
|
|
self.contexto.pop(0)
|
|
while prmp[0] != self.entrada:
|
|
self.contexto.pop(0)
|
|
|
|
def adiciona_entrada(self, entrada):
|
|
|
|
self.adicionar_contexto(self.entrada)
|
|
self.adicionar_contexto(entrada)
|
|
self.adicionar_contexto(self.resposta)
|
|
|
|
def adiciona_resposta(self, resposta):
|
|
self.resposta_recente = resposta.split(self.resposta)[-1].split(self.entrada)[0].split("### Instrução:")[0].split("\n###")[0].split("###")[0].strip()
|
|
self.adicionar_contexto(self.resposta_recente)
|
|
|
|
def limpar_contexto(self):
|
|
|
|
self.contexto = []
|
|
|
|
def ultima_resposta(self):
|
|
return self.resposta_recente
|
|
|
|
|
|
def retorna_tamanho_atual(self):
|
|
return self.tamanhoAtual
|
|
|
|
def retorna_prompt(self):
|
|
|
|
|
|
return "\n".join(self.template + self.contexto)
|
|
|
|
def tester(frase, words):
|
|
for word in words:
|
|
if frase.lower().find(word) > -1:
|
|
return True
|
|
return False
|
|
|
|
|
|
|
|
class ModeloAutoGPTQ:
|
|
|
|
def __init__(self, criterio_parada, caminho_modelo_tokenizer="./modelo/", t_padding_side="right", t_use_fast=True, nome_modelo="capivarita_gptq_model-4bit-128g",tamanho_contexto=4096,tensores_seguros=True,processador="cuda:0",usar_trion=False):
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(caminho_modelo_tokenizer, padding_side=t_padding_side, use_fast=t_use_fast)
|
|
self.modelo = AutoGPTQForCausalLM.from_quantized(caminho_modelo_tokenizer,
|
|
model_basename=nome_modelo,
|
|
|
|
max_position_embeddings=tamanho_contexto,
|
|
|
|
use_safetensors=tensores_seguros,
|
|
|
|
device=processador,
|
|
|
|
use_triton=usar_trion,
|
|
|
|
quantize_config=None
|
|
)
|
|
self.criterio_parada = criterio_parada
|
|
self.processador = processador
|
|
|
|
def generate(self, text_prompt, return_tensors='pt', skip_prompt=True, max_new_tokens=512, repetition_penalty=1.2, temperature=0.9):
|
|
input_ids = self.tokenizer(text_prompt, return_tensors=return_tensors).to(self.processador)
|
|
|
|
self.criterio_parada.comp_inicial = len(input_ids)
|
|
streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=skip_prompt)
|
|
generation_kwargs = dict(**input_ids, max_new_tokens=max_new_tokens, streamer=streamer ,repetition_penalty=repetition_penalty, temperature=temperature)
|
|
return threading.Thread(target=self.modelo.generate, kwargs=generation_kwargs), streamer
|
|
|
|
|
|
|
|
class KeywordsStoppingCriteria(StoppingCriteria):
|
|
def __init__(self, keywords_ids:list):
|
|
self.keywords = keywords_ids
|
|
self.comp_inicial = 0
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
|
|
|
if input_ids[0] > self.comp_inicial + 7:
|
|
|
|
for w in self.keywords:
|
|
if tokenizer.decode(input_ids[0][-8:]).find(w) > 0:
|
|
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
prompt = Prompt()
|
|
|
|
criterio_parada = KeywordsStoppingCriteria([prompt.entrada])
|
|
modelo = ModeloAutoGPTQ(criterio_parada=criterio_parada)
|
|
|
|
|
|
def analizar_entrada(frase):
|
|
|
|
if tester(frase,["limpar contexto do monitor", "nova conversa", "reiniciar monitoria"]):
|
|
prompt.limpar_contexto()
|
|
add_message("\n\n[Reiniciado]\n", "italic")
|
|
|
|
else:
|
|
try:
|
|
|
|
prompt.adiciona_entrada(frase)
|
|
modelo_thread, streamer = modelo.generate(prompt.retorna_prompt())
|
|
modelo_thread.start()
|
|
|
|
generated_text = ""
|
|
add_message("\nCapivarinha: ", "bold_violet")
|
|
for new_text in streamer:
|
|
if new_text in ["### Entrada:","### Entrada","### ", "\n###","### ", "\n###","##"]:
|
|
break
|
|
generated_text += new_text
|
|
add_message(new_text)
|
|
add_message("\n")
|
|
prompt.adiciona_resposta(generated_text)
|
|
except:
|
|
add_message("\n\n[Erro de geração]\n", "italic")
|
|
|
|
|
|
def add_message(mes, style=None):
|
|
if len(mes) > 0:
|
|
if mes[-1] == " ":
|
|
print(mes, end="", flush=True)
|
|
else:
|
|
print(mes, end="")
|
|
|
|
def main():
|
|
print("Escreva sua entrada (escreva 'sair' pra sair):\n")
|
|
while True:
|
|
user_input = input("> ")
|
|
if user_input.lower() == "sair":
|
|
break
|
|
analizar_entrada(user_input)
|
|
|
|
if __name__ == "__main__":
|
|
main() |