Improve speed in combining per-channel data (#21563)

### Description
<!-- Describe your changes. -->
Improve speed in combining `per-channel` data for using a single
`np.concatenate` instead of multiple `np.concatenates` within a for
loop.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

Fix the issue https://github.com/microsoft/onnxruntime/issues/21562

Signed-off-by: duansheng.liu <44742794+duanshengliu@users.noreply.github.com>
This commit is contained in:
duanshengliu 2024-08-07 07:23:20 +08:00 коммит произвёл GitHub
Родитель 4ad87ca2e1
Коммит b95aa0563f
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
1 изменённых файлов: 5 добавлений и 9 удалений

Просмотреть файл

@ -418,6 +418,9 @@ class BaseQuantizer:
zero_point_list = []
scale_list = []
quantized_per_channel_data_list = []
weights_shape = list(weights.shape)
reshape_dims = list(weights_shape) # deep copy
reshape_dims[channel_axis] = 1 # only one per channel for reshape
for i in range(channel_count):
per_channel_data = weights.take(i, channel_axis)
channel_override_index = i if i < num_channel_overrides else 0
@ -460,17 +463,10 @@ class BaseQuantizer:
zero_point_list.append(zero_point)
scale_list.append(scale)
quantized_per_channel_data_list.append(quantized_per_channel_data)
quantized_per_channel_data_list.append(np.asarray(quantized_per_channel_data).reshape(reshape_dims))
# combine per_channel_data into one
weights_shape = list(weights.shape)
reshape_dims = list(weights_shape) # deep copy
reshape_dims[channel_axis] = 1 # only one per channel for reshape
quantized_weights = np.asarray(quantized_per_channel_data_list[0]).reshape(reshape_dims)
for i in range(1, len(quantized_per_channel_data_list)):
channel_weights = np.asarray(quantized_per_channel_data_list[i]).reshape(reshape_dims)
quantized_weights = np.concatenate((quantized_weights, channel_weights), channel_axis)
quantized_weights = np.concatenate(quantized_per_channel_data_list, channel_axis)
q_weight_name = weight_name + TENSOR_NAME_QUANT_SUFFIX
zp_name = weight_name + "_zero_point"
scale_name = weight_name + "_scale"