Initial CyberBattleSim release

This commit is contained in:
William Blum 2021-01-07 21:12:17 +00:00 коммит произвёл William Blum
Коммит 17035cb761
99 изменённых файлов: 52090 добавлений и 0 удалений

5
.dockerignore Normal file
Просмотреть файл

@ -0,0 +1,5 @@
/venv/**
/.cache/**
**/*.pyc
**/__pycache__
**/log.txt

13
.editorconfig Normal file
Просмотреть файл

@ -0,0 +1,13 @@
root=true
[*]
end_of_line = lf
insert_final_newline = true
[*.{py,pyx}]
indent_style = space
indent_size = 4
charset = utf-8
file_type_emacs = python
trim_trailing_whitespace = true
max_line_length = 120

1
.gitattributes поставляемый Normal file
Просмотреть файл

@ -0,0 +1 @@
* -text

8
.gitconfig Normal file
Просмотреть файл

@ -0,0 +1,8 @@
[branch "master"]
rebase = true
[branch]
autosetuprebase = always
[push]
default = simple
[core]
whitespace = cr-at-eol,-trailing-space

60
.gitignore поставляемый Normal file
Просмотреть файл

@ -0,0 +1,60 @@
# python
build/
dist/
__pycache__*
*.pyc
*.pyo
*.egg-info
# mypy type checking
/.mypy_cache/
/.pytest_cache/
**/*.pyc
# jupyter / ipython
.ipynb_checkpoints/
.init
Untitled*.ipynb
data/
# tests
.coverage
.hypothesis
# doc
doc/build/
# sublime
*.sublime-workspace
sftp*-config.json
# line profiler
*.lprof
# visual studio
.vs
*.pyproj
*.pyperf
*.sln
# vscode
.ropeproject
# pycharm
.idea
# data
*.zip
/venv/**
/.cache/**
/.dmypy.json
src/.dmypy.json
notebooks/untracked/**
cyberbattle.code-workspace
typings/**
cyberbattle/agents/baseline/notebooks/images/*.png
cyberbattle/agents/baseline/notebooks/images/*.gif
log.txt

38
.pre-commit-config.yaml Normal file
Просмотреть файл

@ -0,0 +1,38 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
exclude: 'typings/.*'
repos:
- repo: local
hooks:
- id: pyright
name: pyright
entry: pyright
language: node
pass_filenames: false
types: [python]
additional_dependencies: ['pyright@1.1.64']
- id: flake8
name: flake8
description: '`flake8` is a command-line utility for enforcing style consistency across Python projects.'
entry: flake8
language: python
types: [python]
require_serial: true
- repo: https://github.com/pre-commit/mirrors-autopep8
rev: v1.5.2 # Use the sha / tag you want to point at
hooks:
- id: autopep8
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.4.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
exclude: '.*\.ipynb$'
- id: check-yaml
- id: check-added-large-files

550
.pylintrc Normal file
Просмотреть файл

@ -0,0 +1,550 @@
[MASTER]
# A comma-separated list of package or module names from where C extensions may
# be loaded. Extensions are loading into the active Python interpreter and may
# run arbitrary code
extension-pkg-whitelist=
# Add files or directories to the blacklist. They should be base names, not
# paths.
ignore=CVS,.ipynb_checkpoints
# Add files or directories matching the regex patterns to the blacklist. The
# regex matches against base names, not paths.
ignore-patterns=
# Python code to execute, usually for sys.path manipulation such as
# pygtk.require().
#init-hook=
# Use multiple processes to speed up Pylint.
jobs=1
# List of plugins (as comma separated values of python modules names) to load,
# usually to register additional checkers.
load-plugins=
# Pickle collected data for later comparisons.
persistent=yes
# Specify a configuration file.
#rcfile=
# When enabled, pylint would attempt to guess common misconfiguration and emit
# user-friendly hints instead of false-positive error messages
suggestion-mode=yes
# Allow loading of arbitrary C extensions. Extensions are imported into the
# active Python interpreter and may run arbitrary code.
unsafe-load-any-extension=no
[MESSAGES CONTROL]
# Only show warnings with the listed confidence levels. Leave empty to show
# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
confidence=HIGH
# Disable the message, report, category or checker with the given id(s). You
# can either give multiple identifiers separated by comma (,) or put this
# option multiple times (only on the command line, not in the configuration
# file where it should appear only once).You can also use "--disable=all" to
# disable everything first and then reenable specific checks. For example, if
# you want to run only the similarities checker, you can use "--disable=all
# --enable=similarities". If you want to run only the classes checker, but have
# no Warning level messages displayed, use"--disable=all --enable=classes
# --disable=W"
disable=print-statement,
parameter-unpacking,
unpacking-in-except,
old-raise-syntax,
backtick,
long-suffix,
old-ne-operator,
old-octal-literal,
import-star-module-level,
non-ascii-bytes-literal,
raw-checker-failed,
bad-inline-option,
locally-disabled,
locally-enabled,
file-ignored,
suppressed-message,
useless-suppression,
deprecated-pragma,
apply-builtin,
basestring-builtin,
buffer-builtin,
cmp-builtin,
coerce-builtin,
execfile-builtin,
file-builtin,
long-builtin,
raw_input-builtin,
reduce-builtin,
standarderror-builtin,
unicode-builtin,
xrange-builtin,
coerce-method,
delslice-method,
getslice-method,
setslice-method,
no-absolute-import,
old-division,
dict-iter-method,
dict-view-method,
next-method-called,
metaclass-assignment,
indexing-exception,
raising-string,
reload-builtin,
oct-method,
hex-method,
nonzero-method,
cmp-method,
input-builtin,
round-builtin,
intern-builtin,
unichr-builtin,
map-builtin-not-iterating,
zip-builtin-not-iterating,
range-builtin-not-iterating,
filter-builtin-not-iterating,
using-cmp-argument,
eq-without-hash,
div-method,
idiv-method,
rdiv-method,
exception-message-attribute,
invalid-str-codec,
sys-max-int,
bad-python3-import,
deprecated-string-function,
deprecated-str-translate-call,
deprecated-itertools-function,
deprecated-types-field,
next-method-defined,
dict-items-not-iterating,
dict-keys-not-iterating,
dict-values-not-iterating
# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
# multiple time (only on the command line, not in the configuration file where
# it should appear only once). See also the "--disable" option for examples.
enable=c-extension-no-member
[REPORTS]
# Python expression which should return a note less than 10 (10 is the highest
# note). You have access to the variables errors warning, statement which
# respectively contain the number of errors / warnings messages and the total
# number of statements analyzed. This is used by the global evaluation report
# (RP0004).
evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
# Template used to display messages. This is a python new-style format string
# used to format the message information. See doc for all details
#msg-template=
# Set the output format. Available formats are text, parseable, colorized, json
# and msvs (visual studio).You can also give a reporter class, eg
# mypackage.mymodule.MyReporterClass.
output-format=text
# Tells whether to display a full report or only the messages
reports=no
# Activate the evaluation score.
score=yes
[REFACTORING]
# Maximum number of nested blocks for function / method body
max-nested-blocks=5
# Complete name of functions that never returns. When checking for
# inconsistent-return-statements if a never returning function is called then
# it will be considered as an explicit return statement and no message will be
# printed.
never-returning-functions=optparse.Values,sys.exit
[BASIC]
# Naming style matching correct argument names
argument-naming-style=snake_case
# Regular expression matching correct argument names. Overrides argument-
# naming-style
#argument-rgx=
# Naming style matching correct attribute names
attr-naming-style=snake_case
# Regular expression matching correct attribute names. Overrides attr-naming-
# style
#attr-rgx=
# Bad variable names which should always be refused, separated by a comma
bad-names=foo,
bar,
baz,
toto,
tutu,
tata
# Naming style matching correct class attribute names
class-attribute-naming-style=any
# Regular expression matching correct class attribute names. Overrides class-
# attribute-naming-style
#class-attribute-rgx=
# Naming style matching correct class names
class-naming-style=PascalCase
# Regular expression matching correct class names. Overrides class-naming-style
#class-rgx=
# Naming style matching correct constant names
const-naming-style=UPPER_CASE
# Regular expression matching correct constant names. Overrides const-naming-
# style
#const-rgx=
# Minimum line length for functions/classes that require docstrings, shorter
# ones are exempt.
docstring-min-length=-1
# Naming style matching correct function names
function-naming-style=snake_case
# Regular expression matching correct function names. Overrides function-
# naming-style
#function-rgx=
# Good variable names which should always be accepted, separated by a comma
good-names=i,
j,
k,
ex,
Run,
t,
u,
v,
n,
x,
y,
o,
h,
r,
c,
Q,
_
# Include a hint for the correct naming format with invalid-name
include-naming-hint=no
# Naming style matching correct inline iteration names
inlinevar-naming-style=any
# Regular expression matching correct inline iteration names. Overrides
# inlinevar-naming-style
#inlinevar-rgx=
# Naming style matching correct method names
method-naming-style=snake_case
# Regular expression matching correct method names. Overrides method-naming-
# style
#method-rgx=
# Naming style matching correct module names
module-naming-style=snake_case
# Regular expression matching correct module names. Overrides module-naming-
# style
#module-rgx=
# Colon-delimited sets of names that determine each other's naming style when
# the name regexes allow several styles.
name-group=
# Regular expression which should only match function or class names that do
# not require a docstring.
no-docstring-rgx=^_
# List of decorators that produce properties, such as abc.abstractproperty. Add
# to this list to register other decorators that produce valid properties.
property-classes=abc.abstractproperty
# Naming style matching correct variable names
variable-naming-style=snake_case
# Regular expression matching correct variable names. Overrides variable-
# naming-style
#variable-rgx=
[FORMAT]
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
expected-line-ending-format=
# Regexp for a line that is allowed to be longer than the limit.
ignore-long-lines=^\s*(# )?<?https?://\S+>?$
# Number of spaces of indent required inside a hanging or continued line.
indent-after-paren=4
# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
# tab).
indent-string=' '
# Maximum number of characters on a single line.
max-line-length=0
# Maximum number of lines in a module
max-module-lines=1000
# List of optional constructs for which whitespace checking is disabled. `dict-
# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
# `trailing-comma` allows a space between comma and closing bracket: (a, ).
# `empty-line` allows space-only lines.
no-space-check=trailing-comma,
dict-separator
# Allow the body of a class to be on the same line as the declaration if body
# contains single statement.
single-line-class-stmt=no
# Allow the body of an if to be on the same line as the test if there is no
# else.
single-line-if-stmt=no
[LOGGING]
# Logging modules to check that the string format arguments are in logging
# function parameter format
logging-modules=logging
[MISCELLANEOUS]
# List of note tags to take in consideration, separated by a comma.
notes=FIXME,
XXX,
TODO
[SIMILARITIES]
# Ignore comments when computing similarities.
ignore-comments=yes
# Ignore docstrings when computing similarities.
ignore-docstrings=yes
# Ignore imports when computing similarities.
ignore-imports=no
# Minimum lines number of a similarity.
min-similarity-lines=4
[SPELLING]
# Limits count of emitted suggestions for spelling mistakes
max-spelling-suggestions=4
# Spelling dictionary name. Available dictionaries: none. To make it working
# install python-enchant package.
spelling-dict=
# List of comma separated words that should not be checked.
spelling-ignore-words=
# A path to a file that contains private dictionary; one word per line.
spelling-private-dict-file=
# Tells whether to store unknown words to indicated private dictionary in
# --spelling-private-dict-file option instead of raising a message.
spelling-store-unknown-words=no
[TYPECHECK]
# List of decorators that produce context managers, such as
# contextlib.contextmanager. Add to this list to register other decorators that
# produce valid context managers.
contextmanager-decorators=contextlib.contextmanager
# List of members which are set dynamically and missed by pylint inference
# system, and so shouldn't trigger E1101 when accessed. Python regular
# expressions are accepted.
generated-members=
# Tells whether missing members accessed in mixin class should be ignored. A
# mixin class is detected if its name ends with "mixin" (case insensitive).
ignore-mixin-members=yes
# This flag controls whether pylint should warn about no-member and similar
# checks whenever an opaque object is returned when inferring. The inference
# can return multiple potential results while evaluating a Python object, but
# some branches might not be evaluated, which results in partial inference. In
# that case, it might be useful to still emit no-member and other checks for
# the rest of the inferred objects.
ignore-on-opaque-inference=yes
# List of class names for which member attributes should not be checked (useful
# for classes with dynamically set attributes). This supports the use of
# qualified names.
ignored-classes=optparse.Values,thread._local,_thread._local
# List of module names for which member attributes should not be checked
# (useful for modules/projects where namespaces are manipulated during runtime
# and thus existing member attributes cannot be deduced by static analysis. It
# supports qualified module names, as well as Unix pattern matching.
ignored-modules=
# Show a hint with possible names when a member name was not found. The aspect
# of finding the hint is based on edit distance.
missing-member-hint=yes
# The minimum edit distance a name should have in order to be considered a
# similar match for a missing member name.
missing-member-hint-distance=1
# The total number of similar names that should be taken in consideration when
# showing a hint for a missing member.
missing-member-max-choices=1
[VARIABLES]
# List of additional names supposed to be defined in builtins. Remember that
# you should avoid to define new builtins when possible.
additional-builtins=
# Tells whether unused global variables should be treated as a violation.
allow-global-unused-variables=yes
# List of strings which can identify a callback function by name. A callback
# name must start or end with one of those strings.
callbacks=cb_,
_cb
# A regular expression matching the name of dummy variables (i.e. expectedly
# not used).
dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
# Argument names that match this expression will be ignored. Default to name
# with leading underscore
ignored-argument-names=_.*|^ignored_|^unused_
# Tells whether we should check for unused import in __init__ files.
init-import=no
# List of qualified module names which can have objects that can redefine
# builtins.
redefining-builtins-modules=six.moves,past.builtins,future.builtins
[CLASSES]
# List of method names used to declare (i.e. assign) instance attributes.
defining-attr-methods=__init__,
__new__,
setUp
# List of member names, which should be excluded from the protected access
# warning.
exclude-protected=_asdict,
_fields,
_replace,
_source,
_make
# List of valid names for the first argument in a class method.
valid-classmethod-first-arg=cls
# List of valid names for the first argument in a metaclass class method.
valid-metaclass-classmethod-first-arg=mcs
[DESIGN]
# Maximum number of arguments for function / method
max-args=5
# Maximum number of attributes for a class (see R0902).
max-attributes=7
# Maximum number of boolean expressions in a if statement
max-bool-expr=5
# Maximum number of branch for function / method body
max-branches=12
# Maximum number of locals for function / method body
max-locals=15
# Maximum number of parents for a class (see R0901).
max-parents=7
# Maximum number of public methods for a class (see R0904).
max-public-methods=20
# Maximum number of return / yield for function / method body
max-returns=6
# Maximum number of statements in function / method body
max-statements=50
# Minimum number of public methods for a class (see R0903).
min-public-methods=2
[IMPORTS]
# Allow wildcard imports from modules that define __all__.
allow-wildcard-with-all=no
# Analyse import fallback blocks. This can be used to support both Python 2 and
# 3 compatible code, which means that the block might have code that exists
# only in one or another interpreter, leading to false positives when analysed.
analyse-fallback-blocks=no
# Deprecated modules which should not be used, separated by a comma
deprecated-modules=optparse,tkinter.tix
# Create a graph of external dependencies in the given file (report RP0402 must
# not be disabled)
ext-import-graph=
# Create a graph of every (i.e. internal and external) dependencies in the
# given file (report RP0402 must not be disabled)
import-graph=
# Create a graph of internal dependencies in the given file (report RP0402 must
# not be disabled)
int-import-graph=
# Force import order to recognize a module as part of the standard
# compatibility libraries.
known-standard-library=
# Force import order to recognize a module as part of a third party library.
known-third-party=enchant
[EXCEPTIONS]
# Exceptions that will emit a warning when being caught. Defaults to
# "Exception"
overgeneral-exceptions=Exception

22
.vscode/launch.json поставляемый Normal file
Просмотреть файл

@ -0,0 +1,22 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python: cyberbattle_gym",
"type": "python",
"request": "launch",
"program": "src/cyberbattle_gym/samples/run.py",
"console": "integratedTerminal"
},
{
"name": "Python: Current File",
"type": "python",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal"
}
]
}

43
.vscode/settings.json поставляемый Normal file
Просмотреть файл

@ -0,0 +1,43 @@
{
"python.pythonPath": "/usr/bin/python",
"python.venvPath": "venv",
"python.linting.pylintEnabled": false,
"python.linting.mypyEnabled": false,
"python.linting.flake8Enabled": true,
"python.linting.enabled": true,
"python.linting.mypyPath": "venv\\Scripts\\mypy.exe",
"python.linting.mypyArgs": [
"--config-file",
"setup.cfg"
],
"python.testing.pytestArgs": [
"cyberbattle"
],
"python.testing.unittestEnabled": false,
"python.testing.nosetestsEnabled": false,
"python.testing.pytestEnabled": true,
"python.linting.lintOnSave": true,
"python.formatting.provider": "autopep8",
"editor.formatOnPaste": true,
"editor.formatOnSave": true,
"extensions.ignoreRecommendations": true,
"files.autoSave": "off",
"debug.allowBreakpointsEverywhere": true,
"cSpell.enableFiletypes": [
"!yaml",
"!python",
"!json"
],
"python.languageServer": "Pylance",
"python.analysis.typeCheckingMode": "basic",
"search.useGlobalIgnoreFiles": true,
"files.watcherExclude": {
"typings/**": true,
"venv/**": true
},
"files.exclude": {
"venv/": true
},
"jupyter.jupyterServerType": "local",
"files.trimFinalNewlines": true
}

9
CODE_OF_CONDUCT.md Normal file
Просмотреть файл

@ -0,0 +1,9 @@
# Microsoft Open Source Code of Conduct
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
Resources:
- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns

25
Dockerfile Normal file
Просмотреть файл

@ -0,0 +1,25 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
FROM mcr.microsoft.com/azureml/onnxruntime:latest-cuda
WORKDIR /root
ADD *.sh ./
ADD *.txt ./
ADD *.py ./
RUN export TERM=dumb && ./init.sh -n
# Override conda python 3.7 install
RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.8 2
ENV PATH="/usr/bin:${PATH}"
COPY . .
# To build the docker image:
# docker build -t cyberbattle:1.1 .
#
# To run
# docker run -it --rm cyberbattle:1.1 bash
#
# Pushing to private repository
# docker login -u spinshot-team-token-writer --password-stdin spinshot.azurecr.io
# docker tag cyberbattle:1.1 spinshot.azurecr.io/cyberbattle:1.1
# docker push spinshot.azurecr.io/cyberbattle:1.1

21
LICENSE Normal file
Просмотреть файл

@ -0,0 +1,21 @@
MIT License
Copyright (c) Microsoft Corporation.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE

246
README.md Normal file
Просмотреть файл

@ -0,0 +1,246 @@
# CyberBattleSim
CyberBattleSim is an experimentation research platform to investigate the interaction
of automated agents operating in a simulated abstract enterprise network environment.
The simulation provides a high-level abstraction of computer networks
and cyber security concepts.
Its Python-based OpenAI Gym interface allows for training of
automated agents using reinforcement learning algorithms.
The simulation environment is parameterized by a fixed network topology
with an associated set of vulnerabilities that attacker agents can utilize
to move laterally in the network.
The goal of the attacker is to take ownership of a portion of the network by exploiting
vulnerabilities that are planted in the computer nodes.
While the attacker attempts to spread throughout the network,
a defender agent watches the network activity and tries to contain the attack.
We provide a basic stochastic defender that detects
and mitigate ongoing attacks based on pre-defined probabilities of success.
Mitigation consists in re-imaging the infected nodes, a process
abstractly modeled as an operation spanning over multiple simulation steps.
To compare performance of the agents we look at two metrics: the number of simulation steps taken to
attain their goal and the cumulative rewards over simulation steps across training epochs.
## Project goals
We view this project as an experimentation platform to conduct research on the interaction of automated agents in abstract simulated network environments. By open sourcing it we hope to encourage the research community investigate how cyber-agents interact and evolve in such network environments.
The simulation we provide is admittedly simplistic, but this has advantages. Its highly abstract nature prohibits direct application to real-world systems thus providing a safeguard against potential nefarious use of automated agents trained with it.
At the same time, its simplicity allows us to focus on specific security aspects we aim to study and quickly experiment with recent machine learning and AI algorithms.
For instance, the current implementation focuses on
the lateral movement cyber-attacks techniques, with the hope to understand how the network topology and configuration affects them. With this goal in mind, we felt that modeling actual network traffic was not necessary. This is just one example of a significant limitation in our system that future contributions might want to address.
On the algorithmic side, we provide some basic agents as starting points but we
would be curious to find out how state-of-the art reinforcement learning algorithms compare to them. We found that the large action space
intrinsic to any computer system is a particular challenge for
Reinforcement Learning, in contrasts to other applications such as video games or robot control. Training agents that can store and retrieve credentials is another challenge we faced when applying RL techniques
where agents typically do not feature internal memory. These are other areas of research where the provided simulation can perhaps be used for benchmarking purpose.
Other areas of interests include the responsible and ethical use of autonomous cyber-security systems:
How to design an enterprise network that gives an intrinsic advantage to defender agents?
How to conduct safe research aimed at defending enterprises against autonomous cyber-attacks
while preventing nefarious use of such technology?
## Documentation
Read the [Quick introduction](/docs/quickintro.md) to the project.
## Build status
| Type | Branch | Status |
| --- | ------ | ------ |
| CI | master | [![Build Status](https://mscodehub.visualstudio.com/Asterope/_apis/build/status/CyberBattle-ContinuousIntegration?branchName=master)](https://mscodehub.visualstudio.com/Asterope/_build/latest?definitionId=1359&branchName=master) |
| Docker image | master | [![Build Status](https://mscodehub.visualstudio.com/Asterope/_apis/build/status/CyberBattle-Docker?branchName=master)](https://mscodehub.visualstudio.com/Asterope/_build/latest?definitionId=1454&branchName=master) |
## Benchmark
See [Benchmark](/docs/benchmark.md).
## Setting up a dev environment
It is strongly recommended to work under a Linux environment, either directly or via WSL on Windows.
Running Python on Windows directly should work but is not supported anymore.
Start by checking out the repository:
```bash
git clone https://github.com/microsoft/CyberBattleSim.git
```
### On Linux or WSL
The instructions were tested on a Linux Ubuntu distribution (both native and via WSL). Run the following command to set-up your dev environment and install all the required dependencies (apt and pip packages):
```bash
./init.sh
```
The script installs python3.8 if not present. If you are running a version of Ubuntu older than 20 it will automatically add an additional apt repository to install python3.8.
The script will create a [virtual Python environment](https://docs.python.org/3/library/venv.html) under a `venv` subdirectory, you can then
run Python with `venv/bin/python`.
> Note: If you prefer Python from a global installation instead of a virtual environment then you can skip the creation of the virtual envrionment by running the script with `./init.sh -n`. This will instead install all the Python packages on a system-wide installation of Python 3.8.
#### Windows Subsystem for Linux
The supported dev environment on Windows is via WSL.
You first need to install an Ubuntu WSL distribution on your Windows machine,
and then proceed with the Linux instructions (next section).
#### Git authentication from WSL
To authenticate with Git you can either use SSH-based authentication, or
alternatively use the credential-helper trick to automatically generate a
PAT token. The latter can be done by running the following commmand under WSL
([more info here](https://docs.microsoft.com/en-us/windows/wsl/tutorials/wsl-git)):
```ps
git config --global credential.helper "/mnt/c/Program\ Files/Git/mingw64/libexec/git-core/git-credential-manager.exe"
```
#### Docker on WSL
To run your environment within a docker container, we recommend running `docker` via Windows Subsystem on Linux (WSL) using the following instructions:
[Installing Docker on Windows under WSL](https://docs.docker.com/docker-for-windows/wsl-tech-preview/)).
### Windows (unsupported)
This method is not supported anymore, please prefer instead running under
a WSL subsystem Linux environment.
But if you insist you want to start by installing [Python 3.8](https://www.python.org/downloads/windows/) then in a Powershell prompt run the `./init.ps1` script.
## Getting started quickly using Docker (internal only at this stage)
> NOTE: We do not currently redistribute build artifacts or Docker containers externally for this project. We provide the Dockerfile and CI yaml files if you need to recreate those artifacts.
The quickest way to get up and running is to use the Docker container.
Note: you first need to request access to the Docker registry `spinshot.azurecr.io`. (Not publicly available.)
```bash
docker login spinshot.azurecr.io
docker pull spinshot.azurecr.io/cyberbattle:157884
docker run -it spinshot.azurecr.io/cyberbattle:157884 cyberbattle/agents/baseline/run.py
```
## Check your environment
Run the following command to run a simulation with a baseline RL agent:
```
python cyberbattle/agents/baseline/run.py --training_episode_count 1 --eval_episode_count 1 --iteration_count 10 --rewardplot_with 80 --chain_size=20 --ownership_goal 1.0
```
If everything is setup correctly you should get an output that looks like this:
```bash
torch cuda available=True
###### DQL
Learning with: episode_count=1,iteration_count=10,ϵ=0.9,ϵ_min=0.1, ϵ_expdecay=5000,γ=0.015, lr=0.01, replaymemory=10000,
batch=512, target_update=10
## Episode: 1/1 'DQL' ϵ=0.9000, γ=0.015, lr=0.01, replaymemory=10000,
batch=512, target_update=10
Episode 1|Iteration 10|reward: 139.0|Elapsed Time: 0:00:00|###################################################################|
###### Random search
Learning with: episode_count=1,iteration_count=10,ϵ=1.0,ϵ_min=0.0,
## Episode: 1/1 'Random search' ϵ=1.0000,
Episode 1|Iteration 10|reward: 194.0|Elapsed Time: 0:00:00|###################################################################|
simulation ended
Episode duration -- DQN=Red, Random=Green
10.00 ┼
Cumulative rewards -- DQN=Red, Random=Green
194.00 ┼ ╭──╴
174.60 ┤ │
155.20 ┤╭─────╯
135.80 ┤│ ╭──╴
116.40 ┤│ │
97.00 ┤│ ╭╯
77.60 ┤│ │
58.20 ┤╯ ╭──╯
38.80 ┤ │
19.40 ┤ │
0.00 ┼──╯
```
## Jupyter notebooks
To quickly get familiar with the project you can open one the
the provided Juptyer notebooks to play interactively with
the gym environments. Just start jupyter with `jupyter notebook`, or
`venv/bin/jupyter notebook` if you are using a virtual environment setup.
- Capture The Flag Toy environment notebooks:
- [Random agent](notebooks/toyctf-random.ipynb)
- [Interactive session for a human player](notebooks/toyctf-blank.ipynb)
- [Interactive session - fully solved](notebooks/toyctf-solved.ipynb)
- Chain environment notebooks:
- [Random agent](notebooks/chainnetwork-random.ipynb)
- Other environments:
- [Interactive session with a randomly generated environment](notebooks/randomnetwork.ipynb)
- [Random agent playing on randomly generated networks](notebooks/c2_interactive_interface.ipynb)
## Contributing
This project welcomes contributions and suggestions. Most contributions require you to agree to a
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
When you submit a pull request, a CLA bot will automatically determine whether you need to provide
a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
provided by the bot. You will only need to do this once across all repos using our CLA.
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
### Ideas for contributions
Here are some ideas on how to contribute: enhance the simulation (event-based, refined the simulation, …), train an RL algorithm on the existing simulation,
implement benchmark to evaluate and compare novelty of agents, add more network generative modes to train RL-agent on, contribute to the doc, fix bugs.
See also the [wiki for more ideas](https://github.com/microsoft/CyberBattleGym/wiki/Possible-contributions).
## Citing this project
```bibtex
@misc{msft:cyberbattlesim,
Author = {Brandon Marken and Christian Seifert and Emily Goren and Haoran Wei and James Bono and Joshua Neil and Jugal Parikh and Justin Grana and Kate Farris and Kristian Holsheimer and Michael Betser and Nicole Nichols and William Blum},
Publisher = {GitHub},
Howpublished = {\url{https://github.com/microsoft/cyberbattlesim}},
Title = {CyberBattleSim},
Year = {2021}
}
```
## Note on privacy
This project does not include any customer data.
The provided models and network topologies are purely fictitious.
Users of the provided code provide all the input to the simulation
and must have the necessary permissions to use any provided data.
## Trademarks
This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft
trademarks or logos is subject to and must follow
[Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
Any use of third-party trademarks or logos are subject to those third-party's policies.

41
SECURITY.md Normal file
Просмотреть файл

@ -0,0 +1,41 @@
<!-- BEGIN MICROSOFT SECURITY.MD V0.0.5 BLOCK -->
## Security
Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below.
## Reporting Security Issues
**Please do not report security vulnerabilities through public GitHub issues.**
Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report).
If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc).
You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
* Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
* Full paths of source file(s) related to the manifestation of the issue
* The location of the affected source code (tag/branch/commit or direct URL)
* Any special configuration required to reproduce the issue
* Step-by-step instructions to reproduce the issue
* Proof-of-concept or exploit code (if possible)
* Impact of the issue, including how an attacker might exploit the issue
This information will help us triage your report more quickly.
If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs.
## Preferred Languages
We prefer all communications to be in English.
## Policy
Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd).
<!-- END MICROSOFT SECURITY.MD BLOCK -->

11
SUPPORT.md Normal file
Просмотреть файл

@ -0,0 +1,11 @@
# Support
## How to file issues and get help
This project uses GitHub Issues to track bugs and feature requests. Please search the existing
issues before filing new issues to avoid duplicates. For new issues, file your bug or
feature request as a new Issue.
## Microsoft Support Policy
Support for **CyberBattle** is limited to the resources listed above.

39
ado-build-containter.yml Normal file
Просмотреть файл

@ -0,0 +1,39 @@
# Docker
# Build and push an image to Azure Container Registry
# https://docs.microsoft.com/azure/devops/pipelines/languages/docker
trigger:
- master
resources:
- repo: self
variables:
# Container registry service connection established during pipeline creation
dockerRegistryServiceConnection: 'ac2df822-b01f-4588-a9bd-8195602c3995'
imageRepository: 'cyberbattle'
containerRegistry: 'spinshot.azurecr.io'
dockerfilePath: '$(Build.SourcesDirectory)/Dockerfile'
tag: '$(Build.BuildId)'
# Agent VM image name
vmImageName: 'ubuntu-latest'
stages:
- stage: Build
displayName: Build and push stage
jobs:
- job: Build
displayName: Build
pool:
vmImage: $(vmImageName)
steps:
- task: Docker@2
displayName: Build and push an image to container registry
inputs:
command: buildAndPush
repository: $(imageRepository)
dockerfile: $(dockerfilePath)
containerRegistry: $(dockerRegistryServiceConnection)
tags: |
$(tag)

79
ado-ci.yml Normal file
Просмотреть файл

@ -0,0 +1,79 @@
# Python package
# Create and test a Python package on multiple Python versions.
# Add steps that analyze code, save the dist with the build record, publish to a PyPI-compatible index, and more:
# https://docs.microsoft.com/azure/devops/pipelines/languages/python
#
# Adapted from the template available here https://github.com/microsoft/azure-pipelines-yaml/blob/master/templates/python-package.yml
trigger:
- master
pool:
vmImage: "ubuntu-latest"
strategy:
matrix:
Python38:
python.version: "3.8"
steps:
- checkout: self
submodules: false # does not work event though it's supposed to (authentication issue)
persistCredentials: false
- task: UsePythonVersion@0
name: pythonver
inputs:
versionSpec: "$(python.version)"
displayName: "Use Python $(python.version)"
- task: NodeTool@0
inputs:
versionSpec: '12.x'
#checkLatest: false # Optional
displayName: "Use node tools"
- script: |
cat apt-requirements.txt | xargs sudo apt install
displayName: "Install apt dependencies"
- script: |
python -m pip install flake8
flake8 --benchmark
displayName: "Lint with flake8"
- task: Cache@2
displayName: "Pull pip packages from cache"
inputs:
key: 'pip | "$(Agent.OS)" | requirements.txt | requirements.dev.txt | setup.py'
restoreKeys: |
pip | "$(Agent.OS)"
path: $(pythonver.pythonLocation)/lib/python3.8/site-packages
- script: |
./install-pythonpackages.sh --nocoax
displayName: "Pull pip dependencies"
- script: |
npm install -g pyright
displayName: "Install pyright"
- task: Cache@2
displayName: "Pull typing stubs from cache"
inputs:
key: 'typingstubs | "$(Agent.OS)" | createstubs.sh'
restoreKeys: |
typingstubs | "$(Agent.OS)" | createstubs.sh
path: typings/
- script: |
./createstubs.sh
displayName: "create type stubs"
- script: |
./pyright.sh
displayName: "Typecheck with pyright"
- script: |
pip install pytest-azurepipelines
python -m pytest -v cyberbattle
displayName: "Test with pytest"

Просмотреть файл

@ -0,0 +1,5 @@
libgtk-3-0
xvfb
chromium-browser
libgconf-2-4
npm

16
apt-requirements.txt Normal file
Просмотреть файл

@ -0,0 +1,16 @@
software-properties-common
python3.8
python3.8-venv
python3-distutils
python3-pytest
python3-opengl
python3-dev
python3.8-dev
build-essential
nodejs
curl
dirmngr
apt-transport-https
lsb-release
ca-certificates
python-distutils-extra

54
createstubs.sh Executable file
Просмотреть файл

@ -0,0 +1,54 @@
#!/bin/bash
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
set -e
. ./getpythonpath.sh
echo "$(tput setaf 2)Creating type stubs$(tput sgr0)"
createstub() {
local name=$1
if [ ! -d "typings/$name" ]; then
pyright --createstub $name
else
echo stub $name already created
fi
}
createstub networkx
createstub pandas
createstub plotly
createstub progressbar
createstub pytest
createstub setuptools
createstub ordered_set
createstub asciichartpy
if [ ! -d "typings/gym" ]; then
pyright --createstub gym
# Patch gym stubs
echo ' spaces = ...' >> typings/gym/spaces/dict.pyi
echo ' nvec = ...' >> typings/gym/spaces/space.pyi
else
echo stub gym already created
fi
if [ ! -d "typings/IPython" ]; then
pyright --createstub IPython.core.display
else
echo stub 'IPython' already created
fi
if [ ! -d "boolean" ]; then
pyright --createstub boolean
sed -i '/class BooleanAlgebra(object):/a\ TRUE = ...\n FALSE = ...' typings/boolean/boolean.pyi
else
echo stub 'boolean' already created
fi
# Stubs that needed manual patching and that
# were instead checked-in in git
# pyright --createstub boolean
# pyright --createstub gym

44
cyberbattle/NOTICE Normal file
Просмотреть файл

@ -0,0 +1,44 @@
NOTICES
This repository incorporates material as listed below or described in the code.
Component: PyTorch Reinforcement Learning Tutorial
https://github.com/pytorch/tutorials/blob/master/intermediate_source/reinforcement_q_learning.py
Open Source License/Copyright Notice.
BSD 3-Clause License
Copyright (c) 2017, Pytorch contributors
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Additional Attribution.
Adam Paszke https://github.com/apaszke

76
cyberbattle/__init__.py Normal file
Просмотреть файл

@ -0,0 +1,76 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Initialize CyberBattleSim module"""
from gym.envs.registration import registry, EnvSpec
from gym.error import Error
from . import simulation
from . import agents
from ._env.cyberbattle_env import AttackerGoal, DefenderGoal
from .samples.chainpattern import chainpattern
from .samples.toyctf import toy_ctf
from .simulation import generate_network, model
__all__ = (
'simulation',
'agents',
)
def register(id: str, cyberbattle_env_identifiers: model.Identifiers, **kwargs):
""" same as gym.envs.registry.register, but adds CyberBattle specs to env.spec """
if id in registry.env_specs:
raise Error('Cannot re-register id: {}'.format(id))
spec = EnvSpec(id, **kwargs)
# Map from port number to port names : List[model.PortName]
spec.ports = cyberbattle_env_identifiers.ports
# Array of all possible node properties (not necessarily all used in the network) : List[model.PropertyName]
spec.properties = cyberbattle_env_identifiers.properties
# Array defining an index for every possible local vulnerability name : List[model.VulnerabilityID]
spec.local_vulnerabilities = cyberbattle_env_identifiers.local_vulnerabilities
# Array defining an index for every possible remote vulnerability name : List[model.VulnerabilityID]
spec.remote_vulnerabilities = cyberbattle_env_identifiers.remote_vulnerabilities
registry.env_specs[id] = spec
if 'CyberBattleToyCtf-v0' in registry.env_specs:
del registry.env_specs['CyberBattleToyCtf-v0']
register(
id='CyberBattleToyCtf-v0',
cyberbattle_env_identifiers=toy_ctf.ENV_IDENTIFIERS,
entry_point='cyberbattle._env.cyberbattle_toyctf:CyberBattleToyCtf',
kwargs={'defender_agent': None,
'attacker_goal': AttackerGoal(reward=889),
'defender_goal': DefenderGoal(eviction=True)
},
# max_episode_steps=2600,
)
if 'CyberBattleRandom-v0' in registry.env_specs:
del registry.env_specs['CyberBattleRandom-v0']
register(
id='CyberBattleRandom-v0',
cyberbattle_env_identifiers=generate_network.ENV_IDENTIFIERS,
entry_point='cyberbattle._env.cyberbattle_random:CyberBattleRandom',
)
if 'CyberBattleChain-v0' in registry.env_specs:
del registry.env_specs['CyberBattleChain-v0']
register(
id='CyberBattleChain-v0',
cyberbattle_env_identifiers=chainpattern.ENV_IDENTIFIERS,
entry_point='cyberbattle._env.cyberbattle_chain:CyberBattleChain',
kwargs={'size': 4,
'defender_agent': None,
'attacker_goal': AttackerGoal(reward=2200),
'defender_goal': DefenderGoal(eviction=True),
'winning_reward': 5000.0,
'losing_reward': 0.0
},
reward_threshold=2200,
)

Просмотреть файл

Просмотреть файл

@ -0,0 +1,21 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""CyberBattle environment based on a simple chain network structure"""
from ..samples.chainpattern import chainpattern
from . import cyberbattle_env
class CyberBattleChain(cyberbattle_env.CyberBattleEnv):
"""CyberBattle environment based on a simple chain network structure"""
def __init__(self, size, **kwargs):
self.size = size
super().__init__(
initial_environment=chainpattern.new_environment(size),
**kwargs)
@ property
def name(self) -> str:
return f"CyberBattleChain-{self.size}"

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -0,0 +1,130 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Test the CyberBattle Gym environment"""
import pytest
import gym
import numpy as np
from .cyberbattle_env import AttackerGoal
def test_few_gym_iterations() -> None:
"""Run a few iterations of the gym environment"""
env = gym.make('CyberBattleToyCtf-v0')
for _ in range(2):
env.reset()
action_mask = env.compute_action_mask()
assert action_mask
for t in range(12):
# env.render()
# sample a valid action
action = env.sample_valid_action()
observation, reward, done, info = env.step(action)
if done:
print("Episode finished after {} timesteps".format(t + 1))
break
env.close()
pass
def test_step_after_done() -> None:
actions = [
{'local_vulnerability': np.array([0, 1])}, # done=False, r=49.0
{'remote_vulnerability': np.array([0, 1, 0])}, # done=False, r=45.0
{'connect': np.array([0, 1, 2, 0])}, # done=False, r=100.0
{'local_vulnerability': np.array([1, 3])}, # done=False, r=49.0
{'connect': np.array([0, 2, 3, 1])}, # done=False, r=100.0
{'remote_vulnerability': np.array([1, 2, 1])}, # done=False, r=49.0
{'remote_vulnerability': np.array([1, 2, 0])}, # done=False, r=49.0
{'remote_vulnerability': np.array([2, 1, 1])}, # done=False, r=45.0
{'local_vulnerability': np.array([1, 0])}, # done=False, r=49.0
{'local_vulnerability': np.array([1, 1])}, # done=False, r=40.0
{'local_vulnerability': np.array([2, 1])}, # done=False, r=49.0
{'remote_vulnerability': np.array([2, 3, 0])}, # done=False, r=45.0
{'local_vulnerability': np.array([2, 4])}, # done=False, r=49.0
{'connect': np.array([0, 3, 2, 2])}, # done=False, r=100.0
{'local_vulnerability': np.array([3, 3])}, # done=False, r=49.0
{'local_vulnerability': np.array([3, 0])}, # done=False, r=49.0
{'remote_vulnerability': np.array([0, 4, 1])}, # done=False, r=49.0
{'local_vulnerability': np.array([3, 1])}, # done=False, r=40.0
{'connect': np.array([2, 4, 3, 3])}, # done=False, r=100.0
{'remote_vulnerability': np.array([1, 3, 1])}, # done=False, r=45.0
{'remote_vulnerability': np.array([1, 4, 0])}, # done=False, r=49.0
{'local_vulnerability': np.array([4, 1])}, # done=False, r=49.0
{'remote_vulnerability': np.array([0, 5, 0])}, # done=False, r=45.0
{'local_vulnerability': np.array([4, 4])}, # done=False, r=49.0
{'connect': np.array([3, 5, 2, 4])}, # done=False, r=100.0
{'remote_vulnerability': np.array([2, 5, 1])}, # done=False, r=45.0
{'local_vulnerability': np.array([5, 3])}, # done=False, r=49.0
{'connect': np.array([2, 6, 3, 5])}, # done=False, r=100.0
{'remote_vulnerability': np.array([4, 6, 1])}, # done=False, r=49.0
{'local_vulnerability': np.array([5, 0])}, # done=False, r=49.0
{'remote_vulnerability': np.array([4, 6, 0])}, # done=False, r=49.0
{'local_vulnerability': np.array([5, 1])}, # done=False, r=40.0
{'local_vulnerability': np.array([6, 1])}, # done=False, r=49.0
{'remote_vulnerability': np.array([6, 7, 0])}, # done=False, r=45.0
{'remote_vulnerability': np.array([0, 7, 1])}, # done=False, r=45.0
{'local_vulnerability': np.array([6, 4])}, # done=False, r=49.0
{'connect': np.array([4, 7, 2, 6])}, # done=False, r=100.0
{'local_vulnerability': np.array([7, 3])}, # done=False, r=49.0
{'connect': np.array([0, 8, 3, 7])}, # done=False, r=100.0
{'remote_vulnerability': np.array([0, 8, 0])}, # done=False, r=49.0
{'local_vulnerability': np.array([7, 0])}, # done=False, r=49.0
{'local_vulnerability': np.array([8, 4])}, # done=False, r=49.0
{'remote_vulnerability': np.array([3, 9, 1])}, # done=False, r=45.0
{'connect': np.array([3, 9, 2, 8])}, # done=False, r=100.0
{'remote_vulnerability': np.array([4, 9, 0])}, # done=False, r=45.0
{'local_vulnerability': np.array([9, 0])}, # done=False, r=49.0
{'remote_vulnerability': np.array([3, 8, 1])}, # done=False, r=49.0
{'remote_vulnerability': np.array([6, 10, 0])}, # done=False, r=49.0
{'local_vulnerability': np.array([9, 1])}, # done=False, r=40.0
{'local_vulnerability': np.array([9, 3])}, # done=False, r=49.0
{'remote_vulnerability': np.array([8, 10, 1])}, # done=False, r=49.0
{'local_vulnerability': np.array([7, 1])}, # done=False, r=40.0
{'connect': np.array([8, 10, 3, 9])}, # done=False, r=100.0
{'local_vulnerability': np.array([10, 4])}, # done=False, r=49.0
{'local_vulnerability': np.array([8, 1])}, # done=False, r=49.0
{'connect': np.array([7, 11, 2, 10])}, # done=False, r=1000.0
{'connect': np.array([5, 10, 3, 9])}, # done=True, r=5000.0
# this is one too many (after done)
{'connect': np.array([10, 5, 2, 4])}, # done=True, r=5000.0
]
env = gym.make('CyberBattleChain-v0', size=10, attacker_goal=AttackerGoal(reward=4000))
for a in actions[:-1]:
env.step(a)
with pytest.raises(RuntimeError, match=r'new episode must be started with env\.reset\(\)'):
env.step(actions[-1])
@pytest.mark.parametrize('env_name', ['CyberBattleToyCtf-v0', 'CyberBattleRandom-v0', 'CyberBattleChain-v0'])
def test_wrap_spec(env_name) -> None:
env = gym.make(env_name)
class DummyWrapper(gym.Wrapper):
def __init__(self, env):
super().__init__(env)
assert hasattr(self, 'spec')
self.spec.dummy = 7
assert hasattr(env.spec, 'properties')
assert hasattr(env.spec, 'ports')
assert hasattr(env.spec, 'local_vulnerabilities')
assert hasattr(env.spec, 'remote_vulnerabilities')
env = DummyWrapper(env)
assert hasattr(env.spec, 'properties')
assert hasattr(env.spec, 'ports')
assert hasattr(env.spec, 'local_vulnerabilities')
assert hasattr(env.spec, 'remote_vulnerabilities')
assert hasattr(env.spec, 'dummy')

Просмотреть файл

@ -0,0 +1,14 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""A CyberBattle simulation over a randomly generated network"""
from ..simulation import generate_network
from . import cyberbattle_env
class CyberBattleRandom(cyberbattle_env.CyberBattleEnv):
"""A sample CyberBattle environment"""
def __init__(self):
super().__init__(initial_environment=generate_network.new_environment())

Просмотреть файл

@ -0,0 +1,14 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from ..samples.toyctf import toy_ctf
from . import cyberbattle_env
class CyberBattleToyCtf(cyberbattle_env.CyberBattleEnv):
"""CyberBattle simulation based on a toy CTF exercise"""
def __init__(self, **kwargs):
super().__init__(
initial_environment=toy_ctf.new_environment(),
**kwargs)

Просмотреть файл

@ -0,0 +1,151 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
Defines stock defender agents for the CyberBattle simulation.
"""
import random
import numpy
from abc import abstractmethod
from cyberbattle.simulation.model import Environment
from cyberbattle.simulation.actions import DefenderAgentActions
from ..simulation import model
import logging
class DefenderAgent:
"""Define the step function for a defender agent.
Gets called after each step executed by the attacker agent."""
@abstractmethod
def step(self, environment: Environment, actions: DefenderAgentActions, t: int):
None
class ScanAndReimageCompromisedMachines(DefenderAgent):
"""A defender agent that scans a subset of network nodes
detects presence of an attacker on a given node with
some fixed probability and if detected re-image the compromised node.
probability -- probability that an attacker agent is detected when scanned given that the attacker agent is present
scan_capacity -- maxium number of machine that a defender agent can scan in one simulation step
scan_frequency -- frequencey of the scan in simulation steps
"""
def __init__(self, probability: float, scan_capacity: int, scan_frequency: int):
self.probability = probability
self.scan_capacity = scan_capacity
self.scan_frequency = scan_frequency
def step(self, environment: Environment, actions: DefenderAgentActions, t: int):
if t % self.scan_frequency == 0:
# scan nodes at random
scanned_nodes = random.choices(list(environment.network.nodes), k=self.scan_capacity)
for node_id in scanned_nodes:
node_info = environment.get_node(node_id)
if node_info.status == model.MachineStatus.Running and \
node_info.agent_installed:
is_malware_detected = numpy.random.random() <= self.probability
if is_malware_detected:
if node_info.reimagable:
logging.info(f"Defender detected malware, reimaging node {node_id}")
actions.reimage_node(node_id)
else:
logging.info(f"Defender detected malware, but node cannot be reimaged {node_id}")
class ExternalRandomEvents(DefenderAgent):
"""A 'defender' that randomly alters network node configuration"""
def step(self, environment: Environment, actions: DefenderAgentActions, t: int):
self.patch_vulnerabilities_at_random(environment)
self.stop_service_at_random(environment, actions)
self.plant_vulnerabilities_at_random(environment)
self.firewall_change_remove(environment)
self.firewall_change_add(environment)
def patch_vulnerabilities_at_random(self, environment: Environment, probability: float = 0.1) -> None:
# Iterate through every node.
for node_id, node_data in environment.nodes():
# Have a boolean remove_vulnerability decide if we will remove one.
remove_vulnerability = numpy.random.random() <= probability
if remove_vulnerability and len(node_data.vulnerabilities) > 0:
choice = random.choice(list(node_data.vulnerabilities))
node_data.vulnerabilities.pop(choice)
def stop_service_at_random(self, environment: Environment, actions: DefenderAgentActions, probability: float = 0.1) -> None:
for node_id, node_data in environment.nodes():
remove_service = numpy.random.random() <= probability
if remove_service and len(node_data.services) > 0:
service = random.choice(node_data.services)
actions.stop_service(node_id, service.name)
def plant_vulnerabilities_at_random(self, environment: Environment, probability: float = 0.1) -> None:
for node_id, node_data in environment.nodes():
add_vulnerability = numpy.random.random() <= probability
# See all differences between current node vulnerabilities and global ones.
new_vulnerabilities = numpy.setdiff1d(
list(environment.vulnerability_library.keys()), list(node_data.vulnerabilities.keys()))
# If we have decided that we will add a vulnerability and there are new vulnerabilities not already
# on the node, then add them.
if add_vulnerability and len(new_vulnerabilities) > 0:
new_vulnerability = random.choice(new_vulnerabilities)
node_data.vulnerabilities[new_vulnerability] = \
environment.vulnerability_library[new_vulnerability]
"""
TODO: Not sure how to access global (environment) services.
def serviceChangeAdd(self, probability: float) -> None:
# Iterate through every node.
for node_id, node_data in self.__environment.nodes():
# Have a boolean addService decide if we will add one.
addService = numpy.random.random() <= probability
# List all new services we can add.
newServices = numpy.setdiff1d(self.__environment.services, node_data.services)
# If we have decided to add a service and there are new services to add, go ahead and add them.
if addService and len(newServices) > 0:
newService = random.choice(newServices)
node_data.services.append(newService)
return None
"""
def firewall_change_remove(self, environment: Environment, probability: float = 0.1) -> None:
# Iterate through every node.
for node_id, node_data in environment.nodes():
# Have a boolean remove_rule decide if we will remove one.
remove_rule = numpy.random.random() <= probability
# The following logic sees if there are both incoming and outgoing rules.
# If there are, we remove one randomly.
if remove_rule and len(node_data.firewall.outgoing) > 0 and len(node_data.firewall.incoming) > 0:
incoming = numpy.random.random() <= 0.5
if incoming:
rule_to_remove = random.choice(node_data.firewall.incoming)
node_data.firewall.incoming.remove(rule_to_remove)
else:
rule_to_remove = random.choice(node_data.firewall.outgoing)
node_data.firewall.outgoing.remove(rule_to_remove)
# If there are only outgoing rules, we remove one random outgoing rule.
elif remove_rule and len(node_data.firewall.outgoing) > 0:
rule_to_remove = random.choice(node_data.firewall.outgoing)
node_data.firewall.outgoing.remove(rule_to_remove)
# If there are only incoming rules, we remove one random incoming rule.
elif remove_rule and len(node_data.firewall.incoming) > 0:
rule_to_remove = random.choice(node_data.firewall.incoming)
node_data.firewall.incoming.remove(rule_to_remove)
def firewall_change_add(self, environment: Environment, probability: float = 0.1) -> None:
# Iterate through every node.
for node_id, node_data in environment.nodes():
# Have a boolean rule_to_add decide if we will add one.
add_rule = numpy.random.random() <= probability
if add_rule:
# 0 For allow, 1 for block.
rule_to_add = model.FirewallRule(port=random.choice(model.SAMPLE_IDENTIFIERS.ports),
permission=model.RulePermission.ALLOW)
# Randomly decide if we will add an incoming or outgoing rule.
incoming = numpy.random.random() <= 0.5
if incoming and rule_to_add not in node_data.firewall.incoming:
node_data.firewall.incoming.append(rule_to_add)
elif not incoming and rule_to_add not in node_data.firewall.incoming:
node_data.firewall.outgoing.append(rule_to_add)

Просмотреть файл

@ -0,0 +1,82 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""A discriminated union space for Gym"""
from collections import OrderedDict
from typing import Mapping, Union, List
from gym import spaces
from gym.utils import seeding
class DiscriminatedUnion(spaces.Dict): # type: ignore
"""
A discriminated union of simpler spaces.
Example usage:
self.observation_space = discriminatedunion.DiscriminatedUnion(
{"foo": spaces.Discrete(2), "Bar": spaces.Discrete(3)})
"""
def __init__(self,
spaces: Union[None, List[spaces.Space], Mapping[str, spaces.Space]] = None,
**spaces_kwargs: spaces.Space) -> None:
"""Create a discriminated union space"""
if spaces is None:
super().__init__(spaces_kwargs)
else:
super().__init__(spaces=spaces)
def seed(self, seed: Union[None, int] = None) -> None:
self._np_random, seed = seeding.np_random(seed)
super().seed(seed)
def sample(self) -> object:
space_count = len(self.spaces.items())
index_k = self.np_random.randint(space_count)
kth_key, kth_space = list(self.spaces.items())[index_k]
return OrderedDict([(kth_key, kth_space.sample())])
def contains(self, candidate: object) -> bool:
if not isinstance(candidate, dict) or len(candidate) != 1:
return False
k, space = list(candidate)[0]
return k in self.spaces.keys()
@classmethod
def is_of_kind(cls, key: str, sample_n: Mapping[str, object]) -> bool:
"""Returns true if a given sample is of the specified discriminated kind"""
return key in sample_n.keys()
@classmethod
def kind(cls, sample_n: Mapping[str, object]) -> str:
"""Returns the discriminated kind of a given sample"""
keys = sample_n.keys()
assert len(keys) == 1
return list(keys)[0]
def __getitem__(self, key: str) -> spaces.Space:
return self.spaces[key]
def __repr__(self) -> str:
return self.__class__.__name__ + "(" + ", ". join([str(k) + ":" + str(s) for k, s in self.spaces.items()]) + ")"
def to_jsonable(self, sample_n: object) -> object:
return super().to_jsonable(sample_n)
def from_jsonable(self, sample_n: object) -> object:
ret = super().from_jsonable(sample_n)
assert len(ret) == 1
return ret
def __eq__(self, other: object) -> bool:
return isinstance(other, DiscriminatedUnion) and self.spaces == other.spaces
def test_sampling() -> None:
"""Simple sampling test"""
union = DiscriminatedUnion(spaces={"foo": spaces.Discrete(8), "Bar": spaces.Discrete(3)})
[union.sample() for i in range(100)]

Просмотреть файл

@ -0,0 +1,106 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import Optional
import networkx as nx
from gym.spaces import Space, Dict
class BaseGraph(Space):
_nx_class: type
def __init__(
self,
max_num_nodes: int,
node_property_space: Optional[Dict] = None,
edge_property_space: Optional[Dict] = None):
self.max_num_nodes = max_num_nodes
self.node_property_space = Dict() if node_property_space is None else node_property_space
self.edge_property_space = Dict() if edge_property_space is None else edge_property_space
super().__init__()
def sample(self):
num_nodes = self.np_random.randint(self.max_num_nodes + 1)
graph = self._nx_class()
# add nodes with properties
for node_id in range(num_nodes):
node_properties = {k: s.sample() for k, s in self.node_property_space.spaces.items()}
graph.add_node(node_id, **node_properties)
if num_nodes < 2:
return graph
# add some edges with properties
seen, unseen = [], list(range(num_nodes)) # init
self.__pop_random(unseen, seen) # pop one node before we start
while unseen:
node_id_from, node_id_to = self.__sample_random(seen), self.__pop_random(unseen, seen)
edge_properties = {k: s.sample() for k, s in self.edge_property_space.spaces.items()}
graph.add_edge(node_id_from, node_id_to, **edge_properties)
return graph
def __pop_random(self, unseen: list, seen: list):
i = self.np_random.choice(len(unseen))
x = unseen[i]
seen.append(x)
del unseen[i]
return x
def __sample_random(self, seen: list):
i = self.np_random.choice(len(seen))
return seen[i]
def contains(self, x):
return (
isinstance(x, self._nx_class)
and all(node_property in self.node_property_space for node_property in x.nodes.values())
and all(edge_property in self.edge_property_space for edge_property in x.edges.values())
)
class Graph(BaseGraph):
_nx_class = nx.Graph
class DiGraph(BaseGraph):
_nx_class = nx.DiGraph
class MultiGraph(BaseGraph):
_nx_class = nx.MultiGraph
class MultiDiGraph(BaseGraph):
_nx_class = nx.MultiDiGraph
if __name__ == '__main__':
from gym.spaces import Box, Discrete
import matplotlib.pyplot as plt
space = DiGraph(
max_num_nodes=10,
node_property_space=Dict({'vector': Box(0, 1, (3,)), 'category': Discrete(7)}),
edge_property_space=Dict({'weight': Box(0, 1, ())}))
space.seed(42)
graph = space.sample()
assert graph in space
for node_id, node_properties in graph.nodes.items():
print(f"node_id: {node_id}, node_properties: {node_properties}")
for (node_id_from, node_id_to), edge_properties in graph.edges.items():
print(f"node_id_from: {node_id_from}, node_id_to: {node_id_to}, "
f"edge_properties: {edge_properties}")
pos = nx.spring_layout(graph)
nx.draw_networkx_nodes(graph, pos)
nx.draw_networkx_edges(graph, pos)
nx.draw_networkx_labels(graph, pos)
# nx.draw_networkx_labels(graph, pos, graph.nodes)
plt.show()

Просмотреть файл

@ -0,0 +1,214 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import Union, Tuple
import gym
import numpy as onp
import networkx as nx
from .graph_spaces import DiGraph
Action = Union[Tuple[int, int], Tuple[int, int, int], Tuple[int, int, int, int]]
class CyberBattleGraph(gym.Wrapper):
"""
A wrapper for CyberBattleSim that maintains the agent's
knowledge graph containing information about the subset
of the network that was explored so far.
Currently the nodes of this graph are a subset of the environment nodes.
Eventually we will add new node types to represent various entities
like credentials and users. Edges will represent relationships between those entities
(e.g. user X is authenticated with machine Y using credential Z).
Actions
-------
Actions are of the form:
.. code:: python
(kind, *indicators)
The ``kind`` which is one of
.. code:: python
# kind
0: Local Vulnerability
1: Remote Vulnerability
2: Connect
The indicators vary in meaning and length, depending on the ``kind``:
.. code:: python
# kind=0 (Local Vulnerability)
indicators = (node_id, local_vulnerability_id)
# kind=1 (Remote Vulnerability)
indicators = (from_node_id, to_node_id, remote_vulnerability_id)
# kind=2 (Connect)
indicators = (from_node_id, to_node_id, port_id, credential_id)
The node ids can be obtained from the graph, e.g.
.. code:: python
node_ids = observation['graph'].keys()
The other indicators are listed below.
.. code:: python
# local_vulnerability_ids
0: ScanBashHistory
1: ScanExplorerRecentFiles
2: SudoAttempt
3: CrackKeepPassX
4: CrackKeepPass
# remote_vulnerability_ids
0: ProbeLinux
1: ProbeWindows
# port_ids
0: HTTPS
1: GIT
2: SSH
3: RDP
4: PING
5: MySQL
6: SSH-key
7: su
Examples
~~~~~~~~
Here are some example actions:
.. code:: python
a = (0, 5, 3) # try local vulnerability "CrackKeepPassX" on node 5
a = (1, 5, 7, 1) # try remote vulnerability "ProbeWindows" from node 5 to node 7
a = (2, 5, 7, 3, 2) # try to connect from node 5 to node 7 using credential 2 over RDP port
Observations
------------
Observations are graphs of the nodes that have been discovered so far. Each node is annotated
with a dict of properties of the form:
.. code:: python
node_properties = {
'name': 'FooServer', # human-readable identifier
'privilege_level': 1, # 0: not owned, 1: admin, 2: system
'flags': array(-1, 0, 1, 0, 0, ..., 0]), # 1: set, -1: unset, 0: unknown
'credentials': array([-1, 5, -1, ..., -1]), # array of ports (-1 means no cred)
'has_leaked_creds': True, # whether node has leaked any credentials so far
}
# flag_ids
0: Windows
1: Linux
2: ApacheWebSite
3: IIS_2019
4: IIS_2020_patched
5: MySql
6: Ubuntu
7: nginx/1.10.3
8: SMB_vuln
9: SMB_vuln_patched
10: SQLServer
11: Win10
12: Win10Patched
13: FLAG:Linux
Note that the **position** of a non-trivial port number in ``'credentials'`` corresponds to the
credential id. Therefore, for the node in the example above, we have a known credential on
:code:`port_id=5` with :code:`credential_id=1` (the position in the array).
"""
__kinds = ('local_vulnerability', 'remote_vulnerability', 'connect')
def __init__(self, env, maximum_total_credentials=22, maximum_node_count=22):
super().__init__(env)
self._bounds = self.env._bounds
self.__graph = None
self.observation_space = DiGraph(self._bounds.maximum_node_count)
def reset(self):
observation = self.env.reset()
self.__graph = nx.DiGraph()
self.__add_node(observation)
self.__update_nodes(observation)
return self.__graph
def step(self, action: Action):
"""
Take a step in the MDP.
Args:
action: An *abstract* action.
Returns:
observation: The next-step observation.
reward: The reward associated with the given action (and previous observation).
done: Whether the next-step observation is a terminal state.
info: Some additional info.
"""
kind_id, *indicators = action
observation, reward, done, info = self.env.step({self.__kinds[kind_id]: indicators})
for _ in range(observation['newly_discovered_nodes_count']):
self.__add_node(observation)
if True: # TODO: do we need to update edges and nodes every time?
self.__update_edges(observation)
self.__update_nodes(observation)
return self.__graph, reward, done, info
def __add_node(self, observation):
while self.__graph.number_of_nodes() < observation['discovered_node_count']:
node_index = self.__graph.number_of_nodes()
creds = onp.full(self._bounds.maximum_total_credentials, -1, dtype=onp.int8)
self.__graph.add_node(
node_index,
name=observation['discovered_nodes'][node_index],
privilege_level=None, flags=None, # these are set by __update_nodes()
credentials=creds,
has_leaked_creds=False,
)
def __update_edges(self, observation):
g_orig = observation['explored_network']
node_ids = {n: i for i, n in enumerate(observation['discovered_nodes'])}
for (from_name, to_name), edge_properties in g_orig.edges.items():
self.__graph.add_edge(node_ids[from_name], node_ids[to_name], **edge_properties)
def __update_nodes(self, observation):
node_properties = zip(
observation['nodes_privilegelevel'],
observation['discovered_nodes_properties'],
)
for node_id, (privilege_level, flags) in enumerate(node_properties):
# This value is already provided in self.__graph.nodes[node_id]['data'].privilege_level
self.__graph.nodes[node_id]['privilege_level'] = privilege_level
# This value is already provided in self.__graph.nodes[node_id]['data'].properties
self.__graph.nodes[node_id]['flags'] = flags
for cred_id, (node_id, port_id) in enumerate(observation['credential_cache_matrix']):
node_id, port_id = int(node_id), int(port_id)
# NOTE: this code ignores situations where the same cred_id is
# used for two different ports (This can be the case, even on the same node for two different ports.)
self.__graph.nodes[node_id]['credentials'][cred_id] = port_id
# Mark the node has leaking credentials
self.__graph.nodes[node_id]['has_leaked_creds'] = True

Просмотреть файл

@ -0,0 +1,126 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import NamedTuple
import gym
from gym.spaces import Space, Discrete, Tuple
import numpy as onp
class Env(NamedTuple):
observation_space: Space
action_space: Space
def context_spaces(observation_space, action_space):
K = 3 # noqa: N806
N, L = action_space.spaces['local_vulnerability'].nvec # noqa: N806
N, N, R = action_space.spaces['remote_vulnerability'].nvec # noqa: N806
N, N, P, C = action_space.spaces['connect'].nvec # noqa: N806
return {
'kind': Env(observation_space, Discrete(K)),
'local_node_id': Env(Tuple((observation_space, Discrete(K))), Discrete(N)),
'local_vuln_id': Env(Tuple((observation_space, Discrete(N))), Discrete(L)),
'remote_node_id': Env(Tuple((observation_space, Discrete(K), Discrete(N))), Discrete(N)),
'remote_vuln_id': Env(Tuple((observation_space, Discrete(N), Discrete(N))), Discrete(R)),
'cred_id': Env(observation_space, Discrete(C)),
}
class ContextWrapper(gym.Wrapper):
__kinds = ('local_vulnerability', 'remote_vulnerability', 'connect')
def __init__(self, env, options):
super().__init__(env)
assert isinstance(options, dict) and set(options) == {
'kind', 'local_node_id', 'local_vuln_id', 'remote_node_id', 'remote_vuln_id', 'cred_id'}
self._options = options
self._bounds = self.env._bounds
self._action_context = []
def reset(self):
self._action_context = onp.full(5, -1, dtype=onp.int32)
self._observation = self.env.reset()
return self._observation
def step(self, dummy=None):
obs = self._observation
kind = self._options['kind'](obs)
local_node_id = self._options['local_node_id']((obs, kind))
if kind == 0:
local_vuln_id = self._options['local_vuln_id']((obs, local_node_id))
a = {self.__kinds[kind]: onp.array([local_node_id, local_vuln_id])}
else:
remote_node_id = self._options['remote_node_id']((obs, kind, local_node_id))
if kind == 1:
remote_vuln_id = \
self._options['remote_vuln_id']((obs, local_node_id, remote_node_id))
a = {self.__kinds[kind]: onp.array([local_node_id, remote_node_id, remote_vuln_id])}
else:
cred_id = self._options['cred_id'](obs)
assert cred_id < obs['credential_cache_length']
node_id, port_id = obs['credential_cache_matrix'][cred_id].astype('int32')
a = {self.__kinds[kind]: onp.array([local_node_id, node_id, port_id, cred_id])}
self._observation, reward, done, info = self.env.step(a)
return self._observation, reward, done, {**info, 'action': a}
# --- random option policies --------------------------------------------------------------------- #
def pi_kind(s):
kinds = ('local_vulnerability', 'remote_vulnerability', 'connect')
masked = onp.array([i for i, k in enumerate(kinds) if onp.any(s['action_mask'][k])])
return onp.random.choice(masked)
def pi_local_node_id(s):
s, k = s
if k == 0:
local_node_ids, _ = onp.argwhere(s['action_mask']['local_vulnerability']).T
elif k == 1:
local_node_ids, _, _ = onp.argwhere(s['action_mask']['remote_vulnerability']).T
else:
local_node_ids, _, _, _ = onp.argwhere(s['action_mask']['connect']).T
return onp.random.choice(local_node_ids)
def pi_local_vuln_id(s):
s, local_node_id = s
local_node_ids, local_vuln_ids = onp.argwhere(s['action_mask']['local_vulnerability']).T
masked = local_vuln_ids[local_node_ids == local_node_id]
return onp.random.choice(masked)
def pi_remote_node_id(s):
s, k, local_node_id = s
assert k != 0
if k == 1:
local_node_ids, remote_node_ids, _ = onp.argwhere(s['action_mask']['remote_vulnerability']).T
else:
local_node_ids, remote_node_ids, _, _ = onp.argwhere(s['action_mask']['connect']).T
return onp.random.choice(remote_node_ids[local_node_ids == local_node_id])
def pi_remote_vuln_id(s):
s, local_node_id, remote_node_id = s
local_node_ids, remote_node_ids, remote_vuln_ids = \
onp.argwhere(s['action_mask']['remote_vulnerability']).T
mask = (local_node_ids == local_node_id) & (remote_node_ids == remote_node_id)
return onp.random.choice(remote_vuln_ids[mask])
def pi_cred_id(s):
return onp.random.choice(s['credential_cache_length'])
random_options = {
'kind': pi_kind,
'local_node_id': pi_local_node_id,
'local_vuln_id': pi_local_vuln_id,
'remote_node_id': pi_remote_node_id,
'remote_vuln_id': pi_remote_vuln_id,
'cred_id': pi_cred_id,
}

Просмотреть файл

@ -0,0 +1,15 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
This module contains all the agents to be used as baselines on the CyberBattle env.
"""
from .baseline.learner import Learner, AgentWrapper, EnvironmentBounds
__all__ = (
'Learner',
'AgentWrapper',
'EnvironmentBounds'
)

Просмотреть файл

@ -0,0 +1,489 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Function DeepQLearnerPolicy.optimize_model:
# Copyright (c) 2017, Pytorch contributors
# All rights reserved.
# https://github.com/pytorch/tutorials/blob/master/LICENSE
"""Deep Q-learning agent applied to chain network (notebook)
This notebooks can be run directly from VSCode, to generate a
traditional Jupyter Notebook to open in your browser
you can run the VSCode command `Export Currenty Python File As Jupyter Notebook`.
Requirements:
Nvidia CUDA drivers for WSL2: https://docs.nvidia.com/cuda/wsl-user-guide/index.html
PyTorch
"""
# pylint: disable=invalid-name
# %% [markdown]
# # Chain network CyberBattle Gym played by a Deeo Q-learning agent
# %%
from numpy import ndarray
from cyberbattle._env import cyberbattle_env
import numpy as np
from typing import List, NamedTuple, Optional, Tuple, Union
import random
# deep learning packages
from torch import Tensor
import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn
import torch
import torch.cuda
from .learner import Learner
from .agent_wrapper import EnvironmentBounds
import cyberbattle.agents.baseline.agent_wrapper as w
from .agent_randomcredlookup import CredentialCacheExploiter
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class CyberBattleStateActionModel:
""" Define an abstraction of the state and action space
for a CyberBattle environment, to be used to train a Q-function.
"""
def __init__(self, ep: EnvironmentBounds):
self.ep = ep
self.global_features = w.ConcatFeatures(ep, [
# w.Feature_discovered_node_count(ep),
# w.Feature_owned_node_count(ep),
w.Feature_discovered_notowned_node_count(ep, None)
# w.Feature_discovered_ports(ep),
# w.Feature_discovered_ports_counts(ep),
# w.Feature_discovered_ports_sliding(ep),
# w.Feature_discovered_credential_count(ep),
# w.Feature_discovered_nodeproperties_sliding(ep),
])
self.node_specific_features = w.ConcatFeatures(ep, [
# w.Feature_actions_tried_at_node(ep),
w.Feature_success_actions_at_node(ep),
w.Feature_failed_actions_at_node(ep),
w.Feature_active_node_properties(ep),
w.Feature_active_node_age(ep)
# w.Feature_active_node_id(ep)
])
self.state_space = w.ConcatFeatures(ep, self.global_features.feature_selection +
self.node_specific_features.feature_selection)
self.action_space = w.AbstractAction(ep)
def get_state_astensor(self, state: w.StateAugmentation):
state_vector = self.state_space.get(state, node=None)
state_vector_float = np.array(state_vector, dtype=np.float32)
state_tensor = torch.from_numpy(state_vector_float).unsqueeze(0)
return state_tensor
def implement_action(
self,
wrapped_env: w.AgentWrapper,
actor_features: ndarray,
abstract_action: np.int32) -> Tuple[str, Optional[cyberbattle_env.Action], Optional[int]]:
"""Specialize an abstract model action into a CyberBattle gym action.
actor_features -- the desired features of the actor to use (source CyberBattle node)
abstract_action -- the desired type of attack (connect, local, remote).
Returns a gym environment implementing the desired attack at a node with the desired embedding.
"""
observation = wrapped_env.state.observation
# Pick source node at random (owned and with the desired feature encoding)
potential_source_nodes = [
from_node
for from_node in w.owned_nodes(observation)
if np.all(actor_features == self.node_specific_features.get(wrapped_env.state, from_node))
]
if len(potential_source_nodes) > 0:
source_node = np.random.choice(potential_source_nodes)
gym_action = self.action_space.specialize_to_gymaction(
source_node, observation, np.int32(abstract_action))
if not gym_action:
return "exploit[undefined]->explore", None, None
elif wrapped_env.env.is_action_valid(gym_action, observation['action_mask']):
return "exploit", gym_action, source_node
else:
return "exploit[invalid]->explore", None, None
else:
return "exploit[no_actor]->explore", None, None
# %%
# Deep Q-learning
class Transition(NamedTuple):
"""One taken transition and its outcome"""
state: Union[Tuple[Tensor], List[Tensor]]
action: Union[Tuple[Tensor], List[Tensor]]
next_state: Union[Tuple[Tensor], List[Tensor]]
reward: Union[Tuple[Tensor], List[Tensor]]
class ReplayMemory(object):
"""Transition replay memory"""
def __init__(self, capacity):
self.capacity = capacity
self.memory = []
self.position = 0
def push(self, *args):
"""Saves a transition."""
if len(self.memory) < self.capacity:
self.memory.append(None)
self.memory[self.position] = Transition(*args)
self.position = (self.position + 1) % self.capacity
def sample(self, batch_size):
return random.sample(self.memory, batch_size)
def __len__(self):
return len(self.memory)
class DQN(nn.Module):
"""The Deep Neural Network used to estimate the Q function"""
def __init__(self, ep: EnvironmentBounds):
super(DQN, self).__init__()
model = CyberBattleStateActionModel(ep)
linear_input_size = len(model.state_space.dim_sizes)
output_size = model.action_space.flat_size()
self.hidden_layer1 = nn.Linear(linear_input_size, 1024)
# self.bn1 = nn.BatchNorm1d(256)
self.hidden_layer2 = nn.Linear(1024, 512)
self.hidden_layer3 = nn.Linear(512, 128)
# self.hidden_layer4 = nn.Linear(128, 64)
self.head = nn.Linear(128, output_size)
# Called with either one element to determine next action, or a batch
# during optimization. Returns tensor([[left0exp,right0exp]...]).
def forward(self, x):
x = F.relu(self.hidden_layer1(x))
# x = F.dropout(x, p=0.5, training=self.training)
x = F.relu(self.hidden_layer2(x))
# x = F.dropout(x, p=0.5, training=self.training)
x = F.relu(self.hidden_layer3(x))
# x = F.relu(self.hidden_layer4(x))
return self.head(x.view(x.size(0), -1))
def random_argmax(array):
"""Just like `argmax` but if there are multiple elements with the max
return a random index to break ties instead of returning the first one."""
max_value = np.max(array)
max_index = np.where(array == max_value)[0]
if max_index.shape[0] > 1:
max_index = int(np.random.choice(max_index, size=1))
else:
max_index = int(max_index)
return max_value, max_index
class ChosenActionMetadata(NamedTuple):
"""Additonal info about the action chosen by the DQN-induced policy"""
abstract_action: np.int32
actor_node: int
actor_features: ndarray
actor_state: ndarray
def __repr__(self) -> str:
return f"[abstract_action={self.abstract_action}, actor={self.actor_node}, state={self.actor_state}]"
class DeepQLearnerPolicy(Learner):
"""Deep Q-Learning on CyberBattle environments
Parameters
==========
ep -- global parameters of the environment
model -- define a state and action abstraction for the gym environment
gamma -- Q discount factor
replay_memory_size -- size of the replay memory
batch_size -- Deep Q-learning batch
target_update -- Deep Q-learning replay frequency (in number of episodes)
learning_rate -- the learning rate
Parameters from DeepDoubleQ paper
- learning_rate = 0.00025
- linear epsilon decay
- gamma = 0.99
Pytorch code from tutorial at
https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html
"""
def __init__(self,
ep: EnvironmentBounds,
gamma: float,
replay_memory_size: int,
target_update: int,
batch_size: int,
learning_rate: float
):
self.stateaction_model = CyberBattleStateActionModel(ep)
self.batch_size = batch_size
self.gamma = gamma
self.learning_rate = learning_rate
self.policy_net = DQN(ep).to(device)
self.target_net = DQN(ep).to(device)
self.target_net.load_state_dict(self.policy_net.state_dict())
self.target_net.eval()
self.target_update = target_update
self.optimizer = optim.RMSprop(self.policy_net.parameters(), lr=learning_rate)
self.memory = ReplayMemory(replay_memory_size)
self.credcache_policy = CredentialCacheExploiter()
def parameters_as_string(self):
return f'γ={self.gamma}, lr={self.learning_rate}, replaymemory={self.memory.capacity},\n' \
f'batch={self.batch_size}, target_update={self.target_update}'
def all_parameters_as_string(self) -> str:
model = self.stateaction_model
return f'{self.parameters_as_string()}\n' \
f'dimension={model.state_space.flat_size()}x{model.action_space.flat_size()}, ' \
f'Q={[f.name() for f in model.state_space.feature_selection]} ' \
f"-> 'abstract_action'"
def optimize_model(self, norm_clipping=False):
if len(self.memory) < self.batch_size:
return
transitions = self.memory.sample(self.batch_size)
# converts batch-array of Transitions to Transition of batch-arrays.
batch = Transition(*zip(*transitions))
# Compute a mask of non-final states and concatenate the batch elements
# (a final state would've been the one after which simulation ended)
non_final_mask = torch.tensor(tuple(map((lambda s: s is not None), batch.next_state)),
device=device, dtype=torch.bool)
non_final_next_states = torch.cat([s for s in batch.next_state
if s is not None])
state_batch = torch.cat(batch.state)
action_batch = torch.cat(batch.action)
reward_batch = torch.cat(batch.reward)
# Compute Q(s_t, a) - the model computes Q(s_t), then we select the
# columns of actions taken. These are the actions which would've been taken
# for each batch state according to policy_net
# print(f'state_batch={state_batch.shape} input={len(self.stateaction_model.state_space.dim_sizes)}')
output = self.policy_net(state_batch)
# print(f'output={output.shape} batch.action={transitions[0].action.shape} action_batch={action_batch.shape}')
state_action_values = output.gather(1, action_batch)
# Compute V(s_{t+1}) for all next states.
# Expected values of actions for non_final_next_states are computed based
# on the "older" target_net; selecting their best reward with max(1)[0].
# This is merged based on the mask, such that we'll have either the expected
# state value or 0 in case the state was final.
next_state_values = torch.zeros(self.batch_size, device=device)
next_state_values[non_final_mask] = self.target_net(non_final_next_states).max(1)[0].detach()
# Compute the expected Q values
expected_state_action_values = (next_state_values * self.gamma) + reward_batch
# Compute Huber loss
loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))
# Optimize the model
self.optimizer.zero_grad()
loss.backward()
# Gradient clipping
if norm_clipping:
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
else:
for param in self.policy_net.parameters():
param.grad.data.clamp_(-1, 1)
self.optimizer.step()
def get_actor_state_vector(self, global_state: ndarray, actor_features: ndarray) -> ndarray:
return np.concatenate((np.array(global_state, dtype=np.float32),
np.array(actor_features, dtype=np.float32)))
def update_q_function(self,
reward: float,
actor_state: ndarray,
abstract_action: np.int32,
next_actor_state: Optional[ndarray]):
# store the transition in memory
reward_tensor = torch.tensor([reward], device=device, dtype=torch.float)
action_tensor = torch.tensor([[np.long(abstract_action)]], device=device, dtype=torch.long)
current_state_tensor = torch.as_tensor(actor_state, dtype=torch.float, device=device).unsqueeze(0)
if next_actor_state is None:
next_state_tensor = None
else:
next_state_tensor = torch.as_tensor(next_actor_state, dtype=torch.float, device=device).unsqueeze(0)
self.memory.push(current_state_tensor, action_tensor, next_state_tensor, reward_tensor)
# optimize the target network
self.optimize_model()
def on_step(self, wrapped_env: w.AgentWrapper,
observation, reward: float, done: bool, info, action_metadata):
agent_state = wrapped_env.state
if done:
self.update_q_function(reward,
actor_state=action_metadata.actor_state,
abstract_action=action_metadata.abstract_action,
next_actor_state=None)
else:
next_global_state = self.stateaction_model.global_features.get(agent_state, node=None)
next_actor_features = self.stateaction_model.node_specific_features.get(
agent_state, action_metadata.actor_node)
next_actor_state = self.get_actor_state_vector(next_global_state, next_actor_features)
self.update_q_function(reward,
actor_state=action_metadata.actor_state,
abstract_action=action_metadata.abstract_action,
next_actor_state=next_actor_state)
def end_of_episode(self, i_episode, t):
# Update the target network, copying all weights and biases in DQN
if i_episode % self.target_update == 0:
self.target_net.load_state_dict(self.policy_net.state_dict())
def lookup_dqn(self, states_to_consider: List[ndarray]) -> Tuple[List[np.int32], List[np.int32]]:
""" Given a set of possible current states return:
- index, in the provided list, of the state that would yield the best possible outcome
- the best action to take in such a state"""
with torch.no_grad():
# t.max(1) will return largest column value of each row.
# second column on max result is index of where max element was
# found, so we pick action with the larger expected reward.
# action: np.int32 = self.policy_net(states_to_consider).max(1)[1].view(1, 1).item()
state_batch = torch.tensor(states_to_consider).to(device)
dnn_output = self.policy_net(state_batch).max(1)
action_lookups = dnn_output[1].tolist()
expectedq_lookups = dnn_output[0].tolist()
return action_lookups, expectedq_lookups
def metadata_from_gymaction(self, wrapped_env, gym_action):
current_global_state = self.stateaction_model.global_features.get(wrapped_env.state, node=None)
actor_node = cyberbattle_env.sourcenode_of_action(gym_action)
actor_features = self.stateaction_model.node_specific_features.get(wrapped_env.state, actor_node)
abstract_action = self.stateaction_model.action_space.abstract_from_gymaction(gym_action)
return ChosenActionMetadata(
abstract_action=abstract_action,
actor_node=actor_node,
actor_features=actor_features,
actor_state=self.get_actor_state_vector(current_global_state, actor_features))
def explore(self, wrapped_env: w.AgentWrapper
) -> Tuple[str, cyberbattle_env.Action, object]:
"""Random exploration that avoids repeating actions previously taken in the same state"""
# sample local and remote actions only (excludes connect action)
gym_action = wrapped_env.env.sample_valid_action(kinds=[0, 1, 2])
metadata = self.metadata_from_gymaction(wrapped_env, gym_action)
return "explore", gym_action, metadata
def try_exploit_at_candidate_actor_states(
self,
wrapped_env,
current_global_state,
actor_features,
abstract_action):
actor_state = self.get_actor_state_vector(current_global_state, actor_features)
action_style, gym_action, actor_node = self.stateaction_model.implement_action(
wrapped_env, actor_features, abstract_action)
if gym_action:
assert actor_node is not None, 'actor_node should be set together with gym_action'
return action_style, gym_action, ChosenActionMetadata(
abstract_action=abstract_action,
actor_node=actor_node,
actor_features=actor_features,
actor_state=actor_state)
else:
# learn the failed exploit attempt in the current state
self.update_q_function(reward=0.0,
actor_state=actor_state,
next_actor_state=actor_state,
abstract_action=abstract_action)
return "exploit[undefined]->explore", None, None
def exploit(self,
wrapped_env,
observation
) -> Tuple[str, Optional[cyberbattle_env.Action], object]:
# first, attempt to exploit the credential cache
# using the crecache_policy
# action_style, gym_action, _ = self.credcache_policy.exploit(wrapped_env, observation)
# if gym_action:
# return action_style, gym_action, self.metadata_from_gymaction(wrapped_env, gym_action)
# Otherwise on exploit learnt Q-function
current_global_state = self.stateaction_model.global_features.get(wrapped_env.state, node=None)
# Gather the features of all the current active actors (i.e. owned nodes)
active_actors_features: List[ndarray] = [
self.stateaction_model.node_specific_features.get(wrapped_env.state, from_node)
for from_node in w.owned_nodes(observation)
]
unique_active_actors_features: List[ndarray] = np.unique(active_actors_features, axis=0)
# array of actor state vector for every possible set of node features
candidate_actor_state_vector: List[ndarray] = [
self.get_actor_state_vector(current_global_state, node_features)
for node_features in unique_active_actors_features]
remaining_action_lookups, remaining_expectedq_lookups = self.lookup_dqn(candidate_actor_state_vector)
remaining_candidate_indices = list(range(len(candidate_actor_state_vector)))
while remaining_candidate_indices:
_, remaining_candidate_index = random_argmax(remaining_expectedq_lookups)
actor_index = remaining_candidate_indices[remaining_candidate_index]
abstract_action = remaining_action_lookups[remaining_candidate_index]
actor_features = unique_active_actors_features[actor_index]
action_style, gym_action, metadata = self.try_exploit_at_candidate_actor_states(
wrapped_env,
current_global_state,
actor_features,
abstract_action)
if gym_action:
return action_style, gym_action, metadata
remaining_candidate_indices.pop(remaining_candidate_index)
remaining_expectedq_lookups.pop(remaining_candidate_index)
remaining_action_lookups.pop(remaining_candidate_index)
return "exploit[undefined]->explore", None, None
def stateaction_as_string(self, action_metadata) -> str:
return ''

Просмотреть файл

@ -0,0 +1,86 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Random agent with credential lookup (notebook)
"""
# pylint: disable=invalid-name
from .agent_wrapper import AgentWrapper
from .learner import Learner
from typing import Optional
import cyberbattle._env.cyberbattle_env as cyberbattle_env
import numpy as np
import logging
import cyberbattle.agents.baseline.agent_wrapper as w
def exploit_credentialcache(observation) -> Optional[cyberbattle_env.Action]:
"""Exploit the credential cache to connect to
a node not owned yet."""
# Pick source node at random (owned and with the desired feature encoding)
potential_source_nodes = w.owned_nodes(observation)
if len(potential_source_nodes) == 0:
return None
source_node = np.random.choice(potential_source_nodes)
discovered_credentials = np.array(observation['credential_cache_matrix'])
n_discovered_creds = len(discovered_credentials)
if n_discovered_creds <= 0:
# no credential available in the cache: cannot poduce a valid connect action
return None
nodes_not_owned = w.discovered_nodes_notowned(observation)
match_port__target_notowned = [c for c in range(n_discovered_creds)
if discovered_credentials[c, 0] in nodes_not_owned]
if match_port__target_notowned:
logging.debug('found matching cred in the credential cache')
cred = np.int32(np.random.choice(match_port__target_notowned))
target = np.int32(discovered_credentials[cred, 0])
port = np.int32(discovered_credentials[cred, 1])
return {'connect': np.array([source_node, target, port, cred], dtype=np.int32)}
else:
return None
class CredentialCacheExploiter(Learner):
"""A learner that just exploits the credential cache"""
def parameters_as_string(self):
return ''
def explore(self, wrapped_env: AgentWrapper):
return "explore", wrapped_env.env.sample_valid_action([0, 1]), None
def exploit(self, wrapped_env: AgentWrapper, observation):
gym_action = exploit_credentialcache(observation)
if gym_action:
if wrapped_env.env.is_action_valid(gym_action, observation['action_mask']):
return 'exploit', gym_action, None
else:
# fallback on random exploration
return 'exploit[invalid]->explore', None, None
else:
return 'exploit[undefined]->explore', None, None
def stateaction_as_string(self, actionmetadata):
return ''
def on_step(self, wrapped_env: AgentWrapper, observation, reward, done, info, action_metadata):
None
def end_of_iteration(self, t, done):
None
def end_of_episode(self, i_episode, t):
None
def loss_as_string(self):
return ''
def new_episode(self):
None

Просмотреть файл

@ -0,0 +1,422 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Q-learning agent applied to chain network (notebook)
This notebooks can be run directly from VSCode, to generate a
traditional Jupyter Notebook to open in your browser
you can run the VSCode command `Export Currenty Python File As Jupyter Notebook`.
"""
# pylint: disable=invalid-name
from typing import NamedTuple, Optional, Tuple
import numpy as np
import logging
from cyberbattle._env import cyberbattle_env
from .agent_wrapper import EnvironmentBounds
from .agent_randomcredlookup import CredentialCacheExploiter
import cyberbattle.agents.baseline.agent_wrapper as w
import cyberbattle.agents.baseline.learner as learner
def random_argmax(array):
"""Just like `argmax` but if there are multiple elements with the max
return a random index to break ties instead of returning the first one."""
max_value = np.max(array)
max_index = np.where(array == max_value)[0]
if max_index.shape[0] > 1:
max_index = int(np.random.choice(max_index, size=1))
else:
max_index = int(max_index)
return max_value, max_index
def random_argtop_percentile(array, percentile):
"""Just like `argmax` but if there are multiple elements with the max
return a random index to break ties instead of returning the first one."""
top_percentile = np.percentile(array, percentile)
indices = np.where(array >= top_percentile)[0]
if len(indices) == 0:
return random_argmax(array)
elif indices.shape[0] > 1:
max_index = int(np.random.choice(indices, size=1))
else:
max_index = int(indices)
return top_percentile, max_index
class QMatrix:
"""Q-Learning matrix for a given state and action space
state_space - Features defining the state space
action_space - Features defining the action space
qm - Optional: initialization values for the Q matrix
"""
# The Quality matrix
qm: np.ndarray
def __init__(self, name,
state_space: w.Feature,
action_space: w.Feature,
qm: Optional[np.ndarray] = None):
"""Initialize the Q-matrix"""
self.name = name
self.state_space = state_space
self.action_space = action_space
self.statedim = state_space.flat_size()
self.actiondim = action_space.flat_size()
self.qm = self.clear() if qm is None else qm
# error calculated for the last update to the Q-matrix
self.last_error = 0
def shape(self):
return (self.statedim, self.actiondim)
def clear(self):
"""Re-initialize the Q-matrix to 0"""
self.qm = np.zeros(shape=self.shape())
# self.qm = np.random.rand(*self.shape()) / 100
return self.qm
def print(self):
print(f"[{self.name}]\n"
f"state: {self.state_space}\n"
f"action: {self.action_space}\n"
f"shape = {self.shape()}")
def update(self, current_state: int, action: int, next_state: int, reward, gamma, learning_rate):
"""Update the Q matrix after taking `action` in state 'current_State'
and obtaining reward=R[current_state, action]"""
maxq_atnext, max_index = random_argmax(self.qm[next_state, ])
# bellman equation for Q-learning
temporal_difference = reward + gamma * maxq_atnext - self.qm[current_state, action]
self.qm[current_state, action] += learning_rate * temporal_difference
# The loss is calculated using the squared difference between
# target Q-Value and predicted Q-Value
square_error = temporal_difference * temporal_difference
self.last_error = square_error
return self.qm[current_state, action]
def exploit(self, features, percentile) -> Tuple[int, float]:
"""exploit: leverage the Q-matrix.
Returns the expected Q value and the chosen action."""
expected_q, action = random_argtop_percentile(self.qm[features, :], percentile)
return int(action), expected_q
class QLearnAttackSource(QMatrix):
""" Top-level Q matrix to pick the attack
State space: global state info
Action space: feature encodings of suggested nodes
"""
def __init__(self, ep: EnvironmentBounds, qm: Optional[np.ndarray] = None):
self.ep = ep
self.state_space = w.HashEncoding(ep, [
# Feature_discovered_node_count(),
# Feature_discovered_credential_count(),
w.Feature_discovered_ports_sliding(ep),
w.Feature_discovered_nodeproperties_sliding(ep),
w.Feature_discovered_notowned_node_count(ep, 3)
], 5000) # should not be too small, pick something big to avoid collision
self.action_space = w.RavelEncoding(ep, [
w.Feature_active_node_properties(ep)])
super().__init__("attack_source", self.state_space, self.action_space, qm)
class QLearnBestAttackAtSource(QMatrix):
""" Top-level Q matrix to pick the attack from a pre-chosen source node
State space: feature encodings of suggested node states
Action space: a SimpleAbstract action
"""
def __init__(self, ep: EnvironmentBounds, qm: Optional[np.ndarray] = None) -> None:
self.state_space = w.HashEncoding(ep, [
w.Feature_active_node_properties(ep),
w.Feature_active_node_age(ep)
# w.Feature_actions_tried_at_node(ep)
], 7000)
# NOTE: For debugging purpose it's convenient instead to use
# Ravel encoding for node properties
self.state_space_debugging = w.RavelEncoding(ep, [
w.HashEncoding(ep, [
# Feature_discovered_node_count(),
# Feature_discovered_credential_count(),
w.Feature_discovered_ports_sliding(ep),
w.Feature_discovered_nodeproperties_sliding(ep),
w.Feature_discovered_notowned_node_count(ep, 3),
], 100),
w.Feature_active_node_properties(ep)
])
self.action_space = w.AbstractAction(ep)
super().__init__("attack_at_source", self.state_space, self.action_space, qm)
# TODO: We should try scipy for sparse matrices and OpenBLAS (MKL Intel version of BLAS, faster than openBLAS) for numpy
# %%
class LossEval:
"""Loss evaluation for a Q-Learner,
learner -- The Q learner
"""
def __init__(self, qmatrix: QMatrix):
self.qmatrix = qmatrix
self.this_episode = []
self.all_episodes = []
def new_episode(self):
self.this_episode = []
def end_of_iteration(self, t, done):
self.this_episode.append(self.qmatrix.last_error)
def current_episode_loss(self):
return np.average(self.this_episode)
def end_of_episode(self, i_episode, t):
"""Average out the overall loss for this episode"""
self.all_episodes.append(self.current_episode_loss())
class ChosenActionMetadata(NamedTuple):
"""Additional information associated with the action chosen by the agent"""
Q_source_state: int
Q_source_expectedq: float
Q_attack_expectedq: float
source_node: int
source_node_encoding: int
abstract_action: np.int32
Q_attack_state: int
class QTabularLearner(learner.Learner):
"""Tabular Q-learning
Parameters
==========
gamma -- discount factor
learning_rate -- learning rate
ep -- environment global properties
trained -- another QTabularLearner that is pretrained to initialize the Q matrices from (referenced, not copied)
exploit_percentile -- (experimental) Randomly pick actions above this percentile in the Q-matrix.
Setting 100 gives the argmax as in standard Q-learning.
The idea is that a value less than 100 helps compensate for the
approximation made when updating the Q-matrix caused by
the abstraction of the action space (attack parameters are abstracted away
in the Q-matrix, and when an abstract action is picked, it
gets specialized via a random process.)
When running in non-learning mode (lr=0), setting this value too close to 100
may lead to get stuck, being more permissive (e.g. in the 80-90 range)
typically gives better results.
"""
def __init__(self,
ep: EnvironmentBounds,
gamma: float,
learning_rate: float,
exploit_percentile: float,
trained=None, # : Optional[QTabularLearner]
):
if trained:
self.qsource = trained.qsource
self.qattack = trained.qattack
else:
self.qsource = QLearnAttackSource(ep)
self.qattack = QLearnBestAttackAtSource(ep)
self.loss_qsource = LossEval(self.qsource)
self.loss_qattack = LossEval(self.qattack)
self.gamma = gamma
self.learning_rate = learning_rate
self.exploit_percentile = exploit_percentile
self.credcache_policy = CredentialCacheExploiter()
def on_step(self, wrapped_env: w.AgentWrapper, observation, reward, done, info, action_metadata: ChosenActionMetadata):
agent_state = wrapped_env.state
# Update the top-level Q matrix for the state of the selected source node
after_toplevel_state = self.qsource.state_space.encode(agent_state)
self.qsource.update(action_metadata.Q_source_state,
action_metadata.source_node_encoding,
after_toplevel_state,
reward, self.gamma, self.learning_rate)
# Update the second Q matrix for the abstract action chosen
qattack_state_after = self.qattack.state_space.encode_at(agent_state, action_metadata.source_node)
self.qattack.update(action_metadata.Q_attack_state,
int(action_metadata.abstract_action),
qattack_state_after,
reward, self.gamma, self.learning_rate)
def end_of_iteration(self, t, done):
self.loss_qsource.end_of_iteration(t, done)
self.loss_qattack.end_of_iteration(t, done)
def end_of_episode(self, i_episode, t):
self.loss_qsource.end_of_episode(i_episode, t)
self.loss_qattack.end_of_episode(i_episode, t)
def loss_as_string(self):
return f"[loss_source={self.loss_qsource.current_episode_loss():0.3f}" \
f" loss_attack={self.loss_qattack.current_episode_loss():0.3f}]"
def new_episode(self):
self.loss_qsource.new_episode()
self.loss_qattack.new_episode()
def exploit(self, wrapped_env: w.AgentWrapper, observation):
agent_state = wrapped_env.state
qsource_state = self.qsource.state_space.encode(agent_state)
#############
# first, attempt to exploit the credential cache
# using the crecache_policy
action_style, gym_action, _ = self.credcache_policy.exploit(wrapped_env, observation)
if gym_action:
source_node = cyberbattle_env.sourcenode_of_action(gym_action)
return action_style, gym_action, ChosenActionMetadata(
Q_source_state=qsource_state,
Q_source_expectedq=-1,
Q_attack_expectedq=-1,
source_node=source_node,
source_node_encoding=self.qsource.action_space.encode_at(
agent_state, source_node),
abstract_action=np.int32(self.qattack.action_space.abstract_from_gymaction(gym_action)),
Q_attack_state=self.qattack.state_space.encode_at(agent_state, source_node)
)
#############
# Pick action: pick random source state among the ones with the maximum Q-value
action_style = "exploit"
source_node_encoding, qsource_expectedq = self.qsource.exploit(
qsource_state, percentile=100)
# Pick source node at random (owned and with the desired feature encoding)
potential_source_nodes = [
from_node
for from_node in w.owned_nodes(observation)
if source_node_encoding == self.qsource.action_space.encode_at(agent_state, from_node)
]
if len(potential_source_nodes) == 0:
logging.debug(f'No node with encoding {source_node_encoding}, fallback on explore')
# NOTE: we should make sure that it does not happen too often,
# the penalty should be much smaller than typical rewards, small nudge
# not a new feedback signal.
# Learn the lack of node availability
self.qsource.update(qsource_state,
source_node_encoding,
qsource_state,
reward=0, gamma=self.gamma, learning_rate=self.learning_rate)
return "exploit-1->explore", None, None
else:
source_node = np.random.choice(potential_source_nodes)
qattack_state = self.qattack.state_space.encode_at(agent_state, source_node)
abstract_action, qattack_expectedq = self.qattack.exploit(
qattack_state, percentile=self.exploit_percentile)
gym_action = self.qattack.action_space.specialize_to_gymaction(
source_node, observation, np.int32(abstract_action))
assert int(abstract_action) < self.qattack.action_space.flat_size(), \
f'abstract_action={abstract_action} gym_action={gym_action}'
if gym_action and wrapped_env.env.is_action_valid(gym_action, observation['action_mask']):
logging.debug(f' exploit gym_action={gym_action} source_node_encoding={source_node_encoding}')
return action_style, gym_action, ChosenActionMetadata(
Q_source_state=qsource_state,
Q_source_expectedq=qsource_expectedq,
Q_attack_expectedq=qsource_expectedq,
source_node=source_node,
source_node_encoding=source_node_encoding,
abstract_action=np.int32(abstract_action),
Q_attack_state=qattack_state
)
else:
# NOTE: We should make the penalty reward smaller than
# the average/typical non-zero reward of the env (e.g. 1/1000 smaller)
# The idea of weighing the learning_rate when taking a chance is
# related to "Inverse propensity weighting"
# Learn the non-validity of the action
self.qsource.update(qsource_state,
source_node_encoding,
qsource_state,
reward=0, gamma=self.gamma, learning_rate=self.learning_rate)
self.qattack.update(qattack_state,
int(abstract_action),
qattack_state,
reward=0, gamma=self.gamma, learning_rate=self.learning_rate)
# fallback on random exploration
return ('exploit[invalid]->explore' if gym_action else 'exploit[undefined]->explore'), None, None
def explore(self, wrapped_env: w.AgentWrapper):
agent_state = wrapped_env.state
gym_action = wrapped_env.env.sample_valid_action(kinds=[0, 1, 2])
abstract_action = self.qattack.action_space.abstract_from_gymaction(gym_action)
assert int(abstract_action) < self.qattack.action_space.flat_size(
), f'Q_attack_action={abstract_action} gym_action={gym_action}'
source_node = cyberbattle_env.sourcenode_of_action(gym_action)
return "explore", gym_action, ChosenActionMetadata(
Q_source_state=self.qsource.state_space.encode(agent_state),
Q_source_expectedq=-1,
Q_attack_expectedq=-1,
source_node=source_node,
source_node_encoding=self.qsource.action_space.encode_at(agent_state, source_node),
abstract_action=abstract_action,
Q_attack_state=self.qattack.state_space.encode_at(agent_state, source_node)
)
def stateaction_as_string(self, actionmetadata) -> str:
return f"Qsource[state={actionmetadata.Q_source_state} err={self.qsource.last_error:0.2f}"\
f"Q={actionmetadata.Q_source_expectedq:.2f}] " \
f"Qattack[state={actionmetadata.Q_attack_state} err={self.qattack.last_error:0.2f} "\
f"Q={actionmetadata.Q_attack_expectedq:.2f}] "
def parameters_as_string(self) -> str:
return f"γ={self.gamma}," \
f"learning_rate={self.learning_rate},"\
f"Q%={self.exploit_percentile}"
def all_parameters_as_string(self) -> str:
return f' dimension={self.qsource.state_space.flat_size()}x{self.qsource.action_space.flat_size()},' \
f'{self.qattack.state_space.flat_size()}x{self.qattack.action_space.flat_size()}\n' \
f'Q1={[f.name() for f in self.qsource.state_space.feature_selection]}' \
f' -> {[f.name() for f in self.qsource.action_space.feature_selection]}\n' \
f"Q2={[f.name() for f in self.qattack.state_space.feature_selection]} -> 'action'"

Просмотреть файл

@ -0,0 +1,507 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Agent wrapper for CyberBattle envrionments exposing additional
features extracted from the environment observations"""
from cyberbattle._env.cyberbattle_env import EnvironmentBounds
from typing import Optional, List
import enum
import numpy as np
from gym import spaces, Wrapper
from numpy import ndarray
import cyberbattle._env.cyberbattle_env as cyberbattle_env
import logging
class StateAugmentation:
"""Default agent state augmentation, consisting of the gym environment
observation itself and nothing more."""
def __init__(self, observation: Optional[cyberbattle_env.Observation] = None):
self.observation = observation
def on_step(self, action: cyberbattle_env.Action, reward: float, done: bool, observation: cyberbattle_env.Observation):
self.observation = observation
def on_reset(self, observation: cyberbattle_env.Observation):
self.observation = observation
class Feature(spaces.MultiDiscrete):
"""
Feature consisting of multiple discrete dimensions.
Parameters:
nvec: is a vector defining the number of possible values
for each discrete space.
"""
def __init__(self, env_properties: EnvironmentBounds, nvec):
self.env_properties = env_properties
super().__init__(nvec)
def flat_size(self):
return np.prod(self.nvec)
def name(self):
"""Return the name of the feature"""
p = len(type(Feature(self.env_properties, [])).__name__) + 1
return type(self).__name__[p:]
def get(self, a: StateAugmentation, node: Optional[int]) -> np.ndarray:
"""Compute the current value of a feature value at
the current observation and specific node"""
raise NotImplementedError
def pretty_print(self, v):
return v
class Feature_active_node_properties(Feature):
"""Bitmask of all properties set for the active node"""
def __init__(self, p: EnvironmentBounds):
super().__init__(p, [2] * p.property_count)
def get(self, a: StateAugmentation, node) -> ndarray:
assert node is not None, 'feature only valid in the context of a node'
node_prop = a.observation['discovered_nodes_properties']
# list of all properties set/unset on the node
# Remap to get rid of unknown value 0: 1 -> 1, and -1 -> 0 (and 0-> 0)
assert node < len(node_prop), f'invalid node index {node} (not discovered yet)'
remapped = np.array((1 + node_prop[node]) / 2, dtype=np.int)
return remapped
class Feature_active_node_age(Feature):
"""How recently was this node discovered?
(measured by reverse position in the list of discovered nodes)"""
def __init__(self, p: EnvironmentBounds):
super().__init__(p, [p.maximum_node_count])
def get(self, a: StateAugmentation, node) -> ndarray:
assert node is not None, 'feature only valid in the context of a node'
discovered_node_count = len(a.observation['discovered_nodes_properties'])
assert node < discovered_node_count, f'invalid node index {node} (not discovered yet)'
return np.array([discovered_node_count - node - 1], dtype=np.int)
class Feature_active_node_id(Feature):
"""Return the node id itself"""
def __init__(self, p: EnvironmentBounds):
super().__init__(p, [p.maximum_node_count] * 1)
def get(self, a: StateAugmentation, node) -> ndarray:
return np.array([node], dtype=np.int)
class Feature_discovered_nodeproperties_sliding(Feature):
"""Bitmask indicating node properties seen in last few cache entries"""
window_size = 3
def __init__(self, p: EnvironmentBounds):
super().__init__(p, [2] * p.property_count)
def get(self, a: StateAugmentation, node) -> ndarray:
node_prop = np.array(a.observation['discovered_nodes_properties'])
# keep last window of entries
node_prop_window = node_prop[-self.window_size:, :]
# Remap to get rid of unknown value 0: 1 -> 1, and -1 -> 0 (and 0-> 0)
node_prop_window_remapped = np.int32((1 + node_prop_window) / 2)
countby = np.sum(node_prop_window_remapped, axis=0)
bitmask = (countby > 0) * 1
return bitmask
class Feature_discovered_ports(Feature):
"""Bitmask vector indicating each port seen so far in discovered credentials"""
def __init__(self, p: EnvironmentBounds):
super().__init__(p, [2] * p.port_count)
def get(self, a: StateAugmentation, node):
ccm = a.observation['credential_cache_matrix']
known_credports = np.zeros(self.env_properties.port_count, dtype=np.int32)
known_credports[np.int32(ccm[:, 1])] = 1
return known_credports
class Feature_discovered_ports_sliding(Feature):
"""Bitmask indicating port seen in last few cache entries"""
window_size = 3
def __init__(self, p: EnvironmentBounds):
super().__init__(p, [2] * p.port_count)
def get(self, a: StateAugmentation, node):
ccm = a.observation['credential_cache_matrix']
known_credports = np.zeros(self.env_properties.port_count, dtype=np.int32)
known_credports[np.int32(ccm[-self.window_size:, 1])] = 1
return known_credports
class Feature_discovered_ports_counts(Feature):
"""Count of each port seen so far in discovered credentials"""
def __init__(self, p: EnvironmentBounds):
super().__init__(p, [p.maximum_total_credentials + 1] * p.port_count)
def get(self, a: StateAugmentation, node):
ccm = a.observation['credential_cache_matrix']
return np.bincount(np.int32(ccm[:, 1]), minlength=self.env_properties.port_count)
class Feature_discovered_credential_count(Feature):
"""number of credentials discovered so far"""
def __init__(self, p: EnvironmentBounds):
super().__init__(p, [p.maximum_total_credentials + 1])
def get(self, a: StateAugmentation, node):
return [len(a.observation['credential_cache_matrix'])]
class Feature_discovered_node_count(Feature):
"""number of nodes discovered so far"""
def __init__(self, p: EnvironmentBounds):
super().__init__(p, [p.maximum_node_count + 1])
def get(self, a: StateAugmentation, node):
return [len(a.observation['discovered_nodes_properties'])]
class Feature_discovered_notowned_node_count(Feature):
"""number of nodes discovered that are not owned yet (optionally clipped)"""
def __init__(self, p: EnvironmentBounds, clip: Optional[int]):
self.clip = p.maximum_node_count if clip is None else clip
super().__init__(p, [self.clip + 1])
def get(self, a: StateAugmentation, node):
node_props = a.observation['discovered_nodes_properties']
discovered = len(node_props)
owned = len(np.all(node_props != 0, axis=1))
diff = discovered - owned
return [max(diff, self.clip)]
class Feature_owned_node_count(Feature):
"""number of owned nodes so far"""
def __init__(self, p: EnvironmentBounds):
super().__init__(p, [p.maximum_node_count + 1])
def get(self, a: StateAugmentation, node):
levels = a.observation['nodes_privilegelevel']
owned_nodes_indices = np.where(levels > 0)[0]
return [len(owned_nodes_indices)]
class ConcatFeatures(Feature):
""" Concatenate a list of features into a single feature
Parameters:
feature_selection - a selection of features to combine
"""
def __init__(self, p: EnvironmentBounds, feature_selection: List[Feature]):
self.feature_selection = feature_selection
self.dim_sizes = np.concatenate([f.nvec for f in feature_selection])
super().__init__(p, [self.dim_sizes])
def pretty_print(self, v):
return v
def get(self, a: StateAugmentation, node=None) -> np.ndarray:
"""Return the feature vector"""
feature_vector = [f.get(a, node) for f in self.feature_selection]
return np.concatenate(feature_vector)
class FeatureEncoder(Feature):
""" Encode a list of featues as a unique index
"""
feature_selection: List[Feature]
def vector_to_index(self, feature_vector: np.ndarray) -> int:
raise NotImplementedError
def feature_vector_of_observation_at(self, a: StateAugmentation, node: Optional[int]) -> np.ndarray:
"""Return the current feature vector"""
feature_vector = [f.get(a, node) for f in self.feature_selection]
# print(f'feature_vector={feature_vector} self.feature_selection={self.feature_selection}')
return np.concatenate(feature_vector)
def feature_vector_of_observation(self, a: StateAugmentation):
return self.feature_vector_of_observation_at(a, None)
def encode(self, a: StateAugmentation, node=None) -> int:
"""Return the index encoding of the feature"""
feature_vector_concat = self.feature_vector_of_observation_at(a, node)
return self.vector_to_index(feature_vector_concat)
def encode_at(self, a: StateAugmentation, node) -> int:
"""Return the current feature vector encoding with a node context"""
feature_vector_concat = self.feature_vector_of_observation_at(a, node)
return self.vector_to_index(feature_vector_concat)
def get(self, a: StateAugmentation, node=None) -> np.ndarray:
"""Return the feature vector"""
return np.array([self.encode(a, node)])
def name(self):
"""Return a name for the feature encoding"""
n = ', '.join([f.name() for f in self.feature_selection])
return f'[{n}]'
class HashEncoding(FeatureEncoder):
""" Feature defined as a hash of another feature
Parameters:
feature_selection: a selection of features to combine
hash_dim: dimension after hashing with hash(str(feature_vector)) or -1 for no hashing
"""
def __init__(self, p: EnvironmentBounds, feature_selection: List[Feature], hash_size: int):
self.feature_selection = feature_selection
self.hash_size = hash_size
super().__init__(p, [hash_size])
def flat_size(self):
return self.hash_size
def vector_to_index(self, feature_vector) -> int:
"""Hash the state vector"""
return hash(str(feature_vector)) % self.hash_size
def pretty_print(self, index):
return f'#{index}'
class RavelEncoding(FeatureEncoder):
""" Combine a set of features into a single feature with a unique index
(calculated by raveling the original indices)
Parameters:
feature_selection - a selection of features to combine
"""
def __init__(self, p: EnvironmentBounds, feature_selection: List[Feature]):
self.feature_selection = feature_selection
self.dim_sizes = np.concatenate([f.nvec for f in feature_selection])
self.ravelled_size: int = np.prod(self.dim_sizes)
assert np.shape(self.ravelled_size) == (), f'! {np.shape(self.ravelled_size)}'
super().__init__(p, [self.ravelled_size])
def vector_to_index(self, feature_vector):
assert len(self.dim_sizes) == len(feature_vector), \
f'feature vector of size {len(feature_vector)}, ' \
f'expecting {len(self.dim_sizes)}: {feature_vector} -- {self.dim_sizes}'
index: np.int32 = np.ravel_multi_index(feature_vector, self.dim_sizes)
assert index < self.ravelled_size, \
f'feature vector out of bound ({feature_vector}, dim={self.dim_sizes}) ' \
f'-> index={index}, max_index={self.ravelled_size-1})'
return index
def unravel_index(self, index) -> np.ndarray:
return np.unravel_index(index, self.dim_sizes)
def pretty_print(self, index):
return self.unravel_index(index)
def owned_nodes(observation):
"""Return the list of owned nodes"""
return np.nonzero(observation['nodes_privilegelevel'])[0]
def discovered_nodes_notowned(observation):
"""Return the list of discovered nodes that are not owned yet"""
return np.nonzero(observation['nodes_privilegelevel'] == 0)[0]
class AbstractAction(Feature):
"""An abstraction of the gym state space that reduces
the space dimension for learning use to just
- local_attack(vulnid) (source_node provided)
- remote_attack(vulnid) (source_node provided, target_node forgotten)
- connect(port) (source_node provided, target_node forgotten, credentials infered from cache)
"""
def __init__(self, p: EnvironmentBounds):
self.n_local_actions = p.local_attacks_count
self.n_remote_actions = p.remote_attacks_count
self.n_connect_actions = p.port_count
self.n_actions = self.n_local_actions + self.n_remote_actions + self.n_connect_actions
super().__init__(p, [self.n_actions])
def specialize_to_gymaction(self, source_node: np.int32, observation, abstract_action_index: np.int32
) -> Optional[cyberbattle_env.Action]:
"""Specialize an abstract "q"-action into a gym action.
Return an adjustement weight (1.0 if the choice was deterministic, 1/n if a choice was made out of n)
and the gym action"""
abstract_action_index_int = int(abstract_action_index)
node_prop = np.array(observation['discovered_nodes_properties'])
if abstract_action_index_int < self.n_local_actions:
vuln = abstract_action_index_int
return {'local_vulnerability': np.array([source_node, vuln])}
abstract_action_index_int -= self.n_local_actions
if abstract_action_index_int < self.n_remote_actions:
vuln = abstract_action_index_int
# NOTE: We can do better here than random pick: ultimately this
# should be learnt from target node properties
# pick any node from the discovered ones
target = np.random.choice(len(node_prop))
return {'remote_vulnerability': np.array([source_node, target, vuln])}
abstract_action_index_int -= self.n_remote_actions
port = np.int32(abstract_action_index_int)
discovered_credentials = np.array(observation['credential_cache_matrix'])
n_discovered_creds = len(discovered_credentials)
if n_discovered_creds <= 0:
# no credential available in the cache: cannot poduce a valid connect action
return None
# Pick a matching cred from the discovered_cred matrix
# (at random if more than one exist for this target port)
match_port = discovered_credentials[:, 1] == port
match_port_indices = np.where(match_port)[0]
nodes_not_owned = discovered_nodes_notowned(observation)
match_port__target_notowned = [c for c in match_port_indices
if discovered_credentials[c, 0] in nodes_not_owned]
if match_port__target_notowned:
logging.debug('found matching cred in the credential cache')
cred = np.int32(np.random.choice(match_port__target_notowned))
target = np.int32(discovered_credentials[cred, 0])
return {'connect': np.array([source_node, target, port, cred], dtype=np.int32)}
else:
logging.debug('no cred match')
return None
def abstract_from_gymaction(self, gym_action: cyberbattle_env.Action) -> np.int32:
"""Abstract a gym action into an action to be index in the Q-matrix"""
if 'local_vulnerability' in gym_action:
return gym_action['local_vulnerability'][1]
elif 'remote_vulnerability' in gym_action:
r = gym_action['remote_vulnerability']
return self.n_local_actions + r[2]
assert 'connect' in gym_action
c = gym_action['connect']
a = self.n_local_actions + self.n_remote_actions + c[2]
assert a < self.n_actions
return np.int32(a)
class ActionTrackingStateAugmentation(StateAugmentation):
"""An agent state augmentation consisting of
the environment observation augmented with the following dynamic information:
- success_action_count: count of action taken and succeeded at the current node
- failed_action_count: count of action taken and failed at the current node
"""
def __init__(self, p: EnvironmentBounds, observation: Optional[cyberbattle_env.Observation] = None):
self.aa = AbstractAction(p)
self.success_action_count = np.zeros(shape=(p.maximum_node_count, self.aa.n_actions), dtype=np.int32)
self.failed_action_count = np.zeros(shape=(p.maximum_node_count, self.aa.n_actions), dtype=np.int32)
self.env_properties = p
super().__init__(observation)
def on_step(self, action: cyberbattle_env.Action, reward: float, done: bool, observation: cyberbattle_env.Observation):
node = cyberbattle_env.sourcenode_of_action(action)
abstract_action = self.aa.abstract_from_gymaction(action)
if reward > 0:
self.success_action_count[node, abstract_action] += 1
else:
self.failed_action_count[node, abstract_action] += 1
super().on_step(action, reward, done, observation)
def on_reset(self, observation: cyberbattle_env.Observation):
p = self.env_properties
self.success_action_count = np.zeros(shape=(p.maximum_node_count, self.aa.n_actions), dtype=np.int32)
self.failed_action_count = np.zeros(shape=(p.maximum_node_count, self.aa.n_actions), dtype=np.int32)
super().on_reset(observation)
class Feature_actions_tried_at_node(Feature):
"""A bit mask indicating which actions were already tried
a the current node: 0 no tried, 1 tried"""
def __init__(self, p: EnvironmentBounds):
super().__init__(p, [2] * AbstractAction(p).n_actions)
def get(self, a: ActionTrackingStateAugmentation, node: int):
return ((a.failed_action_count[node, :] + a.success_action_count[node, :]) != 0) * 1
class Feature_success_actions_at_node(Feature):
"""number of time each action succeeded at a given node"""
max_action_count = 100
def __init__(self, p: EnvironmentBounds):
super().__init__(p, [self.max_action_count] * AbstractAction(p).n_actions)
def get(self, a: ActionTrackingStateAugmentation, node: int):
return np.minimum(a.success_action_count[node, :], self.max_action_count - 1)
class Feature_failed_actions_at_node(Feature):
"""number of time each action failed at a given node"""
max_action_count = 100
def __init__(self, p: EnvironmentBounds):
super().__init__(p, [self.max_action_count] * AbstractAction(p).n_actions)
def get(self, a: ActionTrackingStateAugmentation, node: int):
return np.minimum(a.failed_action_count[node, :], self.max_action_count - 1)
class Verbosity(enum.Enum):
"""Verbosity of the learning function"""
Quiet = 0
Normal = 1
Verbose = 2
class AgentWrapper(Wrapper):
"""Gym wrapper to update the agent state on every step"""
def __init__(self, env: cyberbattle_env.CyberBattleEnv, state: StateAugmentation):
super().__init__(env)
self.state = state
def step(self, action: cyberbattle_env.Action):
observation, reward, done, info = self.env.step(action)
self.state.on_step(action, reward, done, observation)
return observation, reward, done, info
def reset(self):
observation = self.env.reset()
self.state.on_reset(observation)
return observation

Просмотреть файл

@ -0,0 +1,71 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# -*- coding: utf-8 -*-
"""Test training of baseline agents. """
import torch
import gym
import logging
import sys
import cyberbattle._env.cyberbattle_env as cyberbattle_env
from cyberbattle.agents.baseline.agent_wrapper import Verbosity
import cyberbattle.agents.baseline.agent_dql as dqla
import cyberbattle.agents.baseline.agent_wrapper as w
import cyberbattle.agents.baseline.learner as learner
logging.basicConfig(stream=sys.stdout, level=logging.ERROR, format="%(levelname)s: %(message)s")
print(f"torch cuda available={torch.cuda.is_available()}")
cyberbattlechain = gym.make('CyberBattleChain-v0',
size=4,
attacker_goal=cyberbattle_env.AttackerGoal(
own_atleast_percent=1.0,
reward=100))
ep = w.EnvironmentBounds.of_identifiers(
maximum_total_credentials=10,
maximum_node_count=10,
identifiers=cyberbattlechain.identifiers
)
training_episode_count = 2
iteration_count = 5
def test_agent_training() -> None:
dqn_learning_run = learner.epsilon_greedy_search(
cyberbattle_gym_env=cyberbattlechain,
environment_properties=ep,
learner=dqla.DeepQLearnerPolicy(
ep=ep,
gamma=0.015,
replay_memory_size=10000,
target_update=10,
batch_size=512,
learning_rate=0.01), # torch default is 1e-2
episode_count=training_episode_count,
iteration_count=iteration_count,
epsilon=0.90,
render=False,
# epsilon_multdecay=0.75, # 0.999,
epsilon_exponential_decay=5000, # 10000
epsilon_minimum=0.10,
verbosity=Verbosity.Quiet,
title="DQL"
)
assert dqn_learning_run
random_run = learner.epsilon_greedy_search(
cyberbattlechain,
ep,
learner=learner.RandomPolicy(),
episode_count=training_episode_count,
iteration_count=iteration_count,
epsilon=1.0, # purely random
render=False,
verbosity=Verbosity.Quiet,
title="Random search"
)
assert random_run

Просмотреть файл

@ -0,0 +1,385 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Learner helpers and epsilon greedy search"""
import math
import sys
from .plotting import PlotTraining, plot_averaged_cummulative_rewards
from .agent_wrapper import AgentWrapper, EnvironmentBounds, Verbosity, ActionTrackingStateAugmentation
import logging
import numpy as np
from cyberbattle._env import cyberbattle_env
from typing import Tuple, Optional, TypedDict, List
import progressbar
import abc
class Learner(abc.ABC):
"""Interface to be implemented by an epsilon-greedy learner"""
def new_episode(self) -> None:
return None
def end_of_episode(self, i_episode, t) -> None:
return None
def end_of_iteration(self, t, done) -> None:
return None
@abc.abstractmethod
def explore(self, wrapped_env: AgentWrapper) -> Tuple[str, cyberbattle_env.Action, object]:
"""Exploration function.
Returns (action_type, gym_action, action_metadata) where
action_metadata is a custom object that gets passed to the on_step callback function"""
raise NotImplementedError
@abc.abstractmethod
def exploit(self, wrapped_env: AgentWrapper, observation) -> Tuple[str, Optional[cyberbattle_env.Action], object]:
"""Exploit function.
Returns (action_type, gym_action, action_metadata) where
action_metadata is a custom object that gets passed to the on_step callback function"""
raise NotImplementedError
@abc.abstractmethod
def on_step(self, wrapped_env: AgentWrapper, observation, reward, done, info, action_metadata) -> None:
raise NotImplementedError
def parameters_as_string(self) -> str:
return ''
def all_parameters_as_string(self) -> str:
return ''
def loss_as_string(self) -> str:
return ''
def stateaction_as_string(self, action_metadata) -> str:
return ''
class RandomPolicy(Learner):
"""A policy that does not learn and only explore"""
def explore(self, wrapped_env: AgentWrapper) -> Tuple[str, cyberbattle_env.Action, object]:
gym_action = wrapped_env.env.sample_valid_action()
return "explore", gym_action, None
def exploit(self, wrapped_env: AgentWrapper, observation) -> Tuple[str, Optional[cyberbattle_env.Action], object]:
raise NotImplementedError
def on_step(self, wrapped_env: AgentWrapper, observation, reward, done, info, action_metadata):
return None
Breakdown = TypedDict('Breakdown', {
'local': int,
'remote': int,
'connect': int
})
Outcomes = TypedDict('Outcomes', {
'reward': Breakdown,
'noreward': Breakdown
})
Stats = TypedDict('Stats', {
'exploit': Outcomes,
'explore': Outcomes,
'exploit_deflected_to_explore': int
})
TrainedLearner = TypedDict('TrainedLearner', {
'all_episodes_rewards': List[List[float]],
'all_episodes_availability': List[List[float]],
'learner': Learner,
'trained_on': str,
'title': str
})
def print_stats(stats):
"""Print learning statistics"""
def print_breakdown(stats, actiontype: str):
def ratio(kind: str) -> str:
x, y = stats[actiontype]['reward'][kind], stats[actiontype]['noreward'][kind]
sum = x + y
if sum == 0:
return 'NaN'
else:
return f"{(x / sum):.2f}"
def print_kind(kind: str):
print(
f" {actiontype}-{kind}: {stats[actiontype]['reward'][kind]}/{stats[actiontype]['noreward'][kind]} "
f"({ratio(kind)})")
print_kind('local')
print_kind('remote')
print_kind('connect')
print(" Breakdown [Reward/NoReward (Success rate)]")
print_breakdown(stats, 'explore')
print_breakdown(stats, 'exploit')
print(f" exploit deflected to exploration: {stats['exploit_deflected_to_explore']}")
def epsilon_greedy_search(
cyberbattle_gym_env: cyberbattle_env.CyberBattleEnv,
environment_properties: EnvironmentBounds,
learner: Learner,
title: str,
episode_count: int,
iteration_count: int,
epsilon: float,
epsilon_minimum=0.0,
epsilon_multdecay: Optional[float] = None,
epsilon_exponential_decay: Optional[int] = None,
render=True,
render_last_episode_rewards_to: Optional[str] = None,
verbosity: Verbosity = Verbosity.Normal
) -> TrainedLearner:
"""Epsilon greedy search for CyberBattle gym environments
Parameters
==========
- cyberbattle_gym_env -- the CyberBattle environment to train on
- learner --- the policy learner/exploiter
- episode_count -- Number of training episodes
- iteration_count -- Maximum number of iterations in each episode
- epsilon -- explore vs exploit
- 0.0 to exploit the learnt policy only without exploration
- 1.0 to explore purely randomly
- epsilon_minimum -- epsilon decay clipped at this value.
Setting this value too close to 0 may leed the search to get stuck.
- epsilon_decay -- epsilon gets multiplied by this value after each episode
- epsilon_exponential_decay - if set use exponential decay. The bigger the value
is, the slower it takes to get from the initial `epsilon` to `epsilon_minimum`.
- verbosity -- verbosity of the `print` logging
- render -- render the environment interactively after each episode
- render_last_episode_rewards_to -- render the environment to the specified file path
with an index appended to it each time there is a positive reward
for the last episode only
Note on convergence
===================
Setting 'minimum_espilon' to 0 with an exponential decay <1
makes the learning converge quickly (loss function getting to 0),
but that's just a forced convergence, however, since when
epsilon approaches 0, only the q-values that were explored so
far get updated and so only that subset of cells from
the Q-matrix converges.
"""
print(f"###### {title}\n"
f"Learning with: episode_count={episode_count},"
f"iteration_count={iteration_count},"
f"ϵ={epsilon},"
f'ϵ_min={epsilon_minimum}, '
+ (f"ϵ_multdecay={epsilon_multdecay}," if epsilon_multdecay else '')
+ (f"ϵ_expdecay={epsilon_exponential_decay}," if epsilon_exponential_decay else '') +
f"{learner.parameters_as_string()}")
initial_epsilon = epsilon
all_episodes_rewards = []
all_episodes_availability = []
wrapped_env = AgentWrapper(cyberbattle_gym_env,
ActionTrackingStateAugmentation(environment_properties))
steps_done = 0
plot_title = f"{title} (epochs={episode_count}, ϵ={initial_epsilon}, ϵ_min={epsilon_minimum}," \
+ (f"ϵ_multdecay={epsilon_multdecay}," if epsilon_multdecay else '') \
+ (f"ϵ_expdecay={epsilon_exponential_decay}," if epsilon_exponential_decay else '') \
+ learner.parameters_as_string()
plottraining = PlotTraining(title=plot_title, render_each_episode=render)
render_file_index = 1
for i_episode in range(1, episode_count + 1):
print(f" ## Episode: {i_episode}/{episode_count} '{title}' "
f"ϵ={epsilon:.4f}, "
f"{learner.parameters_as_string()}")
observation = wrapped_env.reset()
total_reward = 0.0
all_rewards = []
all_availability = []
learner.new_episode()
stats = Stats(exploit=Outcomes(reward=Breakdown(local=0, remote=0, connect=0),
noreward=Breakdown(local=0, remote=0, connect=0)),
explore=Outcomes(reward=Breakdown(local=0, remote=0, connect=0),
noreward=Breakdown(local=0, remote=0, connect=0)),
exploit_deflected_to_explore=0
)
episode_ended_at = None
sys.stdout.flush()
bar = progressbar.ProgressBar(
widgets=[
'Episode ',
f'{i_episode}',
'|Iteration ',
progressbar.Counter(),
'|',
progressbar.Variable(name='reward', width=6, precision=10),
'|',
progressbar.Timer(),
progressbar.Bar()
],
redirect_stdout=False)
for t in bar(range(1, 1 + iteration_count)):
if epsilon_exponential_decay:
epsilon = epsilon_minimum + math.exp(-1. * steps_done /
epsilon_exponential_decay) * (initial_epsilon - epsilon_minimum)
steps_done += 1
x = np.random.rand()
if x <= epsilon:
action_style, gym_action, action_metadata = learner.explore(wrapped_env)
else:
action_style, gym_action, action_metadata = learner.exploit(wrapped_env, observation)
if not gym_action:
stats['exploit_deflected_to_explore'] += 1
_, gym_action, action_metadata = learner.explore(wrapped_env)
# Take the step
logging.debug(f"gym_action={gym_action}, action_metadata={action_metadata}")
observation, reward, done, info = wrapped_env.step(gym_action)
action_type = 'exploit' if action_style == 'exploit' else 'explore'
outcome = 'reward' if reward > 0 else 'noreward'
if 'local_vulnerability' in gym_action:
stats[action_type][outcome]['local'] += 1
elif 'remote_vulnerability' in gym_action:
stats[action_type][outcome]['remote'] += 1
else:
stats[action_type][outcome]['connect'] += 1
learner.on_step(wrapped_env, observation, reward, done, info, action_metadata)
assert np.shape(reward) == ()
all_rewards.append(reward)
all_availability.append(info['network_availability'])
total_reward += reward
bar.update(t, reward=total_reward)
if verbosity == Verbosity.Verbose or (verbosity == Verbosity.Normal and reward > 0):
sign = ['-', '+'][reward > 0]
print(f" {sign} t={t} {action_style} r={reward} cum_reward:{total_reward} "
f"a={action_metadata}-{gym_action} "
f"creds={len(observation['credential_cache_matrix'])} "
f" {learner.stateaction_as_string(action_metadata)}")
if i_episode == episode_count \
and render_last_episode_rewards_to is not None \
and reward > 0:
fig = cyberbattle_gym_env.render_as_fig()
fig.write_image(f"{render_last_episode_rewards_to}-e{i_episode}-{render_file_index}.png")
render_file_index += 1
learner.end_of_iteration(t, done)
if done:
episode_ended_at = t
bar.finish(dirty=True)
break
sys.stdout.flush()
loss_string = learner.loss_as_string()
if loss_string:
loss_string = "loss={loss_string}"
if episode_ended_at:
print(f" Episode {i_episode} ended at t={episode_ended_at} {loss_string}")
else:
print(f" Episode {i_episode} stopped at t={iteration_count} {loss_string}")
print_stats(stats)
all_episodes_rewards.append(all_rewards)
all_episodes_availability.append(all_availability)
length = episode_ended_at if episode_ended_at else iteration_count
learner.end_of_episode(i_episode=i_episode, t=length)
plottraining.episode_done(length)
if render:
wrapped_env.render()
if epsilon_multdecay:
epsilon = max(epsilon_minimum, epsilon * epsilon_multdecay)
wrapped_env.close()
print("simulation ended")
plottraining.plot_end()
return TrainedLearner(
all_episodes_rewards=all_episodes_rewards,
all_episodes_availability=all_episodes_availability,
learner=learner,
trained_on=cyberbattle_gym_env.name,
title=plot_title
)
def transfer_learning_evaluation(
environment_properties: EnvironmentBounds,
trained_learner: TrainedLearner,
eval_env: cyberbattle_env.CyberBattleEnv,
eval_epsilon: float,
eval_episode_count: int,
iteration_count: int,
benchmark_policy=RandomPolicy(),
benchmark_training_args=dict(title="Benchmark", epsilon=1.0)
):
"""Evaluated a trained agent on another environment of different size"""
eval_oneshot_all = epsilon_greedy_search(
eval_env,
environment_properties,
learner=trained_learner['learner'],
episode_count=eval_episode_count, # one shot from learnt Q matric
iteration_count=iteration_count,
epsilon=eval_epsilon,
render=False,
verbosity=Verbosity.Quiet,
title=f"One shot on {eval_env.name} - Trained on {trained_learner['trained_on']}"
)
eval_random = epsilon_greedy_search(
eval_env,
environment_properties,
learner=benchmark_policy,
episode_count=eval_episode_count,
iteration_count=iteration_count,
render=False,
verbosity=Verbosity.Quiet,
**benchmark_training_args
)
plot_averaged_cummulative_rewards(
all_runs=[eval_oneshot_all, eval_random],
title=f"Transfer learning {trained_learner['trained_on']}->{eval_env.name} "
f'-- max_nodes={environment_properties.maximum_node_count}, '
f'episodes={eval_episode_count},\n'
f"{trained_learner['learner'].all_parameters_as_string()}")

Просмотреть файл

@ -0,0 +1,2 @@
ffmpeg -y -r 2 -i chain10-e10-%d.png -vf "split[s0][s1];[s0]palettegen[p];[s1][p]paletteuse" chain10-dql.gif
ffmpeg -y -r 2 -i chain10-e10-%d.png -vf "split[s0][s1];[s0]palettegen[p];[s1][p]paletteuse,crop=480:400:320:0" chain10-dql-network.gif

Просмотреть файл

@ -0,0 +1,215 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Random agent with credential lookup (notebook)
This notebooks can be run directly from VSCode, to generate a
traditional Jupyter Notebook to open in your browser
you can run the VSCode command `Export Currenty Python File As Jupyter Notebook`.
"""
# pylint: disable=invalid-name
# %% [markdown]
# # Chain network CyberBattle Gym played by a random agent with credential cache lookup
# %%
from cyberbattle._env import cyberbattle_env
import gym
import logging
import sys
import cyberbattle.agents.baseline.plotting as p
import cyberbattle.agents.baseline.agent_wrapper as w
from cyberbattle.agents.baseline.agent_wrapper import Verbosity
import cyberbattle.agents.baseline.learner as learner
import cyberbattle.agents.baseline.agent_randomcredlookup as rca
import cyberbattle.agents.baseline.agent_dql as dqla
import cyberbattle.agents.baseline.agent_tabularqlearning as tqa
import importlib
importlib.reload(tqa)
importlib.reload(dqla)
importlib.reload(learner)
importlib.reload(cyberbattle_env)
logging.basicConfig(stream=sys.stdout, level=logging.ERROR, format="%(levelname)s: %(message)s")
# %% [markdown]
# # Gym environment: chain-like network
# See Jupyer notebook `chainenetwork-random` for an introduction to this network environment.
cyberbattlechain_10 = gym.make(
'CyberBattleChain-v0',
size=10,
attacker_goal=cyberbattle_env.AttackerGoal(reward=4000, own_atleast_percent=1.0)
)
cyberbattlechain_10.environment
# training_env.environment.plot_environment_graph()
cyberbattlechain_10.environment.network.nodes
cyberbattlechain_10.action_space
cyberbattlechain_10.action_space.sample()
cyberbattlechain_10.observation_space.sample()
o0 = cyberbattlechain_10.reset()
o_test, r, d, i = cyberbattlechain_10.step(cyberbattlechain_10.sample_valid_action())
o0 = cyberbattlechain_10.reset()
o0.keys()
# %%
ep = w.EnvironmentBounds.of_identifiers(
maximum_node_count=22,
maximum_total_credentials=22,
identifiers=cyberbattlechain_10.identifiers
)
print(f"port_count = {ep.port_count}, property_count = {ep.property_count}")
fe_example = w.RavelEncoding(ep, [w.Feature_active_node_properties(ep), w.Feature_discovered_node_count(ep)])
a = w.StateAugmentation(o0)
w.Feature_discovered_ports(ep).get(a, None)
fe_example.encode_at(a, 0)
iteration_count = 9000
training_episode_count = 50
eval_episode_count = 5
# %%
random_run = learner.epsilon_greedy_search(
cyberbattlechain_10,
ep,
learner=learner.RandomPolicy(),
episode_count=10, # training_episode_count,
iteration_count=iteration_count,
epsilon=1.0,
render=False,
verbosity=Verbosity.Quiet,
title="Random"
)
# %%
credlookup_run = learner.epsilon_greedy_search(
cyberbattlechain_10,
ep,
learner=rca.CredentialCacheExploiter(),
episode_count=10,
iteration_count=iteration_count,
epsilon=0.90,
render=False,
epsilon_exponential_decay=10000,
epsilon_minimum=0.10,
verbosity=Verbosity.Quiet,
title="Credential lookups (ϵ-greedy)"
)
# %%
tabularq_run = learner.epsilon_greedy_search(
cyberbattlechain_10,
ep,
learner=tqa.QTabularLearner(
ep,
gamma=0.10, learning_rate=0.90, exploit_percentile=100),
render=False,
episode_count=training_episode_count,
iteration_count=iteration_count,
epsilon=0.90,
epsilon_exponential_decay=10000,
epsilon_minimum=0.10,
verbosity=Verbosity.Quiet,
title="Tabular Q-learning"
)
# %%
tabularq_exploit_run = learner.epsilon_greedy_search(
cyberbattlechain_10,
ep,
learner=tqa.QTabularLearner(
ep,
trained=tabularq_run['learner'],
gamma=0.0,
learning_rate=0.0,
exploit_percentile=90),
episode_count=eval_episode_count,
iteration_count=iteration_count,
epsilon=0.0,
render=False,
verbosity=Verbosity.Quiet,
title="Exploiting Q-matrix"
)
# %%
dql_run = learner.epsilon_greedy_search(
cyberbattle_gym_env=cyberbattlechain_10,
environment_properties=ep,
learner=dqla.DeepQLearnerPolicy(
ep=ep,
gamma=0.015,
replay_memory_size=10000,
target_update=10,
batch_size=512,
learning_rate=0.01
),
episode_count=15,
iteration_count=iteration_count,
epsilon=0.90,
render=False,
epsilon_exponential_decay=5000,
epsilon_minimum=0.10,
verbosity=Verbosity.Quiet,
title="DQL"
)
# %%
dql_exploit_run = learner.epsilon_greedy_search(
cyberbattlechain_10,
ep,
learner=dql_run['learner'],
episode_count=50,
iteration_count=iteration_count,
epsilon=0.00,
epsilon_minimum=0.00,
render=False,
verbosity=Verbosity.Quiet,
title="Exploiting DQL"
)
# %%
all_runs = [
random_run,
credlookup_run,
tabularq_run,
tabularq_exploit_run,
dql_run,
dql_exploit_run
]
p.plot_episodes_length(all_runs)
p.plot_averaged_cummulative_rewards(
title=f'Agent Benchmark\n'
f'max_nodes:{ep.maximum_node_count}\n',
all_runs=all_runs)
# %%
contenders = [
credlookup_run,
tabularq_run,
dql_run,
dql_exploit_run
]
p.plot_episodes_length(contenders)
p.plot_averaged_cummulative_rewards(
title=f'Agent Benchmark top contenders\n'
f'max_nodes:{ep.maximum_node_count}\n',
all_runs=contenders)
# %%
for r in contenders:
p.plot_all_episodes(r)
# %%

Просмотреть файл

@ -0,0 +1,116 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Tabular Q-learning agent (notebook)
This notebooks can be run directly from VSCode, to generate a
traditional Jupyter Notebook to open in your browser
you can run the VSCode command `Export Currenty Python File As Jupyter Notebook`.
"""
# pylint: disable=invalid-name
# %%
import sys
import logging
import gym
import cyberbattle.agents.baseline.learner as learner
import cyberbattle.agents.baseline.plotting as p
import cyberbattle.agents.baseline.agent_wrapper as w
import cyberbattle.agents.baseline.agent_dql as dqla
from cyberbattle.agents.baseline.agent_wrapper import Verbosity
import importlib
importlib.reload(learner)
logging.basicConfig(stream=sys.stdout, level=logging.ERROR, format="%(levelname)s: %(message)s")
# %%
ctf_env = gym.make('CyberBattleToyCtf-v0')
ep = w.EnvironmentBounds.of_identifiers(
maximum_node_count=22,
maximum_total_credentials=22,
identifiers=ctf_env.identifiers
)
iteration_count = 2000
training_episode_count = 10
eval_episode_count = 10
# %%
# Run Deep Q-learning
# 0.015
best_dqn_learning_run_10 = learner.epsilon_greedy_search(
cyberbattle_gym_env=ctf_env,
environment_properties=ep,
learner=dqla.DeepQLearnerPolicy(
ep=ep,
gamma=0.015,
replay_memory_size=10000,
target_update=10,
batch_size=512,
learning_rate=0.01), # torch default is 1e-2
episode_count=training_episode_count,
iteration_count=iteration_count,
epsilon=0.90,
render=False,
# epsilon_multdecay=0.75, # 0.999,
epsilon_exponential_decay=5000, # 10000
epsilon_minimum=0.10,
verbosity=Verbosity.Quiet,
title="DQL"
)
# %% Plot episode length
p.plot_episodes_length([best_dqn_learning_run_10])
# %%
dql_exploit_run = learner.epsilon_greedy_search(
ctf_env,
ep,
learner=best_dqn_learning_run_10['learner'],
episode_count=eval_episode_count,
iteration_count=iteration_count,
epsilon=0.0, # 0.35,
render=False,
title="Exploiting DQL",
verbosity=Verbosity.Quiet
)
# %%
random_run = learner.epsilon_greedy_search(
ctf_env,
ep,
learner=learner.RandomPolicy(),
episode_count=eval_episode_count,
iteration_count=iteration_count,
epsilon=1.0, # purely random
render=False,
verbosity=Verbosity.Quiet,
title="Random search"
)
# %%
# Plot averaged cumulative rewards for DQL vs Random vs DQL-Exploit
themodel = dqla.CyberBattleStateActionModel(ep)
p.plot_averaged_cummulative_rewards(
all_runs=[
best_dqn_learning_run_10,
random_run,
dql_exploit_run
],
title=f'Benchmark -- max_nodes={ep.maximum_node_count}, episodes={eval_episode_count},\n'
f'State: {[f.name() for f in themodel.state_space.feature_selection]} '
f'({len(themodel.state_space.feature_selection)}\n'
f"Action: abstract_action ({themodel.action_space.flat_size()})")
# %%
# plot cumulative rewards for all episodes
p.plot_all_episodes(best_dqn_learning_run_10)

Просмотреть файл

@ -0,0 +1,211 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# -*- coding: utf-8 -*-
# %%
"""Deep Q-learning agent (notebook)
This notebooks can be run directly from VSCode, to generate a
traditional Jupyter Notebook to open in your browser
you can run the VSCode command `Export Currenty Python File As Jupyter Notebook`.
"""
# %%
import os
import sys
import logging
import gym
import torch
import cyberbattle.agents.baseline.learner as learner
import cyberbattle.agents.baseline.plotting as p
import cyberbattle.agents.baseline.agent_wrapper as w
import cyberbattle.agents.baseline.agent_dql as dqla
from cyberbattle.agents.baseline.agent_wrapper import Verbosity
import cyberbattle.agents.baseline.agent_randomcredlookup as rca
import importlib
import cyberbattle._env.cyberbattle_env as cyberbattle_env
import cyberbattle._env.cyberbattle_chain as cyberbattle_chain
importlib.reload(learner)
importlib.reload(cyberbattle_env)
importlib.reload(cyberbattle_chain)
logging.basicConfig(stream=sys.stdout, level=logging.ERROR, format="%(levelname)s: %(message)s")
# %%
torch.cuda.is_available()
# %%
# To run once
# import plotly.io as pio
# pio.orca.config.use_xvfb = True
# pio.orca.config.save()
# %%
cyberbattlechain_4 = gym.make('CyberBattleChain-v0', size=4, attacker_goal=cyberbattle_env.AttackerGoal(reward=2180))
cyberbattlechain_10 = gym.make('CyberBattleChain-v0', size=10, attacker_goal=cyberbattle_env.AttackerGoal(reward=4000))
cyberbattlechain_20 = gym.make('CyberBattleChain-v0', size=20, attacker_goal=cyberbattle_env.AttackerGoal(reward=7000))
ep = w.EnvironmentBounds.of_identifiers(
maximum_total_credentials=22,
maximum_node_count=22,
identifiers=cyberbattlechain_10.identifiers
)
iteration_count = 9000
training_episode_count = 50
eval_episode_count = 10
# %%
# Run Deep Q-learning
# 0.015
best_dqn_learning_run_10 = learner.epsilon_greedy_search(
cyberbattle_gym_env=cyberbattlechain_10,
environment_properties=ep,
learner=dqla.DeepQLearnerPolicy(
ep=ep,
gamma=0.015,
replay_memory_size=10000,
target_update=10,
batch_size=512,
learning_rate=0.01), # torch default is 1e-2
episode_count=training_episode_count,
iteration_count=iteration_count,
epsilon=0.90,
render=False,
# epsilon_multdecay=0.75, # 0.999,
epsilon_exponential_decay=5000, # 10000
epsilon_minimum=0.10,
verbosity=Verbosity.Quiet,
title="DQL"
)
# %% Plot episode length
p.plot_episodes_length([best_dqn_learning_run_10])
# %%
if not os.path.exists("images"):
os.mkdir("images")
# %%
dql_exploit_run = learner.epsilon_greedy_search(
cyberbattlechain_10,
ep,
learner=best_dqn_learning_run_10['learner'],
episode_count=eval_episode_count,
iteration_count=iteration_count,
epsilon=0.0, # 0.35,
render=False,
render_last_episode_rewards_to='images/chain10',
title="Exploiting DQL",
verbosity=Verbosity.Quiet
)
# %%
random_run = learner.epsilon_greedy_search(
cyberbattlechain_10,
ep,
learner=learner.RandomPolicy(),
episode_count=eval_episode_count,
iteration_count=iteration_count,
epsilon=1.0, # purely random
render=False,
verbosity=Verbosity.Quiet,
title="Random search"
)
# %%
# Plot averaged cumulative rewards for DQL vs Random vs DQL-Exploit
themodel = dqla.CyberBattleStateActionModel(ep)
p.plot_averaged_cummulative_rewards(
all_runs=[
best_dqn_learning_run_10,
random_run,
dql_exploit_run
],
title=f'Benchmark -- max_nodes={ep.maximum_node_count}, episodes={eval_episode_count},\n'
f'State: {[f.name() for f in themodel.state_space.feature_selection]} '
f'({len(themodel.state_space.feature_selection)}\n'
f"Action: abstract_action ({themodel.action_space.flat_size()})")
# %%
# plot cumulative rewards for all episodes
p.plot_all_episodes(best_dqn_learning_run_10)
##################################################
# %%
# %%
best_dqn_4 = learner.epsilon_greedy_search(
cyberbattle_gym_env=cyberbattlechain_4,
environment_properties=ep,
learner=dqla.DeepQLearnerPolicy(
ep=ep,
gamma=0.15,
replay_memory_size=10000,
target_update=5,
batch_size=256,
learning_rate=0.01),
episode_count=training_episode_count,
iteration_count=iteration_count,
epsilon=0.90,
render=False,
epsilon_exponential_decay=5000,
epsilon_minimum=0.10,
verbosity=Verbosity.Quiet,
title="DQL"
)
# %%
learner.transfer_learning_evaluation(
environment_properties=ep,
trained_learner=best_dqn_learning_run_10,
eval_env=cyberbattlechain_20,
eval_epsilon=0.0, # alternate with exploration to help generalization to bigger network
eval_episode_count=eval_episode_count,
iteration_count=iteration_count,
benchmark_policy=rca.CredentialCacheExploiter(),
benchmark_training_args={'epsilon': 0.90,
'epsilon_exponential_decay': 10000,
'epsilon_minimum': 0.10,
'title': 'Credential lookups (ϵ-greedy)'}
)
# %%
learner.transfer_learning_evaluation(
environment_properties=ep,
trained_learner=best_dqn_4,
eval_env=cyberbattlechain_10,
eval_epsilon=0.0, # exploit Q-matrix only
eval_episode_count=eval_episode_count,
iteration_count=iteration_count,
benchmark_policy=rca.CredentialCacheExploiter(),
benchmark_training_args={'epsilon': 0.90,
'epsilon_exponential_decay': 10000,
'epsilon_minimum': 0.10,
'title': 'Credential lookups (ϵ-greedy)'}
)
# %%
learner.transfer_learning_evaluation(
environment_properties=ep,
trained_learner=best_dqn_4,
eval_env=cyberbattlechain_20,
eval_epsilon=0.0, # exploit Q-matrix only
eval_episode_count=eval_episode_count,
iteration_count=iteration_count,
benchmark_policy=rca.CredentialCacheExploiter(),
benchmark_training_args={'epsilon': 0.90,
'epsilon_exponential_decay': 10000,
'epsilon_minimum': 0.10,
'title': 'Credential lookups (ϵ-greedy)'}
)
# %%

Просмотреть файл

@ -0,0 +1,85 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Random exploration with credential lookup exploitation (notebook)
This notebooks can be run directly from VSCode, to generate a
traditional Jupyter Notebook to open in your browser
you can run the VSCode command `Export Currenty Python File As Jupyter Notebook`.
"""
# pylint: disable=invalid-name
# %%
from cyberbattle._env.cyberbattle_env import AttackerGoal
from cyberbattle.agents.baseline.agent_randomcredlookup import CredentialCacheExploiter
import cyberbattle.agents.baseline.learner as learner
import gym
import logging
import sys
import cyberbattle.agents.baseline.plotting as p
import cyberbattle.agents.baseline.agent_wrapper as w
from cyberbattle.agents.baseline.agent_wrapper import Verbosity
# %%
logging.basicConfig(stream=sys.stdout, level=logging.ERROR, format="%(levelname)s: %(message)s")
# %%
cyberbattlechain_10 = gym.make('CyberBattleChain-v0', size=10,
attacker_goal=AttackerGoal(own_atleast_percent=1.0))
# %%
ep = w.EnvironmentBounds.of_identifiers(
maximum_total_credentials=12,
maximum_node_count=12,
identifiers=cyberbattlechain_10.identifiers
)
iteration_count = 9000
training_episode_count = 50
eval_episode_count = 5
# %%
credexplot = learner.epsilon_greedy_search(
cyberbattlechain_10,
learner=CredentialCacheExploiter(),
environment_properties=ep,
episode_count=training_episode_count,
iteration_count=iteration_count,
epsilon=0.90,
render=False,
epsilon_multdecay=0.75, # 0.999,
epsilon_minimum=0.01,
verbosity=Verbosity.Quiet,
title="Random+CredLookup"
)
# %%
randomlearning_results = learner.epsilon_greedy_search(
cyberbattlechain_10,
environment_properties=ep,
learner=CredentialCacheExploiter(),
episode_count=eval_episode_count,
iteration_count=iteration_count,
epsilon=1.0, # purely random
render=False,
verbosity=Verbosity.Quiet,
title="Random search"
)
# %%
p.plot_episodes_length([credexplot])
p.plot_all_episodes(credexplot)
all_runs = [credexplot,
randomlearning_results
]
p.plot_averaged_cummulative_rewards(
title=f'Benchmark -- max_nodes={ep.maximum_node_count}, episodes={eval_episode_count},\n',
all_runs=all_runs)
# %%

Просмотреть файл

@ -0,0 +1,226 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Tabular Q-learning agent (notebook)
This notebooks can be run directly from VSCode, to generate a
traditional Jupyter Notebook to open in your browser
you can run the VSCode command `Export Currenty Python File As Jupyter Notebook`.
"""
# pylint: disable=invalid-name
# %%
import sys
import logging
from typing import cast
import gym
import numpy as np
import matplotlib.pyplot as plt
from cyberbattle.agents.baseline.learner import TrainedLearner
import cyberbattle.agents.baseline.plotting as p
import cyberbattle.agents.baseline.agent_wrapper as w
import cyberbattle.agents.baseline.agent_tabularqlearning as a
from cyberbattle.agents.baseline.agent_wrapper import Verbosity
import cyberbattle.agents.baseline.learner as learner
from cyberbattle._env.cyberbattle_env import AttackerGoal
logging.basicConfig(stream=sys.stdout, level=logging.ERROR, format="%(levelname)s: %(message)s")
# %%
# Benchmark parameters:
# Parameters from DeepDoubleQ paper
# - learning_rate = 0.00025
# - linear epsilon decay
# - gamma = 0.99
# Eliminated gamma_values
# 0.0,
# 0.0015, # too small
# 0.15, # too big
# 0.25, # too big
# 0.35, # too big
#
# NOTE: Given the relatively low number of training episodes (50,
# a high learning rate of .99 gives better result
# than a lower learning rate of 0.25 (i.e. maximal rewards reached faster on average).
# Ideally we should decay the learning rate just like gamma and train over a
# much larger number of episodes
cyberbattlechain_10 = gym.make('CyberBattleChain-v0', size=10, attacker_goal=AttackerGoal(own_atleast_percent=1.0))
ep = w.EnvironmentBounds.of_identifiers(
maximum_node_count=12,
maximum_total_credentials=12,
identifiers=cyberbattlechain_10.identifiers
)
iteration_count = 9000
training_episode_count = 5
eval_episode_count = 5
gamma_sweep = [
0.015, # about right
]
def qlearning_run(gamma, gym_env):
"""Execute one run of the q-learning algorithm for the
specified gamma value"""
return learner.epsilon_greedy_search(
gym_env,
ep,
a.QTabularLearner(ep, gamma=gamma, learning_rate=0.90, exploit_percentile=100),
episode_count=training_episode_count,
iteration_count=iteration_count,
epsilon=0.90,
render=False,
epsilon_multdecay=0.75, # 0.999,
epsilon_minimum=0.01,
verbosity=Verbosity.Quiet,
title="Q-learning"
)
# %%
# Run Q-learning with gamma-sweep
qlearning_results = [qlearning_run(gamma, cyberbattlechain_10) for gamma in gamma_sweep]
qlearning_bestrun_10 = qlearning_results[0]
# %%
p.new_plot_loss()
for results in qlearning_results:
p.plot_all_episodes_loss(cast(a.QTabularLearner, results['learner']).loss_qsource.all_episodes, 'Q_source', results['title'])
p.plot_all_episodes_loss(cast(a.QTabularLearner, results['learner']).loss_qattack.all_episodes, 'Q_attack', results['title'])
plt.legend(loc="upper right")
plt.show()
# %% Plot episode length
p.plot_episodes_length(qlearning_results)
# %%
nolearning_results = learner.epsilon_greedy_search(
cyberbattlechain_10,
ep,
learner=a.QTabularLearner(ep, trained=qlearning_bestrun_10['learner'],
gamma=0.0, learning_rate=0.0, exploit_percentile=100),
episode_count=eval_episode_count,
iteration_count=iteration_count,
epsilon=0.30, # 0.35,
render=False,
title="Exploiting Q-matrix",
verbosity=Verbosity.Quiet
)
# %%
randomlearning_results = learner.epsilon_greedy_search(
cyberbattlechain_10,
ep,
learner=a.QTabularLearner(ep, trained=qlearning_bestrun_10['learner'],
gamma=0.0, learning_rate=0.0, exploit_percentile=100),
episode_count=eval_episode_count,
iteration_count=iteration_count,
epsilon=1.0, # purely random
render=False,
verbosity=Verbosity.Quiet,
title="Random search"
)
# %%
# Plot averaged cumulative rewards for Q-learning vs Random vs Q-Exploit
all_runs = [*qlearning_results,
randomlearning_results,
nolearning_results
]
Q_source_10 = cast(a.QTabularLearner, qlearning_bestrun_10['learner']).qsource
Q_attack_10 = cast(a.QTabularLearner, qlearning_bestrun_10['learner']).qattack
p.plot_averaged_cummulative_rewards(
all_runs=all_runs,
title=f'Benchmark -- max_nodes={ep.maximum_node_count}, episodes={eval_episode_count},\n'
f'dimension={Q_source_10.state_space.flat_size()}x{Q_source_10.action_space.flat_size()}, '
f'{Q_attack_10.state_space.flat_size()}x{Q_attack_10.action_space.flat_size()}\n'
f'Q1={[f.name() for f in Q_source_10.state_space.feature_selection]} '
f'-> {[f.name() for f in Q_source_10.action_space.feature_selection]})\n'
f"Q2={[f.name() for f in Q_attack_10.state_space.feature_selection]} -> 'action'")
# %%
# plot cumulative rewards for all episodes
p.plot_all_episodes(qlearning_results[0])
# %%
# Plot the Q-matrices
# %%
# Print non-zero coordinate in the Q matrix Q_source
i = np.where(Q_source_10.qm)
q = Q_source_10.qm[i]
list(zip(np.array([Q_source_10.state_space.pretty_print(i) for i in i[0]]),
np.array([Q_source_10.action_space.pretty_print(i) for i in i[1]]), q))
# %%
# Print non-zero coordinate in the Q matrix Q_attack
i2 = np.where(Q_attack_10.qm)
q2 = Q_attack_10.qm[i2]
list(zip([Q_attack_10.state_space.pretty_print(i) for i in i2[0]],
[Q_attack_10.action_space.pretty_print(i) for i in i2[1]], q2))
##################################################
# %% [markdown]
# ## Transfer learning from size 4 to size 10
# Exploiting Q-matrix learned from a different network.
# %%
# Train Q-matrix on CyberBattle network of size 4
cyberbattlechain_4 = gym.make('CyberBattleChain-v0', size=4,
attacker_goal=AttackerGoal(own_atleast_percent=1.0)
)
qlearning_bestrun_4 = qlearning_run(0.015, gym_env=cyberbattlechain_4)
def stop_learning(trained_learner):
return TrainedLearner(
learner=a.QTabularLearner(
ep,
gamma=0.0,
learning_rate=0.0,
exploit_percentile=0,
trained=trained_learner['learner']
),
title=trained_learner['title'],
trained_on=trained_learner['trained_on'],
all_episodes_rewards=trained_learner['all_episodes_rewards'],
all_episodes_availability=trained_learner['all_episodes_availability']
)
learner.transfer_learning_evaluation(
environment_properties=ep,
trained_learner=stop_learning(qlearning_bestrun_4),
eval_env=cyberbattlechain_10,
eval_epsilon=0.5, # alternate with exploration to help generalization to bigger network
eval_episode_count=eval_episode_count,
iteration_count=iteration_count
)
learner.transfer_learning_evaluation(
environment_properties=ep,
trained_learner=stop_learning(qlearning_bestrun_10),
eval_env=cyberbattlechain_4,
eval_epsilon=0.5,
eval_episode_count=eval_episode_count,
iteration_count=iteration_count
)
# %%

Просмотреть файл

@ -0,0 +1,131 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Attacker agent benchmark comparison in presence of a basic defender
This notebooks can be run directly from VSCode, to generate a
traditional Jupyter Notebook to open in your browser
you can run the VSCode command `Export Currenty Python File As Jupyter Notebook`.
"""
# %%
import sys
import logging
import gym
import importlib
import cyberbattle.agents.baseline.learner as learner
import cyberbattle.agents.baseline.plotting as p
import cyberbattle.agents.baseline.agent_wrapper as w
import cyberbattle.agents.baseline.agent_dql as dqla
import cyberbattle.agents.baseline.agent_randomcredlookup as rca
from cyberbattle.agents.baseline.agent_wrapper import Verbosity
from cyberbattle._env.defender import ScanAndReimageCompromisedMachines
from cyberbattle._env.cyberbattle_env import AttackerGoal, DefenderConstraint
importlib.reload(learner)
importlib.reload(p)
logging.basicConfig(stream=sys.stdout, level=logging.ERROR, format="%(levelname)s: %(message)s")
cyberbattlechain_defender = gym.make('CyberBattleChain-v0',
size=10,
attacker_goal=AttackerGoal(
# reward=2180,
own_atleast=0,
own_atleast_percent=1.0
),
defender_constraint=DefenderConstraint(
maintain_sla=0.80
),
defender_agent=ScanAndReimageCompromisedMachines(
probability=0.6,
scan_capacity=2,
scan_frequency=5))
ep = w.EnvironmentBounds.of_identifiers(
maximum_total_credentials=22,
maximum_node_count=22,
identifiers=cyberbattlechain_defender.identifiers
)
iteration_count = 600
training_episode_count = 10
# %%
dqn_with_defender = learner.epsilon_greedy_search(
cyberbattle_gym_env=cyberbattlechain_defender,
environment_properties=ep,
learner=dqla.DeepQLearnerPolicy(
ep=ep,
gamma=0.15,
replay_memory_size=10000,
target_update=5,
batch_size=256,
learning_rate=0.01),
episode_count=training_episode_count,
iteration_count=iteration_count,
epsilon=0.90,
render=False,
epsilon_exponential_decay=5000,
epsilon_minimum=0.10,
verbosity=Verbosity.Quiet,
title="DQL"
)
# %%
dql_exploit_run = learner.epsilon_greedy_search(
cyberbattlechain_defender,
ep,
learner=dqn_with_defender['learner'],
episode_count=training_episode_count,
iteration_count=iteration_count,
epsilon=0.0, # 0.35,
render=False,
# render_last_episode_rewards_to='images/chain10',
verbosity=Verbosity.Quiet,
title="Exploiting DQL"
)
# %%
credlookup_run = learner.epsilon_greedy_search(
cyberbattlechain_defender,
ep,
learner=rca.CredentialCacheExploiter(),
episode_count=10,
iteration_count=iteration_count,
epsilon=0.90,
render=False,
epsilon_exponential_decay=10000,
epsilon_minimum=0.10,
verbosity=Verbosity.Quiet,
title="Credential lookups (ϵ-greedy)"
)
# %%
# Plots
all_runs = [
credlookup_run,
dqn_with_defender,
dql_exploit_run
]
p.plot_averaged_cummulative_rewards(
all_runs=all_runs,
title=f'Attacker agents vs Basic Defender -- rewards\n env={cyberbattlechain_defender.name}, episodes={training_episode_count}'
)
# p.plot_episodes_length(all_runs)
p.plot_averaged_availability(title=f"Attacker agents vs Basic Defender -- availability\n env={cyberbattlechain_defender.name}, episodes={training_episode_count}", all_runs=all_runs)
# %%
# %%
# %%
# %%

Просмотреть файл

@ -0,0 +1,203 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Plotting helpers for agent banchmarking"""
import matplotlib.pyplot as plt # type:ignore
import numpy as np
def new_plot(title):
"""Prepare a new plot of cumulative rewards"""
plt.figure(figsize=(10, 8))
plt.ylabel('cumulative reward', fontsize=20)
plt.xlabel('step', fontsize=20)
plt.xticks(size=20)
plt.yticks(size=20)
plt.title(title, fontsize=12)
def pad(array, length):
"""Pad an array with 0s to make it of desired length"""
padding = np.zeros((length,))
padding[:len(array)] = array
return padding
def plot_episodes_rewards_averaged(results):
"""Plot cumulative rewards for a given set of specified episodes"""
max_iteration_count = np.max([len(r) for r in results['all_episodes_rewards']])
all_episodes_rewards_padded = [pad(rewards, max_iteration_count) for rewards in results['all_episodes_rewards']]
cumrewards = np.cumsum(all_episodes_rewards_padded, axis=1)
avg = np.average(cumrewards, axis=0)
std = np.std(cumrewards, axis=0)
x = [i for i in range(len(std))]
plt.plot(x, avg, label=results['title'])
plt.fill_between(x, avg - std, avg + std, alpha=0.5)
def fill_with_latest_value(array, length):
pad = length - len(array)
if pad > 0:
return np.pad(array, (0, pad), mode='edge')
else:
return array
def plot_episodes_availability_averaged(results):
"""Plot availability for a given set of specified episodes"""
data = results['all_episodes_availability']
longest_episode_length = np.max([len(r) for r in data])
all_episodes_padded = [fill_with_latest_value(av, longest_episode_length) for av in data]
avg = np.average(all_episodes_padded, axis=0)
std = np.std(all_episodes_padded, axis=0)
x = [i for i in range(len(std))]
plt.plot(x, avg, label=results['title'])
plt.fill_between(x, avg - std, avg + std, alpha=0.5)
def plot_episodes_length(learning_results):
"""Plot length of every episode"""
plt.figure(figsize=(10, 8))
plt.ylabel('#iterations', fontsize=20)
plt.xlabel('episode', fontsize=20)
plt.xticks(size=20)
plt.yticks(size=20)
plt.title("Length of each episode", fontsize=12)
for results in learning_results:
iterations = [len(e) for e in results['all_episodes_rewards']]
episode = [i for i in range(len(results['all_episodes_rewards']))]
plt.plot(episode, iterations, label=f"{results['title']}")
plt.legend(loc="upper right")
plt.show()
def plot_each_episode(results):
"""Plot cumulative rewards for each episode"""
for i, episode in enumerate(results['all_episodes_rewards']):
cumrewards = np.cumsum(episode)
x = [i for i in range(len(cumrewards))]
plt.plot(x, cumrewards, label=f'Episode {i}')
def plot_all_episodes(r):
"""Plot cumulative rewards for every episode"""
new_plot(r['title'])
plot_each_episode(r)
plt.legend(loc="lower right")
plt.show()
def plot_averaged_cummulative_rewards(title, all_runs):
"""Plot averaged cumulative rewards"""
new_plot(title)
for r in all_runs:
plot_episodes_rewards_averaged(r)
plt.legend(loc="lower right")
plt.show()
def plot_averaged_availability(title, all_runs):
"""Plot averaged network availability"""
plt.figure(figsize=(10, 8))
plt.ylabel('network availability', fontsize=20)
plt.xlabel('step', fontsize=20)
plt.xticks(size=20)
plt.yticks(size=20)
plt.title(title, fontsize=12)
for r in all_runs:
plot_episodes_availability_averaged(r)
plt.legend(loc="lower right")
plt.show()
def new_plot_loss():
"""Plot MSE loss averaged over all episodes"""
plt.figure(figsize=(10, 8))
plt.ylabel('loss', fontsize=20)
plt.xlabel('episodes', fontsize=20)
plt.xticks(size=12)
plt.yticks(size=20)
plt.title("Loss", fontsize=12)
def plot_all_episodes_loss(all_episodes_losses, name, label):
"""Plot loss for one learning episode"""
x = [i for i in range(len(all_episodes_losses))]
plt.plot(x, all_episodes_losses, label=f'{name} {label}')
def running_mean(x, size):
"""return moving average of x for a window of lenght 'size'"""
cumsum = np.cumsum(np.insert(x, 0, 0))
return (cumsum[size:] - cumsum[:-size]) / float(size)
class PlotTraining:
"""Plot training-related stats"""
def __init__(self, title, render_each_episode):
self.episode_durations = []
self.title = title
self.render_each_episode = render_each_episode
def plot_durations(self, average_window=5):
# plt.figure(2)
plt.figure()
# plt.clf()
durations_t = np.array(self.episode_durations, dtype=np.float32)
plt.title('Training...')
plt.xlabel('Episode')
plt.ylabel('Duration')
plt.title(self.title, fontsize=12)
episodes = [i + 1 for i in range(len(self.episode_durations))]
plt.plot(episodes, durations_t)
# plot episode running averages
if len(durations_t) >= average_window:
means = running_mean(durations_t, average_window)
means = np.concatenate((np.zeros(average_window - 1), means))
plt.plot(episodes, means)
# display.display(plt.gcf())
plt.show()
def episode_done(self, length):
self.episode_durations.append(length)
if self.render_each_episode:
self.plot_durations()
def plot_end(self):
self.plot_durations()
plt.ioff() # type: ignore
# plt.show()
def length_of_all_episodes(run):
"""Get the length of every episode"""
return [len(e) for e in run['all_episodes_rewards']]
def reduce(x, desired_width):
return [np.average(c) for c in np.array_split(x, desired_width)]
def episodes_rewards_averaged(run):
"""Plot cumulative rewards for a given set of specified episodes"""
max_iteration_count = np.max([len(r) for r in run['all_episodes_rewards']])
all_episodes_rewards_padded = [pad(rewards, max_iteration_count) for rewards in run['all_episodes_rewards']]
cumrewards = np.cumsum(all_episodes_rewards_padded, axis=1)
avg = np.average(cumrewards, axis=0)
return list(avg)
def episodes_lengths_for_all_runs(all_runs):
return [length_of_all_episodes(run) for run in all_runs]
def averaged_cummulative_rewards(all_runs, width):
return [reduce(episodes_rewards_averaged(run), width) for run in all_runs]

Просмотреть файл

@ -0,0 +1,121 @@
#!/usr/bin/python3.8
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# -*- coding: utf-8 -*-
"""CLI to run the baseline Deep Q-learning and Random agents
on a sample CyberBattle gym environment and plot the respective
cummulative rewards in the terminal.
Example usage:
python3.8 -m run --training_episode_count 50 --iteration_count 9000 --rewardplot_with 80 --chain_size=20 --ownership_goal 1.0
"""
import torch
import gym
import logging
import sys
import asciichartpy
import argparse
import cyberbattle._env.cyberbattle_env as cyberbattle_env
from cyberbattle.agents.baseline.agent_wrapper import Verbosity
import cyberbattle.agents.baseline.agent_dql as dqla
import cyberbattle.agents.baseline.agent_wrapper as w
import cyberbattle.agents.baseline.plotting as p
import cyberbattle.agents.baseline.learner as learner
parser = argparse.ArgumentParser(description='Run simulation with DQL baseline agent.')
parser.add_argument('--training_episode_count', default=50, type=int,
help='number of training epochs')
parser.add_argument('--eval_episode_count', default=10, type=int,
help='number of evaluation epochs')
parser.add_argument('--iteration_count', default=9000, type=int,
help='number of simulation iterations for each epoch')
parser.add_argument('--reward_goal', default=2180, type=int,
help='minimum target rewards to reach for the attacker to reach its goal')
parser.add_argument('--ownership_goal', default=1.0, type=float,
help='percentage of network nodes to own for the attacker to reach its goal')
parser.add_argument('--rewardplot_with', default=80, type=int,
help='width of the reward plot (values are averaged across iterations to fit in the desired width)')
parser.add_argument('--chain_size', default=4, type=int,
help='size of the chain of the CyberBattleChain sample environment')
parser.add_argument('--random_agent', dest='run_random_agent', action='store_true', help='run the random agent as a baseline for comparison')
parser.add_argument('--no-random_agent', dest='run_random_agent', action='store_false', help='do not run the random agent as a baseline for comparison')
parser.set_defaults(run_random_agent=True)
args = parser.parse_args()
logging.basicConfig(stream=sys.stdout, level=logging.ERROR, format="%(levelname)s: %(message)s")
print(f"torch cuda available={torch.cuda.is_available()}")
cyberbattlechain = gym.make('CyberBattleChain-v0',
size=args.chain_size,
attacker_goal=cyberbattle_env.AttackerGoal(
own_atleast_percent=args.ownership_goal,
reward=args.reward_goal))
ep = w.EnvironmentBounds.of_identifiers(
maximum_total_credentials=22,
maximum_node_count=22,
identifiers=cyberbattlechain.identifiers
)
all_runs = []
# Run Deep Q-learning
dqn_learning_run = learner.epsilon_greedy_search(
cyberbattle_gym_env=cyberbattlechain,
environment_properties=ep,
learner=dqla.DeepQLearnerPolicy(
ep=ep,
gamma=0.015,
replay_memory_size=10000,
target_update=10,
batch_size=512,
learning_rate=0.01), # torch default is 1e-2
episode_count=args.training_episode_count,
iteration_count=args.iteration_count,
epsilon=0.90,
render=False,
# epsilon_multdecay=0.75, # 0.999,
epsilon_exponential_decay=5000, # 10000
epsilon_minimum=0.10,
verbosity=Verbosity.Quiet,
title="DQL"
)
all_runs.append(dqn_learning_run)
if args.run_random_agent:
random_run = learner.epsilon_greedy_search(
cyberbattlechain,
ep,
learner=learner.RandomPolicy(),
episode_count=args.eval_episode_count,
iteration_count=args.iteration_count,
epsilon=1.0, # purely random
render=False,
verbosity=Verbosity.Quiet,
title="Random search"
)
all_runs.append(random_run)
colors = [asciichartpy.red, asciichartpy.green, asciichartpy.yellow, asciichartpy.blue]
print("Episode duration -- DQN=Red, Random=Green")
print(asciichartpy.plot(p.episodes_lengths_for_all_runs(all_runs), {'height': 30, 'colors': colors}))
print("Cumulative rewards -- DQN=Red, Random=Green")
c = p.averaged_cummulative_rewards(all_runs, args.rewardplot_with)
print(asciichartpy.plot(c, {'height': 10, 'colors': colors}))

Просмотреть файл

@ -0,0 +1,41 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Helper to run the random agent from a jupyter notebook"""
import cyberbattle._env.cyberbattle_env as cyberbattle_env
import logging
LOGGER = logging.getLogger(__name__)
def run_random_agent(episode_count: int, iteration_count: int, gym_env: cyberbattle_env.CyberBattleEnv):
"""Run a simple random agent on the specified gym environment and
plot exploration graph and reward function
"""
for i_episode in range(episode_count):
observation = gym_env.reset()
total_reward = 0.0
for t in range(iteration_count):
action = gym_env.sample_valid_action()
LOGGER.debug(f"action={action}")
observation, reward, done, info = gym_env.step(action)
total_reward += reward
if reward > 0:
print(f'+ rewarded action: {action} total_reward={total_reward} reward={reward} @t={t}')
gym_env.render()
if done:
print(f"Episode finished after {t+1} timesteps")
break
gym_env.render()
gym_env.close()
print("simulation ended")

Просмотреть файл

@ -0,0 +1,286 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Defines a set of networks following a speficic pattern
learnable from the properties associated with the nodes.
The network pattern is:
Start ---> (Linux ---> Windows ---> ... Linux ---> Windows)* ---> Linux[Flag]
The network is parameterized by the length of the central Linux-Windows chain.
The start node leaks the credentials to connect to all other nodes:
For each `XXX ---> Windows` section, the XXX node has:
- a local vulnerability exposing the RDP password to the Windows machine
- a bunch of other trap vulnerabilities (high cost with no outcome)
For each `XXX ---> Linux` section,
- the Windows node has a local vulnerability exposing the SSH password to the Linux machine
- a bunch of other trap vulnerabilities (high cost with no outcome)
The chain is terminated by one node with a flag (reward).
A Node-Property matrix would be three-valued (0,1,?) and look like this:
===== Initial state
Properties
Nodes L W SQL
1 1 0 0
2 ? ? ?
3 ? ? ?
...
10
======= After discovering node 2
Properties
Nodes L W SQL
1 1 0 0
2 0 1 1
3 ? ? ?
...
10
===========================
"""
from cyberbattle.simulation.model import Identifiers, NodeID, NodeInfo
from ...simulation import model as m
from typing import Dict
DEFAULT_ALLOW_RULES = [
m.FirewallRule("RDP", m.RulePermission.ALLOW),
m.FirewallRule("SSH", m.RulePermission.ALLOW),
m.FirewallRule("HTTPS", m.RulePermission.ALLOW),
m.FirewallRule("HTTP", m.RulePermission.ALLOW)]
# Environment constants used for all instances of the chain network
ENV_IDENTIFIERS = Identifiers(
properties=[
'Windows',
'Linux',
'ApacheWebSite',
'IIS_2019',
'IIS_2020_patched',
'MySql',
'Ubuntu',
'nginx/1.10.3',
'SMB_vuln',
'SMB_vuln_patched',
'SQLServer',
'Win10',
'Win10Patched',
'FLAG:Linux'
],
ports=[
'HTTPS',
'GIT',
'SSH',
'RDP',
'PING',
'MySQL',
'SSH-key',
'su'
],
local_vulnerabilities=[
'ScanBashHistory',
'ScanExplorerRecentFiles',
'SudoAttempt',
'CrackKeepPassX',
'CrackKeepPass'
],
remote_vulnerabilities=[
'ProbeLinux',
'ProbeWindows'
]
)
def prefix(x: int, name: str):
"""Prefix node name with an instance"""
return f"{x}_{name}"
def rdp_password(index):
"""Generate RDP password for the specified chain link"""
return f"WindowsPassword!{index}"
def ssh_password(index):
"""Generate SSH password for the specified chain link"""
return f"LinuxPassword!{index}"
def create_network_chain_link(n: int) -> Dict[NodeID, NodeInfo]:
"""Instantiate one link of the network chain with associated index n"""
def current(name):
return prefix(n, name)
def next(name):
return prefix(n + 1, name)
return {
current("LinuxNode"): m.NodeInfo(
services=[m.ListeningService("HTTPS"),
m.ListeningService("SSH", allowedCredentials=[ssh_password(n)])],
firewall=m.FirewallConfiguration(incoming=DEFAULT_ALLOW_RULES,
outgoing=DEFAULT_ALLOW_RULES),
value=100,
properties=["MySql", "Ubuntu", "nginx/1.10.3"],
owned_string="Intermediate chain node owned, no intinsic value",
vulnerabilities=dict(
ProbeLinux=m.VulnerabilityInfo(
description="Probe to check if the node runs Linux",
type=m.VulnerabilityType.REMOTE,
outcome=m.ProbeSucceeded(["Ubuntu"]),
reward_string="Remote machine is running Linux",
cost=5.0
),
ProbeWindows=m.VulnerabilityInfo(
description="Probe to check if the node runs Windows",
type=m.VulnerabilityType.REMOTE,
outcome=m.ProbeFailed(),
reward_string="Remote machine is not running Windows",
cost=5.0
),
ScanBashHistory=m.VulnerabilityInfo(
description="Scan bash history for possible references to other machines",
type=m.VulnerabilityType.LOCAL,
outcome=m.LeakedNodesId([next("WindowsNode")]),
reward_string="Found a reference to a remote Windows node in bash history",
cost=1.0
),
ScanExplorerRecentFiles=m.VulnerabilityInfo(
description="Scan Windows Explorer recent files for possible references to other machines",
type=m.VulnerabilityType.LOCAL,
outcome=m.ExploitFailed(),
reward_string="Trap: feature not supported on Linux",
cost=10.0
),
SudoAttempt=m.VulnerabilityInfo(
description="Attempt to sudo into admin user",
type=m.VulnerabilityType.LOCAL,
outcome=m.ExploitFailed(),
reward_string="Trap: suspicious attempt to run sudo",
cost=100.0
),
CrackKeepPassX=m.VulnerabilityInfo(
description="Attempt to crack KeepPassX and look for credentials",
type=m.VulnerabilityType.LOCAL,
outcome=m.LeakedCredentials(credentials=[
m.CachedCredential(node=next("WindowsNode"), port="RDP",
credential=rdp_password(n + 1))]),
reward_string=f"Discovered password to Windows machine {n+1}",
cost=1.0
))),
next("WindowsNode"): m.NodeInfo(
services=[m.ListeningService("HTTPS"),
m.ListeningService("RDP", allowedCredentials=[rdp_password(n + 1)])],
value=100,
properties=["Windows", "Win10", "Win10Patched"],
vulnerabilities=dict(
ProbeLinux=m.VulnerabilityInfo(
description="Probe to check if the node runs Linux",
type=m.VulnerabilityType.REMOTE,
outcome=m.ProbeFailed(),
reward_string="Remote machine is not running Linux",
cost=1.0
),
ProbeWindows=m.VulnerabilityInfo(
description="Probe to check if the node runs Windows",
type=m.VulnerabilityType.REMOTE,
outcome=m.ProbeSucceeded(["Windows"]),
reward_string="Remote machine is running Windows",
cost=1.0
),
ScanBashHistory=m.VulnerabilityInfo(
description="Scan bash history for possible references to other machines",
type=m.VulnerabilityType.LOCAL,
outcome=m.ExploitFailed(),
reward_string="Trap: feature not supported on Windows!",
cost=100.0
),
ScanExplorerRecentFiles=m.VulnerabilityInfo(
description="Scan Windows Explorer recent files for possible references to other machines",
type=m.VulnerabilityType.LOCAL,
outcome=m.LeakedNodesId([prefix(n + 2, "LinuxNode")]),
reward_string="Found a reference to a remote Linux node in bash history",
cost=1.0
),
SudoAttempt=m.VulnerabilityInfo(
description="Attempt to sudo into admin user",
type=m.VulnerabilityType.LOCAL,
outcome=m.ExploitFailed(),
reward_string="Trap: feature not supported on Windows!",
cost=100.0
),
CrackKeepPassX=m.VulnerabilityInfo(
description="Attempt to crack KeepPassX and look for credentials",
type=m.VulnerabilityType.LOCAL,
outcome=m.ExploitFailed(),
reward_string="Trap: feature not supported on Windows!",
cost=100.0
),
CrackKeepPass=m.VulnerabilityInfo(
description="Attempt to crack KeepPass and look for credentials",
type=m.VulnerabilityType.LOCAL,
outcome=m.LeakedCredentials(credentials=[
m.CachedCredential(node=prefix(n + 2, "LinuxNode"), port="SSH",
credential=ssh_password(n + 2))]),
reward_string=f"Discovered password to Linux machine {n+2}",
cost=1.0
)
))
}
def create_chain_network(size: int) -> Dict[NodeID, NodeInfo]:
"""Create a chain network with the chain section of specified size.
Size must be an even number
The number of nodes in the network is `size + 2` to account for the start node (0)
and final node (size + 1).
"""
if size % 2 == 1:
raise ValueError(f"Chain size must be even: {size}")
final_node_index = size + 1
nodes = {
'start': m.NodeInfo(
services=[],
value=0,
vulnerabilities=dict(
ScanExplorerRecentFiles=m.VulnerabilityInfo(
description="Scan Windows Explorer recent files for possible references to other machines",
type=m.VulnerabilityType.LOCAL,
outcome=m.LeakedCredentials(credentials=[
m.CachedCredential(node=prefix(1, "LinuxNode"), port="SSH",
credential=ssh_password(1))]),
reward_string="Found a reference to a remote Linux node in bash history",
cost=1.0
)),
agent_installed=True,
reimagable=False),
prefix(final_node_index, "LinuxNode"): m.NodeInfo(
services=[m.ListeningService("HTTPS"),
m.ListeningService("SSH", allowedCredentials=[ssh_password(final_node_index)])],
value=1000,
owned_string="FLAG: flag discovered!",
properties=["MySql", "Ubuntu", "nginx/1.10.3", "FLAG:Linux"],
vulnerabilities=dict()
)
}
# Add chain links
for i in range(1, size, 2):
nodes.update(create_network_chain_link(i))
return nodes
def new_environment(size) -> m.Environment:
return m.Environment(
network=m.create_network(create_chain_network(size)),
vulnerability_library=dict([]),
identifiers=ENV_IDENTIFIERS
)

Просмотреть файл

@ -0,0 +1,46 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""A simple test sandbox to play with creation of simulation environments"""
import networkx as nx
import yaml
from cyberbattle.simulation import model, model_test, actions_test
def main() -> None:
"""Simple environment sandbox"""
# Create a toy graph
graph = nx.DiGraph()
graph.add_edges_from([('a', 'b'), ('b', 'c')])
print(graph)
# create a random graph
graph = nx.cubical_graph()
graph = model.assign_random_labels(graph)
vulnerabilities = actions_test.SAMPLE_VULNERABILITIES
model.setup_yaml_serializer()
# Define an environment from this graph
env = model.Environment(
network=graph,
vulnerability_library=vulnerabilities,
identifiers=actions_test.ENV_IDENTIFIERS
)
model_test.check_reserializing(env)
model_test.check_reserializing(vulnerabilities)
# Save the environment to file as Yaml
with open('./simpleenv.yaml', 'w') as file:
yaml.dump(env, file)
print(yaml.dump(env))
if __name__ == '__main__':
main()

Просмотреть файл

@ -0,0 +1,240 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Model a toy Capture the flag exercise
See Jupyter notebook toyctf-simulation.ipynb for an example of
game played on this simulation.
"""
from cyberbattle.simulation import model as m
from cyberbattle.simulation.model import NodeID, NodeInfo, VulnerabilityID, VulnerabilityInfo
from typing import Dict, Iterator, cast, Tuple
default_allow_rules = [
m.FirewallRule("RDP", m.RulePermission.ALLOW),
m.FirewallRule("SSH", m.RulePermission.ALLOW),
m.FirewallRule("HTTPS", m.RulePermission.ALLOW),
m.FirewallRule("HTTP", m.RulePermission.ALLOW)]
# Network nodes involved in the Capture the flag game
nodes = {
"Website": m.NodeInfo(
services=[m.ListeningService("HTTPS"),
m.ListeningService("SSH", allowedCredentials=[
"ReusedMySqlCred-web"])],
firewall=m.FirewallConfiguration(incoming=default_allow_rules,
outgoing=default_allow_rules + [
m.FirewallRule("su", m.RulePermission.ALLOW),
m.FirewallRule("sudo", m.RulePermission.ALLOW)]),
value=100,
# If can SSH into server then gets FLAG "Shared credentials with
# database user"
properties=["MySql", "Ubuntu", "nginx/1.10.3"],
owned_string="FLAG: Login using insecure SSH user/password",
vulnerabilities=dict(
ScanPageContent=m.VulnerabilityInfo(
description="LeakedGitHubProjectUrl: Website page content shows a link to GitHub "
"repo",
type=m.VulnerabilityType.REMOTE,
outcome=m.LeakedNodesId(["GitHubProject"]),
reward_string="WEBSITE page content has a link to github -> Github project discovered!",
cost=1.0
),
ScanPageSource=m.VulnerabilityInfo(
description="Website page source contains refrence to browseable "
"relative web directory",
type=m.VulnerabilityType.REMOTE,
outcome=m.LeakedNodesId(["Website.Directory"]),
reward_string="Viewing the web page source reveals a URL to a .txt file and directory on the website",
cost=1.0
),
CredScanBashHistory=m.VulnerabilityInfo(
description="bash history leaking creds - FLAG Stealing "
"credentials for the monitoring user",
type=m.VulnerabilityType.LOCAL,
outcome=m.LeakedCredentials(credentials=[
m.CachedCredential(node="Website[user=monitor]", port="SSH",
credential="monitorBashCreds")]),
reward_string="FLAG: SSH history revealed credentials for the monitoring user (monitor)",
cost=1.0
))),
"Website.Directory": m.NodeInfo(
services=[m.ListeningService("HTTPS")],
value=50,
properties=["Ubuntu", "nginx/1.10.3",
"CTFFLAG:Readme.txt-Discover secret data"
],
vulnerabilities=dict(
NavigateWebDirectoryFurther=m.VulnerabilityInfo(
description="Discover MYSQL credentials MySql for user "
"'web' in (getting-started.txt)",
type=m.VulnerabilityType.REMOTE,
outcome=m.LeakedCredentials(credentials=[
m.CachedCredential(node="Website", port="MySQL",
credential="ReusedMySqlCred-web")]),
reward_string="Discover browseable web directory: Navigating to parent URL revealed file `readme.txt`"
"with secret data (aflag); and `getting-started.txt` with MYSQL credentials",
cost=1.0
),
NavigateWebDirectory=m.VulnerabilityInfo(
description="Discover URL to external sharepoint website "
"(in deprecation-checklist.txt)",
type=m.VulnerabilityType.REMOTE,
outcome=m.LeakedNodesId(["Sharepoint"]),
reward_string="Navigating to parent URL revealed file `deprecation-checklist.txt` containing"
"a URL to an external sharepoint website",
cost=1.0
)
)),
"Website[user=monitor]": m.NodeInfo(
services=[m.ListeningService("SSH", allowedCredentials=[]),
m.ListeningService("SSH-key", allowedCredentials=["unkownkey"]),
m.ListeningService("su", allowedCredentials=["monitorBashCreds"])],
value=100,
properties=["MySql", "Ubuntu", "nginx/1.10.3"],
owned_string="FLAG User escalation by stealing credentials from bash history",
firewall=m.FirewallConfiguration(
outgoing=default_allow_rules,
incoming=[m.FirewallRule("SSH", m.RulePermission.BLOCK,
reason="password authentication disabled! SSH needs private key to authenticate."),
m.FirewallRule("sudo", m.RulePermission.BLOCK,
reason="`sudo -u monitor` failed. User 'monitor' not sudoable."
"This warning will be reported!"),
m.FirewallRule("su", m.RulePermission.ALLOW)] + default_allow_rules
),
vulnerabilities={
"CredScan-HomeDirectory":
m.VulnerabilityInfo(
description="azurecredential.txt file in home directory",
type=m.VulnerabilityType.LOCAL,
outcome=m.LeakedCredentials(credentials=[
m.CachedCredential(
node="AzureResourceManager[user=monitor]",
port="HTTPS",
credential="azuread_user_credentials")]),
reward_string="SSH: cat ~/azurecreds.txt (running as monitor) revealed Azure user credential!",
cost=1.0),
}),
"GitHubProject": m.NodeInfo(
services=[m.ListeningService("GIT")],
value=10,
properties=["GitHub", "SasUrlInCommit"],
vulnerabilities=dict(
CredScanGitHistory=m.VulnerabilityInfo(
description="Some secure access token (SAS) leaked in a "
"reverted git commit",
type=m.VulnerabilityType.REMOTE,
precondition=m.Precondition('SasUrlInCommit&GitHub'),
outcome=m.LeakedCredentials(credentials=[
m.CachedCredential(node="AzureStorage",
port="HTTPS",
credential="SASTOKEN1")]),
rates=m.Rates(probingDetectionRate=0.0,
exploitDetectionRate=0.0,
successRate=1.0),
reward_string="CredScan success: Some secure access token (SAS) was leaked in a reverted git commit",
cost=1.0
))),
"AzureStorage": m.NodeInfo(
services=[
m.ListeningService("HTTPS", allowedCredentials=["SASTOKEN1"])],
value=50,
properties=["CTFFLAG:LeakedCustomerData"],
vulnerabilities=dict(
AccessDataWithSASToken=m.VulnerabilityInfo(
description="Stealing secrets using a publicly shared "
"SAS token",
type=m.VulnerabilityType.REMOTE,
outcome=m.CustomerData(),
rates=m.Rates(successRate=1.0),
reward_string="Stole data using a publicly shared SAS token",
cost=1.0
)
)),
'Sharepoint': m.NodeInfo(
services=[m.ListeningService("HTTPS")],
value=100,
properties=["SharepointLeakingPassword"],
firewall=m.FirewallConfiguration(incoming=[m.FirewallRule("SSH", m.RulePermission.ALLOW),
m.FirewallRule("HTTP", m.RulePermission.ALLOW),
m.FirewallRule("HTTPS", m.RulePermission.ALLOW)],
outgoing=[]),
vulnerabilities=dict(
ScanSharepointParentDirectory=m.VulnerabilityInfo(
description="Navigate to SharePoint site, browse parent "
"directory",
type=m.VulnerabilityType.REMOTE,
outcome=m.LeakedCredentials(credentials=[
m.CachedCredential(node="AzureResourceManager",
port="HTTPS",
credential="ADPrincipalCreds")]),
rates=m.Rates(successRate=1.0),
reward_string="Navigating to the Sharepoint site revealed AD Service Principal Credentials",
cost=1.0)
)),
'AzureResourceManager': m.NodeInfo(
services=[m.ListeningService("HTTPS", allowedCredentials=["ADPrincipalCreds", "azuread_user_credentials"])],
owned_string="FLAG: Shared credentials with database user - Obtained secrets hidden in Azure Managed Resources",
value=50,
properties=["CTFFLAG:LeakedCustomerData2"],
vulnerabilities=dict(
ListAzureResources=m.VulnerabilityInfo(
description="AzureVM info, including public IP address",
type=m.VulnerabilityType.REMOTE,
outcome=m.LeakedNodesId(["AzureVM"]),
reward_string="Obtained Azure VM and public IP information",
cost=1.0
))),
'AzureResourceManager[user=monitor]': m.NodeInfo(
services=[m.ListeningService("HTTPS", allowedCredentials=["azuread_user_credentials"])],
owned_string="More secrets stolen when logged as interactive `monitor` user in Azure with `az`",
value=50,
properties=[],
),
'AzureVM': m.NodeInfo(
services=[m.ListeningService("PING"),
m.ListeningService("SSH")],
value=100,
properties=["CTFFLAG:VMPRIVATEINFO"],
firewall=m.FirewallConfiguration(
incoming=[m.FirewallRule("SSH", m.RulePermission.BLOCK,
reason="internet incoming traffic blocked on the VM by NSG firewall")],
outgoing=[])),
'client': m.NodeInfo(
services=[],
value=0,
vulnerabilities=dict(
SearchEdgeHistory=m.VulnerabilityInfo(
description="Search web history for list of accessed websites",
type=m.VulnerabilityType.LOCAL,
outcome=m.LeakedNodesId(["Website"]),
reward_string="Web browser history revealed website URL of interest",
cost=1.0
)),
agent_installed=True,
reimagable=False),
}
global_vulnerability_library: Dict[VulnerabilityID, VulnerabilityInfo] = dict([])
# Environment constants
ENV_IDENTIFIERS = m.infer_constants_from_nodes(
cast(Iterator[Tuple[NodeID, NodeInfo]], list(nodes.items())),
global_vulnerability_library)
def new_environment() -> m.Environment:
return m.Environment(
network=m.create_network(nodes),
vulnerability_library=global_vulnerability_library,
identifiers=ENV_IDENTIFIERS
)

Просмотреть файл

Просмотреть файл

@ -0,0 +1,649 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
actions.py
This file contains the class and associated methods for the AgentActions
class which interacts directly with the environment. It is the class
which both the user and RL agents should manipulate the environment.
"""
from dataclasses import dataclass
import dataclasses
from datetime import time
import boolean
from collections import OrderedDict
import logging
from enum import Enum
from typing import Iterator, List, NamedTuple, Optional, Set, Tuple, Dict, TypedDict, cast
import IPython.core.display as d
import pandas as pd
from cyberbattle.simulation.model import MachineStatus, PrivilegeLevel, PropertyName, VulnerabilityID, VulnerabilityType
from . import model
logger = logging.getLogger(__name__)
Reward = float
DiscoveredNodeInfo = TypedDict('DiscoveredNodeInfo', {
'id': model.NodeID,
'status': str
})
class Penalty:
"""Penalties (=negative reward) returned for some actions taken in the simulation"""
# penalty for generic suspiciousness
SUPSPICIOUSNESS = -5.0
# penalty for attempting a connection to a port that was not open
SCANNING_UNOPEN_PORT = -10.0
# penalty for repeating the same exploit attempt
REPEAT = -1
LOCAL_EXPLOIT_FAILED = -20
FAILED_REMOTE_EXPLOIT = -50
# penalty for attempting to connect or execute an action on a node that's not in running state
MACHINE_NOT_RUNNING = 0
# penalty for attempting a connection with an invalid password
WRONG_PASSWORD = -10
# traffice blocked by outoing rule in a local firewall
BLOCKED_BY_LOCAL_FIREWALL = -10
# traffice blocked by incoming rule in a remote firewall
BLOCKED_BY_REMOTE_FIREWALL = -10
# Reward for any successfully executed local or remote attack
# (the attack cost gets substracted from this reward)
SUCCEEDED_ATTACK_REWARD = 50
class EdgeAnnotation(Enum):
"""Annotation added to the network edges created as the simulation is played"""
KNOWS = 0
REMOTE_EXPLOIT = 1
LATERAL_MOVE = 2
class ActionResult(NamedTuple):
"""Result from executing an action"""
reward: Reward
outcome: Optional[model.VulnerabilityOutcome]
ALGEBRA = boolean.BooleanAlgebra()
ALGEBRA.TRUE.dual = type(ALGEBRA.FALSE)
ALGEBRA.FALSE.dual = type(ALGEBRA.TRUE)
@dataclass
class NodeTrackingInformation:
"""Track information about nodes gathered throughout the simulation"""
# Map (vulnid, local_or_remote) to time of last attack.
# local_or_remote is true for local, false for remote
last_attack: Dict[Tuple[model.VulnerabilityID, bool], time] = dataclasses.field(default_factory=dict)
# Last time another node connected to this node
last_connection: Optional[time] = None
# All node properties discovered so far
discovered_properties: Set[int] = dataclasses.field(default_factory=set)
class AgentActions:
"""
This is the AgentActions class. It interacts with and makes changes to the environment.
"""
def __init__(self, environment: model.Environment):
"""
AgentActions Constructor
"""
self._environment = environment
self._gathered_credentials: Set[model.CredentialID] = set()
self._discovered_nodes: "OrderedDict[model.NodeID, NodeTrackingInformation]" = OrderedDict()
# List of all special tags indicating a privilege level reached on a node
self.privilege_tags = [model.PrivilegeEscalation(p).tag for p in list(PrivilegeLevel)]
# Mark all owned nodes as discovered
for i, node in environment.nodes():
if node.agent_installed:
self.__mark_node_as_owned(i, PrivilegeLevel.LocalUser)
def discovered_nodes(self) -> Iterator[Tuple[model.NodeID, model.NodeInfo]]:
for node_id in self._discovered_nodes:
yield (node_id, self._environment.get_node(node_id))
def _check_prerequisites(self, target: model.NodeID, vulnerability: model.VulnerabilityInfo) -> bool:
"""
This is a quick helper function to check the prerequisites to see if
they match the ones supplied.
"""
node: model.NodeInfo = self._environment.network.nodes[target]['data']
node_flags = node.properties
expr = vulnerability.precondition.expression
# this line seems redundant but it is necessary to declare the symbols used in the mapping
# pylint: disable=unused-variable
mapping = {i: ALGEBRA.TRUE if str(i) in node_flags else ALGEBRA.FALSE
for i in expr.get_symbols()}
is_true: bool = cast(boolean.Expression, expr.subs(mapping)).simplify() == ALGEBRA.TRUE
return is_true
def list_vulnerabilities_in_target(
self,
target: model.NodeID,
type_filter: Optional[model.VulnerabilityType] = None) -> List[model.VulnerabilityID]:
"""
This function takes a model.NodeID for the target to be scanned
and returns a list of vulnerability IDs.
It checks each vulnerability in the library against the the properties of a given node
and determines which vulnerabilities it has.
"""
if not self._environment.network.has_node(target):
raise ValueError(f"invalid node id '{target}'")
target_node_data: model.NodeInfo = self._environment.get_node(target)
global_vuln: Set[model.VulnerabilityID] = {
vuln_id
for vuln_id, vulnerability in self._environment.vulnerability_library.items()
if (type_filter is None or vulnerability.type == type_filter)
and self._check_prerequisites(target, vulnerability)
}
local_vuln: Set[model.VulnerabilityID] = {
vuln_id
for vuln_id, vulnerability in target_node_data.vulnerabilities.items()
if (type_filter is None or vulnerability.type == type_filter)
and self._check_prerequisites(target, vulnerability)
}
return list(global_vuln.union(local_vuln))
def __annotate_edge(self, source_node_id: model.NodeID,
target_node_id: model.NodeID,
new_annotation: EdgeAnnotation) -> None:
"""Create the edge if it does not already exist, and annotate with the maximum
of the existing annotation and a specified new annotation"""
edge_annotation = self._environment.network.get_edge_data(source_node_id, target_node_id)
if edge_annotation is not None:
if 'kind' in edge_annotation:
new_annotation = EdgeAnnotation(max(edge_annotation['kind'].value, new_annotation.value))
else:
new_annotation = new_annotation.value
self._environment.network.add_edge(source_node_id, target_node_id, kind=new_annotation, kind_as_float=float(new_annotation.value))
def get_discovered_properties(self, node_id: model.NodeID) -> Set[int]:
return self._discovered_nodes[node_id].discovered_properties
def __mark_node_as_discovered(self, node_id: model.NodeID) -> None:
logger.info('discovered node: ' + node_id)
if node_id not in self._discovered_nodes:
self._discovered_nodes[node_id] = NodeTrackingInformation()
def __mark_nodeproperties_as_discovered(self, node_id: model.NodeID, properties: List[PropertyName]):
properties_indices = [self._environment.identifiers.properties.index(p)
for p in properties
if p not in self.privilege_tags]
if node_id in self._discovered_nodes:
self._discovered_nodes[node_id].discovered_properties = self._discovered_nodes[node_id].discovered_properties.union(properties_indices)
else:
self._discovered_nodes[node_id] = NodeTrackingInformation(discovered_properties=set(properties_indices))
def __mark_allnodeproperties_as_discovered(self, node_id: model.NodeID):
node_info: model.NodeInfo = self._environment.network.nodes[node_id]['data']
self.__mark_nodeproperties_as_discovered(node_id, node_info.properties)
def __mark_node_as_owned(self,
node_id: model.NodeID,
privilege: PrivilegeLevel = model.PrivilegeLevel.LocalUser) -> None:
if node_id not in self._discovered_nodes:
self._discovered_nodes[node_id] = NodeTrackingInformation()
node_info = self._environment.get_node(node_id)
node_info.agent_installed = True
node_info.privilege_level = model.escalate(node_info.privilege_level, privilege)
self._environment.network.nodes[node_id].update({'data': node_info})
self.__mark_allnodeproperties_as_discovered(node_id)
def __mark_discovered_entities(self, reference_node: model.NodeID, outcome: model.VulnerabilityOutcome) -> None:
if isinstance(outcome, model.LeakedCredentials):
for credential in outcome.credentials:
self.__mark_node_as_discovered(credential.node)
self._gathered_credentials.add(credential.credential)
logger.info('discovered credential: ' + str(credential))
self.__annotate_edge(reference_node, credential.node, EdgeAnnotation.KNOWS)
elif isinstance(outcome, model.LeakedNodesId):
for node_id in outcome.nodes:
self.__mark_node_as_discovered(node_id)
self.__annotate_edge(reference_node, node_id, EdgeAnnotation.KNOWS)
def get_node_privilegelevel(self, node_id: model.NodeID) -> model.PrivilegeLevel:
"""Return the last recorded privilege level of the specified node"""
node_info = self._environment.get_node(node_id)
return node_info.privilege_level
def get_nodes_with_atleast_privilegelevel(self, level: PrivilegeLevel) -> List[model.NodeID]:
"""Return all nodes with at least the specified privilege level"""
return [n for n, info in self._environment.nodes() if info.privilege_level >= level]
def is_node_discovered(self, node_id: model.NodeID) -> bool:
"""Returns true if previous actions have revealed the specified node ID"""
return node_id in self._discovered_nodes
def __process_outcome(self,
expected_type: VulnerabilityType,
vulnerability_id: VulnerabilityID,
node_id: model.NodeID,
node_info: model.NodeInfo,
local_or_remote: bool,
failed_penalty: float,
throw_if_vulnerability_not_present: bool
) -> Tuple[bool, ActionResult]:
if node_info.status != model.MachineStatus.Running:
logger.info("target machine not in running state")
return False, ActionResult(reward=Penalty.MACHINE_NOT_RUNNING,
outcome=None)
is_global_vulnerability = vulnerability_id in self._environment.vulnerability_library
is_inplace_vulnerability = vulnerability_id in node_info.vulnerabilities
if is_global_vulnerability:
vulnerabilities = self._environment.vulnerability_library
elif is_inplace_vulnerability:
vulnerabilities = node_info.vulnerabilities
else:
if throw_if_vulnerability_not_present:
raise ValueError(f"Vulnerability '{vulnerability_id}' not supported by node='{node_id}'")
else:
logger.info(f"Vulnerability '{vulnerability_id}' not supported by node '{node_id}'")
return False, ActionResult(reward=Penalty.SUPSPICIOUSNESS, outcome=None)
vulnerability = vulnerabilities[vulnerability_id]
outcome = vulnerability.outcome
if vulnerability.type != expected_type:
raise ValueError(f"vulnerability id '{vulnerability_id}' is for an attack of type {vulnerability.type}, expecting: {expected_type}")
# check vulnerability prerequisites
if not self._check_prerequisites(node_id, vulnerability):
return False, ActionResult(reward=failed_penalty, outcome=model.ExploitFailed())
# if the vulnerability type is a privilege escalation
# and if the escalation level is not already reached on that node,
# then add the escalation tag to the node properties
if isinstance(outcome, model.PrivilegeEscalation):
if outcome.tag in node_info.properties:
return False, ActionResult(reward=Penalty.REPEAT, outcome=outcome)
self.__mark_node_as_owned(node_id, outcome.level)
node_info.properties.append(outcome.tag)
elif isinstance(outcome, model.ProbeSucceeded):
for p in outcome.discovered_properties:
assert p in node_info.properties, \
f'Discovered property {p} must belong to the set of properties associated with the node.'
self.__mark_nodeproperties_as_discovered(node_id, outcome.discovered_properties)
if node_id not in self._discovered_nodes:
self._discovered_nodes[node_id] = NodeTrackingInformation()
lookup_key = (vulnerability_id, local_or_remote)
already_executed = lookup_key in self._discovered_nodes[node_id].last_attack
if already_executed:
last_time = self._discovered_nodes[node_id].last_attack[lookup_key]
if node_info.last_reimaging is None or last_time >= node_info.last_reimaging:
return False, ActionResult(reward=Penalty.REPEAT, outcome=outcome)
self._discovered_nodes[node_id].last_attack[lookup_key] = time()
self.__mark_discovered_entities(node_id, outcome)
logger.info("GOT REWARD: " + vulnerability.reward_string)
return True, ActionResult(reward=0.0 if already_executed else SUCCEEDED_ATTACK_REWARD - vulnerability.cost,
outcome=vulnerability.outcome)
def exploit_remote_vulnerability(self,
node_id: model.NodeID,
target_node_id: model.NodeID,
vulnerability_id: model.VulnerabilityID
) -> ActionResult:
"""
Attempt to exploit a remote vulnerability
from a source node to another node using the specified
vulnerability.
"""
if node_id not in self._environment.network.nodes:
raise ValueError(f"invalid node id '{node_id}'")
if target_node_id not in self._environment.network.nodes:
raise ValueError(f"invalid target node id '{target_node_id}'")
source_node_info: model.NodeInfo = self._environment.get_node(node_id)
target_node_info: model.NodeInfo = self._environment.get_node(target_node_id)
if not source_node_info.agent_installed:
raise ValueError("Agent does not owned the source node '" + node_id + "'")
if target_node_id not in self._discovered_nodes:
raise ValueError("Agent has not discovered the target node '" + target_node_id + "'")
succeeded, result = self.__process_outcome(
model.VulnerabilityType.REMOTE,
vulnerability_id,
target_node_id,
target_node_info,
local_or_remote=False,
failed_penalty=Penalty.FAILED_REMOTE_EXPLOIT,
# We do not throw if the vulnerability is missing in order to
# allow agent attempts to explore potential remote vulnerabilities
throw_if_vulnerability_not_present=False
)
if succeeded:
self.__annotate_edge(node_id, target_node_id, EdgeAnnotation.REMOTE_EXPLOIT)
return result
def exploit_local_vulnerability(self, node_id: model.NodeID,
vulnerability_id: model.VulnerabilityID) -> ActionResult:
"""
This function exploits a local vulnerability on a node
it takes a nodeID for the target and a vulnerability ID.
It returns either a vulnerabilityoutcome object or None
"""
graph = self._environment.network
if node_id not in graph.nodes:
raise ValueError(f"invalid node id '{node_id}'")
node_info = self._environment.get_node(node_id)
if not node_info.agent_installed:
raise ValueError(f"Agent does not owned the node '{node_id}'")
succeeded, result = self.__process_outcome(
model.VulnerabilityType.LOCAL,
vulnerability_id,
node_id, node_info,
local_or_remote=True,
failed_penalty=Penalty.LOCAL_EXPLOIT_FAILED,
throw_if_vulnerability_not_present=True)
return result
def __is_passing_firewall_rules(self, rules: List[model.FirewallRule], port_name: model.PortName) -> bool:
"""Determine if traffic on the specified port is permitted by the specified sets of firewall rules"""
for rule in rules:
if rule.port == port_name:
if rule.permission == model.RulePermission.ALLOW:
return True
else:
logger.debug(f'BLOCKED TRAFFIC - PORT \'{port_name}\' Reason: ' + rule.reason)
return False
logger.debug(f"BLOCKED TRAFFIC - PORT '{port_name}' - Reason: no rule defined for this port.")
return False
def connect_to_remote_machine(
self,
source_node_id: model.NodeID,
target_node_id: model.NodeID,
port_name: model.PortName,
credential: model.CredentialID) -> ActionResult:
"""
This function connects to a remote machine with credential as opposed to via an exploit.
It takes a NodeId for the source machine, a NodeID for the target Machine, and a credential object
for the credential.
"""
graph = self._environment.network
if source_node_id not in graph.nodes:
raise ValueError(f"invalid node id '{source_node_id}'")
if target_node_id not in graph.nodes:
raise ValueError(f"invalid node id '{target_node_id}''")
target_node = self._environment.get_node(target_node_id)
source_node = self._environment.get_node(source_node_id)
# ensures that the source node is owned by the agent
# and that the target node is discovered
if not source_node.agent_installed:
raise ValueError(f"Agent does not owned the source node '{source_node_id}'")
if target_node_id not in self._discovered_nodes:
raise ValueError(f"Agent has not discovered the target node '{target_node_id}'")
if credential not in self._gathered_credentials:
raise ValueError(f"Agent has not discovered credential '{credential}'")
if not self.__is_passing_firewall_rules(source_node.firewall.outgoing, port_name):
logger.info(f"BLOCKED TRAFFIC: source node '{source_node_id}'" +
f" is blocking outgoing traffic on port '{port_name}'")
return ActionResult(reward=Penalty.BLOCKED_BY_LOCAL_FIREWALL,
outcome=None)
if not self.__is_passing_firewall_rules(target_node.firewall.incoming, port_name):
logger.info(f"BLOCKED TRAFFIC: target node '{target_node_id}'" +
f" is blocking outgoing traffic on port '{port_name}'")
return ActionResult(reward=Penalty.BLOCKED_BY_REMOTE_FIREWALL,
outcome=None)
target_node_is_listening = port_name in [i.name for i in target_node.services]
if not target_node_is_listening:
logger.info(f"target node '{target_node_id}' not listening on port '{port_name}'")
return ActionResult(reward=Penalty.SCANNING_UNOPEN_PORT,
outcome=None)
else:
target_node_data: model.NodeInfo = self._environment.get_node(target_node_id)
if target_node_data.status != model.MachineStatus.Running:
logger.info("target machine not in running state")
return ActionResult(reward=Penalty.MACHINE_NOT_RUNNING,
outcome=None)
# check the credentials before connecting
if not self._check_service_running_and_authorized(target_node_data, port_name, credential):
logger.info("invalid credentials supplied")
return ActionResult(reward=Penalty.WRONG_PASSWORD,
outcome=None)
is_already_owned = target_node_data.agent_installed
if is_already_owned:
return ActionResult(reward=Penalty.REPEAT,
outcome=model.LateralMove())
if target_node_id not in self._discovered_nodes:
self._discovered_nodes[target_node_id] = NodeTrackingInformation()
was_previously_owned_at = self._discovered_nodes[target_node_id].last_connection
self._discovered_nodes[target_node_id].last_connection = time()
if was_previously_owned_at is not None and \
target_node_data.last_reimaging is not None and \
was_previously_owned_at >= target_node_data.last_reimaging:
return ActionResult(reward=Penalty.REPEAT, outcome=model.LateralMove())
self.__annotate_edge(source_node_id, target_node_id, EdgeAnnotation.LATERAL_MOVE)
self.__mark_node_as_owned(target_node_id)
logger.info(f"Infected node '{target_node_id}' from '{source_node_id}'" +
f" via {port_name} with credential '{credential}'")
if target_node.owned_string:
logger.info("Owned message: " + target_node.owned_string)
return ActionResult(reward=float(target_node_data.value) if was_previously_owned_at is None else 0.0,
outcome=model.LateralMove())
def _check_service_running_and_authorized(self,
target_node_data: model.NodeInfo,
port_name: model.PortName,
credential: model.CredentialID) -> bool:
"""
This is a quick helper function to check the prerequisites to see if
they match the ones supplied.
"""
for service in target_node_data.services:
if service.running and service.name == port_name and credential in service.allowedCredentials:
return True
return False
def list_nodes(self) -> List[DiscoveredNodeInfo]:
"""Returns the list of nodes ID that were discovered or owned by the attacker."""
return [cast(DiscoveredNodeInfo, {'id': node_id,
'status': 'owned' if node_info.agent_installed else 'discovered'
})
for node_id, node_info in self.discovered_nodes()
]
def list_remote_attacks(self, node_id: model.NodeID) -> List[model.VulnerabilityID]:
"""Return list of all remote attacks that may be executed onto the specified node."""
attacks: List[model.VulnerabilityID] = self.list_vulnerabilities_in_target(
node_id, model.VulnerabilityType.REMOTE)
return attacks
def list_local_attacks(self, node_id: model.NodeID) -> List[model.VulnerabilityID]:
"""Return list of all local attacks that may be executed onto the specified node."""
attacks: List[model.VulnerabilityID] = self.list_vulnerabilities_in_target(
node_id, model.VulnerabilityType.LOCAL)
return attacks
def list_attacks(self, node_id: model.NodeID) -> List[model.VulnerabilityID]:
"""Return list of all attacks that may be executed on the specified node."""
attacks: List[model.VulnerabilityID] = self.list_vulnerabilities_in_target(
node_id)
return attacks
def list_all_attacks(self) -> List[Dict[str, object]]:
"""List all possible attacks from all the nodes currently owned by the attacker"""
on_owned_nodes: List[Dict[str, object]] = [
{'id': n['id'],
'status': n['status'],
'properties': self._environment.get_node(n['id']).properties,
'local_attacks': self.list_local_attacks(n['id']),
'remote_attacks': self.list_remote_attacks(n['id'])
}
for n in self.list_nodes() if n['status'] == 'owned']
on_discovered_nodes: List[Dict[str, object]] = [{'id': n['id'],
'status': n['status'],
'local_attacks': None,
'remote_attacks': self.list_remote_attacks(n['id'])}
for n in self.list_nodes() if n['status'] != 'owned']
return on_owned_nodes + on_discovered_nodes
def print_all_attacks(self) -> None:
"""Pretty print list of all possible attacks from all the nodes currently owned by the attacker"""
d.display(pd.DataFrame.from_dict(self.list_all_attacks())) # type: ignore
class DefenderAgentActions:
"""Actions reserved to defender agents"""
# Number of steps it takes to completely reimage a node
REIMAGING_DURATION = 15
def __init__(self, environment: model.Environment):
# map nodes being reimaged to the remaining number of steps to completion
self.node_reimaging_progress: Dict[model.NodeID, int] = dict()
# Last calculated availability of the network
self.__network_availability: float = 1.0
self._environment = environment
@property
def network_availability(self):
return self.__network_availability
def reimage_node(self, node_id: model.NodeID):
"""Re-image a computer node"""
# Mark the node for re-imaging and make it unavailable until re-imaging completes
self.node_reimaging_progress[node_id] = self.REIMAGING_DURATION
node_info = self._environment.get_node(node_id)
assert node_info.reimagable, f'Node {node_id} is not re-imageable'
node_info.agent_installed = False
node_info.privilege_level = model.PrivilegeLevel.NoAccess
node_info.status = model.MachineStatus.Imaging
node_info.last_reimaging = time()
self._environment.network.nodes[node_id].update({'data': node_info})
def on_attacker_step_taken(self):
"""Function to be called each time a step is take in the simulation"""
for node_id in list(self.node_reimaging_progress.keys()):
remaining_steps = self.node_reimaging_progress[node_id]
if remaining_steps > 0:
self.node_reimaging_progress[node_id] -= 1
else:
logger.info(f"Machine re-imaging completed: {node_id}")
node_data = self._environment.get_node(node_id)
node_data.status = model.MachineStatus.Running
self.node_reimaging_progress.pop(node_id)
# Calculate the network availability metric based on machines
# and services that are running
total_node_weights = 0
network_node_availability = 0
for node_id, node_info in self._environment.nodes():
total_service_weights = 0
running_service_weights = 0
for service in node_info.services:
total_service_weights += service.sla_weight
running_service_weights += service.sla_weight * int(service.running)
if node_info.status == MachineStatus.Running:
adjusted_node_availability = (1 + running_service_weights) / (1 + total_service_weights)
else:
adjusted_node_availability = 0.0
total_node_weights += node_info.sla_weight
network_node_availability += adjusted_node_availability * node_info.sla_weight
self.__network_availability = network_node_availability / total_node_weights
assert(self.__network_availability <= 1.0 and self.__network_availability >= 0.0)
def override_firewall_rule(self, node_id: model.NodeID, port_name: model.PortName, incoming: bool, permission: model.RulePermission):
node_data = self._environment.get_node(node_id)
rules = node_data.firewall.incoming if incoming else node_data.firewall.outgoing
matching_rules = [r for r in rules if r.port == port_name]
if matching_rules:
for r in matching_rules:
r.permission = permission
else:
new_rule = model.FirewallRule(port_name, permission)
if incoming:
node_data.firewall.incoming = [new_rule] + node_data.firewall.incoming
else:
node_data.firewall.outgoing = [new_rule] + node_data.firewall.outgoing
def block_traffic(self, node_id: model.NodeID, port_name: model.PortName, incoming: bool):
return self.override_firewall_rule(node_id, port_name, incoming, permission=model.RulePermission.BLOCK)
def allow_traffic(self, node_id: model.NodeID, port_name: model.PortName, incoming: bool):
return self.override_firewall_rule(node_id, port_name, incoming, permission=model.RulePermission.ALLOW)
def stop_service(self, node_id: model.NodeID, port_name: model.PortName):
node_data = self._environment.get_node(node_id)
assert node_data.status == model.MachineStatus.Running, "Machine must be running to stop a service"
for service in node_data.services:
if service.name == port_name:
service.running = False
def start_service(self, node_id: model.NodeID, port_name: model.PortName):
node_data = self._environment.get_node(node_id)
assert node_data.status == model.MachineStatus.Running, "Machine must be running to start a service"
for service in node_data.services:
if service.name == port_name:
service.running = True

Просмотреть файл

@ -0,0 +1,379 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
This is the set of tests for actions.py which implements the actions an agent can take
in this simulation.
"""
import random
from datetime import datetime
from typing import Union, Dict, List
import pytest
import networkx as nx
from . import model, actions
ADMINTAG = model.AdminEscalation().tag
SYSTEMTAG = model.SystemEscalation().tag
# pylint: disable=redefined-outer-name, protected-access
# define fixtures as a type so mypy will shut up
Fixture = Union[actions.AgentActions]
empty_vuln_dict: Dict[model.VulnerabilityID, model.VulnerabilityInfo] = {}
SINGLE_VULNERABILITIES = {
"UACME61":
model.VulnerabilityInfo(
description="UACME UAC bypass #61",
type=model.VulnerabilityType.LOCAL,
URL="https://github.com/hfiref0x/UACME",
precondition=model.Precondition(f"Windows&Win10&(~({ADMINTAG}|{SYSTEMTAG}))"),
outcome=model.AdminEscalation(),
rates=model.Rates(0, 0.2, 1.0))}
# temporary vuln dictionary for development purposes only.
# Remove once the full list of vulnerabilities is put together
# here we'll have 1 UAC bypass, 1 credential dump, and 1 remote infection vulnerability
SAMPLE_VULNERABILITIES = {
"UACME61":
model.VulnerabilityInfo(
description="UACME UAC bypass #61",
type=model.VulnerabilityType.LOCAL,
URL="https://github.com/hfiref0x/UACME",
precondition=model.Precondition(f"Windows&Win10&(~({ADMINTAG}|{SYSTEMTAG}))"),
outcome=model.AdminEscalation(),
rates=model.Rates(0, 0.2, 1.0)),
"UACME67":
model.VulnerabilityInfo(
description="UACME UAC bypass #67 (fake system escalation) ",
type=model.VulnerabilityType.LOCAL,
URL="https://github.com/hfiref0x/UACME",
precondition=model.Precondition(f"Windows&Win10&(~({ADMINTAG}|{SYSTEMTAG}))"),
outcome=model.SystemEscalation(),
rates=model.Rates(0, 0.2, 1.0)),
"MimikatzLogonpasswords":
model.VulnerabilityInfo(
description="Mimikatz sekurlsa::logonpasswords.",
type=model.VulnerabilityType.LOCAL,
URL="https://github.com/gentilkiwi/mimikatz",
precondition=model.Precondition(f"Windows&({ADMINTAG}|{SYSTEMTAG})"),
outcome=model.LeakedCredentials([]),
rates=model.Rates(0, 1.0, 1.0)),
"RDPBF":
model.VulnerabilityInfo(
description="RDP Brute Force",
type=model.VulnerabilityType.REMOTE,
URL="https://attack.mitre.org/techniques/T1110/",
precondition=model.Precondition("Windows&PortRDPOpen"),
outcome=model.LateralMove(),
rates=model.Rates(0, 0.2, 1.0),
cost=1.0)
}
ENV_IDENTIFIERS = model.Identifiers(
local_vulnerabilities=['UACME61', 'UACME67', 'MimikatzLogonpasswords', 'UACME61'],
remote_vulnerabilities=['RDPBF'],
ports=['RDP', 'HTTP', 'HTTPS', 'SSH'],
properties=[
"Linux", "PortSSHOpen", "PortSQLOpen",
"Windows", "Win10", "PortRDPOpen",
"PortHTTPOpen", "PortHTTPsOpen",
"SharepointLeakingPassword"]
)
def sample_random_firwall_configuration() -> model.FirewallConfiguration:
"""Sample a random firewall set of rules"""
return model.FirewallConfiguration(
outgoing=[model.FirewallRule(p, permission=model.RulePermission.ALLOW)
for p in random.choices(ENV_IDENTIFIERS.properties,
k=random.randint(0, len(ENV_IDENTIFIERS.properties)))],
incoming=[model.FirewallRule(p, permission=model.RulePermission.ALLOW)
for p in random.choices(ENV_IDENTIFIERS.properties,
k=random.randint(0, len(ENV_IDENTIFIERS.properties)))])
# temporary info for a single node network
SINGLE_NODE = {
'a': model.NodeInfo(
services=[model.ListeningService("RDP"),
model.ListeningService("HTTP"),
model.ListeningService("HTTPS")],
value=70,
properties=list(["Windows", "Win10", "PortRDPOpen", "PortHTTPOpen", "PortHTTPsOpen"]),
firewall=sample_random_firwall_configuration(),
agent_installed=False)}
# temporary info for 4 nodes
# a is a windows web server, b is linux SQL server, c is a windows workstation,
# and dc is a domain controller
NODES = {
'a': model.NodeInfo(
services=[model.ListeningService("RDP"),
model.ListeningService("HTTP"),
model.ListeningService("HTTPS")],
value=70,
properties=list(["Windows", "Win10", "PortRDPOpen", "PortHTTPOpen", "PortHTTPsOpen"]),
vulnerabilities=dict(
ListNeighbors=model.VulnerabilityInfo(
description="reveal other nodes",
type=model.VulnerabilityType.LOCAL,
outcome=model.LeakedNodesId(nodes=['b', 'c', 'dc'])),
DumpCreds=model.VulnerabilityInfo(
description="leaking some creds",
type=model.VulnerabilityType.LOCAL,
outcome=model.LeakedCredentials([model.CachedCredential('Sharepoint', "HTTPS", "ADPrincipalCreds"),
model.CachedCredential('Sharepoint', "HTTPS", "cred")])
)
),
agent_installed=True),
'b': model.NodeInfo(
services=[model.ListeningService("SSH"),
model.ListeningService("SQL")],
value=80,
properties=list(["Linux", "PortSSHOpen", "PortSQLOpen"]),
agent_installed=False),
'c': model.NodeInfo(
services=[model.ListeningService("RDP"),
model.ListeningService("HTTP"),
model.ListeningService("HTTPS")],
value=40,
properties=list(["Windows", "Win10", "PortRDPOpen", "PortHTTPOpen", "PortHTTPsOpen"]),
agent_installed=True),
'dc': model.NodeInfo(
services=[model.ListeningService("RDP"),
model.ListeningService("WMI")],
value=100, properties=list(["Windows", "Win10", "PortRDPOpen", "PortWMIOpen"]),
agent_installed=False),
'Sharepoint': model.NodeInfo(
services=[model.ListeningService("HTTPS", allowedCredentials=["ADPrincipalCreds"])], value=100,
properties=["SharepointLeakingPassword"],
firewall=model.FirewallConfiguration(
incoming=[model.FirewallRule(port="SSH", permission=model.RulePermission.ALLOW),
model.FirewallRule(port="HTTPS", permission=model.RulePermission.ALLOW),
model.FirewallRule(port="HTTP", permission=model.RulePermission.ALLOW),
model.FirewallRule(port="RDP", permission=model.RulePermission.BLOCK)
],
outgoing=[]),
vulnerabilities=dict(
ScanSharepointParentDirectory=model.VulnerabilityInfo(
description="Navigate to SharePoint site, browse parent "
"directory",
type=model.VulnerabilityType.REMOTE,
outcome=model.LeakedCredentials(credentials=[
model.CachedCredential(node="AzureResourceManager",
port="HTTPS",
credential="ADPrincipalCreds")]),
rates=model.Rates(successRate=1.0),
cost=1.0)
)),
}
# Define an environment from this graph
ENV = model.Environment(
network=model.create_network(NODES),
vulnerability_library=dict([]),
identifiers=ENV_IDENTIFIERS,
creationTime=datetime.utcnow(),
lastModified=datetime.utcnow(),
)
@ pytest.fixture
def actions_on_empty_environment() -> actions.AgentActions:
"""
the test fixtures to reduce the amount of overhead
This fixture will provide us with an empty environment.
"""
egraph = nx.empty_graph(0, create_using=nx.DiGraph())
env = model.Environment(network=egraph,
version=model.VERSION_TAG,
vulnerability_library=SAMPLE_VULNERABILITIES,
identifiers=ENV_IDENTIFIERS,
creationTime=datetime.utcnow(),
lastModified=datetime.utcnow())
return actions.AgentActions(env)
@ pytest.fixture
def actions_on_single_node_environment() -> actions.AgentActions:
"""
This fixture will provide us with a single node environment
"""
env = model.Environment(network=model.create_network(SINGLE_NODE),
version=model.VERSION_TAG,
vulnerability_library=SAMPLE_VULNERABILITIES,
identifiers=ENV_IDENTIFIERS,
creationTime=datetime.utcnow(),
lastModified=datetime.utcnow())
return actions.AgentActions(env)
@ pytest.fixture
def actions_on_simple_environment() -> actions.AgentActions:
"""
This fixture will provide us with a 4 node environment environment.
simulating three workstations connected to a single server
"""
env = model.Environment(network=model.create_network(NODES),
version=model.VERSION_TAG,
vulnerability_library=SAMPLE_VULNERABILITIES,
identifiers=ENV_IDENTIFIERS,
creationTime=datetime.utcnow(),
lastModified=datetime.utcnow())
return actions.AgentActions(env)
def test_list_vulnerabilities_function(actions_on_single_node_environment: Fixture,
actions_on_simple_environment: Fixture) -> None:
"""
This function will test the list_vulnerabilities function from the
AgentActions class in actions.py
"""
# test on an environment with a single node
single_node_results: List[model.VulnerabilityID] = []
single_node_results = actions_on_single_node_environment.list_vulnerabilities_in_target('a')
assert len(single_node_results) == 3
simple_graph_results: List[model.VulnerabilityID] = []
simple_graph_results = actions_on_simple_environment.list_vulnerabilities_in_target('dc')
assert len(simple_graph_results) == 3
def test_exploit_remote_vulnerability(actions_on_simple_environment: Fixture) -> None:
"""
This function will test the exploit_remote_vulnerability function from the
AgentActions class in actions.py
"""
actions_on_simple_environment.exploit_local_vulnerability('a', "ListNeighbors")
# test with invalid source node
with pytest.raises(ValueError, match=r"invalid node id '.*'"):
actions_on_simple_environment.exploit_remote_vulnerability('z', 'b', "RDPBF")
# test with invalid destination node
with pytest.raises(ValueError, match=r"invalid target node id '.*'"):
actions_on_simple_environment.exploit_remote_vulnerability('a', 'z', "RDPBF")
# test with a local vulnerability
with pytest.raises(ValueError, match=r"vulnerability id '.*' is for an attack of type .*"):
actions_on_simple_environment.exploit_remote_vulnerability('a', 'c', "MimikatzLogonpasswords")
# test with an invalid vulnerability (one not there)
result = actions_on_simple_environment.exploit_remote_vulnerability('a', 'c', "HackTheGibson")
assert result.outcome is None and result.reward <= 0
# add RDP brute force to the target node
# very hacky not to be used normally.
graph: nx.graph.Graph = actions_on_simple_environment._environment.network
node: model.NodeInfo = graph.nodes['c']['data']
node.vulnerabilities = SAMPLE_VULNERABILITIES
# test a valid and functional one.
result = actions_on_simple_environment.exploit_remote_vulnerability('a', 'c', "RDPBF")
assert isinstance(result.outcome, model.LateralMove)
assert result.reward <= actions.SUCCEEDED_ATTACK_REWARD - 1
def test_exploit_local_vulnerability(actions_on_simple_environment: Fixture) -> None:
"""
This function will test the exploit_local_vulnerability function from the
AgentActions class in actions.py
"""
# check one with invalid prerequisites
result: actions.ActionResult = actions_on_simple_environment.\
exploit_local_vulnerability('a', "MimikatzLogonpasswords")
assert isinstance(result.outcome, model.ExploitFailed)
# test admin privilege escalation
# exploit_local_vulnerability(node_id, vulnerability_id)
result = actions_on_simple_environment.exploit_local_vulnerability('a', "UACME61")
assert isinstance(result.outcome, model.AdminEscalation)
node: model.NodeInfo = actions_on_simple_environment._environment.network.nodes['a']['data']
assert model.AdminEscalation().tag in node.properties
# test system privilege escalation
result = actions_on_simple_environment.exploit_local_vulnerability('c', "UACME67")
assert isinstance(result.outcome, model.SystemEscalation)
node = actions_on_simple_environment._environment.network.nodes['c']['data']
assert model.SystemEscalation().tag in node.properties
# test dump credentials
result = actions_on_simple_environment.\
exploit_local_vulnerability('a', "MimikatzLogonpasswords")
assert isinstance(result.outcome, model.LeakedCredentials)
def test_connect_to_remote_machine(actions_on_empty_environment: Fixture,
actions_on_single_node_environment: Fixture,
actions_on_simple_environment: Fixture) -> None:
"""
This function will test the connect_to_remote_machine function from the
AgentActions class in actions.py
"""
actions_on_simple_environment.exploit_local_vulnerability('a', "ListNeighbors")
actions_on_simple_environment.exploit_local_vulnerability('a', "DumpCreds")
# test connect to remote machine on an empty environment
with pytest.raises(ValueError, match=r"invalid node id '.*'"):
actions_on_empty_environment.connect_to_remote_machine("a", "b", "RDP", "cred")
# test connect to remote machine on an environment with 1 node
with pytest.raises(ValueError, match=r"invalid node id '.*'"):
actions_on_single_node_environment.connect_to_remote_machine("a", "b", "RDP", "cred")
graph: nx.graph.Graph = actions_on_simple_environment._environment.network
# test connect to remote machine on an environment with multiple nodes
# test with valid source node and invalid destination node
with pytest.raises(ValueError, match=r"invalid node id '.*'"):
actions_on_simple_environment.\
connect_to_remote_machine("a", "f", "RDP", "cred")
# test with an invalid source node and valid destination node
with pytest.raises(ValueError, match=r"invalid node id '.*'"):
actions_on_simple_environment.connect_to_remote_machine("f", "dc", "RDP", "cred")
# test with both nodes invalid
with pytest.raises(ValueError, match=r"invalid node id '.*'"):
actions_on_simple_environment.connect_to_remote_machine("f", "z", "RDP", "cred")
# test with invalid protocol
result = actions_on_simple_environment.connect_to_remote_machine("a", "dc", "TCPIP", "cred")
assert result.reward <= 0 and result.outcome is None
# test with invalid credentials
result2 = actions_on_simple_environment.connect_to_remote_machine("a", "dc", "RDP", "cred")
assert result2.outcome is None and result2.reward <= 0
# test blocking firewall rule
ret_val = actions_on_simple_environment.connect_to_remote_machine("a", 'Sharepoint', "RDP", "ADPrincipalCreds")
assert ret_val.reward < 0
# test with valid nodes
ret_val = actions_on_simple_environment.connect_to_remote_machine("a", 'Sharepoint', "HTTPS", "ADPrincipalCreds")
assert ret_val.reward == 100
assert graph.has_edge("a", "dc")
def test_check_prerequisites(actions_on_simple_environment: Fixture) -> None:
"""
This function will test the _checkPrerequisites function
It's marked as a private function but still needs to be tested before use
"""
# testing on a node/vuln combo which should give us a negative result
result = actions_on_simple_environment._check_prerequisites('dc', SAMPLE_VULNERABILITIES["MimikatzLogonpasswords"])
assert not result
# testing on a node/vuln combo which should give us a positive reuslt.
result = actions_on_simple_environment._check_prerequisites('dc', SAMPLE_VULNERABILITIES["UACME61"])
assert result

Просмотреть файл

@ -0,0 +1,276 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""A 'Command & control'-like interface exposing to a human player
the attacker view and actions of the game.
This includes commands to visualize the part of the environment
that were explored so far, and for each node where the attacker client
is installed, execute actions on the machine.
"""
import networkx as nx
from typing import List, Optional, Dict, Union, Tuple
import plotly.graph_objects as go
from . import model, actions
class CommandControl:
""" The Command and Control interface to the simulation.
This represents a server that centralize information and secrets
retrieved from the individual clients running on the network nodes.
"""
# Global list aggregating all credentials gathered so far, from any node in the network
__gathered_credentials: List[model.CachedCredential] = []
_actuator: actions.AgentActions
__environment: model.Environment
__total_reward: float
def __init__(self, environment_or_actuator: Union[model.Environment, actions.AgentActions]):
if isinstance(environment_or_actuator, model.Environment):
self.__environment = environment_or_actuator
self._actuator = actions.AgentActions(self.__environment)
elif isinstance(environment_or_actuator, actions.AgentActions):
self.__environment = environment_or_actuator._environment
self._actuator = environment_or_actuator
else:
raise ValueError(
"Invalid type: expecting Union[model.Environment, actions.AgentActions])")
self.__gathered_credentials = []
self.__total_reward = 0
def __save_credentials(self, outcome: model.VulnerabilityOutcome) -> None:
"""Save credentials obtained from exploiting a vulnerability"""
if isinstance(outcome, model.LeakedCredentials):
self.__gathered_credentials.extend(outcome.credentials)
return
def __accumulate_reward(self, reward: actions.Reward) -> None:
"""Accumulate new reward"""
self.__total_reward += reward
def total_reward(self) -> actions.Reward:
"""Return the current accumulated reward"""
return self.__total_reward
def list_nodes(self) -> List[actions.DiscoveredNodeInfo]:
"""Returns the list of nodes ID that were discovered or owned by the attacker."""
return self._actuator.list_nodes()
def get_node_color(self, node_info: model.NodeInfo) -> str:
if node_info.agent_installed:
return 'red'
else:
return 'green'
def plot_nodes(self) -> None:
"""Plot the sub-graph of nodes either so far
discovered (their ID is knowned by the agent)
or owned (i.e. where the attacker client is installed)."""
discovered_nodes = [node_id for node_id, _ in self._actuator.discovered_nodes()]
sub_graph = self.__environment.network.subgraph(discovered_nodes)
nx.draw(sub_graph,
with_labels=True,
node_color=[self.get_node_color(self.__environment.get_node(i)) for i in sub_graph.nodes])
def known_vulnerabilities(self) -> model.VulnerabilityLibrary:
"""Return the global list of known vulnerability."""
return self.__environment.vulnerability_library
def list_remote_attacks(self, node_id: model.NodeID) -> List[model.VulnerabilityID]:
"""Return list of all remote attacks that the Command&Control may
execute onto the specified node."""
return self._actuator.list_remote_attacks(node_id)
def list_local_attacks(self, node_id: model.NodeID) -> List[model.VulnerabilityID]:
"""Return list of all local attacks that the Command&Control may
execute onto the specified node."""
return self._actuator.list_local_attacks(node_id)
def list_attacks(self, node_id: model.NodeID) -> List[model.VulnerabilityID]:
"""Return list of all attacks that the Command&Control may
execute on the specified node."""
return self._actuator.list_attacks(node_id)
def list_all_attacks(self) -> List[Dict[str, object]]:
"""List all possible attacks from all the nodes currently owned by the attacker"""
return self._actuator.list_all_attacks()
def print_all_attacks(self) -> None:
"""Pretty print list of all possible attacks from all the nodes currently owned by the attacker"""
return self._actuator.print_all_attacks()
def run_attack(self,
node_id: model.NodeID,
vulnerability_id: model.VulnerabilityID
) -> Optional[model.VulnerabilityOutcome]:
"""Run an attack and attempt to exploit a vulnerability on the specified node."""
result = self._actuator.exploit_local_vulnerability(node_id, vulnerability_id)
if result.outcome is not None:
self.__save_credentials(result.outcome)
self.__accumulate_reward(result.reward)
return result.outcome
def run_remote_attack(self, node_id: model.NodeID,
target_node_id: model.NodeID,
vulnerability_id: model.VulnerabilityID
) -> Optional[model.VulnerabilityOutcome]:
"""Run a remote attack from the specified node to exploit a remote vulnerability
in the specified target node"""
result = self._actuator.exploit_remote_vulnerability(
node_id, target_node_id, vulnerability_id)
if result.outcome is not None:
self.__save_credentials(result.outcome)
self.__accumulate_reward(result.reward)
return result.outcome
def connect_and_infect(self, source_node_id: model.NodeID,
target_node_id: model.NodeID,
port_name: model.PortName,
credentials: model.CredentialID) -> bool:
"""Install the agent on a remote machine using the
provided credentials"""
result = self._actuator.connect_to_remote_machine(source_node_id, target_node_id, port_name,
credentials)
self.__accumulate_reward(result.reward)
return result.outcome is not None
@property
def credentials_gathered_so_far(self) -> List[model.CachedCredential]:
"""Returns the list of credentials gathered so far by the
attacker (from any node)"""
return self.__gathered_credentials
def get_outcome_first_credential(outcome: Optional[model.VulnerabilityOutcome]) -> model.CredentialID:
"""Return the first credential found in a given vulnerability exploit outcome"""
if outcome is not None and isinstance(outcome, model.LeakedCredentials):
return outcome.credentials[0].credential
else:
raise ValueError('Vulnerability outcome does not contain any credential')
class EnvironmentDebugging:
"""Provides debugging feature exposing internals of the environment
that are not normally revealed to an attacker agent according to
the rules of the simulation.
"""
__environment: model.Environment
__actuator: actions.AgentActions
def __init__(self, actuator_or_c2: Union[actions.AgentActions, CommandControl]):
if isinstance(actuator_or_c2, actions.AgentActions):
self.__actuator = actuator_or_c2
elif isinstance(actuator_or_c2, CommandControl):
self.__actuator = actuator_or_c2._actuator
else:
raise ValueError("Invalid type: expecting Union[actions.AgentActions, CommandControl])")
self.__environment = self.__actuator._environment
def network_as_plotly_traces(self, xref: str = "x", yref: str = "y") -> Tuple[List[go.Scatter], dict]:
known_nodes = [node_id for node_id, _ in self.__actuator.discovered_nodes()]
subgraph = self.__environment.network.subgraph(known_nodes)
# pos = nx.fruchterman_reingold_layout(subgraph)
pos = nx.shell_layout(subgraph, [[known_nodes[0]], known_nodes[1:]])
def edge_text(source: model.NodeID, target: model.NodeID) -> str:
data = self.__environment.network.get_edge_data(source, target)
name: str = data['kind'].name
return name
color_map = {actions.EdgeAnnotation.LATERAL_MOVE: 'red',
actions.EdgeAnnotation.REMOTE_EXPLOIT: 'orange',
actions.EdgeAnnotation.KNOWS: 'gray'}
def edge_color(source: model.NodeID, target: model.NodeID) -> str:
data = self.__environment.network.get_edge_data(source, target)
if 'kind' in data:
return color_map[data['kind']]
return 'black'
layout: dict = dict(title="CyberBattle simulation", font=dict(size=10), showlegend=True,
autosize=False, width=800, height=400,
margin=go.layout.Margin(l=2, r=2, b=15, t=35),
hovermode='closest',
annotations=[dict(
ax=pos[source][0],
ay=pos[source][1], axref=xref, ayref=yref,
x=pos[target][0],
y=pos[target][1], xref=xref, yref=yref,
arrowcolor=edge_color(source, target),
hovertext=edge_text(source, target),
showarrow=True,
arrowhead=1,
arrowsize=1,
arrowwidth=1,
startstandoff=10,
standoff=10,
align='center',
opacity=1
) for (source, target) in subgraph.edges]
)
owned_nodes_coordinates = [(i, c) for i, c in pos.items()
if self.get_node_information(i).agent_installed]
discovered_nodes_coordinates = [(i, c)
for i, c in pos.items()
if not self.get_node_information(i).agent_installed]
trace_owned_nodes = go.Scatter(
x=[c[0] for i, c in owned_nodes_coordinates],
y=[c[1] for i, c in owned_nodes_coordinates],
mode='markers+text',
name='owned',
marker=dict(symbol='circle-dot',
size=5,
# green #0e9d00
color='#D32F2E', # red
line=dict(color='rgb(255,0,0)', width=8)
),
text=[i for i, c in owned_nodes_coordinates],
hoverinfo='text',
textposition="bottom center"
)
trace_discovered_nodes = go.Scatter(
x=[c[0] for i, c in discovered_nodes_coordinates],
y=[c[1] for i, c in discovered_nodes_coordinates],
mode='markers+text',
name='discovered',
marker=dict(symbol='circle-dot',
size=5,
color='#0e9d00', # green
line=dict(color='rgb(0,255,0)', width=8)
),
text=[i for i, c in discovered_nodes_coordinates],
hoverinfo='text',
textposition="bottom center"
)
dummy_scatter_for_edge_legend = [
go.Scatter(
x=[0], y=[0], mode="lines",
line=dict(color=color_map[a]),
name=a.name
) for a in actions.EdgeAnnotation]
all_scatters = dummy_scatter_for_edge_legend + [trace_owned_nodes, trace_discovered_nodes]
return (all_scatters, layout)
def plot_discovered_network(self) -> None:
"""Plot the network graph with plotly"""
fig = go.Figure()
traces, layout = self.network_as_plotly_traces()
for t in traces:
fig.add_trace(t)
fig.update_layout(layout)
fig.show()
def get_node_information(self, node_id: model.NodeID) -> model.NodeInfo:
"""Print node information"""
return self.__environment.get_node(node_id)

Просмотреть файл

@ -0,0 +1,77 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
Unit tests for commandcontrol.py.
"""
# pylint: disable=missing-function-docstring
from . import model, commandcontrol
from ..samples.toyctf import toy_ctf as ctf
def test_toyctf() -> None:
# Use the C&C to exploit remote and local vulnerabilities in the toy CTF game
network = model.create_network(ctf.nodes)
env = model.Environment(network=network,
vulnerability_library=dict([]),
identifiers=ctf.ENV_IDENTIFIERS)
command = commandcontrol.CommandControl(env)
leak_website = command.run_attack('client', 'SearchEdgeHistory')
assert leak_website
github = command.run_remote_attack('client', 'Website', 'ScanPageContent')
leaked_sas_url_outcome = command.run_remote_attack('client', 'GitHubProject', 'CredScanGitHistory')
leaked_sas_url = commandcontrol.get_outcome_first_credential(leaked_sas_url_outcome)
blobwithflag = command.connect_and_infect('client', 'AzureStorage', 'HTTPS', leaked_sas_url)
assert(blobwithflag is not False)
browsable_directory = command.run_remote_attack('client', 'Website', 'ScanPageSource')
assert browsable_directory
outcome_mysqlleak = command.run_remote_attack('client', 'Website.Directory', 'NavigateWebDirectoryFurther')
mysql_credential = commandcontrol.get_outcome_first_credential(outcome_mysqlleak)
sharepoint_url = command.run_remote_attack('client', 'Website.Directory', 'NavigateWebDirectory')
assert sharepoint_url
outcome_azure_ad = command.run_remote_attack('client', 'Sharepoint', 'ScanSharepointParentDirectory')
azure_ad_credentials = commandcontrol.get_outcome_first_credential(outcome_azure_ad)
azure_vm_info = command.connect_and_infect('client', 'AzureResourceManager', 'HTTPS', azure_ad_credentials)
assert(azure_vm_info is not False)
azure_resources = command.run_remote_attack('client', 'AzureResourceManager', 'ListAzureResources')
assert azure_resources
directly_ssh_connected = command.connect_and_infect('client', 'AzureVM', 'SSH', mysql_credential)
assert not directly_ssh_connected
sshd = command.connect_and_infect('client', 'Website', 'SSH', mysql_credential)
assert sshd is not False
outcome = command.run_attack('Website', 'CredScanBashHistory')
monitor_bash_breds = commandcontrol.get_outcome_first_credential(outcome)
connected_as_monitor = command.connect_and_infect('Website', 'Website[user=monitor]', 'sudo', monitor_bash_breds)
assert not connected_as_monitor
connected_as_monitor_from_client = command.connect_and_infect(
'client', 'Website[user=monitor]', 'SSH', monitor_bash_breds)
assert not connected_as_monitor_from_client
flag = command.connect_and_infect('Website', 'Website[user=monitor]', 'su', monitor_bash_breds)
assert flag is not False
outcome_azuread = command.run_attack('Website[user=monitor]', 'CredScan-HomeDirectory')
azure_ad_user_credential = commandcontrol.get_outcome_first_credential(outcome_azuread)
secrets = command.connect_and_infect('client', 'AzureResourceManager', 'HTTPS',
azure_ad_user_credential)
assert secrets is not False
reward = command.total_reward()
print('Total reward ' + str(reward))
assert reward == 759.0
assert github is not None
pass

Просмотреть файл

@ -0,0 +1,300 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
environment_generation.py this function generates a semi random environment for
the loonshot simulation v0.
"""
from typing import List, Dict, Set
import random
import re
import networkx as nx
from . import model
# These two lists are lists of potential vulnerabilities. They are split into linux vulnerabilities
# and Windows vulnerabilities so i can
ADMINTAG = model.AdminEscalation().tag
SYSTEMTAG = model.SystemEscalation().tag
potential_windows_vulns = {
"UACME43":
model.VulnerabilityInfo(
description="UACME UAC bypass #43",
type=model.VulnerabilityType.LOCAL,
URL="https://github.com/hfiref0x/UACME",
precondition=model.Precondition(f"Windows&(Win10|Win7)&(~({ADMINTAG}|{SYSTEMTAG}))"),
outcome=model.AdminEscalation(),
rates=model.Rates(0, 0.2, 1.0)),
"UACME45":
model.VulnerabilityInfo(
description="UACME UAC bypass #45",
type=model.VulnerabilityType.LOCAL,
URL="https://github.com/hfiref0x/UACME",
precondition=model.Precondition(f"Windows&Win10&(~({ADMINTAG}|{SYSTEMTAG}))"),
outcome=model.AdminEscalation(),
rates=model.Rates(0, 0.2, 1.0)),
"UACME52":
model.VulnerabilityInfo(
description="UACME UAC bypass #52",
type=model.VulnerabilityType.LOCAL,
URL="https://github.com/hfiref0x/UACME",
precondition=model.Precondition(f"Windows&(Win10|Win7)&(~({ADMINTAG}|{SYSTEMTAG}))"),
outcome=model.AdminEscalation(),
rates=model.Rates(0, 0.2, 1.0)),
"UACME55":
model.VulnerabilityInfo(
description="UACME UAC bypass #55",
type=model.VulnerabilityType.LOCAL,
URL="https://github.com/hfiref0x/UACME",
precondition=model.Precondition(f"Windows&(Win10|Win7)&(~({ADMINTAG}|{SYSTEMTAG}))"),
outcome=model.AdminEscalation(),
rates=model.Rates(0, 0.2, 1.0)),
"UACME61":
model.VulnerabilityInfo(
description="UACME UAC bypass #61",
type=model.VulnerabilityType.LOCAL,
URL="https://github.com/hfiref0x/UACME",
precondition=model.Precondition(f"Windows&Win10&(~({ADMINTAG}|{SYSTEMTAG}))"),
outcome=model.AdminEscalation(),
rates=model.Rates(0, 0.2, 1.0)),
"MimikatzLogonpasswords":
model.VulnerabilityInfo(
description="Mimikatz sekurlsa::logonpasswords.",
type=model.VulnerabilityType.LOCAL,
URL="https://github.com/gentilkiwi/mimikatz",
precondition=model.Precondition(f"Windows&({ADMINTAG}|{SYSTEMTAG})"),
outcome=model.LeakedCredentials([]),
rates=model.Rates(0, 1.0, 1.0)),
"MimikatzKerberosExport":
model.VulnerabilityInfo(
description="Mimikatz Kerberos::list /export."
"Exports .kirbi files to be used with pass the ticket",
type=model.VulnerabilityType.LOCAL,
URL="https://github.com/gentilkiwi/mimikatz",
precondition=model.Precondition(f"Windows&DomainJoined&({ADMINTAG}|{SYSTEMTAG})"),
outcome=model.LeakedCredentials([]),
rates=model.Rates(0, 1.0, 1.0)),
"PassTheTicket":
model.VulnerabilityInfo(
description="Mimikatz Kerberos::ptt /export."
"Exports .kirbi files to be used with pass the ticket",
type=model.VulnerabilityType.REMOTE,
URL="https://github.com/gentilkiwi/mimikatz",
precondition=model.Precondition(f"Windows&DomainJoined&KerberosTicketsDumped"
f"&({ADMINTAG}|{SYSTEMTAG})"),
outcome=model.LeakedCredentials([]),
rates=model.Rates(0, 1.0, 1.0)),
"RDPBF":
model.VulnerabilityInfo(
description="RDP Brute Force",
type=model.VulnerabilityType.REMOTE,
URL="https://attack.mitre.org/techniques/T1110/",
precondition=model.Precondition("Windows&PortRDPOpen"),
outcome=model.LateralMove(),
rates=model.Rates(0, 0.2, 1.0)),
"SMBBF":
model.VulnerabilityInfo(
description="SSH Brute Force",
type=model.VulnerabilityType.REMOTE,
URL="https://attack.mitre.org/techniques/T1110/",
precondition=model.Precondition("(Windows|Linux)&PortSMBOpen"),
outcome=model.LateralMove(),
rates=model.Rates(0, 0.2, 1.0))
}
potential_linux_vulns = {
"SudoCaching":
model.VulnerabilityInfo(
description="Escalating privileges from poorly configured sudo on linux/unix machines",
type=model.VulnerabilityType.REMOTE,
URL="https://attack.mitre.org/techniques/T1206/",
precondition=model.Precondition(f"Linux&(~{ADMINTAG})"),
outcome=model.AdminEscalation(),
rates=model.Rates(0, 1.0, 1.0)),
"SSHBF":
model.VulnerabilityInfo(
description="SSH Brute Force",
type=model.VulnerabilityType.REMOTE,
URL="https://attack.mitre.org/techniques/T1110/",
precondition=model.Precondition("Linux&PortSSHOpen"),
outcome=model.LateralMove(),
rates=model.Rates(0, 0.2, 1.0)),
"SMBBF":
model.VulnerabilityInfo(
description="SSH Brute Force",
type=model.VulnerabilityType.REMOTE,
URL="https://attack.mitre.org/techniques/T1110/",
precondition=model.Precondition("(Windows|Linux)&PortSMBOpen"),
outcome=model.LateralMove(),
rates=model.Rates(0, 0.2, 1.0))
}
# These are potential endpoints that can be open in a game. Note to add any more endpoints simply
# add the protocol name to this list.
# further note that ports are stored in a tuple. This is because some protoocls
# (like SMB) have multiple official ports.
potential_ports: List[model.PortName] = ["RDP", "SSH", "HTTP", "HTTPs",
"SMB", "SQL", "FTP", "WMI"]
# These two lists are potential node states. They are split into linux states and windows
# states so that we can generate real graphs that aren't just totally random.
potential_linux_node_states: List[model.PropertyName] = ["Linux", ADMINTAG,
"PortRDPOpen",
"PortHTTPOpen", "PortHTTPsOpen",
"PortSSHOpen", "PortSMBOpen",
"PortFTPOpen", "DomainJoined"]
potential_windows_node_states: List[model.PropertyName] = ["Windows", "Win10", "PortRDPOpen",
"PortHTTPOpen", "PortHTTPsOpen",
"PortSSHOpen", "PortSMBOpen",
"PortFTPOpen", "BITSEnabled",
"Win7", "DomainJoined"]
ENV_IDENTIFIERS = model.Identifiers(
ports=potential_ports,
properties=potential_linux_node_states + potential_windows_node_states,
local_vulnerabilities=list(potential_windows_vulns.keys()),
remote_vulnerabilities=list(potential_windows_vulns.keys())
)
def create_random_environment(name: str, size: int) -> model.Environment:
"""
This is the create random environment function. It takes a string for the name
of the environment and an int for the size. It returns a randomly genernated
environment.
Note this does not currently support generating credentials.
"""
if not name:
raise ValueError("Please supply a non empty string for the name")
if size < 1:
raise ValueError("Please supply a positive non zero positive"
"integer for the size of the environment")
graph = nx.DiGraph()
nodes: Dict[str, model.NodeInfo] = {}
# append the linux and windows vulnerability dictionaries
local_vuln_lib: Dict[model.VulnerabilityID, model.VulnerabilityInfo] = \
{**potential_windows_vulns, **potential_linux_vulns}
os_types: List[str] = ["Linux", "Windows"]
for i in range(size):
rand_os: str = os_types[random.randint(0, 1)]
nodes[str(i)] = create_random_node(rand_os, potential_ports)
graph.add_nodes_from([(k, {'data': v}) for (k, v) in list(nodes.items())])
return model.Environment(network=graph, vulnerability_library=local_vuln_lib, identifiers=ENV_IDENTIFIERS)
def create_random_node(os_type: str, end_points: List[model.PortName]) \
-> model.NodeInfo:
"""
This is the create random node function.
Currently it takes a string for the OS type and returns a NodeInfo object
Options for OS type are currently Linux or Windows,
Options for the role are Server or Workstation
"""
if not end_points:
raise ValueError("No endpoints supplied")
if os_type not in ("Windows", "Linux"):
raise ValueError("Unsupported OS Type please enter Linux or Windows")
# get the vulnerability dictionary for the important OS
vulnerabilities: model.VulnerabilityLibrary = dict([])
if os_type == "Linux":
vulnerabilities = \
select_random_vulnerabilities(os_type, random.randint(1, len(potential_linux_vulns)))
else:
vulnerabilities = \
select_random_vulnerabilities(os_type, random.randint(1, len(potential_windows_vulns)))
firewall: model.FirewallConfiguration = create_firewall_rules(end_points)
properties: List[model.PropertyName] = \
get_properties_from_vulnerabilities(os_type, vulnerabilities)
return model.NodeInfo(services=[model.ListeningService(name=p) for p in end_points],
vulnerabilities=vulnerabilities,
value=int(random.random()),
properties=properties,
firewall=firewall,
agent_installed=False)
def select_random_vulnerabilities(os_type: str, num_vulns: int) \
-> Dict[str, model.VulnerabilityInfo]:
"""
It takes an a string for the OS type, and an int for the number of
vulnerabilities to select.
It selects num_vulns vulnerabilities from the global list of vulnerabilities for that
specific operating system. It returns a dictionary of VulnerabilityInfo objects to
the caller.
"""
if num_vulns < 1:
raise ValueError("Expected a positive value for num_vulns in select_random_vulnerabilities")
ret_val: Dict[str, model.VulnerabilityInfo] = {}
keys: List[str]
if os_type == "Linux":
keys = random.sample(potential_linux_vulns.keys(), num_vulns)
ret_val = {k: potential_linux_vulns[k] for k in keys}
elif os_type == "Windows":
keys = random.sample(potential_windows_vulns.keys(), num_vulns)
ret_val = {k: potential_windows_vulns[k] for k in keys}
else:
raise ValueError("Invalid Operating System supplied to select_random_vulnerabilities")
return ret_val
def get_properties_from_vulnerabilities(os_type: str,
vulns: Dict[model.NodeID, model.VulnerabilityInfo]) \
-> List[model.PropertyName]:
"""
get_properties_from_vulnerabilities function.
This function takes a string for os_type and returns a list of PropertyName objects
"""
ret_val: Set[model.PropertyName] = set()
properties: List[model.PropertyName] = []
if os_type == "Linux":
properties = potential_linux_node_states
elif os_type == "Windows":
properties = potential_windows_node_states
for prop in properties:
for vuln_id, vuln in vulns.items():
if re.search(prop, str(vuln.precondition.expression)):
ret_val.add(prop)
return list(ret_val)
def create_firewall_rules(end_points: List[model.PortName]) -> model.FirewallConfiguration:
"""
This function takes a List of endpoints and returns a FirewallConfiguration
It iterates through the list of potential ports and if they're in the list passed
to the function it adds a firewall rule allowing that port.
Otherwise it adds a rule blocking that port.
"""
ret_val: model.FirewallConfiguration = model.FirewallConfiguration()
ret_val.incoming.clear()
ret_val.outgoing.clear()
for protocol in potential_ports:
if protocol in end_points:
ret_val.incoming.append(model.FirewallRule(protocol, model.RulePermission.ALLOW))
ret_val.outgoing.append(model.FirewallRule(protocol, model.RulePermission.ALLOW))
else:
ret_val.incoming.append(model.FirewallRule(protocol, model.RulePermission.BLOCK))
ret_val.outgoing.append(model.FirewallRule(protocol, model.RulePermission.BLOCK))
return ret_val

Просмотреть файл

@ -0,0 +1,107 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
The unit tests for the environment_generation functions
"""
from collections import Counter
from typing import List, Dict
import pytest
from . import environment_generation
from . import model
windows_vulns: Dict[str, model.VulnerabilityInfo] = environment_generation.potential_windows_vulns
linux_vulns: Dict[str, model.VulnerabilityInfo] = environment_generation.potential_linux_vulns
windows_node_states: List[model.PropertyName] = environment_generation.potential_linux_node_states
linux_node_states: List[model.PropertyName] = environment_generation.potential_linux_node_states
potential_ports: List[model.PortName] = environment_generation.potential_ports
def test_create_random_environment() -> None:
"""
The unit tests for create_random_environment function
"""
with pytest.raises(ValueError, match=r"Please supply a non empty string for the name"):
environment_generation.create_random_environment("", 2)
with pytest.raises(ValueError, match=r"Please supply a positive non zero positive"
r"integer for the size of the environment"):
environment_generation.create_random_environment("Test_environment", -5)
result: model.Environment = environment_generation.\
create_random_environment("Test_environment 2", 4)
assert isinstance(result, model.Environment)
def test_create_random_node() -> None:
"""
The unit tests for create_random_node() function
"""
# check that the correct exceptions are generated
with pytest.raises(ValueError, match=r"No endpoints supplied"):
environment_generation.create_random_node("Linux", [])
with pytest.raises(ValueError, match=r"Unsupported OS Type please enter Linux or Windows"):
environment_generation.create_random_node("Solaris", potential_ports)
test_node: model.NodeInfo = environment_generation.create_random_node("Linux", potential_ports)
assert isinstance(test_node, model.NodeInfo)
def test_get_properties_from_vulnerabilities() -> None:
"""
This function tests the get_properties_from_vulnerabilities function
It takes nothing and returns nothing.
"""
# testing on linux vulns
props: List[model.PropertyName] = environment_generation.\
get_properties_from_vulnerabilities("Linux", linux_vulns)
assert "Linux" in props
assert "PortSSHOpen" in props
assert "PortSMBOpen" in props
# testing on Windows vulns
windows_props: List[model.PropertyName] = environment_generation.get_properties_from_vulnerabilities(
"Windows", windows_vulns)
assert "Windows" in windows_props
assert "PortRDPOpen" in windows_props
assert "PortSMBOpen" in windows_props
assert "DomainJoined" in windows_props
assert "Win10" in windows_props
assert "Win7" in windows_props
def test_create_firewall_rules() -> None:
"""
This function tests the create_firewall_rules function.
It takes nothing and returns nothing.
"""
empty_ports: List[model.PortName] = []
potential_port_list: List[model.PortName] = ["RDP", "SSH", "HTTP", "HTTPs",
"SMB", "SQL", "FTP", "WMI"]
half_ports: List[model.PortName] = ["SSH", "HTTPs", "SQL", "FTP", "WMI"]
all_blocked: List[model.FirewallRule] = [model.FirewallRule(
port, model.RulePermission.BLOCK) for port in potential_port_list]
all_allowed: List[model.FirewallRule] = [model.FirewallRule(
port, model.RulePermission.ALLOW) for port in potential_port_list]
half_allowed: List[model.FirewallRule] = [model.FirewallRule(port, model.RulePermission.ALLOW)
if port in half_ports else model.FirewallRule(
port, model.RulePermission.BLOCK) for
port in potential_port_list]
# testing on an empty list should lead to
results: model.FirewallConfiguration = environment_generation.create_firewall_rules(empty_ports)
assert Counter(results.incoming) == Counter(all_blocked)
assert Counter(results.outgoing) == Counter(all_blocked)
# testing on a the list supported ports
results = environment_generation.create_firewall_rules(potential_ports)
assert Counter(results.incoming) == Counter(all_allowed)
assert Counter(results.outgoing) == Counter(all_allowed)
results = environment_generation.create_firewall_rules(half_ports)
assert Counter(results.incoming) == Counter(half_allowed)
assert Counter(results.outgoing) == Counter(half_allowed)

Просмотреть файл

@ -0,0 +1,287 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
""" Generating random graphs"""
from cyberbattle.simulation.model import Identifiers, NodeID, CredentialID, PortName
import numpy as np
import networkx as nx
from cyberbattle.simulation import model as m
import random
from typing import List, Tuple, DefaultDict
from collections import defaultdict
ENV_IDENTIFIERS = Identifiers(
properties=[
'breach_node'
],
ports=['SMB', 'HTTP', 'RDP'],
local_vulnerabilities=[
'ScanWindowsCredentialManagerForRDP',
'ScanWindowsExplorerRecentFiles',
'ScanWindowsCredentialManagerForSMB'
],
remote_vulnerabilities=[
'Traceroute'
]
)
def generate_random_traffic_network(
n_clients: int = 200,
n_servers={
"SMB": 1,
"HTTP": 1,
"RDP": 1,
},
seed: int = 0,
tolerance: np.float32 = np.float32(1e-3),
alpha=np.array([(0.1, 0.3), (0.18, 0.09)], dtype=float),
beta=np.array([(100, 10), (10, 100)], dtype=float),
) -> nx.DiGraph:
"""
Randomly generate a directed multi-edge network graph representing
fictitious SMB, HTTP, and RDP traffic.
Arguments:
n_clients: number of workstation nodes that can initiate sessions with server nodes
n_servers: dictionary indicatin the numbers of each nodes listening to each protocol
seed: seed for the psuedo-random number generator
tolerance: absolute tolerance for bounding the edge probabilities in [tolerance, 1-tolerance]
alpha: beta distribution parameters alpha such that E(edge prob) = alpha / beta
beta: beta distribution parameters beta such that E(edge prob) = alpha / beta
Returns:
(nx.classes.multidigraph.MultiDiGraph): the randomly generated network from the hierarchical block model
"""
edges_labels = defaultdict(set) # set backed multidict
for protocol in list(n_servers.keys()):
sizes = [n_clients, n_servers[protocol]]
# sample edge probabilities from a beta distribution
np.random.seed(seed)
probs: np.ndarray = np.random.beta(a=alpha, b=beta, size=(2, 2))
# don't allow probs too close to zero or one
probs = np.clip(probs, a_min=tolerance, a_max=np.float32(1.0 - tolerance))
# scale by edge type
if protocol == "SMB":
probs = 3 * probs
if protocol == "RDP":
probs = 4 * probs
# sample edges using block models given edge probabilities
di_graph_for_protocol = nx.stochastic_block_model(
sizes=sizes, p=probs, directed=True, seed=seed)
for edge in di_graph_for_protocol.edges:
edges_labels[edge].add(protocol)
digraph = nx.DiGraph()
for (u, v), port in list(edges_labels.items()):
digraph.add_edge(u, v, protocol=port)
return digraph
def cyberbattle_model_from_traffic_graph(
traffic_graph: nx.DiGraph,
cached_smb_password_probability=0.75,
cached_rdp_password_probability=0.8,
cached_accessed_network_shares_probability=0.6,
cached_password_has_changed_probability=0.1,
traceroute_discovery_probability=0.5,
probability_two_nodes_use_same_password_to_access_given_resource=0.8
) -> nx.graph.Graph:
"""Generate a random CyberBattle network model from a specified traffic (directed multi) graph.
The input graph can for instance be generated with `generate_random_traffic_network`.
Each edge of the input graph indicates that a communication took place
between the two nodes with the protocol specified in the edge label.
Returns a CyberBattle network with the same nodes and implanted vulnerabilities
to be used to instantiate a CyverBattleSim gym.
Arguments:
cached_smb_password_probability, cached_rdp_password_probability:
probability that a password used for authenticated traffic was cached by the OS for SMB and RDP
cached_accessed_network_shares_probability:
probability that a network share accessed by the system was cached by the OS
cached_password_has_changed_probability:
probability that a given password cached on a node has been rotated on the target node
(typically low has people tend to change their password infrequently)
probability_two_nodes_use_same_password_to_access_given_resource:
as the variable name says
traceroute_discovery_probability:
probability that a target node of an SMB/RDP connection get exposed by a traceroute attack
"""
# convert node IDs to string
graph = nx.relabel_nodes(traffic_graph, {i: str(i) for i in traffic_graph.nodes})
password_counter: int = 0
def generate_password() -> CredentialID:
nonlocal password_counter
password_counter = password_counter + 1
return f'unique_pwd{password_counter}'
def traffic_targets(source_node: NodeID, protocol: str) -> List[NodeID]:
neighbors = [t for (s, t) in graph.edges()
if s == source_node and protocol in graph.edges[(s, t)]['protocol']]
return neighbors
# Map (node, port name) -> assigned pwd
assigned_passwords: DefaultDict[Tuple[NodeID, PortName],
List[CredentialID]] = defaultdict(list)
def assign_new_valid_password(node: NodeID, port: PortName) -> CredentialID:
pwd = generate_password()
assigned_passwords[node, port].append(pwd)
return pwd
def reuse_valid_password(node: NodeID, port: PortName) -> CredentialID:
"""Reuse a password already assigned to that node an port, if none is already
assigned create and assign a new valid password"""
if (node, port) not in assigned_passwords:
return assign_new_valid_password(node, port)
# reuse any of the existing assigne valid password for that node/port
return random.choice(assigned_passwords[node, port])
def create_cached_credential(node: NodeID, port: PortName) -> CredentialID:
if random.random() < cached_password_has_changed_probability:
# generate a new invalid password
return generate_password()
else:
if random.random() < probability_two_nodes_use_same_password_to_access_given_resource:
return reuse_valid_password(node, port)
else:
return assign_new_valid_password(node, port)
def add_leak_neighbors_vulnerability(
node_id: m.NodeID,
library: m.VulnerabilityLibrary = {}) -> m.VulnerabilityLibrary:
"""Create random vulnerabilities
that reveals immediate traffic neighbors from a given node"""
rdp_neighbors = traffic_targets(node_id, 'RDP')
if len(rdp_neighbors) > 0:
library['ScanWindowsCredentialManagerForRDP'] = m.VulnerabilityInfo(
description="Look for RDP credentials in the Windows Credential Manager",
type=m.VulnerabilityType.LOCAL,
outcome=m.LeakedCredentials(credentials=[
m.CachedCredential(node=target_node, port='RDP',
credential=create_cached_credential(target_node, 'RDP'))
for target_node in rdp_neighbors
if random.random() < cached_rdp_password_probability
]),
reward_string="Discovered creds in the Windows Credential Manager",
cost=2.0
)
smb_neighbors = traffic_targets(node_id, 'SMB')
if len(smb_neighbors) > 0:
library['ScanWindowsExplorerRecentFiles'] = m.VulnerabilityInfo(
description="Look for network shares in the Windows Explorer Recent files",
type=m.VulnerabilityType.LOCAL,
outcome=m.LeakedNodesId(
[target_node
for target_node in smb_neighbors
if random.random() < cached_accessed_network_shares_probability
]
),
reward_string="Windows Explorer Recent Files revealed network shares",
cost=1.0
)
library['ScanWindowsCredentialManagerForSMB'] = m.VulnerabilityInfo(
description="Look for network credentials in the Windows Credential Manager",
type=m.VulnerabilityType.LOCAL,
outcome=m.LeakedCredentials(credentials=[
m.CachedCredential(node=target_node, port='SMB',
credential=create_cached_credential(target_node, 'SMB'))
for target_node in smb_neighbors
if random.random() < cached_smb_password_probability
]),
reward_string="Discovered SMB creds in the Windows Credential Manager",
cost=2.0
)
if len(smb_neighbors) > 0 and len(rdp_neighbors) > 0:
library['Traceroute'] = m.VulnerabilityInfo(
description="Attempt to discvover network nodes using Traceroute",
type=m.VulnerabilityType.REMOTE,
outcome=m.LeakedNodesId(
[target_node
for target_node in smb_neighbors or rdp_neighbors
if random.random() < traceroute_discovery_probability
]
),
reward_string="Discovered new network nodes via traceroute",
cost=5.0
)
return library
def create_vulnerabilities_from_traffic_data(node_id: m.NodeID):
return add_leak_neighbors_vulnerability(node_id=node_id)
# Pick a random node as the agent entry node
entry_node_index = random.randrange(len(graph.nodes))
entry_node_id, entry_node_data = list(graph.nodes(data=True))[entry_node_index]
graph.nodes[entry_node_id].clear()
graph.nodes[entry_node_id].update(
{'data': m.NodeInfo(services=[],
value=0,
properties=["breach_node"],
vulnerabilities=create_vulnerabilities_from_traffic_data(entry_node_id),
agent_installed=True,
reimagable=False)})
def create_node_data(node_id: m.NodeID):
return m.NodeInfo(
services=[m.ListeningService(name=port, allowedCredentials=assigned_passwords[(target_node, port)])
for (target_node, port) in assigned_passwords.keys()
if target_node == node_id
],
value=random.randint(0, 100),
vulnerabilities=create_vulnerabilities_from_traffic_data(node_id),
agent_installed=False)
for node in list(graph.nodes):
if node != entry_node_id:
graph.nodes[node].clear()
graph.nodes[node].update({'data': create_node_data(node)})
return graph
def new_environment():
"""Create a new simulation environment based on
a randomly generated network topology.
NOTE: the probabilities and parameter values used
here for the statistical generative model
were arbirarily picked. We recommend exploring different values for those parameters.
"""
traffic = generate_random_traffic_network(seed=1,
n_clients=50,
n_servers={
"SMB": 15,
"HTTP": 15,
"RDP": 15,
},
alpha=[(1, 1), (0.2, 0.5)],
beta=[(1000, 10), (10, 100)])
network = cyberbattle_model_from_traffic_graph(
traffic,
cached_rdp_password_probability=0.8,
cached_smb_password_probability=0.7,
cached_accessed_network_shares_probability=0.8,
cached_password_has_changed_probability=0.01,
probability_two_nodes_use_same_password_to_access_given_resource=0.9)
return m.Environment(network=network,
vulnerability_library=dict([]),
identifiers=ENV_IDENTIFIERS)

Просмотреть файл

@ -0,0 +1,587 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Data model for the simulation environment.
The simulation environment is given by the directed graph
formally defined by:
Node := NodeID x ListeningService[] x Value x Vulnerability[] x FirewallConfig
Edge := NodeID x NodeID x PortName
where:
- NodeID: string
- ListeningService : Name x AllowedCredentials
- AllowedCredentials : string[] # credential pair represented by just a
string ID
- Value : [0...100] # Intrinsic value of reaching this node
- Vulnerability : VulnerabilityID x Type x Precondition x Outcome x Rates
- VulnerabilityID : string
- Rates : ProbingDetectionRate x ExploitDetectionRate x SuccessRate
- FirewallConfig: {
outgoing : FirwallRule[]
incoming : FirwallRule [] }
- FirewallRule: PortName x { ALLOW, BLOCK }
"""
from datetime import datetime, time
from typing import NamedTuple, List, Dict, Optional, Union, Tuple, Iterator
import dataclasses
from dataclasses import dataclass
import matplotlib.pyplot as plt # type:ignore
from enum import Enum, IntEnum
import boolean
import networkx as nx
import yaml
import random
VERSION_TAG = "0.1.0"
ALGEBRA = boolean.BooleanAlgebra()
# These two lines define True as the dual of False and vice versa
# it's necessary in order to make sure the simplify function in boolean.py
# works correctly. See https://github.com/bastikr/boolean.py/issues/82
ALGEBRA.TRUE.dual = type(ALGEBRA.FALSE)
ALGEBRA.FALSE.dual = type(ALGEBRA.TRUE)
# Type alias for identifiers
NodeID = str
# A unique identifier
ID = str
# a (login,password/token) credential pair is abstracted as just a unique
# string identifier
CredentialID = str
# Intrinsic value of a reaching a given node in [0,100]
NodeValue = int
PortName = str
@dataclass
class ListeningService:
"""A service port on a given node accepting connection initiated
with the specified allowed credentials """
# Name of the port the service is listening to
name: PortName
# credential allowed to authenticate with the service
allowedCredentials: List[CredentialID] = dataclasses.field(default_factory=list)
# whether the service is running or stopped
running: bool = True
# Weight used to evaluate the cost of not running the service
sla_weight = 1.0
x = ListeningService(name='d')
VulnerabilityID = str
# Probability rate
Probability = float
# The name of a node property indicating the presence of a
# service, component, feature or vulnerability on a given node.
PropertyName = str
class Rates(NamedTuple):
"""Probabilities associated with a given vulnerability"""
probingDetectionRate: Probability = 0.0
exploitDetectionRate: Probability = 0.0
successRate: Probability = 1.0
class VulnerabilityType(Enum):
"""Is the vulnerability exploitable locally or remotely?"""
LOCAL = 1
REMOTE = 2
class PrivilegeLevel(IntEnum):
"""Access privilege level on a given node"""
NoAccess = 0
LocalUser = 1
Admin = 2
System = 3
MAXIMUM = 3
def escalate(current_level, escalation_level: PrivilegeLevel) -> PrivilegeLevel:
return PrivilegeLevel(max(int(current_level), int(escalation_level)))
class VulnerabilityOutcome:
"""Outcome of exploiting a given vulnerability"""
class LateralMove(VulnerabilityOutcome):
"""Lateral movement to the target node"""
success: bool
class CustomerData(VulnerabilityOutcome):
"""Access customer data on target node"""
class PrivilegeEscalation(VulnerabilityOutcome):
"""Privilege escalation outcome"""
def __init__(self, level: PrivilegeLevel):
self.level = level
@property
def tag(self):
"""Escalation tag that gets added to node properties when
the escalation level is reached for that node"""
return f"privilege_{self.level}"
class SystemEscalation(PrivilegeEscalation):
"""Escalation to SYSTEM privileges"""
def __init__(self):
super().__init__(PrivilegeLevel.System)
class AdminEscalation(PrivilegeEscalation):
"""Escalation to local administrator privileges"""
def __init__(self):
super().__init__(PrivilegeLevel.Admin)
class ProbeSucceeded(VulnerabilityOutcome):
"""Probing succeeded"""
def __init__(self, discovered_properties: List[PropertyName]):
self.discovered_properties = discovered_properties
class ProbeFailed(VulnerabilityOutcome):
"""Probing failed"""
class ExploitFailed(VulnerabilityOutcome):
"""This is for situations where the exploit fails """
class CachedCredential(NamedTuple):
"""Encodes a machine-port-credential triplet"""
node: NodeID
port: PortName
credential: CredentialID
class LeakedCredentials(VulnerabilityOutcome):
"""A set of credentials obtained by exploiting a vulnerability"""
credentials: List[CachedCredential]
def __init__(self, credentials: List[CachedCredential]):
self.credentials = credentials
class LeakedNodesId(VulnerabilityOutcome):
"""A set of node IDs obtained by exploiting a vulnerability"""
def __init__(self, nodes: List[NodeID]):
self.nodes = nodes
VulnerabilityOutcomes = Union[
LeakedCredentials, LeakedNodesId, PrivilegeEscalation, AdminEscalation,
SystemEscalation, CustomerData, LateralMove, ExploitFailed]
class AttackResult():
"""The result of attempting a specific attack (either local or remote)"""
success: bool
expected_outcome: Union[VulnerabilityOutcomes, None]
class Precondition:
""" A predicate logic expression defining the condition under which a given
feature or vulnerability is present or not.
The symbols used in the expression refer to properties associated with
the corresponding node.
E.g. 'Win7', 'Server', 'IISInstalled', 'SQLServerInstalled',
'AntivirusInstalled' ...
"""
expression: boolean.Expression
def __init__(self, expression: Union[boolean.Expression, str]):
if isinstance(expression, boolean.Expression):
self.expression = expression
else:
self.expression = ALGEBRA.parse(expression)
class VulnerabilityInfo(NamedTuple):
"""Definition of a known vulnerability"""
# an optional description of what the vulnerability is
description: str
# type of vulnerability
type: VulnerabilityType
# what happens when successfully exploiting the vulnerability
outcome: VulnerabilityOutcome
# a boolean expression over a node's properties determining if the
# vulnerability is present or not
precondition: Precondition = Precondition("true")
# rates of success/failure associated with this vulnerability
rates: Rates = Rates()
# points to information about the vulnerability
URL: str = ""
# some cost associated with exploiting this vulnerability (e.g.
# brute force more costly than dumping credentials)
cost: float = 1.0
# a string displayed when the vulnerability is successfully exploited
reward_string: str = ""
# A dictionary storing information about all supported vulnerabilities
# or features supported by the simulation.
# This is to be used as a global dictionary pre-populated before
# starting the simulation and estimated from real-world data.
VulnerabilityLibrary = Dict[VulnerabilityID, VulnerabilityInfo]
class RulePermission(Enum):
"""Determine if a rule is blocks or allows traffic"""
ALLOW = 0
BLOCK = 1
class FirewallRule(NamedTuple):
"""A firewall rule"""
# A port name
port: PortName
# permission on this port
permission: RulePermission
# An optional reason for the block/allow rule
reason: str = ""
class FirewallConfiguration(NamedTuple):
"""Firewall configuration on a given node.
Determine if traffic should be allowed or specifically blocked
on a given port for outgoing and incoming traffic.
The rules are process in order: the first rule matching a given
port is applied and the rest are ignored.
Port that are not listed in the configuration
are assumed to be blocked. (Adding an explicit block rule
can still be useful to give a reason for the block.)
"""
outgoing: List[FirewallRule] = [
FirewallRule("RDP", RulePermission.ALLOW),
FirewallRule("SSH", RulePermission.ALLOW),
FirewallRule("HTTPS", RulePermission.ALLOW),
FirewallRule("HTTP", RulePermission.ALLOW)]
incoming: List[FirewallRule] = [
FirewallRule("RDP", RulePermission.ALLOW),
FirewallRule("SSH", RulePermission.ALLOW),
FirewallRule("HTTPS", RulePermission.ALLOW),
FirewallRule("HTTP", RulePermission.ALLOW)]
class MachineStatus(Enum):
"""Machine running status"""
Stopped = 0
Running = 1
Imaging = 2
@dataclass
class NodeInfo:
"""A computer node in the enterprise network"""
# List of port/protocol the node is listening to
services: List[ListeningService]
# List of known vulnerabilities for the node
vulnerabilities: VulnerabilityLibrary = dataclasses.field(default_factory=dict)
# Intrinsic value of the node (translates into a reward if the node gets owned)
value: NodeValue = 0
# Properties of the nodes, some of which can imply further vulnerabilities
properties: List[PropertyName] = dataclasses.field(default_factory=list)
# Fireall configuration of the node
firewall: FirewallConfiguration = FirewallConfiguration()
# Attacker agent installed on the node? (aka the node is 'pwned')
agent_installed: bool = False
# Esclation level
privilege_level: PrivilegeLevel = PrivilegeLevel.NoAccess
# Can the node be re-imaged by a defender agent?
reimagable: bool = True
# Last time the node was reimaged
last_reimaging: Optional[time] = None
# String displayed when the node gets owned
owned_string: str = ""
# Machine status: running or stopped
status = MachineStatus.Running
# Relative node weight used to calculate the cost of stopping this machine
# or its services
sla_weight: float = 1.0
class Identifiers(NamedTuple):
"""Define the global set of identifiers used
in the definition of a given environment.
Such set defines a common vocabulary possibly
shared across multiple environments, thus
ensuring a consistent numbering convention
that a machine learniong model can learn from."""
# Array of all possible node property identifiers
properties: List[PropertyName] = []
# Array of all possible port names
ports: List[PortName] = []
# Array of all possible local vulnerabilities names
local_vulnerabilities: List[VulnerabilityID] = []
# Array of all possible remote vulnerabilities names
remote_vulnerabilities: List[VulnerabilityID] = []
def iterate_network_nodes(network: nx.graph.Graph) -> Iterator[Tuple[NodeID, NodeInfo]]:
"""Iterates over the nodes in the network"""
for nodeid, nodevalue in network.nodes.items():
node_data: NodeInfo = nodevalue['data']
yield nodeid, node_data
class Environment(NamedTuple):
""" The static graph defining the network of computers """
network: nx.graph.Graph
vulnerability_library: VulnerabilityLibrary
identifiers: Identifiers
creationTime: datetime = datetime.utcnow()
lastModified: datetime = datetime.utcnow()
# a version tag indicating the environment schema version
version: str = VERSION_TAG
def nodes(self) -> Iterator[Tuple[NodeID, NodeInfo]]:
"""Iterates over the nodes in the network"""
return iterate_network_nodes(self.network)
def get_node(self, node_id: NodeID) -> NodeInfo:
"""Retrieve info for the node with the specified ID"""
node_info: NodeInfo = self.network.nodes[node_id]['data']
return node_info
def plot_environment_graph(self) -> None:
"""Plot the full environment graph"""
nx.draw(self.network,
with_labels=True,
node_color=[n['data'].value
for i, n in
self.network.nodes.items()],
cmap=plt.cm.Oranges) # type:ignore
def create_network(nodes: Dict[NodeID, NodeInfo]) -> nx.DiGraph:
"""Create a network with a set of nodes and no edges"""
graph = nx.DiGraph()
graph.add_nodes_from([(k, {'data': v}) for (k, v) in list(nodes.items())])
return graph
# Helpers to infer constants from an environment
def collect_ports_from_vuln(vuln: VulnerabilityInfo) -> List[PortName]:
"""Returns all the port named referenced in a given vulnerability"""
if isinstance(vuln.outcome, LeakedCredentials):
return [c.port for c in vuln.outcome.credentials]
else:
return []
def collect_vulnerability_ids_from_nodes_bytype(
nodes: Iterator[Tuple[NodeID, NodeInfo]],
global_vulnerabilities: VulnerabilityLibrary,
type: VulnerabilityType) -> List[VulnerabilityID]:
"""Collect and return all IDs of all vulnerability of the specified type
that are referenced in a given set of nodes and vulnerability library
"""
return sorted(list({
id
for _, node_info in nodes
for id, v in node_info.vulnerabilities.items()
if v.type == type
}.union(
id
for id, v in global_vulnerabilities.items()
if v.type == type
)))
def collect_properties_from_nodes(nodes: Iterator[Tuple[NodeID, NodeInfo]]) -> List[PropertyName]:
"""Collect and return sorted list of all property names used in a given set of nodes"""
return sorted({
p
for _, node_info in nodes
for p in node_info.properties
})
def collect_ports_from_nodes(
nodes: Iterator[Tuple[NodeID, NodeInfo]],
vulnerability_library: VulnerabilityLibrary) -> List[PortName]:
"""Collect and return all port names used in a given set of nodes
and global vulnerability library"""
return sorted(list({
port
for _, v in vulnerability_library.items()
for port in collect_ports_from_vuln(v)
}.union({
port
for _, node_info in nodes
for _, v in node_info.vulnerabilities.items()
for port in collect_ports_from_vuln(v)
}.union(
{service.name
for _, node_info in nodes
for service in node_info.services}))))
def collect_ports_from_environment(environment: Environment) -> List[PortName]:
"""Collect and return all port names used in a given environment"""
return collect_ports_from_nodes(environment.nodes(), environment.vulnerability_library)
def infer_constants_from_nodes(
nodes: Iterator[Tuple[NodeID, NodeInfo]],
vulnerabilities: Dict[VulnerabilityID, VulnerabilityInfo]) -> Identifiers:
"""Infer global environment constants from a given network"""
return Identifiers(
properties=collect_properties_from_nodes(nodes),
ports=collect_ports_from_nodes(nodes, vulnerabilities),
local_vulnerabilities=collect_vulnerability_ids_from_nodes_bytype(
nodes, vulnerabilities, VulnerabilityType.LOCAL),
remote_vulnerabilities=collect_vulnerability_ids_from_nodes_bytype(
nodes, vulnerabilities, VulnerabilityType.REMOTE)
)
def infer_constants_from_network(
network: nx.Graph,
vulnerabilities: Dict[VulnerabilityID, VulnerabilityInfo]) -> Identifiers:
"""Infer global environment constants from a given network"""
return infer_constants_from_nodes(iterate_network_nodes(network), vulnerabilities)
# Network creation
# A sample set of envrionment constants
SAMPLE_IDENTIFIERS = Identifiers(
ports=['RDP', 'SSH', 'SMB', 'HTTP', 'HTTPS', 'WMI', 'SQL'],
properties=[
'Windows', 'Linux', 'HyperV-VM', 'Azure-VM', 'Win7', 'Win10',
'PortRDPOpen', 'GuestAccountEnabled']
)
def assign_random_labels(
graph: nx.Graph,
vulnerabilities: VulnerabilityLibrary = dict([]),
identifiers: Identifiers = SAMPLE_IDENTIFIERS) -> nx.Graph:
"""Create an envrionment network by randomly assigning node information
(properties, firewall configuration, vulnerabilities)
to the nodes of a given graph structure"""
# convert node IDs to string
graph = nx.relabel_nodes(graph, {i: str(i) for i in graph.nodes})
def create_random_firewall_configuration() -> FirewallConfiguration:
return FirewallConfiguration(
outgoing=[
FirewallRule(port=p, permission=RulePermission.ALLOW)
for p in
random.sample(
identifiers.properties,
k=random.randint(0, len(identifiers.properties)))],
incoming=[
FirewallRule(port=p, permission=RulePermission.ALLOW)
for p in random.sample(
identifiers.properties,
k=random.randint(0, len(identifiers.properties)))])
def create_random_properties() -> List[PropertyName]:
return list(random.sample(
identifiers.properties,
k=random.randint(0, len(identifiers.properties))))
def pick_random_global_vulnerabilities() -> VulnerabilityLibrary:
count = random.random()
return {k: v for (k, v) in vulnerabilities.items() if random.random() > count}
def add_leak_neighbors_vulnerability(library: VulnerabilityLibrary, node_id: NodeID) -> None:
"""Create a vulnerability for each node that reveals its immediate neighbors"""
neighbors = {t for (s, t) in graph.edges() if s == node_id}
if len(neighbors) > 0:
library['RecentlyAccessedMachines'] = VulnerabilityInfo(
description="AzureVM info, including public IP address",
type=VulnerabilityType.LOCAL,
outcome=LeakedNodesId(list(neighbors)))
def create_random_vulnerabilities(node_id: NodeID) -> VulnerabilityLibrary:
library = pick_random_global_vulnerabilities()
add_leak_neighbors_vulnerability(library, node_id)
return library
# Pick a random node as the agent entry node
entry_node_index = random.randrange(len(graph.nodes))
entry_node_id, entry_node_data = list(graph.nodes(data=True))[entry_node_index]
graph.nodes[entry_node_id].clear()
node_data = NodeInfo(services=[],
value=0,
properties=create_random_properties(),
vulnerabilities=create_random_vulnerabilities(entry_node_id),
firewall=create_random_firewall_configuration(),
agent_installed=True,
reimagable=False,
privilege_level=PrivilegeLevel.Admin)
graph.nodes[entry_node_id].update({'data': node_data})
def create_random_node_data(node_id: NodeID) -> NodeInfo:
return NodeInfo(
services=[],
value=random.randint(0, 100),
properties=create_random_properties(),
vulnerabilities=create_random_vulnerabilities(node_id),
firewall=create_random_firewall_configuration(),
agent_installed=False,
privilege_level=PrivilegeLevel.NoAccess)
for node in list(graph.nodes):
if node != entry_node_id:
graph.nodes[node].clear()
graph.nodes[node].update({'data': create_random_node_data(node)})
return graph
# Serialization
def setup_yaml_serializer() -> None:
"""Setup a clean YAML formatter for object of type Environment.
"""
yaml.add_representer(Precondition,
lambda dumper, data: dumper.represent_scalar('!BooleanExpression',
str(data.expression))) # type: ignore
yaml.SafeLoader.add_constructor('!BooleanExpression',
lambda loader, expression: Precondition(
loader.construct_scalar(expression))) # type: ignore
yaml.add_constructor('!BooleanExpression',
lambda loader, expression:
Precondition(loader.construct_scalar(expression))) # type: ignore
yaml.add_representer(VulnerabilityType,
lambda dumper, data: dumper.represent_scalar('!VulnerabilityType',
str(data.name))) # type: ignore
yaml.SafeLoader.add_constructor('!VulnerabilityType',
lambda loader, expression: VulnerabilityType[
loader.construct_scalar(expression)]) # type: ignore
yaml.add_constructor('!VulnerabilityType',
lambda loader, expression: VulnerabilityType[
loader.construct_scalar(expression)]) # type: ignore

Просмотреть файл

@ -0,0 +1,132 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
Unit tests for model.py.
Note that model.py mainly provides the data modelling for the simulation,
that is naked data types without members. There is therefore not much
relevant unit testing that can be implemented at this stage.
Once we add operations to generate and modify environments there
will be more room for unit-testing.
"""
# pylint: disable=missing-function-docstring
from cyberbattle.simulation.model import AdminEscalation, Identifiers, SystemEscalation
import yaml
from datetime import datetime
import networkx as nx
from . import model
ADMINTAG = AdminEscalation().tag
SYSTEMTAG = SystemEscalation().tag
vulnerabilities = {
"UACME61":
model.VulnerabilityInfo(
description="UACME UAC bypass #61",
type=model.VulnerabilityType.LOCAL,
URL="https://github.com/hfiref0x/UACME",
precondition=model.Precondition(f"Windows&Win10&(~({ADMINTAG}|{SYSTEMTAG}))"),
outcome=model.AdminEscalation(),
rates=model.Rates(0, 0.2, 1.0))}
ENV_IDENTIFIERS = Identifiers(
properties=[],
ports=[],
local_vulnerabilities=["UACME61"],
remote_vulnerabilities=[]
)
# Verify that there is a unique node injected with the
# attacker in a randomly generated graph
def test_single_infected_node_initially() -> None:
# create a random environment
graph = nx.cubical_graph()
graph = model.assign_random_labels(graph)
env = model.Environment(network=graph,
vulnerability_library=dict([]),
identifiers=ENV_IDENTIFIERS)
count = sum(1 for i in graph.nodes
if env.get_node(i).agent_installed)
assert count == 1
return
# ensures that an environment can successfully be serialized as yaml
def test_environment_is_serializable() -> None:
# create a random environment
env = model.Environment(
network=model.assign_random_labels(nx.cubical_graph()),
version=model.VERSION_TAG,
vulnerability_library=dict([]),
identifiers=ENV_IDENTIFIERS,
creationTime=datetime.utcnow(),
lastModified=datetime.utcnow(),
)
# Dump the environment as Yaml
_ = yaml.dump(env)
assert True
# Test random graph get_node_information
def test_create_random_environment() -> None:
graph = nx.cubical_graph()
graph = model.assign_random_labels(graph)
env = model.Environment(
network=graph,
vulnerability_library=vulnerabilities,
identifiers=ENV_IDENTIFIERS
)
assert env
pass
def check_reserializing(object_to_serialize: object) -> None:
"""Helper function to check that deserializing and serializing are inverse of each other"""
serialized = yaml.dump(object_to_serialize)
# print('Serialized: ' + serialized)
deserialized = yaml.load(serialized, yaml.Loader)
re_serialized = yaml.dump(deserialized)
assert (serialized == re_serialized)
def test_yaml_serialization_environment() -> None:
"""Test Yaml serialization and deserialization for type Environment"""
network = model.assign_random_labels(nx.cubical_graph())
env = model.Environment(
network=network,
vulnerability_library=vulnerabilities,
identifiers=model.infer_constants_from_network(network, vulnerabilities))
model.setup_yaml_serializer()
serialized = yaml.dump(env)
assert (len(serialized) > 100)
check_reserializing(env)
def test_yaml_serialization_precondition() -> None:
"""Test Yaml serialization and deserialization for type Precondition"""
model.setup_yaml_serializer()
precondition = model.Precondition(f"Windows&Win10&(~({ADMINTAG}|{SYSTEMTAG}))")
check_reserializing(precondition)
deserialized = yaml.safe_load(yaml.dump(precondition))
assert (precondition.expression == deserialized.expression)
def test_yaml_serialization_vulnerabilitytype() -> None:
"""Test Yaml serialization and deserialization for type VulnerabilityType"""
model.setup_yaml_serializer()
object_to_serialize = model.VulnerabilityType.LOCAL
check_reserializing(object_to_serialize)

Двоичные данные
docs/.attachments/image-11d24066-875d-43ac-87cb-e91453688028.png Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 56 KiB

Двоичные данные
docs/.attachments/image-377114ff-cdb7-4bee-88da-cac09640f661.png Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 45 KiB

Двоичные данные
docs/.attachments/image-41f45aa6-0af8-4c67-afec-24b1adc910c2.png Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 67 KiB

Двоичные данные
docs/.attachments/image-4fd79f98-36b2-45ae-82e4-2631aacda090.png Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 192 KiB

Двоичные данные
docs/.attachments/image-54d83b7b-65d1-4d6a-b0f6-d41b31460c81.png Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 103 KiB

Двоичные данные
docs/.attachments/image-8cfbbc68-6db1-42f2-867d-5502ff56c4b3.png Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 18 KiB

Двоичные данные
docs/.attachments/image-97b85206-f37d-4798-acc9-12b347808202.png Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 32 KiB

Двоичные данные
docs/.attachments/image-9b0b1506-880b-43e5-91b0-75f2ce3fb032.png Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 62 KiB

Двоичные данные
docs/.attachments/image-9f950a75-2c63-457a-b109-56091f84711a.png Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 92 KiB

Двоичные данные
docs/.attachments/image-cdb2b5e1-92f9-4a9e-af9f-b1a9bcae96a5.png Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 205 KiB

Двоичные данные
docs/.attachments/image-daf58e3d-a4a8-4810-8a5e-1c976a24b266.png Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 57 KiB

Двоичные данные
docs/.attachments/image-f8f00fe7-466f-4d2b-aaee-dd20720854db.png Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 81 KiB

12
docs/.gitignore поставляемый Normal file
Просмотреть файл

@ -0,0 +1,12 @@
.packages
PackageRoot/*
**/TestResults/*
**/.vs/*
**/bin/*
**/obj/*
msbuild.log
*.user
*.trace
**/.nupkg/*
**/.pkg/*
**/packages/*

39
docs/benchmark.md Normal file
Просмотреть файл

@ -0,0 +1,39 @@
# Benchmark on baseline agent implementations
Results obtained on envrionment `CyberBattleChain10` using several algorithms: purely random search, Q-learning with epsilon-greedy, exploiting learned matrix only.
## Time to full ownership on `chain` network environment
Training plot showing duration taken to take over the entire network (Y-axis) across successive episodes (X-axis) for different attacker agent implementations (Tabular Q-learning, epsilon-greedy, Deep Q-Learning, Exploiting Q-function learnt from DQL).
Lower number of iterations is better (the game is set to terminate when the attacker owns the entire network).
Best attacker so far is the Deep Q-Learning
![image.png](.attachments/image-54d83b7b-65d1-4d6a-b0f6-d41b31460c81.png)
The next plot shows the cumulative reward function. Note that when once all the network nodes are owned there may still be additional rewards to be obtained by exploiting vulnerabilities on the owned nodes,
but on this experiment we terminate the game as soon as the agent owns all the nodes. This explains why
the DQL agent, which optimizes for network ownership, does not get to reach the maximum possible reward despite
beating all the other agents. The gym's `done` function can easily re-configured to target a specific reward instead, in which case the DQL also beats the other agents.
![image.png](.attachments/image-f8f00fe7-466f-4d2b-aaee-dd20720854db.png)
### Choice of features
With Q-learning, the best results were obtained when training on features that include network size-specific features, such as the number of discovered nodes (left). Good results are also obtained when using features that are not dependent on the network size (right).
| | Size agnostic features | Size-dependant features|
|--|--|--|
| Tabular Q vs Random | ![image.png](.attachments/image-9b0b1506-880b-43e5-91b0-75f2ce3fb032.png) | ![image.png](.attachments/image-41f45aa6-0af8-4c67-afec-24b1adc910c2.png) |
## Transfer Learning on `chain` environment
This benchmark aims to measure the ability to learn a strategy from one environment and
apply it to similar envrionments of a different size. We train the agent on an environment of size $x$ and evaluate it on an environment of size $y>x$.
As expected, using features that are proportional to the size of the environment (such as the number of nodes or number of number of credentials) did not provide best results. The agent fared better instead when using temporal features like a sliding-window of ports and node properties recently discovered.
| | Train on size 4 , evaluated on size 10 | - Train on size 10, evaluated on size 4 |
|---|---|---|
Tabular Q vs Random | ![image.png](.attachments/image-11d24066-875d-43ac-87cb-e91453688028.png) | ![image.png](.attachments/image-daf58e3d-a4a8-4810-8a5e-1c976a24b266.png) |

5
docs/gendoc.sh Normal file
Просмотреть файл

@ -0,0 +1,5 @@
pandoc -o docs.tex quickintro.md benchmark.md ../README.md -s
pdflatex docs.tex
# \DeclareUnicodeCharacter{2500}{─}
# \DeclareUnicodeCharacter{03F5}{ϵ}
# pandoc -o docs.html quickintro.md benchmark.md ../README.md

190
docs/quickintro.md Normal file
Просмотреть файл

@ -0,0 +1,190 @@
# Quick introduction to CyberBattleSim
## What is it?
A high-level parameterizable model of enterprise networks that simulates the execution of attack and defense cyber-agent.
A network topology and set of pre-defined vulneraiblities defines the arena on which the simulation is played.
The attacker evolves in the network via lateral movements by exploiting existing vulneratiblities.
The defender attempts to contain the attacker and evict it from the nework.
CyberBattleSim offers an OpenAI Gym interface to its simulation to facilitate experimentation with Reinforcement Learning algorithms.
## Why a simulation environment?
Emulation runtime environments provide high fidelity and control: you can take existing code or binaries and run them directly in virtual machines running full-fledged Operating Systems and connected over a virtualized network, while giving access to the full system state. This comes at a performance cost, however.
While simulated environments suffers from lack of realism, they tend to be Lightweight, Fast, Abstract, and more Controllable, which
makes them more amenable to experimentations with Reinforcement Learning.
Advantages of simulation include:
- Higher-level abstraction: we can modeled aspects of the system that matters to us, like application-level network communication versus packet-level network simulation. We can ignore low-level details if deemed uncessary (E.g., file system, registry).
- Flexibility: Defining new machine sensors is straightforward (e.g., does not require low-level/driver code changes); we can restrict the action space to a manageable and relevant subset.
- The global state is efficiently capturable, simplifying debugging and diagnosis.
- A lightweight runtime footprint: run in memory on single machine/process.
We may explore the use of emulation technologies in the future, in particular for benchmarking purpose, as it would provide a more realistic assessment of the agents performances.
A design principle adopted is for the simulation to model just enough _complexity_ to represent attack techniques from the [MITRE matrix](https://attack.mitre.org/) while maintaining the _simplicity_ required to efficiently train an agent using Reinforcement Learning techniques.
On the attack-side, the current simulation focuses more particularly on the `Lateral movement` technique which are intrinsic to all post-breach attacks.
## How the simulation works
Let us go through a toy example and introduce how simulation works
using RL terminology.
Our network __environment__ is given by a directed annotated graph where
nodes represent computers and edges represent knowledge of other nodes or communication taking place between nodes.
![image.png](.attachments/image-377114ff-cdb7-4bee-88da-cac09640f661.png)
Here you can see a toy example of network
with machines running different OSes, software.
Each machine has properties, a value, and suffers from pre-assigned vulnerabilities.
Blue edges represent traffic running between nodes and
are labelled by the communication protocol.
![image.png](.attachments/image-9f950a75-2c63-457a-b109-56091f84711a.png)
There is a **single agent**: the attacker.
Initially, one node is infected (post-breach assumption)
Its **goal** is to maximize reward by discovering and 'owning' nodes
in the network.
The environment is **partially observable**: the agent does not get to see all the
nodes and edges of the network graph in advance.
Instead the attacker takes actions to gradually observe the environment. There are **three kinds of actions**
offering a mix of exploitation and exploration capabilities
to the agent:
- perform a local attack,
- perform a remote attack,
- connect to other nodes.
The **reward** is a float represents the intrinsic value
of a node (e.g., a SQL server has greater value than a test machine).
The attacker breaches into the network from the Win7 node
on the left pointed by the fat orange arrow,
- then proceeds with a lateral move to the Win8 node
by exploiting a vulnerability in SMB,
- then uses some cached credential to log into a Win7 machine,
- exploits an IIS remote vulnerability to own the IIS server,
- and finally uses leaked connection strings to get to the SQL DB.
## What kind of vulnerabilities and attacks are supported by the simulation?
The simulation gym environment is parameterized by the network definition which consists of the underlying network graph itself together with description of supported vulnerabilities and the nodes were they are present.
Since the simulation does not run any code, there is no way to actually implement vulnerabilities and exploits. Instead, we model each vulnerability abstractly by defining: a pre-condition determining if the a vulnerability is active on a given node; the probability that it can be successfully exploited by an attacker; and the side-effects of a successfful exploit. Each node has a set of assigned named-properties. The pre-condition is then expressed as a boolean expression over the set of possible node properties (or flags).
### Vulnerability outcomes
Each vulnerability has a pre-defined outcome which may include:
- A leaked set of credentials;
- A leaked reference to another node in the network;
- Leaked information about a node (node properties);
- Ownerhsip to a node;
- Privilege escalation on the node.
Example of _remote_ vulnerabilities include:
- A SharePoint site exposing `ssh` credentials (but not necessarily the ID of the remote machine);
- An `ssh` vulnerability granting access to the machine;
- A github project leaking credentials in commit history;
- A SharePoint site with file containing SAS token to storage account;
Examples of _local_ vulnerabilities:
- Extracting authentication token or credentials from a system cache;
- Escalating to SYSTEM privileges;
- Escalating to Administrator privileges.
Vulnerabilities can either be defined in-placed at the node level, or can be defined globally and activated by the pre-condition boolean expression.
## Toy Example
Consider as a toy example, the 'Capture the flag' game played on the computer system depicted below:
![image.png](.attachments/image-8cfbbc68-6db1-42f2-867d-5502ff56c4b3.png)
Each graph node is a computing resource with implanted security flaws and vulnerabilities such as reused password, insucure passwords, leaked access tokens, misconfigured Access control, browsable direcotries, and so on. The goal of the attacker is to take ownership of critical nodes in the graph (e.g., Azure and Sharepoint resources). For simplicity we assume that no defender is present and that the game is fully static (no external events between two action of the attacker).
We formally defined this network in Python code at [toy_ctf.py](../cyberbattle/samples/toyctf/toy_ctf.py).
Here is a snippet of the code showing how we define the noe `Website` with its properties, firewall configuration and implanted vulnerabilities:
```python
nodes = {
"Website": m.NodeInfo(
services=[
m.ListeningService("HTTPS"),
m.ListeningService("SSH", allowedCredentials=[
"ReusedMySqlCred-web"])],
firewall=m.FirewallConfiguration(
incoming=default_allow_rules,
outgoing=default_allow_rules
+ [
m.FirewallRule("su", m.RulePermission.ALLOW),
m.FirewallRule("sudo", m.RulePermission.ALLOW)]),
value=100,
properties=["MySql", "Ubuntu", "nginx/1.10.3"],
owned_string="FLAG: Login using insecure SSH user/password",
vulnerabilities=dict(
ScanPageContent=m.VulnerabilityInfo(
description="Website page content shows a link to GitHub repo",
type=m.VulnerabilityType.REMOTE,
outcome=m.LeakedNodesId(["GitHubProject"]),
reward_string="page content has a link to a Github project",
cost=1.0
),
ScanPageSource=m.VulnerabilityInfo(
description="Website page source contains refrence to"
"browseable relative web directory",
type=m.VulnerabilityType.REMOTE,
outcome=m.LeakedNodesId(["Website.Directory"]),
reward_string="Viewing the web page source reveals a URL to"
"a .txt file and directory on the website",
cost=1.0
),
...
)
),
...
```
### Interactive play with ToyCTF
You can play the simulation interactively using the Jupyter notebook located at [toyctf-blank.ipynb](notebooks/toyctf-blank.ipynb). Try the following commands:
```python
env.plot_environment_graph()
plot()
c2.run_attack('client', 'SearchEdgeHistory')
plot()
c2.run_remote_attack('client', 'Website', 'ScanPageContent')
plot()
c2.run_remote_attack('client', 'Website', 'ScanPageSource')
plot()
c2.run_remote_attack('client', 'Website.Directory', 'NavigateWebDirectoryFurther')
plot()
c2.connect_and_infect('client', 'Website', 'SSH', 'ReusedMySqlCred-web')
plot()
```
The plot function displays the subset of the network explorsed so far. After a few attempts the explored network should looks like this:
![image.png](.attachments/image-4fd79f98-36b2-45ae-82e4-2631aacda090.png)
### Solution
The fully solved game is provided in the notebook [toyctf-solved.ipynb](notebooks/toyctf-solved.ipynb).
### A random agent playing ToyCTF
The notebook [toyctf-random.ipynb](notebooks/toyctf-random.ipynb)
runs a random agent on the ToyCTF enviornment: at each step, the agent
picks an action at random from the action space.
Here is the result after a couple of thousands iterations:
![image.png](.attachments/image-cdb2b5e1-92f9-4a9e-af9f-b1a9bcae96a5.png)

36
getpythonpath.sh Normal file
Просмотреть файл

@ -0,0 +1,36 @@
#!/bin/bash
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Look for the the required version of python, return its path in $PYTHON
# and version in $PYTHONVER.
#
# Usage: source this file to set the correct default version of Python in the PATH:
# source getpythonpath.sh
#
PYTHON=`which python3.8`
if [ -z "$PYTHON" ]; then
PYTHON=`which python3`
fi
if [ -z "$PYTHON" ]; then
PYTHON=`which python`
fi
if [ -z "$PYTHON" ]; then
echo "Could not located python interpreter: '$PYTHON'" >&2
exit -1
fi
PYTHONVER=`$PYTHON --version | cut -d' ' -f2`
if [[ ! "$PYTHONVER" == "3.8."* ]]; then
echo 'Version >=3.8 of Python is required' >&2
exit
else
echo "Compatible version $PYTHONVER of Python detected at $PYTHON"
fi
PYTHONPATH=$(dirname $PYTHON)
export PATH=$PYTHONPATH:$PATH

29
init.ps1 Normal file
Просмотреть файл

@ -0,0 +1,29 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#.SYNOPSIS
# Initialize a dev environment by installing all python dependencies on Windows.
# Not supported anymore: use WSL-based Linux instead on Windows.
param($installJupyterExtensions)
# Install pip
$pipversion = $(py -m pip --version)
if ($pipversion) {
Write-Host "pip already installed"
} else {
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
py.exe get-pip.py
}
# Install virtualenv
py -m pip install --user virtualenv
# Create a virtual environment
py.exe -m venv venv
# Install all pip dependencies in the virtual environment
& venv/Scripts/python.exe -m pip install -e .
& venv/Scripts/python.exe -m pip install -e .[dev]
# Setup pre-commit to check every git commit
& venv/Scripts/pre-commit.exe install -t pre-push

129
init.sh Executable file
Просмотреть файл

@ -0,0 +1,129 @@
#!/bin/bash
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
pushd "$(dirname "$0")"
UBUNTU_VERSION=$(lsb_release -rs)
verlte() {
[ "$1" = "`echo -e "$1\n$2" | sort -V | head -n1`" ]
}
verlte $UBUNTU_VERSION '20' && OLDER_UBUNTU=1 || OLDER_UBUNTU=0
if [ $OLDER_UBUNTU == 1 ]; then
echo "Old version of Ubuntu detected ($UBUNTU_VERSION), will register an additional apt repository to install latest version of Python"
ADD_PYTHON38_APTREPO=1
else
ADD_PYTHON38_APTREPO=0
fi
if [ ""$AML_CloudName"" != "" ]; then
echo "Running on AML machine: skipping venv creation by default"
CREATE_VENV=0
else
CREATE_VENV=1
fi
while getopts "nr" opt; do
case $opt in
n)
echo "skipping venv creation. Parameter: $OPTARG" >&2
CREATE_VENV=0
;;
r)
echo "Will add apt repo to install latest version of Python"
ADD_PYTHON38_APTREPO=1
;;
\?)
echo "Syntax: init.sh [-n]" >&2
echo " -n skip creation of virtual environment" >&2
echo " -r register required apt repository to install latest version of Python for older versions of Ubuntu (e.g. 16)" >&2
exit 1
;;
esac
done
SUDO=''
if (( $EUID != 0 )); then
SUDO='sudo -E'
fi
if [ ! -z "${VIRTUAL_ENV}" ]; then
echo 'Running under virtual environment, skipping installation of global packages';
else
if [ "${ADD_PYTHON38_APTREPO}" == "1" ]; then
echo 'Adding APT repo ppa:deadsnakes/ppa'
$SUDO apt install software-properties-common -y
$SUDO add-apt-repository ppa:deadsnakes/ppa -y
fi
$SUDO apt update
# install nodejs 12.0 (required by pyright typechecker)
curl -sL https://deb.nodesource.com/setup_12.x | $SUDO bash -
# install all apt-get dependencies
if [ $OLDER_UBUNTU == 1 ]; then
# exclude package not available on older ubuntu
cat apt-requirements.txt | grep -v python3-distutils | $SUDO xargs apt-get install -y
else
cat apt-requirements.txt | $SUDO xargs apt-get install -y
fi
# $SUDO npm -g upgrade node
$SUDO npm install -g --unsafe-perm=true --allow-root npm
# Make sure that the desired version of python is used
# in the rest of the script and when calling pyright to
# generate stubs
$SUDO update-alternatives --install /usr/bin/python python /usr/bin/python3.8 2
$SUDO update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.8 2
export PATH="/usr/bin:${PATH}"
# install pyright
./pyright.sh --version
# installing orca required to export images with plotly
./install-orca.sh
# install pip
if [ ! -f "/tmp/get-pip.py" ]; then
curl https://bootstrap.pypa.io/get-pip.py -o /tmp/get-pip.py
fi
python --version
python /tmp/get-pip.py
if [ "${CREATE_VENV}" == "1" ]; then
# Install virtualenv
python -m pip install --user virtualenv
# Create a virtual environment
python -m venv venv
source venv/bin/activate
fi
fi
./install-pythonpackages.sh
if [ "${CREATE_VENV}" == "1" ]; then
# Add venv to jupyter notebook
python -m ipykernel install --user --name=venv
fi
# setup checks on every `git push`
pre-commit install -t pre-push
./createstubs.sh
popd

16
install-orca.sh Executable file
Просмотреть файл

@ -0,0 +1,16 @@
#!/bin/bash
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
echo 'Installing the Plotly orca dependency for plotly figure export'
SUDO=''
if (( $EUID != 0 )); then
SUDO='sudo -E'
fi
xargs -a apt-requirements-orca.txt $SUDO apt-get install
$SUDO npm install -g --unsafe-perm=true --allow-root electron@6.1.4 orca

14
install-pythonpackages.sh Executable file
Просмотреть файл

@ -0,0 +1,14 @@
#!/bin/bash
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
set -e
. ./getpythonpath.sh
# Install python packages
$PYTHON -m pip install --upgrade pip
$PYTHON -m pip install wheel
$PYTHON -m pip install -e .
$PYTHON -m pip install -e .[dev]

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Просмотреть файл

@ -0,0 +1,219 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Copyright (c) Microsoft Corporation. All rights reserved. Licensed under the MIT License.\n",
"\n",
"# Chain network CyberBattle Gym played by a random agent"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Gym random agent attacking a chain-like network\n",
"\n",
"## Chain network\n",
"We consider a computer network of Windows and Linux machines where each machine has vulnerability \n",
"granting access to another machine as per the following pattern:\n",
"\n",
" Start ---> (Linux ---> Windows ---> ... Linux ---> Windows)* ---> Linux[Flag]\n",
"\n",
"The network is parameterized by the length of the central Linux-Windows chain.\n",
"The start node leaks the credentials to connect to all other nodes:\n",
"\n",
"For each `XXX ---> Windows` section, the XXX node has:\n",
" - a local vulnerability exposing the RDP password to the Windows machine\n",
" - a bunch of other trap vulnerabilities (high cost with no outcome)\n",
"For each `XXX ---> Linux` section,\n",
" - the Windows node has a local vulnerability exposing the SSH password to the Linux machine\n",
" - a bunch of other trap vulnerabilities (high cost with no outcome)\n",
"\n",
"The chain is terminated by one node with a flag (reward)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Benchmark\n",
"The following plot shows the average and one standard deviation cumulative reward over time as a random agent attacks the network."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%HTML\n",
"<img src=\"random_plot.png\" width=\"300\">"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"import logging\n",
"import gym\n",
"from gym import spaces\n",
"import numpy as np\n",
"import networkx as nx\n",
"import cyberbattle.simulation.actions as actions\n",
"import cyberbattle._env.cyberbattle_env as cyberbattle_env\n",
"import cyberbattle.agents.random_agent as random_agent\n",
"import cyberbattle.samples.chainpattern.chainpattern as chainpattern\n",
"import importlib\n",
"importlib.reload(actions)\n",
"importlib.reload(cyberbattle_env)\n",
"importlib.reload(chainpattern)\n",
"\n",
"logging.basicConfig(stream=sys.stdout, level=logging.ERROR, format=\"%(levelname)s: %(message)s\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# chainpattern.create_network_chain_link(2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"gym_env = gym.make('CyberBattleChain-v0', size=10, attacker_goal=None)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"gym_env.environment"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"gym_env.environment.network.nodes"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"gym_env.action_space"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"gym_env.action_space.sample()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"gym_env.observation_space.sample()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"outputPrepend"
]
},
"outputs": [],
"source": [
"for i in range(100) : gym_env.sample_valid_action()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false,
"tags": [
"outputPrepend"
]
},
"outputs": [],
"source": [
"random_agent.run_random_agent(1, 10000, gym_env)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"o,r,d,i = gym_env.step(gym_env.sample_valid_action())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"o"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

Двоичные данные
notebooks/random_plot.png Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 51 KiB

15754
notebooks/randomnetwork.ipynb Normal file

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Просмотреть файл

@ -0,0 +1,102 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Copyright (c) Microsoft Corporation. All rights reserved. Licensed under the MIT License.\n",
"\n",
"# Capture the Flag Toy Example - Interactive (Human player)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This is a blank instantiaion of the Capture The Flag network to be play interactively by a human player (not via the gym envrionment).\n",
"The interface exposed to the attacker is given by the following commands:\n",
" - c2.print_all_attacks()\n",
" - c2.run_attack(node, attack_id)\n",
" - c2.run_remote_attack(source_node, target_node, attack_id)\n",
" - c2.connect_and_infect(source_node, target_node, port_name, credential_id)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sys, logging\n",
"import cyberbattle.simulation.model as model\n",
"import cyberbattle.simulation.commandcontrol as commandcontrol\n",
"import cyberbattle.samples.toyctf.toy_ctf as ctf\n",
"import plotly.offline as plo\n",
"plo.init_notebook_mode(connected=True)\n",
"logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=\"%(levelname)s: %(message)s\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"network = model.create_network(ctf.nodes)\n",
"env = model.Environment(network=network, vulnerability_library=dict([]), identifiers=ctf.ENV_IDENTIFIERS)\n",
"env"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"env.plot_environment_graph()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"c2 = commandcontrol.CommandControl(env)\n",
"dbg = commandcontrol.EnvironmentDebugging(c2)\n",
"def plot():\n",
" dbg.plot_discovered_network()\n",
" c2.print_all_attacks()\n",
"plot()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.2-final"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

Просмотреть файл

@ -0,0 +1,162 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Copyright (c) Microsoft Corporation. All rights reserved. Licensed under the MIT License.\n",
"\n",
"# Random agent playing the Capture The Flag toy environment"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"import logging\n",
"import gym\n",
"from gym import spaces\n",
"import numpy as np\n",
"import networkx as nx\n",
"import cyberbattle.simulation.actions as actions\n",
"import cyberbattle.simulation.commandcontrol as commandcontrol\n",
"import cyberbattle._env.cyberbattle_env as cyberbattle_env\n",
"import importlib\n",
"importlib.reload(actions)\n",
"importlib.reload(cyberbattle_env)\n",
"importlib.reload(commandcontrol)\n",
"\n",
"logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=\"%(levelname)s: %(message)s\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### CyberBattle simulation\n",
"- **Environment**: a network of nodes with assigned vulnerabilities/functionalities, value, and firewall configuration\n",
"- **Action space**: local attack | remote attack | authenticated connection\n",
"- **Observation**: effects of action on environment"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"gym_env = gym.make('CyberBattleToyCtf-v0')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"gym_env.environment"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"gym_env.action_space"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"gym_env.action_space.sample()"
]
},
{
"cell_type": "markdown",
"metadata": {
"scrolled": false,
"tags": [
"outputPrepend"
]
},
"source": [
"## A random agent"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false,
"tags": [
"outputPrepend"
]
},
"outputs": [],
"source": [
"for i_episode in range(1):\n",
" observation = gym_env.reset()\n",
"\n",
" total_reward = 0\n",
"\n",
" for t in range(5600):\n",
" action = gym_env.sample_valid_action()\n",
"\n",
" observation, reward, done, info = gym_env.step(action)\n",
" \n",
" total_reward += reward\n",
" \n",
" if reward>0:\n",
" print('####### rewarded action: {action}')\n",
" print(f'total_reward={total_reward} reward={reward}')\n",
" gym_env.render()\n",
" \n",
" if done:\n",
" print(\"Episode finished after {} timesteps\".format(t+1))\n",
" break\n",
"\n",
" gym_env.render()\n",
"\n",
"gym_env.close()\n",
"print(\"simulation ended\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### End of simulation"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

21895
notebooks/toyctf-solved.ipynb Normal file

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

80
pyright.sh Executable file
Просмотреть файл

@ -0,0 +1,80 @@
#!/bin/bash
# https://github.com/microsoft/pyright/blob/master/docs/ci-integration.md
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
PATH_TO_PYRIGHT=`which pyright`
ARGS=$@
vercomp () {
if [[ $1 == $2 ]]
then
return 0
fi
local IFS=.
local i ver1=($1) ver2=($2)
# fill empty fields in ver1 with zeros
for ((i=${#ver1[@]}; i<${#ver2[@]}; i++))
do
ver1[i]=0
done
for ((i=0; i<${#ver1[@]}; i++))
do
if [[ -z ${ver2[i]} ]]
then
# fill empty fields in ver2 with zeros
ver2[i]=0
fi
if ((10#${ver1[i]} > 10#${ver2[i]}))
then
return 1
fi
if ((10#${ver1[i]} < 10#${ver2[i]}))
then
return 2
fi
done
return 0
}
# Node version check
echo "Checking node version..."
NODE_VERSION=`node -v | cut -d'v' -f2`
MIN_NODE_VERSION="10.15.2"
vercomp $MIN_NODE_VERSION $NODE_VERSION
# 1 == gt
if [[ $? -eq 1 ]]; then
echo "Node version ${NODE_VERSION} too old, min expected is ${MIN_NODE_VERSION}, run:"
echo " npm -g upgrade node"
exit -1
fi
# Do we need to sudo?
echo "Checking node_modules dir..."
NODE_MODULES=`npm -g root`
SUDO="sudo"
if [ -w "$NODE_MODULES" ]; then
SUDO="" #nop
fi
# If we can't find pyright, install it.
echo "Checking pyright exists..."
if [ -z "$PATH_TO_PYRIGHT" ]; then
echo "...installing pyright"
${SUDO} npm install -g pyright
else
# already installed, upgrade to make sure it's current
# this avoids a sudo on launch if we're already current
echo "Checking pyright version..."
CURRENT=`pyright --version | cut -d' ' -f2`
REMOTE=`npm info pyright version`
if [ "$CURRENT" != "$REMOTE" ]; then
echo "...new version of pyright found, upgrading."
${SUDO} npm upgrade -g pyright
fi
fi
echo "done."
pyright $ARGS

22
pyrightconfig.json Normal file
Просмотреть файл

@ -0,0 +1,22 @@
{
"exclude": [
"**/.ipynb_checpoints",
"**/__pycache__",
"typings",
"venv"
],
"ignore": [
"cyberbattle/__init__.py",
"cyberbattle/agents/__init__.py"
],
"reportMissingImports": true,
"reportMissingTypeStubs": true,
"enableTypeIgnoreComments": true,
"reportUnusedFunction": "warning",
"reportUnusedImport": "warning",
"reportDuplicateImport": "warning",
"reportUnnecessaryCast": "warning",
"reportUndefinedVariable": "error",
"strictParameterNoneValue": true,
"strict": []
}

35
requirements.dev.txt Normal file
Просмотреть файл

@ -0,0 +1,35 @@
# resolve incompatibility with pytest
attrs==19.1.0
# for vscode intellisense
jedi
pandas
jupyter>=1.0.0
torch===1.5.1
# linter and formatter
flake8
pep8-naming
autopep8
pre-commit
pytest~=5.4.2
jupytext~=1.6.0
psutil==5.7.2
### type stubs
numpy_stubs @ git+https://github.com/numpy/numpy-stubs/@master#egg=numpy_stubs
# Unfortunately, the official `data-science-types` package overwrites
# the numpy stubs above, and their stubs are not as reliable.
# So we use instead one of the project fork that disables the numpy stubs
# See https://github.com/predictive-analytics-lab/data-science-types/issues/132
# and https://github.com/microsoft/pyright/issues/861
#data-science-types~=0.2.20
data-science-types @ git+https://github.com/blumu/data-science-types/@blumu-patch-1
asciichart
asciichartpy

11
requirements.txt Normal file
Просмотреть файл

@ -0,0 +1,11 @@
gym~=0.17.3
numpy==1.19.4
boolean.py~=3.7
networkx==2.4
pyyaml~=5.3.1
setuptools~=49.2.1
matplotlib~=3.2.1
plotly~=4.11.0
tabulate~=0.8.7
ordered_set==4.0.2
progressbar2==3.51.4

55
scripts/run.py Normal file
Просмотреть файл

@ -0,0 +1,55 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""A sample run of the CyberBattle simulation"""
import gym
import logging
import sys
def main() -> int:
'''Entry point if called as an executable'''
root = logging.getLogger()
root.setLevel(logging.DEBUG)
handler = logging.StreamHandler(sys.stdout)
handler.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
root.addHandler(handler)
env = gym.make('CyberBattleToyCtf-v0')
logging.info(env.action_space.sample())
logging.info(env.observation_space.sample())
for i_episode in range(1):
observation = env.reset()
action_mask = env.compute_action_mask()
total_reward = 0
for t in range(500):
env.render()
# sample a valid action
action = env.action_space.sample()
while not env.apply_mask(action_mask, action):
action = env.action_space.sample()
print('action' + str(action))
observation, reward, done, info = env.step(action)
action_mask = observation['action_mask']
total_reward = total_reward + reward
# print(observation)
print('total_reward=' + str(total_reward))
if done:
print("Episode finished after {} timesteps".format(t + 1))
break
env.close()
return 0
if __name__ == '__main__':
main()

6
setup.cfg Normal file
Просмотреть файл

@ -0,0 +1,6 @@
[flake8]
ignore = W504,W503,E501,N813,N812,E741
max-line-length = 200
max-doc-length = 200
exclude = typings, venv
per-file-ignores = ./cyberbattle/simulation/model.py:N815 ./cyberbattle/agents/baseline/agent_wrapper.py:N801

49
setup.py Normal file
Просмотреть файл

@ -0,0 +1,49 @@
#!/usr/bin/env python3
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""setup CyberBattle simulator module"""
import os
import setuptools
from typing import List
pwd = os.path.dirname(__file__)
def get_install_requires(requirements_txt) -> List[str]:
"""get the list of requried packages"""
install_requires = []
with open(os.path.join(pwd, requirements_txt)) as file:
for line in file:
line = line.strip()
if line and not line.startswith('#'):
install_requires.append(line)
return install_requires
# main setup kw args
setup_kwargs = {
'name': 'loonshot-sim',
'version': '0.1',
'description': "The simulation and RL code for the S+C loonshot project",
'author': 'S+C Loonshot Team',
'author_email': 'scloonshot@Microsoft.com',
'install_requires': get_install_requires("requirements.txt"),
'classifiers': [
'Environment :: Other Environment',
'Intended Audience :: Science/Research',
'Natural Language :: English',
'Operating System :: OS Independent',
'Programming Language :: Python',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
],
'zip_safe': True,
'packages': setuptools.find_packages(exclude=['test_*.py', '*_test.py']),
'extras_require': {
'dev': get_install_requires('requirements.dev.txt')
}
}
if __name__ == '__main__':
setuptools.setup(**setup_kwargs)