Edit model card

You need to agree to share your contact information to access this model

This repository is publicly accessible, but you have to accept the conditions to access its files and content.

Log in or Sign Up to review the conditions and access this model content.

简介

该款自然语言生成 SQL 的模型(NL2SQL/Text2SQL)是以 replit-code-v1-3b 代码续写预训练模型为基础进行 LoRA 微调的,这里仅提供 LoRA 权重(大概 11M),推理时需要结合原始预训练模型一起使用,具体参考下文示例。

用法

NL2SQL 任务中输入参数含有用户查询文本+数据库表信息,目前按照以下格式拼接模型的输入文本:

# Table Allergy_Type , columns = [ Allergy , AllergyType ]
# Table Has_Allergy , columns = [ StuID , Allergy ]
# Table Student , columns = [ StuID , LName , Fname , Age , Sex , Major , Advisor , city_code ]
# primary keys: [ Allergy_Type.Allergy , Student.StuID ]
# foreign keys: [ Has_Allergy.Allergy = Allergy_Type.Allergy , Has_Allergy.StuID = Student.StuID ]
# Create a query for question: 显示所有男生的学生ID。
query =

具体使用方法参考以下示例:

import sqlparse
import torch
from peft import PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM, TextGenerationPipeline

device = 'cuda'
base_model_path = 'replit/replit-code-v1-3b'
lora_model_path = 'DMetaSoul/nl2sql-chinese-standard-3b-lora'
sampling = False
tokenizer = AutoTokenizer.from_pretrained(base_model_path, 
    trust_remote_code=True, padding_side='left')
model = AutoModelForCausalLM.from_pretrained(base_model_path,
    trust_remote_code=True, torch_dtype=torch.float16)
if lora_model_path:
    model = PeftModel.from_pretrained(model, lora_model_path,
        torch_dtype=torch.float16)
model.eval()
model.to(device)

input_texts = [
    "# Table Allergy_Type , columns = [ Allergy , AllergyType ]\n# Table Has_Allergy , columns = [ StuID , Allergy ]\n# Table Student , columns = [ StuID , LName , Fname , Age , Sex , Major , Advisor , city_code ]\n# primary keys: [ Allergy_Type.Allergy , Student.StuID ]\n# foreign keys: [ Has_Allergy.Allergy = Allergy_Type.Allergy , Has_Allergy.StuID = Student.StuID ]\n# Create a query for question: 显示所有女学生的名、 姓氏、年龄。他们的性别是“女”.\nquery =",
    "# Table Allergy_Type , columns = [ Allergy , AllergyType ]\n# Table Has_Allergy , columns = [ StuID , Allergy ]\n# Table Student , columns = [ StuID , LName , Fname , Age , Sex , Major , Advisor , city_code ]\n# primary keys: [ Allergy_Type.Allergy , Student.StuID ]\n# foreign keys: [ Has_Allergy.Allergy = Allergy_Type.Allergy , Has_Allergy.StuID = Student.StuID ]\n# Create a query for question: 显示所有男生的学生ID。\nquery =",
]
inputs = tokenizer(input_texts, max_length=512, return_tensors="pt",
    padding=True, truncation=True)
inputs = {k:v.to(device) for k,v in inputs.items()}

with torch.no_grad():
    if sampling:
        outputs = model.generate(**inputs, do_sample=True, top_k=50, top_p=0.95,
            temperature=1.0, num_return_sequences=1, return_full_text=False,
            max_length=512, return_dict_in_generate=True, output_scores=True)
    else:
        outputs = model.generate(**inputs, num_beams=4, num_return_sequences=1, return_full_text=False
            max_length=512, return_dict_in_generate=True, output_scores=True)

output_ids = outputs.sequences
results = tokenizer.batch_decode(output_ids, skip_special_tokens=True,
            clean_up_tokenization_spaces=True)

for question, sql in zip(input_texts, results):
    print(question)
    print('SQL: {}'.format(sqlparse.format(sql, reindent=True)))
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference API
Unable to determine this model's library. Check the docs .