Spaces:
Build error
Build error
import json | |
import logging | |
import random | |
import re | |
import string | |
import subprocess | |
import time | |
import uuid | |
from collections.abc import Generator | |
from datetime import datetime | |
from hashlib import sha256 | |
from typing import Any, Optional, Union | |
from zoneinfo import available_timezones | |
from flask import Response, stream_with_context | |
from flask_restful import fields | |
from configs import dify_config | |
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator | |
from core.file import helpers as file_helpers | |
from extensions.ext_redis import redis_client | |
from models.account import Account | |
def run(script): | |
return subprocess.getstatusoutput("source /root/.bashrc && " + script) | |
class AppIconUrlField(fields.Raw): | |
def output(self, key, obj): | |
if obj is None: | |
return None | |
from models.model import IconType | |
if obj.icon_type == IconType.IMAGE.value: | |
return file_helpers.get_signed_file_url(obj.icon) | |
return None | |
class TimestampField(fields.Raw): | |
def format(self, value) -> int: | |
return int(value.timestamp()) | |
def email(email): | |
# Define a regex pattern for email addresses | |
pattern = r"^[\w\.!#$%&'*+\-/=?^_`{|}~]+@([\w-]+\.)+[\w-]{2,}$" | |
# Check if the email matches the pattern | |
if re.match(pattern, email) is not None: | |
return email | |
error = "{email} is not a valid email.".format(email=email) | |
raise ValueError(error) | |
def uuid_value(value): | |
if value == "": | |
return str(value) | |
try: | |
uuid_obj = uuid.UUID(value) | |
return str(uuid_obj) | |
except ValueError: | |
error = "{value} is not a valid uuid.".format(value=value) | |
raise ValueError(error) | |
def alphanumeric(value: str): | |
# check if the value is alphanumeric and underlined | |
if re.match(r"^[a-zA-Z0-9_]+$", value): | |
return value | |
raise ValueError(f"{value} is not a valid alphanumeric value") | |
def timestamp_value(timestamp): | |
try: | |
int_timestamp = int(timestamp) | |
if int_timestamp < 0: | |
raise ValueError | |
return int_timestamp | |
except ValueError: | |
error = "{timestamp} is not a valid timestamp.".format(timestamp=timestamp) | |
raise ValueError(error) | |
class StrLen: | |
"""Restrict input to an integer in a range (inclusive)""" | |
def __init__(self, max_length, argument="argument"): | |
self.max_length = max_length | |
self.argument = argument | |
def __call__(self, value): | |
length = len(value) | |
if length > self.max_length: | |
error = "Invalid {arg}: {val}. {arg} cannot exceed length {length}".format( | |
arg=self.argument, val=value, length=self.max_length | |
) | |
raise ValueError(error) | |
return value | |
class FloatRange: | |
"""Restrict input to an float in a range (inclusive)""" | |
def __init__(self, low, high, argument="argument"): | |
self.low = low | |
self.high = high | |
self.argument = argument | |
def __call__(self, value): | |
value = _get_float(value) | |
if value < self.low or value > self.high: | |
error = "Invalid {arg}: {val}. {arg} must be within the range {lo} - {hi}".format( | |
arg=self.argument, val=value, lo=self.low, hi=self.high | |
) | |
raise ValueError(error) | |
return value | |
class DatetimeString: | |
def __init__(self, format, argument="argument"): | |
self.format = format | |
self.argument = argument | |
def __call__(self, value): | |
try: | |
datetime.strptime(value, self.format) | |
except ValueError: | |
error = "Invalid {arg}: {val}. {arg} must be conform to the format {format}".format( | |
arg=self.argument, val=value, format=self.format | |
) | |
raise ValueError(error) | |
return value | |
def _get_float(value): | |
try: | |
return float(value) | |
except (TypeError, ValueError): | |
raise ValueError("{} is not a valid float".format(value)) | |
def timezone(timezone_string): | |
if timezone_string and timezone_string in available_timezones(): | |
return timezone_string | |
error = "{timezone_string} is not a valid timezone.".format(timezone_string=timezone_string) | |
raise ValueError(error) | |
def generate_string(n): | |
letters_digits = string.ascii_letters + string.digits | |
result = "" | |
for i in range(n): | |
result += random.choice(letters_digits) | |
return result | |
def extract_remote_ip(request) -> str: | |
if request.headers.get("CF-Connecting-IP"): | |
return request.headers.get("Cf-Connecting-Ip") | |
elif request.headers.getlist("X-Forwarded-For"): | |
return request.headers.getlist("X-Forwarded-For")[0] | |
else: | |
return request.remote_addr | |
def generate_text_hash(text: str) -> str: | |
hash_text = str(text) + "None" | |
return sha256(hash_text.encode()).hexdigest() | |
def compact_generate_response(response: Union[dict, RateLimitGenerator]) -> Response: | |
if isinstance(response, dict): | |
return Response(response=json.dumps(response), status=200, mimetype="application/json") | |
else: | |
def generate() -> Generator: | |
yield from response | |
return Response(stream_with_context(generate()), status=200, mimetype="text/event-stream") | |
class TokenManager: | |
def generate_token( | |
cls, | |
token_type: str, | |
account: Optional[Account] = None, | |
email: Optional[str] = None, | |
additional_data: Optional[dict] = None, | |
) -> str: | |
if account is None and email is None: | |
raise ValueError("Account or email must be provided") | |
account_id = account.id if account else None | |
account_email = account.email if account else email | |
if account_id: | |
old_token = cls._get_current_token_for_account(account_id, token_type) | |
if old_token: | |
if isinstance(old_token, bytes): | |
old_token = old_token.decode("utf-8") | |
cls.revoke_token(old_token, token_type) | |
token = str(uuid.uuid4()) | |
token_data = {"account_id": account_id, "email": account_email, "token_type": token_type} | |
if additional_data: | |
token_data.update(additional_data) | |
expiry_minutes = dify_config.model_dump().get(f"{token_type.upper()}_TOKEN_EXPIRY_MINUTES") | |
token_key = cls._get_token_key(token, token_type) | |
expiry_time = int(expiry_minutes * 60) | |
redis_client.setex(token_key, expiry_time, json.dumps(token_data)) | |
if account_id: | |
cls._set_current_token_for_account(account.id, token, token_type, expiry_minutes) | |
return token | |
def _get_token_key(cls, token: str, token_type: str) -> str: | |
return f"{token_type}:token:{token}" | |
def revoke_token(cls, token: str, token_type: str): | |
token_key = cls._get_token_key(token, token_type) | |
redis_client.delete(token_key) | |
def get_token_data(cls, token: str, token_type: str) -> Optional[dict[str, Any]]: | |
key = cls._get_token_key(token, token_type) | |
token_data_json = redis_client.get(key) | |
if token_data_json is None: | |
logging.warning(f"{token_type} token {token} not found with key {key}") | |
return None | |
token_data = json.loads(token_data_json) | |
return token_data | |
def _get_current_token_for_account(cls, account_id: str, token_type: str) -> Optional[str]: | |
key = cls._get_account_token_key(account_id, token_type) | |
current_token = redis_client.get(key) | |
return current_token | |
def _set_current_token_for_account( | |
cls, account_id: str, token: str, token_type: str, expiry_hours: Union[int, float] | |
): | |
key = cls._get_account_token_key(account_id, token_type) | |
expiry_time = int(expiry_hours * 60 * 60) | |
redis_client.setex(key, expiry_time, token) | |
def _get_account_token_key(cls, account_id: str, token_type: str) -> str: | |
return f"{token_type}:account:{account_id}" | |
class RateLimiter: | |
def __init__(self, prefix: str, max_attempts: int, time_window: int): | |
self.prefix = prefix | |
self.max_attempts = max_attempts | |
self.time_window = time_window | |
def _get_key(self, email: str) -> str: | |
return f"{self.prefix}:{email}" | |
def is_rate_limited(self, email: str) -> bool: | |
key = self._get_key(email) | |
current_time = int(time.time()) | |
window_start_time = current_time - self.time_window | |
redis_client.zremrangebyscore(key, "-inf", window_start_time) | |
attempts = redis_client.zcard(key) | |
if attempts and int(attempts) >= self.max_attempts: | |
return True | |
return False | |
def increment_rate_limit(self, email: str): | |
key = self._get_key(email) | |
current_time = int(time.time()) | |
redis_client.zadd(key, {current_time: current_time}) | |
redis_client.expire(key, self.time_window * 2) | |