juanma1907's picture
Upload folder using huggingface_hub
462dacf
# from abc import ABC
import torch
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.utils.cpp_extension import load
import os
import sys
import platform
library_dir = os.path.dirname(os.path.abspath(__file__))
extension_name = "exllama_ext"
verbose = False
# another kludge to get things compiling in Windows
windows = os.name == "nt"
if windows:
def find_msvc():
for msvc_dir in [a + "\\Microsoft Visual Studio\\" + b + "\\" + c + "\\VC\Tools\\MSVC\\"
for b in ["2022", "2019", "2017"]
for a in [os.environ["ProgramW6432"], os.environ["ProgramFiles(x86)"]]
for c in ["BuildTools", "Community", "Professional", "Enterprise", "Preview"]
]:
if not os.path.exists(msvc_dir):
continue
versions = sorted(os.listdir(msvc_dir), reverse=True)
for version in versions:
compiler_dir = msvc_dir + version + "\\bin\\Hostx64\\x64"
if os.path.exists(compiler_dir) and os.path.exists(compiler_dir + "\\cl.exe"):
return compiler_dir
return None
import subprocess
try:
subprocess.check_output(["where", "/Q", "cl"])
except subprocess.CalledProcessError as e:
cl_path = find_msvc()
if cl_path:
if verbose:
print("Injected compiler path:", cl_path)
os.environ["path"] += ";" + cl_path
else:
print("Unable to find cl.exe; compilation will probably fail.", file=sys.stderr)
exllama_ext = load(
name = extension_name,
sources = [
os.path.join(library_dir, "exllama_ext/exllama_ext.cpp"),
os.path.join(library_dir, "exllama_ext/cuda_buffers.cu"),
os.path.join(library_dir, "exllama_ext/cuda_func/q4_matrix.cu"),
os.path.join(library_dir, "exllama_ext/cuda_func/q4_matmul.cu"),
os.path.join(library_dir, "exllama_ext/cuda_func/column_remap.cu"),
os.path.join(library_dir, "exllama_ext/cuda_func/rms_norm.cu"),
os.path.join(library_dir, "exllama_ext/cuda_func/rope.cu"),
os.path.join(library_dir, "exllama_ext/cuda_func/half_matmul.cu"),
os.path.join(library_dir, "exllama_ext/cuda_func/q4_attn.cu"),
os.path.join(library_dir, "exllama_ext/cuda_func/q4_mlp.cu"),
os.path.join(library_dir, "exllama_ext/cpu_func/rep_penalty.cpp")
],
extra_include_paths = [os.path.join(library_dir, "exllama_ext")],
verbose = verbose,
extra_ldflags = (["cublas.lib"] + ([f"/LIBPATH:{os.path.join(sys.base_prefix, 'libs')}"] if sys.base_prefix != sys.prefix else [])) if windows else [],
extra_cuda_cflags = ["-lineinfo"] + (["-U__HIP_NO_HALF_CONVERSIONS__", "-O3"] if torch.version.hip else []),
extra_cflags = ["-O3"]
# extra_cflags = ["-ftime-report", "-DTORCH_USE_CUDA_DSA"]
)
# from exllama_ext import set_tuning_params
# from exllama_ext import prepare_buffers
from exllama_ext import make_q4
from exllama_ext import q4_matmul
from exllama_ext import q4_matmul_lora
from exllama_ext import half_matmul
from exllama_ext import half_matmul_cublas
# from exllama_ext import q4_mlp
from exllama_ext import rms_norm
from exllama_ext import rope_
from exllama_ext import rep_penalty
from exllama_ext import apply_rep_penalty
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor = torch.empty((1, 1), device = "meta")
# Construct Q4Matrix, return handle
def ext_make_q4(qweight, qzeros, scales, g_idx, device):
return make_q4(qweight,
qzeros,
scales,
g_idx if g_idx is not None else none_tensor,
device)
# Matrix multiplication, returns x @ q4
def ext_q4_matmul(x, q4, q4_width, lora_A = None, lora_B = None):
outshape = x.shape[:-1] + (q4_width,)
x = x.view(-1, x.shape[-1])
output = torch.empty((x.shape[0], q4_width), dtype = torch.float16, device = x.device)
if lora_A is None:
q4_matmul(x, q4, output)
else:
lora_temp = torch.empty((x.shape[0], lora_A.shape[1]), dtype = torch.float16, device = x.device)
q4_matmul_lora(x, q4, output, lora_A, lora_B, lora_temp)
return output.view(outshape)
# Matrix multiplication, returns x @ w, both half-precision tensors
def ext_half_matmul(x, w, cublas = False):
outshape = x.shape[:-1] + (w.shape[1],)
x = x.view(-1, x.shape[-1])
if cublas:
output = torch.empty((x.shape[0], w.shape[1]), dtype = torch.float16, device = x.device)
half_matmul_cublas(x, w, output)
else:
output = torch.zeros((x.shape[0], w.shape[1]), dtype = torch.float16, device = x.device)
half_matmul(x, w, output)
return output.view(outshape) ##
# RoPE embeddings, in_place
def ext_rope_(x, sin, cos, past_len, num_heads, head_dim):
rope_(x, sin, cos, past_len, num_heads, head_dim)
# RMS norm: x = x * w / sqrt(row_mean(x * x) + epsilon)
def ext_rms_norm(x, w, epsilon):
outshape = x.shape
x = x.view(-1, x.shape[-1])
output = torch.empty_like(x)
rms_norm(x, w, output, epsilon)
return output.view(outshape)
def ext_rms_norm_(x, w, epsilon):
outshape = x.shape
x = x.view(-1, x.shape[-1])
rms_norm(x, w, x, epsilon)
# Repetition penalty
def ext_rep_penalty_mask_cpu(vocab_size, sequence, penalty_max, sustain, decay):
rep_mask = torch.empty(vocab_size, dtype = torch.float32)
rep_penalty(sequence, rep_mask, penalty_max, sustain, decay)
return rep_mask
def ext_apply_rep_penalty_mask_cpu(sequence, penalty_max, sustain, decay, logits):
apply_rep_penalty(sequence, penalty_max, sustain, decay, logits)