Andrei Panferov commited on
Commit
f1a2023
1 Parent(s): 115e749

slightly faster inference

Browse files
Files changed (1) hide show
  1. inference.py +10 -8
inference.py CHANGED
@@ -161,6 +161,7 @@ def forward_pass_quantized_linear(
161
  "num_input_groups",
162
  "num_input_groups_next_power_of_2",
163
  "compute_in_fp32",
 
164
  ],
165
  )
166
  @triton.jit
@@ -180,6 +181,7 @@ def _aqlm_gemv_simple(
180
  num_input_groups: tl.constexpr,
181
  num_input_groups_next_power_of_2: tl.constexpr,
182
  compute_in_fp32: tl.constexpr,
 
183
  UNUSED: tl.constexpr,
184
  ):
185
  # variables ending with "_i" mean "for i-th output unit"
@@ -188,11 +190,11 @@ def _aqlm_gemv_simple(
188
  # Stage 1: load input data
189
  input_vec = tl.load(
190
  input_vec_ptr
191
- + tl.arange(0, num_input_groups_next_power_of_2)[:, None, None] * in_group_size
192
- + tl.arange(0, in_group_size)[None, None, :],
193
- mask=tl.arange(0, num_input_groups_next_power_of_2)[:, None, None] < num_input_groups,
194
  )
195
- # [in_features//in_group_size, 1, group_size]
196
  # Note: we could simply load input_vec then reshape
197
  # input_vec = tl.load(input_vec_ptr + tl.arange(0, in_features)) # [in_features]
198
  # input_vec = tl.view(input_vec, [num_input_groups, 1, in_group_size])
@@ -237,19 +239,17 @@ def _aqlm_gemv_simple(
237
  weights_i = weights_i.to(tl.float32)
238
  input_vec = input_vec.to(tl.float32)
239
  # ^-- [in_features // in_group_size, num_codebooks, out_group_size, in_group_size]
240
- weights_i = tl.sum(weights_i, axis=1) # sum codebooks as per additive quantization
241
- # ^-- [in_features // in_group_size, out_group_size, in_group_size]
242
 
243
  if out_group_size == 1:
244
  scale = tl.load(scales_ptr + pid).to(weights_i.dtype) # scalar
245
  output_i = tl.sum(weights_i * input_vec) * scale
246
- if bias_ptr:
247
  output_i += tl.load(bias_ptr + pid).to(weights_i.dtype)
248
  tl.store(output_vec_ptr + pid, output_i.to(input_vec.dtype))
249
  else:
250
  output_i = tl.sum(tl.sum(weights_i, axis=2) * input_vec, axis=0) # [out_group_size]
251
  output_i *= tl.load(scales_ptr + pid).to(weights_i.dtype)
252
- if bias_ptr:
253
  output_i += tl.load(bias_ptr + pid).to(weights_i.dtype)
254
  tl.store(output_vec_ptr + pid * out_group_size + tl.arange(0, out_group_size), output_i.to(input_vec.dtype))
255
 
@@ -296,6 +296,7 @@ def aqlm_gemv_simple(
296
  num_input_groups,
297
  next_power_of_2(num_input_groups),
298
  compute_in_fp32,
 
299
  )
300
 
301
  return output_vec
@@ -339,6 +340,7 @@ def aqlm_gemm_stupid(
339
  num_input_groups,
340
  next_power_of_2(num_input_groups),
341
  compute_in_fp32,
 
342
  )
343
 
344
  return output
 
161
  "num_input_groups",
162
  "num_input_groups_next_power_of_2",
163
  "compute_in_fp32",
164
+ "has_bias",
165
  ],
166
  )
167
  @triton.jit
 
181
  num_input_groups: tl.constexpr,
182
  num_input_groups_next_power_of_2: tl.constexpr,
183
  compute_in_fp32: tl.constexpr,
184
+ has_bias: tl.constexpr,
185
  UNUSED: tl.constexpr,
186
  ):
187
  # variables ending with "_i" mean "for i-th output unit"
 
190
  # Stage 1: load input data
191
  input_vec = tl.load(
192
  input_vec_ptr
193
+ + tl.arange(0, num_input_groups_next_power_of_2)[:, None, None, None] * in_group_size
194
+ + tl.arange(0, in_group_size)[None, None, None, :],
195
+ mask=tl.arange(0, num_input_groups_next_power_of_2)[:, None, None, None] < num_input_groups,
196
  )
197
+ # [in_features//in_group_size, 1, 1, group_size]
198
  # Note: we could simply load input_vec then reshape
199
  # input_vec = tl.load(input_vec_ptr + tl.arange(0, in_features)) # [in_features]
200
  # input_vec = tl.view(input_vec, [num_input_groups, 1, in_group_size])
 
239
  weights_i = weights_i.to(tl.float32)
240
  input_vec = input_vec.to(tl.float32)
241
  # ^-- [in_features // in_group_size, num_codebooks, out_group_size, in_group_size]
 
 
242
 
243
  if out_group_size == 1:
244
  scale = tl.load(scales_ptr + pid).to(weights_i.dtype) # scalar
245
  output_i = tl.sum(weights_i * input_vec) * scale
246
+ if has_bias:
247
  output_i += tl.load(bias_ptr + pid).to(weights_i.dtype)
248
  tl.store(output_vec_ptr + pid, output_i.to(input_vec.dtype))
249
  else:
250
  output_i = tl.sum(tl.sum(weights_i, axis=2) * input_vec, axis=0) # [out_group_size]
251
  output_i *= tl.load(scales_ptr + pid).to(weights_i.dtype)
252
+ if has_bias:
253
  output_i += tl.load(bias_ptr + pid).to(weights_i.dtype)
254
  tl.store(output_vec_ptr + pid * out_group_size + tl.arange(0, out_group_size), output_i.to(input_vec.dtype))
255
 
 
296
  num_input_groups,
297
  next_power_of_2(num_input_groups),
298
  compute_in_fp32,
299
+ bias is not None,
300
  )
301
 
302
  return output_vec
 
340
  num_input_groups,
341
  next_power_of_2(num_input_groups),
342
  compute_in_fp32,
343
+ bias is not None,
344
  )
345
 
346
  return output