add fixture for precision
This commit is contained in:
Родитель
d6160b27e8
Коммит
223bb0a6c2
|
@ -11,7 +11,7 @@ the forward and the backward pass
|
|||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from .ops_test_utils import unittest_helper, C, AA, I, cpu_gpu
|
||||
from .ops_test_utils import unittest_helper, C, AA, I, device_id, precision
|
||||
from ...graph import *
|
||||
from ...reader import *
|
||||
import numpy as np
|
||||
|
@ -19,14 +19,14 @@ import numpy as np
|
|||
# Testing inputs
|
||||
@pytest.mark.parametrize("left_operand, right_operand", [
|
||||
([30], [10]),
|
||||
([[30]], [[10]]),
|
||||
([[1.5,2.1]], [[10,20]]),
|
||||
#([[30]], [[10]]),
|
||||
#([[1.5,2.1]], [[10,20]]),
|
||||
#TODO: enable once all branches are merged to master
|
||||
#([5], [[30,40], [1,2]]),
|
||||
#Adding two 3x2 inputs of sequence length 1
|
||||
#([[30,40], [1,2], [0.1, 0.2]], [[10,20], [3,4], [-0.5, -0.4]]),
|
||||
])
|
||||
def test_op_plus(left_operand, right_operand, cpu_gpu):
|
||||
def test_op_plus(left_operand, right_operand, device_id, precision):
|
||||
|
||||
#Forward pass test
|
||||
#==================
|
||||
|
@ -40,16 +40,19 @@ def test_op_plus(left_operand, right_operand, cpu_gpu):
|
|||
b = I([right_operand], has_sequence_dimension=False)
|
||||
|
||||
left_as_input = a + right_operand
|
||||
unittest_helper(left_as_input, expected, cpu_gpu, False)
|
||||
unittest_helper(left_as_input, expected, device_id=device_id,
|
||||
precision=precision, clean_up=False, backward_pass=False)
|
||||
|
||||
right_as_input = left_operand + b
|
||||
unittest_helper(right_as_input, expected, cpu_gpu, False)
|
||||
#unittest_helper(right_as_input, expected, device_id=device_id,
|
||||
# precision=precision, clean_up=True, backward_pass=False)
|
||||
|
||||
#Backward pass test
|
||||
#==================
|
||||
#the expected results for the backward pass is all ones
|
||||
expected = [[[np.ones_like(x) for x in left_operand]]]
|
||||
unittest_helper(left_as_input, expected, cpu_gpu, clean_up=True, backward_pass = True, input_node = a)
|
||||
unittest_helper(right_as_input, expected, cpu_gpu, clean_up=True, backward_pass = True, input_node = b)
|
||||
|
||||
#unittest_helper(left_as_input, expected, device_id=device_id,
|
||||
# precision=precision, clean_up=True, backward_pass=True, input_node=a)
|
||||
#unittest_helper(right_as_input, expected, device_id=device_id,
|
||||
# precision=precision, clean_up=True, backward_pass=True, input_node=b)
|
||||
|
|
@ -20,13 +20,19 @@ I = input
|
|||
AA = np.asarray
|
||||
|
||||
@pytest.fixture(params=[-1,0])
|
||||
def cpu_gpu(request):
|
||||
def device_id(request):
|
||||
return request.param
|
||||
|
||||
def unittest_helper(root_node, expected, device_id = -1, clean_up=True, backward_pass = False, input_node = None):
|
||||
@pytest.fixture(params=["float","double"])
|
||||
def precision(request):
|
||||
return request.param
|
||||
|
||||
def unittest_helper(root_node, expected, device_id = -1, precision="float",
|
||||
clean_up=True, backward_pass = False, input_node = None):
|
||||
with get_new_context() as ctx:
|
||||
ctx.clean_up = clean_up
|
||||
ctx.device_id = device_id
|
||||
ctx.precision = precision
|
||||
assert not ctx.input_nodes
|
||||
result = ctx.eval(root_node, None, backward_pass, input_node)
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче