Stable-baselines3 agent working - fix shapes of observation spaces (#144)

* Stable-baselines3 agent working - fix shapes of observation spaces

chore: Update stable-baselines3 agent with env_checker

chore: Update ports list with "Null" value


---------

Co-authored-by: William Blum <william.blum@microsoft.com>
This commit is contained in:
William Blum 2024-08-07 14:10:08 -07:00 коммит произвёл GitHub
Родитель 4eabac5e60
Коммит 8f7078c814
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
9 изменённых файлов: 222 добавлений и 143 удалений

Просмотреть файл

@ -22,7 +22,7 @@ from cyberbattle._env.defender import DefenderAgent
from cyberbattle.simulation.model import PortName, PrivilegeLevel
from ..simulation import commandcontrol, model, actions
from .discriminatedunion import DiscriminatedUnion
import numpy as np
LOGGER = logging.getLogger(__name__)
@ -71,7 +71,7 @@ Observation = TypedDict(
# whether a lateral move was just performed
"lateral_move": numpy.int32,
# whether customer data were just discovered
"customer_data_found": Tuple[numpy.int32],
"customer_data_found": numpy.int32,
# 0 if there were no probing attempt
# 1 if an attempted probing failed
# 2 if an attempted probing succeeded
@ -90,7 +90,7 @@ Observation = TypedDict(
# total nodes discovered so far
"discovered_node_count": int,
# Matrix of properties for all the discovered nodes
"discovered_nodes_properties": Tuple[numpy.ndarray, ...],
"discovered_nodes_properties": numpy.ndarray,
# Node privilege level on every discovered node (e.g., 0 if not owned, 1 owned, 2 admin, 3 for system)
"nodes_privilegelevel": numpy.ndarray,
# Tuple encoding of the credential cache matrix.
@ -183,14 +183,14 @@ class EnvironmentBounds(NamedTuple):
remote_attacks_count - Unique remote vulnerabilities
"""
maximum_total_credentials: int
maximum_node_count: int
maximum_discoverable_credentials_per_action: int
maximum_total_credentials: np.int32
maximum_node_count: np.int32
maximum_discoverable_credentials_per_action: np.int32
port_count: int
property_count: int
local_attacks_count: int
remote_attacks_count: int
port_count: np.int32
property_count: np.int32
local_attacks_count: np.int32
remote_attacks_count: np.int32
@classmethod
def of_identifiers(
@ -200,16 +200,27 @@ class EnvironmentBounds(NamedTuple):
maximum_node_count: int,
maximum_discoverable_credentials_per_action: Optional[int] = None,
):
if not maximum_discoverable_credentials_per_action:
maximum_discoverable_credentials_per_action = maximum_total_credentials
maximum_discoverable_credentials_per_action = maximum_discoverable_credentials_per_action or maximum_total_credentials
assert np.can_cast(maximum_total_credentials, np.int32), "maximum_total_credentials must be a 32-bit integer"
assert np.can_cast(maximum_node_count, np.int32), "maximum_node_count must be a 32-bit integer"
assert maximum_total_credentials > 0, "maximum_total_credentials must be positive"
assert maximum_node_count > 0, "maximum_node_count must be positive"
assert np.can_cast(len(identifiers.ports), np.int32), "port_count must be a 32-bit integer"
assert np.can_cast(len(identifiers.properties), np.int32), "property_count must be a 32-bit integer"
assert np.can_cast(len(identifiers.local_vulnerabilities), np.int32), "local_attacks_count must be a 32-bit integer"
assert np.can_cast(len(identifiers.remote_vulnerabilities), np.int32), "remote_attacks_count must be a 32-bit integer"
assert np.can_cast(maximum_discoverable_credentials_per_action, np.int32), "maximum_discoverable_credentials_per_action must be a 32-bit integer"
return EnvironmentBounds(
maximum_total_credentials=maximum_total_credentials,
maximum_node_count=maximum_node_count,
maximum_discoverable_credentials_per_action=maximum_discoverable_credentials_per_action,
port_count=len(identifiers.ports),
property_count=len(identifiers.properties),
local_attacks_count=len(identifiers.local_vulnerabilities),
remote_attacks_count=len(identifiers.remote_vulnerabilities),
maximum_total_credentials=np.int32(maximum_total_credentials),
maximum_node_count=np.int32(maximum_node_count),
maximum_discoverable_credentials_per_action=np.int32(maximum_discoverable_credentials_per_action),
port_count=np.int32(len(identifiers.ports)),
property_count=np.int32(len(identifiers.properties)),
local_attacks_count=np.int32(len(identifiers.local_vulnerabilities)),
remote_attacks_count=np.int32(len(identifiers.remote_vulnerabilities)),
)
@ -252,7 +263,7 @@ class ObservationSpaceType(spaces.Dict):
# successuflly moved to the target node (1) or not (0)
"lateral_move": spaces.Discrete(2),
# boolean: 1 if customer secret data were discovered, 0 otherwise
"customer_data_found": spaces.MultiBinary(2),
"customer_data_found": spaces.Discrete(2),
# whether an attempted probing succeeded or not
"probe_result": spaces.Discrete(3),
# Esclation result
@ -271,12 +282,12 @@ class ObservationSpaceType(spaces.Dict):
# that was used to authenticat to target node 56 on port number 22 (e.g. SSH)
[
spaces.MultiDiscrete(
[
np.array([
NA + 1,
bounds.maximum_total_credentials,
bounds.maximum_node_count,
bounds.port_count,
]
], dtype=np.int32)
)
]
* bounds.maximum_discoverable_credentials_per_action
@ -287,9 +298,9 @@ class ObservationSpaceType(spaces.Dict):
# even though such action would 'fail' and potentially yield a negative reward.
"action_mask": spaces.Dict(
{
"local_vulnerability": spaces.MultiBinary(bounds.maximum_node_count * bounds.local_attacks_count),
"remote_vulnerability": spaces.MultiBinary(bounds.maximum_node_count * bounds.maximum_node_count * bounds.remote_attacks_count),
"connect": spaces.MultiBinary(bounds.maximum_node_count * bounds.maximum_node_count * bounds.port_count * bounds.maximum_total_credentials),
"local_vulnerability": spaces.MultiBinary(np.array([bounds.maximum_node_count, bounds.local_attacks_count])),
"remote_vulnerability": spaces.MultiBinary(np.array([bounds.maximum_node_count, bounds.maximum_node_count, bounds.remote_attacks_count])),
"connect": spaces.MultiBinary(np.array([bounds.maximum_node_count, bounds.maximum_node_count, bounds.port_count, bounds.maximum_total_credentials], dtype=np.int32))
}
),
# size of the credential stack
@ -298,7 +309,7 @@ class ObservationSpaceType(spaces.Dict):
"discovered_node_count": spaces.Discrete(bounds.maximum_node_count),
# Matrix of properties for all the discovered nodes
# 3 values for each matrix cell: set, unset, unknown
"discovered_nodes_properties": spaces.MultiDiscrete([3] * bounds.maximum_node_count * bounds.property_count),
"discovered_nodes_properties": spaces.MultiDiscrete(np.full(shape=(bounds.maximum_node_count, bounds.property_count), fill_value=3)),
# Escalation level on every discovered node (e.g., 0 if not owned, 1 for admin, 2 for system)
"nodes_privilegelevel": spaces.MultiDiscrete([CyberBattleEnv.privilege_levels] * bounds.maximum_node_count),
# Encoding of the credential cache of shape: (credential_cache_length, 2)
@ -306,7 +317,7 @@ class ObservationSpaceType(spaces.Dict):
# Each row represent a discovered credential,
# the credential index is given by the row index (i.e. order of discovery)
# A row is of the form: (target_node_discover_index, port_index)
"credential_cache_matrix": spaces.Tuple([spaces.MultiDiscrete([bounds.maximum_node_count, bounds.port_count])] * bounds.maximum_total_credentials),
"credential_cache_matrix": spaces.Tuple([spaces.MultiDiscrete(np.array([bounds.maximum_node_count, bounds.port_count],dtype=np.int32))] * bounds.maximum_total_credentials),
# ---------------------------------------------------------
# Fields that were previously in the 'info' dict:
# ---------------------------------------------------------
@ -460,7 +471,7 @@ class CyberBattleEnv(CyberBattleSpaceKind):
self,
initial_environment: model.Environment,
maximum_total_credentials: int = 1000,
maximum_node_count: int = 100,
maximum_node_count: int = 100,
maximum_discoverable_credentials_per_action: int = 5,
defender_agent: Optional[DefenderAgent] = None,
attacker_goal: Optional[AttackerGoal] = AttackerGoal(own_atleast_percent=1.0),
@ -523,27 +534,27 @@ class CyberBattleEnv(CyberBattleSpaceKind):
# The Space object defining the valid actions of an attacker.
local_vulnerabilities_count = self.__bounds.local_attacks_count
remote_vulnerabilities_count = self.__bounds.remote_attacks_count
maximum_node_count = self.__bounds.maximum_node_count
maximum_node_count_int32 = self.__bounds.maximum_node_count
port_count = self.__bounds.port_count
action_spaces = {
"local_vulnerability": spaces.MultiDiscrete(
# source_node_id, vulnerability_id
[maximum_node_count, local_vulnerabilities_count]
np.array([maximum_node_count_int32, local_vulnerabilities_count], dtype=np.int32)
),
"remote_vulnerability": spaces.MultiDiscrete(
# source_node_id, target_node_id, vulnerability_id
[maximum_node_count, maximum_node_count, remote_vulnerabilities_count]
np.array([maximum_node_count_int32, maximum_node_count_int32, remote_vulnerabilities_count], dtype=np.int32)
),
"connect": spaces.MultiDiscrete(
# source_node_id, target_node_id, target_port, credential_id
# (by index of discovery: 0 for initial node, 1 for first discovered node, ...)
[
maximum_node_count,
maximum_node_count,
np.array([
maximum_node_count_int32,
maximum_node_count_int32,
port_count,
maximum_total_credentials,
]
], dtype=np.int32)
),
}
@ -613,10 +624,10 @@ class CyberBattleEnv(CyberBattleSpaceKind):
local_vulnerabilities_count = self.__bounds.local_attacks_count
remote_vulnerabilities_count = self.__bounds.remote_attacks_count
port_count = self.__bounds.port_count
local = numpy.zeros(shape=(max_node_count, local_vulnerabilities_count), dtype=numpy.int32)
local = numpy.zeros(shape=(max_node_count, local_vulnerabilities_count), dtype=numpy.int8)
remote = numpy.zeros(
shape=(max_node_count, max_node_count, remote_vulnerabilities_count),
dtype=numpy.int32,
dtype=numpy.int8,
)
connect = numpy.zeros(
shape=(
@ -625,7 +636,7 @@ class CyberBattleEnv(CyberBattleSpaceKind):
port_count,
self.__bounds.maximum_total_credentials,
),
dtype=numpy.int32,
dtype=numpy.int8,
)
return ActionMask(local_vulnerability=local, remote_vulnerability=remote, connect=connect)
@ -744,14 +755,14 @@ class CyberBattleEnv(CyberBattleSpaceKind):
newly_discovered_nodes_count=numpy.int32(0),
leaked_credentials=tuple([numpy.array([UNUSED_SLOT, 0, 0, 0], dtype=numpy.int32)] * self.__bounds.maximum_discoverable_credentials_per_action),
lateral_move=numpy.int32(0),
customer_data_found=(numpy.int32(0),),
customer_data_found=numpy.int32(0),
escalation=numpy.int32(PrivilegeLevel.NoAccess),
action_mask=self.__get_blank_action_mask(),
probe_result=numpy.int32(0),
credential_cache_matrix=tuple([numpy.zeros((2))] * self.__bounds.maximum_total_credentials),
credential_cache_matrix=tuple([numpy.zeros((2), dtype=numpy.int64)] * self.__bounds.maximum_total_credentials),
credential_cache_length=0,
discovered_node_count=len(self.__discovered_nodes),
discovered_nodes_properties=tuple([numpy.full((self.__bounds.property_count,), 2, dtype=numpy.int32)] * self.__bounds.maximum_node_count),
discovered_nodes_properties=numpy.full((self.__bounds.maximum_node_count, self.__bounds.property_count,), 2, dtype=numpy.int32),
nodes_privilegelevel=numpy.zeros((self.bounds.maximum_node_count,), dtype=numpy.int32),
# raw data not actually encoded as a proper gym numeric space
# (were previously returned in the 'info' dict)
@ -797,7 +808,7 @@ class CyberBattleEnv(CyberBattleSpaceKind):
vector[properties_indices] = 1
return vector
def __get_property_matrix(self) -> Tuple[numpy.ndarray, ...]:
def __get_property_matrix(self) -> numpy.ndarray:
"""Return the Node-Property matrix,
where 0 means the property is not set for that node
1 means the property is set for that node
@ -810,11 +821,13 @@ class CyberBattleEnv(CyberBattleSpaceKind):
2nd row: no known properties for the 2nd discovered node
3rd row: properties of 3rd discovered and owned node"""
property_discovered = [self.__property_vector(node_id, node_info) for node_id, node_info in self._actuator.discovered_nodes()]
return self.__pad_tuple_if_requested(
as_numpy = numpy.array(self.__pad_tuple_if_requested(
property_discovered,
self.__bounds.property_count,
self.__bounds.maximum_node_count,
)
))
assert as_numpy.shape == (self.__bounds.maximum_node_count, self.__bounds.property_count)
return as_numpy
def __get__owned_nodes_indices(self) -> List[int]:
"""Get list of indices of all owned nodes"""
@ -896,7 +909,7 @@ class CyberBattleEnv(CyberBattleSpaceKind):
elif isinstance(outcome, model.LateralMove):
obs["lateral_move"] = numpy.int32(1)
elif isinstance(outcome, model.CustomerData):
obs["customer_data_found"] = (numpy.int32(1),)
obs["customer_data_found"] = numpy.int32(1)
elif isinstance(outcome, model.ProbeSucceeded):
obs["probe_result"] = numpy.int32(2)
elif isinstance(outcome, model.ProbeFailed):
@ -1179,7 +1192,8 @@ class CyberBattleEnv(CyberBattleSpaceKind):
) -> Tuple[Observation, dict]:
LOGGER.info("Resetting the CyberBattle environment")
self.__reset_environment()
self.seed(seed)
self.np_random, seed = seeding.np_random(seed)
observation = self.__get_blank_observation()
observation["action_mask"] = self.compute_action_mask()
observation["discovered_nodes_properties"] = self.__get_property_matrix()
@ -1209,12 +1223,5 @@ class CyberBattleEnv(CyberBattleSpaceKind):
fig = self.render_as_fig()
fig.show(renderer=self.__renderer)
def seed(self, seed: Optional[int] = None) -> None:
if seed is None:
self._seed = seed
return
self.np_random, seed = seeding.np_random(seed)
def close(self) -> None:
return None

Просмотреть файл

@ -4,11 +4,12 @@ for CyberBattleEnv gym environment.
from collections import OrderedDict
from sqlite3 import NotSupportedError
from gymnasium import spaces
from gymnasium import Env, spaces
import numpy as np
from cyberbattle._env.cyberbattle_env import DummySpace, CyberBattleEnv, Action, CyberBattleSpaceKind
from gymnasium.core import ObservationWrapper, ActionWrapper
from cyberbattle._env.cyberbattle_env import Action, CyberBattleEnv
class FlattenObservationWrapper(ObservationWrapper):
"""
@ -22,69 +23,92 @@ class FlattenObservationWrapper(ObservationWrapper):
if isinstance(space, spaces.MultiBinary):
if type(space.n) in [tuple, list, np.ndarray]:
flatten_dim = np.multiply.reduce(space.n)
print(f"// MultiBinary flattened from {space.n} -> {flatten_dim}")
return spaces.MultiBinary(flatten_dim)
flatten_space = spaces.MultiBinary(flatten_dim)
print(f"// MultiBinary flattened from {space.n} -> {flatten_space.n} - dtype: {space.dtype} -> {flatten_space.dtype}")
return flatten_space
else:
print(f"// MultiBinary already flat: {space.n}")
return space
else:
return space
def __init__(self, env: CyberBattleSpaceKind, ignore_fields=["action_mask"]):
def flatten_multidiscrete_space(self, space: spaces.Space):
if isinstance(space, spaces.MultiDiscrete):
if type(space.nvec) in [tuple, list, np.ndarray]:
flatten_space = spaces.MultiDiscrete(space.nvec.flatten())
print(f"// MultiDiscrete flattened from {space.nvec} -> {flatten_space.nvec}")
return flatten_space
else:
print(f"// MultiDiscrete already flat: {space.nvec}")
return space
def __init__(self, env: Env, ignore_fields=["action_mask"]):
ObservationWrapper.__init__(self, env)
self.env = env
self.ignore_fields = ignore_fields
if isinstance(env.observation_space, spaces.Dict):
space_dict = OrderedDict({})
for key, space in env.observation_space.spaces.items():
if key in ignore_fields:
print("Filtering out field", key)
elif isinstance(space, spaces.Dict):
for subkey, subspace in space.items():
space_dict[f"{key}_{subkey}"] = self.flatten_multibinary_space(subspace)
elif isinstance(space, spaces.Tuple):
for i, subspace in enumerate(space.spaces):
space_dict[f"{key}_{i}"] = self.flatten_multibinary_space(subspace)
elif isinstance(space, spaces.MultiBinary):
space_dict[key] = self.flatten_multibinary_space(space)
elif isinstance(space, spaces.Discrete):
space_dict[key] = space
elif isinstance(space, spaces.MultiDiscrete):
space_dict[key] = self.flatten_multidiscrete_space(space)
else:
raise NotImplementedError(f"Case not handled: {key} - type {type(space)}")
space_dict = OrderedDict({})
for key, space in env.observation_space.spaces.items():
if key in ignore_fields:
print("Filtering out field", key)
elif isinstance(space, spaces.Dict):
for k2, subspace in space.items():
space_dict[f"{key}_{k2}"] = self.flatten_multibinary_space(subspace)
elif isinstance(space, spaces.Tuple):
for i, subspace in enumerate(space.spaces):
space_dict[f"{key}_{i}"] = self.flatten_multibinary_space(subspace)
elif isinstance(space, spaces.MultiBinary):
space_dict[key] = self.flatten_multibinary_space(space)
elif isinstance(space, spaces.Discrete) or isinstance(space, spaces.MultiDiscrete):
space_dict[key] = space
elif isinstance(space, DummySpace):
print(f"warning: unsupported observation space: {space} : {type(space)}")
else:
raise NotImplementedError(f"Case not handled: {key} - type {type(space)}")
self.observation_space = spaces.Dict(space_dict)
self.observation_space = spaces.Dict(space_dict)
def flatten_multibinary_observation(self, space, o):
if isinstance(space, spaces.MultiBinary) and isinstance(space.n, tuple) and len(space.n) > 1:
flatten_dim = np.multiply.reduce(space.n)
return tuple(o.reshape(flatten_dim))
# print(f"dtype: {o.dtype} shape: {o.shape} -> {flatten_dim}")
reshaped = o.reshape(flatten_dim)
# print(f"reshaped: {reshaped.dtype} shape: {reshaped.shape}")
return reshaped
else:
return o
def flatten_multidiscrete_observation(self, space, o):
if isinstance(space, spaces.MultiDiscrete):
return o.flatten()
else:
return o
def observation(self, observation):
o = OrderedDict({})
for key, space in self.env.observation_space.spaces.items():
value = observation[key]
if key in self.ignore_fields:
continue
elif isinstance(space, spaces.Dict):
for subkey, subspace in space.items():
o[f"{key}_{subkey}"] = self.flatten_multibinary_observation(subspace, value[subkey])
elif isinstance(space, spaces.Tuple):
for i, subspace in enumerate(space.spaces):
o[f"{key}_{i}"] = self.flatten_multibinary_observation(subspace, value[i])
elif isinstance(space, spaces.MultiBinary):
o[key] = self.flatten_multibinary_observation(space, value)
elif isinstance(space, spaces.Discrete) or isinstance(space, spaces.MultiDiscrete):
o[key] = value
elif isinstance(space, DummySpace):
continue
else:
raise NotImplementedError(f"Case not handled: {key} - type {type(space)}")
if isinstance(self.env.observation_space, spaces.Dict):
o = OrderedDict({})
for key, space in self.env.observation_space.spaces.items():
value = observation[key]
if key in self.ignore_fields:
continue
elif isinstance(space, spaces.Dict):
for subkey, subspace in space.items():
o[f"{key}_{subkey}"] = self.flatten_multibinary_observation(subspace, value[subkey])
elif isinstance(space, spaces.Tuple):
for i, subspace in enumerate(space.spaces):
o[f"{key}_{i}"] = self.flatten_multibinary_observation(subspace, value[i])
elif isinstance(space, spaces.MultiBinary):
o[key] = self.flatten_multibinary_observation(space, value)
elif isinstance(space, spaces.Discrete):
o[key] = value
elif isinstance(space, spaces.MultiDiscrete):
o[key] = self.flatten_multidiscrete_observation(space, value)
else:
raise NotImplementedError(f"Case not handled: {key} - type {type(space)}")
return o
return o
else:
return observation
def step(self, action):
"""Returns a modified observation using :meth:`self.observation` after calling :meth:`env.step`."""
@ -105,7 +129,7 @@ class FlattenActionWrapper(ActionWrapper):
self.env = env
self.action_space = spaces.MultiDiscrete(
[
np.array([
# connect, local vulnerabilities, remote vulnerabilities
1 + env.bounds.local_attacks_count + env.bounds.remote_attacks_count,
# source node
@ -116,7 +140,7 @@ class FlattenActionWrapper(ActionWrapper):
env.bounds.port_count,
# target port (credentials used, for connect action only)
env.bounds.maximum_total_credentials,
]
], dtype=np.int32)
)
def action(self, action: np.ndarray) -> Action:

Просмотреть файл

@ -168,7 +168,7 @@ class Feature_discovered_nodeproperties_sliding(GlobalFeature):
def get_global(self, a: StateAugmentation) -> ndarray:
n = a.observation["discovered_node_count"]
node_prop = np.array(a.observation["discovered_nodes_properties"])[:n]
node_prop = a.observation["discovered_nodes_properties"][:n]
# keep last window of entries
node_prop_window = node_prop[-self.window_size :, :]
@ -255,7 +255,7 @@ class Feature_discovered_notowned_node_count(GlobalFeature):
"""number of nodes discovered that are not owned yet (optionally clipped)"""
def __init__(self, p: EnvironmentBounds, clip: Optional[int]):
self.clip = p.maximum_node_count if clip is None else clip
self.clip = np.int32(clip or p.maximum_node_count)
super().__init__(p, [self.clip + 1])
def get_global(self, a: StateAugmentation):
@ -263,8 +263,8 @@ class Feature_discovered_notowned_node_count(GlobalFeature):
node_props = np.array(a.observation["discovered_nodes_properties"][:discovered])
# here we assume that a node is owned just if all its properties are known
owned = np.count_nonzero(np.all(node_props != 2, axis=1))
diff = discovered - owned
return np.array([min(diff, self.clip)], dtype=np.int_)
diff = np.int32(discovered - owned)
return np.array( [np.min((diff, self.clip))], dtype=np.int32)
class Feature_owned_node_count(GlobalFeature):

Просмотреть файл

@ -0,0 +1,48 @@
'''Stable-baselines agent for CyberBattle Gym environment'''
import os
from typing import cast
from cyberbattle._env.cyberbattle_toyctf import CyberBattleToyCtf
from cyberbattle._env.flatten_wrapper import (
FlattenObservationWrapper,
FlattenActionWrapper,
)
from stable_baselines3.common.type_aliases import GymEnv
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.a2c.a2c import A2C
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
def test_stablebaseline3(training_steps=3, eval_steps=10):
cybersinm_env = CyberBattleToyCtf(
maximum_node_count=12,
maximum_total_credentials=10,
observation_padding=True,
throws_on_invalid_actions=False,
)
flatten_action_env = FlattenActionWrapper(cybersinm_env)
flatten_obs_env = FlattenObservationWrapper(flatten_action_env, ignore_fields=[
"_credential_cache",
"_discovered_nodes",
"_explored_network",
])
env_as_gym = cast(GymEnv, flatten_obs_env)
check_env(flatten_obs_env)
model_a2c = A2C("MultiInputPolicy", env_as_gym).learn(training_steps)
model_a2c.save("a2c_trained_toyctf")
model = A2C("MultiInputPolicy", env_as_gym).load("a2c_trained_toyctf")
obs , _= env_as_gym.reset()
for i in range(eval_steps):
assert isinstance(obs, dict)
action, _states = model.predict(obs, deterministic=True)
obs, reward, done, truncated, info = flatten_obs_env.step(action)
flatten_obs_env.render()
flatten_obs_env.close()

Просмотреть файл

@ -355,7 +355,7 @@ class Identifiers(NamedTuple):
# Array of all possible node property identifiers
properties: List[PropertyName] = []
# Array of all possible port names
ports: List[PortName] = []
ports: List[PortName] = ["Null"]
# Array of all possible local vulnerabilities names
local_vulnerabilities: List[VulnerabilityID] = []
# Array of all possible remote vulnerabilities names

Просмотреть файл

@ -57,7 +57,7 @@
"from cyberbattle._env.cyberbattle_env import CyberBattleEnv\n",
"\n",
"envs = [cast(CyberBattleEnv, gym.make(gymid)) for gymid in gymids]\n",
"map(lambda g: g.seed(1), envs)\n",
"map(lambda g: g.reset(seed=1), envs)\n",
"ep = w.EnvironmentBounds.of_identifiers(maximum_node_count=30, maximum_total_credentials=50, identifiers=envs[0].identifiers)"
]
},
@ -18119,7 +18119,7 @@
"source": [
"tiny = cast(CyberBattleEnv, gym.make(f\"ActiveDirectory-v{ngyms}\"))\n",
"current_o, _ = tiny.reset()\n",
"tiny.seed(1)\n",
"tiny.reset(seed=1)\n",
"wrapped_env = AgentWrapper(tiny, ActionTrackingStateAugmentation(ep, current_o))\n",
"# Use the trained agent to run the steps one by one\n",
"max_steps = 1000\n",
@ -18166,4 +18166,4 @@
},
"nbformat": 4,
"nbformat_minor": 4
}
}

Просмотреть файл

@ -394461,9 +394461,9 @@
"cell_metadata_filter": "title,-all"
},
"kernelspec": {
"display_name": "Python [conda env:cybersim]",
"display_name": "cybersim",
"language": "python",
"name": "conda-env-cybersim-py"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
@ -394492,4 +394492,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

Просмотреть файл

@ -96,8 +96,8 @@ run notebook_withdefender -y "
"
run dql_active_directory -y "
ngyms: 3
iteration_count: 50
ngyms: 3
iteration_count: 50
"

Просмотреть файл

@ -1,9 +1,7 @@
# %%
# !pip install stable-baselines3[extra]
'''Stable-baselines agent for CyberBattle Gym environment'''
# %%
from typing import cast
from cyberbattle._env.cyberbattle_env import CyberBattleEnv
from cyberbattle._env.cyberbattle_toyctf import CyberBattleToyCtf
import logging
import sys
@ -14,15 +12,13 @@ from cyberbattle._env.flatten_wrapper import (
FlattenActionWrapper,
)
import os
import numpy as np
from stable_baselines3.common.type_aliases import GymEnv
from stable_baselines3.common.env_checker import check_env
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
retrain = ["a2c"]
retrain = ["a2c", "ppo"]
# %%
logging.basicConfig(stream=sys.stdout, level=logging.ERROR, format="%(levelname)s: %(message)s")
# %%
@ -33,29 +29,27 @@ env = CyberBattleToyCtf(
throws_on_invalid_actions=False,
)
# %%
flatten_action_env = FlattenActionWrapper(env)
# %%
env1 = FlattenActionWrapper(env)
# %%
# MultiBinary
# 'action_mask',
# 'customer_data_found',
# MultiDiscrete space
# 'nodes_privilegelevel',
# 'leaked_credentials',
# 'credential_cache_matrix'
# 'discovered_nodes_properties',
ignore_fields = [
flatten_obs_env = FlattenObservationWrapper(flatten_action_env, ignore_fields=[
# DummySpace
"_credential_cache",
"_discovered_nodes",
"_explored_network",
]
env2 = FlattenObservationWrapper(cast(CyberBattleEnv, env1), ignore_fields=ignore_fields)
])
#%%
env_as_gym = cast(GymEnv, flatten_obs_env)
#%%
o, _ = env_as_gym.reset()
print(o)
#%%
check_env(flatten_obs_env)
env_as_gym = cast(GymEnv, env2)
# %%
if "a2c" in retrain:
@ -75,10 +69,16 @@ model = A2C("MultiInputPolicy", env_as_gym).load("a2c_trained_toyctf")
# %%
obs = env2.reset()
for i in range(1000):
action, _states = model.predict(np.array(obs), deterministic=True)
obs, reward, done, truncated, info = env2.step(action)
obs , _= env_as_gym.reset()
env2.render()
env2.close()
# %%
for i in range(1000):
assert isinstance(obs, dict)
action, _states = model.predict(obs, deterministic=True)
obs, reward, done, truncated, info = flatten_obs_env.step(action)
flatten_obs_env.render()
flatten_obs_env.close()
# %%