asoria HF staff commited on
Commit
d61f780
1 Parent(s): a127a18

Use transformers

Browse files
Files changed (1) hide show
  1. app.py +13 -1
app.py CHANGED
@@ -7,6 +7,7 @@ from huggingface_hub import HfApi
7
  from huggingface_hub.utils import logging
8
  from llama_cpp import Llama
9
  import pandas as pd
 
10
 
11
  load_dotenv()
12
 
@@ -27,10 +28,13 @@ client = Client(headers=headers)
27
  api = HfApi(token=HF_TOKEN)
28
 
29
  print("About to load DuckDB-NSQL-7B model")
 
30
  llama = Llama(
31
  model_path="DuckDB-NSQL-7B-v0.1-q8_0.gguf",
32
  n_ctx=2048,
33
  )
 
 
34
  print("DuckDB-NSQL-7B model has been loaded")
35
 
36
  def get_first_parquet(dataset: str):
@@ -47,6 +51,14 @@ def query_remote_model(text):
47
  pred = response.json()
48
  return pred[0]["generated_text"]
49
 
 
 
 
 
 
 
 
 
50
 
51
  def query_local_model(text):
52
  pred = llama(text, temperature=0.1, max_tokens=500)
@@ -84,7 +96,7 @@ def text2sql(dataset_name, query_input):
84
 
85
  # sql_output = query_remote_model(text)
86
 
87
- sql_output = query_local_model(text)
88
 
89
  try:
90
  query_result = con.sql(sql_output).df()
 
7
  from huggingface_hub.utils import logging
8
  from llama_cpp import Llama
9
  import pandas as pd
10
+ from transformers import pipeline
11
 
12
  load_dotenv()
13
 
 
28
  api = HfApi(token=HF_TOKEN)
29
 
30
  print("About to load DuckDB-NSQL-7B model")
31
+ """
32
  llama = Llama(
33
  model_path="DuckDB-NSQL-7B-v0.1-q8_0.gguf",
34
  n_ctx=2048,
35
  )
36
+ """
37
+ pipe = pipeline("text-generation", model="motherduckdb/DuckDB-NSQL-7B-v0.1")
38
  print("DuckDB-NSQL-7B model has been loaded")
39
 
40
  def get_first_parquet(dataset: str):
 
51
  pred = response.json()
52
  return pred[0]["generated_text"]
53
 
54
+ def query_local_model_transformers(text):
55
+ pred = pipe(text)
56
+ print(type(pred))
57
+ print(pred)
58
+ return pred
59
+ #pred = llama(text, temperature=0.1, max_tokens=500)
60
+ #return pred["choices"][0]["text"]
61
+
62
 
63
  def query_local_model(text):
64
  pred = llama(text, temperature=0.1, max_tokens=500)
 
96
 
97
  # sql_output = query_remote_model(text)
98
 
99
+ sql_output = query_local_model_transformers(text)
100
 
101
  try:
102
  query_result = con.sql(sql_output).df()