File size: 3,128 Bytes
aa37f44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
import time

from langchain_openai import AzureChatOpenAI
from msal import ConfidentialClientApplication

DEFAULT_TOKEN_UPDATE_FREQUENCY = 3300  # Default token duration is 1 hour (3600 s.)

# LOAD ENVIRONMENT VARIABLES
try:
    from dotenv import load_dotenv
    load_dotenv()
except Exception:
    pass


client_id = os.environ.get("AZURE_CLIENT_ID", None)
client_credential = os.environ.get("AZURE_CLIENT_CREDENTIAL", None)
tenant_name = os.environ.get("AZURE_TENANT_NAME", None)
scopes = [os.environ.get("AZURE_SCOPE", None)]

azure_ad_token_frequency = int(
    os.environ.get("TOKEN_UPDATE_FREQUENCY", DEFAULT_TOKEN_UPDATE_FREQUENCY)
)
azure_ad_token = None
azure_ad_token_timestamp = 0.0


def _get_azure_ad_token():
    global azure_ad_token
    global azure_ad_token_timestamp
    now = time.time()

    # Return current token if not outdated:
    if (azure_ad_token is not None) and (
        azure_ad_token_timestamp + azure_ad_token_frequency > now
    ):
        print("Using current token (not expired)...")
        return azure_ad_token

    # Else, generate a new token:
    print("Generating new token...")
    app = ConfidentialClientApplication(
        client_id=client_id,
        client_credential=client_credential,
        authority=f"https://login.microsoftonline.com/{tenant_name}",
    )
    result = app.acquire_token_for_client(scopes=scopes)
    if "access_token" not in result:
        raise ValueError("No access token in result")

    if result["access_token"] != azure_ad_token:
        print("New token received.")
        azure_ad_token = result["access_token"]
        azure_ad_token_timestamp = now
    else:
        print("Same token received.")

    return azure_ad_token


def get_llm(
    max_tokens: int = 1024,
    temperature: float = 0.0,
    verbose: bool = True,
    streaming: bool = False,
    **kwargs,
) -> AzureChatOpenAI:
    auth_dict = dict(openai_api_type="azure")
    # Note: OPENAI_API_VERSION is automatically taken from environment variables.

    # First option: provide AZURE_OPENAI_API_BASE_URL, OPENAI_API_VERSION, AZURE_CLIENT_ID,
    # AZURE_CLIENT_CREDENTIAL, AZURE_TENANT_NAME & AZURE_SCOPE:
    if (
        (client_id is not None)
        and (client_credential is not None)
        and (tenant_name is not None)
    ):
        print("Using Azure AD token")
        auth_dict["openai_api_base"] = os.environ["AZURE_OPENAI_API_BASE_URL"]
        auth_dict["azure_ad_token_provider"] = _get_azure_ad_token

    # Second option: provide AZURE_OPENAI_ENDPOINT, AZURE_OPENAI_API_DEPLOYMENT_NAME,
    # OPENAI_API_VERSION & AZURE_OPENAI_API_KEY:
    else:
        print("Using AZURE_OPENAI_API_DEPLOYMENT_NAME and AZURE_OPENAI_API_KEY")
        auth_dict["deployment_name"] = os.environ["AZURE_OPENAI_API_DEPLOYMENT_NAME"]
        # Note: AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_API_KEY are automatically taken
        # from environment variable.

    llm = AzureChatOpenAI(
        **auth_dict,
        max_tokens=max_tokens,
        temperature=temperature,
        verbose=verbose,
        streaming=streaming,
        **kwargs,
    )
    return llm