from dataclasses import dataclass from enum import IntEnum import yaml from typing import Dict, Optional, List from pydantic import BaseModel, ValidationError from huggingface_hub import hf_hub_download from huggingface_hub.utils import EntryNotFoundError from openai import OpenAI class OAuthProvider(IntEnum): NONE = 0 GOOGLE = 1 @dataclass class User: oauth: OAuthProvider username: str permissions_id: str class PileConfig(BaseModel): file2persona: Dict[str, str] file2prefix: Dict[str, str] persona2system: Dict[str, str] prompt: str class InferenceConfig(BaseModel): chat_template: str permissions: Dict[str, list] = {} class RepoConfig(BaseModel): name: str tag: str class ModelConfig(BaseModel): pile: PileConfig inference: InferenceConfig repo: RepoConfig @classmethod def from_yaml(cls, yaml_file = "datasets/config.yaml"): with open(yaml_file, 'r') as file: data = yaml.safe_load(file) try: return cls(**data) except ValidationError as e: raise e class Client: def __init__(self, api_url, api_key, personas = {}): self.api_url = api_url self.api_key = api_key self.input_personas = personas self.init_all() def init_all(self): self.init_client() self.get_metadata() self.get_personas() def init_client(self): self.openai = OpenAI( base_url=f"{self.api_url}/v1", api_key=self.api_key, ) def get_metadata(self): models = self.openai.models.list() vllm_model_name = models.data[0].id model_name, *suffix = vllm_model_name.split("@") revision = dict(enumerate(suffix)).get(0, None) self.vllm_model_name = vllm_model_name self.model_name = model_name self.revision = revision def get_personas(self): personas = {} if self.revision is not None: try: config_path = hf_hub_download(self.model_name, "config.yaml", subfolder="datasets", revision=self.revision) self.config = ModelConfig.from_yaml(config_path) personas = self.config.pile.persona2system except EntryNotFoundError: pass personas["vanilla"] = None self.personas = self.input_personas | personas