Spaces:
Build error
Build error
import logging | |
from datetime import datetime, timezone | |
from typing import Optional | |
import requests | |
from flask import current_app, redirect, request | |
from flask_restful import Resource | |
from werkzeug.exceptions import Unauthorized | |
from configs import dify_config | |
from constants.languages import languages | |
from events.tenant_event import tenant_was_created | |
from extensions.ext_database import db | |
from libs.helper import extract_remote_ip | |
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo | |
from models import Account | |
from models.account import AccountStatus | |
from services.account_service import AccountService, RegisterService, TenantService | |
from services.errors.account import AccountNotFoundError | |
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError | |
from services.feature_service import FeatureService | |
from .. import api | |
def get_oauth_providers(): | |
with current_app.app_context(): | |
if not dify_config.GITHUB_CLIENT_ID or not dify_config.GITHUB_CLIENT_SECRET: | |
github_oauth = None | |
else: | |
github_oauth = GitHubOAuth( | |
client_id=dify_config.GITHUB_CLIENT_ID, | |
client_secret=dify_config.GITHUB_CLIENT_SECRET, | |
redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/github", | |
) | |
if not dify_config.GOOGLE_CLIENT_ID or not dify_config.GOOGLE_CLIENT_SECRET: | |
google_oauth = None | |
else: | |
google_oauth = GoogleOAuth( | |
client_id=dify_config.GOOGLE_CLIENT_ID, | |
client_secret=dify_config.GOOGLE_CLIENT_SECRET, | |
redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/google", | |
) | |
OAUTH_PROVIDERS = {"github": github_oauth, "google": google_oauth} | |
return OAUTH_PROVIDERS | |
class OAuthLogin(Resource): | |
def get(self, provider: str): | |
invite_token = request.args.get("invite_token") or None | |
OAUTH_PROVIDERS = get_oauth_providers() | |
with current_app.app_context(): | |
oauth_provider = OAUTH_PROVIDERS.get(provider) | |
print(vars(oauth_provider)) | |
if not oauth_provider: | |
return {"error": "Invalid provider"}, 400 | |
auth_url = oauth_provider.get_authorization_url(invite_token=invite_token) | |
return redirect(auth_url) | |
class OAuthCallback(Resource): | |
def get(self, provider: str): | |
OAUTH_PROVIDERS = get_oauth_providers() | |
with current_app.app_context(): | |
oauth_provider = OAUTH_PROVIDERS.get(provider) | |
if not oauth_provider: | |
return {"error": "Invalid provider"}, 400 | |
code = request.args.get("code") | |
state = request.args.get("state") | |
invite_token = None | |
if state: | |
invite_token = state | |
try: | |
token = oauth_provider.get_access_token(code) | |
user_info = oauth_provider.get_user_info(token) | |
except requests.exceptions.HTTPError as e: | |
logging.exception(f"An error occurred during the OAuth process with {provider}: {e.response.text}") | |
return {"error": "OAuth process failed"}, 400 | |
if invite_token and RegisterService.is_valid_invite_token(invite_token): | |
invitation = RegisterService._get_invitation_by_token(token=invite_token) | |
if invitation: | |
invitation_email = invitation.get("email", None) | |
if invitation_email != user_info.email: | |
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Invalid invitation token.") | |
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}") | |
try: | |
account = _generate_account(provider, user_info) | |
except AccountNotFoundError: | |
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account not found.") | |
except (WorkSpaceNotFoundError, WorkSpaceNotAllowedCreateError): | |
return redirect( | |
f"{dify_config.CONSOLE_WEB_URL}/signin" | |
"?message=Workspace not found, please contact system admin to invite you to join in a workspace." | |
) | |
# Check account status | |
if account.status == AccountStatus.BANNED.value: | |
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account is banned.") | |
if account.status == AccountStatus.PENDING.value: | |
account.status = AccountStatus.ACTIVE.value | |
account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None) | |
db.session.commit() | |
try: | |
TenantService.create_owner_tenant_if_not_exist(account) | |
except Unauthorized: | |
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Workspace not found.") | |
except WorkSpaceNotAllowedCreateError: | |
return redirect( | |
f"{dify_config.CONSOLE_WEB_URL}/signin" | |
"?message=Workspace not found, please contact system admin to invite you to join in a workspace." | |
) | |
token_pair = AccountService.login( | |
account=account, | |
ip_address=extract_remote_ip(request), | |
) | |
return redirect( | |
f"{dify_config.CONSOLE_WEB_URL}?access_token={token_pair.access_token}&refresh_token={token_pair.refresh_token}" | |
) | |
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]: | |
account = Account.get_by_openid(provider, user_info.id) | |
if not account: | |
account = Account.query.filter_by(email=user_info.email).first() | |
return account | |
def _generate_account(provider: str, user_info: OAuthUserInfo): | |
# Get account by openid or email. | |
account = _get_account_by_openid_or_email(provider, user_info) | |
if account: | |
tenant = TenantService.get_join_tenants(account) | |
if not tenant: | |
if not FeatureService.get_system_features().is_allow_create_workspace: | |
raise WorkSpaceNotAllowedCreateError() | |
else: | |
tenant = TenantService.create_tenant(f"{account.name}'s Workspace") | |
TenantService.create_tenant_member(tenant, account, role="owner") | |
account.current_tenant = tenant | |
tenant_was_created.send(tenant) | |
if not account: | |
if not FeatureService.get_system_features().is_allow_register: | |
raise AccountNotFoundError() | |
account_name = user_info.name or "Dify" | |
account = RegisterService.register( | |
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider | |
) | |
# Set interface language | |
preferred_lang = request.accept_languages.best_match(languages) | |
if preferred_lang and preferred_lang in languages: | |
interface_language = preferred_lang | |
else: | |
interface_language = languages[0] | |
account.interface_language = interface_language | |
db.session.commit() | |
# Link account | |
AccountService.link_account_integrate(provider, user_info.id, account) | |
return account | |
api.add_resource(OAuthLogin, "/oauth/login/<provider>") | |
api.add_resource(OAuthCallback, "/oauth/authorize/<provider>") | |