Chirayu commited on
Commit
30f36e8
·
1 Parent(s): 18efae2

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +53 -0
README.md CHANGED
@@ -1,3 +1,56 @@
1
  ---
2
  license: mit
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: mit
3
+ tags:
4
+ - code
5
+ language:
6
+ - en
7
  ---
8
+ # What does this model do?
9
+ This model converts the natural language input to MongoDB (MQL) query. It is a fine-tuned CodeT5+ 220M. This model is a part of nl2query repository which is present at https://github.com/Chirayu-Tripathi/nl2query
10
+
11
+ You can use this model via the github repository or via following code.
12
+
13
+ ```python
14
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
15
+ import torch
16
+ model = AutoModelForSeq2SeqLM.from_pretrained("Chirayu/nl2mongo")
17
+ tokenizer = AutoTokenizer.from_pretrained("Chirayu/nl2mongo")
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ model = model.to(device)
20
+ def generate_query(
21
+ textual_query: str,
22
+ num_beams: int = 10,
23
+ max_length: int = 128,
24
+ repetition_penalty: int = 2.5,
25
+ length_penalty: int = 1,
26
+ early_stopping: bool = True,
27
+ top_p: int = 0.95,
28
+ top_k: int = 50,
29
+ num_return_sequences: int = 1,
30
+ ) -> str:
31
+ input_ids = tokenizer.encode(
32
+ textual_query, return_tensors="pt", add_special_tokens=True
33
+ )
34
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
+ input_ids = input_ids.to(device)
36
+ generated_ids = model.generate(
37
+ input_ids=input_ids,
38
+ num_beams=num_beams,
39
+ max_length=max_length,
40
+ repetition_penalty=repetition_penalty,
41
+ length_penalty=length_penalty,
42
+ early_stopping=early_stopping,
43
+ top_p=top_p,
44
+ top_k=top_k,
45
+ num_return_sequences=num_return_sequences,
46
+ )
47
+ query = [
48
+ tokenizer.decode(
49
+ generated_id,
50
+ skip_special_tokens=True,
51
+ clean_up_tokenization_spaces=True,
52
+ )
53
+ for generated_id in generated_ids
54
+ ][0]
55
+ return query
56
+ ```