FIRE / src /modules /xfastertransformer.py
zhangbofei
feat: change to fstchat
6dc0c9c
raw
history blame
1.3 kB
from dataclasses import dataclass
import sys
@dataclass
class XftConfig:
max_seq_len: int = 4096
beam_width: int = 1
eos_token_id: int = -1
pad_token_id: int = -1
num_return_sequences: int = 1
is_encoder_decoder: bool = False
padding: bool = True
early_stopping: bool = False
data_type: str = "bf16_fp16"
class XftModel:
def __init__(self, xft_model, xft_config):
self.model = xft_model
self.config = xft_config
def load_xft_model(model_path, xft_config: XftConfig):
try:
import xfastertransformer
from transformers import AutoTokenizer
except ImportError as e:
print(f"Error: Failed to load xFasterTransformer. {e}")
sys.exit(-1)
if xft_config.data_type is None or xft_config.data_type == "":
data_type = "bf16_fp16"
else:
data_type = xft_config.data_type
tokenizer = AutoTokenizer.from_pretrained(
model_path, use_fast=False, padding_side="left", trust_remote_code=True
)
xft_model = xfastertransformer.AutoModel.from_pretrained(
model_path, dtype=data_type
)
model = XftModel(xft_model=xft_model, xft_config=xft_config)
if model.model.rank > 0:
while True:
model.model.generate()
return model, tokenizer