Stable-baselines3 agent working - fix shapes of observation spaces (#144)

* Stable-baselines3 agent working - fix shapes of observation spaces

chore: Update stable-baselines3 agent with env_checker

chore: Update ports list with "Null" value


---------

Co-authored-by: William Blum <william.blum@microsoft.com>
This commit is contained in:
William Blum 2024-08-07 14:10:08 -07:00 коммит произвёл GitHub
Родитель 4eabac5e60
Коммит 8f7078c814
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
9 изменённых файлов: 222 добавлений и 143 удалений

Просмотреть файл

@ -22,7 +22,7 @@ from cyberbattle._env.defender import DefenderAgent
from cyberbattle.simulation.model import PortName, PrivilegeLevel from cyberbattle.simulation.model import PortName, PrivilegeLevel
from ..simulation import commandcontrol, model, actions from ..simulation import commandcontrol, model, actions
from .discriminatedunion import DiscriminatedUnion from .discriminatedunion import DiscriminatedUnion
import numpy as np
LOGGER = logging.getLogger(__name__) LOGGER = logging.getLogger(__name__)
@ -71,7 +71,7 @@ Observation = TypedDict(
# whether a lateral move was just performed # whether a lateral move was just performed
"lateral_move": numpy.int32, "lateral_move": numpy.int32,
# whether customer data were just discovered # whether customer data were just discovered
"customer_data_found": Tuple[numpy.int32], "customer_data_found": numpy.int32,
# 0 if there were no probing attempt # 0 if there were no probing attempt
# 1 if an attempted probing failed # 1 if an attempted probing failed
# 2 if an attempted probing succeeded # 2 if an attempted probing succeeded
@ -90,7 +90,7 @@ Observation = TypedDict(
# total nodes discovered so far # total nodes discovered so far
"discovered_node_count": int, "discovered_node_count": int,
# Matrix of properties for all the discovered nodes # Matrix of properties for all the discovered nodes
"discovered_nodes_properties": Tuple[numpy.ndarray, ...], "discovered_nodes_properties": numpy.ndarray,
# Node privilege level on every discovered node (e.g., 0 if not owned, 1 owned, 2 admin, 3 for system) # Node privilege level on every discovered node (e.g., 0 if not owned, 1 owned, 2 admin, 3 for system)
"nodes_privilegelevel": numpy.ndarray, "nodes_privilegelevel": numpy.ndarray,
# Tuple encoding of the credential cache matrix. # Tuple encoding of the credential cache matrix.
@ -183,14 +183,14 @@ class EnvironmentBounds(NamedTuple):
remote_attacks_count - Unique remote vulnerabilities remote_attacks_count - Unique remote vulnerabilities
""" """
maximum_total_credentials: int maximum_total_credentials: np.int32
maximum_node_count: int maximum_node_count: np.int32
maximum_discoverable_credentials_per_action: int maximum_discoverable_credentials_per_action: np.int32
port_count: int port_count: np.int32
property_count: int property_count: np.int32
local_attacks_count: int local_attacks_count: np.int32
remote_attacks_count: int remote_attacks_count: np.int32
@classmethod @classmethod
def of_identifiers( def of_identifiers(
@ -200,16 +200,27 @@ class EnvironmentBounds(NamedTuple):
maximum_node_count: int, maximum_node_count: int,
maximum_discoverable_credentials_per_action: Optional[int] = None, maximum_discoverable_credentials_per_action: Optional[int] = None,
): ):
if not maximum_discoverable_credentials_per_action:
maximum_discoverable_credentials_per_action = maximum_total_credentials maximum_discoverable_credentials_per_action = maximum_discoverable_credentials_per_action or maximum_total_credentials
assert np.can_cast(maximum_total_credentials, np.int32), "maximum_total_credentials must be a 32-bit integer"
assert np.can_cast(maximum_node_count, np.int32), "maximum_node_count must be a 32-bit integer"
assert maximum_total_credentials > 0, "maximum_total_credentials must be positive"
assert maximum_node_count > 0, "maximum_node_count must be positive"
assert np.can_cast(len(identifiers.ports), np.int32), "port_count must be a 32-bit integer"
assert np.can_cast(len(identifiers.properties), np.int32), "property_count must be a 32-bit integer"
assert np.can_cast(len(identifiers.local_vulnerabilities), np.int32), "local_attacks_count must be a 32-bit integer"
assert np.can_cast(len(identifiers.remote_vulnerabilities), np.int32), "remote_attacks_count must be a 32-bit integer"
assert np.can_cast(maximum_discoverable_credentials_per_action, np.int32), "maximum_discoverable_credentials_per_action must be a 32-bit integer"
return EnvironmentBounds( return EnvironmentBounds(
maximum_total_credentials=maximum_total_credentials, maximum_total_credentials=np.int32(maximum_total_credentials),
maximum_node_count=maximum_node_count, maximum_node_count=np.int32(maximum_node_count),
maximum_discoverable_credentials_per_action=maximum_discoverable_credentials_per_action, maximum_discoverable_credentials_per_action=np.int32(maximum_discoverable_credentials_per_action),
port_count=len(identifiers.ports), port_count=np.int32(len(identifiers.ports)),
property_count=len(identifiers.properties), property_count=np.int32(len(identifiers.properties)),
local_attacks_count=len(identifiers.local_vulnerabilities), local_attacks_count=np.int32(len(identifiers.local_vulnerabilities)),
remote_attacks_count=len(identifiers.remote_vulnerabilities), remote_attacks_count=np.int32(len(identifiers.remote_vulnerabilities)),
) )
@ -252,7 +263,7 @@ class ObservationSpaceType(spaces.Dict):
# successuflly moved to the target node (1) or not (0) # successuflly moved to the target node (1) or not (0)
"lateral_move": spaces.Discrete(2), "lateral_move": spaces.Discrete(2),
# boolean: 1 if customer secret data were discovered, 0 otherwise # boolean: 1 if customer secret data were discovered, 0 otherwise
"customer_data_found": spaces.MultiBinary(2), "customer_data_found": spaces.Discrete(2),
# whether an attempted probing succeeded or not # whether an attempted probing succeeded or not
"probe_result": spaces.Discrete(3), "probe_result": spaces.Discrete(3),
# Esclation result # Esclation result
@ -271,12 +282,12 @@ class ObservationSpaceType(spaces.Dict):
# that was used to authenticat to target node 56 on port number 22 (e.g. SSH) # that was used to authenticat to target node 56 on port number 22 (e.g. SSH)
[ [
spaces.MultiDiscrete( spaces.MultiDiscrete(
[ np.array([
NA + 1, NA + 1,
bounds.maximum_total_credentials, bounds.maximum_total_credentials,
bounds.maximum_node_count, bounds.maximum_node_count,
bounds.port_count, bounds.port_count,
] ], dtype=np.int32)
) )
] ]
* bounds.maximum_discoverable_credentials_per_action * bounds.maximum_discoverable_credentials_per_action
@ -287,9 +298,9 @@ class ObservationSpaceType(spaces.Dict):
# even though such action would 'fail' and potentially yield a negative reward. # even though such action would 'fail' and potentially yield a negative reward.
"action_mask": spaces.Dict( "action_mask": spaces.Dict(
{ {
"local_vulnerability": spaces.MultiBinary(bounds.maximum_node_count * bounds.local_attacks_count), "local_vulnerability": spaces.MultiBinary(np.array([bounds.maximum_node_count, bounds.local_attacks_count])),
"remote_vulnerability": spaces.MultiBinary(bounds.maximum_node_count * bounds.maximum_node_count * bounds.remote_attacks_count), "remote_vulnerability": spaces.MultiBinary(np.array([bounds.maximum_node_count, bounds.maximum_node_count, bounds.remote_attacks_count])),
"connect": spaces.MultiBinary(bounds.maximum_node_count * bounds.maximum_node_count * bounds.port_count * bounds.maximum_total_credentials), "connect": spaces.MultiBinary(np.array([bounds.maximum_node_count, bounds.maximum_node_count, bounds.port_count, bounds.maximum_total_credentials], dtype=np.int32))
} }
), ),
# size of the credential stack # size of the credential stack
@ -298,7 +309,7 @@ class ObservationSpaceType(spaces.Dict):
"discovered_node_count": spaces.Discrete(bounds.maximum_node_count), "discovered_node_count": spaces.Discrete(bounds.maximum_node_count),
# Matrix of properties for all the discovered nodes # Matrix of properties for all the discovered nodes
# 3 values for each matrix cell: set, unset, unknown # 3 values for each matrix cell: set, unset, unknown
"discovered_nodes_properties": spaces.MultiDiscrete([3] * bounds.maximum_node_count * bounds.property_count), "discovered_nodes_properties": spaces.MultiDiscrete(np.full(shape=(bounds.maximum_node_count, bounds.property_count), fill_value=3)),
# Escalation level on every discovered node (e.g., 0 if not owned, 1 for admin, 2 for system) # Escalation level on every discovered node (e.g., 0 if not owned, 1 for admin, 2 for system)
"nodes_privilegelevel": spaces.MultiDiscrete([CyberBattleEnv.privilege_levels] * bounds.maximum_node_count), "nodes_privilegelevel": spaces.MultiDiscrete([CyberBattleEnv.privilege_levels] * bounds.maximum_node_count),
# Encoding of the credential cache of shape: (credential_cache_length, 2) # Encoding of the credential cache of shape: (credential_cache_length, 2)
@ -306,7 +317,7 @@ class ObservationSpaceType(spaces.Dict):
# Each row represent a discovered credential, # Each row represent a discovered credential,
# the credential index is given by the row index (i.e. order of discovery) # the credential index is given by the row index (i.e. order of discovery)
# A row is of the form: (target_node_discover_index, port_index) # A row is of the form: (target_node_discover_index, port_index)
"credential_cache_matrix": spaces.Tuple([spaces.MultiDiscrete([bounds.maximum_node_count, bounds.port_count])] * bounds.maximum_total_credentials), "credential_cache_matrix": spaces.Tuple([spaces.MultiDiscrete(np.array([bounds.maximum_node_count, bounds.port_count],dtype=np.int32))] * bounds.maximum_total_credentials),
# --------------------------------------------------------- # ---------------------------------------------------------
# Fields that were previously in the 'info' dict: # Fields that were previously in the 'info' dict:
# --------------------------------------------------------- # ---------------------------------------------------------
@ -460,7 +471,7 @@ class CyberBattleEnv(CyberBattleSpaceKind):
self, self,
initial_environment: model.Environment, initial_environment: model.Environment,
maximum_total_credentials: int = 1000, maximum_total_credentials: int = 1000,
maximum_node_count: int = 100, maximum_node_count: int = 100,
maximum_discoverable_credentials_per_action: int = 5, maximum_discoverable_credentials_per_action: int = 5,
defender_agent: Optional[DefenderAgent] = None, defender_agent: Optional[DefenderAgent] = None,
attacker_goal: Optional[AttackerGoal] = AttackerGoal(own_atleast_percent=1.0), attacker_goal: Optional[AttackerGoal] = AttackerGoal(own_atleast_percent=1.0),
@ -523,27 +534,27 @@ class CyberBattleEnv(CyberBattleSpaceKind):
# The Space object defining the valid actions of an attacker. # The Space object defining the valid actions of an attacker.
local_vulnerabilities_count = self.__bounds.local_attacks_count local_vulnerabilities_count = self.__bounds.local_attacks_count
remote_vulnerabilities_count = self.__bounds.remote_attacks_count remote_vulnerabilities_count = self.__bounds.remote_attacks_count
maximum_node_count = self.__bounds.maximum_node_count maximum_node_count_int32 = self.__bounds.maximum_node_count
port_count = self.__bounds.port_count port_count = self.__bounds.port_count
action_spaces = { action_spaces = {
"local_vulnerability": spaces.MultiDiscrete( "local_vulnerability": spaces.MultiDiscrete(
# source_node_id, vulnerability_id # source_node_id, vulnerability_id
[maximum_node_count, local_vulnerabilities_count] np.array([maximum_node_count_int32, local_vulnerabilities_count], dtype=np.int32)
), ),
"remote_vulnerability": spaces.MultiDiscrete( "remote_vulnerability": spaces.MultiDiscrete(
# source_node_id, target_node_id, vulnerability_id # source_node_id, target_node_id, vulnerability_id
[maximum_node_count, maximum_node_count, remote_vulnerabilities_count] np.array([maximum_node_count_int32, maximum_node_count_int32, remote_vulnerabilities_count], dtype=np.int32)
), ),
"connect": spaces.MultiDiscrete( "connect": spaces.MultiDiscrete(
# source_node_id, target_node_id, target_port, credential_id # source_node_id, target_node_id, target_port, credential_id
# (by index of discovery: 0 for initial node, 1 for first discovered node, ...) # (by index of discovery: 0 for initial node, 1 for first discovered node, ...)
[ np.array([
maximum_node_count, maximum_node_count_int32,
maximum_node_count, maximum_node_count_int32,
port_count, port_count,
maximum_total_credentials, maximum_total_credentials,
] ], dtype=np.int32)
), ),
} }
@ -613,10 +624,10 @@ class CyberBattleEnv(CyberBattleSpaceKind):
local_vulnerabilities_count = self.__bounds.local_attacks_count local_vulnerabilities_count = self.__bounds.local_attacks_count
remote_vulnerabilities_count = self.__bounds.remote_attacks_count remote_vulnerabilities_count = self.__bounds.remote_attacks_count
port_count = self.__bounds.port_count port_count = self.__bounds.port_count
local = numpy.zeros(shape=(max_node_count, local_vulnerabilities_count), dtype=numpy.int32) local = numpy.zeros(shape=(max_node_count, local_vulnerabilities_count), dtype=numpy.int8)
remote = numpy.zeros( remote = numpy.zeros(
shape=(max_node_count, max_node_count, remote_vulnerabilities_count), shape=(max_node_count, max_node_count, remote_vulnerabilities_count),
dtype=numpy.int32, dtype=numpy.int8,
) )
connect = numpy.zeros( connect = numpy.zeros(
shape=( shape=(
@ -625,7 +636,7 @@ class CyberBattleEnv(CyberBattleSpaceKind):
port_count, port_count,
self.__bounds.maximum_total_credentials, self.__bounds.maximum_total_credentials,
), ),
dtype=numpy.int32, dtype=numpy.int8,
) )
return ActionMask(local_vulnerability=local, remote_vulnerability=remote, connect=connect) return ActionMask(local_vulnerability=local, remote_vulnerability=remote, connect=connect)
@ -744,14 +755,14 @@ class CyberBattleEnv(CyberBattleSpaceKind):
newly_discovered_nodes_count=numpy.int32(0), newly_discovered_nodes_count=numpy.int32(0),
leaked_credentials=tuple([numpy.array([UNUSED_SLOT, 0, 0, 0], dtype=numpy.int32)] * self.__bounds.maximum_discoverable_credentials_per_action), leaked_credentials=tuple([numpy.array([UNUSED_SLOT, 0, 0, 0], dtype=numpy.int32)] * self.__bounds.maximum_discoverable_credentials_per_action),
lateral_move=numpy.int32(0), lateral_move=numpy.int32(0),
customer_data_found=(numpy.int32(0),), customer_data_found=numpy.int32(0),
escalation=numpy.int32(PrivilegeLevel.NoAccess), escalation=numpy.int32(PrivilegeLevel.NoAccess),
action_mask=self.__get_blank_action_mask(), action_mask=self.__get_blank_action_mask(),
probe_result=numpy.int32(0), probe_result=numpy.int32(0),
credential_cache_matrix=tuple([numpy.zeros((2))] * self.__bounds.maximum_total_credentials), credential_cache_matrix=tuple([numpy.zeros((2), dtype=numpy.int64)] * self.__bounds.maximum_total_credentials),
credential_cache_length=0, credential_cache_length=0,
discovered_node_count=len(self.__discovered_nodes), discovered_node_count=len(self.__discovered_nodes),
discovered_nodes_properties=tuple([numpy.full((self.__bounds.property_count,), 2, dtype=numpy.int32)] * self.__bounds.maximum_node_count), discovered_nodes_properties=numpy.full((self.__bounds.maximum_node_count, self.__bounds.property_count,), 2, dtype=numpy.int32),
nodes_privilegelevel=numpy.zeros((self.bounds.maximum_node_count,), dtype=numpy.int32), nodes_privilegelevel=numpy.zeros((self.bounds.maximum_node_count,), dtype=numpy.int32),
# raw data not actually encoded as a proper gym numeric space # raw data not actually encoded as a proper gym numeric space
# (were previously returned in the 'info' dict) # (were previously returned in the 'info' dict)
@ -797,7 +808,7 @@ class CyberBattleEnv(CyberBattleSpaceKind):
vector[properties_indices] = 1 vector[properties_indices] = 1
return vector return vector
def __get_property_matrix(self) -> Tuple[numpy.ndarray, ...]: def __get_property_matrix(self) -> numpy.ndarray:
"""Return the Node-Property matrix, """Return the Node-Property matrix,
where 0 means the property is not set for that node where 0 means the property is not set for that node
1 means the property is set for that node 1 means the property is set for that node
@ -810,11 +821,13 @@ class CyberBattleEnv(CyberBattleSpaceKind):
2nd row: no known properties for the 2nd discovered node 2nd row: no known properties for the 2nd discovered node
3rd row: properties of 3rd discovered and owned node""" 3rd row: properties of 3rd discovered and owned node"""
property_discovered = [self.__property_vector(node_id, node_info) for node_id, node_info in self._actuator.discovered_nodes()] property_discovered = [self.__property_vector(node_id, node_info) for node_id, node_info in self._actuator.discovered_nodes()]
return self.__pad_tuple_if_requested( as_numpy = numpy.array(self.__pad_tuple_if_requested(
property_discovered, property_discovered,
self.__bounds.property_count, self.__bounds.property_count,
self.__bounds.maximum_node_count, self.__bounds.maximum_node_count,
) ))
assert as_numpy.shape == (self.__bounds.maximum_node_count, self.__bounds.property_count)
return as_numpy
def __get__owned_nodes_indices(self) -> List[int]: def __get__owned_nodes_indices(self) -> List[int]:
"""Get list of indices of all owned nodes""" """Get list of indices of all owned nodes"""
@ -896,7 +909,7 @@ class CyberBattleEnv(CyberBattleSpaceKind):
elif isinstance(outcome, model.LateralMove): elif isinstance(outcome, model.LateralMove):
obs["lateral_move"] = numpy.int32(1) obs["lateral_move"] = numpy.int32(1)
elif isinstance(outcome, model.CustomerData): elif isinstance(outcome, model.CustomerData):
obs["customer_data_found"] = (numpy.int32(1),) obs["customer_data_found"] = numpy.int32(1)
elif isinstance(outcome, model.ProbeSucceeded): elif isinstance(outcome, model.ProbeSucceeded):
obs["probe_result"] = numpy.int32(2) obs["probe_result"] = numpy.int32(2)
elif isinstance(outcome, model.ProbeFailed): elif isinstance(outcome, model.ProbeFailed):
@ -1179,7 +1192,8 @@ class CyberBattleEnv(CyberBattleSpaceKind):
) -> Tuple[Observation, dict]: ) -> Tuple[Observation, dict]:
LOGGER.info("Resetting the CyberBattle environment") LOGGER.info("Resetting the CyberBattle environment")
self.__reset_environment() self.__reset_environment()
self.seed(seed) self.np_random, seed = seeding.np_random(seed)
observation = self.__get_blank_observation() observation = self.__get_blank_observation()
observation["action_mask"] = self.compute_action_mask() observation["action_mask"] = self.compute_action_mask()
observation["discovered_nodes_properties"] = self.__get_property_matrix() observation["discovered_nodes_properties"] = self.__get_property_matrix()
@ -1209,12 +1223,5 @@ class CyberBattleEnv(CyberBattleSpaceKind):
fig = self.render_as_fig() fig = self.render_as_fig()
fig.show(renderer=self.__renderer) fig.show(renderer=self.__renderer)
def seed(self, seed: Optional[int] = None) -> None:
if seed is None:
self._seed = seed
return
self.np_random, seed = seeding.np_random(seed)
def close(self) -> None: def close(self) -> None:
return None return None

Просмотреть файл

@ -4,11 +4,12 @@ for CyberBattleEnv gym environment.
from collections import OrderedDict from collections import OrderedDict
from sqlite3 import NotSupportedError from sqlite3 import NotSupportedError
from gymnasium import spaces from gymnasium import Env, spaces
import numpy as np import numpy as np
from cyberbattle._env.cyberbattle_env import DummySpace, CyberBattleEnv, Action, CyberBattleSpaceKind
from gymnasium.core import ObservationWrapper, ActionWrapper from gymnasium.core import ObservationWrapper, ActionWrapper
from cyberbattle._env.cyberbattle_env import Action, CyberBattleEnv
class FlattenObservationWrapper(ObservationWrapper): class FlattenObservationWrapper(ObservationWrapper):
""" """
@ -22,69 +23,92 @@ class FlattenObservationWrapper(ObservationWrapper):
if isinstance(space, spaces.MultiBinary): if isinstance(space, spaces.MultiBinary):
if type(space.n) in [tuple, list, np.ndarray]: if type(space.n) in [tuple, list, np.ndarray]:
flatten_dim = np.multiply.reduce(space.n) flatten_dim = np.multiply.reduce(space.n)
print(f"// MultiBinary flattened from {space.n} -> {flatten_dim}") flatten_space = spaces.MultiBinary(flatten_dim)
return spaces.MultiBinary(flatten_dim) print(f"// MultiBinary flattened from {space.n} -> {flatten_space.n} - dtype: {space.dtype} -> {flatten_space.dtype}")
return flatten_space
else: else:
print(f"// MultiBinary already flat: {space.n}") print(f"// MultiBinary already flat: {space.n}")
return space return space
else: else:
return space return space
def __init__(self, env: CyberBattleSpaceKind, ignore_fields=["action_mask"]): def flatten_multidiscrete_space(self, space: spaces.Space):
if isinstance(space, spaces.MultiDiscrete):
if type(space.nvec) in [tuple, list, np.ndarray]:
flatten_space = spaces.MultiDiscrete(space.nvec.flatten())
print(f"// MultiDiscrete flattened from {space.nvec} -> {flatten_space.nvec}")
return flatten_space
else:
print(f"// MultiDiscrete already flat: {space.nvec}")
return space
def __init__(self, env: Env, ignore_fields=["action_mask"]):
ObservationWrapper.__init__(self, env) ObservationWrapper.__init__(self, env)
self.env = env self.env = env
self.ignore_fields = ignore_fields self.ignore_fields = ignore_fields
if isinstance(env.observation_space, spaces.Dict):
space_dict = OrderedDict({})
for key, space in env.observation_space.spaces.items():
if key in ignore_fields:
print("Filtering out field", key)
elif isinstance(space, spaces.Dict):
for subkey, subspace in space.items():
space_dict[f"{key}_{subkey}"] = self.flatten_multibinary_space(subspace)
elif isinstance(space, spaces.Tuple):
for i, subspace in enumerate(space.spaces):
space_dict[f"{key}_{i}"] = self.flatten_multibinary_space(subspace)
elif isinstance(space, spaces.MultiBinary):
space_dict[key] = self.flatten_multibinary_space(space)
elif isinstance(space, spaces.Discrete):
space_dict[key] = space
elif isinstance(space, spaces.MultiDiscrete):
space_dict[key] = self.flatten_multidiscrete_space(space)
else:
raise NotImplementedError(f"Case not handled: {key} - type {type(space)}")
space_dict = OrderedDict({}) self.observation_space = spaces.Dict(space_dict)
for key, space in env.observation_space.spaces.items():
if key in ignore_fields:
print("Filtering out field", key)
elif isinstance(space, spaces.Dict):
for k2, subspace in space.items():
space_dict[f"{key}_{k2}"] = self.flatten_multibinary_space(subspace)
elif isinstance(space, spaces.Tuple):
for i, subspace in enumerate(space.spaces):
space_dict[f"{key}_{i}"] = self.flatten_multibinary_space(subspace)
elif isinstance(space, spaces.MultiBinary):
space_dict[key] = self.flatten_multibinary_space(space)
elif isinstance(space, spaces.Discrete) or isinstance(space, spaces.MultiDiscrete):
space_dict[key] = space
elif isinstance(space, DummySpace):
print(f"warning: unsupported observation space: {space} : {type(space)}")
else:
raise NotImplementedError(f"Case not handled: {key} - type {type(space)}")
self.observation_space = spaces.Dict(space_dict)
def flatten_multibinary_observation(self, space, o): def flatten_multibinary_observation(self, space, o):
if isinstance(space, spaces.MultiBinary) and isinstance(space.n, tuple) and len(space.n) > 1: if isinstance(space, spaces.MultiBinary) and isinstance(space.n, tuple) and len(space.n) > 1:
flatten_dim = np.multiply.reduce(space.n) flatten_dim = np.multiply.reduce(space.n)
return tuple(o.reshape(flatten_dim)) # print(f"dtype: {o.dtype} shape: {o.shape} -> {flatten_dim}")
reshaped = o.reshape(flatten_dim)
# print(f"reshaped: {reshaped.dtype} shape: {reshaped.shape}")
return reshaped
else:
return o
def flatten_multidiscrete_observation(self, space, o):
if isinstance(space, spaces.MultiDiscrete):
return o.flatten()
else: else:
return o return o
def observation(self, observation): def observation(self, observation):
o = OrderedDict({}) if isinstance(self.env.observation_space, spaces.Dict):
for key, space in self.env.observation_space.spaces.items(): o = OrderedDict({})
value = observation[key] for key, space in self.env.observation_space.spaces.items():
if key in self.ignore_fields: value = observation[key]
continue if key in self.ignore_fields:
elif isinstance(space, spaces.Dict): continue
for subkey, subspace in space.items(): elif isinstance(space, spaces.Dict):
o[f"{key}_{subkey}"] = self.flatten_multibinary_observation(subspace, value[subkey]) for subkey, subspace in space.items():
elif isinstance(space, spaces.Tuple): o[f"{key}_{subkey}"] = self.flatten_multibinary_observation(subspace, value[subkey])
for i, subspace in enumerate(space.spaces): elif isinstance(space, spaces.Tuple):
o[f"{key}_{i}"] = self.flatten_multibinary_observation(subspace, value[i]) for i, subspace in enumerate(space.spaces):
elif isinstance(space, spaces.MultiBinary): o[f"{key}_{i}"] = self.flatten_multibinary_observation(subspace, value[i])
o[key] = self.flatten_multibinary_observation(space, value) elif isinstance(space, spaces.MultiBinary):
elif isinstance(space, spaces.Discrete) or isinstance(space, spaces.MultiDiscrete): o[key] = self.flatten_multibinary_observation(space, value)
o[key] = value elif isinstance(space, spaces.Discrete):
elif isinstance(space, DummySpace): o[key] = value
continue elif isinstance(space, spaces.MultiDiscrete):
else: o[key] = self.flatten_multidiscrete_observation(space, value)
raise NotImplementedError(f"Case not handled: {key} - type {type(space)}") else:
raise NotImplementedError(f"Case not handled: {key} - type {type(space)}")
return o return o
else:
return observation
def step(self, action): def step(self, action):
"""Returns a modified observation using :meth:`self.observation` after calling :meth:`env.step`.""" """Returns a modified observation using :meth:`self.observation` after calling :meth:`env.step`."""
@ -105,7 +129,7 @@ class FlattenActionWrapper(ActionWrapper):
self.env = env self.env = env
self.action_space = spaces.MultiDiscrete( self.action_space = spaces.MultiDiscrete(
[ np.array([
# connect, local vulnerabilities, remote vulnerabilities # connect, local vulnerabilities, remote vulnerabilities
1 + env.bounds.local_attacks_count + env.bounds.remote_attacks_count, 1 + env.bounds.local_attacks_count + env.bounds.remote_attacks_count,
# source node # source node
@ -116,7 +140,7 @@ class FlattenActionWrapper(ActionWrapper):
env.bounds.port_count, env.bounds.port_count,
# target port (credentials used, for connect action only) # target port (credentials used, for connect action only)
env.bounds.maximum_total_credentials, env.bounds.maximum_total_credentials,
] ], dtype=np.int32)
) )
def action(self, action: np.ndarray) -> Action: def action(self, action: np.ndarray) -> Action:

Просмотреть файл

@ -168,7 +168,7 @@ class Feature_discovered_nodeproperties_sliding(GlobalFeature):
def get_global(self, a: StateAugmentation) -> ndarray: def get_global(self, a: StateAugmentation) -> ndarray:
n = a.observation["discovered_node_count"] n = a.observation["discovered_node_count"]
node_prop = np.array(a.observation["discovered_nodes_properties"])[:n] node_prop = a.observation["discovered_nodes_properties"][:n]
# keep last window of entries # keep last window of entries
node_prop_window = node_prop[-self.window_size :, :] node_prop_window = node_prop[-self.window_size :, :]
@ -255,7 +255,7 @@ class Feature_discovered_notowned_node_count(GlobalFeature):
"""number of nodes discovered that are not owned yet (optionally clipped)""" """number of nodes discovered that are not owned yet (optionally clipped)"""
def __init__(self, p: EnvironmentBounds, clip: Optional[int]): def __init__(self, p: EnvironmentBounds, clip: Optional[int]):
self.clip = p.maximum_node_count if clip is None else clip self.clip = np.int32(clip or p.maximum_node_count)
super().__init__(p, [self.clip + 1]) super().__init__(p, [self.clip + 1])
def get_global(self, a: StateAugmentation): def get_global(self, a: StateAugmentation):
@ -263,8 +263,8 @@ class Feature_discovered_notowned_node_count(GlobalFeature):
node_props = np.array(a.observation["discovered_nodes_properties"][:discovered]) node_props = np.array(a.observation["discovered_nodes_properties"][:discovered])
# here we assume that a node is owned just if all its properties are known # here we assume that a node is owned just if all its properties are known
owned = np.count_nonzero(np.all(node_props != 2, axis=1)) owned = np.count_nonzero(np.all(node_props != 2, axis=1))
diff = discovered - owned diff = np.int32(discovered - owned)
return np.array([min(diff, self.clip)], dtype=np.int_) return np.array( [np.min((diff, self.clip))], dtype=np.int32)
class Feature_owned_node_count(GlobalFeature): class Feature_owned_node_count(GlobalFeature):

Просмотреть файл

@ -0,0 +1,48 @@
'''Stable-baselines agent for CyberBattle Gym environment'''
import os
from typing import cast
from cyberbattle._env.cyberbattle_toyctf import CyberBattleToyCtf
from cyberbattle._env.flatten_wrapper import (
FlattenObservationWrapper,
FlattenActionWrapper,
)
from stable_baselines3.common.type_aliases import GymEnv
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.a2c.a2c import A2C
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
def test_stablebaseline3(training_steps=3, eval_steps=10):
cybersinm_env = CyberBattleToyCtf(
maximum_node_count=12,
maximum_total_credentials=10,
observation_padding=True,
throws_on_invalid_actions=False,
)
flatten_action_env = FlattenActionWrapper(cybersinm_env)
flatten_obs_env = FlattenObservationWrapper(flatten_action_env, ignore_fields=[
"_credential_cache",
"_discovered_nodes",
"_explored_network",
])
env_as_gym = cast(GymEnv, flatten_obs_env)
check_env(flatten_obs_env)
model_a2c = A2C("MultiInputPolicy", env_as_gym).learn(training_steps)
model_a2c.save("a2c_trained_toyctf")
model = A2C("MultiInputPolicy", env_as_gym).load("a2c_trained_toyctf")
obs , _= env_as_gym.reset()
for i in range(eval_steps):
assert isinstance(obs, dict)
action, _states = model.predict(obs, deterministic=True)
obs, reward, done, truncated, info = flatten_obs_env.step(action)
flatten_obs_env.render()
flatten_obs_env.close()

Просмотреть файл

@ -355,7 +355,7 @@ class Identifiers(NamedTuple):
# Array of all possible node property identifiers # Array of all possible node property identifiers
properties: List[PropertyName] = [] properties: List[PropertyName] = []
# Array of all possible port names # Array of all possible port names
ports: List[PortName] = [] ports: List[PortName] = ["Null"]
# Array of all possible local vulnerabilities names # Array of all possible local vulnerabilities names
local_vulnerabilities: List[VulnerabilityID] = [] local_vulnerabilities: List[VulnerabilityID] = []
# Array of all possible remote vulnerabilities names # Array of all possible remote vulnerabilities names

Просмотреть файл

@ -57,7 +57,7 @@
"from cyberbattle._env.cyberbattle_env import CyberBattleEnv\n", "from cyberbattle._env.cyberbattle_env import CyberBattleEnv\n",
"\n", "\n",
"envs = [cast(CyberBattleEnv, gym.make(gymid)) for gymid in gymids]\n", "envs = [cast(CyberBattleEnv, gym.make(gymid)) for gymid in gymids]\n",
"map(lambda g: g.seed(1), envs)\n", "map(lambda g: g.reset(seed=1), envs)\n",
"ep = w.EnvironmentBounds.of_identifiers(maximum_node_count=30, maximum_total_credentials=50, identifiers=envs[0].identifiers)" "ep = w.EnvironmentBounds.of_identifiers(maximum_node_count=30, maximum_total_credentials=50, identifiers=envs[0].identifiers)"
] ]
}, },
@ -18119,7 +18119,7 @@
"source": [ "source": [
"tiny = cast(CyberBattleEnv, gym.make(f\"ActiveDirectory-v{ngyms}\"))\n", "tiny = cast(CyberBattleEnv, gym.make(f\"ActiveDirectory-v{ngyms}\"))\n",
"current_o, _ = tiny.reset()\n", "current_o, _ = tiny.reset()\n",
"tiny.seed(1)\n", "tiny.reset(seed=1)\n",
"wrapped_env = AgentWrapper(tiny, ActionTrackingStateAugmentation(ep, current_o))\n", "wrapped_env = AgentWrapper(tiny, ActionTrackingStateAugmentation(ep, current_o))\n",
"# Use the trained agent to run the steps one by one\n", "# Use the trained agent to run the steps one by one\n",
"max_steps = 1000\n", "max_steps = 1000\n",
@ -18166,4 +18166,4 @@
}, },
"nbformat": 4, "nbformat": 4,
"nbformat_minor": 4 "nbformat_minor": 4
} }

Просмотреть файл

@ -394461,9 +394461,9 @@
"cell_metadata_filter": "title,-all" "cell_metadata_filter": "title,-all"
}, },
"kernelspec": { "kernelspec": {
"display_name": "Python [conda env:cybersim]", "display_name": "cybersim",
"language": "python", "language": "python",
"name": "conda-env-cybersim-py" "name": "python3"
}, },
"language_info": { "language_info": {
"codemirror_mode": { "codemirror_mode": {
@ -394492,4 +394492,4 @@
}, },
"nbformat": 4, "nbformat": 4,
"nbformat_minor": 5 "nbformat_minor": 5
} }

Просмотреть файл

@ -96,8 +96,8 @@ run notebook_withdefender -y "
" "
run dql_active_directory -y " run dql_active_directory -y "
ngyms: 3 ngyms: 3
iteration_count: 50 iteration_count: 50
" "

Просмотреть файл

@ -1,9 +1,7 @@
# %% '''Stable-baselines agent for CyberBattle Gym environment'''
# !pip install stable-baselines3[extra]
# %% # %%
from typing import cast from typing import cast
from cyberbattle._env.cyberbattle_env import CyberBattleEnv
from cyberbattle._env.cyberbattle_toyctf import CyberBattleToyCtf from cyberbattle._env.cyberbattle_toyctf import CyberBattleToyCtf
import logging import logging
import sys import sys
@ -14,15 +12,13 @@ from cyberbattle._env.flatten_wrapper import (
FlattenActionWrapper, FlattenActionWrapper,
) )
import os import os
import numpy as np
from stable_baselines3.common.type_aliases import GymEnv from stable_baselines3.common.type_aliases import GymEnv
from stable_baselines3.common.env_checker import check_env
os.environ["CUDA_LAUNCH_BLOCKING"] = "1" os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
retrain = ["a2c"] retrain = ["a2c", "ppo"]
# %%
logging.basicConfig(stream=sys.stdout, level=logging.ERROR, format="%(levelname)s: %(message)s") logging.basicConfig(stream=sys.stdout, level=logging.ERROR, format="%(levelname)s: %(message)s")
# %% # %%
@ -33,29 +29,27 @@ env = CyberBattleToyCtf(
throws_on_invalid_actions=False, throws_on_invalid_actions=False,
) )
# %%
flatten_action_env = FlattenActionWrapper(env)
# %% # %%
env1 = FlattenActionWrapper(env) flatten_obs_env = FlattenObservationWrapper(flatten_action_env, ignore_fields=[
# %%
# MultiBinary
# 'action_mask',
# 'customer_data_found',
# MultiDiscrete space
# 'nodes_privilegelevel',
# 'leaked_credentials',
# 'credential_cache_matrix'
# 'discovered_nodes_properties',
ignore_fields = [
# DummySpace # DummySpace
"_credential_cache", "_credential_cache",
"_discovered_nodes", "_discovered_nodes",
"_explored_network", "_explored_network",
] ])
env2 = FlattenObservationWrapper(cast(CyberBattleEnv, env1), ignore_fields=ignore_fields)
#%%
env_as_gym = cast(GymEnv, flatten_obs_env)
#%%
o, _ = env_as_gym.reset()
print(o)
#%%
check_env(flatten_obs_env)
env_as_gym = cast(GymEnv, env2)
# %% # %%
if "a2c" in retrain: if "a2c" in retrain:
@ -75,10 +69,16 @@ model = A2C("MultiInputPolicy", env_as_gym).load("a2c_trained_toyctf")
# %% # %%
obs = env2.reset() obs , _= env_as_gym.reset()
for i in range(1000):
action, _states = model.predict(np.array(obs), deterministic=True)
obs, reward, done, truncated, info = env2.step(action)
env2.render()
env2.close() # %%
for i in range(1000):
assert isinstance(obs, dict)
action, _states = model.predict(obs, deterministic=True)
obs, reward, done, truncated, info = flatten_obs_env.step(action)
flatten_obs_env.render()
flatten_obs_env.close()
# %%