File size: 1,437 Bytes
7677cff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import os
from dotenv import load_dotenv
from langchain_groq import ChatGroq
from langchain_openai import ChatOpenAI
load_dotenv()

model_list = {
    "typhoon-v1.5x-70b-instruct": {
        "model_name": "typhoon-v1.5x-70b-instruct",
        "model_type": "openai",
        "base_url": "https://api.opentyphoon.ai/v1",
        "api_key": os.getenv("TYPHOON_CHAT_KEY")
    },
    "openthaigpt": {
        "model_name": ".",
        "model_type": "openai",
        "base_url": "https://api.aieat.or.th/v1",
        "api_key": "dummy"
    },
    "llama-3.3-70b-versatile": {
        "model_name": "llama-3.3-70b-versatile",
        "model_type": "groq",
        "base_url": None,
        "api_key": os.getenv("GROQ_CHAT_KEY")
    }
}

def get_model(model_name, base_url=None, api_key=None):
    api_key = api_key or model_list[model_name]["api_key"]
    base_url = base_url or model_list[model_name]["base_url"]
    model = model_list[model_name]["model_name"]
    model_type = model_list[model_name]["model_type"]
    if model_type == "openai":
        return ChatOpenAI(
            temperature=0.3,
            timeout=15,
            base_url= base_url, model=model, api_key=api_key, max_retries=0)
    elif model_type == "groq":
        return ChatGroq(temperature=0.3, timeout=15, groq_api_key=api_key, model_name=model,max_retries=0)
    else:
        raise ValueError("Invalid model type. Supported types are 'openai' and 'groq'.")