File size: 1,747 Bytes
ea596d3
 
 
a7b762f
f470422
 
 
 
 
 
 
ea596d3
 
 
 
 
f470422
0a0064d
ea596d3
2cdb77d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7b762f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
---
library_name: peft
base_model: mistralai/Mistral-7B-Instruct-v0.1
pipeline_tag: text-generation
datasets:
- bugdaryan/sql-create-context-instruction
tags:
- Mistral
- PEFT
- LoRA
- SQL
---

### Model Description

<!-- Provide a longer summary of what this model is. -->
SQL Generation model which is fine-tuned on the Mistral-7B-Instruct-v0.1.
Inspired from https://huggingface.co/kanxxyc/Mistral-7B-SQLTuned

### Code 
```py
import torch
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
peft_model_id = "AhmedSSoliman/Mistral-Instruct-SQL-Generation"
config = PeftConfig.from_pretrained(peft_model_id)
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, trust_remote_code=True, return_dict=True, load_in_4bit=True, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)

# Load the Lora model
model = PeftModel.from_pretrained(model, peft_model_id)

def predict_SQL(table, question):
    pipe = pipeline('text-generation', model = base_model, tokenizer = tokenizer)
    prompt = f"[INST] Write SQL query to answer the following question given the database schema. Please wrap your code answer using ```: Schema: {table} Question: {question} [/INST] Here is the SQL query to answer to the question: {question}: ``` "
    #prompt = f"### Schema: {table} ### Question: {question} # "
    ans = pipe(prompt, max_new_tokens=200)
    generatedSql = ans[0]['generated_text'].split('```')[2]
    return generatedSql


table = "CREATE TABLE Employee (name VARCHAR, salary INTEGER);"
question = 'Show names for all employees with salary more than the average.'

generatedSql=predict_SQL(table, question)
print(generatedSql)

```