guillaumetell-7b / prompt_demo_rag.py
Pclanglais's picture
Rename prompt_demo_inference.py to prompt_demo_rag.py
5e8dd09 verified
raw
history blame
5.23 kB
#Full demo of the Guillaume-Tell reference model with three references.
#Guillaume-Tell is currently trained by default on five references but future version will enhance the flexibility of the model.
#Example of generated text:
#Le meilleur moyen de cuire une blanquette est d'utiliser un mélange de viande et de légumes, tels que des champignons de Paris<ref text="Les meilleures blanquettes se font avec des champignons de Paris">hash49080806</ref>.
#Il est recommandé de faire chauffer la blanquette à feu doux pendant 46 heures<ref text="faîtes chauffer la blanquette à feu doux pendant 46 heures.">hash49080806</ref>.
#Enfin, pour achever la préparation, il faut ajouter une crème fraîche, un jaune d’œuf et du jus de citron juste avant de servir<ref text="Dans un bol, bien mélanger la crème fraîche, le jaune d’oeuf et le jus de citron. Ajouter ce mélange au dernier moment, bien remuer et servir tout de suite.">hash49080806</ref>.
import sys, os
from pprint import pprint
from jinja2 import Environment, FileSystemLoader, meta
import yaml
import pandas as pd
from vllm import LLM, SamplingParams
sys.path.append(".")
os.chdir(os.path.dirname(os.path.abspath(__file__)))
def get_llm_response(prompt_template):
sampling_params = SamplingParams(temperature=0.4, top_p=.95, max_tokens=2000, presence_penalty = 2)
prompts = [prompt_template]
outputs = llm.generate(prompts, sampling_params, use_tqdm = False)
generated_text = outputs[0].outputs[0].text
prompt = prompt_template + generated_text
return prompt, generated_text
#Typical example:
if __name__ == "__main__":
with open('prompt_config.yaml') as f:
config = yaml.safe_load(f)
print("prompt format:", config.get("prompt_format"))
print(config)
print()
for prompt in config["prompts"]:
if prompt["mode"] == "rag":
print(f'--- prompt mode: {prompt["mode"]} ---')
env = Environment(loader=FileSystemLoader("."))
template = env.get_template(prompt["template"])
source = template.environment.loader.get_source(template.environment, template.name)
variables = meta.find_undeclared_variables(env.parse(source[0]))
print("variables:", variables)
print("---")
data = {
"query": "Quel est le meilleur moyen de cuire une blanquette?",
"chunks" : [
{
"url": "http://data.gouv.fr",
"h": "hash49080805",
"title": "A chunk title",
"text": "Moi j'aime la blanquette avec du beurre dedans\nEt une sauce bien épaisse.",
},
{
"url": "http://...",
"h": "hash49080806",
"title": "A chunk title",
"text": "faîtes chauffer la blanquette à feu doux pendant 46 heures.",
"context": "Recette de blanquette"
},
{
"url": "http://...",
"h": "hash49080806",
"title": "A chunk title",
"text": "Les meilleures blanquettes se font avec des champignons de Paris",
"context": "Avis de grand-mère"
},
{
"url": "http://...",
"h": "hash49080806",
"title": "A chunk title",
"text": """Étape 1 Faire revenir la viande dans un peu de beurre doux jusqu'à ce que les morceaux soient un peu dorés.
Étape 2: Saupoudrer de 2 cuillères de farine. Bien remuer.
Étape 3: Ajouter 2 ou 3 verres d'eau, les cubes de bouillon, le vin et remuer. Ajouter de l'eau si nécessaire pour couvrir.
Étape 4: Couper les carottes en rondelles et émincer les oignons puis les incorporer à la viande, ainsi que les champignons.
Étape 5: Laisser mijoter à feu très doux environ 1h30 à 2h00 en remuant.
Étape 6: Si nécessaire, ajouter de l'eau de temps en temps.
Étape 7: Dans un bol, bien mélanger la crème fraîche, le jaune d’oeuf et le jus de citron. Ajouter ce mélange au dernier moment, bien remuer et servir tout de suite.
""",
"context": "Recette Marmiton"
},
]
}
if "system_prompt" in variables:
data["system_prompt"] = prompt["system_prompt"]
rendered_template = template.render(**data)
print(rendered_template)
print("---")
llm = LLM("mistral-mfs-reference-2/mistral-mfs-reference-2")
sampling_params = SamplingParams(temperature=0.7, top_p=0.95, max_tokens=1500)
prompt, generated_text = get_llm_response(rendered_template)
print("Albert : ", generated_text)