update WeightDtype params
Browse files- src/display/utils.py +7 -1
src/display/utils.py
CHANGED
@@ -237,19 +237,25 @@ class WeightDtype(Enum):
|
|
237 |
nf4 = ModelDetails("nf4")
|
238 |
fp4 = ModelDetails("fp4")
|
239 |
|
|
|
240 |
Unknown = ModelDetails("?")
|
241 |
|
|
|
|
|
|
|
242 |
def from_str(weight_dtype):
|
243 |
if weight_dtype in ["int2"]:
|
244 |
return WeightDtype.int2
|
245 |
if weight_dtype in ["int3"]:
|
246 |
-
return WeightDtype.int3
|
247 |
if weight_dtype in ["int4"]:
|
248 |
return WeightDtype.int4
|
249 |
if weight_dtype in ["nf4"]:
|
250 |
return WeightDtype.nf4
|
251 |
if weight_dtype in ["fp4"]:
|
252 |
return WeightDtype.fp4
|
|
|
|
|
253 |
return WeightDtype.Unknown
|
254 |
|
255 |
class ComputeDtype(Enum):
|
|
|
237 |
nf4 = ModelDetails("nf4")
|
238 |
fp4 = ModelDetails("fp4")
|
239 |
|
240 |
+
|
241 |
Unknown = ModelDetails("?")
|
242 |
|
243 |
+
all = ModelDetails("All")
|
244 |
+
|
245 |
+
|
246 |
def from_str(weight_dtype):
|
247 |
if weight_dtype in ["int2"]:
|
248 |
return WeightDtype.int2
|
249 |
if weight_dtype in ["int3"]:
|
250 |
+
return WeightDtype.int3
|
251 |
if weight_dtype in ["int4"]:
|
252 |
return WeightDtype.int4
|
253 |
if weight_dtype in ["nf4"]:
|
254 |
return WeightDtype.nf4
|
255 |
if weight_dtype in ["fp4"]:
|
256 |
return WeightDtype.fp4
|
257 |
+
if weight_dtype in ["All"]:
|
258 |
+
return WeightDtype.all
|
259 |
return WeightDtype.Unknown
|
260 |
|
261 |
class ComputeDtype(Enum):
|