CyberBattleSim/notebooks/notebook_benchmark.py

231 строка
6.9 KiB
Python

# ---
# jupyter:
# jupytext:
# cell_metadata_filter: tags,-all
# cell_metadata_json: true
# formats: ipynb,py:percent
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.16.4
# kernelspec:
# display_name: Python 3 (ipykernel)
# language: python
# name: python3
# ---
# %% {"tags": []}
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Benchmark all the baseline agents
on a given CyberBattleSim environment and compare
them to the dumb 'random agent' baseline.
NOTE: You can run this `.py`-notebook directly from VSCode.
You can also generate a traditional Jupyter Notebook
using the VSCode command `Export Currenty Python File As Jupyter Notebook`.
"""
# pylint: disable=invalid-name
# %% {"tags": []}
import sys
import os
import logging
import gymnasium as gym
import cyberbattle.agents.baseline.learner as learner
import cyberbattle.agents.baseline.plotting as p
import cyberbattle.agents.baseline.agent_wrapper as w
import cyberbattle.agents.baseline.agent_randomcredlookup as rca
import cyberbattle.agents.baseline.agent_tabularqlearning as tqa
import cyberbattle.agents.baseline.agent_dql as dqla
from cyberbattle.agents.baseline.agent_wrapper import Verbosity
from cyberbattle._env.cyberbattle_env import CyberBattleEnv
logging.basicConfig(stream=sys.stdout, level=logging.ERROR, format="%(levelname)s: %(message)s")
# %% {"tags": []}
# %matplotlib inline
# %% {"tags": ["parameters"]}
# Papermill notebook parameters
gymid = "CyberBattleChain-v0"
env_size = 10
iteration_count = 9000
training_episode_count = 50
eval_episode_count = 5
maximum_node_count = 22
maximum_total_credentials = 22
plots_dir = "output/plots"
# %% {"tags": []}
os.makedirs(plots_dir, exist_ok=True)
# Load the Gym environment
if env_size:
_gym_env = gym.make(gymid, size=env_size)
else:
_gym_env = gym.make(gymid)
from typing import cast
gym_env = cast(CyberBattleEnv, _gym_env.unwrapped)
assert isinstance(gym_env, CyberBattleEnv), f"Expected CyberBattleEnv, got {type(gym_env)}"
ep = w.EnvironmentBounds.of_identifiers(maximum_node_count=maximum_node_count, maximum_total_credentials=maximum_total_credentials, identifiers=gym_env.identifiers)
# %% {"tags": []}
debugging = False
if debugging:
print(f"port_count = {ep.port_count}, property_count = {ep.property_count}")
gym_env.environment
# training_env.environment.plot_environment_graph()
gym_env.environment.network.nodes
gym_env.action_space
gym_env.action_space.sample()
gym_env.observation_space.sample()
o0, _ = gym_env.reset()
o_test, r, d, t, i = gym_env.step(gym_env.sample_valid_action())
o0, _ = gym_env.reset()
o0.keys()
fe_example = w.RavelEncoding(ep, [w.Feature_active_node_properties(ep), w.Feature_discovered_node_count(ep)])
a = w.StateAugmentation(o0)
w.Feature_discovered_ports(ep).get(a)
fe_example.encode_at(a, 0)
# %% {"tags": []}
# Evaluate a random agent that opportunistically exploits
# credentials gathere in its local cache
credlookup_run = learner.epsilon_greedy_search(
gym_env,
ep,
learner=rca.CredentialCacheExploiter(),
episode_count=10,
iteration_count=iteration_count,
epsilon=0.90,
render=False,
epsilon_exponential_decay=10000,
epsilon_minimum=0.10,
verbosity=Verbosity.Quiet,
title="Credential lookups (ϵ-greedy)",
)
# %% {"tags": []}
# Evaluate a Tabular Q-learning agent
tabularq_run = learner.epsilon_greedy_search(
gym_env,
ep,
learner=tqa.QTabularLearner(ep, gamma=0.015, learning_rate=0.01, exploit_percentile=100),
episode_count=training_episode_count,
iteration_count=iteration_count,
epsilon=0.90,
epsilon_exponential_decay=5000,
epsilon_minimum=0.01,
verbosity=Verbosity.Quiet,
render=False,
plot_episodes_length=False,
title="Tabular Q-learning",
)
# %% {"tags": []}
# Evaluate an agent that exploits the Q-table learnt above
tabularq_exploit_run = learner.epsilon_greedy_search(
gym_env,
ep,
learner=tqa.QTabularLearner(ep, trained=tabularq_run["learner"], gamma=0.0, learning_rate=0.0, exploit_percentile=90),
episode_count=eval_episode_count,
iteration_count=iteration_count,
epsilon=0.0,
render=False,
verbosity=Verbosity.Quiet,
title="Exploiting Q-matrix",
)
# %% {"tags": []}
# Evaluate the Deep Q-learning agent
dql_run = learner.epsilon_greedy_search(
cyberbattle_gym_env=gym_env,
environment_properties=ep,
learner=dqla.DeepQLearnerPolicy(
ep=ep,
gamma=0.015,
replay_memory_size=10000,
target_update=10,
batch_size=512,
# torch default learning rate is 1e-2
# a large value helps converge in less episodes
learning_rate=0.01,
),
episode_count=training_episode_count,
iteration_count=iteration_count,
epsilon=0.90,
epsilon_exponential_decay=5000,
epsilon_minimum=0.10,
verbosity=Verbosity.Quiet,
render=False,
plot_episodes_length=False,
title="DQL",
)
# %% {"tags": []}
# Evaluate an agent that exploits the Q-function learnt above
dql_exploit_run = learner.epsilon_greedy_search(
gym_env,
ep,
learner=dql_run["learner"],
episode_count=eval_episode_count,
iteration_count=iteration_count,
epsilon=0.0,
epsilon_minimum=0.00,
render=False,
plot_episodes_length=False,
verbosity=Verbosity.Quiet,
title="Exploiting DQL",
)
# %% {"tags": []}
# Evaluate the random agent
random_run = learner.epsilon_greedy_search(
gym_env,
ep,
learner=learner.RandomPolicy(),
episode_count=eval_episode_count,
iteration_count=iteration_count,
epsilon=1.0, # purely random
render=False,
verbosity=Verbosity.Quiet,
plot_episodes_length=False,
title="Random search",
)
# %% {"tags": []}
# Compare and plot results for all the agents
all_runs = [random_run, credlookup_run, tabularq_run, tabularq_exploit_run, dql_run, dql_exploit_run]
# Plot averaged cumulative rewards for DQL vs Random vs DQL-Exploit
themodel = dqla.CyberBattleStateActionModel(ep)
p.plot_averaged_cummulative_rewards(
all_runs=all_runs,
title=f"Benchmark -- max_nodes={ep.maximum_node_count}, episodes={eval_episode_count},\n"
f"State: {[f.name() for f in themodel.state_space.feature_selection]} "
f"({len(themodel.state_space.feature_selection)}\n"
f"Action: abstract_action ({themodel.action_space.flat_size()})",
save_at=os.path.join(plots_dir, f"benchmark-{gymid}-cumrewards.png"),
)
# %% {"tags": []}
contenders = [credlookup_run, tabularq_run, dql_run, dql_exploit_run]
p.plot_episodes_length(contenders)
p.plot_averaged_cummulative_rewards(title=f"Agent Benchmark top contenders\n" f"max_nodes:{ep.maximum_node_count}\n", all_runs=contenders,
save_at=os.path.join(plots_dir, f"benchmark-{gymid}-cumreward_contenders.png"))
# %% {"tags": []}
# Plot cumulative rewards for all episodes
for r in contenders:
p.plot_all_episodes(r)