Safetensors
aredden commited on
Commit
56c313c
1 Parent(s): 9e376d8

Fix issue where cublas linear not installed causing TypeError

Browse files
Files changed (2) hide show
  1. float8_quantize.py +2 -2
  2. lora_loading.py +1 -1
float8_quantize.py CHANGED
@@ -371,7 +371,7 @@ def recursive_swap_linears(
371
 
372
  @torch.inference_mode()
373
  def swap_to_cublaslinear(model: nn.Module):
374
- if not isinstance(CublasLinear, type(torch.nn.Module)):
375
  return
376
  for name, child in model.named_children():
377
  if isinstance(child, nn.Linear) and not isinstance(
@@ -485,7 +485,7 @@ def quantize_flow_transformer_and_dispatch_float8(
485
  if (
486
  swap_linears_with_cublaslinear
487
  and flow_dtype == torch.float16
488
- and isinstance(CublasLinear, type(torch.nn.Linear))
489
  ):
490
  swap_to_cublaslinear(flow_model)
491
  elif swap_linears_with_cublaslinear and flow_dtype != torch.float16:
 
371
 
372
  @torch.inference_mode()
373
  def swap_to_cublaslinear(model: nn.Module):
374
+ if CublasLinear == type(None):
375
  return
376
  for name, child in model.named_children():
377
  if isinstance(child, nn.Linear) and not isinstance(
 
485
  if (
486
  swap_linears_with_cublaslinear
487
  and flow_dtype == torch.float16
488
+ and CublasLinear != type(None)
489
  ):
490
  swap_to_cublaslinear(flow_model)
491
  elif swap_linears_with_cublaslinear and flow_dtype != torch.float16:
lora_loading.py CHANGED
@@ -626,7 +626,7 @@ def extract_weight_from_linear(linear: Union[nn.Linear, CublasLinear, F8Linear])
626
  )
627
  elif isinstance(linear, torch.nn.Linear):
628
  weight = linear.weight.clone().detach().float()
629
- elif isinstance(linear, CublasLinear):
630
  weight = linear.weight.clone().detach().float()
631
  return weight, weight_is_f8, dtype
632
 
 
626
  )
627
  elif isinstance(linear, torch.nn.Linear):
628
  weight = linear.weight.clone().detach().float()
629
+ elif isinstance(linear, CublasLinear) and CublasLinear != type(None):
630
  weight = linear.weight.clone().detach().float()
631
  return weight, weight_is_f8, dtype
632