Spaces:
Build error
Build error
import flask_restful | |
from flask import request | |
from flask_login import current_user | |
from flask_restful import Resource, marshal, marshal_with, reqparse | |
from werkzeug.exceptions import Forbidden, NotFound | |
import services | |
from configs import dify_config | |
from controllers.console import api | |
from controllers.console.apikey import api_key_fields, api_key_list | |
from controllers.console.app.error import ProviderNotInitializeError | |
from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError | |
from controllers.console.wraps import account_initialization_required, setup_required | |
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError | |
from core.indexing_runner import IndexingRunner | |
from core.model_runtime.entities.model_entities import ModelType | |
from core.provider_manager import ProviderManager | |
from core.rag.datasource.vdb.vector_type import VectorType | |
from core.rag.extractor.entity.extract_setting import ExtractSetting | |
from core.rag.retrieval.retrieval_methods import RetrievalMethod | |
from extensions.ext_database import db | |
from fields.app_fields import related_app_list | |
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields | |
from fields.document_fields import document_status_fields | |
from libs.login import login_required | |
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile | |
from models.dataset import DatasetPermissionEnum | |
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService | |
def _validate_name(name): | |
if not name or len(name) < 1 or len(name) > 40: | |
raise ValueError("Name must be between 1 to 40 characters.") | |
return name | |
def _validate_description_length(description): | |
if len(description) > 400: | |
raise ValueError("Description cannot exceed 400 characters.") | |
return description | |
class DatasetListApi(Resource): | |
def get(self): | |
page = request.args.get("page", default=1, type=int) | |
limit = request.args.get("limit", default=20, type=int) | |
ids = request.args.getlist("ids") | |
# provider = request.args.get("provider", default="vendor") | |
search = request.args.get("keyword", default=None, type=str) | |
tag_ids = request.args.getlist("tag_ids") | |
if ids: | |
datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id) | |
else: | |
datasets, total = DatasetService.get_datasets( | |
page, limit, current_user.current_tenant_id, current_user, search, tag_ids | |
) | |
# check embedding setting | |
provider_manager = ProviderManager() | |
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id) | |
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) | |
model_names = [] | |
for embedding_model in embedding_models: | |
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}") | |
data = marshal(datasets, dataset_detail_fields) | |
for item in data: | |
if item["indexing_technique"] == "high_quality": | |
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" | |
if item_model in model_names: | |
item["embedding_available"] = True | |
else: | |
item["embedding_available"] = False | |
else: | |
item["embedding_available"] = True | |
if item.get("permission") == "partial_members": | |
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(item["id"]) | |
item.update({"partial_member_list": part_users_list}) | |
else: | |
item.update({"partial_member_list": []}) | |
response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page} | |
return response, 200 | |
def post(self): | |
parser = reqparse.RequestParser() | |
parser.add_argument( | |
"name", | |
nullable=False, | |
required=True, | |
help="type is required. Name must be between 1 to 40 characters.", | |
type=_validate_name, | |
) | |
parser.add_argument( | |
"description", | |
type=str, | |
nullable=True, | |
required=False, | |
default="", | |
) | |
parser.add_argument( | |
"indexing_technique", | |
type=str, | |
location="json", | |
choices=Dataset.INDEXING_TECHNIQUE_LIST, | |
nullable=True, | |
help="Invalid indexing technique.", | |
) | |
parser.add_argument( | |
"external_knowledge_api_id", | |
type=str, | |
nullable=True, | |
required=False, | |
) | |
parser.add_argument( | |
"provider", | |
type=str, | |
nullable=True, | |
choices=Dataset.PROVIDER_LIST, | |
required=False, | |
default="vendor", | |
) | |
parser.add_argument( | |
"external_knowledge_id", | |
type=str, | |
nullable=True, | |
required=False, | |
) | |
args = parser.parse_args() | |
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator | |
if not current_user.is_dataset_editor: | |
raise Forbidden() | |
try: | |
dataset = DatasetService.create_empty_dataset( | |
tenant_id=current_user.current_tenant_id, | |
name=args["name"], | |
description=args["description"], | |
indexing_technique=args["indexing_technique"], | |
account=current_user, | |
permission=DatasetPermissionEnum.ONLY_ME, | |
provider=args["provider"], | |
external_knowledge_api_id=args["external_knowledge_api_id"], | |
external_knowledge_id=args["external_knowledge_id"], | |
) | |
except services.errors.dataset.DatasetNameDuplicateError: | |
raise DatasetNameDuplicateError() | |
return marshal(dataset, dataset_detail_fields), 201 | |
class DatasetApi(Resource): | |
def get(self, dataset_id): | |
dataset_id_str = str(dataset_id) | |
dataset = DatasetService.get_dataset(dataset_id_str) | |
if dataset is None: | |
raise NotFound("Dataset not found.") | |
try: | |
DatasetService.check_dataset_permission(dataset, current_user) | |
except services.errors.account.NoPermissionError as e: | |
raise Forbidden(str(e)) | |
data = marshal(dataset, dataset_detail_fields) | |
if data.get("permission") == "partial_members": | |
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) | |
data.update({"partial_member_list": part_users_list}) | |
# check embedding setting | |
provider_manager = ProviderManager() | |
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id) | |
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) | |
model_names = [] | |
for embedding_model in embedding_models: | |
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}") | |
if data["indexing_technique"] == "high_quality": | |
item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}" | |
if item_model in model_names: | |
data["embedding_available"] = True | |
else: | |
data["embedding_available"] = False | |
else: | |
data["embedding_available"] = True | |
if data.get("permission") == "partial_members": | |
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) | |
data.update({"partial_member_list": part_users_list}) | |
return data, 200 | |
def patch(self, dataset_id): | |
dataset_id_str = str(dataset_id) | |
dataset = DatasetService.get_dataset(dataset_id_str) | |
if dataset is None: | |
raise NotFound("Dataset not found.") | |
parser = reqparse.RequestParser() | |
parser.add_argument( | |
"name", | |
nullable=False, | |
help="type is required. Name must be between 1 to 40 characters.", | |
type=_validate_name, | |
) | |
parser.add_argument("description", location="json", store_missing=False, type=_validate_description_length) | |
parser.add_argument( | |
"indexing_technique", | |
type=str, | |
location="json", | |
choices=Dataset.INDEXING_TECHNIQUE_LIST, | |
nullable=True, | |
help="Invalid indexing technique.", | |
) | |
parser.add_argument( | |
"permission", | |
type=str, | |
location="json", | |
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), | |
help="Invalid permission.", | |
) | |
parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.") | |
parser.add_argument( | |
"embedding_model_provider", type=str, location="json", help="Invalid embedding model provider." | |
) | |
parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.") | |
parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.") | |
parser.add_argument( | |
"external_retrieval_model", | |
type=dict, | |
required=False, | |
nullable=True, | |
location="json", | |
help="Invalid external retrieval model.", | |
) | |
parser.add_argument( | |
"external_knowledge_id", | |
type=str, | |
required=False, | |
nullable=True, | |
location="json", | |
help="Invalid external knowledge id.", | |
) | |
parser.add_argument( | |
"external_knowledge_api_id", | |
type=str, | |
required=False, | |
nullable=True, | |
location="json", | |
help="Invalid external knowledge api id.", | |
) | |
args = parser.parse_args() | |
data = request.get_json() | |
# check embedding model setting | |
if data.get("indexing_technique") == "high_quality": | |
DatasetService.check_embedding_model_setting( | |
dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model") | |
) | |
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator | |
DatasetPermissionService.check_permission( | |
current_user, dataset, data.get("permission"), data.get("partial_member_list") | |
) | |
dataset = DatasetService.update_dataset(dataset_id_str, args, current_user) | |
if dataset is None: | |
raise NotFound("Dataset not found.") | |
result_data = marshal(dataset, dataset_detail_fields) | |
tenant_id = current_user.current_tenant_id | |
if data.get("partial_member_list") and data.get("permission") == "partial_members": | |
DatasetPermissionService.update_partial_member_list( | |
tenant_id, dataset_id_str, data.get("partial_member_list") | |
) | |
# clear partial member list when permission is only_me or all_team_members | |
elif ( | |
data.get("permission") == DatasetPermissionEnum.ONLY_ME | |
or data.get("permission") == DatasetPermissionEnum.ALL_TEAM | |
): | |
DatasetPermissionService.clear_partial_member_list(dataset_id_str) | |
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) | |
result_data.update({"partial_member_list": partial_member_list}) | |
return result_data, 200 | |
def delete(self, dataset_id): | |
dataset_id_str = str(dataset_id) | |
# The role of the current user in the ta table must be admin, owner, or editor | |
if not current_user.is_editor or current_user.is_dataset_operator: | |
raise Forbidden() | |
try: | |
if DatasetService.delete_dataset(dataset_id_str, current_user): | |
DatasetPermissionService.clear_partial_member_list(dataset_id_str) | |
return {"result": "success"}, 204 | |
else: | |
raise NotFound("Dataset not found.") | |
except services.errors.dataset.DatasetInUseError: | |
raise DatasetInUseError() | |
class DatasetUseCheckApi(Resource): | |
def get(self, dataset_id): | |
dataset_id_str = str(dataset_id) | |
dataset_is_using = DatasetService.dataset_use_check(dataset_id_str) | |
return {"is_using": dataset_is_using}, 200 | |
class DatasetQueryApi(Resource): | |
def get(self, dataset_id): | |
dataset_id_str = str(dataset_id) | |
dataset = DatasetService.get_dataset(dataset_id_str) | |
if dataset is None: | |
raise NotFound("Dataset not found.") | |
try: | |
DatasetService.check_dataset_permission(dataset, current_user) | |
except services.errors.account.NoPermissionError as e: | |
raise Forbidden(str(e)) | |
page = request.args.get("page", default=1, type=int) | |
limit = request.args.get("limit", default=20, type=int) | |
dataset_queries, total = DatasetService.get_dataset_queries(dataset_id=dataset.id, page=page, per_page=limit) | |
response = { | |
"data": marshal(dataset_queries, dataset_query_detail_fields), | |
"has_more": len(dataset_queries) == limit, | |
"limit": limit, | |
"total": total, | |
"page": page, | |
} | |
return response, 200 | |
class DatasetIndexingEstimateApi(Resource): | |
def post(self): | |
parser = reqparse.RequestParser() | |
parser.add_argument("info_list", type=dict, required=True, nullable=True, location="json") | |
parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json") | |
parser.add_argument( | |
"indexing_technique", | |
type=str, | |
required=True, | |
choices=Dataset.INDEXING_TECHNIQUE_LIST, | |
nullable=True, | |
location="json", | |
) | |
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") | |
parser.add_argument("dataset_id", type=str, required=False, nullable=False, location="json") | |
parser.add_argument( | |
"doc_language", type=str, default="English", required=False, nullable=False, location="json" | |
) | |
args = parser.parse_args() | |
# validate args | |
DocumentService.estimate_args_validate(args) | |
extract_settings = [] | |
if args["info_list"]["data_source_type"] == "upload_file": | |
file_ids = args["info_list"]["file_info_list"]["file_ids"] | |
file_details = ( | |
db.session.query(UploadFile) | |
.filter(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids)) | |
.all() | |
) | |
if file_details is None: | |
raise NotFound("File not found.") | |
if file_details: | |
for file_detail in file_details: | |
extract_setting = ExtractSetting( | |
datasource_type="upload_file", upload_file=file_detail, document_model=args["doc_form"] | |
) | |
extract_settings.append(extract_setting) | |
elif args["info_list"]["data_source_type"] == "notion_import": | |
notion_info_list = args["info_list"]["notion_info_list"] | |
for notion_info in notion_info_list: | |
workspace_id = notion_info["workspace_id"] | |
for page in notion_info["pages"]: | |
extract_setting = ExtractSetting( | |
datasource_type="notion_import", | |
notion_info={ | |
"notion_workspace_id": workspace_id, | |
"notion_obj_id": page["page_id"], | |
"notion_page_type": page["type"], | |
"tenant_id": current_user.current_tenant_id, | |
}, | |
document_model=args["doc_form"], | |
) | |
extract_settings.append(extract_setting) | |
elif args["info_list"]["data_source_type"] == "website_crawl": | |
website_info_list = args["info_list"]["website_info_list"] | |
for url in website_info_list["urls"]: | |
extract_setting = ExtractSetting( | |
datasource_type="website_crawl", | |
website_info={ | |
"provider": website_info_list["provider"], | |
"job_id": website_info_list["job_id"], | |
"url": url, | |
"tenant_id": current_user.current_tenant_id, | |
"mode": "crawl", | |
"only_main_content": website_info_list["only_main_content"], | |
}, | |
document_model=args["doc_form"], | |
) | |
extract_settings.append(extract_setting) | |
else: | |
raise ValueError("Data source type not support") | |
indexing_runner = IndexingRunner() | |
try: | |
response = indexing_runner.indexing_estimate( | |
current_user.current_tenant_id, | |
extract_settings, | |
args["process_rule"], | |
args["doc_form"], | |
args["doc_language"], | |
args["dataset_id"], | |
args["indexing_technique"], | |
) | |
except LLMBadRequestError: | |
raise ProviderNotInitializeError( | |
"No Embedding Model available. Please configure a valid provider " "in the Settings -> Model Provider." | |
) | |
except ProviderTokenNotInitError as ex: | |
raise ProviderNotInitializeError(ex.description) | |
except Exception as e: | |
raise IndexingEstimateError(str(e)) | |
return response, 200 | |
class DatasetRelatedAppListApi(Resource): | |
def get(self, dataset_id): | |
dataset_id_str = str(dataset_id) | |
dataset = DatasetService.get_dataset(dataset_id_str) | |
if dataset is None: | |
raise NotFound("Dataset not found.") | |
try: | |
DatasetService.check_dataset_permission(dataset, current_user) | |
except services.errors.account.NoPermissionError as e: | |
raise Forbidden(str(e)) | |
app_dataset_joins = DatasetService.get_related_apps(dataset.id) | |
related_apps = [] | |
for app_dataset_join in app_dataset_joins: | |
app_model = app_dataset_join.app | |
if app_model: | |
related_apps.append(app_model) | |
return {"data": related_apps, "total": len(related_apps)}, 200 | |
class DatasetIndexingStatusApi(Resource): | |
def get(self, dataset_id): | |
dataset_id = str(dataset_id) | |
documents = ( | |
db.session.query(Document) | |
.filter(Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id) | |
.all() | |
) | |
documents_status = [] | |
for document in documents: | |
completed_segments = DocumentSegment.query.filter( | |
DocumentSegment.completed_at.isnot(None), | |
DocumentSegment.document_id == str(document.id), | |
DocumentSegment.status != "re_segment", | |
).count() | |
total_segments = DocumentSegment.query.filter( | |
DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment" | |
).count() | |
document.completed_segments = completed_segments | |
document.total_segments = total_segments | |
documents_status.append(marshal(document, document_status_fields)) | |
data = {"data": documents_status} | |
return data | |
class DatasetApiKeyApi(Resource): | |
max_keys = 10 | |
token_prefix = "dataset-" | |
resource_type = "dataset" | |
def get(self): | |
keys = ( | |
db.session.query(ApiToken) | |
.filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id) | |
.all() | |
) | |
return {"items": keys} | |
def post(self): | |
# The role of the current user in the ta table must be admin or owner | |
if not current_user.is_admin_or_owner: | |
raise Forbidden() | |
current_key_count = ( | |
db.session.query(ApiToken) | |
.filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id) | |
.count() | |
) | |
if current_key_count >= self.max_keys: | |
flask_restful.abort( | |
400, | |
message=f"Cannot create more than {self.max_keys} API keys for this resource type.", | |
code="max_keys_exceeded", | |
) | |
key = ApiToken.generate_api_key(self.token_prefix, 24) | |
api_token = ApiToken() | |
api_token.tenant_id = current_user.current_tenant_id | |
api_token.token = key | |
api_token.type = self.resource_type | |
db.session.add(api_token) | |
db.session.commit() | |
return api_token, 200 | |
class DatasetApiDeleteApi(Resource): | |
resource_type = "dataset" | |
def delete(self, api_key_id): | |
api_key_id = str(api_key_id) | |
# The role of the current user in the ta table must be admin or owner | |
if not current_user.is_admin_or_owner: | |
raise Forbidden() | |
key = ( | |
db.session.query(ApiToken) | |
.filter( | |
ApiToken.tenant_id == current_user.current_tenant_id, | |
ApiToken.type == self.resource_type, | |
ApiToken.id == api_key_id, | |
) | |
.first() | |
) | |
if key is None: | |
flask_restful.abort(404, message="API key not found") | |
db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete() | |
db.session.commit() | |
return {"result": "success"}, 204 | |
class DatasetApiBaseUrlApi(Resource): | |
def get(self): | |
return {"api_base_url": (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1"} | |
class DatasetRetrievalSettingApi(Resource): | |
def get(self): | |
vector_type = dify_config.VECTOR_STORE | |
match vector_type: | |
case ( | |
VectorType.MILVUS | |
| VectorType.RELYT | |
| VectorType.PGVECTOR | |
| VectorType.TIDB_VECTOR | |
| VectorType.CHROMA | |
| VectorType.TENCENT | |
| VectorType.PGVECTO_RS | |
| VectorType.BAIDU | |
| VectorType.VIKINGDB | |
| VectorType.UPSTASH | |
| VectorType.OCEANBASE | |
): | |
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]} | |
case ( | |
VectorType.QDRANT | |
| VectorType.WEAVIATE | |
| VectorType.OPENSEARCH | |
| VectorType.ANALYTICDB | |
| VectorType.MYSCALE | |
| VectorType.ORACLE | |
| VectorType.ELASTICSEARCH | |
| VectorType.PGVECTOR | |
| VectorType.TIDB_ON_QDRANT | |
| VectorType.LINDORM | |
| VectorType.COUCHBASE | |
): | |
return { | |
"retrieval_method": [ | |
RetrievalMethod.SEMANTIC_SEARCH.value, | |
RetrievalMethod.FULL_TEXT_SEARCH.value, | |
RetrievalMethod.HYBRID_SEARCH.value, | |
] | |
} | |
case _: | |
raise ValueError(f"Unsupported vector db type {vector_type}.") | |
class DatasetRetrievalSettingMockApi(Resource): | |
def get(self, vector_type): | |
match vector_type: | |
case ( | |
VectorType.MILVUS | |
| VectorType.RELYT | |
| VectorType.TIDB_VECTOR | |
| VectorType.CHROMA | |
| VectorType.TENCENT | |
| VectorType.PGVECTO_RS | |
| VectorType.BAIDU | |
| VectorType.VIKINGDB | |
| VectorType.UPSTASH | |
| VectorType.OCEANBASE | |
): | |
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]} | |
case ( | |
VectorType.QDRANT | |
| VectorType.WEAVIATE | |
| VectorType.OPENSEARCH | |
| VectorType.ANALYTICDB | |
| VectorType.MYSCALE | |
| VectorType.ORACLE | |
| VectorType.ELASTICSEARCH | |
| VectorType.COUCHBASE | |
| VectorType.PGVECTOR | |
| VectorType.LINDORM | |
): | |
return { | |
"retrieval_method": [ | |
RetrievalMethod.SEMANTIC_SEARCH.value, | |
RetrievalMethod.FULL_TEXT_SEARCH.value, | |
RetrievalMethod.HYBRID_SEARCH.value, | |
] | |
} | |
case _: | |
raise ValueError(f"Unsupported vector db type {vector_type}.") | |
class DatasetErrorDocs(Resource): | |
def get(self, dataset_id): | |
dataset_id_str = str(dataset_id) | |
dataset = DatasetService.get_dataset(dataset_id_str) | |
if dataset is None: | |
raise NotFound("Dataset not found.") | |
results = DocumentService.get_error_documents_by_dataset_id(dataset_id_str) | |
return {"data": [marshal(item, document_status_fields) for item in results], "total": len(results)}, 200 | |
class DatasetPermissionUserListApi(Resource): | |
def get(self, dataset_id): | |
dataset_id_str = str(dataset_id) | |
dataset = DatasetService.get_dataset(dataset_id_str) | |
if dataset is None: | |
raise NotFound("Dataset not found.") | |
try: | |
DatasetService.check_dataset_permission(dataset, current_user) | |
except services.errors.account.NoPermissionError as e: | |
raise Forbidden(str(e)) | |
partial_members_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) | |
return { | |
"data": partial_members_list, | |
}, 200 | |
api.add_resource(DatasetListApi, "/datasets") | |
api.add_resource(DatasetApi, "/datasets/<uuid:dataset_id>") | |
api.add_resource(DatasetUseCheckApi, "/datasets/<uuid:dataset_id>/use-check") | |
api.add_resource(DatasetQueryApi, "/datasets/<uuid:dataset_id>/queries") | |
api.add_resource(DatasetErrorDocs, "/datasets/<uuid:dataset_id>/error-docs") | |
api.add_resource(DatasetIndexingEstimateApi, "/datasets/indexing-estimate") | |
api.add_resource(DatasetRelatedAppListApi, "/datasets/<uuid:dataset_id>/related-apps") | |
api.add_resource(DatasetIndexingStatusApi, "/datasets/<uuid:dataset_id>/indexing-status") | |
api.add_resource(DatasetApiKeyApi, "/datasets/api-keys") | |
api.add_resource(DatasetApiDeleteApi, "/datasets/api-keys/<uuid:api_key_id>") | |
api.add_resource(DatasetApiBaseUrlApi, "/datasets/api-base-info") | |
api.add_resource(DatasetRetrievalSettingApi, "/datasets/retrieval-setting") | |
api.add_resource(DatasetRetrievalSettingMockApi, "/datasets/retrieval-setting/<string:vector_type>") | |
api.add_resource(DatasetPermissionUserListApi, "/datasets/<uuid:dataset_id>/permission-part-users") | |