Spaces:
Runtime error
Runtime error
""" | |
1. working! 用自己的方式实现了一个Qwen的数据分析代码。 | |
1. working!在执行streamlit run命令前,先执行python add_fonts.py。这样可以确保中文字体被加载到当前的python执行环境中。 | |
1. working!测试可以在局域网中访问。 | |
1. 多个阶段实现这一功能: | |
1. 根据用户的prompt,用Qwen生成代码。 | |
1. 根据生成的代码,用Python的exec等模块在本地执行代码。 | |
1. 将所有的内容放入hist_message中,然后用Qwen生成结论。 | |
1. 变量名尽量用’‘引号括起来,避免出现错误。 | |
1. plt.rcParams['font.sans-serif'] = ['Microsoft YaHei UI' ## 在qwen数据分析模块中,这里最好用plt。 | |
1. 需要把datafram的信息(如列名,数据类型)加入到prompt中,这样可以帮助用户更好的选择数据进行分析。 | |
1. | |
""" | |
##TODO: 1. re-submit button. | |
# -*- coding: utf-8 -*- | |
import requests | |
import random | |
# from keras.utils.np_utils import to_categorical | |
# from keras.optimizers import SGD | |
# from keras.layers import Dense, Activation, LSTM, Dropout, SimpleRNN, SimpleRNNCell | |
# from keras.models import Sequential | |
# import keras | |
from sklearn.preprocessing import StandardScaler | |
import numpy as np | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
from dateutil.relativedelta import relativedelta | |
from scipy.optimize import minimize | |
import statsmodels.formula.api as smf | |
import statsmodels.tsa.api as smt | |
import statsmodels.api as sm | |
import scipy.stats as scs | |
from itertools import product | |
from tqdm import tqdm_notebook, tqdm, trange | |
import time | |
import pretty_errors | |
import seaborn as sns | |
import sklearn | |
from matplotlib.pyplot import style | |
from sklearn.metrics import r2_score, median_absolute_error, mean_absolute_error | |
from sklearn.metrics import median_absolute_error, mean_squared_error, mean_squared_log_error | |
from sklearn.linear_model import LinearRegression | |
from sklearn.model_selection import cross_val_score | |
from sklearn.model_selection import TimeSeriesSplit | |
import matplotlib.pylab as plt | |
import warnings | |
import re | |
from re import sub | |
import smtplib | |
import io | |
import os | |
import contextlib | |
import streamlit as st | |
import streamlit_authenticator as stauth | |
import random | |
from http import HTTPStatus | |
import dashscope | |
from io import StringIO | |
from PIL import Image | |
import add_fonts | |
import tempfile | |
from tempfile import NamedTemporaryFile | |
warnings.filterwarnings('ignore') | |
### 参考: | |
## export DASHSCOPE_API_KEY="sk-948adb3e65414e55961a9ad9d22d186b" | |
dashscope.api_key = "sk-948adb3e65414e55961a9ad9d22d186b" | |
### make it look nice from the start | |
st.set_page_config(layout='wide', initial_sidebar_state='auto', page_icon="🤖", page_title="本地化大模型智能数据分析演示") | |
## layout settings. | |
st.title("本地化大模型智能数据分析演示") | |
st.subheader("Artificial Intelligence Data Analysis Center for Professionals") | |
st.markdown("_声明:内容由人工智能生成,仅供参考。如果您本人使用或对外传播本服务生成的输出,您应当主动核查输出内容的真实性、准确性,避免传播虚假信息。_") | |
data_show = st.empty() | |
### authentication with a local yaml file. | |
import yaml | |
from yaml.loader import SafeLoader | |
with open('./config.yaml') as file: | |
config = yaml.load(file, Loader=SafeLoader) | |
authenticator = stauth.Authenticate( | |
config['credentials'], | |
config['cookie']['name'], | |
config['cookie']['key'], | |
config['cookie']['expiry_days'], | |
config['preauthorized'] | |
) | |
# authentication with a remove cloud-based database. | |
# 导入云端用户数据库。 | |
# DETA_KEY = "c0zegv33efm_4MBTaoQAn76GzUfsZeKV64Uh9qMY3WZb" | |
# load_dotenv(".env") | |
# DETA_KEY = os.getenv("DETA_KEY") | |
# print(DETA_KEY) | |
# deta = Deta(DETA_KEY) | |
# mybase is the name of the database in Deta. You can change it to any name you want. | |
# credentials = {"usernames":{}} | |
# # credentials = {"users": {}} | |
# # db = db() | |
# users = [] | |
# email = [] | |
# passwords = [] | |
# names = [] | |
# for row in db.fetch_all_users(): | |
# users.append(row["username"]) | |
# email.append(row["email"]) | |
# names.append(row["key"]) | |
# passwords.append(row["password"]) | |
# hashed_passwords = stauth.Hasher(passwords).generate() | |
## 需要严格的按照yaml文件的格式来定义如下几个字段。 | |
# for un, name, pw in zip(users, names, hashed_passwords): | |
# # user_dict = {"name":name,"password":pw} | |
# user_dict = {"name": un, "password": pw} | |
# # credentials["usernames"].update({un:user_dict}) | |
# credentials["usernames"].update({un: user_dict}) | |
# ## sign-up模块,未完成。 | |
# database_table = [] | |
# # print(pd.DataFrame(credentials)) | |
# for i in credentials['usernames'].keys(): | |
# # print("i:",i) | |
# # print("name",credentials['usernames'][i]['name']) | |
# # print("password",credentials['usernames'][i]['password']) | |
# database_table.append([i,credentials['usernames'][i]['name'],credentials['usernames'][i]['password']]) | |
# print("database_table:",database_table) | |
# authenticator = stauth.Authenticate( | |
# credentials=credentials, cookie_name="joeshi_gpt", key='abcedefg', cookie_expiry_days=30) | |
user, authentication_status, username = authenticator.login('用户登录', 'main') | |
# print("name", name, "username", username) | |
# ## sign-up widget,未完成。 | |
# try: | |
# if authenticator.register_user('新用户注册', preauthorization=False): | |
# # for list in database_table: | |
# # db.update_user(username=list[0], name=list[1], password=list[2]) | |
# db.update_user(username=list[-1][0], name=list[-1][1], password=list[-1][2]) | |
# # st.success('User registered successfully') | |
# st.success('注册成功!') | |
# except Exception as e: | |
# st.error(e) | |
# ## clear conversion. | |
def reset_all(): | |
# st.session_state.conversation = None | |
st.session_state.chat_history = None | |
st.session_state.messages = [] | |
# st.session_state.messages | |
message_placeholder = st.empty() | |
data_show = st.empty() | |
return None | |
##TODO:在带有聊天历史的情况下,重新整理这个模块。 | |
def regenerate(user_input): | |
## 因为可能没有历史,第一次的时候,所以需要处理异常。 | |
try: | |
reset_all() | |
main(user_input) | |
except Exception as e: | |
print('Error:', e) | |
pass | |
return None | |
### 对长传数据进行描述性统计,获得列表名 | |
def dataframe_describe(df): | |
df = pd.DataFrame(df) | |
return pd.DataFrame({'column_name': df.columns, 'data_type': df.dtypes.values}) | |
## get the current time | |
from datetime import datetime | |
from pytz import timezone | |
def get_current_time(): | |
beijing_tz = timezone('Asia/Shanghai') | |
beijing_time = datetime.now(beijing_tz) | |
current_time = beijing_time.strftime('%H:%M:%S') | |
return current_time | |
if authentication_status: | |
with st.sidebar: | |
st.markdown( | |
""" | |
<style> | |
[data-testid="stSidebar"][aria-expanded="true"]{ | |
min-width: 550px; | |
max-width: 550px; | |
} | |
""", | |
unsafe_allow_html=True, | |
) | |
### siderbar的题目。 | |
st.header(f'**欢迎 **{username}** 来到人工智能的世界** ♠') | |
st.write(f'_Welcome and Hope U Enjoy Staying Here_') | |
authenticator.logout('登出', 'sidebar') | |
# reset_button_key = "reset_button" | |
# reset_button = st.button(label=("清除所有记录,并开启一轮新对话 ▶"), | |
# key=reset_button_key, use_container_width=True, type="primary") | |
# ## 清除所有对话记录, reset all conversation. | |
# if reset_button: | |
# reset_all() | |
# st.markdown("#### 大语言模型设置") | |
# with st.expander(label='**选择一个大语言模型基座**', expanded=True): | |
radio_1 = st.selectbox( | |
label='选择一个大语言模型基座', | |
options=["阿里通义千问","ChatGPT-3.5", "ChatGPT-4", "清华GLM3-6B", "百川Baichuan-13B", "阿里通义千问-130B", "阿里通义千问-14B", "Llama-2", "Mistral", "Vicuna"], | |
index=0, | |
placeholder="大语言模型列表", | |
) | |
### reset button | |
reset_button_key = "reset_button" | |
reset_button = st.button(label=("清除所有记录,并开启一轮新对话 ▶"), | |
key=reset_button_key, use_container_width=True, type="primary") | |
if reset_button: | |
reset_all() | |
## regenerate button | |
# re_btn = st.sidebar.button("重新生成答案 🔛", use_container_width=True, type='secondary') | |
# if re_btn: | |
# regenerate(my_orginal_question) | |
## 在sidebar上的三个分页显示,用st.tabs实现。 | |
tab_1, tab_2, tab_3, tab_4 = st.tabs(['使用须知', '模型参数', '提示词模板', '系统角色设定']) | |
# with st.expander(label='**使用须知**', expanded=False): | |
with tab_1: | |
# st.markdown("#### 快速上手指南") | |
# with st.text(body="说明"): | |
# st.markdown("* 重启一轮新对话时,只需要刷新页面(按Ctrl/Command + R)即可。") | |
with st.text(body="说明"): | |
st.markdown( | |
"* 利用大语言模型将自然语言查询转化为数据分析语言(如Python)。模型可以理解非技术用户的查询并将其转化为可执行的数据分析命令。") | |
with st.text(body="说明"): | |
st.markdown("* 语言模型可以辅助生成用于数据清洗和预处理的脚本,如Python中的pandas库代码片段。通过自然语言指令,非技术人员也可以执行复杂的数据预处理任务。") | |
with st.text(body="说明"): | |
st.markdown("* 为了保护数据与隐私,所有对话均不会被保存,刷新页面立即删除。敬请放心。") | |
with st.text(body="说明"): | |
st.markdown("* 推荐上传csv格式的数据文件,部分Excel文件容易出现数据不兼容的情况。") | |
## 大模型参数 | |
# with st.expander(label='**大语言模型参数**', expanded=True): | |
with tab_2: | |
max_tokens = st.slider(label='Max_Token(生成结果时最大字数)', min_value=100, max_value=8096, value=4096,step=100) | |
temperature = st.slider(label='Temperature (温度)', min_value=0.0, max_value=1.0, value=0.8, step=0.1) | |
top_p = st.slider(label='Top_P (核采样)', min_value=0.0, max_value=1.0, value=0.6, step=0.1) | |
frequency_penalty = st.slider(label='Frequency Penalty (重复度惩罚因子)', min_value=-2.0, max_value=2.0, value=1.0, step=0.1) | |
presence_penalty = st.slider(label='Presence Penalty (控制主题的重复度)', min_value=-2.0, max_value=2.0, value=1.0, step=0.1) | |
## reset password widget | |
# try: | |
# if authenticator.reset_password(st.session_state["username"], 'Reset password'): | |
# st.success('Password modified successfully') | |
# except Exception as e: | |
# st.error(e) | |
# with st.header(body="欢迎"): | |
# st.markdown("# 欢迎使用大语言模型商业智能中心") | |
# with st.expander(label=("**重要的使用注意事项**"), expanded=True): | |
# with st.container(): | |
with tab_3: | |
# st.markdown("#### Prompt提示词参考资料") | |
# with st.expander(label="**大语言模型基础提示词Prompt示例**", expanded=False): | |
# st.code( | |
# body="继续用中文写一篇关于 [文章主题] 的文章,以下列句子开头:[文章开头]。", language='plaintext') | |
# st.code(body="将以下文字概括为 100 个字,使其易于阅读和理解。避免使用复杂的句子结构或技术术语。", | |
# language='plaintext') | |
# st.code(body="给我出一个迪奥2023春季发布会活动策划。", language='plaintext') | |
# st.code(body="帮我按照正式会议结构写一个会邀:主题是xx手机游戏立项会议。", language='plaintext') | |
# st.code(body="帮我写一个车内健康监测全场景落地的项目计划,用表格。", language='plaintext') | |
# st.code( | |
# body="同时掷两枚质地均匀的骰子,则两枚骰子向上的点数之和为 7 的概率是多少。", language='plaintext') | |
# st.code(body="写一篇产品经理的演讲稿,注意使用以下词汇: 赋能,抓手,中台,闭环,落地,漏斗,沉淀,给到,同步,对齐,对标,迭代,拉通,打通,升级,交付,聚焦,倒逼,复盘,梳理,方案,联动,透传,咬合,洞察,渗透,兜底,解耦,耦合,复用,拆解。", language='plaintext') | |
with st.expander(label="**数据分析模式的专用提示词Prompt示例**", expanded=True): | |
# with st.subheader(body="提示词Prompt"): | |
st.code(body="对文件中的'话务量'和'平均通话时长'进行线性回归分析。", language='python') | |
st.code(body="对文件中的'话务量'和'平均通话时长'进行相关性分析。", language='python') | |
st.code(body="对于对于上传的'话务量'数据给我一个'直方图'',并给出分析结果。", language='python') | |
st.code(body="对文件中'话务量'和'平均通话时长'进行Kmeans聚类分析。", language='python') | |
st.code(body="对于上传的数据中'组别'和'话务量'数据给我一个'swarmplot'分析,并给出分析结果。", language='python') | |
st.code(body="对文件中的'话务量'和'平均通话时长'进行heatmap分析。", language='python') | |
st.code(body="对于'话务量'数据给我一个'小提琴图',并给出分析结果。", language='python') | |
st.code(body="对文件中的基于'组别'对'话务量'进行分类分析。", language='python') | |
st.code(body="对文件中的基于'组别'对'话务量'进行‘swarmplot’分析。", language='python') | |
st.code(body="在上传的数据中,基于'组别'对'话务量'数据给我一个分类汇总,并给出分析结果。", language='python') | |
st.code( | |
body="对于[X,Y,Z]数据在一个'分布散点图 (stripplot)',所有的数据在一张图上展现, 并给出分析结果。", language='python') | |
st.code(body="对于[X,Y]数据,进行'T检验',你需要展示图表,并给出分析结果。", | |
language='python') | |
st.code(body="对于[X,Y]数据给我一个3个类别的'聚类分析',并给出分析结果。", | |
language='python') | |
with tab_4: | |
st.text_area(label='系统角色设定', value='你是一个人工智能,你需要回答我提出的问题,或者完成我交代的任务。你需要使用我提问的语言(如中文、英文)来回答。', height=200, label_visibility='hidden') | |
elif authentication_status == False: | |
st.error('⛔ 用户名或密码错误!') | |
elif authentication_status == None: | |
st.warning('⬆ 请先登录!') | |
### 上传文件的模块 | |
def upload_file(uploaded_file): | |
if uploaded_file is not None: | |
filename = uploaded_file.name | |
data_show = st.empty() | |
if '.csv' in filename: | |
csv_file = pd.read_csv(uploaded_file) | |
csv_file.to_csv(f'./{username}/{username}_upload.csv', encoding='utf-8', index=False) | |
st.write(csv_file[:3]) # 这里只是显示文件,后面需要定位文件所在的绝对路径。 | |
else: | |
xls_file = pd.read_excel(uploaded_file) | |
xls_file.to_csv(f'./{username}/{username}_upload.csv', index=False) | |
data_show.write(xls_file[:3]) | |
# st.write(xls_file[:3]) ### original code here. | |
# upload_df = pd.read_csv(f"/Users/yunshi/Downloads/Qwen/{username}/{username}_upload.csv") | |
# upload_df = pd.read_csv(f'./{username}/{username}_upload.csv') ##TODO: 这里需要改成实际用户的目录 | |
# st.write(upload_df[:3]) | |
# df_info = dataframe_describe(upload_df) | |
uploaded_file_name = "File_provided" | |
temp_dir = tempfile.TemporaryDirectory() | |
return None | |
uploaded_file = st.file_uploader( | |
"选择上传一个文件", type=(["csv", "xlsx", "xls"])) | |
# 默认状态下没有上传文件,None,会报错。需要判断。 | |
if uploaded_file is not None: | |
uploaded_file_path = upload_file(uploaded_file) | |
### 环境设定 | |
my_orginal_question = "" | |
hist_message = [] | |
# messages = [{'role': 'system', 'content': 'You are a helpful assistant.'}, | |
# {'role': 'user', 'content': ""}] | |
current_time = get_current_time() | |
rand_num = str(current_time) + str(random.randint(1, 10000)) | |
sys_prompt = f""" | |
1. 你是一个python的数据分析专家。 | |
2. 你需要加入以下python代码: | |
-------------------------------- | |
import matplotlib | |
import matplotlib.pyplot as plt | |
plt.rcParams['font.sans-serif'] = ['Microsoft YaHei UI'] | |
-------------------------------- | |
3. 你所需要分析的文件位置在:./{username}/{username}_upload.csv | |
4. 如果你需要作图时,你把图片保持在以下路径:/Users/yunshi/Downloads/Qwen/joeshi/{rand_num}.png' | |
""" | |
### 用Qwen来作为大模型输出代码。 | |
def call_with_messages(prompt): | |
### 获得所有dataframe的描述性统计信息。 | |
df_upload = pd.read_csv(f"/Users/yunshi/Downloads/Qwen/{username}/{username}_upload.csv") | |
df_info = dataframe_describe(df_upload) | |
### 将描述性统计信息加入到prompt中。 | |
prompt = f"数据文件的描述性统计信息如下:\n {df_info}。 你只能从这些column names中选择数据进行分析。\n\n 我的要求如下: {prompt}", | |
hist_message.append(prompt) | |
messages = [{'role': 'system', 'content': sys_prompt}, {'role': 'user', 'content': str(hist_message)}] | |
# # {'role': 'user', 'content': '如何做西红柿炒鸡蛋?'}] | |
## Qwen的调用 | |
response = dashscope.Generation.call( | |
"qwen-turbo", | |
messages=messages, | |
# set the random seed, optional, default to 1234 if not set | |
seed=random.randint(1, 10000), | |
# set the result to be "message" format. | |
result_format='message', | |
temperature=0.1, | |
) | |
if response.status_code == HTTPStatus.OK: | |
print(response) | |
else: | |
print('Request id: %s, Status code: %s, error code: %s, error message: %s' % ( | |
response.request_id, response.status_code, | |
response.code, response.message | |
)) | |
llm_response = response['output']['choices'][0]['message']['content'] | |
print('llm_response:', llm_response) | |
hist_message.append(llm_response) | |
# return response['output']['choices'][0]['message']['content'] ### 这里是content的内容,不是message的全部内容。 | |
return llm_response ### 这里是content的内容,不是message的全部内容。 | |
### clean code | |
def clean_code(string): | |
pattern = r'```python\n([\s\S]*?)\n```' | |
code_blocks = re.findall(pattern, string, re.DOTALL) | |
if code_blocks: | |
code = "\n".join(code_blocks) | |
print("-------------------") | |
print('clean code: \n', code) | |
print("-------------------") | |
else: | |
code = "无法进行有效分析" | |
print("No Python code found in the given text.") | |
return code | |
### 运行python code的核心代码。 | |
def execute_Python_code(code): | |
# A string stream to capture the outputs of exec | |
output = io.StringIO() | |
try: | |
# Redirect stdout to the StringIO object | |
with contextlib.redirect_stdout(output): | |
# Allow imports | |
exec(code, globals()) | |
except Exception as e: | |
# If an error occurs, capture it as part of the output | |
print(f"Error: {e}", file=output) | |
return output.getvalue() | |
# prompt = f""" | |
# 你用python对如下文件中的cyl和displ进行聚类分析,然后画出聚类图: | |
# {"/Users/yunshi/Downloads/360Data/Data Center/Working-On Task/演讲与培训/2023ChatGPT/Coding/code_interpreter/rawdata/Cars.csv"} | |
# """ | |
###! working! | |
# prompt = f""" | |
# 你用python对如下文件中的cyl和displ进行相关性分析,然后画出散点图: | |
# {"/Users/yunshi/Downloads/360Data/Data Center/Working-On Task/演讲与培训/2023ChatGPT/Coding/code_interpreter/rawdata/Cars.csv"} | |
# """ | |
###! working! | |
# prompt = f""" | |
# 你用python对如下文件中的cyl和displ进行线性回归分析: | |
# {"/Users/yunshi/Downloads/360Data/Data Center/Working-On Task/演讲与培训/2023ChatGPT/Coding/code_interpreter/rawdata/Cars.csv"} | |
# """ | |
# llm_response = call_with_messages(prompt=prompt) | |
# print('llm_response:', llm_response) | |
# myCode = clean_code(llm_response) | |
# python_res = execute_Python_code(myCode) | |
# hist_message.append(python_res) | |
# analysis_prompt = f""" | |
# 你是一个数据分析专家。根据如下所有的内容,你告诉我本次数据分析的结论。你需要根据已知的所有信息,结合你所知道的行业信息,给出所有可能的结论和对应的分析。 | |
# 本次数据分析相关的内容下: | |
# ----------------------- | |
# {hist_message} | |
# """ | |
# final_res = call_with_messages(prompt=analysis_prompt) | |
# print(final_res) | |
## 核心函数 | |
def main(): | |
# Initialize chat history | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
# Display chat messages from history on app rerun | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
# Display assistant response in chat message container | |
# if prompt := st.chat_input("说点什么吧"): | |
prompt = st.chat_input("说点什么吧...") | |
# global my_orginal_question | |
my_orginal_question = str(prompt) | |
print('prompt now:', prompt) | |
print('----------'*5) | |
if prompt: | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
with st.chat_message("assistant", avatar="./llm_icon.png" ): | |
message_placeholder = st.empty() | |
# full_response = "" | |
llm_response = call_with_messages(prompt=prompt) | |
message_placeholder.markdown(llm_response) | |
st.session_state.messages.append({"role": "assistant", "content": llm_response}) | |
with st.chat_message("assistant", avatar="./log_icon.png"): | |
code_placeholder = st.empty() | |
myCode = clean_code(llm_response) | |
python_res = execute_Python_code(myCode) | |
# code_placeholder.info(python_res) ## 用streamlit info来显示内容,浅蓝色底框。 | |
### 有时没有任何内容输出,所以需要判断一下。 | |
if len(python_res) > 0: | |
code_placeholder.code(python_res) | |
# code_placeholder.markdown(python_res) | |
## 确认是否有图片输出。 | |
folder_path = './joeshi/' ##TODO: 需要改成实际用户的目录 | |
search_string = str(rand_num) | |
for filename in os.listdir(folder_path): | |
file_path = os.path.join(folder_path, filename) | |
if os.path.isfile(file_path) and filename.startswith(search_string): | |
st.image(file_path, use_column_width=True) | |
st.session_state.messages.append({"role": "assistant", "content": python_res}) | |
with st.chat_message("assistant", avatar="./solution_icon.png"): | |
analysis_prompt = f""" | |
你是一个数据分析专家。根据如下所有的内容,你告诉我本次数据分析的结论。你需要根据已知的所有信息,结合你所知道的行业信息,给出所有可能的结论和对应的分析。 | |
本次数据分析相关的内容下: | |
----------------------- | |
{st.session_state.messages} | |
""" | |
final_messsage = st.empty() | |
final_res = call_with_messages(prompt=analysis_prompt) | |
final_messsage.markdown(final_res) | |
if __name__ == '__main__': | |
main() | |