[AutoTVM] fix argument type for curve feature (#3004)

This commit is contained in:
Lianmin Zheng 2019-04-11 10:58:54 +08:00 коммит произвёл GitHub
Родитель 5178506255
Коммит 5a27632e27
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 21 добавлений и 2 удалений

Просмотреть файл

@ -514,10 +514,10 @@ TVM_REGISTER_API("autotvm.feature.GetItervarFeatureFlatten")
TVM_REGISTER_API("autotvm.feature.GetCurveSampleFeatureFlatten")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Stmt stmt = args[0];
bool take_log = args[1];
int sample_n = args[1];
std::vector<float> ret_feature;
GetCurveSampleFeatureFlatten(stmt, take_log, &ret_feature);
GetCurveSampleFeatureFlatten(stmt, sample_n, &ret_feature);
TVMByteArray arr;
arr.size = sizeof(float) * ret_feature.size();

Просмотреть файл

@ -61,6 +61,23 @@ def test_iter_feature_gemm():
assert ans[pair[0]] == pair[1:], "%s: %s vs %s" % (pair[0], ans[pair[0]], pair[1:])
def test_curve_feature_gemm():
N = 128
k = tvm.reduce_axis((0, N), 'k')
A = tvm.placeholder((N, N), name='A')
B = tvm.placeholder((N, N), name='B')
C = tvm.compute(
A.shape,
lambda y, x: tvm.sum(A[y, k] * B[k, x], axis=k),
name='C')
s = tvm.create_schedule(C.op)
feas = feature.get_buffer_curve_sample_flatten(s, [A, B, C], sample_n=30)
# sample_n * #buffers * #curves * 2 numbers per curve
assert len(feas) == 30 * 3 * 4 * 2
def test_feature_shape():
"""test the dimensions of flatten feature are the same"""
@ -112,4 +129,6 @@ def test_feature_shape():
if __name__ == "__main__":
test_iter_feature_gemm()
test_curve_feature_gemm()
test_feature_shape()