Spaces:
Running
Running
from dataclasses import dataclass | |
from enum import IntEnum | |
import yaml | |
from typing import Dict, Optional, List | |
from pydantic import BaseModel, ValidationError | |
from huggingface_hub import hf_hub_download | |
from huggingface_hub.utils import EntryNotFoundError | |
from openai import OpenAI | |
class OAuthProvider(IntEnum): | |
NONE = 0 | |
GOOGLE = 1 | |
class User: | |
oauth: OAuthProvider | |
username: str | |
permissions_id: str | |
class PileConfig(BaseModel): | |
file2persona: Dict[str, str] | |
file2prefix: Dict[str, str] | |
persona2system: Dict[str, str] | |
prompt: str | |
class InferenceConfig(BaseModel): | |
chat_template: str | |
permissions: Dict[str, list] = {} | |
class RepoConfig(BaseModel): | |
name: str | |
class ModelConfig(BaseModel): | |
pile: PileConfig | |
inference: InferenceConfig | |
repo: RepoConfig | |
def from_yaml(cls, yaml_file = "datasets/config.yaml"): | |
with open(yaml_file, 'r') as file: | |
data = yaml.safe_load(file) | |
try: | |
return cls(**data) | |
except ValidationError as e: | |
raise e | |
class Client: | |
def __init__(self, api_url, api_key, personas = {}): | |
self.api_url = api_url | |
self.api_key = api_key | |
self.input_personas = personas | |
self.init_all() | |
def init_all(self): | |
self.init_client() | |
self.get_metadata() | |
self.get_personas() | |
def init_client(self): | |
self.openai = OpenAI( | |
base_url=f"{self.api_url}/v1", | |
api_key=self.api_key, | |
) | |
def get_metadata(self): | |
models = self.openai.models.list() | |
vllm_model_name = models.data[0].id | |
model_name, *suffix = vllm_model_name.split("@") | |
revision = dict(enumerate(suffix)).get(0, None) | |
self.vllm_model_name = vllm_model_name | |
self.model_name = model_name | |
self.revision = revision | |
def get_personas(self): | |
personas = {} | |
if self.revision is not None: | |
try: | |
config_path = hf_hub_download(self.model_name, "config.yaml", | |
subfolder="datasets", | |
revision=self.revision) | |
self.config = ModelConfig.from_yaml(config_path) | |
personas = self.config.pile.persona2system | |
except EntryNotFoundError: | |
pass | |
personas["vanilla"] = None | |
self.personas = self.input_personas | personas |