Dynamic backend issues fix. (#321)
* fix dynamic backend slot number bug, set snapshot default value to 0 to align with static backend, update related tests * remove supply chain scenario * correct the scenario number
This commit is contained in:
Родитель
0f380f44f9
Коммит
803c3fc2ee
|
@ -233,6 +233,8 @@ cdef class RawBackend(BackendAbc):
|
|||
cdef list slots_not_equal(self, NODE_INDEX index, ATTR_TYPE attr_type, object value) except +:
|
||||
return self.where(index, attr_type, lambda x : x != value)
|
||||
|
||||
cdef SLOT_INDEX get_slot_number(self, NODE_INDEX index, ATTR_TYPE attr_type) except +:
|
||||
return self._frame.get_slot_number(index, attr_type)
|
||||
|
||||
cdef class RawSnapshotList(SnapshotListAbc):
|
||||
def __cinit__(self, RawBackend backend, USHORT total_snapshots):
|
||||
|
|
|
@ -139,7 +139,7 @@ cdef class BackendAbc:
|
|||
# Filter slots that greater than specified value.
|
||||
cdef list slots_greater_than(self, NODE_INDEX index, ATTR_TYPE attr_type, object value) except +
|
||||
|
||||
# Filter slots that greater equeal to specified value.
|
||||
# Filter slots that greater equal to specified value.
|
||||
cdef list slots_greater_equal(self, NODE_INDEX index, ATTR_TYPE attr_type, object value) except +
|
||||
|
||||
# Filter slots that less than specified value.
|
||||
|
@ -151,5 +151,8 @@ cdef class BackendAbc:
|
|||
# Filter slots that equal to specified value.
|
||||
cdef list slots_equal(self, NODE_INDEX index, ATTR_TYPE attr_type, object value) except +
|
||||
|
||||
# Filter slots that not euqal to specified value.
|
||||
# Filter slots that not equal to specified value.
|
||||
cdef list slots_not_equal(self, NODE_INDEX index, ATTR_TYPE attr_type, object value) except +
|
||||
|
||||
# Get slot number for specified attribute, only support dynamic backend.
|
||||
cdef SLOT_INDEX get_slot_number(self, NODE_INDEX index, ATTR_TYPE attr_type) except +
|
||||
|
|
|
@ -126,3 +126,6 @@ cdef class BackendAbc:
|
|||
|
||||
cdef list slots_not_equal(self, NODE_INDEX index, ATTR_TYPE attr_type, object value) except +:
|
||||
pass
|
||||
|
||||
cdef SLOT_INDEX get_slot_number(self, NODE_INDEX index, ATTR_TYPE attr_type) except +:
|
||||
pass
|
||||
|
|
|
@ -142,14 +142,17 @@ cdef class _NodeAttributeAccessor:
|
|||
Current attribute must be a list.
|
||||
|
||||
Args:
|
||||
value(object): Value to append, the data type must fit the decleared one.
|
||||
value(object): Value to append, the data type must fit the declared one.
|
||||
"""
|
||||
if not self._is_list:
|
||||
raise BackendsAppendToNonListAttributeException()
|
||||
|
||||
self._backend.append_to_list(self._node_index, self._attr_type, value)
|
||||
|
||||
self._slot_number += 1
|
||||
self._slot_number = self._backend.get_slot_number(self._node_index, self._attr_type)
|
||||
|
||||
if "_cb" in self.__dict__:
|
||||
self._cb(None)
|
||||
|
||||
def resize(self, new_size: int):
|
||||
"""Resize current list attribute with specified new size.
|
||||
|
@ -165,7 +168,10 @@ cdef class _NodeAttributeAccessor:
|
|||
|
||||
self._backend.resize_list(self._node_index, self._attr_type, new_size)
|
||||
|
||||
self._slot_number = new_size
|
||||
self._slot_number = self._backend.get_slot_number(self._node_index, self._attr_type)
|
||||
|
||||
if "_cb" in self.__dict__:
|
||||
self._cb(None)
|
||||
|
||||
def clear(self):
|
||||
"""Clear all items in current list attribute.
|
||||
|
@ -180,6 +186,9 @@ cdef class _NodeAttributeAccessor:
|
|||
|
||||
self._slot_number = 0
|
||||
|
||||
if "_cb" in self.__dict__:
|
||||
self._cb(None)
|
||||
|
||||
def insert(self, slot_index: int, value: object):
|
||||
"""Insert a value to specified slot.
|
||||
|
||||
|
@ -192,7 +201,10 @@ cdef class _NodeAttributeAccessor:
|
|||
|
||||
self._backend.insert_to_list(self._node_index, self._attr_type, slot_index, value)
|
||||
|
||||
self._slot_number += 1
|
||||
self._slot_number = self._backend.get_slot_number(self._node_index, self._attr_type)
|
||||
|
||||
if "_cb" in self.__dict__:
|
||||
self._cb(None)
|
||||
|
||||
def remove(self, slot_index: int):
|
||||
"""Remove specified slot.
|
||||
|
@ -205,7 +217,10 @@ cdef class _NodeAttributeAccessor:
|
|||
|
||||
self._backend.remove_from_list(self._node_index, self._attr_type, slot_index)
|
||||
|
||||
self._slot_number -= 1
|
||||
self._slot_number = self._backend.get_slot_number(self._node_index, self._attr_type)
|
||||
|
||||
if "_cb" in self.__dict__:
|
||||
self._cb(None)
|
||||
|
||||
def where(self, filter_func: callable):
|
||||
"""Filter current attribute slots with input function.
|
||||
|
@ -214,7 +229,7 @@ cdef class _NodeAttributeAccessor:
|
|||
filter_func (callable): Function to filter slot value.
|
||||
|
||||
Returns:
|
||||
List[int]: List of slot index whoes value match the filter function.
|
||||
List[int]: List of slot index whose value match the filter function.
|
||||
"""
|
||||
return self._backend.where(self._node_index, self._attr_type, filter_func)
|
||||
|
||||
|
@ -345,12 +360,12 @@ cdef class _NodeAttributeAccessor:
|
|||
else:
|
||||
raise BackendsSetItemInvalidException()
|
||||
|
||||
# Check and invoke value changed callback, except list attribute.
|
||||
if not self._is_list and "_cb" in self.__dict__:
|
||||
# Check and invoke value changed callback.
|
||||
if "_cb" in self.__dict__:
|
||||
self._cb(value)
|
||||
|
||||
def __len__(self):
|
||||
return self._slot_number
|
||||
return self._backend.get_slot_number(self._node_index, self._attr_type)
|
||||
|
||||
def on_value_changed(self, cb):
|
||||
"""Set the value changed callback."""
|
||||
|
@ -386,7 +401,9 @@ cdef class NodeBase:
|
|||
cdef str cb_name
|
||||
cdef _NodeAttributeAccessor attr_acc
|
||||
|
||||
for name, attr in type(self).__dict__.items():
|
||||
for name in dir(type(self)):
|
||||
attr = getattr(self, name)
|
||||
|
||||
# Append an attribute access wrapper to current instance.
|
||||
if isinstance(attr, NodeAttribute):
|
||||
# Register attribute.
|
||||
|
@ -399,12 +416,12 @@ cdef class NodeBase:
|
|||
|
||||
# Bind a value changed callback if available, named as _on_<attr name>_changed.
|
||||
# Except list attribute.
|
||||
if not attr_acc._is_list:
|
||||
cb_name = f"_on_{name}_changed"
|
||||
cb_func = getattr(self, cb_name, None)
|
||||
# if not attr_acc._is_list:
|
||||
cb_name = f"_on_{name}_changed"
|
||||
cb_func = getattr(self, cb_name, None)
|
||||
|
||||
if cb_func is not None:
|
||||
attr_acc.on_value_changed(cb_func)
|
||||
if cb_func is not None:
|
||||
attr_acc.on_value_changed(cb_func)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
"""Used to avoid attribute overriding, and an easy way to set for 1 slot attribute."""
|
||||
|
@ -633,7 +650,9 @@ cdef class FrameBase:
|
|||
cdef NODE_INDEX i
|
||||
|
||||
# Register node and attribute in backend.
|
||||
for frame_attr_name, frame_attr in type(self).__dict__.items():
|
||||
for frame_attr_name in dir(type(self)):
|
||||
frame_attr = getattr(self, frame_attr_name)
|
||||
|
||||
# We only care about FrameNode instance.
|
||||
if isinstance(frame_attr, FrameNode):
|
||||
node_cls = frame_attr._node_cls
|
||||
|
@ -656,8 +675,10 @@ cdef class FrameBase:
|
|||
attr_name_type_dict = {}
|
||||
|
||||
# Register attributes.
|
||||
for node_attr_name, node_attr in node_cls.__dict__.items():
|
||||
if isinstance(node_attr, NodeAttribute):
|
||||
for node_attr_name in dir(node_cls):
|
||||
node_attr = getattr(node_cls, node_attr_name)
|
||||
|
||||
if node_attr and isinstance(node_attr, NodeAttribute):
|
||||
attr_type = self._backend.add_attr(node_type, node_attr_name, node_attr._dtype, node_attr._slot_number, node_attr._is_const, node_attr._is_list)
|
||||
|
||||
attr_name_type_dict[node_attr_name] = attr_type
|
||||
|
|
|
@ -271,6 +271,25 @@ namespace maro
|
|||
list.clear();
|
||||
}
|
||||
|
||||
UINT list_index = 0;
|
||||
|
||||
for (auto& attr_def : _attribute_definitions)
|
||||
{
|
||||
if (attr_def.is_list)
|
||||
{
|
||||
// Assign each attribute with the index of actual list.
|
||||
for (NODE_INDEX i = 0; i < _defined_node_number; i++)
|
||||
{
|
||||
auto& target_attr = _dynamic_block[_dynamic_size_per_node * i + attr_def.offset];
|
||||
|
||||
// Save the index of list in list store.
|
||||
target_attr = list_index;
|
||||
|
||||
list_index++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reset bitset masks.
|
||||
_node_instance_masks.resize(_defined_node_number);
|
||||
_node_instance_masks.reset(true);
|
||||
|
@ -413,7 +432,7 @@ namespace maro
|
|||
|
||||
const auto list_index = target_attr.get_value<ATTR_UINT>();
|
||||
|
||||
// Then get the actual list reference for furthure operation.
|
||||
// Then get the actual list reference for further operation.
|
||||
auto& target_list = _list_store[list_index];
|
||||
|
||||
return target_list[slot_index];
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
|
||||
#include "common.h"
|
||||
#include "attribute.h"
|
||||
|
|
|
@ -131,7 +131,10 @@ namespace maro
|
|||
}
|
||||
|
||||
shape.attr_number++;
|
||||
shape.max_slot_number = max(attr_def.slot_number, shape.max_slot_number);
|
||||
|
||||
shape.max_slot_number = attr_def.slot_number > shape.max_slot_number ? attr_def.slot_number : shape.max_slot_number;
|
||||
|
||||
shape.max_slot_number = MAX(attr_def.slot_number, shape.max_slot_number); //std::max(attr_def.slot_number, shape.max_slot_number);
|
||||
}
|
||||
}
|
||||
else
|
||||
|
@ -165,7 +168,7 @@ namespace maro
|
|||
|
||||
if (target_tick_pair == _snapshots.end())
|
||||
{
|
||||
throw SnapshotQueryNoSnapshotsError();
|
||||
throw SnapshotQueryInvalidTickError();
|
||||
}
|
||||
|
||||
auto& snapshot = target_tick_pair->second;
|
||||
|
|
|
@ -23,6 +23,8 @@ namespace maro
|
|||
{
|
||||
namespace raw
|
||||
{
|
||||
#define MAX(a, b) a > b ? a : b
|
||||
|
||||
/// <summary>
|
||||
/// Shape of current querying.
|
||||
/// </summary>
|
||||
|
|
|
@ -233,6 +233,8 @@ cdef class RawBackend(BackendAbc):
|
|||
cdef list slots_not_equal(self, NODE_INDEX index, ATTR_TYPE attr_type, object value) except +:
|
||||
return self.where(index, attr_type, lambda x : x != value)
|
||||
|
||||
cdef SLOT_INDEX get_slot_number(self, NODE_INDEX index, ATTR_TYPE attr_type) except +:
|
||||
return self._frame.get_slot_number(index, attr_type)
|
||||
|
||||
cdef class RawSnapshotList(SnapshotListAbc):
|
||||
def __cinit__(self, RawBackend backend, USHORT total_snapshots):
|
||||
|
@ -289,7 +291,7 @@ cdef class RawSnapshotList(SnapshotListAbc):
|
|||
cdef QUERY_FLOAT[:, :, :, :] result = view.array(shape=(shape.tick_number, shape.max_node_number, shape.attr_number, shape.max_slot_number), itemsize=sizeof(QUERY_FLOAT), format="f")
|
||||
|
||||
# Default result value
|
||||
result[:, :, :, :] = np.nan
|
||||
result[:, :, :, :] = 0
|
||||
|
||||
# Do query
|
||||
self._snapshots.query(&result[0][0][0][0])
|
||||
|
|
14
setup.py
14
setup.py
|
@ -3,6 +3,7 @@
|
|||
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
import numpy
|
||||
|
||||
# NOTE: DO NOT change the import order, as sometimes there is a conflict between setuptools and distutils,
|
||||
|
@ -14,6 +15,11 @@ from distutils.extension import Extension
|
|||
|
||||
from maro import __version__
|
||||
|
||||
compile_flag = '-std=c++11'
|
||||
|
||||
if sys.platform == "win32":
|
||||
compile_flag = '/std:c++14'
|
||||
|
||||
# Set environment variable to skip deployment process of MARO
|
||||
os.environ["SKIP_DEPLOYMENT"] = "TRUE"
|
||||
|
||||
|
@ -39,7 +45,7 @@ extensions.append(
|
|||
Extension(
|
||||
f"{BASE_MODULE_NAME}.backend",
|
||||
sources=[f"{BASE_SRC_PATH}/backend.cpp"],
|
||||
extra_compile_args=['-std=c++11'])
|
||||
extra_compile_args=[compile_flag])
|
||||
)
|
||||
|
||||
|
||||
|
@ -50,7 +56,7 @@ extensions.append(
|
|||
f"{BASE_MODULE_NAME}.np_backend",
|
||||
sources=[f"{BASE_SRC_PATH}/np_backend.cpp"],
|
||||
include_dirs=include_dirs,
|
||||
extra_compile_args=['-std=c++11'])
|
||||
extra_compile_args=[compile_flag])
|
||||
)
|
||||
|
||||
# raw implementation
|
||||
|
@ -60,7 +66,7 @@ extensions.append(
|
|||
f"{BASE_MODULE_NAME}.raw_backend",
|
||||
sources=[f"{BASE_SRC_PATH}/raw_backend.cpp"],
|
||||
include_dirs=include_dirs,
|
||||
extra_compile_args=['-std=c++11'])
|
||||
extra_compile_args=[compile_flag])
|
||||
)
|
||||
|
||||
# frame
|
||||
|
@ -69,7 +75,7 @@ extensions.append(
|
|||
f"{BASE_MODULE_NAME}.frame",
|
||||
sources=[f"{BASE_SRC_PATH}/frame.cpp"],
|
||||
include_dirs=include_dirs,
|
||||
extra_compile_args=['-std=c++11'])
|
||||
extra_compile_args=[compile_flag])
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -130,7 +130,7 @@ class TestEnv(unittest.TestCase):
|
|||
vals_after_reset = dummies_ss[env.frame_index::"val"]
|
||||
|
||||
if backend_name == "dynamic":
|
||||
self.assertTrue(np.isnan(vals_after_reset).all())
|
||||
self.assertTrue((vals_after_reset == 0).all())
|
||||
else:
|
||||
self.assertListEqual(list(vals_after_reset.flatten()), [
|
||||
0]*dummy_number, msg=f"we should have padding values")
|
||||
|
@ -271,7 +271,7 @@ class TestEnv(unittest.TestCase):
|
|||
def test_get_avaiable_envs(self):
|
||||
scenario_names = get_scenarios()
|
||||
|
||||
# we have 2 built-in scenarios
|
||||
# we have 3 built-in scenarios
|
||||
self.assertEqual(3, len(scenario_names))
|
||||
|
||||
self.assertTrue("cim" in scenario_names)
|
||||
|
@ -293,5 +293,6 @@ class TestEnv(unittest.TestCase):
|
|||
self.assertListEqual([0, 1], ticks[0])
|
||||
self.assertListEqual([8, 9], ticks[4])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -246,7 +246,7 @@ class TestFrame(unittest.TestCase):
|
|||
self.assertListEqual([0.0, 0.0, 0.0, 0.0, 9.0], list(states)[0:5])
|
||||
|
||||
# 2 padding (NAN) in the end
|
||||
self.assertTrue(np.isnan(states[-2:]).all())
|
||||
self.assertTrue((states[-2:].astype(np.int)==0).all())
|
||||
|
||||
states = static_snapshot[1::"a3"]
|
||||
|
||||
|
@ -329,7 +329,7 @@ class TestFrame(unittest.TestCase):
|
|||
states = states.flatten()
|
||||
|
||||
# 2nd is padding value
|
||||
self.assertTrue(np.isnan(states[1]))
|
||||
self.assertEqual(0, int(states[1]))
|
||||
|
||||
self.assertListEqual([0.0, 0.0, 0.0, 123.0],
|
||||
list(states[[0, 2, 3, 4]]))
|
||||
|
@ -411,6 +411,7 @@ class TestFrame(unittest.TestCase):
|
|||
@node("test")
|
||||
class TestNode(NodeBase):
|
||||
a1 = NodeAttribute("i", 1, is_list=True)
|
||||
a4 = NodeAttribute("i", 1, is_list=True)
|
||||
a2 = NodeAttribute("i", 2, is_const=True)
|
||||
a3 = NodeAttribute("i")
|
||||
|
||||
|
@ -437,6 +438,9 @@ class TestFrame(unittest.TestCase):
|
|||
n1.a1.append(11)
|
||||
n1.a1.append(12)
|
||||
|
||||
n1.a4.append(100)
|
||||
n1.a4.append(101)
|
||||
|
||||
expected_value = [10, 11, 12]
|
||||
|
||||
# check if value set append correct
|
||||
|
@ -508,6 +512,26 @@ class TestFrame(unittest.TestCase):
|
|||
self.assertEqual(3, len(states))
|
||||
self.assertListEqual([10, 11, 12], list(states))
|
||||
|
||||
# check states after reset
|
||||
frame.reset()
|
||||
frame.snapshots.reset()
|
||||
|
||||
# list attribute should be cleared
|
||||
self.assertEqual(0, len(n1.a1))
|
||||
self.assertEqual(0, len(n1.a4))
|
||||
|
||||
# then append value to each list attribute to test if value will be mixed
|
||||
n1.a1.append(10)
|
||||
n1.a1.append(20)
|
||||
|
||||
n1.a4.append(100)
|
||||
n1.a4.append(200)
|
||||
|
||||
self.assertEqual(10, n1.a1[0])
|
||||
self.assertEqual(20, n1.a1[1])
|
||||
self.assertEqual(100, n1.a4[0])
|
||||
self.assertEqual(200, n1.a4[1])
|
||||
|
||||
def test_list_attribute_with_large_size(self):
|
||||
@node("test")
|
||||
class TestNode(NodeBase):
|
||||
|
@ -650,5 +674,6 @@ class TestFrame(unittest.TestCase):
|
|||
|
||||
self.assertListEqual([i for i in range(99)], results)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -68,7 +68,7 @@ class TestFrame(unittest.TestCase):
|
|||
msg="slicing with 1 tick, 1 node and 1 attr, should return array with 1 result")
|
||||
|
||||
if backend_name == "dynamic":
|
||||
self.assertTrue(np.isnan(static_node_a2_states).all())
|
||||
self.assertTrue((static_node_a2_states == 0).all())
|
||||
else:
|
||||
self.assertEqual(0, static_node_a2_states.astype(
|
||||
"i")[0], msg="states before taking snapshot should be 0")
|
||||
|
@ -159,7 +159,7 @@ class TestFrame(unittest.TestCase):
|
|||
3, len(states), msg="states should contains 3 row")
|
||||
|
||||
if backend_name == "dynamic":
|
||||
self.assertTrue(np.isnan(states[0]).all())
|
||||
self.assertTrue((states[0] == 0).all())
|
||||
else:
|
||||
self.assertListEqual([0]*len(frame.static_nodes),
|
||||
list(states[0].astype("i")), msg="over-wrote tick should return 0")
|
||||
|
@ -206,7 +206,7 @@ class TestFrame(unittest.TestCase):
|
|||
# NOTE: raw backend will padding with nan while numpy padding with 0
|
||||
if backend_name == "dynamic":
|
||||
# all should be nan
|
||||
self.assertTrue(np.isnan(states).all())
|
||||
self.assertTrue((states==0).all())
|
||||
else:
|
||||
self.assertListEqual(list(states.astype("I")), [
|
||||
0]*STATIC_NODE_NUM)
|
||||
|
@ -222,7 +222,7 @@ class TestFrame(unittest.TestCase):
|
|||
i for i in range(STATIC_NODE_NUM)])
|
||||
|
||||
if backend_name == "dynamic":
|
||||
self.assertTrue(np.isnan(states[1]).all())
|
||||
self.assertTrue((states[1] == 0).all())
|
||||
else:
|
||||
self.assertListEqual(list(states[1].astype("i")), [
|
||||
0]*STATIC_NODE_NUM)
|
||||
|
|
Загрузка…
Ссылка в новой задаче