--- library_name: peft base_model: mistralai/Mistral-7B-Instruct-v0.1 pipeline_tag: text-generation datasets: - bugdaryan/sql-create-context-instruction tags: - Mistral - PEFT - LoRA - SQL --- ### Model Description SQL Generation model which is fine-tuned on the Mistral-7B-Instruct-v0.1. Inspired from https://huggingface.co/kanxxyc/Mistral-7B-SQLTuned ### Code ```py import torch from peft import PeftModel, PeftConfig from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline peft_model_id = "AhmedSSoliman/Mistral-Instruct-SQL-Generation" config = PeftConfig.from_pretrained(peft_model_id) model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, trust_remote_code=True, return_dict=True, load_in_4bit=True, device_map='auto') tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path) # Load the Lora model model = PeftModel.from_pretrained(model, peft_model_id) def predict_SQL(table, question): pipe = pipeline('text-generation', model = base_model, tokenizer = tokenizer) prompt = f"[INST] Write SQL query to answer the following question given the database schema. Please wrap your code answer using ```: Schema: {table} Question: {question} [/INST] Here is the SQL query to answer to the question: {question}: ``` " #prompt = f"### Schema: {table} ### Question: {question} # " ans = pipe(prompt, max_new_tokens=200) generatedSql = ans[0]['generated_text'].split('```')[2] return generatedSql table = "CREATE TABLE Employee (name VARCHAR, salary INTEGER);" question = 'Show names for all employees with salary more than the average.' generatedSql=predict_SQL(table, question) print(generatedSql) ```