Spaces:
Running
Running
from dataclasses import dataclass, field | |
import sys | |
class ExllamaConfig: | |
max_seq_len: int | |
gpu_split: str = None | |
cache_8bit: bool = False | |
class ExllamaModel: | |
def __init__(self, exllama_model, exllama_cache): | |
self.model = exllama_model | |
self.cache = exllama_cache | |
self.config = self.model.config | |
def load_exllama_model(model_path, exllama_config: ExllamaConfig): | |
try: | |
from exllamav2 import ( | |
ExLlamaV2Config, | |
ExLlamaV2Tokenizer, | |
ExLlamaV2, | |
ExLlamaV2Cache, | |
ExLlamaV2Cache_8bit, | |
) | |
except ImportError as e: | |
print(f"Error: Failed to load Exllamav2. {e}") | |
sys.exit(-1) | |
exllamav2_config = ExLlamaV2Config() | |
exllamav2_config.model_dir = model_path | |
exllamav2_config.prepare() | |
exllamav2_config.max_seq_len = exllama_config.max_seq_len | |
exllamav2_config.cache_8bit = exllama_config.cache_8bit | |
exllama_model = ExLlamaV2(exllamav2_config) | |
tokenizer = ExLlamaV2Tokenizer(exllamav2_config) | |
split = None | |
if exllama_config.gpu_split: | |
split = [float(alloc) for alloc in exllama_config.gpu_split.split(",")] | |
exllama_model.load(split) | |
cache_class = ExLlamaV2Cache_8bit if exllamav2_config.cache_8bit else ExLlamaV2Cache | |
exllama_cache = cache_class(exllama_model) | |
model = ExllamaModel(exllama_model=exllama_model, exllama_cache=exllama_cache) | |
return model, tokenizer | |