* fix dynamic backend slot number bug, set snapshot default value to 0 to align with static backend, update related tests

* remove supply chain scenario

* correct the scenario number
This commit is contained in:
Chaos Yu 2021-04-13 14:23:33 +08:00 коммит произвёл GitHub
Родитель 0f380f44f9
Коммит 803c3fc2ee
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
13 изменённых файлов: 124 добавлений и 36 удалений

Просмотреть файл

@ -233,6 +233,8 @@ cdef class RawBackend(BackendAbc):
cdef list slots_not_equal(self, NODE_INDEX index, ATTR_TYPE attr_type, object value) except +:
return self.where(index, attr_type, lambda x : x != value)
cdef SLOT_INDEX get_slot_number(self, NODE_INDEX index, ATTR_TYPE attr_type) except +:
return self._frame.get_slot_number(index, attr_type)
cdef class RawSnapshotList(SnapshotListAbc):
def __cinit__(self, RawBackend backend, USHORT total_snapshots):

Просмотреть файл

@ -139,7 +139,7 @@ cdef class BackendAbc:
# Filter slots that greater than specified value.
cdef list slots_greater_than(self, NODE_INDEX index, ATTR_TYPE attr_type, object value) except +
# Filter slots that greater equeal to specified value.
# Filter slots that greater equal to specified value.
cdef list slots_greater_equal(self, NODE_INDEX index, ATTR_TYPE attr_type, object value) except +
# Filter slots that less than specified value.
@ -151,5 +151,8 @@ cdef class BackendAbc:
# Filter slots that equal to specified value.
cdef list slots_equal(self, NODE_INDEX index, ATTR_TYPE attr_type, object value) except +
# Filter slots that not euqal to specified value.
# Filter slots that not equal to specified value.
cdef list slots_not_equal(self, NODE_INDEX index, ATTR_TYPE attr_type, object value) except +
# Get slot number for specified attribute, only support dynamic backend.
cdef SLOT_INDEX get_slot_number(self, NODE_INDEX index, ATTR_TYPE attr_type) except +

Просмотреть файл

@ -126,3 +126,6 @@ cdef class BackendAbc:
cdef list slots_not_equal(self, NODE_INDEX index, ATTR_TYPE attr_type, object value) except +:
pass
cdef SLOT_INDEX get_slot_number(self, NODE_INDEX index, ATTR_TYPE attr_type) except +:
pass

Просмотреть файл

@ -142,14 +142,17 @@ cdef class _NodeAttributeAccessor:
Current attribute must be a list.
Args:
value(object): Value to append, the data type must fit the decleared one.
value(object): Value to append, the data type must fit the declared one.
"""
if not self._is_list:
raise BackendsAppendToNonListAttributeException()
self._backend.append_to_list(self._node_index, self._attr_type, value)
self._slot_number += 1
self._slot_number = self._backend.get_slot_number(self._node_index, self._attr_type)
if "_cb" in self.__dict__:
self._cb(None)
def resize(self, new_size: int):
"""Resize current list attribute with specified new size.
@ -165,7 +168,10 @@ cdef class _NodeAttributeAccessor:
self._backend.resize_list(self._node_index, self._attr_type, new_size)
self._slot_number = new_size
self._slot_number = self._backend.get_slot_number(self._node_index, self._attr_type)
if "_cb" in self.__dict__:
self._cb(None)
def clear(self):
"""Clear all items in current list attribute.
@ -180,6 +186,9 @@ cdef class _NodeAttributeAccessor:
self._slot_number = 0
if "_cb" in self.__dict__:
self._cb(None)
def insert(self, slot_index: int, value: object):
"""Insert a value to specified slot.
@ -192,7 +201,10 @@ cdef class _NodeAttributeAccessor:
self._backend.insert_to_list(self._node_index, self._attr_type, slot_index, value)
self._slot_number += 1
self._slot_number = self._backend.get_slot_number(self._node_index, self._attr_type)
if "_cb" in self.__dict__:
self._cb(None)
def remove(self, slot_index: int):
"""Remove specified slot.
@ -205,7 +217,10 @@ cdef class _NodeAttributeAccessor:
self._backend.remove_from_list(self._node_index, self._attr_type, slot_index)
self._slot_number -= 1
self._slot_number = self._backend.get_slot_number(self._node_index, self._attr_type)
if "_cb" in self.__dict__:
self._cb(None)
def where(self, filter_func: callable):
"""Filter current attribute slots with input function.
@ -214,7 +229,7 @@ cdef class _NodeAttributeAccessor:
filter_func (callable): Function to filter slot value.
Returns:
List[int]: List of slot index whoes value match the filter function.
List[int]: List of slot index whose value match the filter function.
"""
return self._backend.where(self._node_index, self._attr_type, filter_func)
@ -345,12 +360,12 @@ cdef class _NodeAttributeAccessor:
else:
raise BackendsSetItemInvalidException()
# Check and invoke value changed callback, except list attribute.
if not self._is_list and "_cb" in self.__dict__:
# Check and invoke value changed callback.
if "_cb" in self.__dict__:
self._cb(value)
def __len__(self):
return self._slot_number
return self._backend.get_slot_number(self._node_index, self._attr_type)
def on_value_changed(self, cb):
"""Set the value changed callback."""
@ -386,7 +401,9 @@ cdef class NodeBase:
cdef str cb_name
cdef _NodeAttributeAccessor attr_acc
for name, attr in type(self).__dict__.items():
for name in dir(type(self)):
attr = getattr(self, name)
# Append an attribute access wrapper to current instance.
if isinstance(attr, NodeAttribute):
# Register attribute.
@ -399,12 +416,12 @@ cdef class NodeBase:
# Bind a value changed callback if available, named as _on_<attr name>_changed.
# Except list attribute.
if not attr_acc._is_list:
cb_name = f"_on_{name}_changed"
cb_func = getattr(self, cb_name, None)
# if not attr_acc._is_list:
cb_name = f"_on_{name}_changed"
cb_func = getattr(self, cb_name, None)
if cb_func is not None:
attr_acc.on_value_changed(cb_func)
if cb_func is not None:
attr_acc.on_value_changed(cb_func)
def __setattr__(self, name, value):
"""Used to avoid attribute overriding, and an easy way to set for 1 slot attribute."""
@ -633,7 +650,9 @@ cdef class FrameBase:
cdef NODE_INDEX i
# Register node and attribute in backend.
for frame_attr_name, frame_attr in type(self).__dict__.items():
for frame_attr_name in dir(type(self)):
frame_attr = getattr(self, frame_attr_name)
# We only care about FrameNode instance.
if isinstance(frame_attr, FrameNode):
node_cls = frame_attr._node_cls
@ -656,8 +675,10 @@ cdef class FrameBase:
attr_name_type_dict = {}
# Register attributes.
for node_attr_name, node_attr in node_cls.__dict__.items():
if isinstance(node_attr, NodeAttribute):
for node_attr_name in dir(node_cls):
node_attr = getattr(node_cls, node_attr_name)
if node_attr and isinstance(node_attr, NodeAttribute):
attr_type = self._backend.add_attr(node_type, node_attr_name, node_attr._dtype, node_attr._slot_number, node_attr._is_const, node_attr._is_list)
attr_name_type_dict[node_attr_name] = attr_type

Просмотреть файл

@ -271,6 +271,25 @@ namespace maro
list.clear();
}
UINT list_index = 0;
for (auto& attr_def : _attribute_definitions)
{
if (attr_def.is_list)
{
// Assign each attribute with the index of actual list.
for (NODE_INDEX i = 0; i < _defined_node_number; i++)
{
auto& target_attr = _dynamic_block[_dynamic_size_per_node * i + attr_def.offset];
// Save the index of list in list store.
target_attr = list_index;
list_index++;
}
}
}
// Reset bitset masks.
_node_instance_masks.resize(_defined_node_number);
_node_instance_masks.reset(true);
@ -413,7 +432,7 @@ namespace maro
const auto list_index = target_attr.get_value<ATTR_UINT>();
// Then get the actual list reference for furthure operation.
// Then get the actual list reference for further operation.
auto& target_list = _list_store[list_index];
return target_list[slot_index];

Просмотреть файл

@ -6,6 +6,7 @@
#include <vector>
#include <string>
#include <iostream>
#include "common.h"
#include "attribute.h"

Просмотреть файл

@ -131,7 +131,10 @@ namespace maro
}
shape.attr_number++;
shape.max_slot_number = max(attr_def.slot_number, shape.max_slot_number);
shape.max_slot_number = attr_def.slot_number > shape.max_slot_number ? attr_def.slot_number : shape.max_slot_number;
shape.max_slot_number = MAX(attr_def.slot_number, shape.max_slot_number); //std::max(attr_def.slot_number, shape.max_slot_number);
}
}
else
@ -165,7 +168,7 @@ namespace maro
if (target_tick_pair == _snapshots.end())
{
throw SnapshotQueryNoSnapshotsError();
throw SnapshotQueryInvalidTickError();
}
auto& snapshot = target_tick_pair->second;

Просмотреть файл

@ -23,6 +23,8 @@ namespace maro
{
namespace raw
{
#define MAX(a, b) a > b ? a : b
/// <summary>
/// Shape of current querying.
/// </summary>

Просмотреть файл

@ -233,6 +233,8 @@ cdef class RawBackend(BackendAbc):
cdef list slots_not_equal(self, NODE_INDEX index, ATTR_TYPE attr_type, object value) except +:
return self.where(index, attr_type, lambda x : x != value)
cdef SLOT_INDEX get_slot_number(self, NODE_INDEX index, ATTR_TYPE attr_type) except +:
return self._frame.get_slot_number(index, attr_type)
cdef class RawSnapshotList(SnapshotListAbc):
def __cinit__(self, RawBackend backend, USHORT total_snapshots):
@ -289,7 +291,7 @@ cdef class RawSnapshotList(SnapshotListAbc):
cdef QUERY_FLOAT[:, :, :, :] result = view.array(shape=(shape.tick_number, shape.max_node_number, shape.attr_number, shape.max_slot_number), itemsize=sizeof(QUERY_FLOAT), format="f")
# Default result value
result[:, :, :, :] = np.nan
result[:, :, :, :] = 0
# Do query
self._snapshots.query(&result[0][0][0][0])

Просмотреть файл

@ -3,6 +3,7 @@
import io
import os
import sys
import numpy
# NOTE: DO NOT change the import order, as sometimes there is a conflict between setuptools and distutils,
@ -14,6 +15,11 @@ from distutils.extension import Extension
from maro import __version__
compile_flag = '-std=c++11'
if sys.platform == "win32":
compile_flag = '/std:c++14'
# Set environment variable to skip deployment process of MARO
os.environ["SKIP_DEPLOYMENT"] = "TRUE"
@ -39,7 +45,7 @@ extensions.append(
Extension(
f"{BASE_MODULE_NAME}.backend",
sources=[f"{BASE_SRC_PATH}/backend.cpp"],
extra_compile_args=['-std=c++11'])
extra_compile_args=[compile_flag])
)
@ -50,7 +56,7 @@ extensions.append(
f"{BASE_MODULE_NAME}.np_backend",
sources=[f"{BASE_SRC_PATH}/np_backend.cpp"],
include_dirs=include_dirs,
extra_compile_args=['-std=c++11'])
extra_compile_args=[compile_flag])
)
# raw implementation
@ -60,7 +66,7 @@ extensions.append(
f"{BASE_MODULE_NAME}.raw_backend",
sources=[f"{BASE_SRC_PATH}/raw_backend.cpp"],
include_dirs=include_dirs,
extra_compile_args=['-std=c++11'])
extra_compile_args=[compile_flag])
)
# frame
@ -69,7 +75,7 @@ extensions.append(
f"{BASE_MODULE_NAME}.frame",
sources=[f"{BASE_SRC_PATH}/frame.cpp"],
include_dirs=include_dirs,
extra_compile_args=['-std=c++11'])
extra_compile_args=[compile_flag])
)

Просмотреть файл

@ -130,7 +130,7 @@ class TestEnv(unittest.TestCase):
vals_after_reset = dummies_ss[env.frame_index::"val"]
if backend_name == "dynamic":
self.assertTrue(np.isnan(vals_after_reset).all())
self.assertTrue((vals_after_reset == 0).all())
else:
self.assertListEqual(list(vals_after_reset.flatten()), [
0]*dummy_number, msg=f"we should have padding values")
@ -271,7 +271,7 @@ class TestEnv(unittest.TestCase):
def test_get_avaiable_envs(self):
scenario_names = get_scenarios()
# we have 2 built-in scenarios
# we have 3 built-in scenarios
self.assertEqual(3, len(scenario_names))
self.assertTrue("cim" in scenario_names)
@ -293,5 +293,6 @@ class TestEnv(unittest.TestCase):
self.assertListEqual([0, 1], ticks[0])
self.assertListEqual([8, 9], ticks[4])
if __name__ == "__main__":
unittest.main()

Просмотреть файл

@ -246,7 +246,7 @@ class TestFrame(unittest.TestCase):
self.assertListEqual([0.0, 0.0, 0.0, 0.0, 9.0], list(states)[0:5])
# 2 padding (NAN) in the end
self.assertTrue(np.isnan(states[-2:]).all())
self.assertTrue((states[-2:].astype(np.int)==0).all())
states = static_snapshot[1::"a3"]
@ -329,7 +329,7 @@ class TestFrame(unittest.TestCase):
states = states.flatten()
# 2nd is padding value
self.assertTrue(np.isnan(states[1]))
self.assertEqual(0, int(states[1]))
self.assertListEqual([0.0, 0.0, 0.0, 123.0],
list(states[[0, 2, 3, 4]]))
@ -411,6 +411,7 @@ class TestFrame(unittest.TestCase):
@node("test")
class TestNode(NodeBase):
a1 = NodeAttribute("i", 1, is_list=True)
a4 = NodeAttribute("i", 1, is_list=True)
a2 = NodeAttribute("i", 2, is_const=True)
a3 = NodeAttribute("i")
@ -437,6 +438,9 @@ class TestFrame(unittest.TestCase):
n1.a1.append(11)
n1.a1.append(12)
n1.a4.append(100)
n1.a4.append(101)
expected_value = [10, 11, 12]
# check if value set append correct
@ -508,6 +512,26 @@ class TestFrame(unittest.TestCase):
self.assertEqual(3, len(states))
self.assertListEqual([10, 11, 12], list(states))
# check states after reset
frame.reset()
frame.snapshots.reset()
# list attribute should be cleared
self.assertEqual(0, len(n1.a1))
self.assertEqual(0, len(n1.a4))
# then append value to each list attribute to test if value will be mixed
n1.a1.append(10)
n1.a1.append(20)
n1.a4.append(100)
n1.a4.append(200)
self.assertEqual(10, n1.a1[0])
self.assertEqual(20, n1.a1[1])
self.assertEqual(100, n1.a4[0])
self.assertEqual(200, n1.a4[1])
def test_list_attribute_with_large_size(self):
@node("test")
class TestNode(NodeBase):
@ -650,5 +674,6 @@ class TestFrame(unittest.TestCase):
self.assertListEqual([i for i in range(99)], results)
if __name__ == "__main__":
unittest.main()

Просмотреть файл

@ -68,7 +68,7 @@ class TestFrame(unittest.TestCase):
msg="slicing with 1 tick, 1 node and 1 attr, should return array with 1 result")
if backend_name == "dynamic":
self.assertTrue(np.isnan(static_node_a2_states).all())
self.assertTrue((static_node_a2_states == 0).all())
else:
self.assertEqual(0, static_node_a2_states.astype(
"i")[0], msg="states before taking snapshot should be 0")
@ -159,7 +159,7 @@ class TestFrame(unittest.TestCase):
3, len(states), msg="states should contains 3 row")
if backend_name == "dynamic":
self.assertTrue(np.isnan(states[0]).all())
self.assertTrue((states[0] == 0).all())
else:
self.assertListEqual([0]*len(frame.static_nodes),
list(states[0].astype("i")), msg="over-wrote tick should return 0")
@ -206,7 +206,7 @@ class TestFrame(unittest.TestCase):
# NOTE: raw backend will padding with nan while numpy padding with 0
if backend_name == "dynamic":
# all should be nan
self.assertTrue(np.isnan(states).all())
self.assertTrue((states==0).all())
else:
self.assertListEqual(list(states.astype("I")), [
0]*STATIC_NODE_NUM)
@ -222,7 +222,7 @@ class TestFrame(unittest.TestCase):
i for i in range(STATIC_NODE_NUM)])
if backend_name == "dynamic":
self.assertTrue(np.isnan(states[1]).all())
self.assertTrue((states[1] == 0).all())
else:
self.assertListEqual(list(states[1].astype("i")), [
0]*STATIC_NODE_NUM)