Spaces:
Build error
Build error
import datetime | |
import json | |
import logging | |
from collections import defaultdict | |
from collections.abc import Iterator | |
from json import JSONDecodeError | |
from typing import Optional | |
from pydantic import BaseModel, ConfigDict | |
from constants import HIDDEN_VALUE | |
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity | |
from core.entities.provider_entities import ( | |
CustomConfiguration, | |
ModelSettings, | |
SystemConfiguration, | |
SystemConfigurationStatus, | |
) | |
from core.helper import encrypter | |
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType | |
from core.model_runtime.entities.model_entities import FetchFrom, ModelType | |
from core.model_runtime.entities.provider_entities import ( | |
ConfigurateMethod, | |
CredentialFormSchema, | |
FormType, | |
ProviderEntity, | |
) | |
from core.model_runtime.model_providers import model_provider_factory | |
from core.model_runtime.model_providers.__base.ai_model import AIModel | |
from core.model_runtime.model_providers.__base.model_provider import ModelProvider | |
from extensions.ext_database import db | |
from models.provider import ( | |
LoadBalancingModelConfig, | |
Provider, | |
ProviderModel, | |
ProviderModelSetting, | |
ProviderType, | |
TenantPreferredModelProvider, | |
) | |
logger = logging.getLogger(__name__) | |
original_provider_configurate_methods = {} | |
class ProviderConfiguration(BaseModel): | |
""" | |
Model class for provider configuration. | |
""" | |
tenant_id: str | |
provider: ProviderEntity | |
preferred_provider_type: ProviderType | |
using_provider_type: ProviderType | |
system_configuration: SystemConfiguration | |
custom_configuration: CustomConfiguration | |
model_settings: list[ModelSettings] | |
# pydantic configs | |
model_config = ConfigDict(protected_namespaces=()) | |
def __init__(self, **data): | |
super().__init__(**data) | |
if self.provider.provider not in original_provider_configurate_methods: | |
original_provider_configurate_methods[self.provider.provider] = [] | |
for configurate_method in self.provider.configurate_methods: | |
original_provider_configurate_methods[self.provider.provider].append(configurate_method) | |
if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]: | |
if ( | |
any( | |
len(quota_configuration.restrict_models) > 0 | |
for quota_configuration in self.system_configuration.quota_configurations | |
) | |
and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods | |
): | |
self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL) | |
def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]: | |
""" | |
Get current credentials. | |
:param model_type: model type | |
:param model: model name | |
:return: | |
""" | |
if self.model_settings: | |
# check if model is disabled by admin | |
for model_setting in self.model_settings: | |
if model_setting.model_type == model_type and model_setting.model == model: | |
if not model_setting.enabled: | |
raise ValueError(f"Model {model} is disabled.") | |
if self.using_provider_type == ProviderType.SYSTEM: | |
restrict_models = [] | |
for quota_configuration in self.system_configuration.quota_configurations: | |
if self.system_configuration.current_quota_type != quota_configuration.quota_type: | |
continue | |
restrict_models = quota_configuration.restrict_models | |
copy_credentials = self.system_configuration.credentials.copy() | |
if restrict_models: | |
for restrict_model in restrict_models: | |
if ( | |
restrict_model.model_type == model_type | |
and restrict_model.model == model | |
and restrict_model.base_model_name | |
): | |
copy_credentials["base_model_name"] = restrict_model.base_model_name | |
return copy_credentials | |
else: | |
credentials = None | |
if self.custom_configuration.models: | |
for model_configuration in self.custom_configuration.models: | |
if model_configuration.model_type == model_type and model_configuration.model == model: | |
credentials = model_configuration.credentials | |
break | |
if not credentials and self.custom_configuration.provider: | |
credentials = self.custom_configuration.provider.credentials | |
return credentials | |
def get_system_configuration_status(self) -> SystemConfigurationStatus: | |
""" | |
Get system configuration status. | |
:return: | |
""" | |
if self.system_configuration.enabled is False: | |
return SystemConfigurationStatus.UNSUPPORTED | |
current_quota_type = self.system_configuration.current_quota_type | |
current_quota_configuration = next( | |
(q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), None | |
) | |
return ( | |
SystemConfigurationStatus.ACTIVE | |
if current_quota_configuration.is_valid | |
else SystemConfigurationStatus.QUOTA_EXCEEDED | |
) | |
def is_custom_configuration_available(self) -> bool: | |
""" | |
Check custom configuration available. | |
:return: | |
""" | |
return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0 | |
def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]: | |
""" | |
Get custom credentials. | |
:param obfuscated: obfuscated secret data in credentials | |
:return: | |
""" | |
if self.custom_configuration.provider is None: | |
return None | |
credentials = self.custom_configuration.provider.credentials | |
if not obfuscated: | |
return credentials | |
# Obfuscate credentials | |
return self.obfuscated_credentials( | |
credentials=credentials, | |
credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas | |
if self.provider.provider_credential_schema | |
else [], | |
) | |
def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict]: | |
""" | |
Validate custom credentials. | |
:param credentials: provider credentials | |
:return: | |
""" | |
# get provider | |
provider_record = ( | |
db.session.query(Provider) | |
.filter( | |
Provider.tenant_id == self.tenant_id, | |
Provider.provider_name == self.provider.provider, | |
Provider.provider_type == ProviderType.CUSTOM.value, | |
) | |
.first() | |
) | |
# Get provider credential secret variables | |
provider_credential_secret_variables = self.extract_secret_variables( | |
self.provider.provider_credential_schema.credential_form_schemas | |
if self.provider.provider_credential_schema | |
else [] | |
) | |
if provider_record: | |
try: | |
# fix origin data | |
if provider_record.encrypted_config: | |
if not provider_record.encrypted_config.startswith("{"): | |
original_credentials = {"openai_api_key": provider_record.encrypted_config} | |
else: | |
original_credentials = json.loads(provider_record.encrypted_config) | |
else: | |
original_credentials = {} | |
except JSONDecodeError: | |
original_credentials = {} | |
# encrypt credentials | |
for key, value in credentials.items(): | |
if key in provider_credential_secret_variables: | |
# if send [__HIDDEN__] in secret input, it will be same as original value | |
if value == HIDDEN_VALUE and key in original_credentials: | |
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) | |
credentials = model_provider_factory.provider_credentials_validate( | |
provider=self.provider.provider, credentials=credentials | |
) | |
for key, value in credentials.items(): | |
if key in provider_credential_secret_variables: | |
credentials[key] = encrypter.encrypt_token(self.tenant_id, value) | |
return provider_record, credentials | |
def add_or_update_custom_credentials(self, credentials: dict) -> None: | |
""" | |
Add or update custom provider credentials. | |
:param credentials: | |
:return: | |
""" | |
# validate custom provider config | |
provider_record, credentials = self.custom_credentials_validate(credentials) | |
# save provider | |
# Note: Do not switch the preferred provider, which allows users to use quotas first | |
if provider_record: | |
provider_record.encrypted_config = json.dumps(credentials) | |
provider_record.is_valid = True | |
provider_record.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) | |
db.session.commit() | |
else: | |
provider_record = Provider( | |
tenant_id=self.tenant_id, | |
provider_name=self.provider.provider, | |
provider_type=ProviderType.CUSTOM.value, | |
encrypted_config=json.dumps(credentials), | |
is_valid=True, | |
) | |
db.session.add(provider_record) | |
db.session.commit() | |
provider_model_credentials_cache = ProviderCredentialsCache( | |
tenant_id=self.tenant_id, identity_id=provider_record.id, cache_type=ProviderCredentialsCacheType.PROVIDER | |
) | |
provider_model_credentials_cache.delete() | |
self.switch_preferred_provider_type(ProviderType.CUSTOM) | |
def delete_custom_credentials(self) -> None: | |
""" | |
Delete custom provider credentials. | |
:return: | |
""" | |
# get provider | |
provider_record = ( | |
db.session.query(Provider) | |
.filter( | |
Provider.tenant_id == self.tenant_id, | |
Provider.provider_name == self.provider.provider, | |
Provider.provider_type == ProviderType.CUSTOM.value, | |
) | |
.first() | |
) | |
# delete provider | |
if provider_record: | |
self.switch_preferred_provider_type(ProviderType.SYSTEM) | |
db.session.delete(provider_record) | |
db.session.commit() | |
provider_model_credentials_cache = ProviderCredentialsCache( | |
tenant_id=self.tenant_id, | |
identity_id=provider_record.id, | |
cache_type=ProviderCredentialsCacheType.PROVIDER, | |
) | |
provider_model_credentials_cache.delete() | |
def get_custom_model_credentials( | |
self, model_type: ModelType, model: str, obfuscated: bool = False | |
) -> Optional[dict]: | |
""" | |
Get custom model credentials. | |
:param model_type: model type | |
:param model: model name | |
:param obfuscated: obfuscated secret data in credentials | |
:return: | |
""" | |
if not self.custom_configuration.models: | |
return None | |
for model_configuration in self.custom_configuration.models: | |
if model_configuration.model_type == model_type and model_configuration.model == model: | |
credentials = model_configuration.credentials | |
if not obfuscated: | |
return credentials | |
# Obfuscate credentials | |
return self.obfuscated_credentials( | |
credentials=credentials, | |
credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas | |
if self.provider.model_credential_schema | |
else [], | |
) | |
return None | |
def custom_model_credentials_validate( | |
self, model_type: ModelType, model: str, credentials: dict | |
) -> tuple[ProviderModel, dict]: | |
""" | |
Validate custom model credentials. | |
:param model_type: model type | |
:param model: model name | |
:param credentials: model credentials | |
:return: | |
""" | |
# get provider model | |
provider_model_record = ( | |
db.session.query(ProviderModel) | |
.filter( | |
ProviderModel.tenant_id == self.tenant_id, | |
ProviderModel.provider_name == self.provider.provider, | |
ProviderModel.model_name == model, | |
ProviderModel.model_type == model_type.to_origin_model_type(), | |
) | |
.first() | |
) | |
# Get provider credential secret variables | |
provider_credential_secret_variables = self.extract_secret_variables( | |
self.provider.model_credential_schema.credential_form_schemas | |
if self.provider.model_credential_schema | |
else [] | |
) | |
if provider_model_record: | |
try: | |
original_credentials = ( | |
json.loads(provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {} | |
) | |
except JSONDecodeError: | |
original_credentials = {} | |
# decrypt credentials | |
for key, value in credentials.items(): | |
if key in provider_credential_secret_variables: | |
# if send [__HIDDEN__] in secret input, it will be same as original value | |
if value == HIDDEN_VALUE and key in original_credentials: | |
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) | |
credentials = model_provider_factory.model_credentials_validate( | |
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials | |
) | |
for key, value in credentials.items(): | |
if key in provider_credential_secret_variables: | |
credentials[key] = encrypter.encrypt_token(self.tenant_id, value) | |
return provider_model_record, credentials | |
def add_or_update_custom_model_credentials(self, model_type: ModelType, model: str, credentials: dict) -> None: | |
""" | |
Add or update custom model credentials. | |
:param model_type: model type | |
:param model: model name | |
:param credentials: model credentials | |
:return: | |
""" | |
# validate custom model config | |
provider_model_record, credentials = self.custom_model_credentials_validate(model_type, model, credentials) | |
# save provider model | |
# Note: Do not switch the preferred provider, which allows users to use quotas first | |
if provider_model_record: | |
provider_model_record.encrypted_config = json.dumps(credentials) | |
provider_model_record.is_valid = True | |
provider_model_record.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) | |
db.session.commit() | |
else: | |
provider_model_record = ProviderModel( | |
tenant_id=self.tenant_id, | |
provider_name=self.provider.provider, | |
model_name=model, | |
model_type=model_type.to_origin_model_type(), | |
encrypted_config=json.dumps(credentials), | |
is_valid=True, | |
) | |
db.session.add(provider_model_record) | |
db.session.commit() | |
provider_model_credentials_cache = ProviderCredentialsCache( | |
tenant_id=self.tenant_id, | |
identity_id=provider_model_record.id, | |
cache_type=ProviderCredentialsCacheType.MODEL, | |
) | |
provider_model_credentials_cache.delete() | |
def delete_custom_model_credentials(self, model_type: ModelType, model: str) -> None: | |
""" | |
Delete custom model credentials. | |
:param model_type: model type | |
:param model: model name | |
:return: | |
""" | |
# get provider model | |
provider_model_record = ( | |
db.session.query(ProviderModel) | |
.filter( | |
ProviderModel.tenant_id == self.tenant_id, | |
ProviderModel.provider_name == self.provider.provider, | |
ProviderModel.model_name == model, | |
ProviderModel.model_type == model_type.to_origin_model_type(), | |
) | |
.first() | |
) | |
# delete provider model | |
if provider_model_record: | |
db.session.delete(provider_model_record) | |
db.session.commit() | |
provider_model_credentials_cache = ProviderCredentialsCache( | |
tenant_id=self.tenant_id, | |
identity_id=provider_model_record.id, | |
cache_type=ProviderCredentialsCacheType.MODEL, | |
) | |
provider_model_credentials_cache.delete() | |
def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting: | |
""" | |
Enable model. | |
:param model_type: model type | |
:param model: model name | |
:return: | |
""" | |
model_setting = ( | |
db.session.query(ProviderModelSetting) | |
.filter( | |
ProviderModelSetting.tenant_id == self.tenant_id, | |
ProviderModelSetting.provider_name == self.provider.provider, | |
ProviderModelSetting.model_type == model_type.to_origin_model_type(), | |
ProviderModelSetting.model_name == model, | |
) | |
.first() | |
) | |
if model_setting: | |
model_setting.enabled = True | |
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) | |
db.session.commit() | |
else: | |
model_setting = ProviderModelSetting( | |
tenant_id=self.tenant_id, | |
provider_name=self.provider.provider, | |
model_type=model_type.to_origin_model_type(), | |
model_name=model, | |
enabled=True, | |
) | |
db.session.add(model_setting) | |
db.session.commit() | |
return model_setting | |
def disable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting: | |
""" | |
Disable model. | |
:param model_type: model type | |
:param model: model name | |
:return: | |
""" | |
model_setting = ( | |
db.session.query(ProviderModelSetting) | |
.filter( | |
ProviderModelSetting.tenant_id == self.tenant_id, | |
ProviderModelSetting.provider_name == self.provider.provider, | |
ProviderModelSetting.model_type == model_type.to_origin_model_type(), | |
ProviderModelSetting.model_name == model, | |
) | |
.first() | |
) | |
if model_setting: | |
model_setting.enabled = False | |
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) | |
db.session.commit() | |
else: | |
model_setting = ProviderModelSetting( | |
tenant_id=self.tenant_id, | |
provider_name=self.provider.provider, | |
model_type=model_type.to_origin_model_type(), | |
model_name=model, | |
enabled=False, | |
) | |
db.session.add(model_setting) | |
db.session.commit() | |
return model_setting | |
def get_provider_model_setting(self, model_type: ModelType, model: str) -> Optional[ProviderModelSetting]: | |
""" | |
Get provider model setting. | |
:param model_type: model type | |
:param model: model name | |
:return: | |
""" | |
return ( | |
db.session.query(ProviderModelSetting) | |
.filter( | |
ProviderModelSetting.tenant_id == self.tenant_id, | |
ProviderModelSetting.provider_name == self.provider.provider, | |
ProviderModelSetting.model_type == model_type.to_origin_model_type(), | |
ProviderModelSetting.model_name == model, | |
) | |
.first() | |
) | |
def enable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting: | |
""" | |
Enable model load balancing. | |
:param model_type: model type | |
:param model: model name | |
:return: | |
""" | |
load_balancing_config_count = ( | |
db.session.query(LoadBalancingModelConfig) | |
.filter( | |
LoadBalancingModelConfig.tenant_id == self.tenant_id, | |
LoadBalancingModelConfig.provider_name == self.provider.provider, | |
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), | |
LoadBalancingModelConfig.model_name == model, | |
) | |
.count() | |
) | |
if load_balancing_config_count <= 1: | |
raise ValueError("Model load balancing configuration must be more than 1.") | |
model_setting = ( | |
db.session.query(ProviderModelSetting) | |
.filter( | |
ProviderModelSetting.tenant_id == self.tenant_id, | |
ProviderModelSetting.provider_name == self.provider.provider, | |
ProviderModelSetting.model_type == model_type.to_origin_model_type(), | |
ProviderModelSetting.model_name == model, | |
) | |
.first() | |
) | |
if model_setting: | |
model_setting.load_balancing_enabled = True | |
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) | |
db.session.commit() | |
else: | |
model_setting = ProviderModelSetting( | |
tenant_id=self.tenant_id, | |
provider_name=self.provider.provider, | |
model_type=model_type.to_origin_model_type(), | |
model_name=model, | |
load_balancing_enabled=True, | |
) | |
db.session.add(model_setting) | |
db.session.commit() | |
return model_setting | |
def disable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting: | |
""" | |
Disable model load balancing. | |
:param model_type: model type | |
:param model: model name | |
:return: | |
""" | |
model_setting = ( | |
db.session.query(ProviderModelSetting) | |
.filter( | |
ProviderModelSetting.tenant_id == self.tenant_id, | |
ProviderModelSetting.provider_name == self.provider.provider, | |
ProviderModelSetting.model_type == model_type.to_origin_model_type(), | |
ProviderModelSetting.model_name == model, | |
) | |
.first() | |
) | |
if model_setting: | |
model_setting.load_balancing_enabled = False | |
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) | |
db.session.commit() | |
else: | |
model_setting = ProviderModelSetting( | |
tenant_id=self.tenant_id, | |
provider_name=self.provider.provider, | |
model_type=model_type.to_origin_model_type(), | |
model_name=model, | |
load_balancing_enabled=False, | |
) | |
db.session.add(model_setting) | |
db.session.commit() | |
return model_setting | |
def get_provider_instance(self) -> ModelProvider: | |
""" | |
Get provider instance. | |
:return: | |
""" | |
return model_provider_factory.get_provider_instance(self.provider.provider) | |
def get_model_type_instance(self, model_type: ModelType) -> AIModel: | |
""" | |
Get current model type instance. | |
:param model_type: model type | |
:return: | |
""" | |
# Get provider instance | |
provider_instance = self.get_provider_instance() | |
# Get model instance of LLM | |
return provider_instance.get_model_instance(model_type) | |
def switch_preferred_provider_type(self, provider_type: ProviderType) -> None: | |
""" | |
Switch preferred provider type. | |
:param provider_type: | |
:return: | |
""" | |
if provider_type == self.preferred_provider_type: | |
return | |
if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled: | |
return | |
# get preferred provider | |
preferred_model_provider = ( | |
db.session.query(TenantPreferredModelProvider) | |
.filter( | |
TenantPreferredModelProvider.tenant_id == self.tenant_id, | |
TenantPreferredModelProvider.provider_name == self.provider.provider, | |
) | |
.first() | |
) | |
if preferred_model_provider: | |
preferred_model_provider.preferred_provider_type = provider_type.value | |
else: | |
preferred_model_provider = TenantPreferredModelProvider( | |
tenant_id=self.tenant_id, | |
provider_name=self.provider.provider, | |
preferred_provider_type=provider_type.value, | |
) | |
db.session.add(preferred_model_provider) | |
db.session.commit() | |
def extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]: | |
""" | |
Extract secret input form variables. | |
:param credential_form_schemas: | |
:return: | |
""" | |
secret_input_form_variables = [] | |
for credential_form_schema in credential_form_schemas: | |
if credential_form_schema.type == FormType.SECRET_INPUT: | |
secret_input_form_variables.append(credential_form_schema.variable) | |
return secret_input_form_variables | |
def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict: | |
""" | |
Obfuscated credentials. | |
:param credentials: credentials | |
:param credential_form_schemas: credential form schemas | |
:return: | |
""" | |
# Get provider credential secret variables | |
credential_secret_variables = self.extract_secret_variables(credential_form_schemas) | |
# Obfuscate provider credentials | |
copy_credentials = credentials.copy() | |
for key, value in copy_credentials.items(): | |
if key in credential_secret_variables: | |
copy_credentials[key] = encrypter.obfuscated_token(value) | |
return copy_credentials | |
def get_provider_model( | |
self, model_type: ModelType, model: str, only_active: bool = False | |
) -> Optional[ModelWithProviderEntity]: | |
""" | |
Get provider model. | |
:param model_type: model type | |
:param model: model name | |
:param only_active: return active model only | |
:return: | |
""" | |
provider_models = self.get_provider_models(model_type, only_active) | |
for provider_model in provider_models: | |
if provider_model.model == model: | |
return provider_model | |
return None | |
def get_provider_models( | |
self, model_type: Optional[ModelType] = None, only_active: bool = False | |
) -> list[ModelWithProviderEntity]: | |
""" | |
Get provider models. | |
:param model_type: model type | |
:param only_active: only active models | |
:return: | |
""" | |
provider_instance = self.get_provider_instance() | |
model_types = [] | |
if model_type: | |
model_types.append(model_type) | |
else: | |
model_types = provider_instance.get_provider_schema().supported_model_types | |
# Group model settings by model type and model | |
model_setting_map = defaultdict(dict) | |
for model_setting in self.model_settings: | |
model_setting_map[model_setting.model_type][model_setting.model] = model_setting | |
if self.using_provider_type == ProviderType.SYSTEM: | |
provider_models = self._get_system_provider_models( | |
model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map | |
) | |
else: | |
provider_models = self._get_custom_provider_models( | |
model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map | |
) | |
if only_active: | |
provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE] | |
# resort provider_models | |
return sorted(provider_models, key=lambda x: x.model_type.value) | |
def _get_system_provider_models( | |
self, | |
model_types: list[ModelType], | |
provider_instance: ModelProvider, | |
model_setting_map: dict[ModelType, dict[str, ModelSettings]], | |
) -> list[ModelWithProviderEntity]: | |
""" | |
Get system provider models. | |
:param model_types: model types | |
:param provider_instance: provider instance | |
:param model_setting_map: model setting map | |
:return: | |
""" | |
provider_models = [] | |
for model_type in model_types: | |
for m in provider_instance.models(model_type): | |
status = ModelStatus.ACTIVE | |
if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]: | |
model_setting = model_setting_map[m.model_type][m.model] | |
if model_setting.enabled is False: | |
status = ModelStatus.DISABLED | |
provider_models.append( | |
ModelWithProviderEntity( | |
model=m.model, | |
label=m.label, | |
model_type=m.model_type, | |
features=m.features, | |
fetch_from=m.fetch_from, | |
model_properties=m.model_properties, | |
deprecated=m.deprecated, | |
provider=SimpleModelProviderEntity(self.provider), | |
status=status, | |
) | |
) | |
if self.provider.provider not in original_provider_configurate_methods: | |
original_provider_configurate_methods[self.provider.provider] = [] | |
for configurate_method in provider_instance.get_provider_schema().configurate_methods: | |
original_provider_configurate_methods[self.provider.provider].append(configurate_method) | |
should_use_custom_model = False | |
if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]: | |
should_use_custom_model = True | |
for quota_configuration in self.system_configuration.quota_configurations: | |
if self.system_configuration.current_quota_type != quota_configuration.quota_type: | |
continue | |
restrict_models = quota_configuration.restrict_models | |
if len(restrict_models) == 0: | |
break | |
if should_use_custom_model: | |
if original_provider_configurate_methods[self.provider.provider] == [ | |
ConfigurateMethod.CUSTOMIZABLE_MODEL | |
]: | |
# only customizable model | |
for restrict_model in restrict_models: | |
copy_credentials = self.system_configuration.credentials.copy() | |
if restrict_model.base_model_name: | |
copy_credentials["base_model_name"] = restrict_model.base_model_name | |
try: | |
custom_model_schema = provider_instance.get_model_instance( | |
restrict_model.model_type | |
).get_customizable_model_schema_from_credentials(restrict_model.model, copy_credentials) | |
except Exception as ex: | |
logger.warning(f"get custom model schema failed, {ex}") | |
continue | |
if not custom_model_schema: | |
continue | |
if custom_model_schema.model_type not in model_types: | |
continue | |
status = ModelStatus.ACTIVE | |
if ( | |
custom_model_schema.model_type in model_setting_map | |
and custom_model_schema.model in model_setting_map[custom_model_schema.model_type] | |
): | |
model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model] | |
if model_setting.enabled is False: | |
status = ModelStatus.DISABLED | |
provider_models.append( | |
ModelWithProviderEntity( | |
model=custom_model_schema.model, | |
label=custom_model_schema.label, | |
model_type=custom_model_schema.model_type, | |
features=custom_model_schema.features, | |
fetch_from=FetchFrom.PREDEFINED_MODEL, | |
model_properties=custom_model_schema.model_properties, | |
deprecated=custom_model_schema.deprecated, | |
provider=SimpleModelProviderEntity(self.provider), | |
status=status, | |
) | |
) | |
# if llm name not in restricted llm list, remove it | |
restrict_model_names = [rm.model for rm in restrict_models] | |
for m in provider_models: | |
if m.model_type == ModelType.LLM and m.model not in restrict_model_names: | |
m.status = ModelStatus.NO_PERMISSION | |
elif not quota_configuration.is_valid: | |
m.status = ModelStatus.QUOTA_EXCEEDED | |
return provider_models | |
def _get_custom_provider_models( | |
self, | |
model_types: list[ModelType], | |
provider_instance: ModelProvider, | |
model_setting_map: dict[ModelType, dict[str, ModelSettings]], | |
) -> list[ModelWithProviderEntity]: | |
""" | |
Get custom provider models. | |
:param model_types: model types | |
:param provider_instance: provider instance | |
:param model_setting_map: model setting map | |
:return: | |
""" | |
provider_models = [] | |
credentials = None | |
if self.custom_configuration.provider: | |
credentials = self.custom_configuration.provider.credentials | |
for model_type in model_types: | |
if model_type not in self.provider.supported_model_types: | |
continue | |
models = provider_instance.models(model_type) | |
for m in models: | |
status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE | |
load_balancing_enabled = False | |
if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]: | |
model_setting = model_setting_map[m.model_type][m.model] | |
if model_setting.enabled is False: | |
status = ModelStatus.DISABLED | |
if len(model_setting.load_balancing_configs) > 1: | |
load_balancing_enabled = True | |
provider_models.append( | |
ModelWithProviderEntity( | |
model=m.model, | |
label=m.label, | |
model_type=m.model_type, | |
features=m.features, | |
fetch_from=m.fetch_from, | |
model_properties=m.model_properties, | |
deprecated=m.deprecated, | |
provider=SimpleModelProviderEntity(self.provider), | |
status=status, | |
load_balancing_enabled=load_balancing_enabled, | |
) | |
) | |
# custom models | |
for model_configuration in self.custom_configuration.models: | |
if model_configuration.model_type not in model_types: | |
continue | |
try: | |
custom_model_schema = provider_instance.get_model_instance( | |
model_configuration.model_type | |
).get_customizable_model_schema_from_credentials( | |
model_configuration.model, model_configuration.credentials | |
) | |
except Exception as ex: | |
logger.warning(f"get custom model schema failed, {ex}") | |
continue | |
if not custom_model_schema: | |
continue | |
status = ModelStatus.ACTIVE | |
load_balancing_enabled = False | |
if ( | |
custom_model_schema.model_type in model_setting_map | |
and custom_model_schema.model in model_setting_map[custom_model_schema.model_type] | |
): | |
model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model] | |
if model_setting.enabled is False: | |
status = ModelStatus.DISABLED | |
if len(model_setting.load_balancing_configs) > 1: | |
load_balancing_enabled = True | |
provider_models.append( | |
ModelWithProviderEntity( | |
model=custom_model_schema.model, | |
label=custom_model_schema.label, | |
model_type=custom_model_schema.model_type, | |
features=custom_model_schema.features, | |
fetch_from=custom_model_schema.fetch_from, | |
model_properties=custom_model_schema.model_properties, | |
deprecated=custom_model_schema.deprecated, | |
provider=SimpleModelProviderEntity(self.provider), | |
status=status, | |
load_balancing_enabled=load_balancing_enabled, | |
) | |
) | |
return provider_models | |
class ProviderConfigurations(BaseModel): | |
""" | |
Model class for provider configuration dict. | |
""" | |
tenant_id: str | |
configurations: dict[str, ProviderConfiguration] = {} | |
def __init__(self, tenant_id: str): | |
super().__init__(tenant_id=tenant_id) | |
def get_models( | |
self, provider: Optional[str] = None, model_type: Optional[ModelType] = None, only_active: bool = False | |
) -> list[ModelWithProviderEntity]: | |
""" | |
Get available models. | |
If preferred provider type is `system`: | |
Get the current **system mode** if provider supported, | |
if all system modes are not available (no quota), it is considered to be the **custom credential mode**. | |
If there is no model configured in custom mode, it is treated as no_configure. | |
system > custom > no_configure | |
If preferred provider type is `custom`: | |
If custom credentials are configured, it is treated as custom mode. | |
Otherwise, get the current **system mode** if supported, | |
If all system modes are not available (no quota), it is treated as no_configure. | |
custom > system > no_configure | |
If real mode is `system`, use system credentials to get models, | |
paid quotas > provider free quotas > system free quotas | |
include pre-defined models (exclude GPT-4, status marked as `no_permission`). | |
If real mode is `custom`, use workspace custom credentials to get models, | |
include pre-defined models, custom models(manual append). | |
If real mode is `no_configure`, only return pre-defined models from `model runtime`. | |
(model status marked as `no_configure` if preferred provider type is `custom` otherwise `quota_exceeded`) | |
model status marked as `active` is available. | |
:param provider: provider name | |
:param model_type: model type | |
:param only_active: only active models | |
:return: | |
""" | |
all_models = [] | |
for provider_configuration in self.values(): | |
if provider and provider_configuration.provider.provider != provider: | |
continue | |
all_models.extend(provider_configuration.get_provider_models(model_type, only_active)) | |
return all_models | |
def to_list(self) -> list[ProviderConfiguration]: | |
""" | |
Convert to list. | |
:return: | |
""" | |
return list(self.values()) | |
def __getitem__(self, key): | |
return self.configurations[key] | |
def __setitem__(self, key, value): | |
self.configurations[key] = value | |
def __iter__(self): | |
return iter(self.configurations) | |
def values(self) -> Iterator[ProviderConfiguration]: | |
return self.configurations.values() | |
def get(self, key, default=None): | |
return self.configurations.get(key, default) | |
class ProviderModelBundle(BaseModel): | |
""" | |
Provider model bundle. | |
""" | |
configuration: ProviderConfiguration | |
provider_instance: ModelProvider | |
model_type_instance: AIModel | |
# pydantic configs | |
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=()) | |