|
from pydantic import BaseModel, field_validator |
|
from typing import Optional |
|
import os |
|
from llmdantic import LLMdantic, LLMdanticConfig |
|
from sambanova.langchain_wrappers import SambaNovaFastAPI |
|
from dotenv import load_dotenv |
|
from llmdantic import LLMdanticResult |
|
|
|
|
|
current_dir = os.getcwd() |
|
utils_dir = os.path.abspath(os.path.join(current_dir, '..')) |
|
load_dotenv(os.path.join(utils_dir, '.env'), override=True) |
|
|
|
|
|
|
|
class Catergories_Classify_Input(BaseModel): |
|
text: str |
|
|
|
class Catergories_Classify_Output(BaseModel): |
|
result: str |
|
|
|
@field_validator("result") |
|
def catergory_result_must_not_be_empty(cls, v) -> bool: |
|
"""Category result must not be empty""" |
|
if not v.strip(): |
|
raise ValueError("Category result must not be empty") |
|
return v |
|
|
|
|
|
class Evaluator: |
|
def __init__(self, llm : Optional[str], prompt: str): |
|
self.llm = SambaNovaFastAPI(model=llm, fastapi_url = "https://fast-api.snova.ai/v1/chat/completions" , fastapi_api_key = "dHVhbmFuaC5uay4xOF9fZ21haWwuY29tOlRWbG9yQkxhNUY=") |
|
self.prompt = prompt |
|
self.config = LLMdanticConfig( |
|
objective=self.prompt, |
|
inp_schema=Catergories_Classify_Input, |
|
out_schema=Catergories_Classify_Output, |
|
retries=5, |
|
) |
|
self.llmdantic = LLMdantic(llm=self.llm, config=self.config) |
|
|
|
def classify_text(self, text: str) -> Optional[Catergories_Classify_Output]: |
|
data = Catergories_Classify_Input(text=text) |
|
result: LLMdanticResult = self.llmdantic.invoke(data) |
|
return result.output |