[RELAY] Add missing arg in vgg (#2329)
This commit is contained in:
Родитель
14acb80adc
Коммит
dc8fd79c43
|
@ -98,7 +98,8 @@ def get_workload(batch_size,
|
|||
num_classes=1000,
|
||||
image_shape=(3, 224, 224),
|
||||
dtype="float32",
|
||||
num_layers=11):
|
||||
num_layers=11,
|
||||
batch_norm=False):
|
||||
"""Get benchmark workload for VGG nets.
|
||||
|
||||
Parameters
|
||||
|
@ -118,6 +119,9 @@ def get_workload(batch_size,
|
|||
num_layers : int
|
||||
Number of layers for the variant of vgg. Options are 11, 13, 16, 19.
|
||||
|
||||
batch_norm : bool
|
||||
Use batch normalization.
|
||||
|
||||
Returns
|
||||
-------
|
||||
net : nnvm.Symbol
|
||||
|
@ -126,5 +130,5 @@ def get_workload(batch_size,
|
|||
params : dict of str to NDArray
|
||||
The parameters.
|
||||
"""
|
||||
net = get_net(batch_size, image_shape, num_classes, dtype, num_layers)
|
||||
net = get_net(batch_size, image_shape, num_classes, dtype, num_layers, batch_norm)
|
||||
return create_workload(net)
|
||||
|
|
Загрузка…
Ссылка в новой задаче