[AutoTVM] fix argument type for curve feature (#3004)
This commit is contained in:
Родитель
5178506255
Коммит
5a27632e27
|
@ -514,10 +514,10 @@ TVM_REGISTER_API("autotvm.feature.GetItervarFeatureFlatten")
|
|||
TVM_REGISTER_API("autotvm.feature.GetCurveSampleFeatureFlatten")
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||
Stmt stmt = args[0];
|
||||
bool take_log = args[1];
|
||||
int sample_n = args[1];
|
||||
std::vector<float> ret_feature;
|
||||
|
||||
GetCurveSampleFeatureFlatten(stmt, take_log, &ret_feature);
|
||||
GetCurveSampleFeatureFlatten(stmt, sample_n, &ret_feature);
|
||||
|
||||
TVMByteArray arr;
|
||||
arr.size = sizeof(float) * ret_feature.size();
|
||||
|
|
|
@ -61,6 +61,23 @@ def test_iter_feature_gemm():
|
|||
assert ans[pair[0]] == pair[1:], "%s: %s vs %s" % (pair[0], ans[pair[0]], pair[1:])
|
||||
|
||||
|
||||
def test_curve_feature_gemm():
|
||||
N = 128
|
||||
|
||||
k = tvm.reduce_axis((0, N), 'k')
|
||||
A = tvm.placeholder((N, N), name='A')
|
||||
B = tvm.placeholder((N, N), name='B')
|
||||
C = tvm.compute(
|
||||
A.shape,
|
||||
lambda y, x: tvm.sum(A[y, k] * B[k, x], axis=k),
|
||||
name='C')
|
||||
|
||||
s = tvm.create_schedule(C.op)
|
||||
|
||||
feas = feature.get_buffer_curve_sample_flatten(s, [A, B, C], sample_n=30)
|
||||
# sample_n * #buffers * #curves * 2 numbers per curve
|
||||
assert len(feas) == 30 * 3 * 4 * 2
|
||||
|
||||
def test_feature_shape():
|
||||
"""test the dimensions of flatten feature are the same"""
|
||||
|
||||
|
@ -112,4 +129,6 @@ def test_feature_shape():
|
|||
|
||||
if __name__ == "__main__":
|
||||
test_iter_feature_gemm()
|
||||
test_curve_feature_gemm()
|
||||
test_feature_shape()
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче