Spaces:
Runtime error
Runtime error
import gradio as gr | |
# !python -c "import torch; assert torch.cuda.get_device_capability()[0] >= 8, 'Hardware not supported for Flash Attention'" | |
import json | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GemmaTokenizer, StoppingCriteria, StoppingCriteriaList, GenerationConfig | |
# from google.colab import userdata | |
import os | |
model_id = "somosnlp/GemmaColRAC-AeroExpert" | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.bfloat16 | |
) | |
max_seq_length=400 | |
# if torch.cuda.get_device_capability()[0] >= 8: | |
# # print("Flash Attention") | |
# attn_implementation="flash_attention_2" | |
# else: | |
# attn_implementation=None | |
attn_implementation=None | |
tokenizer = AutoTokenizer.from_pretrained(model_id, | |
max_length = max_seq_length) | |
model = AutoModelForCausalLM.from_pretrained(model_id, | |
# quantization_config=bnb_config, | |
device_map = {"":0}, | |
attn_implementation = attn_implementation, # A100 o H100 | |
).eval() | |
class ListOfTokensStoppingCriteria(StoppingCriteria): | |
""" | |
Clase para definir un criterio de parada basado en una lista de tokens específicos. | |
""" | |
def __init__(self, tokenizer, stop_tokens): | |
self.tokenizer = tokenizer | |
# Codifica cada token de parada y guarda sus IDs en una lista | |
self.stop_token_ids_list = [tokenizer.encode(stop_token, add_special_tokens=False) for stop_token in stop_tokens] | |
def __call__(self, input_ids, scores, **kwargs): | |
# Verifica si los últimos tokens generados coinciden con alguno de los conjuntos de tokens de parada | |
for stop_token_ids in self.stop_token_ids_list: | |
len_stop_tokens = len(stop_token_ids) | |
if len(input_ids[0]) >= len_stop_tokens: | |
if input_ids[0, -len_stop_tokens:].tolist() == stop_token_ids: | |
return True | |
return False | |
# Uso del criterio de parada personalizado | |
stop_tokens = ["<end_of_turn>"] # Lista de tokens de parada | |
# Inicializa tu criterio de parada con el tokenizer y la lista de tokens de parada | |
stopping_criteria = ListOfTokensStoppingCriteria(tokenizer, stop_tokens) | |
# Añade tu criterio de parada a una StoppingCriteriaList | |
stopping_criteria_list = StoppingCriteriaList([stopping_criteria]) | |
def generate_text(prompt, max_length=2100): | |
# prompt="""What were the main contributions of Eratosthenes to the development of mathematics in ancient Greece?""" | |
prompt=prompt.replace("\n", "").replace("¿","").replace("?","") | |
#EXAMPLE | |
input_text = f'''<bos><start_of_turn>system\nYou are a helpful AI assistant.\nResponde en formato json.\nEres un agente experto en la normativa aeronautica Colombiana.<end_of_turn>\n<start_of_turn>user\n¿{prompt}?<end_of_turn>\n<start_of_turn>model\n''' | |
inputs = tokenizer.encode(input_text, | |
return_tensors="pt", | |
add_special_tokens=False).to("cuda:0") | |
max_new_tokens=max_length | |
generation_config = GenerationConfig( | |
max_new_tokens=max_new_tokens, | |
temperature=0.15, | |
#top_p=0.9, | |
top_k=40, # 45 | |
repetition_penalty=1., #1.1 | |
do_sample=True, | |
) | |
outputs = model.generate(generation_config=generation_config, | |
input_ids=inputs, | |
stopping_criteria=stopping_criteria_list,) | |
return tokenizer.decode(outputs[0], skip_special_tokens=False) #True | |
def mostrar_respuesta(pregunta): | |
json_obj={} | |
json_obj['respuesta']='Esperando' | |
json_obj['pagina']='Esperando' | |
json_obj['rac']='Esperando' | |
if pregunta!="": | |
try: | |
res= generate_text(pregunta, max_length=500) | |
# print(">> RES:",res) | |
inicio_json = res.find('{') | |
fin_json = res.rfind('}') + 1 | |
json_str = res[inicio_json:fin_json] | |
json_obj = json.loads(json_str) | |
# print("json_obj:",json_obj) | |
return json_obj["respuesta"], json_obj["pagina"], json_obj["rac"] | |
except: | |
return json_obj["respuesta"], json_obj["pagina"], json_obj["rac"] | |
return json_obj["respuesta"], json_obj["pagina"], json_obj["rac"] | |
# Ejemplos de preguntas | |
ejemplos = [ | |
["¿Cuál fue la fecha de publicación del RAC 1 en el Diario Oficial?"], | |
["¿Qué se incorpora a los Reglamentos Aeronáuticos de Colombia?"], | |
["Cuál fue la fecha de publicación del RAC 1 en el Diario Oficial?"], | |
] | |
iface = gr.Interface( | |
fn=mostrar_respuesta, | |
inputs=gr.Textbox(label="Pregunta"), | |
outputs=[ | |
gr.Textbox(label="Respuesta", lines=2), | |
gr.Textbox(label="Pagina", lines=1), | |
gr.Textbox(label="Rac", lines=1) | |
], | |
title="Consultas Normativa Aeronáutica Colombiana", | |
description="Introduce tu pregunta sobre la normativa aeronáutica colombiana para obtener una respuesta.", | |
examples=ejemplos, | |
) | |
iface.queue(max_size=14).launch(debug=True) # share=True,debug=True | |