|
--- |
|
library_name: peft |
|
base_model: microsoft/phi-2 |
|
--- |
|
|
|
# Model Card for Model ID |
|
phi-2-mongodb is a fine-tuned version of microsoft/phi-2 to generate MongoDB pipeline queries. It was fine-tuned on a custom curated natural language to MongoDB queries dataset, I'll be releasing that next week. |
|
|
|
|
|
## Model Details |
|
Further details about fine-tuned model can be found at : https://github.com/Chirayu-Tripathi/nl2query. It can also be used via nl2query library. |
|
|
|
### Model Description |
|
|
|
<!-- Provide a longer summary of what this model is. --> |
|
|
|
|
|
- **Fine-tuned by:** [`Chirayu Tripathi`](http://www.linkedin.com/in/chirayu-tripathi) |
|
- **Developed by:** [`Microsoft`] |
|
- **Language(s) (NLP):** English |
|
- **License:** MIT |
|
- **Finetuned from model:** [`microsoft/phi-2`](https://huggingface.co/microsoft/phi-2) |
|
|
|
### Prompt Template |
|
``` |
|
prompt_template = f"""<s> |
|
Task Description: |
|
Your task is to create a MongoDB query that accurately fulfills the provided Instruct while strictly adhering to the given MongoDB schema. Ensure that the query solely relies on keys and columns present in the schema. Minimize the usage of lookup operations wherever feasible to enhance query efficiency. |
|
|
|
MongoDB Schema: |
|
{db_schema} |
|
|
|
### Instruct: |
|
{text} |
|
|
|
### Output: |
|
""" |
|
``` |
|
|
|
|
|
## How to Get Started with the Model |
|
|
|
Use the code sample provided in the original post to interact with the model. |
|
```python |
|
from transformers import ( |
|
AutoTokenizer, |
|
AutoModelForCausalLM, |
|
BitsAndBytesConfig, |
|
) |
|
import torch |
|
from peft import PeftModel |
|
|
|
db_schema = '''{ |
|
"collections": [ |
|
{ |
|
"name": "shipwrecks", |
|
"indexes": [ |
|
{ |
|
"key": { |
|
"_id": 1 |
|
} |
|
}, |
|
{ |
|
"key": { |
|
"feature_type": 1 |
|
} |
|
}, |
|
{ |
|
"key": { |
|
"chart": 1 |
|
} |
|
}, |
|
{ |
|
"key": { |
|
"latdec": 1, |
|
"londec": 1 |
|
} |
|
} |
|
], |
|
"uniqueIndexes": [], |
|
"document": { |
|
"properties": { |
|
"_id": { |
|
"bsonType": "string" |
|
}, |
|
"recrd": { |
|
"bsonType": "string" |
|
}, |
|
"vesslterms": { |
|
"bsonType": "string" |
|
}, |
|
"feature_type": { |
|
"bsonType": "string" |
|
}, |
|
"chart": { |
|
"bsonType": "string" |
|
}, |
|
"latdec": { |
|
"bsonType": "double" |
|
}, |
|
"londec": { |
|
"bsonType": "double" |
|
}, |
|
"gp_quality": { |
|
"bsonType": "string" |
|
}, |
|
"depth": { |
|
"bsonType": "string" |
|
}, |
|
"sounding_type": { |
|
"bsonType": "string" |
|
}, |
|
"history": { |
|
"bsonType": "string" |
|
}, |
|
"quasou": { |
|
"bsonType": "string" |
|
}, |
|
"watlev": { |
|
"bsonType": "string" |
|
}, |
|
"coordinates": { |
|
"bsonType": "array", |
|
"items": { |
|
"bsonType": "double" |
|
} |
|
} |
|
} |
|
} |
|
} |
|
], |
|
"version": 1 |
|
}''' |
|
|
|
text = ''''Find the count of shipwrecks for each unique combination of "latdec" and "longdec"''' |
|
prompt = f"""<s> |
|
Task Description: |
|
Your task is to create a MongoDB query that accurately fulfills the provided Instruct while strictly adhering to the given MongoDB schema. Ensure that the query solely relies on keys and columns present in the schema. Minimize the usage of lookup operations wherever feasible to enhance query efficiency. |
|
|
|
MongoDB Schema: |
|
{db_schema} |
|
|
|
### Instruct: |
|
{text} |
|
|
|
### Output: |
|
""" |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
base_model_id = "microsoft/phi-2" |
|
tokenizer = AutoTokenizer.from_pretrained(base_model_id, use_fast=True) |
|
compute_dtype = getattr(torch, "float16") |
|
bnb_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_compute_dtype=compute_dtype, |
|
bnb_4bit_use_double_quant=True, |
|
) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
base_model_id, |
|
trust_remote_code=True, |
|
quantization_config=bnb_config, |
|
revision="refs/pr/23", |
|
device_map={"": 0}, |
|
torch_dtype="auto", |
|
flash_attn=True, |
|
flash_rotary=True, |
|
fused_dense=True, |
|
) |
|
adapter = 'Chirayu/phi-2-mongodb' |
|
|
|
model = PeftModel.from_pretrained(model, adapter).to(device) |
|
model_inputs = tokenizer(prompt, return_tensors="pt").to(device) |
|
output = model.generate( |
|
**model_inputs, |
|
max_length=1024, |
|
no_repeat_ngram_size=10, |
|
repetition_penalty=1.02, |
|
pad_token_id=tokenizer.eos_token_id, |
|
eos_token_id=tokenizer.eos_token_id, |
|
)[0] |
|
|
|
prompt_length = model_inputs['input_ids'].shape[1] |
|
query = tokenizer.decode(output[prompt_length:], skip_special_tokens=False) |
|
try: |
|
stop_idx = query.index("</s>") |
|
except Exception as e: |
|
print(e) |
|
stop_idx = len(query) |
|
print(query[: stop_idx].strip()) |
|
``` |
|
|
|
- PEFT 0.10.0 |