Improve speed in combining per-channel data (#21563)
### Description <!-- Describe your changes. --> Improve speed in combining `per-channel` data for using a single `np.concatenate` instead of multiple `np.concatenates` within a for loop. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Fix the issue https://github.com/microsoft/onnxruntime/issues/21562 Signed-off-by: duansheng.liu <44742794+duanshengliu@users.noreply.github.com>
This commit is contained in:
Родитель
4ad87ca2e1
Коммит
b95aa0563f
|
@ -418,6 +418,9 @@ class BaseQuantizer:
|
|||
zero_point_list = []
|
||||
scale_list = []
|
||||
quantized_per_channel_data_list = []
|
||||
weights_shape = list(weights.shape)
|
||||
reshape_dims = list(weights_shape) # deep copy
|
||||
reshape_dims[channel_axis] = 1 # only one per channel for reshape
|
||||
for i in range(channel_count):
|
||||
per_channel_data = weights.take(i, channel_axis)
|
||||
channel_override_index = i if i < num_channel_overrides else 0
|
||||
|
@ -460,17 +463,10 @@ class BaseQuantizer:
|
|||
|
||||
zero_point_list.append(zero_point)
|
||||
scale_list.append(scale)
|
||||
quantized_per_channel_data_list.append(quantized_per_channel_data)
|
||||
quantized_per_channel_data_list.append(np.asarray(quantized_per_channel_data).reshape(reshape_dims))
|
||||
|
||||
# combine per_channel_data into one
|
||||
weights_shape = list(weights.shape)
|
||||
reshape_dims = list(weights_shape) # deep copy
|
||||
reshape_dims[channel_axis] = 1 # only one per channel for reshape
|
||||
quantized_weights = np.asarray(quantized_per_channel_data_list[0]).reshape(reshape_dims)
|
||||
for i in range(1, len(quantized_per_channel_data_list)):
|
||||
channel_weights = np.asarray(quantized_per_channel_data_list[i]).reshape(reshape_dims)
|
||||
quantized_weights = np.concatenate((quantized_weights, channel_weights), channel_axis)
|
||||
|
||||
quantized_weights = np.concatenate(quantized_per_channel_data_list, channel_axis)
|
||||
q_weight_name = weight_name + TENSOR_NAME_QUANT_SUFFIX
|
||||
zp_name = weight_name + "_zero_point"
|
||||
scale_name = weight_name + "_scale"
|
||||
|
|
Загрузка…
Ссылка в новой задаче