Aziz Alto commited on
Commit
7c72cf3
β€’
1 Parent(s): 51c318e

Suggest customized questions for any dataset πŸ”₯ powered by GPT

Browse files
Files changed (1) hide show
  1. app.py +143 -63
app.py CHANGED
@@ -6,12 +6,12 @@ import pandas as pd
6
  import streamlit as st
7
  import streamlit_ace as stace
8
  import duckdb
9
- import numpy as np # for user session
10
- import scipy # for user session
11
  import plotly_express
12
- import plotly.express as px # for user session
13
- import plotly.figure_factory as ff # for user session
14
- import matplotlib.pyplot as plt # for user session
15
  import sklearn
16
  from ydata_profiling import ProfileReport
17
  from streamlit_pandas_profiling import st_profile_report
@@ -24,7 +24,16 @@ header = """
24
  > `GPT-powered` and `Jupyter notebook-inspired`
25
  """
26
  st.markdown(header, unsafe_allow_html=True)
27
- st.markdown("> <sub>[NYC AI Hackathon](https://tech.cornell.edu/events/nyc-gpt-llm-hackathon/) April, 23 2023</sub>", unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
28
 
29
  if "OPENAI_API_KEY" not in os.environ:
30
  os.environ["OPENAI_API_KEY"] = st.text_input("OpenAI API Key", type="password")
@@ -34,6 +43,7 @@ p = st.write
34
  print = st.write
35
  display = st.write
36
 
 
37
  @st.cache_data
38
  def _read_csv(f, **kwargs):
39
  df = pd.read_csv(f, on_bad_lines="skip", **kwargs)
@@ -45,8 +55,9 @@ def _read_csv(f, **kwargs):
45
  def timer(func):
46
  def wrapper_function(*args, **kwargs):
47
  start_time = time.time()
48
- func(*args, **kwargs)
49
  st.write(f"`{(time.time() - start_time):.2f}s.`")
 
50
  return wrapper_function
51
 
52
 
@@ -59,7 +70,7 @@ SAMPLE_DATA = {
59
  "Country Table": "https://raw.githubusercontent.com/datasciencedojo/datasets/master/WorldDBTables/CountryTable.csv",
60
  "World Cities": "https://raw.githubusercontent.com/dr5hn/countries-states-cities-database/master/csv/cities.csv",
61
  "World States": "https://raw.githubusercontent.com/dr5hn/countries-states-cities-database/master/csv/states.csv",
62
- "World Countries": "https://raw.githubusercontent.com/dr5hn/countries-states-cities-database/master/csv/countries.csv"
63
  }
64
 
65
 
@@ -78,7 +89,9 @@ def read_data():
78
  if url:
79
  file_ = url
80
  with col3:
81
- selected = st.selectbox("Select a sample dataset", options=[""] + list(SAMPLE_DATA))
 
 
82
  if selected:
83
  file_ = SAMPLE_DATA[selected]
84
 
@@ -122,12 +135,28 @@ def code_editor(language, hint, show_panel, key=None, content=None):
122
  _KEYBINDINGS = stace.KEYBINDINGS
123
  col21, col22 = st.columns(2)
124
  with col21:
125
- theme = st.selectbox("Theme", options=[default_theme] + _THEMES, key=f"{language}1{key}")
126
- tab_size = st.slider("Tab size", min_value=1, max_value=8, value=4, key=f"{language}2{key}")
 
 
 
 
127
  with col22:
128
- keybinding = st.selectbox("Keybinding", options=[_KEYBINDINGS[-2]] + _KEYBINDINGS, key=f"{language}3{key}")
129
- font_size = st.slider("Font size", min_value=5, max_value=24, value=14, key=f"{language}4{key}")
130
- height = st.slider("Editor height", value=130, max_value=777,key=f"{language}5{key}")
 
 
 
 
 
 
 
 
 
 
 
 
131
  # kwargs = {theme: theme, keybinding: keybinding} # TODO: DRY
132
  if not show_panel:
133
  placeholder.empty()
@@ -143,7 +172,7 @@ def code_editor(language, hint, show_panel, key=None, content=None):
143
  theme=theme,
144
  font_size=font_size,
145
  tab_size=tab_size,
146
- key=key
147
  )
148
 
149
  # Display editor's content as you type
@@ -167,13 +196,7 @@ def download(df, key, save_as="results.csv"):
167
  return _df.to_csv().encode("utf-8")
168
 
169
  csv = convert_df(df)
170
- st.download_button(
171
- "Download",
172
- csv,
173
- save_as,
174
- "text/csv",
175
- key=key
176
- )
177
 
178
 
179
  def display_results(query: str, result: pd.DataFrame, key: str):
@@ -186,7 +209,7 @@ def display_results(query: str, result: pd.DataFrame, key: str):
186
  def run_python_script(user_script, key):
187
  if user_script.startswith("st.") or ";" in user_script:
188
  py = user_script
189
- elif user_script.endswith("?"): # -- same as ? in Jupyter Notebook
190
  in_ = user_script.replace("?", "")
191
  py = f"st.help({in_})"
192
  else:
@@ -278,7 +301,7 @@ def display_example_snippets():
278
 
279
 
280
  class GPTWrapper:
281
- def __init__(self):#, df_info):
282
 
283
  from gpt import AnthropicSerivce, OpenAIService
284
 
@@ -289,6 +312,7 @@ class GPTWrapper:
289
  @st.cache_data
290
  def ask_sql(df_info, question):
291
  from gpt import OpenAIService
 
292
  openai_model = OpenAIService()
293
  prompt = GPTWrapper().build_sql_prompt(df_info, question)
294
  res = openai_model.prompt(prompt)
@@ -298,18 +322,19 @@ class GPTWrapper:
298
  @st.cache_data
299
  def ask_python(df_info, question):
300
  from gpt import OpenAIService
 
301
  openai_model = OpenAIService()
302
  prompt = GPTWrapper().build_python_prompt(df_info, question)
303
  res = openai_model.prompt(prompt)
304
  return res, prompt
305
 
306
-
307
  @staticmethod
308
  @st.cache_data
309
  def build_sql_prompt(df_info, question):
310
  prompt = f"""I have data in a pandas dataframe, here is the data schema: {df_info}
311
  Next, I will ask you a question. Assume the table name is `df`.
312
- And you will answer in writing a SQL query only. {question}
 
313
  """
314
  return prompt
315
 
@@ -317,30 +342,49 @@ class GPTWrapper:
317
  @st.cache_data
318
  def build_python_prompt(df_info, question):
319
  prompt = f"""I have data in a pandas dataframe, here is the dataframe schema: {df_info}
320
- Next, I will ask you a question. And you will answer in writing a Python code only.
321
- Assume the data is stored in a variable named `df`.
322
- Here are some instructions for the generated Python code:
 
323
 
324
- - You should always use the variable `df` to refer to the dataframe.
325
- - You should not include any markdown syntax or any other syntax that is not Python in the answer.
326
  - Import any required libraries in the first line of the generated code.
327
- - Just show the Python code only, don't include any Python comments or English explanation in the answer text.
328
- - If the generarted code has multiple Python lines, every Python line must end with a semicolon (;).
329
- - If the answer is not a plot or a figure, always use print to print the answer using print().
330
- - If the answer requires plotting, generate a plot using plotly_express and show it using st.plotly_chart(fig).
 
 
 
331
 
332
  Here is the question: {question}
333
  """
334
  return prompt
335
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
 
337
  def ask_gpt_sql(df_info, key):
338
  # -- GPT AI
339
  # agi = GPTWrapper(df_info=df_info)
340
- question = st.text_input("Ask a question about the dataset to get a SQL query that answers the question",
341
- placeholder="How many rows are there in the dataset?",
342
- key=key
343
- )
 
344
  if question:
345
  # res, prompt = agi.ask_sql(df_info, question)
346
  res, prompt = GPTWrapper().ask_sql(df_info, question)
@@ -349,22 +393,34 @@ def ask_gpt_sql(df_info, key):
349
  st.code(sql_query, language="sql")
350
  return sql_query
351
 
 
 
 
 
 
 
 
352
  def ask_gpt_python(df_info, key):
353
  # -- GPT AI
354
- # agi = GPTWrapper(df_info=df_info)
355
- question = st.text_input("Ask a question about the dataset to get a Python code that answers the question",
356
- placeholder="How many rows and columns are there in the dataset?",
357
- key=key
358
- )
 
359
  if question:
360
- # res, prompt = agi.ask_python(df_info, question)
361
  res, prompt = GPTWrapper().ask_python(df_info, question)
362
- # st.markdown(f"```{prompt}```")
363
  python_code = res.choices[0].message.content
364
  st.code(python_code, language="python")
365
- # st.markdown(f"```{python_code}```", unsafe_allow_html=True)
366
  return python_code
367
 
 
 
 
 
 
 
368
 
369
  if __name__ == "__main__":
370
  show_examples = docs()
@@ -381,8 +437,6 @@ if __name__ == "__main__":
381
  df.info(buf=sio)
382
  df_info = sio.getvalue()
383
  # st.markdown(f"```{df_info}```", unsafe_allow_html=True)
384
-
385
-
386
 
387
  # run and execute SQL script
388
  def sql_cells(df):
@@ -394,24 +448,32 @@ if __name__ == "__main__":
394
  Describe the table:
395
  DESCRIBE TABLE df
396
  """
397
- number_cells = st.sidebar.number_input("Number of SQL cells to use", value=1, max_value=40)
 
 
398
  for i in range(number_cells):
399
  key = f"sql{i}"
400
  col1, col2 = st.columns([2, 1])
401
  st.markdown("<br>", unsafe_allow_html=True)
402
- show_panel = False #col2.checkbox("Show cell config panel", key=f"{i}-sql")
403
-
 
404
 
405
  col1.write(f"> `IN[{i+1}]`")
406
 
407
- # with col2:
408
  # -- GPT AI
409
  query = ask_gpt_sql(df_info, key=f"{key}-gpt")
410
  content = None
411
- if query and st.button("Use SQL", key=f"{key}-use-sql"):
412
  content = query
413
- # with col1:
414
- sql = code_editor("sql", hint, show_panel=show_panel, key=key, content=content if content else None)
 
 
 
 
 
 
415
  if sql:
416
  st.code(sql, language="sql")
417
  st.write(f"`OUT[{i+1}]`")
@@ -451,27 +513,42 @@ if __name__ == "__main__":
451
  st.bar_chart(groups[i].mean())
452
  ```
453
  """
454
- number_cells = st.sidebar.number_input("Number of Python cells to use", value=1, max_value=40, min_value=1, help=help)
 
 
 
 
 
 
455
  for i in range(number_cells):
456
  # st.markdown("<br><br><br>", unsafe_allow_html=True)
457
  col1, col2 = st.columns([2, 1])
458
  # col1.write(f"> `IN[{i+1}]`")
459
- show_panel = False # col2.checkbox("Show cell config panel", key=f"panel{i}")
 
 
460
 
461
  # -- GPT AI
462
  query = ask_gpt_python(df_info, key=f"{i}-gpt")
463
  content = None
464
- if query and st.checkbox("Use generated code", key=f"{i}-use-python"):
465
  content = query
466
- user_script = code_editor("python", hint, show_panel=show_panel, key=i, content=content if content else None)
 
 
 
 
 
 
467
  if user_script:
468
- df.rename(columns={"lng": "lon"}, inplace=True) # hot-fix for "World Population" dataset
 
 
469
  st.write(f"> `IN[{i+1}]`")
470
  st.code(user_script, language="python")
471
  st.write(f"> `OUT[{i+1}]`")
472
  run_python_script(user_script, key=f"{user_script}{i}")
473
 
474
-
475
  if st.sidebar.checkbox("Show SQL cells", value=True):
476
  sql_cells(df)
477
  if st.sidebar.checkbox("Show Python cells", value=True):
@@ -479,7 +556,10 @@ if __name__ == "__main__":
479
 
480
  st.sidebar.write("---")
481
 
482
- if st.sidebar.checkbox("Generate Data Profile Report", help="pandas profiling, generated by [ydata-profiling](https://github.com/ydataai/ydata-profiling)"):
 
 
 
483
  st.write("---")
484
  st.header("Data Profiling")
485
  profile = data_profiler(df)
 
6
  import streamlit as st
7
  import streamlit_ace as stace
8
  import duckdb
9
+ import numpy as np # for user session
10
+ import scipy # for user session
11
  import plotly_express
12
+ import plotly.express as px # for user session
13
+ import plotly.figure_factory as ff # for user session
14
+ import matplotlib.pyplot as plt # for user session
15
  import sklearn
16
  from ydata_profiling import ProfileReport
17
  from streamlit_pandas_profiling import st_profile_report
 
24
  > `GPT-powered` and `Jupyter notebook-inspired`
25
  """
26
  st.markdown(header, unsafe_allow_html=True)
27
+ st.markdown(
28
+ "> <sub>[NYC AI Hackathon](https://tech.cornell.edu/events/nyc-gpt-llm-hackathon/) April, 23 2023</sub>",
29
+ unsafe_allow_html=True,
30
+ )
31
+
32
+
33
+ if "ANTHROPIC_API_KEY" not in os.environ:
34
+ os.environ["ANTHROPIC_API_KEY"] = st.text_input(
35
+ "Anthropic API Key", type="password"
36
+ )
37
 
38
  if "OPENAI_API_KEY" not in os.environ:
39
  os.environ["OPENAI_API_KEY"] = st.text_input("OpenAI API Key", type="password")
 
43
  print = st.write
44
  display = st.write
45
 
46
+
47
  @st.cache_data
48
  def _read_csv(f, **kwargs):
49
  df = pd.read_csv(f, on_bad_lines="skip", **kwargs)
 
55
  def timer(func):
56
  def wrapper_function(*args, **kwargs):
57
  start_time = time.time()
58
+ func(*args, **kwargs)
59
  st.write(f"`{(time.time() - start_time):.2f}s.`")
60
+
61
  return wrapper_function
62
 
63
 
 
70
  "Country Table": "https://raw.githubusercontent.com/datasciencedojo/datasets/master/WorldDBTables/CountryTable.csv",
71
  "World Cities": "https://raw.githubusercontent.com/dr5hn/countries-states-cities-database/master/csv/cities.csv",
72
  "World States": "https://raw.githubusercontent.com/dr5hn/countries-states-cities-database/master/csv/states.csv",
73
+ "World Countries": "https://raw.githubusercontent.com/dr5hn/countries-states-cities-database/master/csv/countries.csv",
74
  }
75
 
76
 
 
89
  if url:
90
  file_ = url
91
  with col3:
92
+ selected = st.selectbox(
93
+ "Select a sample dataset", options=[""] + list(SAMPLE_DATA)
94
+ )
95
  if selected:
96
  file_ = SAMPLE_DATA[selected]
97
 
 
135
  _KEYBINDINGS = stace.KEYBINDINGS
136
  col21, col22 = st.columns(2)
137
  with col21:
138
+ theme = st.selectbox(
139
+ "Theme", options=[default_theme] + _THEMES, key=f"{language}1{key}"
140
+ )
141
+ tab_size = st.slider(
142
+ "Tab size", min_value=1, max_value=8, value=4, key=f"{language}2{key}"
143
+ )
144
  with col22:
145
+ keybinding = st.selectbox(
146
+ "Keybinding",
147
+ options=[_KEYBINDINGS[-2]] + _KEYBINDINGS,
148
+ key=f"{language}3{key}",
149
+ )
150
+ font_size = st.slider(
151
+ "Font size",
152
+ min_value=5,
153
+ max_value=24,
154
+ value=14,
155
+ key=f"{language}4{key}",
156
+ )
157
+ height = st.slider(
158
+ "Editor height", value=130, max_value=777, key=f"{language}5{key}"
159
+ )
160
  # kwargs = {theme: theme, keybinding: keybinding} # TODO: DRY
161
  if not show_panel:
162
  placeholder.empty()
 
172
  theme=theme,
173
  font_size=font_size,
174
  tab_size=tab_size,
175
+ key=key,
176
  )
177
 
178
  # Display editor's content as you type
 
196
  return _df.to_csv().encode("utf-8")
197
 
198
  csv = convert_df(df)
199
+ st.download_button("Download", csv, save_as, "text/csv", key=key)
 
 
 
 
 
 
200
 
201
 
202
  def display_results(query: str, result: pd.DataFrame, key: str):
 
209
  def run_python_script(user_script, key):
210
  if user_script.startswith("st.") or ";" in user_script:
211
  py = user_script
212
+ elif user_script.endswith("?"): # -- same as ? in Jupyter Notebook
213
  in_ = user_script.replace("?", "")
214
  py = f"st.help({in_})"
215
  else:
 
301
 
302
 
303
  class GPTWrapper:
304
+ def __init__(self): # , df_info):
305
 
306
  from gpt import AnthropicSerivce, OpenAIService
307
 
 
312
  @st.cache_data
313
  def ask_sql(df_info, question):
314
  from gpt import OpenAIService
315
+
316
  openai_model = OpenAIService()
317
  prompt = GPTWrapper().build_sql_prompt(df_info, question)
318
  res = openai_model.prompt(prompt)
 
322
  @st.cache_data
323
  def ask_python(df_info, question):
324
  from gpt import OpenAIService
325
+
326
  openai_model = OpenAIService()
327
  prompt = GPTWrapper().build_python_prompt(df_info, question)
328
  res = openai_model.prompt(prompt)
329
  return res, prompt
330
 
 
331
  @staticmethod
332
  @st.cache_data
333
  def build_sql_prompt(df_info, question):
334
  prompt = f"""I have data in a pandas dataframe, here is the data schema: {df_info}
335
  Next, I will ask you a question. Assume the table name is `df`.
336
+ And you will answer in writing a SQL query only by using the table `df` and shema above.
337
+ Here is the question: {question}.
338
  """
339
  return prompt
340
 
 
342
  @st.cache_data
343
  def build_python_prompt(df_info, question):
344
  prompt = f"""I have data in a pandas dataframe, here is the dataframe schema: {df_info}
345
+ Next, I will ask you a question. Assume the data is stored in a variable named `df`.
346
+ And you will answer in writing a Python code only by using the variable `df` and shema above.
347
+
348
+ Here are some instructions you must follow when writing the code:
349
 
350
+ - The answer must be Python code only.
351
+ - The code must include column names from the dataframe schema above only.
352
  - Import any required libraries in the first line of the generated code.
353
+ - Use `df` as the variable name for the dataframe.
354
+ - Don't include any comments in the code.
355
+ - Every line of code must end with `;`.
356
+ - For non-plotting answers, you must use `print()` to print the answer.
357
+ - For plotting answers, one of the folowing options must be used:
358
+ - `st.pyplot(fig)` to display the plot in the Streamlit app.
359
+ - plotly_express to generate a plot and `st.plotly_chart()` to show it.
360
 
361
  Here is the question: {question}
362
  """
363
  return prompt
364
 
365
+ @staticmethod
366
+ @st.cache_data
367
+ def suggest_questions(df_info, language):
368
+ prompt = f"""
369
+ {df_info}
370
+
371
+ What questions (exploratory or explanatory) can be asked about this dataset to analyze the data as a whole using {language}? Be as specific as possible based on the data schema above.
372
+ """
373
+ from gpt import OpenAIService
374
+
375
+ openai_model = OpenAIService()
376
+ res = openai_model.prompt(prompt)
377
+ return res, prompt
378
+
379
 
380
  def ask_gpt_sql(df_info, key):
381
  # -- GPT AI
382
  # agi = GPTWrapper(df_info=df_info)
383
+ question = st.text_input(
384
+ "Ask a question about the dataset to get a SQL query that answers the question",
385
+ placeholder="How many rows are there in the dataset?",
386
+ key=key,
387
+ )
388
  if question:
389
  # res, prompt = agi.ask_sql(df_info, question)
390
  res, prompt = GPTWrapper().ask_sql(df_info, question)
 
393
  st.code(sql_query, language="sql")
394
  return sql_query
395
 
396
+ with st.expander("Example questions"):
397
+ res, prompt = GPTWrapper().suggest_questions(df_info, "SQL")
398
+ suggestions = res.choices[0].message.content
399
+ st.markdown("Here are some example questions:")
400
+ st.markdown(f"```{suggestions}```", unsafe_allow_html=True)
401
+
402
+
403
  def ask_gpt_python(df_info, key):
404
  # -- GPT AI
405
+
406
+ question = st.text_input(
407
+ "Ask a question about the dataset to get a Python code that answers the question",
408
+ placeholder="How many rows and columns are there in the dataset?",
409
+ key=key,
410
+ )
411
  if question:
 
412
  res, prompt = GPTWrapper().ask_python(df_info, question)
 
413
  python_code = res.choices[0].message.content
414
  st.code(python_code, language="python")
415
+
416
  return python_code
417
 
418
+ with st.expander("Example questions"):
419
+ res, prompt = GPTWrapper().suggest_questions(df_info, "Python")
420
+ suggestions = res.choices[0].message.content
421
+ st.markdown("Here are some example questions:")
422
+ st.markdown(suggestions, unsafe_allow_html=True)
423
+
424
 
425
  if __name__ == "__main__":
426
  show_examples = docs()
 
437
  df.info(buf=sio)
438
  df_info = sio.getvalue()
439
  # st.markdown(f"```{df_info}```", unsafe_allow_html=True)
 
 
440
 
441
  # run and execute SQL script
442
  def sql_cells(df):
 
448
  Describe the table:
449
  DESCRIBE TABLE df
450
  """
451
+ number_cells = st.sidebar.number_input(
452
+ "Number of SQL cells to use", value=1, max_value=40
453
+ )
454
  for i in range(number_cells):
455
  key = f"sql{i}"
456
  col1, col2 = st.columns([2, 1])
457
  st.markdown("<br>", unsafe_allow_html=True)
458
+ show_panel = (
459
+ False # col2.checkbox("Show cell config panel", key=f"{i}-sql")
460
+ )
461
 
462
  col1.write(f"> `IN[{i+1}]`")
463
 
 
464
  # -- GPT AI
465
  query = ask_gpt_sql(df_info, key=f"{key}-gpt")
466
  content = None
467
+ if query and st.button("Run the generated code", key=f"{key}-use-sql"):
468
  content = query
469
+
470
+ sql = code_editor(
471
+ "sql",
472
+ hint,
473
+ show_panel=show_panel,
474
+ key=key,
475
+ content=content if content else None,
476
+ )
477
  if sql:
478
  st.code(sql, language="sql")
479
  st.write(f"`OUT[{i+1}]`")
 
513
  st.bar_chart(groups[i].mean())
514
  ```
515
  """
516
+ number_cells = st.sidebar.number_input(
517
+ "Number of Python cells to use",
518
+ value=1,
519
+ max_value=40,
520
+ min_value=1,
521
+ help=help,
522
+ )
523
  for i in range(number_cells):
524
  # st.markdown("<br><br><br>", unsafe_allow_html=True)
525
  col1, col2 = st.columns([2, 1])
526
  # col1.write(f"> `IN[{i+1}]`")
527
+ show_panel = (
528
+ False # col2.checkbox("Show cell config panel", key=f"panel{i}")
529
+ )
530
 
531
  # -- GPT AI
532
  query = ask_gpt_python(df_info, key=f"{i}-gpt")
533
  content = None
534
+ if query and st.checkbox("Run the generated code", key=f"{i}-use-python"):
535
  content = query
536
+ user_script = code_editor(
537
+ "python",
538
+ hint,
539
+ show_panel=show_panel,
540
+ key=i,
541
+ content=content if content else None,
542
+ )
543
  if user_script:
544
+ df.rename(
545
+ columns={"lng": "lon"}, inplace=True
546
+ ) # hot-fix for "World Population" dataset
547
  st.write(f"> `IN[{i+1}]`")
548
  st.code(user_script, language="python")
549
  st.write(f"> `OUT[{i+1}]`")
550
  run_python_script(user_script, key=f"{user_script}{i}")
551
 
 
552
  if st.sidebar.checkbox("Show SQL cells", value=True):
553
  sql_cells(df)
554
  if st.sidebar.checkbox("Show Python cells", value=True):
 
556
 
557
  st.sidebar.write("---")
558
 
559
+ if st.sidebar.checkbox(
560
+ "Generate Data Profile Report",
561
+ help="pandas profiling, generated by [ydata-profiling](https://github.com/ydataai/ydata-profiling)",
562
+ ):
563
  st.write("---")
564
  st.header("Data Profiling")
565
  profile = data_profiler(df)