Spaces:
Runtime error
Runtime error
from .py_generate import PyGenerator | |
from .rs_generate import RsGenerator | |
from .go_generate import GoGenerator | |
from .generator_types import Generator | |
from .model import CodeLlama, ModelBase, GPT4, GPT35, StarChat, GPTDavinci, Samba, GPT4o, GroqBase | |
def generator_factory(lang: str) -> Generator: | |
if lang == "py" or lang == "python": | |
return PyGenerator() | |
elif lang == "rs" or lang == "rust": | |
return RsGenerator() | |
elif lang == "go" or lang == "golang": | |
return GoGenerator() | |
else: | |
raise ValueError(f"Invalid language for generator: {lang}") | |
def model_factory(model_name: str) -> ModelBase: | |
if model_name == "gpt-4": | |
return GPT4() | |
elif model_name == "gpt-4o": | |
return GPT4o() | |
elif model_name == "samba": | |
return Samba() | |
elif model_name == "groq": | |
return GroqBase() | |
elif model_name == "gpt-3.5-turbo-0613": | |
return GPT35() | |
elif model_name == "starchat": | |
return StarChat() | |
elif model_name.startswith("codellama"): | |
# if it has `-` in the name, version was specified | |
kwargs = {} | |
if "-" in model_name: | |
kwargs["version"] = model_name.split("-")[1] | |
return CodeLlama(**kwargs) | |
elif model_name.startswith("text-davinci"): | |
return GPTDavinci(model_name) | |
else: | |
raise ValueError(f"Invalid model name: {model_name}") | |