231 строка
6.9 KiB
Python
231 строка
6.9 KiB
Python
# ---
|
|
# jupyter:
|
|
# jupytext:
|
|
# cell_metadata_filter: tags,-all
|
|
# cell_metadata_json: true
|
|
# formats: ipynb,py:percent
|
|
# text_representation:
|
|
# extension: .py
|
|
# format_name: percent
|
|
# format_version: '1.3'
|
|
# jupytext_version: 1.16.4
|
|
# kernelspec:
|
|
# display_name: Python 3 (ipykernel)
|
|
# language: python
|
|
# name: python3
|
|
# ---
|
|
|
|
# %% {"tags": []}
|
|
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT License.
|
|
|
|
"""Benchmark all the baseline agents
|
|
on a given CyberBattleSim environment and compare
|
|
them to the dumb 'random agent' baseline.
|
|
|
|
NOTE: You can run this `.py`-notebook directly from VSCode.
|
|
You can also generate a traditional Jupyter Notebook
|
|
using the VSCode command `Export Currenty Python File As Jupyter Notebook`.
|
|
"""
|
|
|
|
# pylint: disable=invalid-name
|
|
|
|
# %% {"tags": []}
|
|
import sys
|
|
import os
|
|
import logging
|
|
import gymnasium as gym
|
|
import cyberbattle.agents.baseline.learner as learner
|
|
import cyberbattle.agents.baseline.plotting as p
|
|
import cyberbattle.agents.baseline.agent_wrapper as w
|
|
import cyberbattle.agents.baseline.agent_randomcredlookup as rca
|
|
import cyberbattle.agents.baseline.agent_tabularqlearning as tqa
|
|
import cyberbattle.agents.baseline.agent_dql as dqla
|
|
from cyberbattle.agents.baseline.agent_wrapper import Verbosity
|
|
from cyberbattle._env.cyberbattle_env import CyberBattleEnv
|
|
|
|
logging.basicConfig(stream=sys.stdout, level=logging.ERROR, format="%(levelname)s: %(message)s")
|
|
# %% {"tags": []}
|
|
# %matplotlib inline
|
|
# %% {"tags": ["parameters"]}
|
|
# Papermill notebook parameters
|
|
gymid = "CyberBattleChain-v0"
|
|
env_size = 10
|
|
iteration_count = 9000
|
|
training_episode_count = 50
|
|
eval_episode_count = 5
|
|
maximum_node_count = 22
|
|
maximum_total_credentials = 22
|
|
plots_dir = "output/plots"
|
|
|
|
# %% {"tags": []}
|
|
os.makedirs(plots_dir, exist_ok=True)
|
|
|
|
# Load the Gym environment
|
|
if env_size:
|
|
_gym_env = gym.make(gymid, size=env_size)
|
|
else:
|
|
_gym_env = gym.make(gymid)
|
|
|
|
from typing import cast
|
|
|
|
gym_env = cast(CyberBattleEnv, _gym_env.unwrapped)
|
|
assert isinstance(gym_env, CyberBattleEnv), f"Expected CyberBattleEnv, got {type(gym_env)}"
|
|
|
|
ep = w.EnvironmentBounds.of_identifiers(maximum_node_count=maximum_node_count, maximum_total_credentials=maximum_total_credentials, identifiers=gym_env.identifiers)
|
|
|
|
# %% {"tags": []}
|
|
debugging = False
|
|
if debugging:
|
|
print(f"port_count = {ep.port_count}, property_count = {ep.property_count}")
|
|
|
|
gym_env.environment
|
|
# training_env.environment.plot_environment_graph()
|
|
gym_env.environment.network.nodes
|
|
gym_env.action_space
|
|
gym_env.action_space.sample()
|
|
gym_env.observation_space.sample()
|
|
o0, _ = gym_env.reset()
|
|
o_test, r, d, t, i = gym_env.step(gym_env.sample_valid_action())
|
|
o0, _ = gym_env.reset()
|
|
|
|
o0.keys()
|
|
|
|
fe_example = w.RavelEncoding(ep, [w.Feature_active_node_properties(ep), w.Feature_discovered_node_count(ep)])
|
|
a = w.StateAugmentation(o0)
|
|
w.Feature_discovered_ports(ep).get(a)
|
|
fe_example.encode_at(a, 0)
|
|
|
|
# %% {"tags": []}
|
|
# Evaluate a random agent that opportunistically exploits
|
|
# credentials gathere in its local cache
|
|
credlookup_run = learner.epsilon_greedy_search(
|
|
gym_env,
|
|
ep,
|
|
learner=rca.CredentialCacheExploiter(),
|
|
episode_count=10,
|
|
iteration_count=iteration_count,
|
|
epsilon=0.90,
|
|
render=False,
|
|
epsilon_exponential_decay=10000,
|
|
epsilon_minimum=0.10,
|
|
verbosity=Verbosity.Quiet,
|
|
title="Credential lookups (ϵ-greedy)",
|
|
)
|
|
|
|
# %% {"tags": []}
|
|
# Evaluate a Tabular Q-learning agent
|
|
tabularq_run = learner.epsilon_greedy_search(
|
|
gym_env,
|
|
ep,
|
|
learner=tqa.QTabularLearner(ep, gamma=0.015, learning_rate=0.01, exploit_percentile=100),
|
|
episode_count=training_episode_count,
|
|
iteration_count=iteration_count,
|
|
epsilon=0.90,
|
|
epsilon_exponential_decay=5000,
|
|
epsilon_minimum=0.01,
|
|
verbosity=Verbosity.Quiet,
|
|
render=False,
|
|
plot_episodes_length=False,
|
|
title="Tabular Q-learning",
|
|
)
|
|
|
|
# %% {"tags": []}
|
|
# Evaluate an agent that exploits the Q-table learnt above
|
|
tabularq_exploit_run = learner.epsilon_greedy_search(
|
|
gym_env,
|
|
ep,
|
|
learner=tqa.QTabularLearner(ep, trained=tabularq_run["learner"], gamma=0.0, learning_rate=0.0, exploit_percentile=90),
|
|
episode_count=eval_episode_count,
|
|
iteration_count=iteration_count,
|
|
epsilon=0.0,
|
|
render=False,
|
|
verbosity=Verbosity.Quiet,
|
|
title="Exploiting Q-matrix",
|
|
)
|
|
|
|
# %% {"tags": []}
|
|
# Evaluate the Deep Q-learning agent
|
|
dql_run = learner.epsilon_greedy_search(
|
|
cyberbattle_gym_env=gym_env,
|
|
environment_properties=ep,
|
|
learner=dqla.DeepQLearnerPolicy(
|
|
ep=ep,
|
|
gamma=0.015,
|
|
replay_memory_size=10000,
|
|
target_update=10,
|
|
batch_size=512,
|
|
# torch default learning rate is 1e-2
|
|
# a large value helps converge in less episodes
|
|
learning_rate=0.01,
|
|
),
|
|
episode_count=training_episode_count,
|
|
iteration_count=iteration_count,
|
|
epsilon=0.90,
|
|
epsilon_exponential_decay=5000,
|
|
epsilon_minimum=0.10,
|
|
verbosity=Verbosity.Quiet,
|
|
render=False,
|
|
plot_episodes_length=False,
|
|
title="DQL",
|
|
)
|
|
|
|
# %% {"tags": []}
|
|
# Evaluate an agent that exploits the Q-function learnt above
|
|
dql_exploit_run = learner.epsilon_greedy_search(
|
|
gym_env,
|
|
ep,
|
|
learner=dql_run["learner"],
|
|
episode_count=eval_episode_count,
|
|
iteration_count=iteration_count,
|
|
epsilon=0.0,
|
|
epsilon_minimum=0.00,
|
|
render=False,
|
|
plot_episodes_length=False,
|
|
verbosity=Verbosity.Quiet,
|
|
title="Exploiting DQL",
|
|
)
|
|
|
|
|
|
# %% {"tags": []}
|
|
# Evaluate the random agent
|
|
random_run = learner.epsilon_greedy_search(
|
|
gym_env,
|
|
ep,
|
|
learner=learner.RandomPolicy(),
|
|
episode_count=eval_episode_count,
|
|
iteration_count=iteration_count,
|
|
epsilon=1.0, # purely random
|
|
render=False,
|
|
verbosity=Verbosity.Quiet,
|
|
plot_episodes_length=False,
|
|
title="Random search",
|
|
)
|
|
|
|
# %% {"tags": []}
|
|
# Compare and plot results for all the agents
|
|
all_runs = [random_run, credlookup_run, tabularq_run, tabularq_exploit_run, dql_run, dql_exploit_run]
|
|
|
|
# Plot averaged cumulative rewards for DQL vs Random vs DQL-Exploit
|
|
themodel = dqla.CyberBattleStateActionModel(ep)
|
|
p.plot_averaged_cummulative_rewards(
|
|
all_runs=all_runs,
|
|
title=f"Benchmark -- max_nodes={ep.maximum_node_count}, episodes={eval_episode_count},\n"
|
|
f"State: {[f.name() for f in themodel.state_space.feature_selection]} "
|
|
f"({len(themodel.state_space.feature_selection)}\n"
|
|
f"Action: abstract_action ({themodel.action_space.flat_size()})",
|
|
save_at=os.path.join(plots_dir, f"benchmark-{gymid}-cumrewards.png"),
|
|
)
|
|
|
|
# %% {"tags": []}
|
|
contenders = [credlookup_run, tabularq_run, dql_run, dql_exploit_run]
|
|
p.plot_episodes_length(contenders)
|
|
p.plot_averaged_cummulative_rewards(title=f"Agent Benchmark top contenders\n" f"max_nodes:{ep.maximum_node_count}\n", all_runs=contenders,
|
|
save_at=os.path.join(plots_dir, f"benchmark-{gymid}-cumreward_contenders.png"))
|
|
|
|
|
|
# %% {"tags": []}
|
|
# Plot cumulative rewards for all episodes
|
|
for r in contenders:
|
|
p.plot_all_episodes(r)
|