Fix issue where cublas linear not installed causing TypeError
Browse files- float8_quantize.py +2 -2
- lora_loading.py +1 -1
float8_quantize.py
CHANGED
@@ -371,7 +371,7 @@ def recursive_swap_linears(
|
|
371 |
|
372 |
@torch.inference_mode()
|
373 |
def swap_to_cublaslinear(model: nn.Module):
|
374 |
-
if
|
375 |
return
|
376 |
for name, child in model.named_children():
|
377 |
if isinstance(child, nn.Linear) and not isinstance(
|
@@ -485,7 +485,7 @@ def quantize_flow_transformer_and_dispatch_float8(
|
|
485 |
if (
|
486 |
swap_linears_with_cublaslinear
|
487 |
and flow_dtype == torch.float16
|
488 |
-
and
|
489 |
):
|
490 |
swap_to_cublaslinear(flow_model)
|
491 |
elif swap_linears_with_cublaslinear and flow_dtype != torch.float16:
|
|
|
371 |
|
372 |
@torch.inference_mode()
|
373 |
def swap_to_cublaslinear(model: nn.Module):
|
374 |
+
if CublasLinear == type(None):
|
375 |
return
|
376 |
for name, child in model.named_children():
|
377 |
if isinstance(child, nn.Linear) and not isinstance(
|
|
|
485 |
if (
|
486 |
swap_linears_with_cublaslinear
|
487 |
and flow_dtype == torch.float16
|
488 |
+
and CublasLinear != type(None)
|
489 |
):
|
490 |
swap_to_cublaslinear(flow_model)
|
491 |
elif swap_linears_with_cublaslinear and flow_dtype != torch.float16:
|
lora_loading.py
CHANGED
@@ -626,7 +626,7 @@ def extract_weight_from_linear(linear: Union[nn.Linear, CublasLinear, F8Linear])
|
|
626 |
)
|
627 |
elif isinstance(linear, torch.nn.Linear):
|
628 |
weight = linear.weight.clone().detach().float()
|
629 |
-
elif isinstance(linear, CublasLinear):
|
630 |
weight = linear.weight.clone().detach().float()
|
631 |
return weight, weight_is_f8, dtype
|
632 |
|
|
|
626 |
)
|
627 |
elif isinstance(linear, torch.nn.Linear):
|
628 |
weight = linear.weight.clone().detach().float()
|
629 |
+
elif isinstance(linear, CublasLinear) and CublasLinear != type(None):
|
630 |
weight = linear.weight.clone().detach().float()
|
631 |
return weight, weight_is_f8, dtype
|
632 |
|