Stable-baselines3 agent working - fix shapes of observation spaces (#144)
* Stable-baselines3 agent working - fix shapes of observation spaces chore: Update stable-baselines3 agent with env_checker chore: Update ports list with "Null" value --------- Co-authored-by: William Blum <william.blum@microsoft.com>
This commit is contained in:
Родитель
4eabac5e60
Коммит
8f7078c814
|
@ -22,7 +22,7 @@ from cyberbattle._env.defender import DefenderAgent
|
||||||
from cyberbattle.simulation.model import PortName, PrivilegeLevel
|
from cyberbattle.simulation.model import PortName, PrivilegeLevel
|
||||||
from ..simulation import commandcontrol, model, actions
|
from ..simulation import commandcontrol, model, actions
|
||||||
from .discriminatedunion import DiscriminatedUnion
|
from .discriminatedunion import DiscriminatedUnion
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
LOGGER = logging.getLogger(__name__)
|
LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -71,7 +71,7 @@ Observation = TypedDict(
|
||||||
# whether a lateral move was just performed
|
# whether a lateral move was just performed
|
||||||
"lateral_move": numpy.int32,
|
"lateral_move": numpy.int32,
|
||||||
# whether customer data were just discovered
|
# whether customer data were just discovered
|
||||||
"customer_data_found": Tuple[numpy.int32],
|
"customer_data_found": numpy.int32,
|
||||||
# 0 if there were no probing attempt
|
# 0 if there were no probing attempt
|
||||||
# 1 if an attempted probing failed
|
# 1 if an attempted probing failed
|
||||||
# 2 if an attempted probing succeeded
|
# 2 if an attempted probing succeeded
|
||||||
|
@ -90,7 +90,7 @@ Observation = TypedDict(
|
||||||
# total nodes discovered so far
|
# total nodes discovered so far
|
||||||
"discovered_node_count": int,
|
"discovered_node_count": int,
|
||||||
# Matrix of properties for all the discovered nodes
|
# Matrix of properties for all the discovered nodes
|
||||||
"discovered_nodes_properties": Tuple[numpy.ndarray, ...],
|
"discovered_nodes_properties": numpy.ndarray,
|
||||||
# Node privilege level on every discovered node (e.g., 0 if not owned, 1 owned, 2 admin, 3 for system)
|
# Node privilege level on every discovered node (e.g., 0 if not owned, 1 owned, 2 admin, 3 for system)
|
||||||
"nodes_privilegelevel": numpy.ndarray,
|
"nodes_privilegelevel": numpy.ndarray,
|
||||||
# Tuple encoding of the credential cache matrix.
|
# Tuple encoding of the credential cache matrix.
|
||||||
|
@ -183,14 +183,14 @@ class EnvironmentBounds(NamedTuple):
|
||||||
remote_attacks_count - Unique remote vulnerabilities
|
remote_attacks_count - Unique remote vulnerabilities
|
||||||
"""
|
"""
|
||||||
|
|
||||||
maximum_total_credentials: int
|
maximum_total_credentials: np.int32
|
||||||
maximum_node_count: int
|
maximum_node_count: np.int32
|
||||||
maximum_discoverable_credentials_per_action: int
|
maximum_discoverable_credentials_per_action: np.int32
|
||||||
|
|
||||||
port_count: int
|
port_count: np.int32
|
||||||
property_count: int
|
property_count: np.int32
|
||||||
local_attacks_count: int
|
local_attacks_count: np.int32
|
||||||
remote_attacks_count: int
|
remote_attacks_count: np.int32
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def of_identifiers(
|
def of_identifiers(
|
||||||
|
@ -200,16 +200,27 @@ class EnvironmentBounds(NamedTuple):
|
||||||
maximum_node_count: int,
|
maximum_node_count: int,
|
||||||
maximum_discoverable_credentials_per_action: Optional[int] = None,
|
maximum_discoverable_credentials_per_action: Optional[int] = None,
|
||||||
):
|
):
|
||||||
if not maximum_discoverable_credentials_per_action:
|
|
||||||
maximum_discoverable_credentials_per_action = maximum_total_credentials
|
maximum_discoverable_credentials_per_action = maximum_discoverable_credentials_per_action or maximum_total_credentials
|
||||||
|
|
||||||
|
assert np.can_cast(maximum_total_credentials, np.int32), "maximum_total_credentials must be a 32-bit integer"
|
||||||
|
assert np.can_cast(maximum_node_count, np.int32), "maximum_node_count must be a 32-bit integer"
|
||||||
|
assert maximum_total_credentials > 0, "maximum_total_credentials must be positive"
|
||||||
|
assert maximum_node_count > 0, "maximum_node_count must be positive"
|
||||||
|
assert np.can_cast(len(identifiers.ports), np.int32), "port_count must be a 32-bit integer"
|
||||||
|
assert np.can_cast(len(identifiers.properties), np.int32), "property_count must be a 32-bit integer"
|
||||||
|
assert np.can_cast(len(identifiers.local_vulnerabilities), np.int32), "local_attacks_count must be a 32-bit integer"
|
||||||
|
assert np.can_cast(len(identifiers.remote_vulnerabilities), np.int32), "remote_attacks_count must be a 32-bit integer"
|
||||||
|
assert np.can_cast(maximum_discoverable_credentials_per_action, np.int32), "maximum_discoverable_credentials_per_action must be a 32-bit integer"
|
||||||
|
|
||||||
return EnvironmentBounds(
|
return EnvironmentBounds(
|
||||||
maximum_total_credentials=maximum_total_credentials,
|
maximum_total_credentials=np.int32(maximum_total_credentials),
|
||||||
maximum_node_count=maximum_node_count,
|
maximum_node_count=np.int32(maximum_node_count),
|
||||||
maximum_discoverable_credentials_per_action=maximum_discoverable_credentials_per_action,
|
maximum_discoverable_credentials_per_action=np.int32(maximum_discoverable_credentials_per_action),
|
||||||
port_count=len(identifiers.ports),
|
port_count=np.int32(len(identifiers.ports)),
|
||||||
property_count=len(identifiers.properties),
|
property_count=np.int32(len(identifiers.properties)),
|
||||||
local_attacks_count=len(identifiers.local_vulnerabilities),
|
local_attacks_count=np.int32(len(identifiers.local_vulnerabilities)),
|
||||||
remote_attacks_count=len(identifiers.remote_vulnerabilities),
|
remote_attacks_count=np.int32(len(identifiers.remote_vulnerabilities)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -252,7 +263,7 @@ class ObservationSpaceType(spaces.Dict):
|
||||||
# successuflly moved to the target node (1) or not (0)
|
# successuflly moved to the target node (1) or not (0)
|
||||||
"lateral_move": spaces.Discrete(2),
|
"lateral_move": spaces.Discrete(2),
|
||||||
# boolean: 1 if customer secret data were discovered, 0 otherwise
|
# boolean: 1 if customer secret data were discovered, 0 otherwise
|
||||||
"customer_data_found": spaces.MultiBinary(2),
|
"customer_data_found": spaces.Discrete(2),
|
||||||
# whether an attempted probing succeeded or not
|
# whether an attempted probing succeeded or not
|
||||||
"probe_result": spaces.Discrete(3),
|
"probe_result": spaces.Discrete(3),
|
||||||
# Esclation result
|
# Esclation result
|
||||||
|
@ -271,12 +282,12 @@ class ObservationSpaceType(spaces.Dict):
|
||||||
# that was used to authenticat to target node 56 on port number 22 (e.g. SSH)
|
# that was used to authenticat to target node 56 on port number 22 (e.g. SSH)
|
||||||
[
|
[
|
||||||
spaces.MultiDiscrete(
|
spaces.MultiDiscrete(
|
||||||
[
|
np.array([
|
||||||
NA + 1,
|
NA + 1,
|
||||||
bounds.maximum_total_credentials,
|
bounds.maximum_total_credentials,
|
||||||
bounds.maximum_node_count,
|
bounds.maximum_node_count,
|
||||||
bounds.port_count,
|
bounds.port_count,
|
||||||
]
|
], dtype=np.int32)
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
* bounds.maximum_discoverable_credentials_per_action
|
* bounds.maximum_discoverable_credentials_per_action
|
||||||
|
@ -287,9 +298,9 @@ class ObservationSpaceType(spaces.Dict):
|
||||||
# even though such action would 'fail' and potentially yield a negative reward.
|
# even though such action would 'fail' and potentially yield a negative reward.
|
||||||
"action_mask": spaces.Dict(
|
"action_mask": spaces.Dict(
|
||||||
{
|
{
|
||||||
"local_vulnerability": spaces.MultiBinary(bounds.maximum_node_count * bounds.local_attacks_count),
|
"local_vulnerability": spaces.MultiBinary(np.array([bounds.maximum_node_count, bounds.local_attacks_count])),
|
||||||
"remote_vulnerability": spaces.MultiBinary(bounds.maximum_node_count * bounds.maximum_node_count * bounds.remote_attacks_count),
|
"remote_vulnerability": spaces.MultiBinary(np.array([bounds.maximum_node_count, bounds.maximum_node_count, bounds.remote_attacks_count])),
|
||||||
"connect": spaces.MultiBinary(bounds.maximum_node_count * bounds.maximum_node_count * bounds.port_count * bounds.maximum_total_credentials),
|
"connect": spaces.MultiBinary(np.array([bounds.maximum_node_count, bounds.maximum_node_count, bounds.port_count, bounds.maximum_total_credentials], dtype=np.int32))
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
# size of the credential stack
|
# size of the credential stack
|
||||||
|
@ -298,7 +309,7 @@ class ObservationSpaceType(spaces.Dict):
|
||||||
"discovered_node_count": spaces.Discrete(bounds.maximum_node_count),
|
"discovered_node_count": spaces.Discrete(bounds.maximum_node_count),
|
||||||
# Matrix of properties for all the discovered nodes
|
# Matrix of properties for all the discovered nodes
|
||||||
# 3 values for each matrix cell: set, unset, unknown
|
# 3 values for each matrix cell: set, unset, unknown
|
||||||
"discovered_nodes_properties": spaces.MultiDiscrete([3] * bounds.maximum_node_count * bounds.property_count),
|
"discovered_nodes_properties": spaces.MultiDiscrete(np.full(shape=(bounds.maximum_node_count, bounds.property_count), fill_value=3)),
|
||||||
# Escalation level on every discovered node (e.g., 0 if not owned, 1 for admin, 2 for system)
|
# Escalation level on every discovered node (e.g., 0 if not owned, 1 for admin, 2 for system)
|
||||||
"nodes_privilegelevel": spaces.MultiDiscrete([CyberBattleEnv.privilege_levels] * bounds.maximum_node_count),
|
"nodes_privilegelevel": spaces.MultiDiscrete([CyberBattleEnv.privilege_levels] * bounds.maximum_node_count),
|
||||||
# Encoding of the credential cache of shape: (credential_cache_length, 2)
|
# Encoding of the credential cache of shape: (credential_cache_length, 2)
|
||||||
|
@ -306,7 +317,7 @@ class ObservationSpaceType(spaces.Dict):
|
||||||
# Each row represent a discovered credential,
|
# Each row represent a discovered credential,
|
||||||
# the credential index is given by the row index (i.e. order of discovery)
|
# the credential index is given by the row index (i.e. order of discovery)
|
||||||
# A row is of the form: (target_node_discover_index, port_index)
|
# A row is of the form: (target_node_discover_index, port_index)
|
||||||
"credential_cache_matrix": spaces.Tuple([spaces.MultiDiscrete([bounds.maximum_node_count, bounds.port_count])] * bounds.maximum_total_credentials),
|
"credential_cache_matrix": spaces.Tuple([spaces.MultiDiscrete(np.array([bounds.maximum_node_count, bounds.port_count],dtype=np.int32))] * bounds.maximum_total_credentials),
|
||||||
# ---------------------------------------------------------
|
# ---------------------------------------------------------
|
||||||
# Fields that were previously in the 'info' dict:
|
# Fields that were previously in the 'info' dict:
|
||||||
# ---------------------------------------------------------
|
# ---------------------------------------------------------
|
||||||
|
@ -460,7 +471,7 @@ class CyberBattleEnv(CyberBattleSpaceKind):
|
||||||
self,
|
self,
|
||||||
initial_environment: model.Environment,
|
initial_environment: model.Environment,
|
||||||
maximum_total_credentials: int = 1000,
|
maximum_total_credentials: int = 1000,
|
||||||
maximum_node_count: int = 100,
|
maximum_node_count: int = 100,
|
||||||
maximum_discoverable_credentials_per_action: int = 5,
|
maximum_discoverable_credentials_per_action: int = 5,
|
||||||
defender_agent: Optional[DefenderAgent] = None,
|
defender_agent: Optional[DefenderAgent] = None,
|
||||||
attacker_goal: Optional[AttackerGoal] = AttackerGoal(own_atleast_percent=1.0),
|
attacker_goal: Optional[AttackerGoal] = AttackerGoal(own_atleast_percent=1.0),
|
||||||
|
@ -523,27 +534,27 @@ class CyberBattleEnv(CyberBattleSpaceKind):
|
||||||
# The Space object defining the valid actions of an attacker.
|
# The Space object defining the valid actions of an attacker.
|
||||||
local_vulnerabilities_count = self.__bounds.local_attacks_count
|
local_vulnerabilities_count = self.__bounds.local_attacks_count
|
||||||
remote_vulnerabilities_count = self.__bounds.remote_attacks_count
|
remote_vulnerabilities_count = self.__bounds.remote_attacks_count
|
||||||
maximum_node_count = self.__bounds.maximum_node_count
|
maximum_node_count_int32 = self.__bounds.maximum_node_count
|
||||||
port_count = self.__bounds.port_count
|
port_count = self.__bounds.port_count
|
||||||
|
|
||||||
action_spaces = {
|
action_spaces = {
|
||||||
"local_vulnerability": spaces.MultiDiscrete(
|
"local_vulnerability": spaces.MultiDiscrete(
|
||||||
# source_node_id, vulnerability_id
|
# source_node_id, vulnerability_id
|
||||||
[maximum_node_count, local_vulnerabilities_count]
|
np.array([maximum_node_count_int32, local_vulnerabilities_count], dtype=np.int32)
|
||||||
),
|
),
|
||||||
"remote_vulnerability": spaces.MultiDiscrete(
|
"remote_vulnerability": spaces.MultiDiscrete(
|
||||||
# source_node_id, target_node_id, vulnerability_id
|
# source_node_id, target_node_id, vulnerability_id
|
||||||
[maximum_node_count, maximum_node_count, remote_vulnerabilities_count]
|
np.array([maximum_node_count_int32, maximum_node_count_int32, remote_vulnerabilities_count], dtype=np.int32)
|
||||||
),
|
),
|
||||||
"connect": spaces.MultiDiscrete(
|
"connect": spaces.MultiDiscrete(
|
||||||
# source_node_id, target_node_id, target_port, credential_id
|
# source_node_id, target_node_id, target_port, credential_id
|
||||||
# (by index of discovery: 0 for initial node, 1 for first discovered node, ...)
|
# (by index of discovery: 0 for initial node, 1 for first discovered node, ...)
|
||||||
[
|
np.array([
|
||||||
maximum_node_count,
|
maximum_node_count_int32,
|
||||||
maximum_node_count,
|
maximum_node_count_int32,
|
||||||
port_count,
|
port_count,
|
||||||
maximum_total_credentials,
|
maximum_total_credentials,
|
||||||
]
|
], dtype=np.int32)
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -613,10 +624,10 @@ class CyberBattleEnv(CyberBattleSpaceKind):
|
||||||
local_vulnerabilities_count = self.__bounds.local_attacks_count
|
local_vulnerabilities_count = self.__bounds.local_attacks_count
|
||||||
remote_vulnerabilities_count = self.__bounds.remote_attacks_count
|
remote_vulnerabilities_count = self.__bounds.remote_attacks_count
|
||||||
port_count = self.__bounds.port_count
|
port_count = self.__bounds.port_count
|
||||||
local = numpy.zeros(shape=(max_node_count, local_vulnerabilities_count), dtype=numpy.int32)
|
local = numpy.zeros(shape=(max_node_count, local_vulnerabilities_count), dtype=numpy.int8)
|
||||||
remote = numpy.zeros(
|
remote = numpy.zeros(
|
||||||
shape=(max_node_count, max_node_count, remote_vulnerabilities_count),
|
shape=(max_node_count, max_node_count, remote_vulnerabilities_count),
|
||||||
dtype=numpy.int32,
|
dtype=numpy.int8,
|
||||||
)
|
)
|
||||||
connect = numpy.zeros(
|
connect = numpy.zeros(
|
||||||
shape=(
|
shape=(
|
||||||
|
@ -625,7 +636,7 @@ class CyberBattleEnv(CyberBattleSpaceKind):
|
||||||
port_count,
|
port_count,
|
||||||
self.__bounds.maximum_total_credentials,
|
self.__bounds.maximum_total_credentials,
|
||||||
),
|
),
|
||||||
dtype=numpy.int32,
|
dtype=numpy.int8,
|
||||||
)
|
)
|
||||||
return ActionMask(local_vulnerability=local, remote_vulnerability=remote, connect=connect)
|
return ActionMask(local_vulnerability=local, remote_vulnerability=remote, connect=connect)
|
||||||
|
|
||||||
|
@ -744,14 +755,14 @@ class CyberBattleEnv(CyberBattleSpaceKind):
|
||||||
newly_discovered_nodes_count=numpy.int32(0),
|
newly_discovered_nodes_count=numpy.int32(0),
|
||||||
leaked_credentials=tuple([numpy.array([UNUSED_SLOT, 0, 0, 0], dtype=numpy.int32)] * self.__bounds.maximum_discoverable_credentials_per_action),
|
leaked_credentials=tuple([numpy.array([UNUSED_SLOT, 0, 0, 0], dtype=numpy.int32)] * self.__bounds.maximum_discoverable_credentials_per_action),
|
||||||
lateral_move=numpy.int32(0),
|
lateral_move=numpy.int32(0),
|
||||||
customer_data_found=(numpy.int32(0),),
|
customer_data_found=numpy.int32(0),
|
||||||
escalation=numpy.int32(PrivilegeLevel.NoAccess),
|
escalation=numpy.int32(PrivilegeLevel.NoAccess),
|
||||||
action_mask=self.__get_blank_action_mask(),
|
action_mask=self.__get_blank_action_mask(),
|
||||||
probe_result=numpy.int32(0),
|
probe_result=numpy.int32(0),
|
||||||
credential_cache_matrix=tuple([numpy.zeros((2))] * self.__bounds.maximum_total_credentials),
|
credential_cache_matrix=tuple([numpy.zeros((2), dtype=numpy.int64)] * self.__bounds.maximum_total_credentials),
|
||||||
credential_cache_length=0,
|
credential_cache_length=0,
|
||||||
discovered_node_count=len(self.__discovered_nodes),
|
discovered_node_count=len(self.__discovered_nodes),
|
||||||
discovered_nodes_properties=tuple([numpy.full((self.__bounds.property_count,), 2, dtype=numpy.int32)] * self.__bounds.maximum_node_count),
|
discovered_nodes_properties=numpy.full((self.__bounds.maximum_node_count, self.__bounds.property_count,), 2, dtype=numpy.int32),
|
||||||
nodes_privilegelevel=numpy.zeros((self.bounds.maximum_node_count,), dtype=numpy.int32),
|
nodes_privilegelevel=numpy.zeros((self.bounds.maximum_node_count,), dtype=numpy.int32),
|
||||||
# raw data not actually encoded as a proper gym numeric space
|
# raw data not actually encoded as a proper gym numeric space
|
||||||
# (were previously returned in the 'info' dict)
|
# (were previously returned in the 'info' dict)
|
||||||
|
@ -797,7 +808,7 @@ class CyberBattleEnv(CyberBattleSpaceKind):
|
||||||
vector[properties_indices] = 1
|
vector[properties_indices] = 1
|
||||||
return vector
|
return vector
|
||||||
|
|
||||||
def __get_property_matrix(self) -> Tuple[numpy.ndarray, ...]:
|
def __get_property_matrix(self) -> numpy.ndarray:
|
||||||
"""Return the Node-Property matrix,
|
"""Return the Node-Property matrix,
|
||||||
where 0 means the property is not set for that node
|
where 0 means the property is not set for that node
|
||||||
1 means the property is set for that node
|
1 means the property is set for that node
|
||||||
|
@ -810,11 +821,13 @@ class CyberBattleEnv(CyberBattleSpaceKind):
|
||||||
2nd row: no known properties for the 2nd discovered node
|
2nd row: no known properties for the 2nd discovered node
|
||||||
3rd row: properties of 3rd discovered and owned node"""
|
3rd row: properties of 3rd discovered and owned node"""
|
||||||
property_discovered = [self.__property_vector(node_id, node_info) for node_id, node_info in self._actuator.discovered_nodes()]
|
property_discovered = [self.__property_vector(node_id, node_info) for node_id, node_info in self._actuator.discovered_nodes()]
|
||||||
return self.__pad_tuple_if_requested(
|
as_numpy = numpy.array(self.__pad_tuple_if_requested(
|
||||||
property_discovered,
|
property_discovered,
|
||||||
self.__bounds.property_count,
|
self.__bounds.property_count,
|
||||||
self.__bounds.maximum_node_count,
|
self.__bounds.maximum_node_count,
|
||||||
)
|
))
|
||||||
|
assert as_numpy.shape == (self.__bounds.maximum_node_count, self.__bounds.property_count)
|
||||||
|
return as_numpy
|
||||||
|
|
||||||
def __get__owned_nodes_indices(self) -> List[int]:
|
def __get__owned_nodes_indices(self) -> List[int]:
|
||||||
"""Get list of indices of all owned nodes"""
|
"""Get list of indices of all owned nodes"""
|
||||||
|
@ -896,7 +909,7 @@ class CyberBattleEnv(CyberBattleSpaceKind):
|
||||||
elif isinstance(outcome, model.LateralMove):
|
elif isinstance(outcome, model.LateralMove):
|
||||||
obs["lateral_move"] = numpy.int32(1)
|
obs["lateral_move"] = numpy.int32(1)
|
||||||
elif isinstance(outcome, model.CustomerData):
|
elif isinstance(outcome, model.CustomerData):
|
||||||
obs["customer_data_found"] = (numpy.int32(1),)
|
obs["customer_data_found"] = numpy.int32(1)
|
||||||
elif isinstance(outcome, model.ProbeSucceeded):
|
elif isinstance(outcome, model.ProbeSucceeded):
|
||||||
obs["probe_result"] = numpy.int32(2)
|
obs["probe_result"] = numpy.int32(2)
|
||||||
elif isinstance(outcome, model.ProbeFailed):
|
elif isinstance(outcome, model.ProbeFailed):
|
||||||
|
@ -1179,7 +1192,8 @@ class CyberBattleEnv(CyberBattleSpaceKind):
|
||||||
) -> Tuple[Observation, dict]:
|
) -> Tuple[Observation, dict]:
|
||||||
LOGGER.info("Resetting the CyberBattle environment")
|
LOGGER.info("Resetting the CyberBattle environment")
|
||||||
self.__reset_environment()
|
self.__reset_environment()
|
||||||
self.seed(seed)
|
self.np_random, seed = seeding.np_random(seed)
|
||||||
|
|
||||||
observation = self.__get_blank_observation()
|
observation = self.__get_blank_observation()
|
||||||
observation["action_mask"] = self.compute_action_mask()
|
observation["action_mask"] = self.compute_action_mask()
|
||||||
observation["discovered_nodes_properties"] = self.__get_property_matrix()
|
observation["discovered_nodes_properties"] = self.__get_property_matrix()
|
||||||
|
@ -1209,12 +1223,5 @@ class CyberBattleEnv(CyberBattleSpaceKind):
|
||||||
fig = self.render_as_fig()
|
fig = self.render_as_fig()
|
||||||
fig.show(renderer=self.__renderer)
|
fig.show(renderer=self.__renderer)
|
||||||
|
|
||||||
def seed(self, seed: Optional[int] = None) -> None:
|
|
||||||
if seed is None:
|
|
||||||
self._seed = seed
|
|
||||||
return
|
|
||||||
|
|
||||||
self.np_random, seed = seeding.np_random(seed)
|
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
return None
|
return None
|
||||||
|
|
|
@ -4,11 +4,12 @@ for CyberBattleEnv gym environment.
|
||||||
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from sqlite3 import NotSupportedError
|
from sqlite3 import NotSupportedError
|
||||||
from gymnasium import spaces
|
from gymnasium import Env, spaces
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from cyberbattle._env.cyberbattle_env import DummySpace, CyberBattleEnv, Action, CyberBattleSpaceKind
|
|
||||||
from gymnasium.core import ObservationWrapper, ActionWrapper
|
from gymnasium.core import ObservationWrapper, ActionWrapper
|
||||||
|
|
||||||
|
from cyberbattle._env.cyberbattle_env import Action, CyberBattleEnv
|
||||||
|
|
||||||
|
|
||||||
class FlattenObservationWrapper(ObservationWrapper):
|
class FlattenObservationWrapper(ObservationWrapper):
|
||||||
"""
|
"""
|
||||||
|
@ -22,69 +23,92 @@ class FlattenObservationWrapper(ObservationWrapper):
|
||||||
if isinstance(space, spaces.MultiBinary):
|
if isinstance(space, spaces.MultiBinary):
|
||||||
if type(space.n) in [tuple, list, np.ndarray]:
|
if type(space.n) in [tuple, list, np.ndarray]:
|
||||||
flatten_dim = np.multiply.reduce(space.n)
|
flatten_dim = np.multiply.reduce(space.n)
|
||||||
print(f"// MultiBinary flattened from {space.n} -> {flatten_dim}")
|
flatten_space = spaces.MultiBinary(flatten_dim)
|
||||||
return spaces.MultiBinary(flatten_dim)
|
print(f"// MultiBinary flattened from {space.n} -> {flatten_space.n} - dtype: {space.dtype} -> {flatten_space.dtype}")
|
||||||
|
return flatten_space
|
||||||
else:
|
else:
|
||||||
print(f"// MultiBinary already flat: {space.n}")
|
print(f"// MultiBinary already flat: {space.n}")
|
||||||
return space
|
return space
|
||||||
else:
|
else:
|
||||||
return space
|
return space
|
||||||
|
|
||||||
def __init__(self, env: CyberBattleSpaceKind, ignore_fields=["action_mask"]):
|
def flatten_multidiscrete_space(self, space: spaces.Space):
|
||||||
|
if isinstance(space, spaces.MultiDiscrete):
|
||||||
|
if type(space.nvec) in [tuple, list, np.ndarray]:
|
||||||
|
flatten_space = spaces.MultiDiscrete(space.nvec.flatten())
|
||||||
|
print(f"// MultiDiscrete flattened from {space.nvec} -> {flatten_space.nvec}")
|
||||||
|
return flatten_space
|
||||||
|
else:
|
||||||
|
print(f"// MultiDiscrete already flat: {space.nvec}")
|
||||||
|
return space
|
||||||
|
|
||||||
|
def __init__(self, env: Env, ignore_fields=["action_mask"]):
|
||||||
ObservationWrapper.__init__(self, env)
|
ObservationWrapper.__init__(self, env)
|
||||||
self.env = env
|
self.env = env
|
||||||
self.ignore_fields = ignore_fields
|
self.ignore_fields = ignore_fields
|
||||||
|
if isinstance(env.observation_space, spaces.Dict):
|
||||||
|
space_dict = OrderedDict({})
|
||||||
|
for key, space in env.observation_space.spaces.items():
|
||||||
|
if key in ignore_fields:
|
||||||
|
print("Filtering out field", key)
|
||||||
|
elif isinstance(space, spaces.Dict):
|
||||||
|
for subkey, subspace in space.items():
|
||||||
|
space_dict[f"{key}_{subkey}"] = self.flatten_multibinary_space(subspace)
|
||||||
|
elif isinstance(space, spaces.Tuple):
|
||||||
|
for i, subspace in enumerate(space.spaces):
|
||||||
|
space_dict[f"{key}_{i}"] = self.flatten_multibinary_space(subspace)
|
||||||
|
elif isinstance(space, spaces.MultiBinary):
|
||||||
|
space_dict[key] = self.flatten_multibinary_space(space)
|
||||||
|
elif isinstance(space, spaces.Discrete):
|
||||||
|
space_dict[key] = space
|
||||||
|
elif isinstance(space, spaces.MultiDiscrete):
|
||||||
|
space_dict[key] = self.flatten_multidiscrete_space(space)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Case not handled: {key} - type {type(space)}")
|
||||||
|
|
||||||
space_dict = OrderedDict({})
|
self.observation_space = spaces.Dict(space_dict)
|
||||||
for key, space in env.observation_space.spaces.items():
|
|
||||||
if key in ignore_fields:
|
|
||||||
print("Filtering out field", key)
|
|
||||||
elif isinstance(space, spaces.Dict):
|
|
||||||
for k2, subspace in space.items():
|
|
||||||
space_dict[f"{key}_{k2}"] = self.flatten_multibinary_space(subspace)
|
|
||||||
elif isinstance(space, spaces.Tuple):
|
|
||||||
for i, subspace in enumerate(space.spaces):
|
|
||||||
space_dict[f"{key}_{i}"] = self.flatten_multibinary_space(subspace)
|
|
||||||
elif isinstance(space, spaces.MultiBinary):
|
|
||||||
space_dict[key] = self.flatten_multibinary_space(space)
|
|
||||||
elif isinstance(space, spaces.Discrete) or isinstance(space, spaces.MultiDiscrete):
|
|
||||||
space_dict[key] = space
|
|
||||||
elif isinstance(space, DummySpace):
|
|
||||||
print(f"warning: unsupported observation space: {space} : {type(space)}")
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Case not handled: {key} - type {type(space)}")
|
|
||||||
|
|
||||||
self.observation_space = spaces.Dict(space_dict)
|
|
||||||
|
|
||||||
def flatten_multibinary_observation(self, space, o):
|
def flatten_multibinary_observation(self, space, o):
|
||||||
if isinstance(space, spaces.MultiBinary) and isinstance(space.n, tuple) and len(space.n) > 1:
|
if isinstance(space, spaces.MultiBinary) and isinstance(space.n, tuple) and len(space.n) > 1:
|
||||||
flatten_dim = np.multiply.reduce(space.n)
|
flatten_dim = np.multiply.reduce(space.n)
|
||||||
return tuple(o.reshape(flatten_dim))
|
# print(f"dtype: {o.dtype} shape: {o.shape} -> {flatten_dim}")
|
||||||
|
reshaped = o.reshape(flatten_dim)
|
||||||
|
# print(f"reshaped: {reshaped.dtype} shape: {reshaped.shape}")
|
||||||
|
return reshaped
|
||||||
|
else:
|
||||||
|
return o
|
||||||
|
|
||||||
|
def flatten_multidiscrete_observation(self, space, o):
|
||||||
|
if isinstance(space, spaces.MultiDiscrete):
|
||||||
|
return o.flatten()
|
||||||
else:
|
else:
|
||||||
return o
|
return o
|
||||||
|
|
||||||
def observation(self, observation):
|
def observation(self, observation):
|
||||||
o = OrderedDict({})
|
if isinstance(self.env.observation_space, spaces.Dict):
|
||||||
for key, space in self.env.observation_space.spaces.items():
|
o = OrderedDict({})
|
||||||
value = observation[key]
|
for key, space in self.env.observation_space.spaces.items():
|
||||||
if key in self.ignore_fields:
|
value = observation[key]
|
||||||
continue
|
if key in self.ignore_fields:
|
||||||
elif isinstance(space, spaces.Dict):
|
continue
|
||||||
for subkey, subspace in space.items():
|
elif isinstance(space, spaces.Dict):
|
||||||
o[f"{key}_{subkey}"] = self.flatten_multibinary_observation(subspace, value[subkey])
|
for subkey, subspace in space.items():
|
||||||
elif isinstance(space, spaces.Tuple):
|
o[f"{key}_{subkey}"] = self.flatten_multibinary_observation(subspace, value[subkey])
|
||||||
for i, subspace in enumerate(space.spaces):
|
elif isinstance(space, spaces.Tuple):
|
||||||
o[f"{key}_{i}"] = self.flatten_multibinary_observation(subspace, value[i])
|
for i, subspace in enumerate(space.spaces):
|
||||||
elif isinstance(space, spaces.MultiBinary):
|
o[f"{key}_{i}"] = self.flatten_multibinary_observation(subspace, value[i])
|
||||||
o[key] = self.flatten_multibinary_observation(space, value)
|
elif isinstance(space, spaces.MultiBinary):
|
||||||
elif isinstance(space, spaces.Discrete) or isinstance(space, spaces.MultiDiscrete):
|
o[key] = self.flatten_multibinary_observation(space, value)
|
||||||
o[key] = value
|
elif isinstance(space, spaces.Discrete):
|
||||||
elif isinstance(space, DummySpace):
|
o[key] = value
|
||||||
continue
|
elif isinstance(space, spaces.MultiDiscrete):
|
||||||
else:
|
o[key] = self.flatten_multidiscrete_observation(space, value)
|
||||||
raise NotImplementedError(f"Case not handled: {key} - type {type(space)}")
|
else:
|
||||||
|
raise NotImplementedError(f"Case not handled: {key} - type {type(space)}")
|
||||||
|
|
||||||
return o
|
return o
|
||||||
|
else:
|
||||||
|
return observation
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
"""Returns a modified observation using :meth:`self.observation` after calling :meth:`env.step`."""
|
"""Returns a modified observation using :meth:`self.observation` after calling :meth:`env.step`."""
|
||||||
|
@ -105,7 +129,7 @@ class FlattenActionWrapper(ActionWrapper):
|
||||||
self.env = env
|
self.env = env
|
||||||
|
|
||||||
self.action_space = spaces.MultiDiscrete(
|
self.action_space = spaces.MultiDiscrete(
|
||||||
[
|
np.array([
|
||||||
# connect, local vulnerabilities, remote vulnerabilities
|
# connect, local vulnerabilities, remote vulnerabilities
|
||||||
1 + env.bounds.local_attacks_count + env.bounds.remote_attacks_count,
|
1 + env.bounds.local_attacks_count + env.bounds.remote_attacks_count,
|
||||||
# source node
|
# source node
|
||||||
|
@ -116,7 +140,7 @@ class FlattenActionWrapper(ActionWrapper):
|
||||||
env.bounds.port_count,
|
env.bounds.port_count,
|
||||||
# target port (credentials used, for connect action only)
|
# target port (credentials used, for connect action only)
|
||||||
env.bounds.maximum_total_credentials,
|
env.bounds.maximum_total_credentials,
|
||||||
]
|
], dtype=np.int32)
|
||||||
)
|
)
|
||||||
|
|
||||||
def action(self, action: np.ndarray) -> Action:
|
def action(self, action: np.ndarray) -> Action:
|
||||||
|
|
|
@ -168,7 +168,7 @@ class Feature_discovered_nodeproperties_sliding(GlobalFeature):
|
||||||
|
|
||||||
def get_global(self, a: StateAugmentation) -> ndarray:
|
def get_global(self, a: StateAugmentation) -> ndarray:
|
||||||
n = a.observation["discovered_node_count"]
|
n = a.observation["discovered_node_count"]
|
||||||
node_prop = np.array(a.observation["discovered_nodes_properties"])[:n]
|
node_prop = a.observation["discovered_nodes_properties"][:n]
|
||||||
|
|
||||||
# keep last window of entries
|
# keep last window of entries
|
||||||
node_prop_window = node_prop[-self.window_size :, :]
|
node_prop_window = node_prop[-self.window_size :, :]
|
||||||
|
@ -255,7 +255,7 @@ class Feature_discovered_notowned_node_count(GlobalFeature):
|
||||||
"""number of nodes discovered that are not owned yet (optionally clipped)"""
|
"""number of nodes discovered that are not owned yet (optionally clipped)"""
|
||||||
|
|
||||||
def __init__(self, p: EnvironmentBounds, clip: Optional[int]):
|
def __init__(self, p: EnvironmentBounds, clip: Optional[int]):
|
||||||
self.clip = p.maximum_node_count if clip is None else clip
|
self.clip = np.int32(clip or p.maximum_node_count)
|
||||||
super().__init__(p, [self.clip + 1])
|
super().__init__(p, [self.clip + 1])
|
||||||
|
|
||||||
def get_global(self, a: StateAugmentation):
|
def get_global(self, a: StateAugmentation):
|
||||||
|
@ -263,8 +263,8 @@ class Feature_discovered_notowned_node_count(GlobalFeature):
|
||||||
node_props = np.array(a.observation["discovered_nodes_properties"][:discovered])
|
node_props = np.array(a.observation["discovered_nodes_properties"][:discovered])
|
||||||
# here we assume that a node is owned just if all its properties are known
|
# here we assume that a node is owned just if all its properties are known
|
||||||
owned = np.count_nonzero(np.all(node_props != 2, axis=1))
|
owned = np.count_nonzero(np.all(node_props != 2, axis=1))
|
||||||
diff = discovered - owned
|
diff = np.int32(discovered - owned)
|
||||||
return np.array([min(diff, self.clip)], dtype=np.int_)
|
return np.array( [np.min((diff, self.clip))], dtype=np.int32)
|
||||||
|
|
||||||
|
|
||||||
class Feature_owned_node_count(GlobalFeature):
|
class Feature_owned_node_count(GlobalFeature):
|
||||||
|
|
|
@ -0,0 +1,48 @@
|
||||||
|
'''Stable-baselines agent for CyberBattle Gym environment'''
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import cast
|
||||||
|
from cyberbattle._env.cyberbattle_toyctf import CyberBattleToyCtf
|
||||||
|
from cyberbattle._env.flatten_wrapper import (
|
||||||
|
FlattenObservationWrapper,
|
||||||
|
FlattenActionWrapper,
|
||||||
|
)
|
||||||
|
from stable_baselines3.common.type_aliases import GymEnv
|
||||||
|
from stable_baselines3.common.env_checker import check_env
|
||||||
|
from stable_baselines3.a2c.a2c import A2C
|
||||||
|
|
||||||
|
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
||||||
|
|
||||||
|
def test_stablebaseline3(training_steps=3, eval_steps=10):
|
||||||
|
|
||||||
|
cybersinm_env = CyberBattleToyCtf(
|
||||||
|
maximum_node_count=12,
|
||||||
|
maximum_total_credentials=10,
|
||||||
|
observation_padding=True,
|
||||||
|
throws_on_invalid_actions=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
flatten_action_env = FlattenActionWrapper(cybersinm_env)
|
||||||
|
|
||||||
|
flatten_obs_env = FlattenObservationWrapper(flatten_action_env, ignore_fields=[
|
||||||
|
"_credential_cache",
|
||||||
|
"_discovered_nodes",
|
||||||
|
"_explored_network",
|
||||||
|
])
|
||||||
|
|
||||||
|
env_as_gym = cast(GymEnv, flatten_obs_env)
|
||||||
|
|
||||||
|
check_env(flatten_obs_env)
|
||||||
|
|
||||||
|
model_a2c = A2C("MultiInputPolicy", env_as_gym).learn(training_steps)
|
||||||
|
model_a2c.save("a2c_trained_toyctf")
|
||||||
|
model = A2C("MultiInputPolicy", env_as_gym).load("a2c_trained_toyctf")
|
||||||
|
|
||||||
|
obs , _= env_as_gym.reset()
|
||||||
|
for i in range(eval_steps):
|
||||||
|
assert isinstance(obs, dict)
|
||||||
|
action, _states = model.predict(obs, deterministic=True)
|
||||||
|
obs, reward, done, truncated, info = flatten_obs_env.step(action)
|
||||||
|
|
||||||
|
flatten_obs_env.render()
|
||||||
|
flatten_obs_env.close()
|
|
@ -355,7 +355,7 @@ class Identifiers(NamedTuple):
|
||||||
# Array of all possible node property identifiers
|
# Array of all possible node property identifiers
|
||||||
properties: List[PropertyName] = []
|
properties: List[PropertyName] = []
|
||||||
# Array of all possible port names
|
# Array of all possible port names
|
||||||
ports: List[PortName] = []
|
ports: List[PortName] = ["Null"]
|
||||||
# Array of all possible local vulnerabilities names
|
# Array of all possible local vulnerabilities names
|
||||||
local_vulnerabilities: List[VulnerabilityID] = []
|
local_vulnerabilities: List[VulnerabilityID] = []
|
||||||
# Array of all possible remote vulnerabilities names
|
# Array of all possible remote vulnerabilities names
|
||||||
|
|
|
@ -57,7 +57,7 @@
|
||||||
"from cyberbattle._env.cyberbattle_env import CyberBattleEnv\n",
|
"from cyberbattle._env.cyberbattle_env import CyberBattleEnv\n",
|
||||||
"\n",
|
"\n",
|
||||||
"envs = [cast(CyberBattleEnv, gym.make(gymid)) for gymid in gymids]\n",
|
"envs = [cast(CyberBattleEnv, gym.make(gymid)) for gymid in gymids]\n",
|
||||||
"map(lambda g: g.seed(1), envs)\n",
|
"map(lambda g: g.reset(seed=1), envs)\n",
|
||||||
"ep = w.EnvironmentBounds.of_identifiers(maximum_node_count=30, maximum_total_credentials=50, identifiers=envs[0].identifiers)"
|
"ep = w.EnvironmentBounds.of_identifiers(maximum_node_count=30, maximum_total_credentials=50, identifiers=envs[0].identifiers)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -18119,7 +18119,7 @@
|
||||||
"source": [
|
"source": [
|
||||||
"tiny = cast(CyberBattleEnv, gym.make(f\"ActiveDirectory-v{ngyms}\"))\n",
|
"tiny = cast(CyberBattleEnv, gym.make(f\"ActiveDirectory-v{ngyms}\"))\n",
|
||||||
"current_o, _ = tiny.reset()\n",
|
"current_o, _ = tiny.reset()\n",
|
||||||
"tiny.seed(1)\n",
|
"tiny.reset(seed=1)\n",
|
||||||
"wrapped_env = AgentWrapper(tiny, ActionTrackingStateAugmentation(ep, current_o))\n",
|
"wrapped_env = AgentWrapper(tiny, ActionTrackingStateAugmentation(ep, current_o))\n",
|
||||||
"# Use the trained agent to run the steps one by one\n",
|
"# Use the trained agent to run the steps one by one\n",
|
||||||
"max_steps = 1000\n",
|
"max_steps = 1000\n",
|
||||||
|
@ -18166,4 +18166,4 @@
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
"nbformat_minor": 4
|
"nbformat_minor": 4
|
||||||
}
|
}
|
||||||
|
|
|
@ -394461,9 +394461,9 @@
|
||||||
"cell_metadata_filter": "title,-all"
|
"cell_metadata_filter": "title,-all"
|
||||||
},
|
},
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "Python [conda env:cybersim]",
|
"display_name": "cybersim",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
"name": "conda-env-cybersim-py"
|
"name": "python3"
|
||||||
},
|
},
|
||||||
"language_info": {
|
"language_info": {
|
||||||
"codemirror_mode": {
|
"codemirror_mode": {
|
||||||
|
@ -394492,4 +394492,4 @@
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
"nbformat_minor": 5
|
"nbformat_minor": 5
|
||||||
}
|
}
|
||||||
|
|
|
@ -96,8 +96,8 @@ run notebook_withdefender -y "
|
||||||
"
|
"
|
||||||
|
|
||||||
run dql_active_directory -y "
|
run dql_active_directory -y "
|
||||||
ngyms: 3
|
ngyms: 3
|
||||||
iteration_count: 50
|
iteration_count: 50
|
||||||
"
|
"
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,7 @@
|
||||||
# %%
|
'''Stable-baselines agent for CyberBattle Gym environment'''
|
||||||
# !pip install stable-baselines3[extra]
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
from typing import cast
|
from typing import cast
|
||||||
from cyberbattle._env.cyberbattle_env import CyberBattleEnv
|
|
||||||
from cyberbattle._env.cyberbattle_toyctf import CyberBattleToyCtf
|
from cyberbattle._env.cyberbattle_toyctf import CyberBattleToyCtf
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
|
@ -14,15 +12,13 @@ from cyberbattle._env.flatten_wrapper import (
|
||||||
FlattenActionWrapper,
|
FlattenActionWrapper,
|
||||||
)
|
)
|
||||||
import os
|
import os
|
||||||
import numpy as np
|
|
||||||
from stable_baselines3.common.type_aliases import GymEnv
|
from stable_baselines3.common.type_aliases import GymEnv
|
||||||
|
from stable_baselines3.common.env_checker import check_env
|
||||||
|
|
||||||
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
||||||
retrain = ["a2c"]
|
retrain = ["a2c", "ppo"]
|
||||||
|
|
||||||
|
|
||||||
# %%
|
|
||||||
|
|
||||||
logging.basicConfig(stream=sys.stdout, level=logging.ERROR, format="%(levelname)s: %(message)s")
|
logging.basicConfig(stream=sys.stdout, level=logging.ERROR, format="%(levelname)s: %(message)s")
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
|
@ -33,29 +29,27 @@ env = CyberBattleToyCtf(
|
||||||
throws_on_invalid_actions=False,
|
throws_on_invalid_actions=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
flatten_action_env = FlattenActionWrapper(env)
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
env1 = FlattenActionWrapper(env)
|
flatten_obs_env = FlattenObservationWrapper(flatten_action_env, ignore_fields=[
|
||||||
|
|
||||||
# %%
|
|
||||||
# MultiBinary
|
|
||||||
# 'action_mask',
|
|
||||||
# 'customer_data_found',
|
|
||||||
# MultiDiscrete space
|
|
||||||
# 'nodes_privilegelevel',
|
|
||||||
# 'leaked_credentials',
|
|
||||||
# 'credential_cache_matrix'
|
|
||||||
# 'discovered_nodes_properties',
|
|
||||||
|
|
||||||
ignore_fields = [
|
|
||||||
# DummySpace
|
# DummySpace
|
||||||
"_credential_cache",
|
"_credential_cache",
|
||||||
"_discovered_nodes",
|
"_discovered_nodes",
|
||||||
"_explored_network",
|
"_explored_network",
|
||||||
]
|
])
|
||||||
env2 = FlattenObservationWrapper(cast(CyberBattleEnv, env1), ignore_fields=ignore_fields)
|
|
||||||
|
#%%
|
||||||
|
env_as_gym = cast(GymEnv, flatten_obs_env)
|
||||||
|
|
||||||
|
#%%
|
||||||
|
o, _ = env_as_gym.reset()
|
||||||
|
print(o)
|
||||||
|
|
||||||
|
#%%
|
||||||
|
check_env(flatten_obs_env)
|
||||||
|
|
||||||
env_as_gym = cast(GymEnv, env2)
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
if "a2c" in retrain:
|
if "a2c" in retrain:
|
||||||
|
@ -75,10 +69,16 @@ model = A2C("MultiInputPolicy", env_as_gym).load("a2c_trained_toyctf")
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
obs = env2.reset()
|
obs , _= env_as_gym.reset()
|
||||||
for i in range(1000):
|
|
||||||
action, _states = model.predict(np.array(obs), deterministic=True)
|
|
||||||
obs, reward, done, truncated, info = env2.step(action)
|
|
||||||
|
|
||||||
env2.render()
|
|
||||||
env2.close()
|
# %%
|
||||||
|
for i in range(1000):
|
||||||
|
assert isinstance(obs, dict)
|
||||||
|
action, _states = model.predict(obs, deterministic=True)
|
||||||
|
obs, reward, done, truncated, info = flatten_obs_env.step(action)
|
||||||
|
|
||||||
|
flatten_obs_env.render()
|
||||||
|
flatten_obs_env.close()
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
Загрузка…
Ссылка в новой задаче