Spaces:
Build error
Build error
import datetime | |
import pytz | |
from flask import request | |
from flask_login import current_user | |
from flask_restful import Resource, fields, marshal_with, reqparse | |
from configs import dify_config | |
from constants.languages import supported_language | |
from controllers.console import api | |
from controllers.console.workspace.error import ( | |
AccountAlreadyInitedError, | |
CurrentPasswordIncorrectError, | |
InvalidInvitationCodeError, | |
RepeatPasswordNotMatchError, | |
) | |
from controllers.console.wraps import account_initialization_required, setup_required | |
from extensions.ext_database import db | |
from fields.member_fields import account_fields | |
from libs.helper import TimestampField, timezone | |
from libs.login import login_required | |
from models import AccountIntegrate, InvitationCode | |
from services.account_service import AccountService | |
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError | |
class AccountInitApi(Resource): | |
def post(self): | |
account = current_user | |
if account.status == "active": | |
raise AccountAlreadyInitedError() | |
parser = reqparse.RequestParser() | |
if dify_config.EDITION == "CLOUD": | |
parser.add_argument("invitation_code", type=str, location="json") | |
parser.add_argument("interface_language", type=supported_language, required=True, location="json") | |
parser.add_argument("timezone", type=timezone, required=True, location="json") | |
args = parser.parse_args() | |
if dify_config.EDITION == "CLOUD": | |
if not args["invitation_code"]: | |
raise ValueError("invitation_code is required") | |
# check invitation code | |
invitation_code = ( | |
db.session.query(InvitationCode) | |
.filter( | |
InvitationCode.code == args["invitation_code"], | |
InvitationCode.status == "unused", | |
) | |
.first() | |
) | |
if not invitation_code: | |
raise InvalidInvitationCodeError() | |
invitation_code.status = "used" | |
invitation_code.used_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) | |
invitation_code.used_by_tenant_id = account.current_tenant_id | |
invitation_code.used_by_account_id = account.id | |
account.interface_language = args["interface_language"] | |
account.timezone = args["timezone"] | |
account.interface_theme = "light" | |
account.status = "active" | |
account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) | |
db.session.commit() | |
return {"result": "success"} | |
class AccountProfileApi(Resource): | |
def get(self): | |
return current_user | |
class AccountNameApi(Resource): | |
def post(self): | |
parser = reqparse.RequestParser() | |
parser.add_argument("name", type=str, required=True, location="json") | |
args = parser.parse_args() | |
# Validate account name length | |
if len(args["name"]) < 3 or len(args["name"]) > 30: | |
raise ValueError("Account name must be between 3 and 30 characters.") | |
updated_account = AccountService.update_account(current_user, name=args["name"]) | |
return updated_account | |
class AccountAvatarApi(Resource): | |
def post(self): | |
parser = reqparse.RequestParser() | |
parser.add_argument("avatar", type=str, required=True, location="json") | |
args = parser.parse_args() | |
updated_account = AccountService.update_account(current_user, avatar=args["avatar"]) | |
return updated_account | |
class AccountInterfaceLanguageApi(Resource): | |
def post(self): | |
parser = reqparse.RequestParser() | |
parser.add_argument("interface_language", type=supported_language, required=True, location="json") | |
args = parser.parse_args() | |
updated_account = AccountService.update_account(current_user, interface_language=args["interface_language"]) | |
return updated_account | |
class AccountInterfaceThemeApi(Resource): | |
def post(self): | |
parser = reqparse.RequestParser() | |
parser.add_argument("interface_theme", type=str, choices=["light", "dark"], required=True, location="json") | |
args = parser.parse_args() | |
updated_account = AccountService.update_account(current_user, interface_theme=args["interface_theme"]) | |
return updated_account | |
class AccountTimezoneApi(Resource): | |
def post(self): | |
parser = reqparse.RequestParser() | |
parser.add_argument("timezone", type=str, required=True, location="json") | |
args = parser.parse_args() | |
# Validate timezone string, e.g. America/New_York, Asia/Shanghai | |
if args["timezone"] not in pytz.all_timezones: | |
raise ValueError("Invalid timezone string.") | |
updated_account = AccountService.update_account(current_user, timezone=args["timezone"]) | |
return updated_account | |
class AccountPasswordApi(Resource): | |
def post(self): | |
parser = reqparse.RequestParser() | |
parser.add_argument("password", type=str, required=False, location="json") | |
parser.add_argument("new_password", type=str, required=True, location="json") | |
parser.add_argument("repeat_new_password", type=str, required=True, location="json") | |
args = parser.parse_args() | |
if args["new_password"] != args["repeat_new_password"]: | |
raise RepeatPasswordNotMatchError() | |
try: | |
AccountService.update_account_password(current_user, args["password"], args["new_password"]) | |
except ServiceCurrentPasswordIncorrectError: | |
raise CurrentPasswordIncorrectError() | |
return {"result": "success"} | |
class AccountIntegrateApi(Resource): | |
integrate_fields = { | |
"provider": fields.String, | |
"created_at": TimestampField, | |
"is_bound": fields.Boolean, | |
"link": fields.String, | |
} | |
integrate_list_fields = { | |
"data": fields.List(fields.Nested(integrate_fields)), | |
} | |
def get(self): | |
account = current_user | |
account_integrates = db.session.query(AccountIntegrate).filter(AccountIntegrate.account_id == account.id).all() | |
base_url = request.url_root.rstrip("/") | |
oauth_base_path = "/console/api/oauth/login" | |
providers = ["github", "google"] | |
integrate_data = [] | |
for provider in providers: | |
existing_integrate = next((ai for ai in account_integrates if ai.provider == provider), None) | |
if existing_integrate: | |
integrate_data.append( | |
{ | |
"id": existing_integrate.id, | |
"provider": provider, | |
"created_at": existing_integrate.created_at, | |
"is_bound": True, | |
"link": None, | |
} | |
) | |
else: | |
integrate_data.append( | |
{ | |
"id": None, | |
"provider": provider, | |
"created_at": None, | |
"is_bound": False, | |
"link": f"{base_url}{oauth_base_path}/{provider}", | |
} | |
) | |
return {"data": integrate_data} | |
# Register API resources | |
api.add_resource(AccountInitApi, "/account/init") | |
api.add_resource(AccountProfileApi, "/account/profile") | |
api.add_resource(AccountNameApi, "/account/name") | |
api.add_resource(AccountAvatarApi, "/account/avatar") | |
api.add_resource(AccountInterfaceLanguageApi, "/account/interface-language") | |
api.add_resource(AccountInterfaceThemeApi, "/account/interface-theme") | |
api.add_resource(AccountTimezoneApi, "/account/timezone") | |
api.add_resource(AccountPasswordApi, "/account/password") | |
api.add_resource(AccountIntegrateApi, "/account/integrates") | |
# api.add_resource(AccountEmailApi, '/account/email') | |
# api.add_resource(AccountEmailVerifyApi, '/account/email-verify') | |