[NNVM][DARKNET]Logistic activation added (#1477)
This commit is contained in:
Родитель
5846775fba
Коммит
55a08deca0
|
@ -301,7 +301,9 @@ def _darknet_region(inputs, attrs):
|
|||
def _darknet_activations(inputs, attrs):
|
||||
"""Process the activation function."""
|
||||
act = _darknet_required_attr(attrs, 'activation')
|
||||
if ACTIVATION.RELU == act:
|
||||
if ACTIVATION.LOGISTIC == act:
|
||||
act_type = 'sigmoid'
|
||||
elif ACTIVATION.RELU == act:
|
||||
act_type = 'relu'
|
||||
elif ACTIVATION.TANH == act:
|
||||
act_type = 'tanh'
|
||||
|
@ -323,6 +325,9 @@ def _darknet_activations(inputs, attrs):
|
|||
sym = _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs)
|
||||
elif act_type in ['elu']:
|
||||
sym = -1 * _sym.relu(1 - _sym.exp(*inputs)) + _sym.relu(*inputs)
|
||||
elif act_type in ['sigmoid']:
|
||||
op_name, new_attrs = act_type, {}
|
||||
sym = _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs)
|
||||
else:
|
||||
_darknet_raise_not_supported('act_type: ' + act_type)
|
||||
return sym, None
|
||||
|
|
|
@ -324,6 +324,32 @@ def test_forward_rnn():
|
|||
test_rnn_forward(net)
|
||||
LIB.free_network(net)
|
||||
|
||||
def test_forward_activation_logistic():
|
||||
'''test logistic activation layer'''
|
||||
net = LIB.make_network(1)
|
||||
batch = 1
|
||||
h = 224
|
||||
w = 224
|
||||
c = 3
|
||||
n = 32
|
||||
groups = 1
|
||||
size = 3
|
||||
stride = 2
|
||||
padding = 0
|
||||
activation = 0
|
||||
batch_normalize = 0
|
||||
binary = 0
|
||||
xnor = 0
|
||||
adam = 0
|
||||
layer_1 = LIB.make_convolutional_layer(batch, h, w, c, n, groups, size, stride, padding,
|
||||
activation, batch_normalize, binary, xnor, adam)
|
||||
net.layers[0] = layer_1
|
||||
net.w = w
|
||||
net.h = h
|
||||
LIB.resize_network(net, net.w, net.h)
|
||||
test_forward(net)
|
||||
LIB.free_network(net)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_forward_resnet50()
|
||||
test_forward_alexnet()
|
||||
|
@ -342,3 +368,5 @@ if __name__ == '__main__':
|
|||
test_forward_reorg()
|
||||
test_forward_region()
|
||||
test_forward_elu()
|
||||
test_forward_rnn()
|
||||
test_forward_activation_logistic()
|
Загрузка…
Ссылка в новой задаче