Stable-baselines3 agent working - fix shapes of observation spaces (#144)
* Stable-baselines3 agent working - fix shapes of observation spaces chore: Update stable-baselines3 agent with env_checker chore: Update ports list with "Null" value --------- Co-authored-by: William Blum <william.blum@microsoft.com>
This commit is contained in:
Родитель
4eabac5e60
Коммит
8f7078c814
|
@ -22,7 +22,7 @@ from cyberbattle._env.defender import DefenderAgent
|
|||
from cyberbattle.simulation.model import PortName, PrivilegeLevel
|
||||
from ..simulation import commandcontrol, model, actions
|
||||
from .discriminatedunion import DiscriminatedUnion
|
||||
|
||||
import numpy as np
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -71,7 +71,7 @@ Observation = TypedDict(
|
|||
# whether a lateral move was just performed
|
||||
"lateral_move": numpy.int32,
|
||||
# whether customer data were just discovered
|
||||
"customer_data_found": Tuple[numpy.int32],
|
||||
"customer_data_found": numpy.int32,
|
||||
# 0 if there were no probing attempt
|
||||
# 1 if an attempted probing failed
|
||||
# 2 if an attempted probing succeeded
|
||||
|
@ -90,7 +90,7 @@ Observation = TypedDict(
|
|||
# total nodes discovered so far
|
||||
"discovered_node_count": int,
|
||||
# Matrix of properties for all the discovered nodes
|
||||
"discovered_nodes_properties": Tuple[numpy.ndarray, ...],
|
||||
"discovered_nodes_properties": numpy.ndarray,
|
||||
# Node privilege level on every discovered node (e.g., 0 if not owned, 1 owned, 2 admin, 3 for system)
|
||||
"nodes_privilegelevel": numpy.ndarray,
|
||||
# Tuple encoding of the credential cache matrix.
|
||||
|
@ -183,14 +183,14 @@ class EnvironmentBounds(NamedTuple):
|
|||
remote_attacks_count - Unique remote vulnerabilities
|
||||
"""
|
||||
|
||||
maximum_total_credentials: int
|
||||
maximum_node_count: int
|
||||
maximum_discoverable_credentials_per_action: int
|
||||
maximum_total_credentials: np.int32
|
||||
maximum_node_count: np.int32
|
||||
maximum_discoverable_credentials_per_action: np.int32
|
||||
|
||||
port_count: int
|
||||
property_count: int
|
||||
local_attacks_count: int
|
||||
remote_attacks_count: int
|
||||
port_count: np.int32
|
||||
property_count: np.int32
|
||||
local_attacks_count: np.int32
|
||||
remote_attacks_count: np.int32
|
||||
|
||||
@classmethod
|
||||
def of_identifiers(
|
||||
|
@ -200,16 +200,27 @@ class EnvironmentBounds(NamedTuple):
|
|||
maximum_node_count: int,
|
||||
maximum_discoverable_credentials_per_action: Optional[int] = None,
|
||||
):
|
||||
if not maximum_discoverable_credentials_per_action:
|
||||
maximum_discoverable_credentials_per_action = maximum_total_credentials
|
||||
|
||||
maximum_discoverable_credentials_per_action = maximum_discoverable_credentials_per_action or maximum_total_credentials
|
||||
|
||||
assert np.can_cast(maximum_total_credentials, np.int32), "maximum_total_credentials must be a 32-bit integer"
|
||||
assert np.can_cast(maximum_node_count, np.int32), "maximum_node_count must be a 32-bit integer"
|
||||
assert maximum_total_credentials > 0, "maximum_total_credentials must be positive"
|
||||
assert maximum_node_count > 0, "maximum_node_count must be positive"
|
||||
assert np.can_cast(len(identifiers.ports), np.int32), "port_count must be a 32-bit integer"
|
||||
assert np.can_cast(len(identifiers.properties), np.int32), "property_count must be a 32-bit integer"
|
||||
assert np.can_cast(len(identifiers.local_vulnerabilities), np.int32), "local_attacks_count must be a 32-bit integer"
|
||||
assert np.can_cast(len(identifiers.remote_vulnerabilities), np.int32), "remote_attacks_count must be a 32-bit integer"
|
||||
assert np.can_cast(maximum_discoverable_credentials_per_action, np.int32), "maximum_discoverable_credentials_per_action must be a 32-bit integer"
|
||||
|
||||
return EnvironmentBounds(
|
||||
maximum_total_credentials=maximum_total_credentials,
|
||||
maximum_node_count=maximum_node_count,
|
||||
maximum_discoverable_credentials_per_action=maximum_discoverable_credentials_per_action,
|
||||
port_count=len(identifiers.ports),
|
||||
property_count=len(identifiers.properties),
|
||||
local_attacks_count=len(identifiers.local_vulnerabilities),
|
||||
remote_attacks_count=len(identifiers.remote_vulnerabilities),
|
||||
maximum_total_credentials=np.int32(maximum_total_credentials),
|
||||
maximum_node_count=np.int32(maximum_node_count),
|
||||
maximum_discoverable_credentials_per_action=np.int32(maximum_discoverable_credentials_per_action),
|
||||
port_count=np.int32(len(identifiers.ports)),
|
||||
property_count=np.int32(len(identifiers.properties)),
|
||||
local_attacks_count=np.int32(len(identifiers.local_vulnerabilities)),
|
||||
remote_attacks_count=np.int32(len(identifiers.remote_vulnerabilities)),
|
||||
)
|
||||
|
||||
|
||||
|
@ -252,7 +263,7 @@ class ObservationSpaceType(spaces.Dict):
|
|||
# successuflly moved to the target node (1) or not (0)
|
||||
"lateral_move": spaces.Discrete(2),
|
||||
# boolean: 1 if customer secret data were discovered, 0 otherwise
|
||||
"customer_data_found": spaces.MultiBinary(2),
|
||||
"customer_data_found": spaces.Discrete(2),
|
||||
# whether an attempted probing succeeded or not
|
||||
"probe_result": spaces.Discrete(3),
|
||||
# Esclation result
|
||||
|
@ -271,12 +282,12 @@ class ObservationSpaceType(spaces.Dict):
|
|||
# that was used to authenticat to target node 56 on port number 22 (e.g. SSH)
|
||||
[
|
||||
spaces.MultiDiscrete(
|
||||
[
|
||||
np.array([
|
||||
NA + 1,
|
||||
bounds.maximum_total_credentials,
|
||||
bounds.maximum_node_count,
|
||||
bounds.port_count,
|
||||
]
|
||||
], dtype=np.int32)
|
||||
)
|
||||
]
|
||||
* bounds.maximum_discoverable_credentials_per_action
|
||||
|
@ -287,9 +298,9 @@ class ObservationSpaceType(spaces.Dict):
|
|||
# even though such action would 'fail' and potentially yield a negative reward.
|
||||
"action_mask": spaces.Dict(
|
||||
{
|
||||
"local_vulnerability": spaces.MultiBinary(bounds.maximum_node_count * bounds.local_attacks_count),
|
||||
"remote_vulnerability": spaces.MultiBinary(bounds.maximum_node_count * bounds.maximum_node_count * bounds.remote_attacks_count),
|
||||
"connect": spaces.MultiBinary(bounds.maximum_node_count * bounds.maximum_node_count * bounds.port_count * bounds.maximum_total_credentials),
|
||||
"local_vulnerability": spaces.MultiBinary(np.array([bounds.maximum_node_count, bounds.local_attacks_count])),
|
||||
"remote_vulnerability": spaces.MultiBinary(np.array([bounds.maximum_node_count, bounds.maximum_node_count, bounds.remote_attacks_count])),
|
||||
"connect": spaces.MultiBinary(np.array([bounds.maximum_node_count, bounds.maximum_node_count, bounds.port_count, bounds.maximum_total_credentials], dtype=np.int32))
|
||||
}
|
||||
),
|
||||
# size of the credential stack
|
||||
|
@ -298,7 +309,7 @@ class ObservationSpaceType(spaces.Dict):
|
|||
"discovered_node_count": spaces.Discrete(bounds.maximum_node_count),
|
||||
# Matrix of properties for all the discovered nodes
|
||||
# 3 values for each matrix cell: set, unset, unknown
|
||||
"discovered_nodes_properties": spaces.MultiDiscrete([3] * bounds.maximum_node_count * bounds.property_count),
|
||||
"discovered_nodes_properties": spaces.MultiDiscrete(np.full(shape=(bounds.maximum_node_count, bounds.property_count), fill_value=3)),
|
||||
# Escalation level on every discovered node (e.g., 0 if not owned, 1 for admin, 2 for system)
|
||||
"nodes_privilegelevel": spaces.MultiDiscrete([CyberBattleEnv.privilege_levels] * bounds.maximum_node_count),
|
||||
# Encoding of the credential cache of shape: (credential_cache_length, 2)
|
||||
|
@ -306,7 +317,7 @@ class ObservationSpaceType(spaces.Dict):
|
|||
# Each row represent a discovered credential,
|
||||
# the credential index is given by the row index (i.e. order of discovery)
|
||||
# A row is of the form: (target_node_discover_index, port_index)
|
||||
"credential_cache_matrix": spaces.Tuple([spaces.MultiDiscrete([bounds.maximum_node_count, bounds.port_count])] * bounds.maximum_total_credentials),
|
||||
"credential_cache_matrix": spaces.Tuple([spaces.MultiDiscrete(np.array([bounds.maximum_node_count, bounds.port_count],dtype=np.int32))] * bounds.maximum_total_credentials),
|
||||
# ---------------------------------------------------------
|
||||
# Fields that were previously in the 'info' dict:
|
||||
# ---------------------------------------------------------
|
||||
|
@ -460,7 +471,7 @@ class CyberBattleEnv(CyberBattleSpaceKind):
|
|||
self,
|
||||
initial_environment: model.Environment,
|
||||
maximum_total_credentials: int = 1000,
|
||||
maximum_node_count: int = 100,
|
||||
maximum_node_count: int = 100,
|
||||
maximum_discoverable_credentials_per_action: int = 5,
|
||||
defender_agent: Optional[DefenderAgent] = None,
|
||||
attacker_goal: Optional[AttackerGoal] = AttackerGoal(own_atleast_percent=1.0),
|
||||
|
@ -523,27 +534,27 @@ class CyberBattleEnv(CyberBattleSpaceKind):
|
|||
# The Space object defining the valid actions of an attacker.
|
||||
local_vulnerabilities_count = self.__bounds.local_attacks_count
|
||||
remote_vulnerabilities_count = self.__bounds.remote_attacks_count
|
||||
maximum_node_count = self.__bounds.maximum_node_count
|
||||
maximum_node_count_int32 = self.__bounds.maximum_node_count
|
||||
port_count = self.__bounds.port_count
|
||||
|
||||
action_spaces = {
|
||||
"local_vulnerability": spaces.MultiDiscrete(
|
||||
# source_node_id, vulnerability_id
|
||||
[maximum_node_count, local_vulnerabilities_count]
|
||||
np.array([maximum_node_count_int32, local_vulnerabilities_count], dtype=np.int32)
|
||||
),
|
||||
"remote_vulnerability": spaces.MultiDiscrete(
|
||||
# source_node_id, target_node_id, vulnerability_id
|
||||
[maximum_node_count, maximum_node_count, remote_vulnerabilities_count]
|
||||
np.array([maximum_node_count_int32, maximum_node_count_int32, remote_vulnerabilities_count], dtype=np.int32)
|
||||
),
|
||||
"connect": spaces.MultiDiscrete(
|
||||
# source_node_id, target_node_id, target_port, credential_id
|
||||
# (by index of discovery: 0 for initial node, 1 for first discovered node, ...)
|
||||
[
|
||||
maximum_node_count,
|
||||
maximum_node_count,
|
||||
np.array([
|
||||
maximum_node_count_int32,
|
||||
maximum_node_count_int32,
|
||||
port_count,
|
||||
maximum_total_credentials,
|
||||
]
|
||||
], dtype=np.int32)
|
||||
),
|
||||
}
|
||||
|
||||
|
@ -613,10 +624,10 @@ class CyberBattleEnv(CyberBattleSpaceKind):
|
|||
local_vulnerabilities_count = self.__bounds.local_attacks_count
|
||||
remote_vulnerabilities_count = self.__bounds.remote_attacks_count
|
||||
port_count = self.__bounds.port_count
|
||||
local = numpy.zeros(shape=(max_node_count, local_vulnerabilities_count), dtype=numpy.int32)
|
||||
local = numpy.zeros(shape=(max_node_count, local_vulnerabilities_count), dtype=numpy.int8)
|
||||
remote = numpy.zeros(
|
||||
shape=(max_node_count, max_node_count, remote_vulnerabilities_count),
|
||||
dtype=numpy.int32,
|
||||
dtype=numpy.int8,
|
||||
)
|
||||
connect = numpy.zeros(
|
||||
shape=(
|
||||
|
@ -625,7 +636,7 @@ class CyberBattleEnv(CyberBattleSpaceKind):
|
|||
port_count,
|
||||
self.__bounds.maximum_total_credentials,
|
||||
),
|
||||
dtype=numpy.int32,
|
||||
dtype=numpy.int8,
|
||||
)
|
||||
return ActionMask(local_vulnerability=local, remote_vulnerability=remote, connect=connect)
|
||||
|
||||
|
@ -744,14 +755,14 @@ class CyberBattleEnv(CyberBattleSpaceKind):
|
|||
newly_discovered_nodes_count=numpy.int32(0),
|
||||
leaked_credentials=tuple([numpy.array([UNUSED_SLOT, 0, 0, 0], dtype=numpy.int32)] * self.__bounds.maximum_discoverable_credentials_per_action),
|
||||
lateral_move=numpy.int32(0),
|
||||
customer_data_found=(numpy.int32(0),),
|
||||
customer_data_found=numpy.int32(0),
|
||||
escalation=numpy.int32(PrivilegeLevel.NoAccess),
|
||||
action_mask=self.__get_blank_action_mask(),
|
||||
probe_result=numpy.int32(0),
|
||||
credential_cache_matrix=tuple([numpy.zeros((2))] * self.__bounds.maximum_total_credentials),
|
||||
credential_cache_matrix=tuple([numpy.zeros((2), dtype=numpy.int64)] * self.__bounds.maximum_total_credentials),
|
||||
credential_cache_length=0,
|
||||
discovered_node_count=len(self.__discovered_nodes),
|
||||
discovered_nodes_properties=tuple([numpy.full((self.__bounds.property_count,), 2, dtype=numpy.int32)] * self.__bounds.maximum_node_count),
|
||||
discovered_nodes_properties=numpy.full((self.__bounds.maximum_node_count, self.__bounds.property_count,), 2, dtype=numpy.int32),
|
||||
nodes_privilegelevel=numpy.zeros((self.bounds.maximum_node_count,), dtype=numpy.int32),
|
||||
# raw data not actually encoded as a proper gym numeric space
|
||||
# (were previously returned in the 'info' dict)
|
||||
|
@ -797,7 +808,7 @@ class CyberBattleEnv(CyberBattleSpaceKind):
|
|||
vector[properties_indices] = 1
|
||||
return vector
|
||||
|
||||
def __get_property_matrix(self) -> Tuple[numpy.ndarray, ...]:
|
||||
def __get_property_matrix(self) -> numpy.ndarray:
|
||||
"""Return the Node-Property matrix,
|
||||
where 0 means the property is not set for that node
|
||||
1 means the property is set for that node
|
||||
|
@ -810,11 +821,13 @@ class CyberBattleEnv(CyberBattleSpaceKind):
|
|||
2nd row: no known properties for the 2nd discovered node
|
||||
3rd row: properties of 3rd discovered and owned node"""
|
||||
property_discovered = [self.__property_vector(node_id, node_info) for node_id, node_info in self._actuator.discovered_nodes()]
|
||||
return self.__pad_tuple_if_requested(
|
||||
as_numpy = numpy.array(self.__pad_tuple_if_requested(
|
||||
property_discovered,
|
||||
self.__bounds.property_count,
|
||||
self.__bounds.maximum_node_count,
|
||||
)
|
||||
))
|
||||
assert as_numpy.shape == (self.__bounds.maximum_node_count, self.__bounds.property_count)
|
||||
return as_numpy
|
||||
|
||||
def __get__owned_nodes_indices(self) -> List[int]:
|
||||
"""Get list of indices of all owned nodes"""
|
||||
|
@ -896,7 +909,7 @@ class CyberBattleEnv(CyberBattleSpaceKind):
|
|||
elif isinstance(outcome, model.LateralMove):
|
||||
obs["lateral_move"] = numpy.int32(1)
|
||||
elif isinstance(outcome, model.CustomerData):
|
||||
obs["customer_data_found"] = (numpy.int32(1),)
|
||||
obs["customer_data_found"] = numpy.int32(1)
|
||||
elif isinstance(outcome, model.ProbeSucceeded):
|
||||
obs["probe_result"] = numpy.int32(2)
|
||||
elif isinstance(outcome, model.ProbeFailed):
|
||||
|
@ -1179,7 +1192,8 @@ class CyberBattleEnv(CyberBattleSpaceKind):
|
|||
) -> Tuple[Observation, dict]:
|
||||
LOGGER.info("Resetting the CyberBattle environment")
|
||||
self.__reset_environment()
|
||||
self.seed(seed)
|
||||
self.np_random, seed = seeding.np_random(seed)
|
||||
|
||||
observation = self.__get_blank_observation()
|
||||
observation["action_mask"] = self.compute_action_mask()
|
||||
observation["discovered_nodes_properties"] = self.__get_property_matrix()
|
||||
|
@ -1209,12 +1223,5 @@ class CyberBattleEnv(CyberBattleSpaceKind):
|
|||
fig = self.render_as_fig()
|
||||
fig.show(renderer=self.__renderer)
|
||||
|
||||
def seed(self, seed: Optional[int] = None) -> None:
|
||||
if seed is None:
|
||||
self._seed = seed
|
||||
return
|
||||
|
||||
self.np_random, seed = seeding.np_random(seed)
|
||||
|
||||
def close(self) -> None:
|
||||
return None
|
||||
|
|
|
@ -4,11 +4,12 @@ for CyberBattleEnv gym environment.
|
|||
|
||||
from collections import OrderedDict
|
||||
from sqlite3 import NotSupportedError
|
||||
from gymnasium import spaces
|
||||
from gymnasium import Env, spaces
|
||||
import numpy as np
|
||||
from cyberbattle._env.cyberbattle_env import DummySpace, CyberBattleEnv, Action, CyberBattleSpaceKind
|
||||
from gymnasium.core import ObservationWrapper, ActionWrapper
|
||||
|
||||
from cyberbattle._env.cyberbattle_env import Action, CyberBattleEnv
|
||||
|
||||
|
||||
class FlattenObservationWrapper(ObservationWrapper):
|
||||
"""
|
||||
|
@ -22,69 +23,92 @@ class FlattenObservationWrapper(ObservationWrapper):
|
|||
if isinstance(space, spaces.MultiBinary):
|
||||
if type(space.n) in [tuple, list, np.ndarray]:
|
||||
flatten_dim = np.multiply.reduce(space.n)
|
||||
print(f"// MultiBinary flattened from {space.n} -> {flatten_dim}")
|
||||
return spaces.MultiBinary(flatten_dim)
|
||||
flatten_space = spaces.MultiBinary(flatten_dim)
|
||||
print(f"// MultiBinary flattened from {space.n} -> {flatten_space.n} - dtype: {space.dtype} -> {flatten_space.dtype}")
|
||||
return flatten_space
|
||||
else:
|
||||
print(f"// MultiBinary already flat: {space.n}")
|
||||
return space
|
||||
else:
|
||||
return space
|
||||
|
||||
def __init__(self, env: CyberBattleSpaceKind, ignore_fields=["action_mask"]):
|
||||
def flatten_multidiscrete_space(self, space: spaces.Space):
|
||||
if isinstance(space, spaces.MultiDiscrete):
|
||||
if type(space.nvec) in [tuple, list, np.ndarray]:
|
||||
flatten_space = spaces.MultiDiscrete(space.nvec.flatten())
|
||||
print(f"// MultiDiscrete flattened from {space.nvec} -> {flatten_space.nvec}")
|
||||
return flatten_space
|
||||
else:
|
||||
print(f"// MultiDiscrete already flat: {space.nvec}")
|
||||
return space
|
||||
|
||||
def __init__(self, env: Env, ignore_fields=["action_mask"]):
|
||||
ObservationWrapper.__init__(self, env)
|
||||
self.env = env
|
||||
self.ignore_fields = ignore_fields
|
||||
if isinstance(env.observation_space, spaces.Dict):
|
||||
space_dict = OrderedDict({})
|
||||
for key, space in env.observation_space.spaces.items():
|
||||
if key in ignore_fields:
|
||||
print("Filtering out field", key)
|
||||
elif isinstance(space, spaces.Dict):
|
||||
for subkey, subspace in space.items():
|
||||
space_dict[f"{key}_{subkey}"] = self.flatten_multibinary_space(subspace)
|
||||
elif isinstance(space, spaces.Tuple):
|
||||
for i, subspace in enumerate(space.spaces):
|
||||
space_dict[f"{key}_{i}"] = self.flatten_multibinary_space(subspace)
|
||||
elif isinstance(space, spaces.MultiBinary):
|
||||
space_dict[key] = self.flatten_multibinary_space(space)
|
||||
elif isinstance(space, spaces.Discrete):
|
||||
space_dict[key] = space
|
||||
elif isinstance(space, spaces.MultiDiscrete):
|
||||
space_dict[key] = self.flatten_multidiscrete_space(space)
|
||||
else:
|
||||
raise NotImplementedError(f"Case not handled: {key} - type {type(space)}")
|
||||
|
||||
space_dict = OrderedDict({})
|
||||
for key, space in env.observation_space.spaces.items():
|
||||
if key in ignore_fields:
|
||||
print("Filtering out field", key)
|
||||
elif isinstance(space, spaces.Dict):
|
||||
for k2, subspace in space.items():
|
||||
space_dict[f"{key}_{k2}"] = self.flatten_multibinary_space(subspace)
|
||||
elif isinstance(space, spaces.Tuple):
|
||||
for i, subspace in enumerate(space.spaces):
|
||||
space_dict[f"{key}_{i}"] = self.flatten_multibinary_space(subspace)
|
||||
elif isinstance(space, spaces.MultiBinary):
|
||||
space_dict[key] = self.flatten_multibinary_space(space)
|
||||
elif isinstance(space, spaces.Discrete) or isinstance(space, spaces.MultiDiscrete):
|
||||
space_dict[key] = space
|
||||
elif isinstance(space, DummySpace):
|
||||
print(f"warning: unsupported observation space: {space} : {type(space)}")
|
||||
else:
|
||||
raise NotImplementedError(f"Case not handled: {key} - type {type(space)}")
|
||||
|
||||
self.observation_space = spaces.Dict(space_dict)
|
||||
self.observation_space = spaces.Dict(space_dict)
|
||||
|
||||
def flatten_multibinary_observation(self, space, o):
|
||||
if isinstance(space, spaces.MultiBinary) and isinstance(space.n, tuple) and len(space.n) > 1:
|
||||
flatten_dim = np.multiply.reduce(space.n)
|
||||
return tuple(o.reshape(flatten_dim))
|
||||
# print(f"dtype: {o.dtype} shape: {o.shape} -> {flatten_dim}")
|
||||
reshaped = o.reshape(flatten_dim)
|
||||
# print(f"reshaped: {reshaped.dtype} shape: {reshaped.shape}")
|
||||
return reshaped
|
||||
else:
|
||||
return o
|
||||
|
||||
def flatten_multidiscrete_observation(self, space, o):
|
||||
if isinstance(space, spaces.MultiDiscrete):
|
||||
return o.flatten()
|
||||
else:
|
||||
return o
|
||||
|
||||
def observation(self, observation):
|
||||
o = OrderedDict({})
|
||||
for key, space in self.env.observation_space.spaces.items():
|
||||
value = observation[key]
|
||||
if key in self.ignore_fields:
|
||||
continue
|
||||
elif isinstance(space, spaces.Dict):
|
||||
for subkey, subspace in space.items():
|
||||
o[f"{key}_{subkey}"] = self.flatten_multibinary_observation(subspace, value[subkey])
|
||||
elif isinstance(space, spaces.Tuple):
|
||||
for i, subspace in enumerate(space.spaces):
|
||||
o[f"{key}_{i}"] = self.flatten_multibinary_observation(subspace, value[i])
|
||||
elif isinstance(space, spaces.MultiBinary):
|
||||
o[key] = self.flatten_multibinary_observation(space, value)
|
||||
elif isinstance(space, spaces.Discrete) or isinstance(space, spaces.MultiDiscrete):
|
||||
o[key] = value
|
||||
elif isinstance(space, DummySpace):
|
||||
continue
|
||||
else:
|
||||
raise NotImplementedError(f"Case not handled: {key} - type {type(space)}")
|
||||
if isinstance(self.env.observation_space, spaces.Dict):
|
||||
o = OrderedDict({})
|
||||
for key, space in self.env.observation_space.spaces.items():
|
||||
value = observation[key]
|
||||
if key in self.ignore_fields:
|
||||
continue
|
||||
elif isinstance(space, spaces.Dict):
|
||||
for subkey, subspace in space.items():
|
||||
o[f"{key}_{subkey}"] = self.flatten_multibinary_observation(subspace, value[subkey])
|
||||
elif isinstance(space, spaces.Tuple):
|
||||
for i, subspace in enumerate(space.spaces):
|
||||
o[f"{key}_{i}"] = self.flatten_multibinary_observation(subspace, value[i])
|
||||
elif isinstance(space, spaces.MultiBinary):
|
||||
o[key] = self.flatten_multibinary_observation(space, value)
|
||||
elif isinstance(space, spaces.Discrete):
|
||||
o[key] = value
|
||||
elif isinstance(space, spaces.MultiDiscrete):
|
||||
o[key] = self.flatten_multidiscrete_observation(space, value)
|
||||
else:
|
||||
raise NotImplementedError(f"Case not handled: {key} - type {type(space)}")
|
||||
|
||||
return o
|
||||
return o
|
||||
else:
|
||||
return observation
|
||||
|
||||
def step(self, action):
|
||||
"""Returns a modified observation using :meth:`self.observation` after calling :meth:`env.step`."""
|
||||
|
@ -105,7 +129,7 @@ class FlattenActionWrapper(ActionWrapper):
|
|||
self.env = env
|
||||
|
||||
self.action_space = spaces.MultiDiscrete(
|
||||
[
|
||||
np.array([
|
||||
# connect, local vulnerabilities, remote vulnerabilities
|
||||
1 + env.bounds.local_attacks_count + env.bounds.remote_attacks_count,
|
||||
# source node
|
||||
|
@ -116,7 +140,7 @@ class FlattenActionWrapper(ActionWrapper):
|
|||
env.bounds.port_count,
|
||||
# target port (credentials used, for connect action only)
|
||||
env.bounds.maximum_total_credentials,
|
||||
]
|
||||
], dtype=np.int32)
|
||||
)
|
||||
|
||||
def action(self, action: np.ndarray) -> Action:
|
||||
|
|
|
@ -168,7 +168,7 @@ class Feature_discovered_nodeproperties_sliding(GlobalFeature):
|
|||
|
||||
def get_global(self, a: StateAugmentation) -> ndarray:
|
||||
n = a.observation["discovered_node_count"]
|
||||
node_prop = np.array(a.observation["discovered_nodes_properties"])[:n]
|
||||
node_prop = a.observation["discovered_nodes_properties"][:n]
|
||||
|
||||
# keep last window of entries
|
||||
node_prop_window = node_prop[-self.window_size :, :]
|
||||
|
@ -255,7 +255,7 @@ class Feature_discovered_notowned_node_count(GlobalFeature):
|
|||
"""number of nodes discovered that are not owned yet (optionally clipped)"""
|
||||
|
||||
def __init__(self, p: EnvironmentBounds, clip: Optional[int]):
|
||||
self.clip = p.maximum_node_count if clip is None else clip
|
||||
self.clip = np.int32(clip or p.maximum_node_count)
|
||||
super().__init__(p, [self.clip + 1])
|
||||
|
||||
def get_global(self, a: StateAugmentation):
|
||||
|
@ -263,8 +263,8 @@ class Feature_discovered_notowned_node_count(GlobalFeature):
|
|||
node_props = np.array(a.observation["discovered_nodes_properties"][:discovered])
|
||||
# here we assume that a node is owned just if all its properties are known
|
||||
owned = np.count_nonzero(np.all(node_props != 2, axis=1))
|
||||
diff = discovered - owned
|
||||
return np.array([min(diff, self.clip)], dtype=np.int_)
|
||||
diff = np.int32(discovered - owned)
|
||||
return np.array( [np.min((diff, self.clip))], dtype=np.int32)
|
||||
|
||||
|
||||
class Feature_owned_node_count(GlobalFeature):
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
'''Stable-baselines agent for CyberBattle Gym environment'''
|
||||
|
||||
import os
|
||||
from typing import cast
|
||||
from cyberbattle._env.cyberbattle_toyctf import CyberBattleToyCtf
|
||||
from cyberbattle._env.flatten_wrapper import (
|
||||
FlattenObservationWrapper,
|
||||
FlattenActionWrapper,
|
||||
)
|
||||
from stable_baselines3.common.type_aliases import GymEnv
|
||||
from stable_baselines3.common.env_checker import check_env
|
||||
from stable_baselines3.a2c.a2c import A2C
|
||||
|
||||
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
||||
|
||||
def test_stablebaseline3(training_steps=3, eval_steps=10):
|
||||
|
||||
cybersinm_env = CyberBattleToyCtf(
|
||||
maximum_node_count=12,
|
||||
maximum_total_credentials=10,
|
||||
observation_padding=True,
|
||||
throws_on_invalid_actions=False,
|
||||
)
|
||||
|
||||
flatten_action_env = FlattenActionWrapper(cybersinm_env)
|
||||
|
||||
flatten_obs_env = FlattenObservationWrapper(flatten_action_env, ignore_fields=[
|
||||
"_credential_cache",
|
||||
"_discovered_nodes",
|
||||
"_explored_network",
|
||||
])
|
||||
|
||||
env_as_gym = cast(GymEnv, flatten_obs_env)
|
||||
|
||||
check_env(flatten_obs_env)
|
||||
|
||||
model_a2c = A2C("MultiInputPolicy", env_as_gym).learn(training_steps)
|
||||
model_a2c.save("a2c_trained_toyctf")
|
||||
model = A2C("MultiInputPolicy", env_as_gym).load("a2c_trained_toyctf")
|
||||
|
||||
obs , _= env_as_gym.reset()
|
||||
for i in range(eval_steps):
|
||||
assert isinstance(obs, dict)
|
||||
action, _states = model.predict(obs, deterministic=True)
|
||||
obs, reward, done, truncated, info = flatten_obs_env.step(action)
|
||||
|
||||
flatten_obs_env.render()
|
||||
flatten_obs_env.close()
|
|
@ -355,7 +355,7 @@ class Identifiers(NamedTuple):
|
|||
# Array of all possible node property identifiers
|
||||
properties: List[PropertyName] = []
|
||||
# Array of all possible port names
|
||||
ports: List[PortName] = []
|
||||
ports: List[PortName] = ["Null"]
|
||||
# Array of all possible local vulnerabilities names
|
||||
local_vulnerabilities: List[VulnerabilityID] = []
|
||||
# Array of all possible remote vulnerabilities names
|
||||
|
|
|
@ -57,7 +57,7 @@
|
|||
"from cyberbattle._env.cyberbattle_env import CyberBattleEnv\n",
|
||||
"\n",
|
||||
"envs = [cast(CyberBattleEnv, gym.make(gymid)) for gymid in gymids]\n",
|
||||
"map(lambda g: g.seed(1), envs)\n",
|
||||
"map(lambda g: g.reset(seed=1), envs)\n",
|
||||
"ep = w.EnvironmentBounds.of_identifiers(maximum_node_count=30, maximum_total_credentials=50, identifiers=envs[0].identifiers)"
|
||||
]
|
||||
},
|
||||
|
@ -18119,7 +18119,7 @@
|
|||
"source": [
|
||||
"tiny = cast(CyberBattleEnv, gym.make(f\"ActiveDirectory-v{ngyms}\"))\n",
|
||||
"current_o, _ = tiny.reset()\n",
|
||||
"tiny.seed(1)\n",
|
||||
"tiny.reset(seed=1)\n",
|
||||
"wrapped_env = AgentWrapper(tiny, ActionTrackingStateAugmentation(ep, current_o))\n",
|
||||
"# Use the trained agent to run the steps one by one\n",
|
||||
"max_steps = 1000\n",
|
||||
|
@ -18166,4 +18166,4 @@
|
|||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
}
|
||||
|
|
|
@ -394461,9 +394461,9 @@
|
|||
"cell_metadata_filter": "title,-all"
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python [conda env:cybersim]",
|
||||
"display_name": "cybersim",
|
||||
"language": "python",
|
||||
"name": "conda-env-cybersim-py"
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
|
@ -394492,4 +394492,4 @@
|
|||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
|
|
|
@ -96,8 +96,8 @@ run notebook_withdefender -y "
|
|||
"
|
||||
|
||||
run dql_active_directory -y "
|
||||
ngyms: 3
|
||||
iteration_count: 50
|
||||
ngyms: 3
|
||||
iteration_count: 50
|
||||
"
|
||||
|
||||
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
# %%
|
||||
# !pip install stable-baselines3[extra]
|
||||
'''Stable-baselines agent for CyberBattle Gym environment'''
|
||||
|
||||
# %%
|
||||
from typing import cast
|
||||
from cyberbattle._env.cyberbattle_env import CyberBattleEnv
|
||||
from cyberbattle._env.cyberbattle_toyctf import CyberBattleToyCtf
|
||||
import logging
|
||||
import sys
|
||||
|
@ -14,15 +12,13 @@ from cyberbattle._env.flatten_wrapper import (
|
|||
FlattenActionWrapper,
|
||||
)
|
||||
import os
|
||||
import numpy as np
|
||||
from stable_baselines3.common.type_aliases import GymEnv
|
||||
from stable_baselines3.common.env_checker import check_env
|
||||
|
||||
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
||||
retrain = ["a2c"]
|
||||
retrain = ["a2c", "ppo"]
|
||||
|
||||
|
||||
# %%
|
||||
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.ERROR, format="%(levelname)s: %(message)s")
|
||||
|
||||
# %%
|
||||
|
@ -33,29 +29,27 @@ env = CyberBattleToyCtf(
|
|||
throws_on_invalid_actions=False,
|
||||
)
|
||||
|
||||
# %%
|
||||
flatten_action_env = FlattenActionWrapper(env)
|
||||
|
||||
# %%
|
||||
env1 = FlattenActionWrapper(env)
|
||||
|
||||
# %%
|
||||
# MultiBinary
|
||||
# 'action_mask',
|
||||
# 'customer_data_found',
|
||||
# MultiDiscrete space
|
||||
# 'nodes_privilegelevel',
|
||||
# 'leaked_credentials',
|
||||
# 'credential_cache_matrix'
|
||||
# 'discovered_nodes_properties',
|
||||
|
||||
ignore_fields = [
|
||||
flatten_obs_env = FlattenObservationWrapper(flatten_action_env, ignore_fields=[
|
||||
# DummySpace
|
||||
"_credential_cache",
|
||||
"_discovered_nodes",
|
||||
"_explored_network",
|
||||
]
|
||||
env2 = FlattenObservationWrapper(cast(CyberBattleEnv, env1), ignore_fields=ignore_fields)
|
||||
])
|
||||
|
||||
#%%
|
||||
env_as_gym = cast(GymEnv, flatten_obs_env)
|
||||
|
||||
#%%
|
||||
o, _ = env_as_gym.reset()
|
||||
print(o)
|
||||
|
||||
#%%
|
||||
check_env(flatten_obs_env)
|
||||
|
||||
env_as_gym = cast(GymEnv, env2)
|
||||
|
||||
# %%
|
||||
if "a2c" in retrain:
|
||||
|
@ -75,10 +69,16 @@ model = A2C("MultiInputPolicy", env_as_gym).load("a2c_trained_toyctf")
|
|||
|
||||
|
||||
# %%
|
||||
obs = env2.reset()
|
||||
for i in range(1000):
|
||||
action, _states = model.predict(np.array(obs), deterministic=True)
|
||||
obs, reward, done, truncated, info = env2.step(action)
|
||||
obs , _= env_as_gym.reset()
|
||||
|
||||
env2.render()
|
||||
env2.close()
|
||||
|
||||
# %%
|
||||
for i in range(1000):
|
||||
assert isinstance(obs, dict)
|
||||
action, _states = model.predict(obs, deterministic=True)
|
||||
obs, reward, done, truncated, info = flatten_obs_env.step(action)
|
||||
|
||||
flatten_obs_env.render()
|
||||
flatten_obs_env.close()
|
||||
|
||||
# %%
|
||||
|
|
Загрузка…
Ссылка в новой задаче