Spaces:
Build error
Build error
import datetime | |
import json | |
import logging | |
from json import JSONDecodeError | |
from typing import Optional | |
from constants import HIDDEN_VALUE | |
from core.entities.provider_configuration import ProviderConfiguration | |
from core.helper import encrypter | |
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType | |
from core.model_manager import LBModelManager | |
from core.model_runtime.entities.model_entities import ModelType | |
from core.model_runtime.entities.provider_entities import ( | |
ModelCredentialSchema, | |
ProviderCredentialSchema, | |
) | |
from core.model_runtime.model_providers import model_provider_factory | |
from core.provider_manager import ProviderManager | |
from extensions.ext_database import db | |
from models.provider import LoadBalancingModelConfig | |
logger = logging.getLogger(__name__) | |
class ModelLoadBalancingService: | |
def __init__(self) -> None: | |
self.provider_manager = ProviderManager() | |
def enable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str) -> None: | |
""" | |
enable model load balancing. | |
:param tenant_id: workspace id | |
:param provider: provider name | |
:param model: model name | |
:param model_type: model type | |
:return: | |
""" | |
# Get all provider configurations of the current workspace | |
provider_configurations = self.provider_manager.get_configurations(tenant_id) | |
# Get provider configuration | |
provider_configuration = provider_configurations.get(provider) | |
if not provider_configuration: | |
raise ValueError(f"Provider {provider} does not exist.") | |
# Enable model load balancing | |
provider_configuration.enable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type)) | |
def disable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str) -> None: | |
""" | |
disable model load balancing. | |
:param tenant_id: workspace id | |
:param provider: provider name | |
:param model: model name | |
:param model_type: model type | |
:return: | |
""" | |
# Get all provider configurations of the current workspace | |
provider_configurations = self.provider_manager.get_configurations(tenant_id) | |
# Get provider configuration | |
provider_configuration = provider_configurations.get(provider) | |
if not provider_configuration: | |
raise ValueError(f"Provider {provider} does not exist.") | |
# disable model load balancing | |
provider_configuration.disable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type)) | |
def get_load_balancing_configs( | |
self, tenant_id: str, provider: str, model: str, model_type: str | |
) -> tuple[bool, list[dict]]: | |
""" | |
Get load balancing configurations. | |
:param tenant_id: workspace id | |
:param provider: provider name | |
:param model: model name | |
:param model_type: model type | |
:return: | |
""" | |
# Get all provider configurations of the current workspace | |
provider_configurations = self.provider_manager.get_configurations(tenant_id) | |
# Get provider configuration | |
provider_configuration = provider_configurations.get(provider) | |
if not provider_configuration: | |
raise ValueError(f"Provider {provider} does not exist.") | |
# Convert model type to ModelType | |
model_type = ModelType.value_of(model_type) | |
# Get provider model setting | |
provider_model_setting = provider_configuration.get_provider_model_setting( | |
model_type=model_type, | |
model=model, | |
) | |
is_load_balancing_enabled = False | |
if provider_model_setting and provider_model_setting.load_balancing_enabled: | |
is_load_balancing_enabled = True | |
# Get load balancing configurations | |
load_balancing_configs = ( | |
db.session.query(LoadBalancingModelConfig) | |
.filter( | |
LoadBalancingModelConfig.tenant_id == tenant_id, | |
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, | |
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), | |
LoadBalancingModelConfig.model_name == model, | |
) | |
.order_by(LoadBalancingModelConfig.created_at) | |
.all() | |
) | |
if provider_configuration.custom_configuration.provider: | |
# check if the inherit configuration exists, | |
# inherit is represented for the provider or model custom credentials | |
inherit_config_exists = False | |
for load_balancing_config in load_balancing_configs: | |
if load_balancing_config.name == "__inherit__": | |
inherit_config_exists = True | |
break | |
if not inherit_config_exists: | |
# Initialize the inherit configuration | |
inherit_config = self._init_inherit_config(tenant_id, provider, model, model_type) | |
# prepend the inherit configuration | |
load_balancing_configs.insert(0, inherit_config) | |
else: | |
# move the inherit configuration to the first | |
for i, load_balancing_config in enumerate(load_balancing_configs[:]): | |
if load_balancing_config.name == "__inherit__": | |
inherit_config = load_balancing_configs.pop(i) | |
load_balancing_configs.insert(0, inherit_config) | |
# Get credential form schemas from model credential schema or provider credential schema | |
credential_schemas = self._get_credential_schema(provider_configuration) | |
# Get decoding rsa key and cipher for decrypting credentials | |
decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id) | |
# fetch status and ttl for each config | |
datas = [] | |
for load_balancing_config in load_balancing_configs: | |
in_cooldown, ttl = LBModelManager.get_config_in_cooldown_and_ttl( | |
tenant_id=tenant_id, | |
provider=provider, | |
model=model, | |
model_type=model_type, | |
config_id=load_balancing_config.id, | |
) | |
try: | |
if load_balancing_config.encrypted_config: | |
credentials = json.loads(load_balancing_config.encrypted_config) | |
else: | |
credentials = {} | |
except JSONDecodeError: | |
credentials = {} | |
# Get provider credential secret variables | |
credential_secret_variables = provider_configuration.extract_secret_variables( | |
credential_schemas.credential_form_schemas | |
) | |
# decrypt credentials | |
for variable in credential_secret_variables: | |
if variable in credentials: | |
try: | |
credentials[variable] = encrypter.decrypt_token_with_decoding( | |
credentials.get(variable), decoding_rsa_key, decoding_cipher_rsa | |
) | |
except ValueError: | |
pass | |
# Obfuscate credentials | |
credentials = provider_configuration.obfuscated_credentials( | |
credentials=credentials, credential_form_schemas=credential_schemas.credential_form_schemas | |
) | |
datas.append( | |
{ | |
"id": load_balancing_config.id, | |
"name": load_balancing_config.name, | |
"credentials": credentials, | |
"enabled": load_balancing_config.enabled, | |
"in_cooldown": in_cooldown, | |
"ttl": ttl, | |
} | |
) | |
return is_load_balancing_enabled, datas | |
def get_load_balancing_config( | |
self, tenant_id: str, provider: str, model: str, model_type: str, config_id: str | |
) -> Optional[dict]: | |
""" | |
Get load balancing configuration. | |
:param tenant_id: workspace id | |
:param provider: provider name | |
:param model: model name | |
:param model_type: model type | |
:param config_id: load balancing config id | |
:return: | |
""" | |
# Get all provider configurations of the current workspace | |
provider_configurations = self.provider_manager.get_configurations(tenant_id) | |
# Get provider configuration | |
provider_configuration = provider_configurations.get(provider) | |
if not provider_configuration: | |
raise ValueError(f"Provider {provider} does not exist.") | |
# Convert model type to ModelType | |
model_type = ModelType.value_of(model_type) | |
# Get load balancing configurations | |
load_balancing_model_config = ( | |
db.session.query(LoadBalancingModelConfig) | |
.filter( | |
LoadBalancingModelConfig.tenant_id == tenant_id, | |
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, | |
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), | |
LoadBalancingModelConfig.model_name == model, | |
LoadBalancingModelConfig.id == config_id, | |
) | |
.first() | |
) | |
if not load_balancing_model_config: | |
return None | |
try: | |
if load_balancing_model_config.encrypted_config: | |
credentials = json.loads(load_balancing_model_config.encrypted_config) | |
else: | |
credentials = {} | |
except JSONDecodeError: | |
credentials = {} | |
# Get credential form schemas from model credential schema or provider credential schema | |
credential_schemas = self._get_credential_schema(provider_configuration) | |
# Obfuscate credentials | |
credentials = provider_configuration.obfuscated_credentials( | |
credentials=credentials, credential_form_schemas=credential_schemas.credential_form_schemas | |
) | |
return { | |
"id": load_balancing_model_config.id, | |
"name": load_balancing_model_config.name, | |
"credentials": credentials, | |
"enabled": load_balancing_model_config.enabled, | |
} | |
def _init_inherit_config( | |
self, tenant_id: str, provider: str, model: str, model_type: ModelType | |
) -> LoadBalancingModelConfig: | |
""" | |
Initialize the inherit configuration. | |
:param tenant_id: workspace id | |
:param provider: provider name | |
:param model: model name | |
:param model_type: model type | |
:return: | |
""" | |
# Initialize the inherit configuration | |
inherit_config = LoadBalancingModelConfig( | |
tenant_id=tenant_id, | |
provider_name=provider, | |
model_type=model_type.to_origin_model_type(), | |
model_name=model, | |
name="__inherit__", | |
) | |
db.session.add(inherit_config) | |
db.session.commit() | |
return inherit_config | |
def update_load_balancing_configs( | |
self, tenant_id: str, provider: str, model: str, model_type: str, configs: list[dict] | |
) -> None: | |
""" | |
Update load balancing configurations. | |
:param tenant_id: workspace id | |
:param provider: provider name | |
:param model: model name | |
:param model_type: model type | |
:param configs: load balancing configs | |
:return: | |
""" | |
# Get all provider configurations of the current workspace | |
provider_configurations = self.provider_manager.get_configurations(tenant_id) | |
# Get provider configuration | |
provider_configuration = provider_configurations.get(provider) | |
if not provider_configuration: | |
raise ValueError(f"Provider {provider} does not exist.") | |
# Convert model type to ModelType | |
model_type = ModelType.value_of(model_type) | |
if not isinstance(configs, list): | |
raise ValueError("Invalid load balancing configs") | |
current_load_balancing_configs = ( | |
db.session.query(LoadBalancingModelConfig) | |
.filter( | |
LoadBalancingModelConfig.tenant_id == tenant_id, | |
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, | |
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), | |
LoadBalancingModelConfig.model_name == model, | |
) | |
.all() | |
) | |
# id as key, config as value | |
current_load_balancing_configs_dict = {config.id: config for config in current_load_balancing_configs} | |
updated_config_ids = set() | |
for config in configs: | |
if not isinstance(config, dict): | |
raise ValueError("Invalid load balancing config") | |
config_id = config.get("id") | |
name = config.get("name") | |
credentials = config.get("credentials") | |
enabled = config.get("enabled") | |
if not name: | |
raise ValueError("Invalid load balancing config name") | |
if enabled is None: | |
raise ValueError("Invalid load balancing config enabled") | |
# is config exists | |
if config_id: | |
config_id = str(config_id) | |
if config_id not in current_load_balancing_configs_dict: | |
raise ValueError("Invalid load balancing config id: {}".format(config_id)) | |
updated_config_ids.add(config_id) | |
load_balancing_config = current_load_balancing_configs_dict[config_id] | |
# check duplicate name | |
for current_load_balancing_config in current_load_balancing_configs: | |
if current_load_balancing_config.id != config_id and current_load_balancing_config.name == name: | |
raise ValueError("Load balancing config name {} already exists".format(name)) | |
if credentials: | |
if not isinstance(credentials, dict): | |
raise ValueError("Invalid load balancing config credentials") | |
# validate custom provider config | |
credentials = self._custom_credentials_validate( | |
tenant_id=tenant_id, | |
provider_configuration=provider_configuration, | |
model_type=model_type, | |
model=model, | |
credentials=credentials, | |
load_balancing_model_config=load_balancing_config, | |
validate=False, | |
) | |
# update load balancing config | |
load_balancing_config.encrypted_config = json.dumps(credentials) | |
load_balancing_config.name = name | |
load_balancing_config.enabled = enabled | |
load_balancing_config.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) | |
db.session.commit() | |
self._clear_credentials_cache(tenant_id, config_id) | |
else: | |
# create load balancing config | |
if name == "__inherit__": | |
raise ValueError("Invalid load balancing config name") | |
# check duplicate name | |
for current_load_balancing_config in current_load_balancing_configs: | |
if current_load_balancing_config.name == name: | |
raise ValueError("Load balancing config name {} already exists".format(name)) | |
if not credentials: | |
raise ValueError("Invalid load balancing config credentials") | |
if not isinstance(credentials, dict): | |
raise ValueError("Invalid load balancing config credentials") | |
# validate custom provider config | |
credentials = self._custom_credentials_validate( | |
tenant_id=tenant_id, | |
provider_configuration=provider_configuration, | |
model_type=model_type, | |
model=model, | |
credentials=credentials, | |
validate=False, | |
) | |
# create load balancing config | |
load_balancing_model_config = LoadBalancingModelConfig( | |
tenant_id=tenant_id, | |
provider_name=provider_configuration.provider.provider, | |
model_type=model_type.to_origin_model_type(), | |
model_name=model, | |
name=name, | |
encrypted_config=json.dumps(credentials), | |
) | |
db.session.add(load_balancing_model_config) | |
db.session.commit() | |
# get deleted config ids | |
deleted_config_ids = set(current_load_balancing_configs_dict.keys()) - updated_config_ids | |
for config_id in deleted_config_ids: | |
db.session.delete(current_load_balancing_configs_dict[config_id]) | |
db.session.commit() | |
self._clear_credentials_cache(tenant_id, config_id) | |
def validate_load_balancing_credentials( | |
self, | |
tenant_id: str, | |
provider: str, | |
model: str, | |
model_type: str, | |
credentials: dict, | |
config_id: Optional[str] = None, | |
) -> None: | |
""" | |
Validate load balancing credentials. | |
:param tenant_id: workspace id | |
:param provider: provider name | |
:param model_type: model type | |
:param model: model name | |
:param credentials: credentials | |
:param config_id: load balancing config id | |
:return: | |
""" | |
# Get all provider configurations of the current workspace | |
provider_configurations = self.provider_manager.get_configurations(tenant_id) | |
# Get provider configuration | |
provider_configuration = provider_configurations.get(provider) | |
if not provider_configuration: | |
raise ValueError(f"Provider {provider} does not exist.") | |
# Convert model type to ModelType | |
model_type = ModelType.value_of(model_type) | |
load_balancing_model_config = None | |
if config_id: | |
# Get load balancing config | |
load_balancing_model_config = ( | |
db.session.query(LoadBalancingModelConfig) | |
.filter( | |
LoadBalancingModelConfig.tenant_id == tenant_id, | |
LoadBalancingModelConfig.provider_name == provider, | |
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), | |
LoadBalancingModelConfig.model_name == model, | |
LoadBalancingModelConfig.id == config_id, | |
) | |
.first() | |
) | |
if not load_balancing_model_config: | |
raise ValueError(f"Load balancing config {config_id} does not exist.") | |
# Validate custom provider config | |
self._custom_credentials_validate( | |
tenant_id=tenant_id, | |
provider_configuration=provider_configuration, | |
model_type=model_type, | |
model=model, | |
credentials=credentials, | |
load_balancing_model_config=load_balancing_model_config, | |
) | |
def _custom_credentials_validate( | |
self, | |
tenant_id: str, | |
provider_configuration: ProviderConfiguration, | |
model_type: ModelType, | |
model: str, | |
credentials: dict, | |
load_balancing_model_config: Optional[LoadBalancingModelConfig] = None, | |
validate: bool = True, | |
) -> dict: | |
""" | |
Validate custom credentials. | |
:param tenant_id: workspace id | |
:param provider_configuration: provider configuration | |
:param model_type: model type | |
:param model: model name | |
:param credentials: credentials | |
:param load_balancing_model_config: load balancing model config | |
:param validate: validate credentials | |
:return: | |
""" | |
# Get credential form schemas from model credential schema or provider credential schema | |
credential_schemas = self._get_credential_schema(provider_configuration) | |
# Get provider credential secret variables | |
provider_credential_secret_variables = provider_configuration.extract_secret_variables( | |
credential_schemas.credential_form_schemas | |
) | |
if load_balancing_model_config: | |
try: | |
# fix origin data | |
if load_balancing_model_config.encrypted_config: | |
original_credentials = json.loads(load_balancing_model_config.encrypted_config) | |
else: | |
original_credentials = {} | |
except JSONDecodeError: | |
original_credentials = {} | |
# encrypt credentials | |
for key, value in credentials.items(): | |
if key in provider_credential_secret_variables: | |
# if send [__HIDDEN__] in secret input, it will be same as original value | |
if value == HIDDEN_VALUE and key in original_credentials: | |
credentials[key] = encrypter.decrypt_token(tenant_id, original_credentials[key]) | |
if validate: | |
if isinstance(credential_schemas, ModelCredentialSchema): | |
credentials = model_provider_factory.model_credentials_validate( | |
provider=provider_configuration.provider.provider, | |
model_type=model_type, | |
model=model, | |
credentials=credentials, | |
) | |
else: | |
credentials = model_provider_factory.provider_credentials_validate( | |
provider=provider_configuration.provider.provider, credentials=credentials | |
) | |
for key, value in credentials.items(): | |
if key in provider_credential_secret_variables: | |
credentials[key] = encrypter.encrypt_token(tenant_id, value) | |
return credentials | |
def _get_credential_schema( | |
self, provider_configuration: ProviderConfiguration | |
) -> ModelCredentialSchema | ProviderCredentialSchema: | |
""" | |
Get form schemas. | |
:param provider_configuration: provider configuration | |
:return: | |
""" | |
# Get credential form schemas from model credential schema or provider credential schema | |
if provider_configuration.provider.model_credential_schema: | |
credential_schema = provider_configuration.provider.model_credential_schema | |
else: | |
credential_schema = provider_configuration.provider.provider_credential_schema | |
return credential_schema | |
def _clear_credentials_cache(self, tenant_id: str, config_id: str) -> None: | |
""" | |
Clear credentials cache. | |
:param tenant_id: workspace id | |
:param config_id: load balancing config id | |
:return: | |
""" | |
provider_model_credentials_cache = ProviderCredentialsCache( | |
tenant_id=tenant_id, identity_id=config_id, cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL | |
) | |
provider_model_credentials_cache.delete() | |