Initial CyberBattleSim release
|
@ -0,0 +1,5 @@
|
|||
/venv/**
|
||||
/.cache/**
|
||||
**/*.pyc
|
||||
**/__pycache__
|
||||
**/log.txt
|
|
@ -0,0 +1,13 @@
|
|||
root=true
|
||||
|
||||
[*]
|
||||
end_of_line = lf
|
||||
insert_final_newline = true
|
||||
|
||||
[*.{py,pyx}]
|
||||
indent_style = space
|
||||
indent_size = 4
|
||||
charset = utf-8
|
||||
file_type_emacs = python
|
||||
trim_trailing_whitespace = true
|
||||
max_line_length = 120
|
|
@ -0,0 +1 @@
|
|||
* -text
|
|
@ -0,0 +1,8 @@
|
|||
[branch "master"]
|
||||
rebase = true
|
||||
[branch]
|
||||
autosetuprebase = always
|
||||
[push]
|
||||
default = simple
|
||||
[core]
|
||||
whitespace = cr-at-eol,-trailing-space
|
|
@ -0,0 +1,60 @@
|
|||
# python
|
||||
build/
|
||||
dist/
|
||||
__pycache__*
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.egg-info
|
||||
|
||||
# mypy type checking
|
||||
/.mypy_cache/
|
||||
/.pytest_cache/
|
||||
**/*.pyc
|
||||
|
||||
# jupyter / ipython
|
||||
.ipynb_checkpoints/
|
||||
.init
|
||||
Untitled*.ipynb
|
||||
data/
|
||||
|
||||
# tests
|
||||
.coverage
|
||||
.hypothesis
|
||||
|
||||
# doc
|
||||
doc/build/
|
||||
|
||||
# sublime
|
||||
*.sublime-workspace
|
||||
sftp*-config.json
|
||||
|
||||
# line profiler
|
||||
*.lprof
|
||||
|
||||
# visual studio
|
||||
.vs
|
||||
*.pyproj
|
||||
*.pyperf
|
||||
*.sln
|
||||
|
||||
# vscode
|
||||
.ropeproject
|
||||
|
||||
# pycharm
|
||||
.idea
|
||||
|
||||
# data
|
||||
*.zip
|
||||
/venv/**
|
||||
/.cache/**
|
||||
|
||||
/.dmypy.json
|
||||
src/.dmypy.json
|
||||
|
||||
notebooks/untracked/**
|
||||
cyberbattle.code-workspace
|
||||
typings/**
|
||||
|
||||
cyberbattle/agents/baseline/notebooks/images/*.png
|
||||
cyberbattle/agents/baseline/notebooks/images/*.gif
|
||||
log.txt
|
|
@ -0,0 +1,38 @@
|
|||
# See https://pre-commit.com for more information
|
||||
# See https://pre-commit.com/hooks.html for more hooks
|
||||
|
||||
exclude: 'typings/.*'
|
||||
|
||||
repos:
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: pyright
|
||||
name: pyright
|
||||
entry: pyright
|
||||
language: node
|
||||
pass_filenames: false
|
||||
types: [python]
|
||||
additional_dependencies: ['pyright@1.1.64']
|
||||
|
||||
- id: flake8
|
||||
name: flake8
|
||||
description: '`flake8` is a command-line utility for enforcing style consistency across Python projects.'
|
||||
entry: flake8
|
||||
language: python
|
||||
types: [python]
|
||||
require_serial: true
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-autopep8
|
||||
rev: v1.5.2 # Use the sha / tag you want to point at
|
||||
hooks:
|
||||
- id: autopep8
|
||||
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v2.4.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: end-of-file-fixer
|
||||
exclude: '.*\.ipynb$'
|
||||
- id: check-yaml
|
||||
- id: check-added-large-files
|
|
@ -0,0 +1,550 @@
|
|||
[MASTER]
|
||||
|
||||
# A comma-separated list of package or module names from where C extensions may
|
||||
# be loaded. Extensions are loading into the active Python interpreter and may
|
||||
# run arbitrary code
|
||||
extension-pkg-whitelist=
|
||||
|
||||
# Add files or directories to the blacklist. They should be base names, not
|
||||
# paths.
|
||||
ignore=CVS,.ipynb_checkpoints
|
||||
|
||||
# Add files or directories matching the regex patterns to the blacklist. The
|
||||
# regex matches against base names, not paths.
|
||||
ignore-patterns=
|
||||
|
||||
# Python code to execute, usually for sys.path manipulation such as
|
||||
# pygtk.require().
|
||||
#init-hook=
|
||||
|
||||
# Use multiple processes to speed up Pylint.
|
||||
jobs=1
|
||||
|
||||
# List of plugins (as comma separated values of python modules names) to load,
|
||||
# usually to register additional checkers.
|
||||
load-plugins=
|
||||
|
||||
# Pickle collected data for later comparisons.
|
||||
persistent=yes
|
||||
|
||||
# Specify a configuration file.
|
||||
#rcfile=
|
||||
|
||||
# When enabled, pylint would attempt to guess common misconfiguration and emit
|
||||
# user-friendly hints instead of false-positive error messages
|
||||
suggestion-mode=yes
|
||||
|
||||
# Allow loading of arbitrary C extensions. Extensions are imported into the
|
||||
# active Python interpreter and may run arbitrary code.
|
||||
unsafe-load-any-extension=no
|
||||
|
||||
|
||||
[MESSAGES CONTROL]
|
||||
|
||||
# Only show warnings with the listed confidence levels. Leave empty to show
|
||||
# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
|
||||
confidence=HIGH
|
||||
|
||||
# Disable the message, report, category or checker with the given id(s). You
|
||||
# can either give multiple identifiers separated by comma (,) or put this
|
||||
# option multiple times (only on the command line, not in the configuration
|
||||
# file where it should appear only once).You can also use "--disable=all" to
|
||||
# disable everything first and then reenable specific checks. For example, if
|
||||
# you want to run only the similarities checker, you can use "--disable=all
|
||||
# --enable=similarities". If you want to run only the classes checker, but have
|
||||
# no Warning level messages displayed, use"--disable=all --enable=classes
|
||||
# --disable=W"
|
||||
disable=print-statement,
|
||||
parameter-unpacking,
|
||||
unpacking-in-except,
|
||||
old-raise-syntax,
|
||||
backtick,
|
||||
long-suffix,
|
||||
old-ne-operator,
|
||||
old-octal-literal,
|
||||
import-star-module-level,
|
||||
non-ascii-bytes-literal,
|
||||
raw-checker-failed,
|
||||
bad-inline-option,
|
||||
locally-disabled,
|
||||
locally-enabled,
|
||||
file-ignored,
|
||||
suppressed-message,
|
||||
useless-suppression,
|
||||
deprecated-pragma,
|
||||
apply-builtin,
|
||||
basestring-builtin,
|
||||
buffer-builtin,
|
||||
cmp-builtin,
|
||||
coerce-builtin,
|
||||
execfile-builtin,
|
||||
file-builtin,
|
||||
long-builtin,
|
||||
raw_input-builtin,
|
||||
reduce-builtin,
|
||||
standarderror-builtin,
|
||||
unicode-builtin,
|
||||
xrange-builtin,
|
||||
coerce-method,
|
||||
delslice-method,
|
||||
getslice-method,
|
||||
setslice-method,
|
||||
no-absolute-import,
|
||||
old-division,
|
||||
dict-iter-method,
|
||||
dict-view-method,
|
||||
next-method-called,
|
||||
metaclass-assignment,
|
||||
indexing-exception,
|
||||
raising-string,
|
||||
reload-builtin,
|
||||
oct-method,
|
||||
hex-method,
|
||||
nonzero-method,
|
||||
cmp-method,
|
||||
input-builtin,
|
||||
round-builtin,
|
||||
intern-builtin,
|
||||
unichr-builtin,
|
||||
map-builtin-not-iterating,
|
||||
zip-builtin-not-iterating,
|
||||
range-builtin-not-iterating,
|
||||
filter-builtin-not-iterating,
|
||||
using-cmp-argument,
|
||||
eq-without-hash,
|
||||
div-method,
|
||||
idiv-method,
|
||||
rdiv-method,
|
||||
exception-message-attribute,
|
||||
invalid-str-codec,
|
||||
sys-max-int,
|
||||
bad-python3-import,
|
||||
deprecated-string-function,
|
||||
deprecated-str-translate-call,
|
||||
deprecated-itertools-function,
|
||||
deprecated-types-field,
|
||||
next-method-defined,
|
||||
dict-items-not-iterating,
|
||||
dict-keys-not-iterating,
|
||||
dict-values-not-iterating
|
||||
|
||||
# Enable the message, report, category or checker with the given id(s). You can
|
||||
# either give multiple identifier separated by comma (,) or put this option
|
||||
# multiple time (only on the command line, not in the configuration file where
|
||||
# it should appear only once). See also the "--disable" option for examples.
|
||||
enable=c-extension-no-member
|
||||
|
||||
|
||||
[REPORTS]
|
||||
|
||||
# Python expression which should return a note less than 10 (10 is the highest
|
||||
# note). You have access to the variables errors warning, statement which
|
||||
# respectively contain the number of errors / warnings messages and the total
|
||||
# number of statements analyzed. This is used by the global evaluation report
|
||||
# (RP0004).
|
||||
evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
|
||||
|
||||
# Template used to display messages. This is a python new-style format string
|
||||
# used to format the message information. See doc for all details
|
||||
#msg-template=
|
||||
|
||||
# Set the output format. Available formats are text, parseable, colorized, json
|
||||
# and msvs (visual studio).You can also give a reporter class, eg
|
||||
# mypackage.mymodule.MyReporterClass.
|
||||
output-format=text
|
||||
|
||||
# Tells whether to display a full report or only the messages
|
||||
reports=no
|
||||
|
||||
# Activate the evaluation score.
|
||||
score=yes
|
||||
|
||||
|
||||
[REFACTORING]
|
||||
|
||||
# Maximum number of nested blocks for function / method body
|
||||
max-nested-blocks=5
|
||||
|
||||
# Complete name of functions that never returns. When checking for
|
||||
# inconsistent-return-statements if a never returning function is called then
|
||||
# it will be considered as an explicit return statement and no message will be
|
||||
# printed.
|
||||
never-returning-functions=optparse.Values,sys.exit
|
||||
|
||||
|
||||
[BASIC]
|
||||
|
||||
# Naming style matching correct argument names
|
||||
argument-naming-style=snake_case
|
||||
|
||||
# Regular expression matching correct argument names. Overrides argument-
|
||||
# naming-style
|
||||
#argument-rgx=
|
||||
|
||||
# Naming style matching correct attribute names
|
||||
attr-naming-style=snake_case
|
||||
|
||||
# Regular expression matching correct attribute names. Overrides attr-naming-
|
||||
# style
|
||||
#attr-rgx=
|
||||
|
||||
# Bad variable names which should always be refused, separated by a comma
|
||||
bad-names=foo,
|
||||
bar,
|
||||
baz,
|
||||
toto,
|
||||
tutu,
|
||||
tata
|
||||
|
||||
# Naming style matching correct class attribute names
|
||||
class-attribute-naming-style=any
|
||||
|
||||
# Regular expression matching correct class attribute names. Overrides class-
|
||||
# attribute-naming-style
|
||||
#class-attribute-rgx=
|
||||
|
||||
# Naming style matching correct class names
|
||||
class-naming-style=PascalCase
|
||||
|
||||
# Regular expression matching correct class names. Overrides class-naming-style
|
||||
#class-rgx=
|
||||
|
||||
# Naming style matching correct constant names
|
||||
const-naming-style=UPPER_CASE
|
||||
|
||||
# Regular expression matching correct constant names. Overrides const-naming-
|
||||
# style
|
||||
#const-rgx=
|
||||
|
||||
# Minimum line length for functions/classes that require docstrings, shorter
|
||||
# ones are exempt.
|
||||
docstring-min-length=-1
|
||||
|
||||
# Naming style matching correct function names
|
||||
function-naming-style=snake_case
|
||||
|
||||
# Regular expression matching correct function names. Overrides function-
|
||||
# naming-style
|
||||
#function-rgx=
|
||||
|
||||
# Good variable names which should always be accepted, separated by a comma
|
||||
good-names=i,
|
||||
j,
|
||||
k,
|
||||
ex,
|
||||
Run,
|
||||
t,
|
||||
u,
|
||||
v,
|
||||
n,
|
||||
x,
|
||||
y,
|
||||
o,
|
||||
h,
|
||||
r,
|
||||
c,
|
||||
Q,
|
||||
_
|
||||
|
||||
# Include a hint for the correct naming format with invalid-name
|
||||
include-naming-hint=no
|
||||
|
||||
# Naming style matching correct inline iteration names
|
||||
inlinevar-naming-style=any
|
||||
|
||||
# Regular expression matching correct inline iteration names. Overrides
|
||||
# inlinevar-naming-style
|
||||
#inlinevar-rgx=
|
||||
|
||||
# Naming style matching correct method names
|
||||
method-naming-style=snake_case
|
||||
|
||||
# Regular expression matching correct method names. Overrides method-naming-
|
||||
# style
|
||||
#method-rgx=
|
||||
|
||||
# Naming style matching correct module names
|
||||
module-naming-style=snake_case
|
||||
|
||||
# Regular expression matching correct module names. Overrides module-naming-
|
||||
# style
|
||||
#module-rgx=
|
||||
|
||||
# Colon-delimited sets of names that determine each other's naming style when
|
||||
# the name regexes allow several styles.
|
||||
name-group=
|
||||
|
||||
# Regular expression which should only match function or class names that do
|
||||
# not require a docstring.
|
||||
no-docstring-rgx=^_
|
||||
|
||||
# List of decorators that produce properties, such as abc.abstractproperty. Add
|
||||
# to this list to register other decorators that produce valid properties.
|
||||
property-classes=abc.abstractproperty
|
||||
|
||||
# Naming style matching correct variable names
|
||||
variable-naming-style=snake_case
|
||||
|
||||
# Regular expression matching correct variable names. Overrides variable-
|
||||
# naming-style
|
||||
#variable-rgx=
|
||||
|
||||
|
||||
[FORMAT]
|
||||
|
||||
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
|
||||
expected-line-ending-format=
|
||||
|
||||
# Regexp for a line that is allowed to be longer than the limit.
|
||||
ignore-long-lines=^\s*(# )?<?https?://\S+>?$
|
||||
|
||||
# Number of spaces of indent required inside a hanging or continued line.
|
||||
indent-after-paren=4
|
||||
|
||||
# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
|
||||
# tab).
|
||||
indent-string=' '
|
||||
|
||||
# Maximum number of characters on a single line.
|
||||
max-line-length=0
|
||||
|
||||
# Maximum number of lines in a module
|
||||
max-module-lines=1000
|
||||
|
||||
# List of optional constructs for which whitespace checking is disabled. `dict-
|
||||
# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
|
||||
# `trailing-comma` allows a space between comma and closing bracket: (a, ).
|
||||
# `empty-line` allows space-only lines.
|
||||
no-space-check=trailing-comma,
|
||||
dict-separator
|
||||
|
||||
# Allow the body of a class to be on the same line as the declaration if body
|
||||
# contains single statement.
|
||||
single-line-class-stmt=no
|
||||
|
||||
# Allow the body of an if to be on the same line as the test if there is no
|
||||
# else.
|
||||
single-line-if-stmt=no
|
||||
|
||||
|
||||
[LOGGING]
|
||||
|
||||
# Logging modules to check that the string format arguments are in logging
|
||||
# function parameter format
|
||||
logging-modules=logging
|
||||
|
||||
|
||||
[MISCELLANEOUS]
|
||||
|
||||
# List of note tags to take in consideration, separated by a comma.
|
||||
notes=FIXME,
|
||||
XXX,
|
||||
TODO
|
||||
|
||||
|
||||
[SIMILARITIES]
|
||||
|
||||
# Ignore comments when computing similarities.
|
||||
ignore-comments=yes
|
||||
|
||||
# Ignore docstrings when computing similarities.
|
||||
ignore-docstrings=yes
|
||||
|
||||
# Ignore imports when computing similarities.
|
||||
ignore-imports=no
|
||||
|
||||
# Minimum lines number of a similarity.
|
||||
min-similarity-lines=4
|
||||
|
||||
|
||||
[SPELLING]
|
||||
|
||||
# Limits count of emitted suggestions for spelling mistakes
|
||||
max-spelling-suggestions=4
|
||||
|
||||
# Spelling dictionary name. Available dictionaries: none. To make it working
|
||||
# install python-enchant package.
|
||||
spelling-dict=
|
||||
|
||||
# List of comma separated words that should not be checked.
|
||||
spelling-ignore-words=
|
||||
|
||||
# A path to a file that contains private dictionary; one word per line.
|
||||
spelling-private-dict-file=
|
||||
|
||||
# Tells whether to store unknown words to indicated private dictionary in
|
||||
# --spelling-private-dict-file option instead of raising a message.
|
||||
spelling-store-unknown-words=no
|
||||
|
||||
|
||||
[TYPECHECK]
|
||||
|
||||
# List of decorators that produce context managers, such as
|
||||
# contextlib.contextmanager. Add to this list to register other decorators that
|
||||
# produce valid context managers.
|
||||
contextmanager-decorators=contextlib.contextmanager
|
||||
|
||||
# List of members which are set dynamically and missed by pylint inference
|
||||
# system, and so shouldn't trigger E1101 when accessed. Python regular
|
||||
# expressions are accepted.
|
||||
generated-members=
|
||||
|
||||
# Tells whether missing members accessed in mixin class should be ignored. A
|
||||
# mixin class is detected if its name ends with "mixin" (case insensitive).
|
||||
ignore-mixin-members=yes
|
||||
|
||||
# This flag controls whether pylint should warn about no-member and similar
|
||||
# checks whenever an opaque object is returned when inferring. The inference
|
||||
# can return multiple potential results while evaluating a Python object, but
|
||||
# some branches might not be evaluated, which results in partial inference. In
|
||||
# that case, it might be useful to still emit no-member and other checks for
|
||||
# the rest of the inferred objects.
|
||||
ignore-on-opaque-inference=yes
|
||||
|
||||
# List of class names for which member attributes should not be checked (useful
|
||||
# for classes with dynamically set attributes). This supports the use of
|
||||
# qualified names.
|
||||
ignored-classes=optparse.Values,thread._local,_thread._local
|
||||
|
||||
# List of module names for which member attributes should not be checked
|
||||
# (useful for modules/projects where namespaces are manipulated during runtime
|
||||
# and thus existing member attributes cannot be deduced by static analysis. It
|
||||
# supports qualified module names, as well as Unix pattern matching.
|
||||
ignored-modules=
|
||||
|
||||
# Show a hint with possible names when a member name was not found. The aspect
|
||||
# of finding the hint is based on edit distance.
|
||||
missing-member-hint=yes
|
||||
|
||||
# The minimum edit distance a name should have in order to be considered a
|
||||
# similar match for a missing member name.
|
||||
missing-member-hint-distance=1
|
||||
|
||||
# The total number of similar names that should be taken in consideration when
|
||||
# showing a hint for a missing member.
|
||||
missing-member-max-choices=1
|
||||
|
||||
|
||||
[VARIABLES]
|
||||
|
||||
# List of additional names supposed to be defined in builtins. Remember that
|
||||
# you should avoid to define new builtins when possible.
|
||||
additional-builtins=
|
||||
|
||||
# Tells whether unused global variables should be treated as a violation.
|
||||
allow-global-unused-variables=yes
|
||||
|
||||
# List of strings which can identify a callback function by name. A callback
|
||||
# name must start or end with one of those strings.
|
||||
callbacks=cb_,
|
||||
_cb
|
||||
|
||||
# A regular expression matching the name of dummy variables (i.e. expectedly
|
||||
# not used).
|
||||
dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
|
||||
|
||||
# Argument names that match this expression will be ignored. Default to name
|
||||
# with leading underscore
|
||||
ignored-argument-names=_.*|^ignored_|^unused_
|
||||
|
||||
# Tells whether we should check for unused import in __init__ files.
|
||||
init-import=no
|
||||
|
||||
# List of qualified module names which can have objects that can redefine
|
||||
# builtins.
|
||||
redefining-builtins-modules=six.moves,past.builtins,future.builtins
|
||||
|
||||
|
||||
[CLASSES]
|
||||
|
||||
# List of method names used to declare (i.e. assign) instance attributes.
|
||||
defining-attr-methods=__init__,
|
||||
__new__,
|
||||
setUp
|
||||
|
||||
# List of member names, which should be excluded from the protected access
|
||||
# warning.
|
||||
exclude-protected=_asdict,
|
||||
_fields,
|
||||
_replace,
|
||||
_source,
|
||||
_make
|
||||
|
||||
# List of valid names for the first argument in a class method.
|
||||
valid-classmethod-first-arg=cls
|
||||
|
||||
# List of valid names for the first argument in a metaclass class method.
|
||||
valid-metaclass-classmethod-first-arg=mcs
|
||||
|
||||
|
||||
[DESIGN]
|
||||
|
||||
# Maximum number of arguments for function / method
|
||||
max-args=5
|
||||
|
||||
# Maximum number of attributes for a class (see R0902).
|
||||
max-attributes=7
|
||||
|
||||
# Maximum number of boolean expressions in a if statement
|
||||
max-bool-expr=5
|
||||
|
||||
# Maximum number of branch for function / method body
|
||||
max-branches=12
|
||||
|
||||
# Maximum number of locals for function / method body
|
||||
max-locals=15
|
||||
|
||||
# Maximum number of parents for a class (see R0901).
|
||||
max-parents=7
|
||||
|
||||
# Maximum number of public methods for a class (see R0904).
|
||||
max-public-methods=20
|
||||
|
||||
# Maximum number of return / yield for function / method body
|
||||
max-returns=6
|
||||
|
||||
# Maximum number of statements in function / method body
|
||||
max-statements=50
|
||||
|
||||
# Minimum number of public methods for a class (see R0903).
|
||||
min-public-methods=2
|
||||
|
||||
|
||||
[IMPORTS]
|
||||
|
||||
# Allow wildcard imports from modules that define __all__.
|
||||
allow-wildcard-with-all=no
|
||||
|
||||
# Analyse import fallback blocks. This can be used to support both Python 2 and
|
||||
# 3 compatible code, which means that the block might have code that exists
|
||||
# only in one or another interpreter, leading to false positives when analysed.
|
||||
analyse-fallback-blocks=no
|
||||
|
||||
# Deprecated modules which should not be used, separated by a comma
|
||||
deprecated-modules=optparse,tkinter.tix
|
||||
|
||||
# Create a graph of external dependencies in the given file (report RP0402 must
|
||||
# not be disabled)
|
||||
ext-import-graph=
|
||||
|
||||
# Create a graph of every (i.e. internal and external) dependencies in the
|
||||
# given file (report RP0402 must not be disabled)
|
||||
import-graph=
|
||||
|
||||
# Create a graph of internal dependencies in the given file (report RP0402 must
|
||||
# not be disabled)
|
||||
int-import-graph=
|
||||
|
||||
# Force import order to recognize a module as part of the standard
|
||||
# compatibility libraries.
|
||||
known-standard-library=
|
||||
|
||||
# Force import order to recognize a module as part of a third party library.
|
||||
known-third-party=enchant
|
||||
|
||||
|
||||
[EXCEPTIONS]
|
||||
|
||||
# Exceptions that will emit a warning when being caught. Defaults to
|
||||
# "Exception"
|
||||
overgeneral-exceptions=Exception
|
|
@ -0,0 +1,22 @@
|
|||
{
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Python: cyberbattle_gym",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "src/cyberbattle_gym/samples/run.py",
|
||||
"console": "integratedTerminal"
|
||||
},
|
||||
{
|
||||
"name": "Python: Current File",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${file}",
|
||||
"console": "integratedTerminal"
|
||||
}
|
||||
]
|
||||
}
|
|
@ -0,0 +1,43 @@
|
|||
{
|
||||
"python.pythonPath": "/usr/bin/python",
|
||||
"python.venvPath": "venv",
|
||||
"python.linting.pylintEnabled": false,
|
||||
"python.linting.mypyEnabled": false,
|
||||
"python.linting.flake8Enabled": true,
|
||||
"python.linting.enabled": true,
|
||||
"python.linting.mypyPath": "venv\\Scripts\\mypy.exe",
|
||||
"python.linting.mypyArgs": [
|
||||
"--config-file",
|
||||
"setup.cfg"
|
||||
],
|
||||
"python.testing.pytestArgs": [
|
||||
"cyberbattle"
|
||||
],
|
||||
"python.testing.unittestEnabled": false,
|
||||
"python.testing.nosetestsEnabled": false,
|
||||
"python.testing.pytestEnabled": true,
|
||||
"python.linting.lintOnSave": true,
|
||||
"python.formatting.provider": "autopep8",
|
||||
"editor.formatOnPaste": true,
|
||||
"editor.formatOnSave": true,
|
||||
"extensions.ignoreRecommendations": true,
|
||||
"files.autoSave": "off",
|
||||
"debug.allowBreakpointsEverywhere": true,
|
||||
"cSpell.enableFiletypes": [
|
||||
"!yaml",
|
||||
"!python",
|
||||
"!json"
|
||||
],
|
||||
"python.languageServer": "Pylance",
|
||||
"python.analysis.typeCheckingMode": "basic",
|
||||
"search.useGlobalIgnoreFiles": true,
|
||||
"files.watcherExclude": {
|
||||
"typings/**": true,
|
||||
"venv/**": true
|
||||
},
|
||||
"files.exclude": {
|
||||
"venv/": true
|
||||
},
|
||||
"jupyter.jupyterServerType": "local",
|
||||
"files.trimFinalNewlines": true
|
||||
}
|
|
@ -0,0 +1,9 @@
|
|||
# Microsoft Open Source Code of Conduct
|
||||
|
||||
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
||||
|
||||
Resources:
|
||||
|
||||
- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
|
||||
- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
|
||||
- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
|
|
@ -0,0 +1,25 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
FROM mcr.microsoft.com/azureml/onnxruntime:latest-cuda
|
||||
WORKDIR /root
|
||||
ADD *.sh ./
|
||||
ADD *.txt ./
|
||||
ADD *.py ./
|
||||
RUN export TERM=dumb && ./init.sh -n
|
||||
# Override conda python 3.7 install
|
||||
RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.8 2
|
||||
ENV PATH="/usr/bin:${PATH}"
|
||||
|
||||
COPY . .
|
||||
|
||||
# To build the docker image:
|
||||
# docker build -t cyberbattle:1.1 .
|
||||
#
|
||||
# To run
|
||||
# docker run -it --rm cyberbattle:1.1 bash
|
||||
#
|
||||
# Pushing to private repository
|
||||
# docker login -u spinshot-team-token-writer --password-stdin spinshot.azurecr.io
|
||||
# docker tag cyberbattle:1.1 spinshot.azurecr.io/cyberbattle:1.1
|
||||
# docker push spinshot.azurecr.io/cyberbattle:1.1
|
|
@ -0,0 +1,21 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) Microsoft Corporation.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE
|
|
@ -0,0 +1,246 @@
|
|||
# CyberBattleSim
|
||||
|
||||
|
||||
CyberBattleSim is an experimentation research platform to investigate the interaction
|
||||
of automated agents operating in a simulated abstract enterprise network environment.
|
||||
The simulation provides a high-level abstraction of computer networks
|
||||
and cyber security concepts.
|
||||
Its Python-based OpenAI Gym interface allows for training of
|
||||
automated agents using reinforcement learning algorithms.
|
||||
|
||||
The simulation environment is parameterized by a fixed network topology
|
||||
with an associated set of vulnerabilities that attacker agents can utilize
|
||||
to move laterally in the network.
|
||||
The goal of the attacker is to take ownership of a portion of the network by exploiting
|
||||
vulnerabilities that are planted in the computer nodes.
|
||||
While the attacker attempts to spread throughout the network,
|
||||
a defender agent watches the network activity and tries to contain the attack.
|
||||
|
||||
We provide a basic stochastic defender that detects
|
||||
and mitigate ongoing attacks based on pre-defined probabilities of success.
|
||||
Mitigation consists in re-imaging the infected nodes, a process
|
||||
abstractly modeled as an operation spanning over multiple simulation steps.
|
||||
|
||||
To compare performance of the agents we look at two metrics: the number of simulation steps taken to
|
||||
attain their goal and the cumulative rewards over simulation steps across training epochs.
|
||||
|
||||
## Project goals
|
||||
|
||||
We view this project as an experimentation platform to conduct research on the interaction of automated agents in abstract simulated network environments. By open sourcing it we hope to encourage the research community investigate how cyber-agents interact and evolve in such network environments.
|
||||
|
||||
The simulation we provide is admittedly simplistic, but this has advantages. Its highly abstract nature prohibits direct application to real-world systems thus providing a safeguard against potential nefarious use of automated agents trained with it.
|
||||
At the same time, its simplicity allows us to focus on specific security aspects we aim to study and quickly experiment with recent machine learning and AI algorithms.
|
||||
|
||||
For instance, the current implementation focuses on
|
||||
the lateral movement cyber-attacks techniques, with the hope to understand how the network topology and configuration affects them. With this goal in mind, we felt that modeling actual network traffic was not necessary. This is just one example of a significant limitation in our system that future contributions might want to address.
|
||||
|
||||
On the algorithmic side, we provide some basic agents as starting points but we
|
||||
would be curious to find out how state-of-the art reinforcement learning algorithms compare to them. We found that the large action space
|
||||
intrinsic to any computer system is a particular challenge for
|
||||
Reinforcement Learning, in contrasts to other applications such as video games or robot control. Training agents that can store and retrieve credentials is another challenge we faced when applying RL techniques
|
||||
where agents typically do not feature internal memory. These are other areas of research where the provided simulation can perhaps be used for benchmarking purpose.
|
||||
|
||||
Other areas of interests include the responsible and ethical use of autonomous cyber-security systems:
|
||||
How to design an enterprise network that gives an intrinsic advantage to defender agents?
|
||||
How to conduct safe research aimed at defending enterprises against autonomous cyber-attacks
|
||||
while preventing nefarious use of such technology?
|
||||
|
||||
## Documentation
|
||||
|
||||
Read the [Quick introduction](/docs/quickintro.md) to the project.
|
||||
|
||||
## Build status
|
||||
|
||||
| Type | Branch | Status |
|
||||
| --- | ------ | ------ |
|
||||
| CI | master | [![Build Status](https://mscodehub.visualstudio.com/Asterope/_apis/build/status/CyberBattle-ContinuousIntegration?branchName=master)](https://mscodehub.visualstudio.com/Asterope/_build/latest?definitionId=1359&branchName=master) |
|
||||
| Docker image | master | [![Build Status](https://mscodehub.visualstudio.com/Asterope/_apis/build/status/CyberBattle-Docker?branchName=master)](https://mscodehub.visualstudio.com/Asterope/_build/latest?definitionId=1454&branchName=master) |
|
||||
|
||||
|
||||
|
||||
|
||||
## Benchmark
|
||||
|
||||
See [Benchmark](/docs/benchmark.md).
|
||||
|
||||
|
||||
## Setting up a dev environment
|
||||
|
||||
It is strongly recommended to work under a Linux environment, either directly or via WSL on Windows.
|
||||
Running Python on Windows directly should work but is not supported anymore.
|
||||
|
||||
Start by checking out the repository:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/microsoft/CyberBattleSim.git
|
||||
```
|
||||
|
||||
### On Linux or WSL
|
||||
|
||||
The instructions were tested on a Linux Ubuntu distribution (both native and via WSL). Run the following command to set-up your dev environment and install all the required dependencies (apt and pip packages):
|
||||
|
||||
|
||||
```bash
|
||||
./init.sh
|
||||
```
|
||||
|
||||
The script installs python3.8 if not present. If you are running a version of Ubuntu older than 20 it will automatically add an additional apt repository to install python3.8.
|
||||
|
||||
The script will create a [virtual Python environment](https://docs.python.org/3/library/venv.html) under a `venv` subdirectory, you can then
|
||||
run Python with `venv/bin/python`.
|
||||
|
||||
> Note: If you prefer Python from a global installation instead of a virtual environment then you can skip the creation of the virtual envrionment by running the script with `./init.sh -n`. This will instead install all the Python packages on a system-wide installation of Python 3.8.
|
||||
|
||||
#### Windows Subsystem for Linux
|
||||
|
||||
The supported dev environment on Windows is via WSL.
|
||||
You first need to install an Ubuntu WSL distribution on your Windows machine,
|
||||
and then proceed with the Linux instructions (next section).
|
||||
|
||||
#### Git authentication from WSL
|
||||
|
||||
To authenticate with Git you can either use SSH-based authentication, or
|
||||
alternatively use the credential-helper trick to automatically generate a
|
||||
PAT token. The latter can be done by running the following commmand under WSL
|
||||
([more info here](https://docs.microsoft.com/en-us/windows/wsl/tutorials/wsl-git)):
|
||||
|
||||
```ps
|
||||
git config --global credential.helper "/mnt/c/Program\ Files/Git/mingw64/libexec/git-core/git-credential-manager.exe"
|
||||
```
|
||||
|
||||
#### Docker on WSL
|
||||
|
||||
To run your environment within a docker container, we recommend running `docker` via Windows Subsystem on Linux (WSL) using the following instructions:
|
||||
[Installing Docker on Windows under WSL](https://docs.docker.com/docker-for-windows/wsl-tech-preview/)).
|
||||
|
||||
|
||||
### Windows (unsupported)
|
||||
|
||||
This method is not supported anymore, please prefer instead running under
|
||||
a WSL subsystem Linux environment.
|
||||
But if you insist you want to start by installing [Python 3.8](https://www.python.org/downloads/windows/) then in a Powershell prompt run the `./init.ps1` script.
|
||||
|
||||
|
||||
|
||||
## Getting started quickly using Docker (internal only at this stage)
|
||||
|
||||
> NOTE: We do not currently redistribute build artifacts or Docker containers externally for this project. We provide the Dockerfile and CI yaml files if you need to recreate those artifacts.
|
||||
|
||||
The quickest way to get up and running is to use the Docker container.
|
||||
Note: you first need to request access to the Docker registry `spinshot.azurecr.io`. (Not publicly available.)
|
||||
|
||||
```bash
|
||||
docker login spinshot.azurecr.io
|
||||
docker pull spinshot.azurecr.io/cyberbattle:157884
|
||||
docker run -it spinshot.azurecr.io/cyberbattle:157884 cyberbattle/agents/baseline/run.py
|
||||
```
|
||||
|
||||
## Check your environment
|
||||
|
||||
Run the following command to run a simulation with a baseline RL agent:
|
||||
|
||||
```
|
||||
python cyberbattle/agents/baseline/run.py --training_episode_count 1 --eval_episode_count 1 --iteration_count 10 --rewardplot_with 80 --chain_size=20 --ownership_goal 1.0
|
||||
```
|
||||
|
||||
If everything is setup correctly you should get an output that looks like this:
|
||||
|
||||
```bash
|
||||
torch cuda available=True
|
||||
###### DQL
|
||||
Learning with: episode_count=1,iteration_count=10,ϵ=0.9,ϵ_min=0.1, ϵ_expdecay=5000,γ=0.015, lr=0.01, replaymemory=10000,
|
||||
batch=512, target_update=10
|
||||
## Episode: 1/1 'DQL' ϵ=0.9000, γ=0.015, lr=0.01, replaymemory=10000,
|
||||
batch=512, target_update=10
|
||||
Episode 1|Iteration 10|reward: 139.0|Elapsed Time: 0:00:00|###################################################################|
|
||||
###### Random search
|
||||
Learning with: episode_count=1,iteration_count=10,ϵ=1.0,ϵ_min=0.0,
|
||||
## Episode: 1/1 'Random search' ϵ=1.0000,
|
||||
Episode 1|Iteration 10|reward: 194.0|Elapsed Time: 0:00:00|###################################################################|
|
||||
simulation ended
|
||||
Episode duration -- DQN=Red, Random=Green
|
||||
10.00 ┼
|
||||
Cumulative rewards -- DQN=Red, Random=Green
|
||||
194.00 ┼ ╭──╴
|
||||
174.60 ┤ │
|
||||
155.20 ┤╭─────╯
|
||||
135.80 ┤│ ╭──╴
|
||||
116.40 ┤│ │
|
||||
97.00 ┤│ ╭╯
|
||||
77.60 ┤│ │
|
||||
58.20 ┤╯ ╭──╯
|
||||
38.80 ┤ │
|
||||
19.40 ┤ │
|
||||
0.00 ┼──╯
|
||||
```
|
||||
|
||||
## Jupyter notebooks
|
||||
|
||||
To quickly get familiar with the project you can open one the
|
||||
the provided Juptyer notebooks to play interactively with
|
||||
the gym environments. Just start jupyter with `jupyter notebook`, or
|
||||
`venv/bin/jupyter notebook` if you are using a virtual environment setup.
|
||||
|
||||
- Capture The Flag Toy environment notebooks:
|
||||
- [Random agent](notebooks/toyctf-random.ipynb)
|
||||
- [Interactive session for a human player](notebooks/toyctf-blank.ipynb)
|
||||
- [Interactive session - fully solved](notebooks/toyctf-solved.ipynb)
|
||||
|
||||
- Chain environment notebooks:
|
||||
- [Random agent](notebooks/chainnetwork-random.ipynb)
|
||||
|
||||
- Other environments:
|
||||
- [Interactive session with a randomly generated environment](notebooks/randomnetwork.ipynb)
|
||||
- [Random agent playing on randomly generated networks](notebooks/c2_interactive_interface.ipynb)
|
||||
|
||||
|
||||
|
||||
|
||||
## Contributing
|
||||
|
||||
This project welcomes contributions and suggestions. Most contributions require you to agree to a
|
||||
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
|
||||
the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
|
||||
|
||||
When you submit a pull request, a CLA bot will automatically determine whether you need to provide
|
||||
a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
|
||||
provided by the bot. You will only need to do this once across all repos using our CLA.
|
||||
|
||||
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
||||
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
|
||||
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
|
||||
|
||||
|
||||
### Ideas for contributions
|
||||
|
||||
Here are some ideas on how to contribute: enhance the simulation (event-based, refined the simulation, …), train an RL algorithm on the existing simulation,
|
||||
implement benchmark to evaluate and compare novelty of agents, add more network generative modes to train RL-agent on, contribute to the doc, fix bugs.
|
||||
|
||||
See also the [wiki for more ideas](https://github.com/microsoft/CyberBattleGym/wiki/Possible-contributions).
|
||||
|
||||
## Citing this project
|
||||
|
||||
```bibtex
|
||||
@misc{msft:cyberbattlesim,
|
||||
Author = {Brandon Marken and Christian Seifert and Emily Goren and Haoran Wei and James Bono and Joshua Neil and Jugal Parikh and Justin Grana and Kate Farris and Kristian Holsheimer and Michael Betser and Nicole Nichols and William Blum},
|
||||
Publisher = {GitHub},
|
||||
Howpublished = {\url{https://github.com/microsoft/cyberbattlesim}},
|
||||
Title = {CyberBattleSim},
|
||||
Year = {2021}
|
||||
}
|
||||
```
|
||||
|
||||
## Note on privacy
|
||||
This project does not include any customer data.
|
||||
The provided models and network topologies are purely fictitious.
|
||||
Users of the provided code provide all the input to the simulation
|
||||
and must have the necessary permissions to use any provided data.
|
||||
|
||||
|
||||
## Trademarks
|
||||
|
||||
This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft
|
||||
trademarks or logos is subject to and must follow
|
||||
[Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
|
||||
Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
|
||||
Any use of third-party trademarks or logos are subject to those third-party's policies.
|
|
@ -0,0 +1,41 @@
|
|||
<!-- BEGIN MICROSOFT SECURITY.MD V0.0.5 BLOCK -->
|
||||
|
||||
## Security
|
||||
|
||||
Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
|
||||
|
||||
If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below.
|
||||
|
||||
## Reporting Security Issues
|
||||
|
||||
**Please do not report security vulnerabilities through public GitHub issues.**
|
||||
|
||||
Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report).
|
||||
|
||||
If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc).
|
||||
|
||||
You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
|
||||
|
||||
Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
|
||||
|
||||
* Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
|
||||
* Full paths of source file(s) related to the manifestation of the issue
|
||||
* The location of the affected source code (tag/branch/commit or direct URL)
|
||||
* Any special configuration required to reproduce the issue
|
||||
* Step-by-step instructions to reproduce the issue
|
||||
* Proof-of-concept or exploit code (if possible)
|
||||
* Impact of the issue, including how an attacker might exploit the issue
|
||||
|
||||
This information will help us triage your report more quickly.
|
||||
|
||||
If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs.
|
||||
|
||||
## Preferred Languages
|
||||
|
||||
We prefer all communications to be in English.
|
||||
|
||||
## Policy
|
||||
|
||||
Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd).
|
||||
|
||||
<!-- END MICROSOFT SECURITY.MD BLOCK -->
|
|
@ -0,0 +1,11 @@
|
|||
# Support
|
||||
|
||||
## How to file issues and get help
|
||||
|
||||
This project uses GitHub Issues to track bugs and feature requests. Please search the existing
|
||||
issues before filing new issues to avoid duplicates. For new issues, file your bug or
|
||||
feature request as a new Issue.
|
||||
|
||||
## Microsoft Support Policy
|
||||
|
||||
Support for **CyberBattle** is limited to the resources listed above.
|
|
@ -0,0 +1,39 @@
|
|||
# Docker
|
||||
# Build and push an image to Azure Container Registry
|
||||
# https://docs.microsoft.com/azure/devops/pipelines/languages/docker
|
||||
|
||||
trigger:
|
||||
- master
|
||||
|
||||
resources:
|
||||
- repo: self
|
||||
|
||||
variables:
|
||||
# Container registry service connection established during pipeline creation
|
||||
dockerRegistryServiceConnection: 'ac2df822-b01f-4588-a9bd-8195602c3995'
|
||||
imageRepository: 'cyberbattle'
|
||||
containerRegistry: 'spinshot.azurecr.io'
|
||||
dockerfilePath: '$(Build.SourcesDirectory)/Dockerfile'
|
||||
tag: '$(Build.BuildId)'
|
||||
|
||||
# Agent VM image name
|
||||
vmImageName: 'ubuntu-latest'
|
||||
|
||||
stages:
|
||||
- stage: Build
|
||||
displayName: Build and push stage
|
||||
jobs:
|
||||
- job: Build
|
||||
displayName: Build
|
||||
pool:
|
||||
vmImage: $(vmImageName)
|
||||
steps:
|
||||
- task: Docker@2
|
||||
displayName: Build and push an image to container registry
|
||||
inputs:
|
||||
command: buildAndPush
|
||||
repository: $(imageRepository)
|
||||
dockerfile: $(dockerfilePath)
|
||||
containerRegistry: $(dockerRegistryServiceConnection)
|
||||
tags: |
|
||||
$(tag)
|
|
@ -0,0 +1,79 @@
|
|||
# Python package
|
||||
# Create and test a Python package on multiple Python versions.
|
||||
# Add steps that analyze code, save the dist with the build record, publish to a PyPI-compatible index, and more:
|
||||
# https://docs.microsoft.com/azure/devops/pipelines/languages/python
|
||||
#
|
||||
# Adapted from the template available here https://github.com/microsoft/azure-pipelines-yaml/blob/master/templates/python-package.yml
|
||||
trigger:
|
||||
- master
|
||||
|
||||
pool:
|
||||
vmImage: "ubuntu-latest"
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
Python38:
|
||||
python.version: "3.8"
|
||||
|
||||
steps:
|
||||
- checkout: self
|
||||
submodules: false # does not work event though it's supposed to (authentication issue)
|
||||
persistCredentials: false
|
||||
|
||||
- task: UsePythonVersion@0
|
||||
name: pythonver
|
||||
inputs:
|
||||
versionSpec: "$(python.version)"
|
||||
displayName: "Use Python $(python.version)"
|
||||
|
||||
- task: NodeTool@0
|
||||
inputs:
|
||||
versionSpec: '12.x'
|
||||
#checkLatest: false # Optional
|
||||
displayName: "Use node tools"
|
||||
|
||||
- script: |
|
||||
cat apt-requirements.txt | xargs sudo apt install
|
||||
displayName: "Install apt dependencies"
|
||||
|
||||
- script: |
|
||||
python -m pip install flake8
|
||||
flake8 --benchmark
|
||||
displayName: "Lint with flake8"
|
||||
|
||||
- task: Cache@2
|
||||
displayName: "Pull pip packages from cache"
|
||||
inputs:
|
||||
key: 'pip | "$(Agent.OS)" | requirements.txt | requirements.dev.txt | setup.py'
|
||||
restoreKeys: |
|
||||
pip | "$(Agent.OS)"
|
||||
path: $(pythonver.pythonLocation)/lib/python3.8/site-packages
|
||||
|
||||
- script: |
|
||||
./install-pythonpackages.sh --nocoax
|
||||
displayName: "Pull pip dependencies"
|
||||
|
||||
- script: |
|
||||
npm install -g pyright
|
||||
displayName: "Install pyright"
|
||||
|
||||
- task: Cache@2
|
||||
displayName: "Pull typing stubs from cache"
|
||||
inputs:
|
||||
key: 'typingstubs | "$(Agent.OS)" | createstubs.sh'
|
||||
restoreKeys: |
|
||||
typingstubs | "$(Agent.OS)" | createstubs.sh
|
||||
path: typings/
|
||||
|
||||
- script: |
|
||||
./createstubs.sh
|
||||
displayName: "create type stubs"
|
||||
|
||||
- script: |
|
||||
./pyright.sh
|
||||
displayName: "Typecheck with pyright"
|
||||
|
||||
- script: |
|
||||
pip install pytest-azurepipelines
|
||||
python -m pytest -v cyberbattle
|
||||
displayName: "Test with pytest"
|
|
@ -0,0 +1,5 @@
|
|||
libgtk-3-0
|
||||
xvfb
|
||||
chromium-browser
|
||||
libgconf-2-4
|
||||
npm
|
|
@ -0,0 +1,16 @@
|
|||
software-properties-common
|
||||
python3.8
|
||||
python3.8-venv
|
||||
python3-distutils
|
||||
python3-pytest
|
||||
python3-opengl
|
||||
python3-dev
|
||||
python3.8-dev
|
||||
build-essential
|
||||
nodejs
|
||||
curl
|
||||
dirmngr
|
||||
apt-transport-https
|
||||
lsb-release
|
||||
ca-certificates
|
||||
python-distutils-extra
|
|
@ -0,0 +1,54 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
set -e
|
||||
|
||||
. ./getpythonpath.sh
|
||||
|
||||
echo "$(tput setaf 2)Creating type stubs$(tput sgr0)"
|
||||
createstub() {
|
||||
local name=$1
|
||||
if [ ! -d "typings/$name" ]; then
|
||||
pyright --createstub $name
|
||||
else
|
||||
echo stub $name already created
|
||||
fi
|
||||
}
|
||||
|
||||
createstub networkx
|
||||
createstub pandas
|
||||
createstub plotly
|
||||
createstub progressbar
|
||||
createstub pytest
|
||||
createstub setuptools
|
||||
createstub ordered_set
|
||||
createstub asciichartpy
|
||||
|
||||
if [ ! -d "typings/gym" ]; then
|
||||
pyright --createstub gym
|
||||
# Patch gym stubs
|
||||
echo ' spaces = ...' >> typings/gym/spaces/dict.pyi
|
||||
echo ' nvec = ...' >> typings/gym/spaces/space.pyi
|
||||
else
|
||||
echo stub gym already created
|
||||
fi
|
||||
|
||||
if [ ! -d "typings/IPython" ]; then
|
||||
pyright --createstub IPython.core.display
|
||||
else
|
||||
echo stub 'IPython' already created
|
||||
fi
|
||||
|
||||
if [ ! -d "boolean" ]; then
|
||||
pyright --createstub boolean
|
||||
sed -i '/class BooleanAlgebra(object):/a\ TRUE = ...\n FALSE = ...' typings/boolean/boolean.pyi
|
||||
else
|
||||
echo stub 'boolean' already created
|
||||
fi
|
||||
|
||||
# Stubs that needed manual patching and that
|
||||
# were instead checked-in in git
|
||||
# pyright --createstub boolean
|
||||
# pyright --createstub gym
|
|
@ -0,0 +1,44 @@
|
|||
NOTICES
|
||||
|
||||
This repository incorporates material as listed below or described in the code.
|
||||
|
||||
|
||||
Component: PyTorch Reinforcement Learning Tutorial
|
||||
|
||||
https://github.com/pytorch/tutorials/blob/master/intermediate_source/reinforcement_q_learning.py
|
||||
|
||||
Open Source License/Copyright Notice.
|
||||
|
||||
BSD 3-Clause License
|
||||
|
||||
Copyright (c) 2017, Pytorch contributors
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
* Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
* Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
Additional Attribution.
|
||||
|
||||
Adam Paszke https://github.com/apaszke
|
|
@ -0,0 +1,76 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Initialize CyberBattleSim module"""
|
||||
from gym.envs.registration import registry, EnvSpec
|
||||
from gym.error import Error
|
||||
|
||||
from . import simulation
|
||||
from . import agents
|
||||
from ._env.cyberbattle_env import AttackerGoal, DefenderGoal
|
||||
from .samples.chainpattern import chainpattern
|
||||
from .samples.toyctf import toy_ctf
|
||||
from .simulation import generate_network, model
|
||||
|
||||
__all__ = (
|
||||
'simulation',
|
||||
'agents',
|
||||
)
|
||||
|
||||
|
||||
def register(id: str, cyberbattle_env_identifiers: model.Identifiers, **kwargs):
|
||||
""" same as gym.envs.registry.register, but adds CyberBattle specs to env.spec """
|
||||
if id in registry.env_specs:
|
||||
raise Error('Cannot re-register id: {}'.format(id))
|
||||
spec = EnvSpec(id, **kwargs)
|
||||
# Map from port number to port names : List[model.PortName]
|
||||
spec.ports = cyberbattle_env_identifiers.ports
|
||||
# Array of all possible node properties (not necessarily all used in the network) : List[model.PropertyName]
|
||||
spec.properties = cyberbattle_env_identifiers.properties
|
||||
# Array defining an index for every possible local vulnerability name : List[model.VulnerabilityID]
|
||||
spec.local_vulnerabilities = cyberbattle_env_identifiers.local_vulnerabilities
|
||||
# Array defining an index for every possible remote vulnerability name : List[model.VulnerabilityID]
|
||||
spec.remote_vulnerabilities = cyberbattle_env_identifiers.remote_vulnerabilities
|
||||
|
||||
registry.env_specs[id] = spec
|
||||
|
||||
|
||||
if 'CyberBattleToyCtf-v0' in registry.env_specs:
|
||||
del registry.env_specs['CyberBattleToyCtf-v0']
|
||||
|
||||
register(
|
||||
id='CyberBattleToyCtf-v0',
|
||||
cyberbattle_env_identifiers=toy_ctf.ENV_IDENTIFIERS,
|
||||
entry_point='cyberbattle._env.cyberbattle_toyctf:CyberBattleToyCtf',
|
||||
kwargs={'defender_agent': None,
|
||||
'attacker_goal': AttackerGoal(reward=889),
|
||||
'defender_goal': DefenderGoal(eviction=True)
|
||||
},
|
||||
# max_episode_steps=2600,
|
||||
)
|
||||
|
||||
if 'CyberBattleRandom-v0' in registry.env_specs:
|
||||
del registry.env_specs['CyberBattleRandom-v0']
|
||||
|
||||
register(
|
||||
id='CyberBattleRandom-v0',
|
||||
cyberbattle_env_identifiers=generate_network.ENV_IDENTIFIERS,
|
||||
entry_point='cyberbattle._env.cyberbattle_random:CyberBattleRandom',
|
||||
)
|
||||
|
||||
if 'CyberBattleChain-v0' in registry.env_specs:
|
||||
del registry.env_specs['CyberBattleChain-v0']
|
||||
|
||||
register(
|
||||
id='CyberBattleChain-v0',
|
||||
cyberbattle_env_identifiers=chainpattern.ENV_IDENTIFIERS,
|
||||
entry_point='cyberbattle._env.cyberbattle_chain:CyberBattleChain',
|
||||
kwargs={'size': 4,
|
||||
'defender_agent': None,
|
||||
'attacker_goal': AttackerGoal(reward=2200),
|
||||
'defender_goal': DefenderGoal(eviction=True),
|
||||
'winning_reward': 5000.0,
|
||||
'losing_reward': 0.0
|
||||
},
|
||||
reward_threshold=2200,
|
||||
)
|
|
@ -0,0 +1,21 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""CyberBattle environment based on a simple chain network structure"""
|
||||
|
||||
from ..samples.chainpattern import chainpattern
|
||||
from . import cyberbattle_env
|
||||
|
||||
|
||||
class CyberBattleChain(cyberbattle_env.CyberBattleEnv):
|
||||
"""CyberBattle environment based on a simple chain network structure"""
|
||||
|
||||
def __init__(self, size, **kwargs):
|
||||
self.size = size
|
||||
super().__init__(
|
||||
initial_environment=chainpattern.new_environment(size),
|
||||
**kwargs)
|
||||
|
||||
@ property
|
||||
def name(self) -> str:
|
||||
return f"CyberBattleChain-{self.size}"
|
|
@ -0,0 +1,130 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Test the CyberBattle Gym environment"""
|
||||
|
||||
import pytest
|
||||
import gym
|
||||
import numpy as np
|
||||
|
||||
from .cyberbattle_env import AttackerGoal
|
||||
|
||||
|
||||
def test_few_gym_iterations() -> None:
|
||||
"""Run a few iterations of the gym environment"""
|
||||
env = gym.make('CyberBattleToyCtf-v0')
|
||||
|
||||
for _ in range(2):
|
||||
env.reset()
|
||||
action_mask = env.compute_action_mask()
|
||||
assert action_mask
|
||||
for t in range(12):
|
||||
# env.render()
|
||||
|
||||
# sample a valid action
|
||||
action = env.sample_valid_action()
|
||||
|
||||
observation, reward, done, info = env.step(action)
|
||||
if done:
|
||||
print("Episode finished after {} timesteps".format(t + 1))
|
||||
break
|
||||
|
||||
env.close()
|
||||
pass
|
||||
|
||||
|
||||
def test_step_after_done() -> None:
|
||||
actions = [
|
||||
{'local_vulnerability': np.array([0, 1])}, # done=False, r=49.0
|
||||
{'remote_vulnerability': np.array([0, 1, 0])}, # done=False, r=45.0
|
||||
{'connect': np.array([0, 1, 2, 0])}, # done=False, r=100.0
|
||||
{'local_vulnerability': np.array([1, 3])}, # done=False, r=49.0
|
||||
{'connect': np.array([0, 2, 3, 1])}, # done=False, r=100.0
|
||||
{'remote_vulnerability': np.array([1, 2, 1])}, # done=False, r=49.0
|
||||
{'remote_vulnerability': np.array([1, 2, 0])}, # done=False, r=49.0
|
||||
{'remote_vulnerability': np.array([2, 1, 1])}, # done=False, r=45.0
|
||||
{'local_vulnerability': np.array([1, 0])}, # done=False, r=49.0
|
||||
{'local_vulnerability': np.array([1, 1])}, # done=False, r=40.0
|
||||
{'local_vulnerability': np.array([2, 1])}, # done=False, r=49.0
|
||||
{'remote_vulnerability': np.array([2, 3, 0])}, # done=False, r=45.0
|
||||
{'local_vulnerability': np.array([2, 4])}, # done=False, r=49.0
|
||||
{'connect': np.array([0, 3, 2, 2])}, # done=False, r=100.0
|
||||
{'local_vulnerability': np.array([3, 3])}, # done=False, r=49.0
|
||||
{'local_vulnerability': np.array([3, 0])}, # done=False, r=49.0
|
||||
{'remote_vulnerability': np.array([0, 4, 1])}, # done=False, r=49.0
|
||||
{'local_vulnerability': np.array([3, 1])}, # done=False, r=40.0
|
||||
{'connect': np.array([2, 4, 3, 3])}, # done=False, r=100.0
|
||||
{'remote_vulnerability': np.array([1, 3, 1])}, # done=False, r=45.0
|
||||
{'remote_vulnerability': np.array([1, 4, 0])}, # done=False, r=49.0
|
||||
{'local_vulnerability': np.array([4, 1])}, # done=False, r=49.0
|
||||
{'remote_vulnerability': np.array([0, 5, 0])}, # done=False, r=45.0
|
||||
{'local_vulnerability': np.array([4, 4])}, # done=False, r=49.0
|
||||
{'connect': np.array([3, 5, 2, 4])}, # done=False, r=100.0
|
||||
{'remote_vulnerability': np.array([2, 5, 1])}, # done=False, r=45.0
|
||||
{'local_vulnerability': np.array([5, 3])}, # done=False, r=49.0
|
||||
{'connect': np.array([2, 6, 3, 5])}, # done=False, r=100.0
|
||||
{'remote_vulnerability': np.array([4, 6, 1])}, # done=False, r=49.0
|
||||
{'local_vulnerability': np.array([5, 0])}, # done=False, r=49.0
|
||||
{'remote_vulnerability': np.array([4, 6, 0])}, # done=False, r=49.0
|
||||
{'local_vulnerability': np.array([5, 1])}, # done=False, r=40.0
|
||||
{'local_vulnerability': np.array([6, 1])}, # done=False, r=49.0
|
||||
{'remote_vulnerability': np.array([6, 7, 0])}, # done=False, r=45.0
|
||||
{'remote_vulnerability': np.array([0, 7, 1])}, # done=False, r=45.0
|
||||
{'local_vulnerability': np.array([6, 4])}, # done=False, r=49.0
|
||||
{'connect': np.array([4, 7, 2, 6])}, # done=False, r=100.0
|
||||
{'local_vulnerability': np.array([7, 3])}, # done=False, r=49.0
|
||||
{'connect': np.array([0, 8, 3, 7])}, # done=False, r=100.0
|
||||
{'remote_vulnerability': np.array([0, 8, 0])}, # done=False, r=49.0
|
||||
{'local_vulnerability': np.array([7, 0])}, # done=False, r=49.0
|
||||
{'local_vulnerability': np.array([8, 4])}, # done=False, r=49.0
|
||||
{'remote_vulnerability': np.array([3, 9, 1])}, # done=False, r=45.0
|
||||
{'connect': np.array([3, 9, 2, 8])}, # done=False, r=100.0
|
||||
{'remote_vulnerability': np.array([4, 9, 0])}, # done=False, r=45.0
|
||||
{'local_vulnerability': np.array([9, 0])}, # done=False, r=49.0
|
||||
{'remote_vulnerability': np.array([3, 8, 1])}, # done=False, r=49.0
|
||||
{'remote_vulnerability': np.array([6, 10, 0])}, # done=False, r=49.0
|
||||
{'local_vulnerability': np.array([9, 1])}, # done=False, r=40.0
|
||||
{'local_vulnerability': np.array([9, 3])}, # done=False, r=49.0
|
||||
{'remote_vulnerability': np.array([8, 10, 1])}, # done=False, r=49.0
|
||||
{'local_vulnerability': np.array([7, 1])}, # done=False, r=40.0
|
||||
{'connect': np.array([8, 10, 3, 9])}, # done=False, r=100.0
|
||||
{'local_vulnerability': np.array([10, 4])}, # done=False, r=49.0
|
||||
{'local_vulnerability': np.array([8, 1])}, # done=False, r=49.0
|
||||
{'connect': np.array([7, 11, 2, 10])}, # done=False, r=1000.0
|
||||
{'connect': np.array([5, 10, 3, 9])}, # done=True, r=5000.0
|
||||
|
||||
# this is one too many (after done)
|
||||
{'connect': np.array([10, 5, 2, 4])}, # done=True, r=5000.0
|
||||
]
|
||||
|
||||
env = gym.make('CyberBattleChain-v0', size=10, attacker_goal=AttackerGoal(reward=4000))
|
||||
|
||||
for a in actions[:-1]:
|
||||
env.step(a)
|
||||
|
||||
with pytest.raises(RuntimeError, match=r'new episode must be started with env\.reset\(\)'):
|
||||
env.step(actions[-1])
|
||||
|
||||
|
||||
@pytest.mark.parametrize('env_name', ['CyberBattleToyCtf-v0', 'CyberBattleRandom-v0', 'CyberBattleChain-v0'])
|
||||
def test_wrap_spec(env_name) -> None:
|
||||
env = gym.make(env_name)
|
||||
|
||||
class DummyWrapper(gym.Wrapper):
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
assert hasattr(self, 'spec')
|
||||
self.spec.dummy = 7
|
||||
|
||||
assert hasattr(env.spec, 'properties')
|
||||
assert hasattr(env.spec, 'ports')
|
||||
assert hasattr(env.spec, 'local_vulnerabilities')
|
||||
assert hasattr(env.spec, 'remote_vulnerabilities')
|
||||
|
||||
env = DummyWrapper(env)
|
||||
|
||||
assert hasattr(env.spec, 'properties')
|
||||
assert hasattr(env.spec, 'ports')
|
||||
assert hasattr(env.spec, 'local_vulnerabilities')
|
||||
assert hasattr(env.spec, 'remote_vulnerabilities')
|
||||
assert hasattr(env.spec, 'dummy')
|
|
@ -0,0 +1,14 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""A CyberBattle simulation over a randomly generated network"""
|
||||
|
||||
from ..simulation import generate_network
|
||||
from . import cyberbattle_env
|
||||
|
||||
|
||||
class CyberBattleRandom(cyberbattle_env.CyberBattleEnv):
|
||||
"""A sample CyberBattle environment"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(initial_environment=generate_network.new_environment())
|
|
@ -0,0 +1,14 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from ..samples.toyctf import toy_ctf
|
||||
from . import cyberbattle_env
|
||||
|
||||
|
||||
class CyberBattleToyCtf(cyberbattle_env.CyberBattleEnv):
|
||||
"""CyberBattle simulation based on a toy CTF exercise"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(
|
||||
initial_environment=toy_ctf.new_environment(),
|
||||
**kwargs)
|
|
@ -0,0 +1,151 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
Defines stock defender agents for the CyberBattle simulation.
|
||||
"""
|
||||
import random
|
||||
import numpy
|
||||
from abc import abstractmethod
|
||||
from cyberbattle.simulation.model import Environment
|
||||
from cyberbattle.simulation.actions import DefenderAgentActions
|
||||
from ..simulation import model
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
class DefenderAgent:
|
||||
"""Define the step function for a defender agent.
|
||||
Gets called after each step executed by the attacker agent."""
|
||||
|
||||
@abstractmethod
|
||||
def step(self, environment: Environment, actions: DefenderAgentActions, t: int):
|
||||
None
|
||||
|
||||
|
||||
class ScanAndReimageCompromisedMachines(DefenderAgent):
|
||||
"""A defender agent that scans a subset of network nodes
|
||||
detects presence of an attacker on a given node with
|
||||
some fixed probability and if detected re-image the compromised node.
|
||||
|
||||
probability -- probability that an attacker agent is detected when scanned given that the attacker agent is present
|
||||
scan_capacity -- maxium number of machine that a defender agent can scan in one simulation step
|
||||
scan_frequency -- frequencey of the scan in simulation steps
|
||||
"""
|
||||
|
||||
def __init__(self, probability: float, scan_capacity: int, scan_frequency: int):
|
||||
self.probability = probability
|
||||
self.scan_capacity = scan_capacity
|
||||
self.scan_frequency = scan_frequency
|
||||
|
||||
def step(self, environment: Environment, actions: DefenderAgentActions, t: int):
|
||||
if t % self.scan_frequency == 0:
|
||||
# scan nodes at random
|
||||
scanned_nodes = random.choices(list(environment.network.nodes), k=self.scan_capacity)
|
||||
for node_id in scanned_nodes:
|
||||
node_info = environment.get_node(node_id)
|
||||
if node_info.status == model.MachineStatus.Running and \
|
||||
node_info.agent_installed:
|
||||
is_malware_detected = numpy.random.random() <= self.probability
|
||||
if is_malware_detected:
|
||||
if node_info.reimagable:
|
||||
logging.info(f"Defender detected malware, reimaging node {node_id}")
|
||||
actions.reimage_node(node_id)
|
||||
else:
|
||||
logging.info(f"Defender detected malware, but node cannot be reimaged {node_id}")
|
||||
|
||||
|
||||
class ExternalRandomEvents(DefenderAgent):
|
||||
"""A 'defender' that randomly alters network node configuration"""
|
||||
|
||||
def step(self, environment: Environment, actions: DefenderAgentActions, t: int):
|
||||
self.patch_vulnerabilities_at_random(environment)
|
||||
self.stop_service_at_random(environment, actions)
|
||||
self.plant_vulnerabilities_at_random(environment)
|
||||
self.firewall_change_remove(environment)
|
||||
self.firewall_change_add(environment)
|
||||
|
||||
def patch_vulnerabilities_at_random(self, environment: Environment, probability: float = 0.1) -> None:
|
||||
# Iterate through every node.
|
||||
for node_id, node_data in environment.nodes():
|
||||
# Have a boolean remove_vulnerability decide if we will remove one.
|
||||
remove_vulnerability = numpy.random.random() <= probability
|
||||
if remove_vulnerability and len(node_data.vulnerabilities) > 0:
|
||||
choice = random.choice(list(node_data.vulnerabilities))
|
||||
node_data.vulnerabilities.pop(choice)
|
||||
|
||||
def stop_service_at_random(self, environment: Environment, actions: DefenderAgentActions, probability: float = 0.1) -> None:
|
||||
for node_id, node_data in environment.nodes():
|
||||
remove_service = numpy.random.random() <= probability
|
||||
if remove_service and len(node_data.services) > 0:
|
||||
service = random.choice(node_data.services)
|
||||
actions.stop_service(node_id, service.name)
|
||||
|
||||
def plant_vulnerabilities_at_random(self, environment: Environment, probability: float = 0.1) -> None:
|
||||
for node_id, node_data in environment.nodes():
|
||||
add_vulnerability = numpy.random.random() <= probability
|
||||
# See all differences between current node vulnerabilities and global ones.
|
||||
new_vulnerabilities = numpy.setdiff1d(
|
||||
list(environment.vulnerability_library.keys()), list(node_data.vulnerabilities.keys()))
|
||||
# If we have decided that we will add a vulnerability and there are new vulnerabilities not already
|
||||
# on the node, then add them.
|
||||
if add_vulnerability and len(new_vulnerabilities) > 0:
|
||||
new_vulnerability = random.choice(new_vulnerabilities)
|
||||
node_data.vulnerabilities[new_vulnerability] = \
|
||||
environment.vulnerability_library[new_vulnerability]
|
||||
|
||||
"""
|
||||
TODO: Not sure how to access global (environment) services.
|
||||
def serviceChangeAdd(self, probability: float) -> None:
|
||||
# Iterate through every node.
|
||||
for node_id, node_data in self.__environment.nodes():
|
||||
# Have a boolean addService decide if we will add one.
|
||||
addService = numpy.random.random() <= probability
|
||||
# List all new services we can add.
|
||||
newServices = numpy.setdiff1d(self.__environment.services, node_data.services)
|
||||
# If we have decided to add a service and there are new services to add, go ahead and add them.
|
||||
if addService and len(newServices) > 0:
|
||||
newService = random.choice(newServices)
|
||||
node_data.services.append(newService)
|
||||
return None
|
||||
"""
|
||||
|
||||
def firewall_change_remove(self, environment: Environment, probability: float = 0.1) -> None:
|
||||
# Iterate through every node.
|
||||
for node_id, node_data in environment.nodes():
|
||||
# Have a boolean remove_rule decide if we will remove one.
|
||||
remove_rule = numpy.random.random() <= probability
|
||||
# The following logic sees if there are both incoming and outgoing rules.
|
||||
# If there are, we remove one randomly.
|
||||
if remove_rule and len(node_data.firewall.outgoing) > 0 and len(node_data.firewall.incoming) > 0:
|
||||
incoming = numpy.random.random() <= 0.5
|
||||
if incoming:
|
||||
rule_to_remove = random.choice(node_data.firewall.incoming)
|
||||
node_data.firewall.incoming.remove(rule_to_remove)
|
||||
else:
|
||||
rule_to_remove = random.choice(node_data.firewall.outgoing)
|
||||
node_data.firewall.outgoing.remove(rule_to_remove)
|
||||
# If there are only outgoing rules, we remove one random outgoing rule.
|
||||
elif remove_rule and len(node_data.firewall.outgoing) > 0:
|
||||
rule_to_remove = random.choice(node_data.firewall.outgoing)
|
||||
node_data.firewall.outgoing.remove(rule_to_remove)
|
||||
# If there are only incoming rules, we remove one random incoming rule.
|
||||
elif remove_rule and len(node_data.firewall.incoming) > 0:
|
||||
rule_to_remove = random.choice(node_data.firewall.incoming)
|
||||
node_data.firewall.incoming.remove(rule_to_remove)
|
||||
|
||||
def firewall_change_add(self, environment: Environment, probability: float = 0.1) -> None:
|
||||
# Iterate through every node.
|
||||
for node_id, node_data in environment.nodes():
|
||||
# Have a boolean rule_to_add decide if we will add one.
|
||||
add_rule = numpy.random.random() <= probability
|
||||
if add_rule:
|
||||
# 0 For allow, 1 for block.
|
||||
rule_to_add = model.FirewallRule(port=random.choice(model.SAMPLE_IDENTIFIERS.ports),
|
||||
permission=model.RulePermission.ALLOW)
|
||||
# Randomly decide if we will add an incoming or outgoing rule.
|
||||
incoming = numpy.random.random() <= 0.5
|
||||
if incoming and rule_to_add not in node_data.firewall.incoming:
|
||||
node_data.firewall.incoming.append(rule_to_add)
|
||||
elif not incoming and rule_to_add not in node_data.firewall.incoming:
|
||||
node_data.firewall.outgoing.append(rule_to_add)
|
|
@ -0,0 +1,82 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""A discriminated union space for Gym"""
|
||||
|
||||
from collections import OrderedDict
|
||||
from typing import Mapping, Union, List
|
||||
|
||||
from gym import spaces
|
||||
from gym.utils import seeding
|
||||
|
||||
|
||||
class DiscriminatedUnion(spaces.Dict): # type: ignore
|
||||
"""
|
||||
A discriminated union of simpler spaces.
|
||||
|
||||
Example usage:
|
||||
|
||||
self.observation_space = discriminatedunion.DiscriminatedUnion(
|
||||
{"foo": spaces.Discrete(2), "Bar": spaces.Discrete(3)})
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
spaces: Union[None, List[spaces.Space], Mapping[str, spaces.Space]] = None,
|
||||
**spaces_kwargs: spaces.Space) -> None:
|
||||
"""Create a discriminated union space"""
|
||||
if spaces is None:
|
||||
super().__init__(spaces_kwargs)
|
||||
else:
|
||||
super().__init__(spaces=spaces)
|
||||
|
||||
def seed(self, seed: Union[None, int] = None) -> None:
|
||||
self._np_random, seed = seeding.np_random(seed)
|
||||
super().seed(seed)
|
||||
|
||||
def sample(self) -> object:
|
||||
space_count = len(self.spaces.items())
|
||||
index_k = self.np_random.randint(space_count)
|
||||
kth_key, kth_space = list(self.spaces.items())[index_k]
|
||||
return OrderedDict([(kth_key, kth_space.sample())])
|
||||
|
||||
def contains(self, candidate: object) -> bool:
|
||||
if not isinstance(candidate, dict) or len(candidate) != 1:
|
||||
return False
|
||||
k, space = list(candidate)[0]
|
||||
return k in self.spaces.keys()
|
||||
|
||||
@classmethod
|
||||
def is_of_kind(cls, key: str, sample_n: Mapping[str, object]) -> bool:
|
||||
"""Returns true if a given sample is of the specified discriminated kind"""
|
||||
return key in sample_n.keys()
|
||||
|
||||
@classmethod
|
||||
def kind(cls, sample_n: Mapping[str, object]) -> str:
|
||||
"""Returns the discriminated kind of a given sample"""
|
||||
keys = sample_n.keys()
|
||||
assert len(keys) == 1
|
||||
return list(keys)[0]
|
||||
|
||||
def __getitem__(self, key: str) -> spaces.Space:
|
||||
return self.spaces[key]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.__class__.__name__ + "(" + ", ". join([str(k) + ":" + str(s) for k, s in self.spaces.items()]) + ")"
|
||||
|
||||
def to_jsonable(self, sample_n: object) -> object:
|
||||
return super().to_jsonable(sample_n)
|
||||
|
||||
def from_jsonable(self, sample_n: object) -> object:
|
||||
ret = super().from_jsonable(sample_n)
|
||||
assert len(ret) == 1
|
||||
return ret
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, DiscriminatedUnion) and self.spaces == other.spaces
|
||||
|
||||
|
||||
def test_sampling() -> None:
|
||||
"""Simple sampling test"""
|
||||
union = DiscriminatedUnion(spaces={"foo": spaces.Discrete(8), "Bar": spaces.Discrete(3)})
|
||||
[union.sample() for i in range(100)]
|
|
@ -0,0 +1,106 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import networkx as nx
|
||||
from gym.spaces import Space, Dict
|
||||
|
||||
|
||||
class BaseGraph(Space):
|
||||
_nx_class: type
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_nodes: int,
|
||||
node_property_space: Optional[Dict] = None,
|
||||
edge_property_space: Optional[Dict] = None):
|
||||
|
||||
self.max_num_nodes = max_num_nodes
|
||||
self.node_property_space = Dict() if node_property_space is None else node_property_space
|
||||
self.edge_property_space = Dict() if edge_property_space is None else edge_property_space
|
||||
super().__init__()
|
||||
|
||||
def sample(self):
|
||||
num_nodes = self.np_random.randint(self.max_num_nodes + 1)
|
||||
graph = self._nx_class()
|
||||
|
||||
# add nodes with properties
|
||||
for node_id in range(num_nodes):
|
||||
node_properties = {k: s.sample() for k, s in self.node_property_space.spaces.items()}
|
||||
graph.add_node(node_id, **node_properties)
|
||||
|
||||
if num_nodes < 2:
|
||||
return graph
|
||||
|
||||
# add some edges with properties
|
||||
seen, unseen = [], list(range(num_nodes)) # init
|
||||
self.__pop_random(unseen, seen) # pop one node before we start
|
||||
while unseen:
|
||||
node_id_from, node_id_to = self.__sample_random(seen), self.__pop_random(unseen, seen)
|
||||
edge_properties = {k: s.sample() for k, s in self.edge_property_space.spaces.items()}
|
||||
graph.add_edge(node_id_from, node_id_to, **edge_properties)
|
||||
|
||||
return graph
|
||||
|
||||
def __pop_random(self, unseen: list, seen: list):
|
||||
i = self.np_random.choice(len(unseen))
|
||||
x = unseen[i]
|
||||
seen.append(x)
|
||||
del unseen[i]
|
||||
return x
|
||||
|
||||
def __sample_random(self, seen: list):
|
||||
i = self.np_random.choice(len(seen))
|
||||
return seen[i]
|
||||
|
||||
def contains(self, x):
|
||||
return (
|
||||
isinstance(x, self._nx_class)
|
||||
and all(node_property in self.node_property_space for node_property in x.nodes.values())
|
||||
and all(edge_property in self.edge_property_space for edge_property in x.edges.values())
|
||||
)
|
||||
|
||||
|
||||
class Graph(BaseGraph):
|
||||
_nx_class = nx.Graph
|
||||
|
||||
|
||||
class DiGraph(BaseGraph):
|
||||
_nx_class = nx.DiGraph
|
||||
|
||||
|
||||
class MultiGraph(BaseGraph):
|
||||
_nx_class = nx.MultiGraph
|
||||
|
||||
|
||||
class MultiDiGraph(BaseGraph):
|
||||
_nx_class = nx.MultiDiGraph
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from gym.spaces import Box, Discrete
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
space = DiGraph(
|
||||
max_num_nodes=10,
|
||||
node_property_space=Dict({'vector': Box(0, 1, (3,)), 'category': Discrete(7)}),
|
||||
edge_property_space=Dict({'weight': Box(0, 1, ())}))
|
||||
|
||||
space.seed(42)
|
||||
graph = space.sample()
|
||||
assert graph in space
|
||||
|
||||
for node_id, node_properties in graph.nodes.items():
|
||||
print(f"node_id: {node_id}, node_properties: {node_properties}")
|
||||
|
||||
for (node_id_from, node_id_to), edge_properties in graph.edges.items():
|
||||
print(f"node_id_from: {node_id_from}, node_id_to: {node_id_to}, "
|
||||
f"edge_properties: {edge_properties}")
|
||||
|
||||
pos = nx.spring_layout(graph)
|
||||
nx.draw_networkx_nodes(graph, pos)
|
||||
nx.draw_networkx_edges(graph, pos)
|
||||
nx.draw_networkx_labels(graph, pos)
|
||||
# nx.draw_networkx_labels(graph, pos, graph.nodes)
|
||||
plt.show()
|
|
@ -0,0 +1,214 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from typing import Union, Tuple
|
||||
|
||||
import gym
|
||||
import numpy as onp
|
||||
import networkx as nx
|
||||
|
||||
from .graph_spaces import DiGraph
|
||||
|
||||
|
||||
Action = Union[Tuple[int, int], Tuple[int, int, int], Tuple[int, int, int, int]]
|
||||
|
||||
|
||||
class CyberBattleGraph(gym.Wrapper):
|
||||
"""
|
||||
|
||||
A wrapper for CyberBattleSim that maintains the agent's
|
||||
knowledge graph containing information about the subset
|
||||
of the network that was explored so far.
|
||||
|
||||
Currently the nodes of this graph are a subset of the environment nodes.
|
||||
Eventually we will add new node types to represent various entities
|
||||
like credentials and users. Edges will represent relationships between those entities
|
||||
(e.g. user X is authenticated with machine Y using credential Z).
|
||||
|
||||
Actions
|
||||
-------
|
||||
|
||||
Actions are of the form:
|
||||
|
||||
.. code:: python
|
||||
|
||||
(kind, *indicators)
|
||||
|
||||
|
||||
The ``kind`` which is one of
|
||||
|
||||
.. code:: python
|
||||
|
||||
# kind
|
||||
0: Local Vulnerability
|
||||
1: Remote Vulnerability
|
||||
2: Connect
|
||||
|
||||
The indicators vary in meaning and length, depending on the ``kind``:
|
||||
|
||||
.. code:: python
|
||||
|
||||
# kind=0 (Local Vulnerability)
|
||||
indicators = (node_id, local_vulnerability_id)
|
||||
|
||||
# kind=1 (Remote Vulnerability)
|
||||
indicators = (from_node_id, to_node_id, remote_vulnerability_id)
|
||||
|
||||
# kind=2 (Connect)
|
||||
indicators = (from_node_id, to_node_id, port_id, credential_id)
|
||||
|
||||
The node ids can be obtained from the graph, e.g.
|
||||
|
||||
.. code:: python
|
||||
|
||||
node_ids = observation['graph'].keys()
|
||||
|
||||
The other indicators are listed below.
|
||||
|
||||
.. code:: python
|
||||
|
||||
# local_vulnerability_ids
|
||||
0: ScanBashHistory
|
||||
1: ScanExplorerRecentFiles
|
||||
2: SudoAttempt
|
||||
3: CrackKeepPassX
|
||||
4: CrackKeepPass
|
||||
|
||||
# remote_vulnerability_ids
|
||||
0: ProbeLinux
|
||||
1: ProbeWindows
|
||||
|
||||
# port_ids
|
||||
0: HTTPS
|
||||
1: GIT
|
||||
2: SSH
|
||||
3: RDP
|
||||
4: PING
|
||||
5: MySQL
|
||||
6: SSH-key
|
||||
7: su
|
||||
|
||||
Examples
|
||||
~~~~~~~~
|
||||
Here are some example actions:
|
||||
|
||||
.. code:: python
|
||||
|
||||
a = (0, 5, 3) # try local vulnerability "CrackKeepPassX" on node 5
|
||||
a = (1, 5, 7, 1) # try remote vulnerability "ProbeWindows" from node 5 to node 7
|
||||
a = (2, 5, 7, 3, 2) # try to connect from node 5 to node 7 using credential 2 over RDP port
|
||||
|
||||
|
||||
Observations
|
||||
------------
|
||||
|
||||
Observations are graphs of the nodes that have been discovered so far. Each node is annotated
|
||||
with a dict of properties of the form:
|
||||
|
||||
.. code:: python
|
||||
|
||||
node_properties = {
|
||||
'name': 'FooServer', # human-readable identifier
|
||||
'privilege_level': 1, # 0: not owned, 1: admin, 2: system
|
||||
'flags': array(-1, 0, 1, 0, 0, ..., 0]), # 1: set, -1: unset, 0: unknown
|
||||
'credentials': array([-1, 5, -1, ..., -1]), # array of ports (-1 means no cred)
|
||||
'has_leaked_creds': True, # whether node has leaked any credentials so far
|
||||
}
|
||||
|
||||
# flag_ids
|
||||
0: Windows
|
||||
1: Linux
|
||||
2: ApacheWebSite
|
||||
3: IIS_2019
|
||||
4: IIS_2020_patched
|
||||
5: MySql
|
||||
6: Ubuntu
|
||||
7: nginx/1.10.3
|
||||
8: SMB_vuln
|
||||
9: SMB_vuln_patched
|
||||
10: SQLServer
|
||||
11: Win10
|
||||
12: Win10Patched
|
||||
13: FLAG:Linux
|
||||
|
||||
Note that the **position** of a non-trivial port number in ``'credentials'`` corresponds to the
|
||||
credential id. Therefore, for the node in the example above, we have a known credential on
|
||||
:code:`port_id=5` with :code:`credential_id=1` (the position in the array).
|
||||
|
||||
|
||||
"""
|
||||
__kinds = ('local_vulnerability', 'remote_vulnerability', 'connect')
|
||||
|
||||
def __init__(self, env, maximum_total_credentials=22, maximum_node_count=22):
|
||||
super().__init__(env)
|
||||
self._bounds = self.env._bounds
|
||||
self.__graph = None
|
||||
self.observation_space = DiGraph(self._bounds.maximum_node_count)
|
||||
|
||||
def reset(self):
|
||||
observation = self.env.reset()
|
||||
self.__graph = nx.DiGraph()
|
||||
self.__add_node(observation)
|
||||
self.__update_nodes(observation)
|
||||
return self.__graph
|
||||
|
||||
def step(self, action: Action):
|
||||
"""
|
||||
|
||||
Take a step in the MDP.
|
||||
|
||||
Args:
|
||||
action: An *abstract* action.
|
||||
|
||||
Returns:
|
||||
observation: The next-step observation.
|
||||
reward: The reward associated with the given action (and previous observation).
|
||||
done: Whether the next-step observation is a terminal state.
|
||||
info: Some additional info.
|
||||
|
||||
"""
|
||||
kind_id, *indicators = action
|
||||
observation, reward, done, info = self.env.step({self.__kinds[kind_id]: indicators})
|
||||
for _ in range(observation['newly_discovered_nodes_count']):
|
||||
self.__add_node(observation)
|
||||
if True: # TODO: do we need to update edges and nodes every time?
|
||||
self.__update_edges(observation)
|
||||
self.__update_nodes(observation)
|
||||
return self.__graph, reward, done, info
|
||||
|
||||
def __add_node(self, observation):
|
||||
while self.__graph.number_of_nodes() < observation['discovered_node_count']:
|
||||
node_index = self.__graph.number_of_nodes()
|
||||
creds = onp.full(self._bounds.maximum_total_credentials, -1, dtype=onp.int8)
|
||||
self.__graph.add_node(
|
||||
node_index,
|
||||
name=observation['discovered_nodes'][node_index],
|
||||
privilege_level=None, flags=None, # these are set by __update_nodes()
|
||||
credentials=creds,
|
||||
has_leaked_creds=False,
|
||||
)
|
||||
|
||||
def __update_edges(self, observation):
|
||||
g_orig = observation['explored_network']
|
||||
node_ids = {n: i for i, n in enumerate(observation['discovered_nodes'])}
|
||||
for (from_name, to_name), edge_properties in g_orig.edges.items():
|
||||
self.__graph.add_edge(node_ids[from_name], node_ids[to_name], **edge_properties)
|
||||
|
||||
def __update_nodes(self, observation):
|
||||
node_properties = zip(
|
||||
observation['nodes_privilegelevel'],
|
||||
observation['discovered_nodes_properties'],
|
||||
)
|
||||
for node_id, (privilege_level, flags) in enumerate(node_properties):
|
||||
# This value is already provided in self.__graph.nodes[node_id]['data'].privilege_level
|
||||
self.__graph.nodes[node_id]['privilege_level'] = privilege_level
|
||||
# This value is already provided in self.__graph.nodes[node_id]['data'].properties
|
||||
self.__graph.nodes[node_id]['flags'] = flags
|
||||
|
||||
for cred_id, (node_id, port_id) in enumerate(observation['credential_cache_matrix']):
|
||||
node_id, port_id = int(node_id), int(port_id)
|
||||
# NOTE: this code ignores situations where the same cred_id is
|
||||
# used for two different ports (This can be the case, even on the same node for two different ports.)
|
||||
self.__graph.nodes[node_id]['credentials'][cred_id] = port_id
|
||||
# Mark the node has leaking credentials
|
||||
self.__graph.nodes[node_id]['has_leaked_creds'] = True
|
|
@ -0,0 +1,126 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from typing import NamedTuple
|
||||
|
||||
import gym
|
||||
from gym.spaces import Space, Discrete, Tuple
|
||||
import numpy as onp
|
||||
|
||||
|
||||
class Env(NamedTuple):
|
||||
observation_space: Space
|
||||
action_space: Space
|
||||
|
||||
|
||||
def context_spaces(observation_space, action_space):
|
||||
K = 3 # noqa: N806
|
||||
N, L = action_space.spaces['local_vulnerability'].nvec # noqa: N806
|
||||
N, N, R = action_space.spaces['remote_vulnerability'].nvec # noqa: N806
|
||||
N, N, P, C = action_space.spaces['connect'].nvec # noqa: N806
|
||||
return {
|
||||
'kind': Env(observation_space, Discrete(K)),
|
||||
'local_node_id': Env(Tuple((observation_space, Discrete(K))), Discrete(N)),
|
||||
'local_vuln_id': Env(Tuple((observation_space, Discrete(N))), Discrete(L)),
|
||||
'remote_node_id': Env(Tuple((observation_space, Discrete(K), Discrete(N))), Discrete(N)),
|
||||
'remote_vuln_id': Env(Tuple((observation_space, Discrete(N), Discrete(N))), Discrete(R)),
|
||||
'cred_id': Env(observation_space, Discrete(C)),
|
||||
}
|
||||
|
||||
|
||||
class ContextWrapper(gym.Wrapper):
|
||||
__kinds = ('local_vulnerability', 'remote_vulnerability', 'connect')
|
||||
|
||||
def __init__(self, env, options):
|
||||
|
||||
super().__init__(env)
|
||||
assert isinstance(options, dict) and set(options) == {
|
||||
'kind', 'local_node_id', 'local_vuln_id', 'remote_node_id', 'remote_vuln_id', 'cred_id'}
|
||||
self._options = options
|
||||
self._bounds = self.env._bounds
|
||||
self._action_context = []
|
||||
|
||||
def reset(self):
|
||||
self._action_context = onp.full(5, -1, dtype=onp.int32)
|
||||
self._observation = self.env.reset()
|
||||
return self._observation
|
||||
|
||||
def step(self, dummy=None):
|
||||
obs = self._observation
|
||||
kind = self._options['kind'](obs)
|
||||
local_node_id = self._options['local_node_id']((obs, kind))
|
||||
if kind == 0:
|
||||
local_vuln_id = self._options['local_vuln_id']((obs, local_node_id))
|
||||
a = {self.__kinds[kind]: onp.array([local_node_id, local_vuln_id])}
|
||||
else:
|
||||
remote_node_id = self._options['remote_node_id']((obs, kind, local_node_id))
|
||||
if kind == 1:
|
||||
remote_vuln_id = \
|
||||
self._options['remote_vuln_id']((obs, local_node_id, remote_node_id))
|
||||
a = {self.__kinds[kind]: onp.array([local_node_id, remote_node_id, remote_vuln_id])}
|
||||
else:
|
||||
cred_id = self._options['cred_id'](obs)
|
||||
assert cred_id < obs['credential_cache_length']
|
||||
node_id, port_id = obs['credential_cache_matrix'][cred_id].astype('int32')
|
||||
a = {self.__kinds[kind]: onp.array([local_node_id, node_id, port_id, cred_id])}
|
||||
|
||||
self._observation, reward, done, info = self.env.step(a)
|
||||
return self._observation, reward, done, {**info, 'action': a}
|
||||
|
||||
|
||||
# --- random option policies --------------------------------------------------------------------- #
|
||||
|
||||
def pi_kind(s):
|
||||
kinds = ('local_vulnerability', 'remote_vulnerability', 'connect')
|
||||
masked = onp.array([i for i, k in enumerate(kinds) if onp.any(s['action_mask'][k])])
|
||||
return onp.random.choice(masked)
|
||||
|
||||
|
||||
def pi_local_node_id(s):
|
||||
s, k = s
|
||||
if k == 0:
|
||||
local_node_ids, _ = onp.argwhere(s['action_mask']['local_vulnerability']).T
|
||||
elif k == 1:
|
||||
local_node_ids, _, _ = onp.argwhere(s['action_mask']['remote_vulnerability']).T
|
||||
else:
|
||||
local_node_ids, _, _, _ = onp.argwhere(s['action_mask']['connect']).T
|
||||
return onp.random.choice(local_node_ids)
|
||||
|
||||
|
||||
def pi_local_vuln_id(s):
|
||||
s, local_node_id = s
|
||||
local_node_ids, local_vuln_ids = onp.argwhere(s['action_mask']['local_vulnerability']).T
|
||||
masked = local_vuln_ids[local_node_ids == local_node_id]
|
||||
return onp.random.choice(masked)
|
||||
|
||||
|
||||
def pi_remote_node_id(s):
|
||||
s, k, local_node_id = s
|
||||
assert k != 0
|
||||
if k == 1:
|
||||
local_node_ids, remote_node_ids, _ = onp.argwhere(s['action_mask']['remote_vulnerability']).T
|
||||
else:
|
||||
local_node_ids, remote_node_ids, _, _ = onp.argwhere(s['action_mask']['connect']).T
|
||||
return onp.random.choice(remote_node_ids[local_node_ids == local_node_id])
|
||||
|
||||
|
||||
def pi_remote_vuln_id(s):
|
||||
s, local_node_id, remote_node_id = s
|
||||
local_node_ids, remote_node_ids, remote_vuln_ids = \
|
||||
onp.argwhere(s['action_mask']['remote_vulnerability']).T
|
||||
mask = (local_node_ids == local_node_id) & (remote_node_ids == remote_node_id)
|
||||
return onp.random.choice(remote_vuln_ids[mask])
|
||||
|
||||
|
||||
def pi_cred_id(s):
|
||||
return onp.random.choice(s['credential_cache_length'])
|
||||
|
||||
|
||||
random_options = {
|
||||
'kind': pi_kind,
|
||||
'local_node_id': pi_local_node_id,
|
||||
'local_vuln_id': pi_local_vuln_id,
|
||||
'remote_node_id': pi_remote_node_id,
|
||||
'remote_vuln_id': pi_remote_vuln_id,
|
||||
'cred_id': pi_cred_id,
|
||||
}
|
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
This module contains all the agents to be used as baselines on the CyberBattle env.
|
||||
|
||||
"""
|
||||
|
||||
from .baseline.learner import Learner, AgentWrapper, EnvironmentBounds
|
||||
|
||||
__all__ = (
|
||||
'Learner',
|
||||
'AgentWrapper',
|
||||
'EnvironmentBounds'
|
||||
)
|
|
@ -0,0 +1,489 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# Function DeepQLearnerPolicy.optimize_model:
|
||||
# Copyright (c) 2017, Pytorch contributors
|
||||
# All rights reserved.
|
||||
# https://github.com/pytorch/tutorials/blob/master/LICENSE
|
||||
|
||||
"""Deep Q-learning agent applied to chain network (notebook)
|
||||
This notebooks can be run directly from VSCode, to generate a
|
||||
traditional Jupyter Notebook to open in your browser
|
||||
you can run the VSCode command `Export Currenty Python File As Jupyter Notebook`.
|
||||
|
||||
Requirements:
|
||||
Nvidia CUDA drivers for WSL2: https://docs.nvidia.com/cuda/wsl-user-guide/index.html
|
||||
PyTorch
|
||||
"""
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
# %% [markdown]
|
||||
# # Chain network CyberBattle Gym played by a Deeo Q-learning agent
|
||||
|
||||
# %%
|
||||
from numpy import ndarray
|
||||
from cyberbattle._env import cyberbattle_env
|
||||
import numpy as np
|
||||
from typing import List, NamedTuple, Optional, Tuple, Union
|
||||
import random
|
||||
|
||||
# deep learning packages
|
||||
from torch import Tensor
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
import torch.cuda
|
||||
|
||||
from .learner import Learner
|
||||
from .agent_wrapper import EnvironmentBounds
|
||||
import cyberbattle.agents.baseline.agent_wrapper as w
|
||||
from .agent_randomcredlookup import CredentialCacheExploiter
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
class CyberBattleStateActionModel:
|
||||
""" Define an abstraction of the state and action space
|
||||
for a CyberBattle environment, to be used to train a Q-function.
|
||||
"""
|
||||
|
||||
def __init__(self, ep: EnvironmentBounds):
|
||||
self.ep = ep
|
||||
|
||||
self.global_features = w.ConcatFeatures(ep, [
|
||||
# w.Feature_discovered_node_count(ep),
|
||||
# w.Feature_owned_node_count(ep),
|
||||
w.Feature_discovered_notowned_node_count(ep, None)
|
||||
|
||||
# w.Feature_discovered_ports(ep),
|
||||
# w.Feature_discovered_ports_counts(ep),
|
||||
# w.Feature_discovered_ports_sliding(ep),
|
||||
# w.Feature_discovered_credential_count(ep),
|
||||
# w.Feature_discovered_nodeproperties_sliding(ep),
|
||||
])
|
||||
|
||||
self.node_specific_features = w.ConcatFeatures(ep, [
|
||||
# w.Feature_actions_tried_at_node(ep),
|
||||
w.Feature_success_actions_at_node(ep),
|
||||
w.Feature_failed_actions_at_node(ep),
|
||||
w.Feature_active_node_properties(ep),
|
||||
w.Feature_active_node_age(ep)
|
||||
# w.Feature_active_node_id(ep)
|
||||
])
|
||||
|
||||
self.state_space = w.ConcatFeatures(ep, self.global_features.feature_selection +
|
||||
self.node_specific_features.feature_selection)
|
||||
|
||||
self.action_space = w.AbstractAction(ep)
|
||||
|
||||
def get_state_astensor(self, state: w.StateAugmentation):
|
||||
state_vector = self.state_space.get(state, node=None)
|
||||
state_vector_float = np.array(state_vector, dtype=np.float32)
|
||||
state_tensor = torch.from_numpy(state_vector_float).unsqueeze(0)
|
||||
return state_tensor
|
||||
|
||||
def implement_action(
|
||||
self,
|
||||
wrapped_env: w.AgentWrapper,
|
||||
actor_features: ndarray,
|
||||
abstract_action: np.int32) -> Tuple[str, Optional[cyberbattle_env.Action], Optional[int]]:
|
||||
"""Specialize an abstract model action into a CyberBattle gym action.
|
||||
|
||||
actor_features -- the desired features of the actor to use (source CyberBattle node)
|
||||
abstract_action -- the desired type of attack (connect, local, remote).
|
||||
|
||||
Returns a gym environment implementing the desired attack at a node with the desired embedding.
|
||||
"""
|
||||
|
||||
observation = wrapped_env.state.observation
|
||||
|
||||
# Pick source node at random (owned and with the desired feature encoding)
|
||||
potential_source_nodes = [
|
||||
from_node
|
||||
for from_node in w.owned_nodes(observation)
|
||||
if np.all(actor_features == self.node_specific_features.get(wrapped_env.state, from_node))
|
||||
]
|
||||
|
||||
if len(potential_source_nodes) > 0:
|
||||
source_node = np.random.choice(potential_source_nodes)
|
||||
|
||||
gym_action = self.action_space.specialize_to_gymaction(
|
||||
source_node, observation, np.int32(abstract_action))
|
||||
|
||||
if not gym_action:
|
||||
return "exploit[undefined]->explore", None, None
|
||||
|
||||
elif wrapped_env.env.is_action_valid(gym_action, observation['action_mask']):
|
||||
return "exploit", gym_action, source_node
|
||||
else:
|
||||
return "exploit[invalid]->explore", None, None
|
||||
else:
|
||||
return "exploit[no_actor]->explore", None, None
|
||||
|
||||
# %%
|
||||
|
||||
# Deep Q-learning
|
||||
|
||||
|
||||
class Transition(NamedTuple):
|
||||
"""One taken transition and its outcome"""
|
||||
state: Union[Tuple[Tensor], List[Tensor]]
|
||||
action: Union[Tuple[Tensor], List[Tensor]]
|
||||
next_state: Union[Tuple[Tensor], List[Tensor]]
|
||||
reward: Union[Tuple[Tensor], List[Tensor]]
|
||||
|
||||
|
||||
class ReplayMemory(object):
|
||||
"""Transition replay memory"""
|
||||
|
||||
def __init__(self, capacity):
|
||||
self.capacity = capacity
|
||||
self.memory = []
|
||||
self.position = 0
|
||||
|
||||
def push(self, *args):
|
||||
"""Saves a transition."""
|
||||
if len(self.memory) < self.capacity:
|
||||
self.memory.append(None)
|
||||
self.memory[self.position] = Transition(*args)
|
||||
self.position = (self.position + 1) % self.capacity
|
||||
|
||||
def sample(self, batch_size):
|
||||
return random.sample(self.memory, batch_size)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.memory)
|
||||
|
||||
|
||||
class DQN(nn.Module):
|
||||
"""The Deep Neural Network used to estimate the Q function"""
|
||||
|
||||
def __init__(self, ep: EnvironmentBounds):
|
||||
super(DQN, self).__init__()
|
||||
|
||||
model = CyberBattleStateActionModel(ep)
|
||||
linear_input_size = len(model.state_space.dim_sizes)
|
||||
output_size = model.action_space.flat_size()
|
||||
|
||||
self.hidden_layer1 = nn.Linear(linear_input_size, 1024)
|
||||
# self.bn1 = nn.BatchNorm1d(256)
|
||||
self.hidden_layer2 = nn.Linear(1024, 512)
|
||||
self.hidden_layer3 = nn.Linear(512, 128)
|
||||
# self.hidden_layer4 = nn.Linear(128, 64)
|
||||
self.head = nn.Linear(128, output_size)
|
||||
|
||||
# Called with either one element to determine next action, or a batch
|
||||
# during optimization. Returns tensor([[left0exp,right0exp]...]).
|
||||
def forward(self, x):
|
||||
x = F.relu(self.hidden_layer1(x))
|
||||
# x = F.dropout(x, p=0.5, training=self.training)
|
||||
x = F.relu(self.hidden_layer2(x))
|
||||
# x = F.dropout(x, p=0.5, training=self.training)
|
||||
x = F.relu(self.hidden_layer3(x))
|
||||
# x = F.relu(self.hidden_layer4(x))
|
||||
return self.head(x.view(x.size(0), -1))
|
||||
|
||||
|
||||
def random_argmax(array):
|
||||
"""Just like `argmax` but if there are multiple elements with the max
|
||||
return a random index to break ties instead of returning the first one."""
|
||||
max_value = np.max(array)
|
||||
max_index = np.where(array == max_value)[0]
|
||||
|
||||
if max_index.shape[0] > 1:
|
||||
max_index = int(np.random.choice(max_index, size=1))
|
||||
else:
|
||||
max_index = int(max_index)
|
||||
|
||||
return max_value, max_index
|
||||
|
||||
|
||||
class ChosenActionMetadata(NamedTuple):
|
||||
"""Additonal info about the action chosen by the DQN-induced policy"""
|
||||
abstract_action: np.int32
|
||||
actor_node: int
|
||||
actor_features: ndarray
|
||||
actor_state: ndarray
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"[abstract_action={self.abstract_action}, actor={self.actor_node}, state={self.actor_state}]"
|
||||
|
||||
|
||||
class DeepQLearnerPolicy(Learner):
|
||||
"""Deep Q-Learning on CyberBattle environments
|
||||
|
||||
Parameters
|
||||
==========
|
||||
ep -- global parameters of the environment
|
||||
model -- define a state and action abstraction for the gym environment
|
||||
gamma -- Q discount factor
|
||||
replay_memory_size -- size of the replay memory
|
||||
batch_size -- Deep Q-learning batch
|
||||
target_update -- Deep Q-learning replay frequency (in number of episodes)
|
||||
learning_rate -- the learning rate
|
||||
|
||||
Parameters from DeepDoubleQ paper
|
||||
- learning_rate = 0.00025
|
||||
- linear epsilon decay
|
||||
- gamma = 0.99
|
||||
|
||||
Pytorch code from tutorial at
|
||||
https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
ep: EnvironmentBounds,
|
||||
gamma: float,
|
||||
replay_memory_size: int,
|
||||
target_update: int,
|
||||
batch_size: int,
|
||||
learning_rate: float
|
||||
):
|
||||
|
||||
self.stateaction_model = CyberBattleStateActionModel(ep)
|
||||
self.batch_size = batch_size
|
||||
self.gamma = gamma
|
||||
self.learning_rate = learning_rate
|
||||
|
||||
self.policy_net = DQN(ep).to(device)
|
||||
self.target_net = DQN(ep).to(device)
|
||||
self.target_net.load_state_dict(self.policy_net.state_dict())
|
||||
self.target_net.eval()
|
||||
self.target_update = target_update
|
||||
|
||||
self.optimizer = optim.RMSprop(self.policy_net.parameters(), lr=learning_rate)
|
||||
self.memory = ReplayMemory(replay_memory_size)
|
||||
|
||||
self.credcache_policy = CredentialCacheExploiter()
|
||||
|
||||
def parameters_as_string(self):
|
||||
return f'γ={self.gamma}, lr={self.learning_rate}, replaymemory={self.memory.capacity},\n' \
|
||||
f'batch={self.batch_size}, target_update={self.target_update}'
|
||||
|
||||
def all_parameters_as_string(self) -> str:
|
||||
model = self.stateaction_model
|
||||
return f'{self.parameters_as_string()}\n' \
|
||||
f'dimension={model.state_space.flat_size()}x{model.action_space.flat_size()}, ' \
|
||||
f'Q={[f.name() for f in model.state_space.feature_selection]} ' \
|
||||
f"-> 'abstract_action'"
|
||||
|
||||
def optimize_model(self, norm_clipping=False):
|
||||
if len(self.memory) < self.batch_size:
|
||||
return
|
||||
|
||||
transitions = self.memory.sample(self.batch_size)
|
||||
# converts batch-array of Transitions to Transition of batch-arrays.
|
||||
batch = Transition(*zip(*transitions))
|
||||
|
||||
# Compute a mask of non-final states and concatenate the batch elements
|
||||
# (a final state would've been the one after which simulation ended)
|
||||
non_final_mask = torch.tensor(tuple(map((lambda s: s is not None), batch.next_state)),
|
||||
device=device, dtype=torch.bool)
|
||||
non_final_next_states = torch.cat([s for s in batch.next_state
|
||||
if s is not None])
|
||||
state_batch = torch.cat(batch.state)
|
||||
action_batch = torch.cat(batch.action)
|
||||
reward_batch = torch.cat(batch.reward)
|
||||
|
||||
# Compute Q(s_t, a) - the model computes Q(s_t), then we select the
|
||||
# columns of actions taken. These are the actions which would've been taken
|
||||
# for each batch state according to policy_net
|
||||
# print(f'state_batch={state_batch.shape} input={len(self.stateaction_model.state_space.dim_sizes)}')
|
||||
output = self.policy_net(state_batch)
|
||||
|
||||
# print(f'output={output.shape} batch.action={transitions[0].action.shape} action_batch={action_batch.shape}')
|
||||
state_action_values = output.gather(1, action_batch)
|
||||
|
||||
# Compute V(s_{t+1}) for all next states.
|
||||
# Expected values of actions for non_final_next_states are computed based
|
||||
# on the "older" target_net; selecting their best reward with max(1)[0].
|
||||
# This is merged based on the mask, such that we'll have either the expected
|
||||
# state value or 0 in case the state was final.
|
||||
next_state_values = torch.zeros(self.batch_size, device=device)
|
||||
next_state_values[non_final_mask] = self.target_net(non_final_next_states).max(1)[0].detach()
|
||||
# Compute the expected Q values
|
||||
expected_state_action_values = (next_state_values * self.gamma) + reward_batch
|
||||
|
||||
# Compute Huber loss
|
||||
loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))
|
||||
|
||||
# Optimize the model
|
||||
self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
|
||||
# Gradient clipping
|
||||
if norm_clipping:
|
||||
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
|
||||
else:
|
||||
for param in self.policy_net.parameters():
|
||||
param.grad.data.clamp_(-1, 1)
|
||||
self.optimizer.step()
|
||||
|
||||
def get_actor_state_vector(self, global_state: ndarray, actor_features: ndarray) -> ndarray:
|
||||
return np.concatenate((np.array(global_state, dtype=np.float32),
|
||||
np.array(actor_features, dtype=np.float32)))
|
||||
|
||||
def update_q_function(self,
|
||||
reward: float,
|
||||
actor_state: ndarray,
|
||||
abstract_action: np.int32,
|
||||
next_actor_state: Optional[ndarray]):
|
||||
# store the transition in memory
|
||||
reward_tensor = torch.tensor([reward], device=device, dtype=torch.float)
|
||||
action_tensor = torch.tensor([[np.long(abstract_action)]], device=device, dtype=torch.long)
|
||||
current_state_tensor = torch.as_tensor(actor_state, dtype=torch.float, device=device).unsqueeze(0)
|
||||
if next_actor_state is None:
|
||||
next_state_tensor = None
|
||||
else:
|
||||
next_state_tensor = torch.as_tensor(next_actor_state, dtype=torch.float, device=device).unsqueeze(0)
|
||||
self.memory.push(current_state_tensor, action_tensor, next_state_tensor, reward_tensor)
|
||||
|
||||
# optimize the target network
|
||||
self.optimize_model()
|
||||
|
||||
def on_step(self, wrapped_env: w.AgentWrapper,
|
||||
observation, reward: float, done: bool, info, action_metadata):
|
||||
agent_state = wrapped_env.state
|
||||
if done:
|
||||
self.update_q_function(reward,
|
||||
actor_state=action_metadata.actor_state,
|
||||
abstract_action=action_metadata.abstract_action,
|
||||
next_actor_state=None)
|
||||
else:
|
||||
next_global_state = self.stateaction_model.global_features.get(agent_state, node=None)
|
||||
next_actor_features = self.stateaction_model.node_specific_features.get(
|
||||
agent_state, action_metadata.actor_node)
|
||||
next_actor_state = self.get_actor_state_vector(next_global_state, next_actor_features)
|
||||
|
||||
self.update_q_function(reward,
|
||||
actor_state=action_metadata.actor_state,
|
||||
abstract_action=action_metadata.abstract_action,
|
||||
next_actor_state=next_actor_state)
|
||||
|
||||
def end_of_episode(self, i_episode, t):
|
||||
# Update the target network, copying all weights and biases in DQN
|
||||
if i_episode % self.target_update == 0:
|
||||
self.target_net.load_state_dict(self.policy_net.state_dict())
|
||||
|
||||
def lookup_dqn(self, states_to_consider: List[ndarray]) -> Tuple[List[np.int32], List[np.int32]]:
|
||||
""" Given a set of possible current states return:
|
||||
- index, in the provided list, of the state that would yield the best possible outcome
|
||||
- the best action to take in such a state"""
|
||||
with torch.no_grad():
|
||||
# t.max(1) will return largest column value of each row.
|
||||
# second column on max result is index of where max element was
|
||||
# found, so we pick action with the larger expected reward.
|
||||
# action: np.int32 = self.policy_net(states_to_consider).max(1)[1].view(1, 1).item()
|
||||
|
||||
state_batch = torch.tensor(states_to_consider).to(device)
|
||||
dnn_output = self.policy_net(state_batch).max(1)
|
||||
action_lookups = dnn_output[1].tolist()
|
||||
expectedq_lookups = dnn_output[0].tolist()
|
||||
|
||||
return action_lookups, expectedq_lookups
|
||||
|
||||
def metadata_from_gymaction(self, wrapped_env, gym_action):
|
||||
current_global_state = self.stateaction_model.global_features.get(wrapped_env.state, node=None)
|
||||
actor_node = cyberbattle_env.sourcenode_of_action(gym_action)
|
||||
actor_features = self.stateaction_model.node_specific_features.get(wrapped_env.state, actor_node)
|
||||
abstract_action = self.stateaction_model.action_space.abstract_from_gymaction(gym_action)
|
||||
return ChosenActionMetadata(
|
||||
abstract_action=abstract_action,
|
||||
actor_node=actor_node,
|
||||
actor_features=actor_features,
|
||||
actor_state=self.get_actor_state_vector(current_global_state, actor_features))
|
||||
|
||||
def explore(self, wrapped_env: w.AgentWrapper
|
||||
) -> Tuple[str, cyberbattle_env.Action, object]:
|
||||
"""Random exploration that avoids repeating actions previously taken in the same state"""
|
||||
# sample local and remote actions only (excludes connect action)
|
||||
gym_action = wrapped_env.env.sample_valid_action(kinds=[0, 1, 2])
|
||||
metadata = self.metadata_from_gymaction(wrapped_env, gym_action)
|
||||
return "explore", gym_action, metadata
|
||||
|
||||
def try_exploit_at_candidate_actor_states(
|
||||
self,
|
||||
wrapped_env,
|
||||
current_global_state,
|
||||
actor_features,
|
||||
abstract_action):
|
||||
|
||||
actor_state = self.get_actor_state_vector(current_global_state, actor_features)
|
||||
|
||||
action_style, gym_action, actor_node = self.stateaction_model.implement_action(
|
||||
wrapped_env, actor_features, abstract_action)
|
||||
|
||||
if gym_action:
|
||||
assert actor_node is not None, 'actor_node should be set together with gym_action'
|
||||
|
||||
return action_style, gym_action, ChosenActionMetadata(
|
||||
abstract_action=abstract_action,
|
||||
actor_node=actor_node,
|
||||
actor_features=actor_features,
|
||||
actor_state=actor_state)
|
||||
else:
|
||||
# learn the failed exploit attempt in the current state
|
||||
self.update_q_function(reward=0.0,
|
||||
actor_state=actor_state,
|
||||
next_actor_state=actor_state,
|
||||
abstract_action=abstract_action)
|
||||
|
||||
return "exploit[undefined]->explore", None, None
|
||||
|
||||
def exploit(self,
|
||||
wrapped_env,
|
||||
observation
|
||||
) -> Tuple[str, Optional[cyberbattle_env.Action], object]:
|
||||
|
||||
# first, attempt to exploit the credential cache
|
||||
# using the crecache_policy
|
||||
# action_style, gym_action, _ = self.credcache_policy.exploit(wrapped_env, observation)
|
||||
# if gym_action:
|
||||
# return action_style, gym_action, self.metadata_from_gymaction(wrapped_env, gym_action)
|
||||
|
||||
# Otherwise on exploit learnt Q-function
|
||||
|
||||
current_global_state = self.stateaction_model.global_features.get(wrapped_env.state, node=None)
|
||||
|
||||
# Gather the features of all the current active actors (i.e. owned nodes)
|
||||
active_actors_features: List[ndarray] = [
|
||||
self.stateaction_model.node_specific_features.get(wrapped_env.state, from_node)
|
||||
for from_node in w.owned_nodes(observation)
|
||||
]
|
||||
|
||||
unique_active_actors_features: List[ndarray] = np.unique(active_actors_features, axis=0)
|
||||
|
||||
# array of actor state vector for every possible set of node features
|
||||
candidate_actor_state_vector: List[ndarray] = [
|
||||
self.get_actor_state_vector(current_global_state, node_features)
|
||||
for node_features in unique_active_actors_features]
|
||||
|
||||
remaining_action_lookups, remaining_expectedq_lookups = self.lookup_dqn(candidate_actor_state_vector)
|
||||
remaining_candidate_indices = list(range(len(candidate_actor_state_vector)))
|
||||
|
||||
while remaining_candidate_indices:
|
||||
_, remaining_candidate_index = random_argmax(remaining_expectedq_lookups)
|
||||
actor_index = remaining_candidate_indices[remaining_candidate_index]
|
||||
abstract_action = remaining_action_lookups[remaining_candidate_index]
|
||||
|
||||
actor_features = unique_active_actors_features[actor_index]
|
||||
|
||||
action_style, gym_action, metadata = self.try_exploit_at_candidate_actor_states(
|
||||
wrapped_env,
|
||||
current_global_state,
|
||||
actor_features,
|
||||
abstract_action)
|
||||
|
||||
if gym_action:
|
||||
return action_style, gym_action, metadata
|
||||
|
||||
remaining_candidate_indices.pop(remaining_candidate_index)
|
||||
remaining_expectedq_lookups.pop(remaining_candidate_index)
|
||||
remaining_action_lookups.pop(remaining_candidate_index)
|
||||
|
||||
return "exploit[undefined]->explore", None, None
|
||||
|
||||
def stateaction_as_string(self, action_metadata) -> str:
|
||||
return ''
|
|
@ -0,0 +1,86 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Random agent with credential lookup (notebook)
|
||||
"""
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
from .agent_wrapper import AgentWrapper
|
||||
from .learner import Learner
|
||||
from typing import Optional
|
||||
import cyberbattle._env.cyberbattle_env as cyberbattle_env
|
||||
import numpy as np
|
||||
import logging
|
||||
import cyberbattle.agents.baseline.agent_wrapper as w
|
||||
|
||||
|
||||
def exploit_credentialcache(observation) -> Optional[cyberbattle_env.Action]:
|
||||
"""Exploit the credential cache to connect to
|
||||
a node not owned yet."""
|
||||
|
||||
# Pick source node at random (owned and with the desired feature encoding)
|
||||
potential_source_nodes = w.owned_nodes(observation)
|
||||
if len(potential_source_nodes) == 0:
|
||||
return None
|
||||
|
||||
source_node = np.random.choice(potential_source_nodes)
|
||||
|
||||
discovered_credentials = np.array(observation['credential_cache_matrix'])
|
||||
n_discovered_creds = len(discovered_credentials)
|
||||
if n_discovered_creds <= 0:
|
||||
# no credential available in the cache: cannot poduce a valid connect action
|
||||
return None
|
||||
|
||||
nodes_not_owned = w.discovered_nodes_notowned(observation)
|
||||
|
||||
match_port__target_notowned = [c for c in range(n_discovered_creds)
|
||||
if discovered_credentials[c, 0] in nodes_not_owned]
|
||||
|
||||
if match_port__target_notowned:
|
||||
logging.debug('found matching cred in the credential cache')
|
||||
cred = np.int32(np.random.choice(match_port__target_notowned))
|
||||
target = np.int32(discovered_credentials[cred, 0])
|
||||
port = np.int32(discovered_credentials[cred, 1])
|
||||
return {'connect': np.array([source_node, target, port, cred], dtype=np.int32)}
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class CredentialCacheExploiter(Learner):
|
||||
"""A learner that just exploits the credential cache"""
|
||||
|
||||
def parameters_as_string(self):
|
||||
return ''
|
||||
|
||||
def explore(self, wrapped_env: AgentWrapper):
|
||||
return "explore", wrapped_env.env.sample_valid_action([0, 1]), None
|
||||
|
||||
def exploit(self, wrapped_env: AgentWrapper, observation):
|
||||
gym_action = exploit_credentialcache(observation)
|
||||
if gym_action:
|
||||
if wrapped_env.env.is_action_valid(gym_action, observation['action_mask']):
|
||||
return 'exploit', gym_action, None
|
||||
else:
|
||||
# fallback on random exploration
|
||||
return 'exploit[invalid]->explore', None, None
|
||||
else:
|
||||
return 'exploit[undefined]->explore', None, None
|
||||
|
||||
def stateaction_as_string(self, actionmetadata):
|
||||
return ''
|
||||
|
||||
def on_step(self, wrapped_env: AgentWrapper, observation, reward, done, info, action_metadata):
|
||||
None
|
||||
|
||||
def end_of_iteration(self, t, done):
|
||||
None
|
||||
|
||||
def end_of_episode(self, i_episode, t):
|
||||
None
|
||||
|
||||
def loss_as_string(self):
|
||||
return ''
|
||||
|
||||
def new_episode(self):
|
||||
None
|
|
@ -0,0 +1,422 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Q-learning agent applied to chain network (notebook)
|
||||
This notebooks can be run directly from VSCode, to generate a
|
||||
traditional Jupyter Notebook to open in your browser
|
||||
you can run the VSCode command `Export Currenty Python File As Jupyter Notebook`.
|
||||
"""
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
from typing import NamedTuple, Optional, Tuple
|
||||
import numpy as np
|
||||
import logging
|
||||
|
||||
from cyberbattle._env import cyberbattle_env
|
||||
from .agent_wrapper import EnvironmentBounds
|
||||
from .agent_randomcredlookup import CredentialCacheExploiter
|
||||
import cyberbattle.agents.baseline.agent_wrapper as w
|
||||
import cyberbattle.agents.baseline.learner as learner
|
||||
|
||||
|
||||
def random_argmax(array):
|
||||
"""Just like `argmax` but if there are multiple elements with the max
|
||||
return a random index to break ties instead of returning the first one."""
|
||||
max_value = np.max(array)
|
||||
max_index = np.where(array == max_value)[0]
|
||||
|
||||
if max_index.shape[0] > 1:
|
||||
max_index = int(np.random.choice(max_index, size=1))
|
||||
else:
|
||||
max_index = int(max_index)
|
||||
|
||||
return max_value, max_index
|
||||
|
||||
|
||||
def random_argtop_percentile(array, percentile):
|
||||
"""Just like `argmax` but if there are multiple elements with the max
|
||||
return a random index to break ties instead of returning the first one."""
|
||||
top_percentile = np.percentile(array, percentile)
|
||||
indices = np.where(array >= top_percentile)[0]
|
||||
|
||||
if len(indices) == 0:
|
||||
return random_argmax(array)
|
||||
elif indices.shape[0] > 1:
|
||||
max_index = int(np.random.choice(indices, size=1))
|
||||
else:
|
||||
max_index = int(indices)
|
||||
|
||||
return top_percentile, max_index
|
||||
|
||||
|
||||
class QMatrix:
|
||||
"""Q-Learning matrix for a given state and action space
|
||||
state_space - Features defining the state space
|
||||
action_space - Features defining the action space
|
||||
qm - Optional: initialization values for the Q matrix
|
||||
"""
|
||||
# The Quality matrix
|
||||
qm: np.ndarray
|
||||
|
||||
def __init__(self, name,
|
||||
state_space: w.Feature,
|
||||
action_space: w.Feature,
|
||||
qm: Optional[np.ndarray] = None):
|
||||
"""Initialize the Q-matrix"""
|
||||
|
||||
self.name = name
|
||||
self.state_space = state_space
|
||||
self.action_space = action_space
|
||||
self.statedim = state_space.flat_size()
|
||||
self.actiondim = action_space.flat_size()
|
||||
self.qm = self.clear() if qm is None else qm
|
||||
|
||||
# error calculated for the last update to the Q-matrix
|
||||
self.last_error = 0
|
||||
|
||||
def shape(self):
|
||||
return (self.statedim, self.actiondim)
|
||||
|
||||
def clear(self):
|
||||
"""Re-initialize the Q-matrix to 0"""
|
||||
self.qm = np.zeros(shape=self.shape())
|
||||
# self.qm = np.random.rand(*self.shape()) / 100
|
||||
return self.qm
|
||||
|
||||
def print(self):
|
||||
print(f"[{self.name}]\n"
|
||||
f"state: {self.state_space}\n"
|
||||
f"action: {self.action_space}\n"
|
||||
f"shape = {self.shape()}")
|
||||
|
||||
def update(self, current_state: int, action: int, next_state: int, reward, gamma, learning_rate):
|
||||
"""Update the Q matrix after taking `action` in state 'current_State'
|
||||
and obtaining reward=R[current_state, action]"""
|
||||
|
||||
maxq_atnext, max_index = random_argmax(self.qm[next_state, ])
|
||||
|
||||
# bellman equation for Q-learning
|
||||
temporal_difference = reward + gamma * maxq_atnext - self.qm[current_state, action]
|
||||
self.qm[current_state, action] += learning_rate * temporal_difference
|
||||
|
||||
# The loss is calculated using the squared difference between
|
||||
# target Q-Value and predicted Q-Value
|
||||
square_error = temporal_difference * temporal_difference
|
||||
self.last_error = square_error
|
||||
|
||||
return self.qm[current_state, action]
|
||||
|
||||
def exploit(self, features, percentile) -> Tuple[int, float]:
|
||||
"""exploit: leverage the Q-matrix.
|
||||
Returns the expected Q value and the chosen action."""
|
||||
expected_q, action = random_argtop_percentile(self.qm[features, :], percentile)
|
||||
return int(action), expected_q
|
||||
|
||||
|
||||
class QLearnAttackSource(QMatrix):
|
||||
""" Top-level Q matrix to pick the attack
|
||||
State space: global state info
|
||||
Action space: feature encodings of suggested nodes
|
||||
"""
|
||||
|
||||
def __init__(self, ep: EnvironmentBounds, qm: Optional[np.ndarray] = None):
|
||||
self.ep = ep
|
||||
|
||||
self.state_space = w.HashEncoding(ep, [
|
||||
# Feature_discovered_node_count(),
|
||||
# Feature_discovered_credential_count(),
|
||||
w.Feature_discovered_ports_sliding(ep),
|
||||
w.Feature_discovered_nodeproperties_sliding(ep),
|
||||
w.Feature_discovered_notowned_node_count(ep, 3)
|
||||
], 5000) # should not be too small, pick something big to avoid collision
|
||||
|
||||
self.action_space = w.RavelEncoding(ep, [
|
||||
w.Feature_active_node_properties(ep)])
|
||||
|
||||
super().__init__("attack_source", self.state_space, self.action_space, qm)
|
||||
|
||||
|
||||
class QLearnBestAttackAtSource(QMatrix):
|
||||
""" Top-level Q matrix to pick the attack from a pre-chosen source node
|
||||
State space: feature encodings of suggested node states
|
||||
Action space: a SimpleAbstract action
|
||||
"""
|
||||
|
||||
def __init__(self, ep: EnvironmentBounds, qm: Optional[np.ndarray] = None) -> None:
|
||||
|
||||
self.state_space = w.HashEncoding(ep, [
|
||||
w.Feature_active_node_properties(ep),
|
||||
w.Feature_active_node_age(ep)
|
||||
# w.Feature_actions_tried_at_node(ep)
|
||||
], 7000)
|
||||
|
||||
# NOTE: For debugging purpose it's convenient instead to use
|
||||
# Ravel encoding for node properties
|
||||
self.state_space_debugging = w.RavelEncoding(ep, [
|
||||
w.HashEncoding(ep, [
|
||||
# Feature_discovered_node_count(),
|
||||
# Feature_discovered_credential_count(),
|
||||
w.Feature_discovered_ports_sliding(ep),
|
||||
w.Feature_discovered_nodeproperties_sliding(ep),
|
||||
w.Feature_discovered_notowned_node_count(ep, 3),
|
||||
], 100),
|
||||
w.Feature_active_node_properties(ep)
|
||||
])
|
||||
|
||||
self.action_space = w.AbstractAction(ep)
|
||||
|
||||
super().__init__("attack_at_source", self.state_space, self.action_space, qm)
|
||||
|
||||
|
||||
# TODO: We should try scipy for sparse matrices and OpenBLAS (MKL Intel version of BLAS, faster than openBLAS) for numpy
|
||||
|
||||
# %%
|
||||
class LossEval:
|
||||
"""Loss evaluation for a Q-Learner,
|
||||
learner -- The Q learner
|
||||
"""
|
||||
|
||||
def __init__(self, qmatrix: QMatrix):
|
||||
self.qmatrix = qmatrix
|
||||
self.this_episode = []
|
||||
self.all_episodes = []
|
||||
|
||||
def new_episode(self):
|
||||
self.this_episode = []
|
||||
|
||||
def end_of_iteration(self, t, done):
|
||||
self.this_episode.append(self.qmatrix.last_error)
|
||||
|
||||
def current_episode_loss(self):
|
||||
return np.average(self.this_episode)
|
||||
|
||||
def end_of_episode(self, i_episode, t):
|
||||
"""Average out the overall loss for this episode"""
|
||||
self.all_episodes.append(self.current_episode_loss())
|
||||
|
||||
|
||||
class ChosenActionMetadata(NamedTuple):
|
||||
"""Additional information associated with the action chosen by the agent"""
|
||||
Q_source_state: int
|
||||
Q_source_expectedq: float
|
||||
Q_attack_expectedq: float
|
||||
source_node: int
|
||||
source_node_encoding: int
|
||||
abstract_action: np.int32
|
||||
Q_attack_state: int
|
||||
|
||||
|
||||
class QTabularLearner(learner.Learner):
|
||||
"""Tabular Q-learning
|
||||
|
||||
Parameters
|
||||
==========
|
||||
gamma -- discount factor
|
||||
|
||||
learning_rate -- learning rate
|
||||
|
||||
ep -- environment global properties
|
||||
|
||||
trained -- another QTabularLearner that is pretrained to initialize the Q matrices from (referenced, not copied)
|
||||
|
||||
exploit_percentile -- (experimental) Randomly pick actions above this percentile in the Q-matrix.
|
||||
Setting 100 gives the argmax as in standard Q-learning.
|
||||
|
||||
The idea is that a value less than 100 helps compensate for the
|
||||
approximation made when updating the Q-matrix caused by
|
||||
the abstraction of the action space (attack parameters are abstracted away
|
||||
in the Q-matrix, and when an abstract action is picked, it
|
||||
gets specialized via a random process.)
|
||||
When running in non-learning mode (lr=0), setting this value too close to 100
|
||||
may lead to get stuck, being more permissive (e.g. in the 80-90 range)
|
||||
typically gives better results.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
ep: EnvironmentBounds,
|
||||
gamma: float,
|
||||
learning_rate: float,
|
||||
exploit_percentile: float,
|
||||
trained=None, # : Optional[QTabularLearner]
|
||||
):
|
||||
if trained:
|
||||
self.qsource = trained.qsource
|
||||
self.qattack = trained.qattack
|
||||
else:
|
||||
self.qsource = QLearnAttackSource(ep)
|
||||
self.qattack = QLearnBestAttackAtSource(ep)
|
||||
|
||||
self.loss_qsource = LossEval(self.qsource)
|
||||
self.loss_qattack = LossEval(self.qattack)
|
||||
self.gamma = gamma
|
||||
self.learning_rate = learning_rate
|
||||
self.exploit_percentile = exploit_percentile
|
||||
self.credcache_policy = CredentialCacheExploiter()
|
||||
|
||||
def on_step(self, wrapped_env: w.AgentWrapper, observation, reward, done, info, action_metadata: ChosenActionMetadata):
|
||||
|
||||
agent_state = wrapped_env.state
|
||||
|
||||
# Update the top-level Q matrix for the state of the selected source node
|
||||
after_toplevel_state = self.qsource.state_space.encode(agent_state)
|
||||
self.qsource.update(action_metadata.Q_source_state,
|
||||
action_metadata.source_node_encoding,
|
||||
after_toplevel_state,
|
||||
reward, self.gamma, self.learning_rate)
|
||||
|
||||
# Update the second Q matrix for the abstract action chosen
|
||||
qattack_state_after = self.qattack.state_space.encode_at(agent_state, action_metadata.source_node)
|
||||
self.qattack.update(action_metadata.Q_attack_state,
|
||||
int(action_metadata.abstract_action),
|
||||
qattack_state_after,
|
||||
reward, self.gamma, self.learning_rate)
|
||||
|
||||
def end_of_iteration(self, t, done):
|
||||
self.loss_qsource.end_of_iteration(t, done)
|
||||
self.loss_qattack.end_of_iteration(t, done)
|
||||
|
||||
def end_of_episode(self, i_episode, t):
|
||||
self.loss_qsource.end_of_episode(i_episode, t)
|
||||
self.loss_qattack.end_of_episode(i_episode, t)
|
||||
|
||||
def loss_as_string(self):
|
||||
return f"[loss_source={self.loss_qsource.current_episode_loss():0.3f}" \
|
||||
f" loss_attack={self.loss_qattack.current_episode_loss():0.3f}]"
|
||||
|
||||
def new_episode(self):
|
||||
self.loss_qsource.new_episode()
|
||||
self.loss_qattack.new_episode()
|
||||
|
||||
def exploit(self, wrapped_env: w.AgentWrapper, observation):
|
||||
|
||||
agent_state = wrapped_env.state
|
||||
|
||||
qsource_state = self.qsource.state_space.encode(agent_state)
|
||||
|
||||
#############
|
||||
# first, attempt to exploit the credential cache
|
||||
# using the crecache_policy
|
||||
action_style, gym_action, _ = self.credcache_policy.exploit(wrapped_env, observation)
|
||||
if gym_action:
|
||||
source_node = cyberbattle_env.sourcenode_of_action(gym_action)
|
||||
return action_style, gym_action, ChosenActionMetadata(
|
||||
Q_source_state=qsource_state,
|
||||
Q_source_expectedq=-1,
|
||||
Q_attack_expectedq=-1,
|
||||
source_node=source_node,
|
||||
source_node_encoding=self.qsource.action_space.encode_at(
|
||||
agent_state, source_node),
|
||||
abstract_action=np.int32(self.qattack.action_space.abstract_from_gymaction(gym_action)),
|
||||
Q_attack_state=self.qattack.state_space.encode_at(agent_state, source_node)
|
||||
)
|
||||
#############
|
||||
|
||||
# Pick action: pick random source state among the ones with the maximum Q-value
|
||||
action_style = "exploit"
|
||||
source_node_encoding, qsource_expectedq = self.qsource.exploit(
|
||||
qsource_state, percentile=100)
|
||||
|
||||
# Pick source node at random (owned and with the desired feature encoding)
|
||||
potential_source_nodes = [
|
||||
from_node
|
||||
for from_node in w.owned_nodes(observation)
|
||||
if source_node_encoding == self.qsource.action_space.encode_at(agent_state, from_node)
|
||||
]
|
||||
|
||||
if len(potential_source_nodes) == 0:
|
||||
logging.debug(f'No node with encoding {source_node_encoding}, fallback on explore')
|
||||
# NOTE: we should make sure that it does not happen too often,
|
||||
# the penalty should be much smaller than typical rewards, small nudge
|
||||
# not a new feedback signal.
|
||||
|
||||
# Learn the lack of node availability
|
||||
self.qsource.update(qsource_state,
|
||||
source_node_encoding,
|
||||
qsource_state,
|
||||
reward=0, gamma=self.gamma, learning_rate=self.learning_rate)
|
||||
|
||||
return "exploit-1->explore", None, None
|
||||
else:
|
||||
source_node = np.random.choice(potential_source_nodes)
|
||||
|
||||
qattack_state = self.qattack.state_space.encode_at(agent_state, source_node)
|
||||
|
||||
abstract_action, qattack_expectedq = self.qattack.exploit(
|
||||
qattack_state, percentile=self.exploit_percentile)
|
||||
|
||||
gym_action = self.qattack.action_space.specialize_to_gymaction(
|
||||
source_node, observation, np.int32(abstract_action))
|
||||
|
||||
assert int(abstract_action) < self.qattack.action_space.flat_size(), \
|
||||
f'abstract_action={abstract_action} gym_action={gym_action}'
|
||||
|
||||
if gym_action and wrapped_env.env.is_action_valid(gym_action, observation['action_mask']):
|
||||
logging.debug(f' exploit gym_action={gym_action} source_node_encoding={source_node_encoding}')
|
||||
return action_style, gym_action, ChosenActionMetadata(
|
||||
Q_source_state=qsource_state,
|
||||
Q_source_expectedq=qsource_expectedq,
|
||||
Q_attack_expectedq=qsource_expectedq,
|
||||
source_node=source_node,
|
||||
source_node_encoding=source_node_encoding,
|
||||
abstract_action=np.int32(abstract_action),
|
||||
Q_attack_state=qattack_state
|
||||
)
|
||||
else:
|
||||
# NOTE: We should make the penalty reward smaller than
|
||||
# the average/typical non-zero reward of the env (e.g. 1/1000 smaller)
|
||||
# The idea of weighing the learning_rate when taking a chance is
|
||||
# related to "Inverse propensity weighting"
|
||||
|
||||
# Learn the non-validity of the action
|
||||
self.qsource.update(qsource_state,
|
||||
source_node_encoding,
|
||||
qsource_state,
|
||||
reward=0, gamma=self.gamma, learning_rate=self.learning_rate)
|
||||
|
||||
self.qattack.update(qattack_state,
|
||||
int(abstract_action),
|
||||
qattack_state,
|
||||
reward=0, gamma=self.gamma, learning_rate=self.learning_rate)
|
||||
|
||||
# fallback on random exploration
|
||||
return ('exploit[invalid]->explore' if gym_action else 'exploit[undefined]->explore'), None, None
|
||||
|
||||
def explore(self, wrapped_env: w.AgentWrapper):
|
||||
agent_state = wrapped_env.state
|
||||
gym_action = wrapped_env.env.sample_valid_action(kinds=[0, 1, 2])
|
||||
abstract_action = self.qattack.action_space.abstract_from_gymaction(gym_action)
|
||||
|
||||
assert int(abstract_action) < self.qattack.action_space.flat_size(
|
||||
), f'Q_attack_action={abstract_action} gym_action={gym_action}'
|
||||
|
||||
source_node = cyberbattle_env.sourcenode_of_action(gym_action)
|
||||
|
||||
return "explore", gym_action, ChosenActionMetadata(
|
||||
Q_source_state=self.qsource.state_space.encode(agent_state),
|
||||
Q_source_expectedq=-1,
|
||||
Q_attack_expectedq=-1,
|
||||
source_node=source_node,
|
||||
source_node_encoding=self.qsource.action_space.encode_at(agent_state, source_node),
|
||||
abstract_action=abstract_action,
|
||||
Q_attack_state=self.qattack.state_space.encode_at(agent_state, source_node)
|
||||
)
|
||||
|
||||
def stateaction_as_string(self, actionmetadata) -> str:
|
||||
return f"Qsource[state={actionmetadata.Q_source_state} err={self.qsource.last_error:0.2f}"\
|
||||
f"Q={actionmetadata.Q_source_expectedq:.2f}] " \
|
||||
f"Qattack[state={actionmetadata.Q_attack_state} err={self.qattack.last_error:0.2f} "\
|
||||
f"Q={actionmetadata.Q_attack_expectedq:.2f}] "
|
||||
|
||||
def parameters_as_string(self) -> str:
|
||||
return f"γ={self.gamma}," \
|
||||
f"learning_rate={self.learning_rate},"\
|
||||
f"Q%={self.exploit_percentile}"
|
||||
|
||||
def all_parameters_as_string(self) -> str:
|
||||
return f' dimension={self.qsource.state_space.flat_size()}x{self.qsource.action_space.flat_size()},' \
|
||||
f'{self.qattack.state_space.flat_size()}x{self.qattack.action_space.flat_size()}\n' \
|
||||
f'Q1={[f.name() for f in self.qsource.state_space.feature_selection]}' \
|
||||
f' -> {[f.name() for f in self.qsource.action_space.feature_selection]}\n' \
|
||||
f"Q2={[f.name() for f in self.qattack.state_space.feature_selection]} -> 'action'"
|
|
@ -0,0 +1,507 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Agent wrapper for CyberBattle envrionments exposing additional
|
||||
features extracted from the environment observations"""
|
||||
|
||||
from cyberbattle._env.cyberbattle_env import EnvironmentBounds
|
||||
from typing import Optional, List
|
||||
import enum
|
||||
import numpy as np
|
||||
from gym import spaces, Wrapper
|
||||
from numpy import ndarray
|
||||
import cyberbattle._env.cyberbattle_env as cyberbattle_env
|
||||
import logging
|
||||
|
||||
|
||||
class StateAugmentation:
|
||||
"""Default agent state augmentation, consisting of the gym environment
|
||||
observation itself and nothing more."""
|
||||
|
||||
def __init__(self, observation: Optional[cyberbattle_env.Observation] = None):
|
||||
self.observation = observation
|
||||
|
||||
def on_step(self, action: cyberbattle_env.Action, reward: float, done: bool, observation: cyberbattle_env.Observation):
|
||||
self.observation = observation
|
||||
|
||||
def on_reset(self, observation: cyberbattle_env.Observation):
|
||||
self.observation = observation
|
||||
|
||||
|
||||
class Feature(spaces.MultiDiscrete):
|
||||
"""
|
||||
Feature consisting of multiple discrete dimensions.
|
||||
Parameters:
|
||||
nvec: is a vector defining the number of possible values
|
||||
for each discrete space.
|
||||
"""
|
||||
|
||||
def __init__(self, env_properties: EnvironmentBounds, nvec):
|
||||
self.env_properties = env_properties
|
||||
super().__init__(nvec)
|
||||
|
||||
def flat_size(self):
|
||||
return np.prod(self.nvec)
|
||||
|
||||
def name(self):
|
||||
"""Return the name of the feature"""
|
||||
p = len(type(Feature(self.env_properties, [])).__name__) + 1
|
||||
return type(self).__name__[p:]
|
||||
|
||||
def get(self, a: StateAugmentation, node: Optional[int]) -> np.ndarray:
|
||||
"""Compute the current value of a feature value at
|
||||
the current observation and specific node"""
|
||||
raise NotImplementedError
|
||||
|
||||
def pretty_print(self, v):
|
||||
return v
|
||||
|
||||
|
||||
class Feature_active_node_properties(Feature):
|
||||
"""Bitmask of all properties set for the active node"""
|
||||
|
||||
def __init__(self, p: EnvironmentBounds):
|
||||
super().__init__(p, [2] * p.property_count)
|
||||
|
||||
def get(self, a: StateAugmentation, node) -> ndarray:
|
||||
assert node is not None, 'feature only valid in the context of a node'
|
||||
|
||||
node_prop = a.observation['discovered_nodes_properties']
|
||||
|
||||
# list of all properties set/unset on the node
|
||||
# Remap to get rid of unknown value 0: 1 -> 1, and -1 -> 0 (and 0-> 0)
|
||||
assert node < len(node_prop), f'invalid node index {node} (not discovered yet)'
|
||||
remapped = np.array((1 + node_prop[node]) / 2, dtype=np.int)
|
||||
return remapped
|
||||
|
||||
|
||||
class Feature_active_node_age(Feature):
|
||||
"""How recently was this node discovered?
|
||||
(measured by reverse position in the list of discovered nodes)"""
|
||||
|
||||
def __init__(self, p: EnvironmentBounds):
|
||||
super().__init__(p, [p.maximum_node_count])
|
||||
|
||||
def get(self, a: StateAugmentation, node) -> ndarray:
|
||||
assert node is not None, 'feature only valid in the context of a node'
|
||||
|
||||
discovered_node_count = len(a.observation['discovered_nodes_properties'])
|
||||
|
||||
assert node < discovered_node_count, f'invalid node index {node} (not discovered yet)'
|
||||
|
||||
return np.array([discovered_node_count - node - 1], dtype=np.int)
|
||||
|
||||
|
||||
class Feature_active_node_id(Feature):
|
||||
"""Return the node id itself"""
|
||||
|
||||
def __init__(self, p: EnvironmentBounds):
|
||||
super().__init__(p, [p.maximum_node_count] * 1)
|
||||
|
||||
def get(self, a: StateAugmentation, node) -> ndarray:
|
||||
return np.array([node], dtype=np.int)
|
||||
|
||||
|
||||
class Feature_discovered_nodeproperties_sliding(Feature):
|
||||
"""Bitmask indicating node properties seen in last few cache entries"""
|
||||
window_size = 3
|
||||
|
||||
def __init__(self, p: EnvironmentBounds):
|
||||
super().__init__(p, [2] * p.property_count)
|
||||
|
||||
def get(self, a: StateAugmentation, node) -> ndarray:
|
||||
node_prop = np.array(a.observation['discovered_nodes_properties'])
|
||||
|
||||
# keep last window of entries
|
||||
node_prop_window = node_prop[-self.window_size:, :]
|
||||
|
||||
# Remap to get rid of unknown value 0: 1 -> 1, and -1 -> 0 (and 0-> 0)
|
||||
node_prop_window_remapped = np.int32((1 + node_prop_window) / 2)
|
||||
|
||||
countby = np.sum(node_prop_window_remapped, axis=0)
|
||||
|
||||
bitmask = (countby > 0) * 1
|
||||
return bitmask
|
||||
|
||||
|
||||
class Feature_discovered_ports(Feature):
|
||||
"""Bitmask vector indicating each port seen so far in discovered credentials"""
|
||||
|
||||
def __init__(self, p: EnvironmentBounds):
|
||||
super().__init__(p, [2] * p.port_count)
|
||||
|
||||
def get(self, a: StateAugmentation, node):
|
||||
ccm = a.observation['credential_cache_matrix']
|
||||
known_credports = np.zeros(self.env_properties.port_count, dtype=np.int32)
|
||||
known_credports[np.int32(ccm[:, 1])] = 1
|
||||
return known_credports
|
||||
|
||||
|
||||
class Feature_discovered_ports_sliding(Feature):
|
||||
"""Bitmask indicating port seen in last few cache entries"""
|
||||
window_size = 3
|
||||
|
||||
def __init__(self, p: EnvironmentBounds):
|
||||
super().__init__(p, [2] * p.port_count)
|
||||
|
||||
def get(self, a: StateAugmentation, node):
|
||||
ccm = a.observation['credential_cache_matrix']
|
||||
known_credports = np.zeros(self.env_properties.port_count, dtype=np.int32)
|
||||
known_credports[np.int32(ccm[-self.window_size:, 1])] = 1
|
||||
return known_credports
|
||||
|
||||
|
||||
class Feature_discovered_ports_counts(Feature):
|
||||
"""Count of each port seen so far in discovered credentials"""
|
||||
|
||||
def __init__(self, p: EnvironmentBounds):
|
||||
super().__init__(p, [p.maximum_total_credentials + 1] * p.port_count)
|
||||
|
||||
def get(self, a: StateAugmentation, node):
|
||||
ccm = a.observation['credential_cache_matrix']
|
||||
return np.bincount(np.int32(ccm[:, 1]), minlength=self.env_properties.port_count)
|
||||
|
||||
|
||||
class Feature_discovered_credential_count(Feature):
|
||||
"""number of credentials discovered so far"""
|
||||
|
||||
def __init__(self, p: EnvironmentBounds):
|
||||
super().__init__(p, [p.maximum_total_credentials + 1])
|
||||
|
||||
def get(self, a: StateAugmentation, node):
|
||||
return [len(a.observation['credential_cache_matrix'])]
|
||||
|
||||
|
||||
class Feature_discovered_node_count(Feature):
|
||||
"""number of nodes discovered so far"""
|
||||
|
||||
def __init__(self, p: EnvironmentBounds):
|
||||
super().__init__(p, [p.maximum_node_count + 1])
|
||||
|
||||
def get(self, a: StateAugmentation, node):
|
||||
return [len(a.observation['discovered_nodes_properties'])]
|
||||
|
||||
|
||||
class Feature_discovered_notowned_node_count(Feature):
|
||||
"""number of nodes discovered that are not owned yet (optionally clipped)"""
|
||||
|
||||
def __init__(self, p: EnvironmentBounds, clip: Optional[int]):
|
||||
self.clip = p.maximum_node_count if clip is None else clip
|
||||
super().__init__(p, [self.clip + 1])
|
||||
|
||||
def get(self, a: StateAugmentation, node):
|
||||
node_props = a.observation['discovered_nodes_properties']
|
||||
discovered = len(node_props)
|
||||
owned = len(np.all(node_props != 0, axis=1))
|
||||
diff = discovered - owned
|
||||
return [max(diff, self.clip)]
|
||||
|
||||
|
||||
class Feature_owned_node_count(Feature):
|
||||
"""number of owned nodes so far"""
|
||||
|
||||
def __init__(self, p: EnvironmentBounds):
|
||||
super().__init__(p, [p.maximum_node_count + 1])
|
||||
|
||||
def get(self, a: StateAugmentation, node):
|
||||
levels = a.observation['nodes_privilegelevel']
|
||||
owned_nodes_indices = np.where(levels > 0)[0]
|
||||
return [len(owned_nodes_indices)]
|
||||
|
||||
|
||||
class ConcatFeatures(Feature):
|
||||
""" Concatenate a list of features into a single feature
|
||||
Parameters:
|
||||
feature_selection - a selection of features to combine
|
||||
"""
|
||||
|
||||
def __init__(self, p: EnvironmentBounds, feature_selection: List[Feature]):
|
||||
self.feature_selection = feature_selection
|
||||
self.dim_sizes = np.concatenate([f.nvec for f in feature_selection])
|
||||
super().__init__(p, [self.dim_sizes])
|
||||
|
||||
def pretty_print(self, v):
|
||||
return v
|
||||
|
||||
def get(self, a: StateAugmentation, node=None) -> np.ndarray:
|
||||
"""Return the feature vector"""
|
||||
feature_vector = [f.get(a, node) for f in self.feature_selection]
|
||||
return np.concatenate(feature_vector)
|
||||
|
||||
|
||||
class FeatureEncoder(Feature):
|
||||
""" Encode a list of featues as a unique index
|
||||
"""
|
||||
|
||||
feature_selection: List[Feature]
|
||||
|
||||
def vector_to_index(self, feature_vector: np.ndarray) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
def feature_vector_of_observation_at(self, a: StateAugmentation, node: Optional[int]) -> np.ndarray:
|
||||
"""Return the current feature vector"""
|
||||
feature_vector = [f.get(a, node) for f in self.feature_selection]
|
||||
# print(f'feature_vector={feature_vector} self.feature_selection={self.feature_selection}')
|
||||
return np.concatenate(feature_vector)
|
||||
|
||||
def feature_vector_of_observation(self, a: StateAugmentation):
|
||||
return self.feature_vector_of_observation_at(a, None)
|
||||
|
||||
def encode(self, a: StateAugmentation, node=None) -> int:
|
||||
"""Return the index encoding of the feature"""
|
||||
feature_vector_concat = self.feature_vector_of_observation_at(a, node)
|
||||
return self.vector_to_index(feature_vector_concat)
|
||||
|
||||
def encode_at(self, a: StateAugmentation, node) -> int:
|
||||
"""Return the current feature vector encoding with a node context"""
|
||||
feature_vector_concat = self.feature_vector_of_observation_at(a, node)
|
||||
return self.vector_to_index(feature_vector_concat)
|
||||
|
||||
def get(self, a: StateAugmentation, node=None) -> np.ndarray:
|
||||
"""Return the feature vector"""
|
||||
return np.array([self.encode(a, node)])
|
||||
|
||||
def name(self):
|
||||
"""Return a name for the feature encoding"""
|
||||
n = ', '.join([f.name() for f in self.feature_selection])
|
||||
return f'[{n}]'
|
||||
|
||||
|
||||
class HashEncoding(FeatureEncoder):
|
||||
""" Feature defined as a hash of another feature
|
||||
Parameters:
|
||||
feature_selection: a selection of features to combine
|
||||
hash_dim: dimension after hashing with hash(str(feature_vector)) or -1 for no hashing
|
||||
"""
|
||||
|
||||
def __init__(self, p: EnvironmentBounds, feature_selection: List[Feature], hash_size: int):
|
||||
self.feature_selection = feature_selection
|
||||
self.hash_size = hash_size
|
||||
super().__init__(p, [hash_size])
|
||||
|
||||
def flat_size(self):
|
||||
return self.hash_size
|
||||
|
||||
def vector_to_index(self, feature_vector) -> int:
|
||||
"""Hash the state vector"""
|
||||
return hash(str(feature_vector)) % self.hash_size
|
||||
|
||||
def pretty_print(self, index):
|
||||
return f'#{index}'
|
||||
|
||||
|
||||
class RavelEncoding(FeatureEncoder):
|
||||
""" Combine a set of features into a single feature with a unique index
|
||||
(calculated by raveling the original indices)
|
||||
Parameters:
|
||||
feature_selection - a selection of features to combine
|
||||
"""
|
||||
|
||||
def __init__(self, p: EnvironmentBounds, feature_selection: List[Feature]):
|
||||
self.feature_selection = feature_selection
|
||||
self.dim_sizes = np.concatenate([f.nvec for f in feature_selection])
|
||||
self.ravelled_size: int = np.prod(self.dim_sizes)
|
||||
assert np.shape(self.ravelled_size) == (), f'! {np.shape(self.ravelled_size)}'
|
||||
super().__init__(p, [self.ravelled_size])
|
||||
|
||||
def vector_to_index(self, feature_vector):
|
||||
assert len(self.dim_sizes) == len(feature_vector), \
|
||||
f'feature vector of size {len(feature_vector)}, ' \
|
||||
f'expecting {len(self.dim_sizes)}: {feature_vector} -- {self.dim_sizes}'
|
||||
index: np.int32 = np.ravel_multi_index(feature_vector, self.dim_sizes)
|
||||
assert index < self.ravelled_size, \
|
||||
f'feature vector out of bound ({feature_vector}, dim={self.dim_sizes}) ' \
|
||||
f'-> index={index}, max_index={self.ravelled_size-1})'
|
||||
return index
|
||||
|
||||
def unravel_index(self, index) -> np.ndarray:
|
||||
return np.unravel_index(index, self.dim_sizes)
|
||||
|
||||
def pretty_print(self, index):
|
||||
return self.unravel_index(index)
|
||||
|
||||
|
||||
def owned_nodes(observation):
|
||||
"""Return the list of owned nodes"""
|
||||
return np.nonzero(observation['nodes_privilegelevel'])[0]
|
||||
|
||||
|
||||
def discovered_nodes_notowned(observation):
|
||||
"""Return the list of discovered nodes that are not owned yet"""
|
||||
return np.nonzero(observation['nodes_privilegelevel'] == 0)[0]
|
||||
|
||||
|
||||
class AbstractAction(Feature):
|
||||
"""An abstraction of the gym state space that reduces
|
||||
the space dimension for learning use to just
|
||||
- local_attack(vulnid) (source_node provided)
|
||||
- remote_attack(vulnid) (source_node provided, target_node forgotten)
|
||||
- connect(port) (source_node provided, target_node forgotten, credentials infered from cache)
|
||||
"""
|
||||
|
||||
def __init__(self, p: EnvironmentBounds):
|
||||
self.n_local_actions = p.local_attacks_count
|
||||
self.n_remote_actions = p.remote_attacks_count
|
||||
self.n_connect_actions = p.port_count
|
||||
self.n_actions = self.n_local_actions + self.n_remote_actions + self.n_connect_actions
|
||||
super().__init__(p, [self.n_actions])
|
||||
|
||||
def specialize_to_gymaction(self, source_node: np.int32, observation, abstract_action_index: np.int32
|
||||
) -> Optional[cyberbattle_env.Action]:
|
||||
"""Specialize an abstract "q"-action into a gym action.
|
||||
Return an adjustement weight (1.0 if the choice was deterministic, 1/n if a choice was made out of n)
|
||||
and the gym action"""
|
||||
|
||||
abstract_action_index_int = int(abstract_action_index)
|
||||
|
||||
node_prop = np.array(observation['discovered_nodes_properties'])
|
||||
|
||||
if abstract_action_index_int < self.n_local_actions:
|
||||
vuln = abstract_action_index_int
|
||||
return {'local_vulnerability': np.array([source_node, vuln])}
|
||||
|
||||
abstract_action_index_int -= self.n_local_actions
|
||||
if abstract_action_index_int < self.n_remote_actions:
|
||||
vuln = abstract_action_index_int
|
||||
|
||||
# NOTE: We can do better here than random pick: ultimately this
|
||||
# should be learnt from target node properties
|
||||
|
||||
# pick any node from the discovered ones
|
||||
target = np.random.choice(len(node_prop))
|
||||
|
||||
return {'remote_vulnerability': np.array([source_node, target, vuln])}
|
||||
|
||||
abstract_action_index_int -= self.n_remote_actions
|
||||
port = np.int32(abstract_action_index_int)
|
||||
|
||||
discovered_credentials = np.array(observation['credential_cache_matrix'])
|
||||
n_discovered_creds = len(discovered_credentials)
|
||||
if n_discovered_creds <= 0:
|
||||
# no credential available in the cache: cannot poduce a valid connect action
|
||||
return None
|
||||
|
||||
# Pick a matching cred from the discovered_cred matrix
|
||||
# (at random if more than one exist for this target port)
|
||||
match_port = discovered_credentials[:, 1] == port
|
||||
match_port_indices = np.where(match_port)[0]
|
||||
|
||||
nodes_not_owned = discovered_nodes_notowned(observation)
|
||||
|
||||
match_port__target_notowned = [c for c in match_port_indices
|
||||
if discovered_credentials[c, 0] in nodes_not_owned]
|
||||
|
||||
if match_port__target_notowned:
|
||||
logging.debug('found matching cred in the credential cache')
|
||||
cred = np.int32(np.random.choice(match_port__target_notowned))
|
||||
target = np.int32(discovered_credentials[cred, 0])
|
||||
|
||||
return {'connect': np.array([source_node, target, port, cred], dtype=np.int32)}
|
||||
else:
|
||||
logging.debug('no cred match')
|
||||
return None
|
||||
|
||||
def abstract_from_gymaction(self, gym_action: cyberbattle_env.Action) -> np.int32:
|
||||
"""Abstract a gym action into an action to be index in the Q-matrix"""
|
||||
if 'local_vulnerability' in gym_action:
|
||||
return gym_action['local_vulnerability'][1]
|
||||
elif 'remote_vulnerability' in gym_action:
|
||||
r = gym_action['remote_vulnerability']
|
||||
return self.n_local_actions + r[2]
|
||||
|
||||
assert 'connect' in gym_action
|
||||
c = gym_action['connect']
|
||||
|
||||
a = self.n_local_actions + self.n_remote_actions + c[2]
|
||||
assert a < self.n_actions
|
||||
return np.int32(a)
|
||||
|
||||
|
||||
class ActionTrackingStateAugmentation(StateAugmentation):
|
||||
"""An agent state augmentation consisting of
|
||||
the environment observation augmented with the following dynamic information:
|
||||
- success_action_count: count of action taken and succeeded at the current node
|
||||
- failed_action_count: count of action taken and failed at the current node
|
||||
"""
|
||||
|
||||
def __init__(self, p: EnvironmentBounds, observation: Optional[cyberbattle_env.Observation] = None):
|
||||
self.aa = AbstractAction(p)
|
||||
self.success_action_count = np.zeros(shape=(p.maximum_node_count, self.aa.n_actions), dtype=np.int32)
|
||||
self.failed_action_count = np.zeros(shape=(p.maximum_node_count, self.aa.n_actions), dtype=np.int32)
|
||||
self.env_properties = p
|
||||
super().__init__(observation)
|
||||
|
||||
def on_step(self, action: cyberbattle_env.Action, reward: float, done: bool, observation: cyberbattle_env.Observation):
|
||||
node = cyberbattle_env.sourcenode_of_action(action)
|
||||
abstract_action = self.aa.abstract_from_gymaction(action)
|
||||
if reward > 0:
|
||||
self.success_action_count[node, abstract_action] += 1
|
||||
else:
|
||||
self.failed_action_count[node, abstract_action] += 1
|
||||
super().on_step(action, reward, done, observation)
|
||||
|
||||
def on_reset(self, observation: cyberbattle_env.Observation):
|
||||
p = self.env_properties
|
||||
self.success_action_count = np.zeros(shape=(p.maximum_node_count, self.aa.n_actions), dtype=np.int32)
|
||||
self.failed_action_count = np.zeros(shape=(p.maximum_node_count, self.aa.n_actions), dtype=np.int32)
|
||||
super().on_reset(observation)
|
||||
|
||||
|
||||
class Feature_actions_tried_at_node(Feature):
|
||||
"""A bit mask indicating which actions were already tried
|
||||
a the current node: 0 no tried, 1 tried"""
|
||||
|
||||
def __init__(self, p: EnvironmentBounds):
|
||||
super().__init__(p, [2] * AbstractAction(p).n_actions)
|
||||
|
||||
def get(self, a: ActionTrackingStateAugmentation, node: int):
|
||||
return ((a.failed_action_count[node, :] + a.success_action_count[node, :]) != 0) * 1
|
||||
|
||||
|
||||
class Feature_success_actions_at_node(Feature):
|
||||
"""number of time each action succeeded at a given node"""
|
||||
|
||||
max_action_count = 100
|
||||
|
||||
def __init__(self, p: EnvironmentBounds):
|
||||
super().__init__(p, [self.max_action_count] * AbstractAction(p).n_actions)
|
||||
|
||||
def get(self, a: ActionTrackingStateAugmentation, node: int):
|
||||
return np.minimum(a.success_action_count[node, :], self.max_action_count - 1)
|
||||
|
||||
|
||||
class Feature_failed_actions_at_node(Feature):
|
||||
"""number of time each action failed at a given node"""
|
||||
|
||||
max_action_count = 100
|
||||
|
||||
def __init__(self, p: EnvironmentBounds):
|
||||
super().__init__(p, [self.max_action_count] * AbstractAction(p).n_actions)
|
||||
|
||||
def get(self, a: ActionTrackingStateAugmentation, node: int):
|
||||
return np.minimum(a.failed_action_count[node, :], self.max_action_count - 1)
|
||||
|
||||
|
||||
class Verbosity(enum.Enum):
|
||||
"""Verbosity of the learning function"""
|
||||
Quiet = 0
|
||||
Normal = 1
|
||||
Verbose = 2
|
||||
|
||||
|
||||
class AgentWrapper(Wrapper):
|
||||
"""Gym wrapper to update the agent state on every step"""
|
||||
|
||||
def __init__(self, env: cyberbattle_env.CyberBattleEnv, state: StateAugmentation):
|
||||
super().__init__(env)
|
||||
self.state = state
|
||||
|
||||
def step(self, action: cyberbattle_env.Action):
|
||||
observation, reward, done, info = self.env.step(action)
|
||||
self.state.on_step(action, reward, done, observation)
|
||||
return observation, reward, done, info
|
||||
|
||||
def reset(self):
|
||||
observation = self.env.reset()
|
||||
self.state.on_reset(observation)
|
||||
return observation
|
|
@ -0,0 +1,71 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Test training of baseline agents. """
|
||||
import torch
|
||||
import gym
|
||||
import logging
|
||||
import sys
|
||||
import cyberbattle._env.cyberbattle_env as cyberbattle_env
|
||||
from cyberbattle.agents.baseline.agent_wrapper import Verbosity
|
||||
import cyberbattle.agents.baseline.agent_dql as dqla
|
||||
import cyberbattle.agents.baseline.agent_wrapper as w
|
||||
import cyberbattle.agents.baseline.learner as learner
|
||||
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.ERROR, format="%(levelname)s: %(message)s")
|
||||
|
||||
print(f"torch cuda available={torch.cuda.is_available()}")
|
||||
|
||||
cyberbattlechain = gym.make('CyberBattleChain-v0',
|
||||
size=4,
|
||||
attacker_goal=cyberbattle_env.AttackerGoal(
|
||||
own_atleast_percent=1.0,
|
||||
reward=100))
|
||||
|
||||
ep = w.EnvironmentBounds.of_identifiers(
|
||||
maximum_total_credentials=10,
|
||||
maximum_node_count=10,
|
||||
identifiers=cyberbattlechain.identifiers
|
||||
)
|
||||
|
||||
training_episode_count = 2
|
||||
iteration_count = 5
|
||||
|
||||
|
||||
def test_agent_training() -> None:
|
||||
dqn_learning_run = learner.epsilon_greedy_search(
|
||||
cyberbattle_gym_env=cyberbattlechain,
|
||||
environment_properties=ep,
|
||||
learner=dqla.DeepQLearnerPolicy(
|
||||
ep=ep,
|
||||
gamma=0.015,
|
||||
replay_memory_size=10000,
|
||||
target_update=10,
|
||||
batch_size=512,
|
||||
learning_rate=0.01), # torch default is 1e-2
|
||||
episode_count=training_episode_count,
|
||||
iteration_count=iteration_count,
|
||||
epsilon=0.90,
|
||||
render=False,
|
||||
# epsilon_multdecay=0.75, # 0.999,
|
||||
epsilon_exponential_decay=5000, # 10000
|
||||
epsilon_minimum=0.10,
|
||||
verbosity=Verbosity.Quiet,
|
||||
title="DQL"
|
||||
)
|
||||
assert dqn_learning_run
|
||||
|
||||
random_run = learner.epsilon_greedy_search(
|
||||
cyberbattlechain,
|
||||
ep,
|
||||
learner=learner.RandomPolicy(),
|
||||
episode_count=training_episode_count,
|
||||
iteration_count=iteration_count,
|
||||
epsilon=1.0, # purely random
|
||||
render=False,
|
||||
verbosity=Verbosity.Quiet,
|
||||
title="Random search"
|
||||
)
|
||||
|
||||
assert random_run
|
|
@ -0,0 +1,385 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Learner helpers and epsilon greedy search"""
|
||||
import math
|
||||
import sys
|
||||
|
||||
from .plotting import PlotTraining, plot_averaged_cummulative_rewards
|
||||
from .agent_wrapper import AgentWrapper, EnvironmentBounds, Verbosity, ActionTrackingStateAugmentation
|
||||
import logging
|
||||
import numpy as np
|
||||
from cyberbattle._env import cyberbattle_env
|
||||
from typing import Tuple, Optional, TypedDict, List
|
||||
import progressbar
|
||||
import abc
|
||||
|
||||
|
||||
class Learner(abc.ABC):
|
||||
"""Interface to be implemented by an epsilon-greedy learner"""
|
||||
|
||||
def new_episode(self) -> None:
|
||||
return None
|
||||
|
||||
def end_of_episode(self, i_episode, t) -> None:
|
||||
return None
|
||||
|
||||
def end_of_iteration(self, t, done) -> None:
|
||||
return None
|
||||
|
||||
@abc.abstractmethod
|
||||
def explore(self, wrapped_env: AgentWrapper) -> Tuple[str, cyberbattle_env.Action, object]:
|
||||
"""Exploration function.
|
||||
Returns (action_type, gym_action, action_metadata) where
|
||||
action_metadata is a custom object that gets passed to the on_step callback function"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def exploit(self, wrapped_env: AgentWrapper, observation) -> Tuple[str, Optional[cyberbattle_env.Action], object]:
|
||||
"""Exploit function.
|
||||
Returns (action_type, gym_action, action_metadata) where
|
||||
action_metadata is a custom object that gets passed to the on_step callback function"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_step(self, wrapped_env: AgentWrapper, observation, reward, done, info, action_metadata) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def parameters_as_string(self) -> str:
|
||||
return ''
|
||||
|
||||
def all_parameters_as_string(self) -> str:
|
||||
return ''
|
||||
|
||||
def loss_as_string(self) -> str:
|
||||
return ''
|
||||
|
||||
def stateaction_as_string(self, action_metadata) -> str:
|
||||
return ''
|
||||
|
||||
|
||||
class RandomPolicy(Learner):
|
||||
"""A policy that does not learn and only explore"""
|
||||
|
||||
def explore(self, wrapped_env: AgentWrapper) -> Tuple[str, cyberbattle_env.Action, object]:
|
||||
gym_action = wrapped_env.env.sample_valid_action()
|
||||
return "explore", gym_action, None
|
||||
|
||||
def exploit(self, wrapped_env: AgentWrapper, observation) -> Tuple[str, Optional[cyberbattle_env.Action], object]:
|
||||
raise NotImplementedError
|
||||
|
||||
def on_step(self, wrapped_env: AgentWrapper, observation, reward, done, info, action_metadata):
|
||||
return None
|
||||
|
||||
|
||||
Breakdown = TypedDict('Breakdown', {
|
||||
'local': int,
|
||||
'remote': int,
|
||||
'connect': int
|
||||
})
|
||||
|
||||
Outcomes = TypedDict('Outcomes', {
|
||||
'reward': Breakdown,
|
||||
'noreward': Breakdown
|
||||
})
|
||||
|
||||
Stats = TypedDict('Stats', {
|
||||
'exploit': Outcomes,
|
||||
'explore': Outcomes,
|
||||
'exploit_deflected_to_explore': int
|
||||
})
|
||||
|
||||
TrainedLearner = TypedDict('TrainedLearner', {
|
||||
'all_episodes_rewards': List[List[float]],
|
||||
'all_episodes_availability': List[List[float]],
|
||||
'learner': Learner,
|
||||
'trained_on': str,
|
||||
'title': str
|
||||
})
|
||||
|
||||
|
||||
def print_stats(stats):
|
||||
"""Print learning statistics"""
|
||||
def print_breakdown(stats, actiontype: str):
|
||||
def ratio(kind: str) -> str:
|
||||
x, y = stats[actiontype]['reward'][kind], stats[actiontype]['noreward'][kind]
|
||||
sum = x + y
|
||||
if sum == 0:
|
||||
return 'NaN'
|
||||
else:
|
||||
return f"{(x / sum):.2f}"
|
||||
|
||||
def print_kind(kind: str):
|
||||
print(
|
||||
f" {actiontype}-{kind}: {stats[actiontype]['reward'][kind]}/{stats[actiontype]['noreward'][kind]} "
|
||||
f"({ratio(kind)})")
|
||||
print_kind('local')
|
||||
print_kind('remote')
|
||||
print_kind('connect')
|
||||
|
||||
print(" Breakdown [Reward/NoReward (Success rate)]")
|
||||
print_breakdown(stats, 'explore')
|
||||
print_breakdown(stats, 'exploit')
|
||||
print(f" exploit deflected to exploration: {stats['exploit_deflected_to_explore']}")
|
||||
|
||||
|
||||
def epsilon_greedy_search(
|
||||
cyberbattle_gym_env: cyberbattle_env.CyberBattleEnv,
|
||||
environment_properties: EnvironmentBounds,
|
||||
learner: Learner,
|
||||
title: str,
|
||||
episode_count: int,
|
||||
iteration_count: int,
|
||||
epsilon: float,
|
||||
epsilon_minimum=0.0,
|
||||
epsilon_multdecay: Optional[float] = None,
|
||||
epsilon_exponential_decay: Optional[int] = None,
|
||||
render=True,
|
||||
render_last_episode_rewards_to: Optional[str] = None,
|
||||
verbosity: Verbosity = Verbosity.Normal
|
||||
) -> TrainedLearner:
|
||||
"""Epsilon greedy search for CyberBattle gym environments
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
- cyberbattle_gym_env -- the CyberBattle environment to train on
|
||||
|
||||
- learner --- the policy learner/exploiter
|
||||
|
||||
- episode_count -- Number of training episodes
|
||||
|
||||
- iteration_count -- Maximum number of iterations in each episode
|
||||
|
||||
- epsilon -- explore vs exploit
|
||||
- 0.0 to exploit the learnt policy only without exploration
|
||||
- 1.0 to explore purely randomly
|
||||
|
||||
- epsilon_minimum -- epsilon decay clipped at this value.
|
||||
Setting this value too close to 0 may leed the search to get stuck.
|
||||
|
||||
- epsilon_decay -- epsilon gets multiplied by this value after each episode
|
||||
|
||||
- epsilon_exponential_decay - if set use exponential decay. The bigger the value
|
||||
is, the slower it takes to get from the initial `epsilon` to `epsilon_minimum`.
|
||||
|
||||
- verbosity -- verbosity of the `print` logging
|
||||
|
||||
- render -- render the environment interactively after each episode
|
||||
|
||||
- render_last_episode_rewards_to -- render the environment to the specified file path
|
||||
with an index appended to it each time there is a positive reward
|
||||
for the last episode only
|
||||
|
||||
Note on convergence
|
||||
===================
|
||||
|
||||
Setting 'minimum_espilon' to 0 with an exponential decay <1
|
||||
makes the learning converge quickly (loss function getting to 0),
|
||||
but that's just a forced convergence, however, since when
|
||||
epsilon approaches 0, only the q-values that were explored so
|
||||
far get updated and so only that subset of cells from
|
||||
the Q-matrix converges.
|
||||
|
||||
"""
|
||||
|
||||
print(f"###### {title}\n"
|
||||
f"Learning with: episode_count={episode_count},"
|
||||
f"iteration_count={iteration_count},"
|
||||
f"ϵ={epsilon},"
|
||||
f'ϵ_min={epsilon_minimum}, '
|
||||
+ (f"ϵ_multdecay={epsilon_multdecay}," if epsilon_multdecay else '')
|
||||
+ (f"ϵ_expdecay={epsilon_exponential_decay}," if epsilon_exponential_decay else '') +
|
||||
f"{learner.parameters_as_string()}")
|
||||
|
||||
initial_epsilon = epsilon
|
||||
|
||||
all_episodes_rewards = []
|
||||
all_episodes_availability = []
|
||||
|
||||
wrapped_env = AgentWrapper(cyberbattle_gym_env,
|
||||
ActionTrackingStateAugmentation(environment_properties))
|
||||
steps_done = 0
|
||||
plot_title = f"{title} (epochs={episode_count}, ϵ={initial_epsilon}, ϵ_min={epsilon_minimum}," \
|
||||
+ (f"ϵ_multdecay={epsilon_multdecay}," if epsilon_multdecay else '') \
|
||||
+ (f"ϵ_expdecay={epsilon_exponential_decay}," if epsilon_exponential_decay else '') \
|
||||
+ learner.parameters_as_string()
|
||||
plottraining = PlotTraining(title=plot_title, render_each_episode=render)
|
||||
|
||||
render_file_index = 1
|
||||
|
||||
for i_episode in range(1, episode_count + 1):
|
||||
|
||||
print(f" ## Episode: {i_episode}/{episode_count} '{title}' "
|
||||
f"ϵ={epsilon:.4f}, "
|
||||
f"{learner.parameters_as_string()}")
|
||||
|
||||
observation = wrapped_env.reset()
|
||||
total_reward = 0.0
|
||||
all_rewards = []
|
||||
all_availability = []
|
||||
learner.new_episode()
|
||||
|
||||
stats = Stats(exploit=Outcomes(reward=Breakdown(local=0, remote=0, connect=0),
|
||||
noreward=Breakdown(local=0, remote=0, connect=0)),
|
||||
explore=Outcomes(reward=Breakdown(local=0, remote=0, connect=0),
|
||||
noreward=Breakdown(local=0, remote=0, connect=0)),
|
||||
exploit_deflected_to_explore=0
|
||||
)
|
||||
|
||||
episode_ended_at = None
|
||||
sys.stdout.flush()
|
||||
|
||||
bar = progressbar.ProgressBar(
|
||||
widgets=[
|
||||
'Episode ',
|
||||
f'{i_episode}',
|
||||
'|Iteration ',
|
||||
progressbar.Counter(),
|
||||
'|',
|
||||
progressbar.Variable(name='reward', width=6, precision=10),
|
||||
'|',
|
||||
progressbar.Timer(),
|
||||
progressbar.Bar()
|
||||
],
|
||||
redirect_stdout=False)
|
||||
|
||||
for t in bar(range(1, 1 + iteration_count)):
|
||||
|
||||
if epsilon_exponential_decay:
|
||||
epsilon = epsilon_minimum + math.exp(-1. * steps_done /
|
||||
epsilon_exponential_decay) * (initial_epsilon - epsilon_minimum)
|
||||
|
||||
steps_done += 1
|
||||
|
||||
x = np.random.rand()
|
||||
if x <= epsilon:
|
||||
action_style, gym_action, action_metadata = learner.explore(wrapped_env)
|
||||
else:
|
||||
action_style, gym_action, action_metadata = learner.exploit(wrapped_env, observation)
|
||||
if not gym_action:
|
||||
stats['exploit_deflected_to_explore'] += 1
|
||||
_, gym_action, action_metadata = learner.explore(wrapped_env)
|
||||
|
||||
# Take the step
|
||||
logging.debug(f"gym_action={gym_action}, action_metadata={action_metadata}")
|
||||
observation, reward, done, info = wrapped_env.step(gym_action)
|
||||
|
||||
action_type = 'exploit' if action_style == 'exploit' else 'explore'
|
||||
outcome = 'reward' if reward > 0 else 'noreward'
|
||||
if 'local_vulnerability' in gym_action:
|
||||
stats[action_type][outcome]['local'] += 1
|
||||
elif 'remote_vulnerability' in gym_action:
|
||||
stats[action_type][outcome]['remote'] += 1
|
||||
else:
|
||||
stats[action_type][outcome]['connect'] += 1
|
||||
|
||||
learner.on_step(wrapped_env, observation, reward, done, info, action_metadata)
|
||||
assert np.shape(reward) == ()
|
||||
|
||||
all_rewards.append(reward)
|
||||
all_availability.append(info['network_availability'])
|
||||
total_reward += reward
|
||||
bar.update(t, reward=total_reward)
|
||||
|
||||
if verbosity == Verbosity.Verbose or (verbosity == Verbosity.Normal and reward > 0):
|
||||
sign = ['-', '+'][reward > 0]
|
||||
|
||||
print(f" {sign} t={t} {action_style} r={reward} cum_reward:{total_reward} "
|
||||
f"a={action_metadata}-{gym_action} "
|
||||
f"creds={len(observation['credential_cache_matrix'])} "
|
||||
f" {learner.stateaction_as_string(action_metadata)}")
|
||||
|
||||
if i_episode == episode_count \
|
||||
and render_last_episode_rewards_to is not None \
|
||||
and reward > 0:
|
||||
fig = cyberbattle_gym_env.render_as_fig()
|
||||
fig.write_image(f"{render_last_episode_rewards_to}-e{i_episode}-{render_file_index}.png")
|
||||
render_file_index += 1
|
||||
|
||||
learner.end_of_iteration(t, done)
|
||||
|
||||
if done:
|
||||
episode_ended_at = t
|
||||
bar.finish(dirty=True)
|
||||
break
|
||||
|
||||
sys.stdout.flush()
|
||||
|
||||
loss_string = learner.loss_as_string()
|
||||
if loss_string:
|
||||
loss_string = "loss={loss_string}"
|
||||
|
||||
if episode_ended_at:
|
||||
print(f" Episode {i_episode} ended at t={episode_ended_at} {loss_string}")
|
||||
else:
|
||||
print(f" Episode {i_episode} stopped at t={iteration_count} {loss_string}")
|
||||
|
||||
print_stats(stats)
|
||||
|
||||
all_episodes_rewards.append(all_rewards)
|
||||
all_episodes_availability.append(all_availability)
|
||||
|
||||
length = episode_ended_at if episode_ended_at else iteration_count
|
||||
learner.end_of_episode(i_episode=i_episode, t=length)
|
||||
plottraining.episode_done(length)
|
||||
if render:
|
||||
wrapped_env.render()
|
||||
|
||||
if epsilon_multdecay:
|
||||
epsilon = max(epsilon_minimum, epsilon * epsilon_multdecay)
|
||||
|
||||
wrapped_env.close()
|
||||
print("simulation ended")
|
||||
plottraining.plot_end()
|
||||
|
||||
return TrainedLearner(
|
||||
all_episodes_rewards=all_episodes_rewards,
|
||||
all_episodes_availability=all_episodes_availability,
|
||||
learner=learner,
|
||||
trained_on=cyberbattle_gym_env.name,
|
||||
title=plot_title
|
||||
)
|
||||
|
||||
|
||||
def transfer_learning_evaluation(
|
||||
environment_properties: EnvironmentBounds,
|
||||
trained_learner: TrainedLearner,
|
||||
eval_env: cyberbattle_env.CyberBattleEnv,
|
||||
eval_epsilon: float,
|
||||
eval_episode_count: int,
|
||||
iteration_count: int,
|
||||
benchmark_policy=RandomPolicy(),
|
||||
benchmark_training_args=dict(title="Benchmark", epsilon=1.0)
|
||||
):
|
||||
"""Evaluated a trained agent on another environment of different size"""
|
||||
|
||||
eval_oneshot_all = epsilon_greedy_search(
|
||||
eval_env,
|
||||
environment_properties,
|
||||
learner=trained_learner['learner'],
|
||||
episode_count=eval_episode_count, # one shot from learnt Q matric
|
||||
iteration_count=iteration_count,
|
||||
epsilon=eval_epsilon,
|
||||
render=False,
|
||||
verbosity=Verbosity.Quiet,
|
||||
title=f"One shot on {eval_env.name} - Trained on {trained_learner['trained_on']}"
|
||||
)
|
||||
|
||||
eval_random = epsilon_greedy_search(
|
||||
eval_env,
|
||||
environment_properties,
|
||||
learner=benchmark_policy,
|
||||
episode_count=eval_episode_count,
|
||||
iteration_count=iteration_count,
|
||||
render=False,
|
||||
verbosity=Verbosity.Quiet,
|
||||
**benchmark_training_args
|
||||
)
|
||||
|
||||
plot_averaged_cummulative_rewards(
|
||||
all_runs=[eval_oneshot_all, eval_random],
|
||||
title=f"Transfer learning {trained_learner['trained_on']}->{eval_env.name} "
|
||||
f'-- max_nodes={environment_properties.maximum_node_count}, '
|
||||
f'episodes={eval_episode_count},\n'
|
||||
f"{trained_learner['learner'].all_parameters_as_string()}")
|
|
@ -0,0 +1,2 @@
|
|||
ffmpeg -y -r 2 -i chain10-e10-%d.png -vf "split[s0][s1];[s0]palettegen[p];[s1][p]paletteuse" chain10-dql.gif
|
||||
ffmpeg -y -r 2 -i chain10-e10-%d.png -vf "split[s0][s1];[s0]palettegen[p];[s1][p]paletteuse,crop=480:400:320:0" chain10-dql-network.gif
|
|
@ -0,0 +1,215 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Random agent with credential lookup (notebook)
|
||||
|
||||
This notebooks can be run directly from VSCode, to generate a
|
||||
traditional Jupyter Notebook to open in your browser
|
||||
you can run the VSCode command `Export Currenty Python File As Jupyter Notebook`.
|
||||
"""
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
# %% [markdown]
|
||||
# # Chain network CyberBattle Gym played by a random agent with credential cache lookup
|
||||
|
||||
# %%
|
||||
from cyberbattle._env import cyberbattle_env
|
||||
import gym
|
||||
import logging
|
||||
import sys
|
||||
import cyberbattle.agents.baseline.plotting as p
|
||||
import cyberbattle.agents.baseline.agent_wrapper as w
|
||||
from cyberbattle.agents.baseline.agent_wrapper import Verbosity
|
||||
import cyberbattle.agents.baseline.learner as learner
|
||||
import cyberbattle.agents.baseline.agent_randomcredlookup as rca
|
||||
import cyberbattle.agents.baseline.agent_dql as dqla
|
||||
import cyberbattle.agents.baseline.agent_tabularqlearning as tqa
|
||||
|
||||
import importlib
|
||||
importlib.reload(tqa)
|
||||
importlib.reload(dqla)
|
||||
importlib.reload(learner)
|
||||
importlib.reload(cyberbattle_env)
|
||||
|
||||
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.ERROR, format="%(levelname)s: %(message)s")
|
||||
|
||||
|
||||
# %% [markdown]
|
||||
# # Gym environment: chain-like network
|
||||
# See Jupyer notebook `chainenetwork-random` for an introduction to this network environment.
|
||||
cyberbattlechain_10 = gym.make(
|
||||
'CyberBattleChain-v0',
|
||||
size=10,
|
||||
attacker_goal=cyberbattle_env.AttackerGoal(reward=4000, own_atleast_percent=1.0)
|
||||
)
|
||||
|
||||
cyberbattlechain_10.environment
|
||||
# training_env.environment.plot_environment_graph()
|
||||
cyberbattlechain_10.environment.network.nodes
|
||||
cyberbattlechain_10.action_space
|
||||
cyberbattlechain_10.action_space.sample()
|
||||
cyberbattlechain_10.observation_space.sample()
|
||||
o0 = cyberbattlechain_10.reset()
|
||||
o_test, r, d, i = cyberbattlechain_10.step(cyberbattlechain_10.sample_valid_action())
|
||||
o0 = cyberbattlechain_10.reset()
|
||||
|
||||
o0.keys()
|
||||
|
||||
# %%
|
||||
ep = w.EnvironmentBounds.of_identifiers(
|
||||
maximum_node_count=22,
|
||||
maximum_total_credentials=22,
|
||||
identifiers=cyberbattlechain_10.identifiers
|
||||
)
|
||||
|
||||
print(f"port_count = {ep.port_count}, property_count = {ep.property_count}")
|
||||
|
||||
fe_example = w.RavelEncoding(ep, [w.Feature_active_node_properties(ep), w.Feature_discovered_node_count(ep)])
|
||||
a = w.StateAugmentation(o0)
|
||||
w.Feature_discovered_ports(ep).get(a, None)
|
||||
fe_example.encode_at(a, 0)
|
||||
|
||||
|
||||
iteration_count = 9000
|
||||
training_episode_count = 50
|
||||
eval_episode_count = 5
|
||||
|
||||
|
||||
# %%
|
||||
random_run = learner.epsilon_greedy_search(
|
||||
cyberbattlechain_10,
|
||||
ep,
|
||||
learner=learner.RandomPolicy(),
|
||||
episode_count=10, # training_episode_count,
|
||||
iteration_count=iteration_count,
|
||||
epsilon=1.0,
|
||||
render=False,
|
||||
verbosity=Verbosity.Quiet,
|
||||
title="Random"
|
||||
)
|
||||
|
||||
|
||||
# %%
|
||||
credlookup_run = learner.epsilon_greedy_search(
|
||||
cyberbattlechain_10,
|
||||
ep,
|
||||
learner=rca.CredentialCacheExploiter(),
|
||||
episode_count=10,
|
||||
iteration_count=iteration_count,
|
||||
epsilon=0.90,
|
||||
render=False,
|
||||
epsilon_exponential_decay=10000,
|
||||
epsilon_minimum=0.10,
|
||||
verbosity=Verbosity.Quiet,
|
||||
title="Credential lookups (ϵ-greedy)"
|
||||
)
|
||||
|
||||
# %%
|
||||
|
||||
tabularq_run = learner.epsilon_greedy_search(
|
||||
cyberbattlechain_10,
|
||||
ep,
|
||||
learner=tqa.QTabularLearner(
|
||||
ep,
|
||||
gamma=0.10, learning_rate=0.90, exploit_percentile=100),
|
||||
render=False,
|
||||
episode_count=training_episode_count,
|
||||
iteration_count=iteration_count,
|
||||
epsilon=0.90,
|
||||
epsilon_exponential_decay=10000,
|
||||
epsilon_minimum=0.10,
|
||||
verbosity=Verbosity.Quiet,
|
||||
title="Tabular Q-learning"
|
||||
)
|
||||
|
||||
# %%
|
||||
tabularq_exploit_run = learner.epsilon_greedy_search(
|
||||
cyberbattlechain_10,
|
||||
ep,
|
||||
learner=tqa.QTabularLearner(
|
||||
ep,
|
||||
trained=tabularq_run['learner'],
|
||||
gamma=0.0,
|
||||
learning_rate=0.0,
|
||||
exploit_percentile=90),
|
||||
episode_count=eval_episode_count,
|
||||
iteration_count=iteration_count,
|
||||
epsilon=0.0,
|
||||
render=False,
|
||||
verbosity=Verbosity.Quiet,
|
||||
title="Exploiting Q-matrix"
|
||||
)
|
||||
|
||||
# %%
|
||||
dql_run = learner.epsilon_greedy_search(
|
||||
cyberbattle_gym_env=cyberbattlechain_10,
|
||||
environment_properties=ep,
|
||||
learner=dqla.DeepQLearnerPolicy(
|
||||
ep=ep,
|
||||
gamma=0.015,
|
||||
replay_memory_size=10000,
|
||||
target_update=10,
|
||||
batch_size=512,
|
||||
learning_rate=0.01
|
||||
),
|
||||
episode_count=15,
|
||||
iteration_count=iteration_count,
|
||||
epsilon=0.90,
|
||||
render=False,
|
||||
epsilon_exponential_decay=5000,
|
||||
epsilon_minimum=0.10,
|
||||
verbosity=Verbosity.Quiet,
|
||||
title="DQL"
|
||||
)
|
||||
|
||||
# %%
|
||||
dql_exploit_run = learner.epsilon_greedy_search(
|
||||
cyberbattlechain_10,
|
||||
ep,
|
||||
learner=dql_run['learner'],
|
||||
episode_count=50,
|
||||
iteration_count=iteration_count,
|
||||
epsilon=0.00,
|
||||
epsilon_minimum=0.00,
|
||||
render=False,
|
||||
verbosity=Verbosity.Quiet,
|
||||
title="Exploiting DQL"
|
||||
)
|
||||
|
||||
# %%
|
||||
all_runs = [
|
||||
random_run,
|
||||
credlookup_run,
|
||||
tabularq_run,
|
||||
tabularq_exploit_run,
|
||||
dql_run,
|
||||
dql_exploit_run
|
||||
]
|
||||
|
||||
p.plot_episodes_length(all_runs)
|
||||
p.plot_averaged_cummulative_rewards(
|
||||
title=f'Agent Benchmark\n'
|
||||
f'max_nodes:{ep.maximum_node_count}\n',
|
||||
all_runs=all_runs)
|
||||
|
||||
# %%
|
||||
contenders = [
|
||||
credlookup_run,
|
||||
tabularq_run,
|
||||
dql_run,
|
||||
dql_exploit_run
|
||||
]
|
||||
p.plot_episodes_length(contenders)
|
||||
p.plot_averaged_cummulative_rewards(
|
||||
title=f'Agent Benchmark top contenders\n'
|
||||
f'max_nodes:{ep.maximum_node_count}\n',
|
||||
all_runs=contenders)
|
||||
|
||||
|
||||
# %%
|
||||
for r in contenders:
|
||||
p.plot_all_episodes(r)
|
||||
|
||||
# %%
|
|
@ -0,0 +1,116 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Tabular Q-learning agent (notebook)
|
||||
|
||||
This notebooks can be run directly from VSCode, to generate a
|
||||
traditional Jupyter Notebook to open in your browser
|
||||
you can run the VSCode command `Export Currenty Python File As Jupyter Notebook`.
|
||||
"""
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
# %%
|
||||
import sys
|
||||
import logging
|
||||
import gym
|
||||
|
||||
import cyberbattle.agents.baseline.learner as learner
|
||||
import cyberbattle.agents.baseline.plotting as p
|
||||
import cyberbattle.agents.baseline.agent_wrapper as w
|
||||
import cyberbattle.agents.baseline.agent_dql as dqla
|
||||
from cyberbattle.agents.baseline.agent_wrapper import Verbosity
|
||||
import importlib
|
||||
|
||||
importlib.reload(learner)
|
||||
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.ERROR, format="%(levelname)s: %(message)s")
|
||||
|
||||
# %%
|
||||
ctf_env = gym.make('CyberBattleToyCtf-v0')
|
||||
|
||||
ep = w.EnvironmentBounds.of_identifiers(
|
||||
maximum_node_count=22,
|
||||
maximum_total_credentials=22,
|
||||
identifiers=ctf_env.identifiers
|
||||
)
|
||||
|
||||
|
||||
iteration_count = 2000
|
||||
training_episode_count = 10
|
||||
eval_episode_count = 10
|
||||
|
||||
# %%
|
||||
# Run Deep Q-learning
|
||||
# 0.015
|
||||
best_dqn_learning_run_10 = learner.epsilon_greedy_search(
|
||||
cyberbattle_gym_env=ctf_env,
|
||||
environment_properties=ep,
|
||||
learner=dqla.DeepQLearnerPolicy(
|
||||
ep=ep,
|
||||
gamma=0.015,
|
||||
replay_memory_size=10000,
|
||||
target_update=10,
|
||||
batch_size=512,
|
||||
learning_rate=0.01), # torch default is 1e-2
|
||||
episode_count=training_episode_count,
|
||||
iteration_count=iteration_count,
|
||||
epsilon=0.90,
|
||||
render=False,
|
||||
# epsilon_multdecay=0.75, # 0.999,
|
||||
epsilon_exponential_decay=5000, # 10000
|
||||
epsilon_minimum=0.10,
|
||||
verbosity=Verbosity.Quiet,
|
||||
title="DQL"
|
||||
)
|
||||
|
||||
# %% Plot episode length
|
||||
|
||||
p.plot_episodes_length([best_dqn_learning_run_10])
|
||||
|
||||
|
||||
# %%
|
||||
dql_exploit_run = learner.epsilon_greedy_search(
|
||||
ctf_env,
|
||||
ep,
|
||||
learner=best_dqn_learning_run_10['learner'],
|
||||
episode_count=eval_episode_count,
|
||||
iteration_count=iteration_count,
|
||||
epsilon=0.0, # 0.35,
|
||||
render=False,
|
||||
title="Exploiting DQL",
|
||||
verbosity=Verbosity.Quiet
|
||||
)
|
||||
|
||||
|
||||
# %%
|
||||
random_run = learner.epsilon_greedy_search(
|
||||
ctf_env,
|
||||
ep,
|
||||
learner=learner.RandomPolicy(),
|
||||
episode_count=eval_episode_count,
|
||||
iteration_count=iteration_count,
|
||||
epsilon=1.0, # purely random
|
||||
render=False,
|
||||
verbosity=Verbosity.Quiet,
|
||||
title="Random search"
|
||||
)
|
||||
|
||||
# %%
|
||||
# Plot averaged cumulative rewards for DQL vs Random vs DQL-Exploit
|
||||
themodel = dqla.CyberBattleStateActionModel(ep)
|
||||
p.plot_averaged_cummulative_rewards(
|
||||
all_runs=[
|
||||
best_dqn_learning_run_10,
|
||||
random_run,
|
||||
dql_exploit_run
|
||||
],
|
||||
title=f'Benchmark -- max_nodes={ep.maximum_node_count}, episodes={eval_episode_count},\n'
|
||||
f'State: {[f.name() for f in themodel.state_space.feature_selection]} '
|
||||
f'({len(themodel.state_space.feature_selection)}\n'
|
||||
f"Action: abstract_action ({themodel.action_space.flat_size()})")
|
||||
|
||||
|
||||
# %%
|
||||
# plot cumulative rewards for all episodes
|
||||
p.plot_all_episodes(best_dqn_learning_run_10)
|
|
@ -0,0 +1,211 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
# %%
|
||||
"""Deep Q-learning agent (notebook)
|
||||
|
||||
This notebooks can be run directly from VSCode, to generate a
|
||||
traditional Jupyter Notebook to open in your browser
|
||||
you can run the VSCode command `Export Currenty Python File As Jupyter Notebook`.
|
||||
"""
|
||||
|
||||
# %%
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import gym
|
||||
import torch
|
||||
|
||||
import cyberbattle.agents.baseline.learner as learner
|
||||
import cyberbattle.agents.baseline.plotting as p
|
||||
import cyberbattle.agents.baseline.agent_wrapper as w
|
||||
import cyberbattle.agents.baseline.agent_dql as dqla
|
||||
from cyberbattle.agents.baseline.agent_wrapper import Verbosity
|
||||
import cyberbattle.agents.baseline.agent_randomcredlookup as rca
|
||||
import importlib
|
||||
import cyberbattle._env.cyberbattle_env as cyberbattle_env
|
||||
import cyberbattle._env.cyberbattle_chain as cyberbattle_chain
|
||||
|
||||
importlib.reload(learner)
|
||||
importlib.reload(cyberbattle_env)
|
||||
importlib.reload(cyberbattle_chain)
|
||||
|
||||
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.ERROR, format="%(levelname)s: %(message)s")
|
||||
|
||||
# %%
|
||||
torch.cuda.is_available()
|
||||
|
||||
# %%
|
||||
# To run once
|
||||
# import plotly.io as pio
|
||||
# pio.orca.config.use_xvfb = True
|
||||
# pio.orca.config.save()
|
||||
# %%
|
||||
cyberbattlechain_4 = gym.make('CyberBattleChain-v0', size=4, attacker_goal=cyberbattle_env.AttackerGoal(reward=2180))
|
||||
cyberbattlechain_10 = gym.make('CyberBattleChain-v0', size=10, attacker_goal=cyberbattle_env.AttackerGoal(reward=4000))
|
||||
cyberbattlechain_20 = gym.make('CyberBattleChain-v0', size=20, attacker_goal=cyberbattle_env.AttackerGoal(reward=7000))
|
||||
|
||||
ep = w.EnvironmentBounds.of_identifiers(
|
||||
maximum_total_credentials=22,
|
||||
maximum_node_count=22,
|
||||
identifiers=cyberbattlechain_10.identifiers
|
||||
)
|
||||
|
||||
iteration_count = 9000
|
||||
training_episode_count = 50
|
||||
eval_episode_count = 10
|
||||
|
||||
# %%
|
||||
# Run Deep Q-learning
|
||||
# 0.015
|
||||
best_dqn_learning_run_10 = learner.epsilon_greedy_search(
|
||||
cyberbattle_gym_env=cyberbattlechain_10,
|
||||
environment_properties=ep,
|
||||
learner=dqla.DeepQLearnerPolicy(
|
||||
ep=ep,
|
||||
gamma=0.015,
|
||||
replay_memory_size=10000,
|
||||
target_update=10,
|
||||
batch_size=512,
|
||||
learning_rate=0.01), # torch default is 1e-2
|
||||
episode_count=training_episode_count,
|
||||
iteration_count=iteration_count,
|
||||
epsilon=0.90,
|
||||
render=False,
|
||||
# epsilon_multdecay=0.75, # 0.999,
|
||||
epsilon_exponential_decay=5000, # 10000
|
||||
epsilon_minimum=0.10,
|
||||
verbosity=Verbosity.Quiet,
|
||||
title="DQL"
|
||||
)
|
||||
|
||||
# %% Plot episode length
|
||||
|
||||
p.plot_episodes_length([best_dqn_learning_run_10])
|
||||
|
||||
|
||||
# %%
|
||||
if not os.path.exists("images"):
|
||||
os.mkdir("images")
|
||||
|
||||
# %%
|
||||
dql_exploit_run = learner.epsilon_greedy_search(
|
||||
cyberbattlechain_10,
|
||||
ep,
|
||||
learner=best_dqn_learning_run_10['learner'],
|
||||
episode_count=eval_episode_count,
|
||||
iteration_count=iteration_count,
|
||||
epsilon=0.0, # 0.35,
|
||||
render=False,
|
||||
render_last_episode_rewards_to='images/chain10',
|
||||
title="Exploiting DQL",
|
||||
verbosity=Verbosity.Quiet
|
||||
)
|
||||
|
||||
|
||||
# %%
|
||||
random_run = learner.epsilon_greedy_search(
|
||||
cyberbattlechain_10,
|
||||
ep,
|
||||
learner=learner.RandomPolicy(),
|
||||
episode_count=eval_episode_count,
|
||||
iteration_count=iteration_count,
|
||||
epsilon=1.0, # purely random
|
||||
render=False,
|
||||
verbosity=Verbosity.Quiet,
|
||||
title="Random search"
|
||||
)
|
||||
|
||||
# %%
|
||||
# Plot averaged cumulative rewards for DQL vs Random vs DQL-Exploit
|
||||
themodel = dqla.CyberBattleStateActionModel(ep)
|
||||
p.plot_averaged_cummulative_rewards(
|
||||
all_runs=[
|
||||
best_dqn_learning_run_10,
|
||||
random_run,
|
||||
dql_exploit_run
|
||||
],
|
||||
title=f'Benchmark -- max_nodes={ep.maximum_node_count}, episodes={eval_episode_count},\n'
|
||||
f'State: {[f.name() for f in themodel.state_space.feature_selection]} '
|
||||
f'({len(themodel.state_space.feature_selection)}\n'
|
||||
f"Action: abstract_action ({themodel.action_space.flat_size()})")
|
||||
|
||||
|
||||
# %%
|
||||
# plot cumulative rewards for all episodes
|
||||
p.plot_all_episodes(best_dqn_learning_run_10)
|
||||
|
||||
|
||||
##################################################
|
||||
# %%
|
||||
|
||||
# %%
|
||||
best_dqn_4 = learner.epsilon_greedy_search(
|
||||
cyberbattle_gym_env=cyberbattlechain_4,
|
||||
environment_properties=ep,
|
||||
learner=dqla.DeepQLearnerPolicy(
|
||||
ep=ep,
|
||||
gamma=0.15,
|
||||
replay_memory_size=10000,
|
||||
target_update=5,
|
||||
batch_size=256,
|
||||
learning_rate=0.01),
|
||||
episode_count=training_episode_count,
|
||||
iteration_count=iteration_count,
|
||||
epsilon=0.90,
|
||||
render=False,
|
||||
epsilon_exponential_decay=5000,
|
||||
epsilon_minimum=0.10,
|
||||
verbosity=Verbosity.Quiet,
|
||||
title="DQL"
|
||||
)
|
||||
|
||||
|
||||
# %%
|
||||
learner.transfer_learning_evaluation(
|
||||
environment_properties=ep,
|
||||
trained_learner=best_dqn_learning_run_10,
|
||||
eval_env=cyberbattlechain_20,
|
||||
eval_epsilon=0.0, # alternate with exploration to help generalization to bigger network
|
||||
eval_episode_count=eval_episode_count,
|
||||
iteration_count=iteration_count,
|
||||
benchmark_policy=rca.CredentialCacheExploiter(),
|
||||
benchmark_training_args={'epsilon': 0.90,
|
||||
'epsilon_exponential_decay': 10000,
|
||||
'epsilon_minimum': 0.10,
|
||||
'title': 'Credential lookups (ϵ-greedy)'}
|
||||
)
|
||||
# %%
|
||||
learner.transfer_learning_evaluation(
|
||||
environment_properties=ep,
|
||||
trained_learner=best_dqn_4,
|
||||
eval_env=cyberbattlechain_10,
|
||||
eval_epsilon=0.0, # exploit Q-matrix only
|
||||
eval_episode_count=eval_episode_count,
|
||||
iteration_count=iteration_count,
|
||||
benchmark_policy=rca.CredentialCacheExploiter(),
|
||||
benchmark_training_args={'epsilon': 0.90,
|
||||
'epsilon_exponential_decay': 10000,
|
||||
'epsilon_minimum': 0.10,
|
||||
'title': 'Credential lookups (ϵ-greedy)'}
|
||||
)
|
||||
|
||||
# %%
|
||||
|
||||
learner.transfer_learning_evaluation(
|
||||
environment_properties=ep,
|
||||
trained_learner=best_dqn_4,
|
||||
eval_env=cyberbattlechain_20,
|
||||
eval_epsilon=0.0, # exploit Q-matrix only
|
||||
eval_episode_count=eval_episode_count,
|
||||
iteration_count=iteration_count,
|
||||
benchmark_policy=rca.CredentialCacheExploiter(),
|
||||
benchmark_training_args={'epsilon': 0.90,
|
||||
'epsilon_exponential_decay': 10000,
|
||||
'epsilon_minimum': 0.10,
|
||||
'title': 'Credential lookups (ϵ-greedy)'}
|
||||
)
|
||||
|
||||
# %%
|
|
@ -0,0 +1,85 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Random exploration with credential lookup exploitation (notebook)
|
||||
|
||||
This notebooks can be run directly from VSCode, to generate a
|
||||
traditional Jupyter Notebook to open in your browser
|
||||
you can run the VSCode command `Export Currenty Python File As Jupyter Notebook`.
|
||||
"""
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
# %%
|
||||
from cyberbattle._env.cyberbattle_env import AttackerGoal
|
||||
from cyberbattle.agents.baseline.agent_randomcredlookup import CredentialCacheExploiter
|
||||
import cyberbattle.agents.baseline.learner as learner
|
||||
import gym
|
||||
import logging
|
||||
import sys
|
||||
import cyberbattle.agents.baseline.plotting as p
|
||||
import cyberbattle.agents.baseline.agent_wrapper as w
|
||||
from cyberbattle.agents.baseline.agent_wrapper import Verbosity
|
||||
|
||||
# %%
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.ERROR, format="%(levelname)s: %(message)s")
|
||||
|
||||
|
||||
# %%
|
||||
cyberbattlechain_10 = gym.make('CyberBattleChain-v0', size=10,
|
||||
attacker_goal=AttackerGoal(own_atleast_percent=1.0))
|
||||
|
||||
|
||||
# %%
|
||||
ep = w.EnvironmentBounds.of_identifiers(
|
||||
maximum_total_credentials=12,
|
||||
maximum_node_count=12,
|
||||
identifiers=cyberbattlechain_10.identifiers
|
||||
)
|
||||
|
||||
iteration_count = 9000
|
||||
training_episode_count = 50
|
||||
eval_episode_count = 5
|
||||
|
||||
# %%
|
||||
credexplot = learner.epsilon_greedy_search(
|
||||
cyberbattlechain_10,
|
||||
learner=CredentialCacheExploiter(),
|
||||
environment_properties=ep,
|
||||
episode_count=training_episode_count,
|
||||
iteration_count=iteration_count,
|
||||
epsilon=0.90,
|
||||
render=False,
|
||||
epsilon_multdecay=0.75, # 0.999,
|
||||
epsilon_minimum=0.01,
|
||||
verbosity=Verbosity.Quiet,
|
||||
title="Random+CredLookup"
|
||||
)
|
||||
|
||||
# %%
|
||||
randomlearning_results = learner.epsilon_greedy_search(
|
||||
cyberbattlechain_10,
|
||||
environment_properties=ep,
|
||||
learner=CredentialCacheExploiter(),
|
||||
episode_count=eval_episode_count,
|
||||
iteration_count=iteration_count,
|
||||
epsilon=1.0, # purely random
|
||||
render=False,
|
||||
verbosity=Verbosity.Quiet,
|
||||
title="Random search"
|
||||
)
|
||||
|
||||
|
||||
# %%
|
||||
p.plot_episodes_length([credexplot])
|
||||
|
||||
p.plot_all_episodes(credexplot)
|
||||
|
||||
all_runs = [credexplot,
|
||||
randomlearning_results
|
||||
]
|
||||
p.plot_averaged_cummulative_rewards(
|
||||
title=f'Benchmark -- max_nodes={ep.maximum_node_count}, episodes={eval_episode_count},\n',
|
||||
all_runs=all_runs)
|
||||
|
||||
# %%
|
|
@ -0,0 +1,226 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Tabular Q-learning agent (notebook)
|
||||
|
||||
This notebooks can be run directly from VSCode, to generate a
|
||||
traditional Jupyter Notebook to open in your browser
|
||||
you can run the VSCode command `Export Currenty Python File As Jupyter Notebook`.
|
||||
"""
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
# %%
|
||||
import sys
|
||||
import logging
|
||||
from typing import cast
|
||||
import gym
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from cyberbattle.agents.baseline.learner import TrainedLearner
|
||||
import cyberbattle.agents.baseline.plotting as p
|
||||
import cyberbattle.agents.baseline.agent_wrapper as w
|
||||
import cyberbattle.agents.baseline.agent_tabularqlearning as a
|
||||
from cyberbattle.agents.baseline.agent_wrapper import Verbosity
|
||||
import cyberbattle.agents.baseline.learner as learner
|
||||
from cyberbattle._env.cyberbattle_env import AttackerGoal
|
||||
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.ERROR, format="%(levelname)s: %(message)s")
|
||||
|
||||
# %%
|
||||
# Benchmark parameters:
|
||||
# Parameters from DeepDoubleQ paper
|
||||
# - learning_rate = 0.00025
|
||||
# - linear epsilon decay
|
||||
# - gamma = 0.99
|
||||
# Eliminated gamma_values
|
||||
# 0.0,
|
||||
# 0.0015, # too small
|
||||
# 0.15, # too big
|
||||
# 0.25, # too big
|
||||
# 0.35, # too big
|
||||
#
|
||||
# NOTE: Given the relatively low number of training episodes (50,
|
||||
# a high learning rate of .99 gives better result
|
||||
# than a lower learning rate of 0.25 (i.e. maximal rewards reached faster on average).
|
||||
# Ideally we should decay the learning rate just like gamma and train over a
|
||||
# much larger number of episodes
|
||||
|
||||
cyberbattlechain_10 = gym.make('CyberBattleChain-v0', size=10, attacker_goal=AttackerGoal(own_atleast_percent=1.0))
|
||||
|
||||
|
||||
ep = w.EnvironmentBounds.of_identifiers(
|
||||
maximum_node_count=12,
|
||||
maximum_total_credentials=12,
|
||||
identifiers=cyberbattlechain_10.identifiers
|
||||
)
|
||||
|
||||
iteration_count = 9000
|
||||
training_episode_count = 5
|
||||
eval_episode_count = 5
|
||||
gamma_sweep = [
|
||||
0.015, # about right
|
||||
]
|
||||
|
||||
|
||||
def qlearning_run(gamma, gym_env):
|
||||
"""Execute one run of the q-learning algorithm for the
|
||||
specified gamma value"""
|
||||
return learner.epsilon_greedy_search(
|
||||
gym_env,
|
||||
ep,
|
||||
a.QTabularLearner(ep, gamma=gamma, learning_rate=0.90, exploit_percentile=100),
|
||||
episode_count=training_episode_count,
|
||||
iteration_count=iteration_count,
|
||||
epsilon=0.90,
|
||||
render=False,
|
||||
epsilon_multdecay=0.75, # 0.999,
|
||||
epsilon_minimum=0.01,
|
||||
verbosity=Verbosity.Quiet,
|
||||
title="Q-learning"
|
||||
)
|
||||
|
||||
|
||||
# %%
|
||||
# Run Q-learning with gamma-sweep
|
||||
qlearning_results = [qlearning_run(gamma, cyberbattlechain_10) for gamma in gamma_sweep]
|
||||
|
||||
qlearning_bestrun_10 = qlearning_results[0]
|
||||
|
||||
# %%
|
||||
|
||||
p.new_plot_loss()
|
||||
for results in qlearning_results:
|
||||
p.plot_all_episodes_loss(cast(a.QTabularLearner, results['learner']).loss_qsource.all_episodes, 'Q_source', results['title'])
|
||||
p.plot_all_episodes_loss(cast(a.QTabularLearner, results['learner']).loss_qattack.all_episodes, 'Q_attack', results['title'])
|
||||
plt.legend(loc="upper right")
|
||||
plt.show()
|
||||
|
||||
# %% Plot episode length
|
||||
|
||||
p.plot_episodes_length(qlearning_results)
|
||||
|
||||
|
||||
# %%
|
||||
|
||||
|
||||
nolearning_results = learner.epsilon_greedy_search(
|
||||
cyberbattlechain_10,
|
||||
ep,
|
||||
learner=a.QTabularLearner(ep, trained=qlearning_bestrun_10['learner'],
|
||||
gamma=0.0, learning_rate=0.0, exploit_percentile=100),
|
||||
episode_count=eval_episode_count,
|
||||
iteration_count=iteration_count,
|
||||
epsilon=0.30, # 0.35,
|
||||
render=False,
|
||||
title="Exploiting Q-matrix",
|
||||
verbosity=Verbosity.Quiet
|
||||
)
|
||||
|
||||
# %%
|
||||
randomlearning_results = learner.epsilon_greedy_search(
|
||||
cyberbattlechain_10,
|
||||
ep,
|
||||
learner=a.QTabularLearner(ep, trained=qlearning_bestrun_10['learner'],
|
||||
gamma=0.0, learning_rate=0.0, exploit_percentile=100),
|
||||
episode_count=eval_episode_count,
|
||||
iteration_count=iteration_count,
|
||||
epsilon=1.0, # purely random
|
||||
render=False,
|
||||
verbosity=Verbosity.Quiet,
|
||||
title="Random search"
|
||||
)
|
||||
|
||||
# %%
|
||||
# Plot averaged cumulative rewards for Q-learning vs Random vs Q-Exploit
|
||||
all_runs = [*qlearning_results,
|
||||
randomlearning_results,
|
||||
nolearning_results
|
||||
]
|
||||
|
||||
Q_source_10 = cast(a.QTabularLearner, qlearning_bestrun_10['learner']).qsource
|
||||
Q_attack_10 = cast(a.QTabularLearner, qlearning_bestrun_10['learner']).qattack
|
||||
|
||||
p.plot_averaged_cummulative_rewards(
|
||||
all_runs=all_runs,
|
||||
title=f'Benchmark -- max_nodes={ep.maximum_node_count}, episodes={eval_episode_count},\n'
|
||||
f'dimension={Q_source_10.state_space.flat_size()}x{Q_source_10.action_space.flat_size()}, '
|
||||
f'{Q_attack_10.state_space.flat_size()}x{Q_attack_10.action_space.flat_size()}\n'
|
||||
f'Q1={[f.name() for f in Q_source_10.state_space.feature_selection]} '
|
||||
f'-> {[f.name() for f in Q_source_10.action_space.feature_selection]})\n'
|
||||
f"Q2={[f.name() for f in Q_attack_10.state_space.feature_selection]} -> 'action'")
|
||||
|
||||
|
||||
# %%
|
||||
# plot cumulative rewards for all episodes
|
||||
p.plot_all_episodes(qlearning_results[0])
|
||||
|
||||
|
||||
# %%
|
||||
# Plot the Q-matrices
|
||||
|
||||
# %%
|
||||
# Print non-zero coordinate in the Q matrix Q_source
|
||||
i = np.where(Q_source_10.qm)
|
||||
q = Q_source_10.qm[i]
|
||||
list(zip(np.array([Q_source_10.state_space.pretty_print(i) for i in i[0]]),
|
||||
np.array([Q_source_10.action_space.pretty_print(i) for i in i[1]]), q))
|
||||
|
||||
# %%
|
||||
# Print non-zero coordinate in the Q matrix Q_attack
|
||||
i2 = np.where(Q_attack_10.qm)
|
||||
q2 = Q_attack_10.qm[i2]
|
||||
list(zip([Q_attack_10.state_space.pretty_print(i) for i in i2[0]],
|
||||
[Q_attack_10.action_space.pretty_print(i) for i in i2[1]], q2))
|
||||
|
||||
|
||||
##################################################
|
||||
|
||||
# %% [markdown]
|
||||
# ## Transfer learning from size 4 to size 10
|
||||
# Exploiting Q-matrix learned from a different network.
|
||||
|
||||
# %%
|
||||
# Train Q-matrix on CyberBattle network of size 4
|
||||
cyberbattlechain_4 = gym.make('CyberBattleChain-v0', size=4,
|
||||
attacker_goal=AttackerGoal(own_atleast_percent=1.0)
|
||||
)
|
||||
|
||||
qlearning_bestrun_4 = qlearning_run(0.015, gym_env=cyberbattlechain_4)
|
||||
|
||||
|
||||
def stop_learning(trained_learner):
|
||||
return TrainedLearner(
|
||||
learner=a.QTabularLearner(
|
||||
ep,
|
||||
gamma=0.0,
|
||||
learning_rate=0.0,
|
||||
exploit_percentile=0,
|
||||
trained=trained_learner['learner']
|
||||
),
|
||||
title=trained_learner['title'],
|
||||
trained_on=trained_learner['trained_on'],
|
||||
all_episodes_rewards=trained_learner['all_episodes_rewards'],
|
||||
all_episodes_availability=trained_learner['all_episodes_availability']
|
||||
)
|
||||
|
||||
|
||||
learner.transfer_learning_evaluation(
|
||||
environment_properties=ep,
|
||||
trained_learner=stop_learning(qlearning_bestrun_4),
|
||||
eval_env=cyberbattlechain_10,
|
||||
eval_epsilon=0.5, # alternate with exploration to help generalization to bigger network
|
||||
eval_episode_count=eval_episode_count,
|
||||
iteration_count=iteration_count
|
||||
)
|
||||
|
||||
learner.transfer_learning_evaluation(
|
||||
environment_properties=ep,
|
||||
trained_learner=stop_learning(qlearning_bestrun_10),
|
||||
eval_env=cyberbattlechain_4,
|
||||
eval_epsilon=0.5,
|
||||
eval_episode_count=eval_episode_count,
|
||||
iteration_count=iteration_count
|
||||
)
|
||||
|
||||
# %%
|
|
@ -0,0 +1,131 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Attacker agent benchmark comparison in presence of a basic defender
|
||||
|
||||
This notebooks can be run directly from VSCode, to generate a
|
||||
traditional Jupyter Notebook to open in your browser
|
||||
you can run the VSCode command `Export Currenty Python File As Jupyter Notebook`.
|
||||
"""
|
||||
|
||||
# %%
|
||||
import sys
|
||||
import logging
|
||||
import gym
|
||||
import importlib
|
||||
|
||||
import cyberbattle.agents.baseline.learner as learner
|
||||
import cyberbattle.agents.baseline.plotting as p
|
||||
import cyberbattle.agents.baseline.agent_wrapper as w
|
||||
import cyberbattle.agents.baseline.agent_dql as dqla
|
||||
import cyberbattle.agents.baseline.agent_randomcredlookup as rca
|
||||
from cyberbattle.agents.baseline.agent_wrapper import Verbosity
|
||||
from cyberbattle._env.defender import ScanAndReimageCompromisedMachines
|
||||
from cyberbattle._env.cyberbattle_env import AttackerGoal, DefenderConstraint
|
||||
|
||||
importlib.reload(learner)
|
||||
importlib.reload(p)
|
||||
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.ERROR, format="%(levelname)s: %(message)s")
|
||||
|
||||
|
||||
cyberbattlechain_defender = gym.make('CyberBattleChain-v0',
|
||||
size=10,
|
||||
attacker_goal=AttackerGoal(
|
||||
# reward=2180,
|
||||
own_atleast=0,
|
||||
own_atleast_percent=1.0
|
||||
),
|
||||
defender_constraint=DefenderConstraint(
|
||||
maintain_sla=0.80
|
||||
),
|
||||
defender_agent=ScanAndReimageCompromisedMachines(
|
||||
probability=0.6,
|
||||
scan_capacity=2,
|
||||
scan_frequency=5))
|
||||
|
||||
|
||||
ep = w.EnvironmentBounds.of_identifiers(
|
||||
maximum_total_credentials=22,
|
||||
maximum_node_count=22,
|
||||
identifiers=cyberbattlechain_defender.identifiers
|
||||
)
|
||||
|
||||
iteration_count = 600
|
||||
training_episode_count = 10
|
||||
|
||||
|
||||
# %%
|
||||
dqn_with_defender = learner.epsilon_greedy_search(
|
||||
cyberbattle_gym_env=cyberbattlechain_defender,
|
||||
environment_properties=ep,
|
||||
learner=dqla.DeepQLearnerPolicy(
|
||||
ep=ep,
|
||||
gamma=0.15,
|
||||
replay_memory_size=10000,
|
||||
target_update=5,
|
||||
batch_size=256,
|
||||
learning_rate=0.01),
|
||||
episode_count=training_episode_count,
|
||||
iteration_count=iteration_count,
|
||||
epsilon=0.90,
|
||||
render=False,
|
||||
epsilon_exponential_decay=5000,
|
||||
epsilon_minimum=0.10,
|
||||
verbosity=Verbosity.Quiet,
|
||||
title="DQL"
|
||||
)
|
||||
|
||||
|
||||
# %%
|
||||
dql_exploit_run = learner.epsilon_greedy_search(
|
||||
cyberbattlechain_defender,
|
||||
ep,
|
||||
learner=dqn_with_defender['learner'],
|
||||
episode_count=training_episode_count,
|
||||
iteration_count=iteration_count,
|
||||
epsilon=0.0, # 0.35,
|
||||
render=False,
|
||||
# render_last_episode_rewards_to='images/chain10',
|
||||
verbosity=Verbosity.Quiet,
|
||||
title="Exploiting DQL"
|
||||
)
|
||||
|
||||
# %%
|
||||
credlookup_run = learner.epsilon_greedy_search(
|
||||
cyberbattlechain_defender,
|
||||
ep,
|
||||
learner=rca.CredentialCacheExploiter(),
|
||||
episode_count=10,
|
||||
iteration_count=iteration_count,
|
||||
epsilon=0.90,
|
||||
render=False,
|
||||
epsilon_exponential_decay=10000,
|
||||
epsilon_minimum=0.10,
|
||||
verbosity=Verbosity.Quiet,
|
||||
title="Credential lookups (ϵ-greedy)"
|
||||
)
|
||||
|
||||
|
||||
# %%
|
||||
# Plots
|
||||
all_runs = [
|
||||
credlookup_run,
|
||||
dqn_with_defender,
|
||||
dql_exploit_run
|
||||
]
|
||||
p.plot_averaged_cummulative_rewards(
|
||||
all_runs=all_runs,
|
||||
title=f'Attacker agents vs Basic Defender -- rewards\n env={cyberbattlechain_defender.name}, episodes={training_episode_count}'
|
||||
)
|
||||
|
||||
# p.plot_episodes_length(all_runs)
|
||||
p.plot_averaged_availability(title=f"Attacker agents vs Basic Defender -- availability\n env={cyberbattlechain_defender.name}, episodes={training_episode_count}", all_runs=all_runs)
|
||||
|
||||
# %%
|
||||
|
||||
# %%
|
||||
|
||||
# %%
|
||||
|
||||
# %%
|
|
@ -0,0 +1,203 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Plotting helpers for agent banchmarking"""
|
||||
|
||||
import matplotlib.pyplot as plt # type:ignore
|
||||
import numpy as np
|
||||
|
||||
|
||||
def new_plot(title):
|
||||
"""Prepare a new plot of cumulative rewards"""
|
||||
plt.figure(figsize=(10, 8))
|
||||
plt.ylabel('cumulative reward', fontsize=20)
|
||||
plt.xlabel('step', fontsize=20)
|
||||
plt.xticks(size=20)
|
||||
plt.yticks(size=20)
|
||||
plt.title(title, fontsize=12)
|
||||
|
||||
|
||||
def pad(array, length):
|
||||
"""Pad an array with 0s to make it of desired length"""
|
||||
padding = np.zeros((length,))
|
||||
padding[:len(array)] = array
|
||||
return padding
|
||||
|
||||
|
||||
def plot_episodes_rewards_averaged(results):
|
||||
"""Plot cumulative rewards for a given set of specified episodes"""
|
||||
max_iteration_count = np.max([len(r) for r in results['all_episodes_rewards']])
|
||||
|
||||
all_episodes_rewards_padded = [pad(rewards, max_iteration_count) for rewards in results['all_episodes_rewards']]
|
||||
cumrewards = np.cumsum(all_episodes_rewards_padded, axis=1)
|
||||
avg = np.average(cumrewards, axis=0)
|
||||
std = np.std(cumrewards, axis=0)
|
||||
x = [i for i in range(len(std))]
|
||||
plt.plot(x, avg, label=results['title'])
|
||||
plt.fill_between(x, avg - std, avg + std, alpha=0.5)
|
||||
|
||||
|
||||
def fill_with_latest_value(array, length):
|
||||
pad = length - len(array)
|
||||
if pad > 0:
|
||||
return np.pad(array, (0, pad), mode='edge')
|
||||
else:
|
||||
return array
|
||||
|
||||
|
||||
def plot_episodes_availability_averaged(results):
|
||||
"""Plot availability for a given set of specified episodes"""
|
||||
data = results['all_episodes_availability']
|
||||
longest_episode_length = np.max([len(r) for r in data])
|
||||
|
||||
all_episodes_padded = [fill_with_latest_value(av, longest_episode_length) for av in data]
|
||||
avg = np.average(all_episodes_padded, axis=0)
|
||||
std = np.std(all_episodes_padded, axis=0)
|
||||
x = [i for i in range(len(std))]
|
||||
plt.plot(x, avg, label=results['title'])
|
||||
plt.fill_between(x, avg - std, avg + std, alpha=0.5)
|
||||
|
||||
|
||||
def plot_episodes_length(learning_results):
|
||||
"""Plot length of every episode"""
|
||||
plt.figure(figsize=(10, 8))
|
||||
plt.ylabel('#iterations', fontsize=20)
|
||||
plt.xlabel('episode', fontsize=20)
|
||||
plt.xticks(size=20)
|
||||
plt.yticks(size=20)
|
||||
plt.title("Length of each episode", fontsize=12)
|
||||
|
||||
for results in learning_results:
|
||||
iterations = [len(e) for e in results['all_episodes_rewards']]
|
||||
episode = [i for i in range(len(results['all_episodes_rewards']))]
|
||||
plt.plot(episode, iterations, label=f"{results['title']}")
|
||||
|
||||
plt.legend(loc="upper right")
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_each_episode(results):
|
||||
"""Plot cumulative rewards for each episode"""
|
||||
for i, episode in enumerate(results['all_episodes_rewards']):
|
||||
cumrewards = np.cumsum(episode)
|
||||
x = [i for i in range(len(cumrewards))]
|
||||
plt.plot(x, cumrewards, label=f'Episode {i}')
|
||||
|
||||
|
||||
def plot_all_episodes(r):
|
||||
"""Plot cumulative rewards for every episode"""
|
||||
new_plot(r['title'])
|
||||
plot_each_episode(r)
|
||||
plt.legend(loc="lower right")
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_averaged_cummulative_rewards(title, all_runs):
|
||||
"""Plot averaged cumulative rewards"""
|
||||
new_plot(title)
|
||||
for r in all_runs:
|
||||
plot_episodes_rewards_averaged(r)
|
||||
plt.legend(loc="lower right")
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_averaged_availability(title, all_runs):
|
||||
"""Plot averaged network availability"""
|
||||
plt.figure(figsize=(10, 8))
|
||||
plt.ylabel('network availability', fontsize=20)
|
||||
plt.xlabel('step', fontsize=20)
|
||||
plt.xticks(size=20)
|
||||
plt.yticks(size=20)
|
||||
plt.title(title, fontsize=12)
|
||||
for r in all_runs:
|
||||
plot_episodes_availability_averaged(r)
|
||||
plt.legend(loc="lower right")
|
||||
plt.show()
|
||||
|
||||
|
||||
def new_plot_loss():
|
||||
"""Plot MSE loss averaged over all episodes"""
|
||||
plt.figure(figsize=(10, 8))
|
||||
plt.ylabel('loss', fontsize=20)
|
||||
plt.xlabel('episodes', fontsize=20)
|
||||
plt.xticks(size=12)
|
||||
plt.yticks(size=20)
|
||||
plt.title("Loss", fontsize=12)
|
||||
|
||||
|
||||
def plot_all_episodes_loss(all_episodes_losses, name, label):
|
||||
"""Plot loss for one learning episode"""
|
||||
x = [i for i in range(len(all_episodes_losses))]
|
||||
plt.plot(x, all_episodes_losses, label=f'{name} {label}')
|
||||
|
||||
|
||||
def running_mean(x, size):
|
||||
"""return moving average of x for a window of lenght 'size'"""
|
||||
cumsum = np.cumsum(np.insert(x, 0, 0))
|
||||
return (cumsum[size:] - cumsum[:-size]) / float(size)
|
||||
|
||||
|
||||
class PlotTraining:
|
||||
"""Plot training-related stats"""
|
||||
|
||||
def __init__(self, title, render_each_episode):
|
||||
self.episode_durations = []
|
||||
self.title = title
|
||||
self.render_each_episode = render_each_episode
|
||||
|
||||
def plot_durations(self, average_window=5):
|
||||
# plt.figure(2)
|
||||
plt.figure()
|
||||
# plt.clf()
|
||||
durations_t = np.array(self.episode_durations, dtype=np.float32)
|
||||
plt.title('Training...')
|
||||
plt.xlabel('Episode')
|
||||
plt.ylabel('Duration')
|
||||
plt.title(self.title, fontsize=12)
|
||||
|
||||
episodes = [i + 1 for i in range(len(self.episode_durations))]
|
||||
plt.plot(episodes, durations_t)
|
||||
# plot episode running averages
|
||||
if len(durations_t) >= average_window:
|
||||
means = running_mean(durations_t, average_window)
|
||||
means = np.concatenate((np.zeros(average_window - 1), means))
|
||||
plt.plot(episodes, means)
|
||||
|
||||
# display.display(plt.gcf())
|
||||
plt.show()
|
||||
|
||||
def episode_done(self, length):
|
||||
self.episode_durations.append(length)
|
||||
if self.render_each_episode:
|
||||
self.plot_durations()
|
||||
|
||||
def plot_end(self):
|
||||
self.plot_durations()
|
||||
plt.ioff() # type: ignore
|
||||
# plt.show()
|
||||
|
||||
|
||||
def length_of_all_episodes(run):
|
||||
"""Get the length of every episode"""
|
||||
return [len(e) for e in run['all_episodes_rewards']]
|
||||
|
||||
|
||||
def reduce(x, desired_width):
|
||||
return [np.average(c) for c in np.array_split(x, desired_width)]
|
||||
|
||||
|
||||
def episodes_rewards_averaged(run):
|
||||
"""Plot cumulative rewards for a given set of specified episodes"""
|
||||
max_iteration_count = np.max([len(r) for r in run['all_episodes_rewards']])
|
||||
all_episodes_rewards_padded = [pad(rewards, max_iteration_count) for rewards in run['all_episodes_rewards']]
|
||||
cumrewards = np.cumsum(all_episodes_rewards_padded, axis=1)
|
||||
avg = np.average(cumrewards, axis=0)
|
||||
return list(avg)
|
||||
|
||||
|
||||
def episodes_lengths_for_all_runs(all_runs):
|
||||
return [length_of_all_episodes(run) for run in all_runs]
|
||||
|
||||
|
||||
def averaged_cummulative_rewards(all_runs, width):
|
||||
return [reduce(episodes_rewards_averaged(run), width) for run in all_runs]
|
|
@ -0,0 +1,121 @@
|
|||
#!/usr/bin/python3.8
|
||||
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
"""CLI to run the baseline Deep Q-learning and Random agents
|
||||
on a sample CyberBattle gym environment and plot the respective
|
||||
cummulative rewards in the terminal.
|
||||
|
||||
Example usage:
|
||||
|
||||
python3.8 -m run --training_episode_count 50 --iteration_count 9000 --rewardplot_with 80 --chain_size=20 --ownership_goal 1.0
|
||||
|
||||
"""
|
||||
import torch
|
||||
import gym
|
||||
import logging
|
||||
import sys
|
||||
import asciichartpy
|
||||
import argparse
|
||||
import cyberbattle._env.cyberbattle_env as cyberbattle_env
|
||||
from cyberbattle.agents.baseline.agent_wrapper import Verbosity
|
||||
import cyberbattle.agents.baseline.agent_dql as dqla
|
||||
import cyberbattle.agents.baseline.agent_wrapper as w
|
||||
import cyberbattle.agents.baseline.plotting as p
|
||||
import cyberbattle.agents.baseline.learner as learner
|
||||
|
||||
parser = argparse.ArgumentParser(description='Run simulation with DQL baseline agent.')
|
||||
|
||||
parser.add_argument('--training_episode_count', default=50, type=int,
|
||||
help='number of training epochs')
|
||||
|
||||
parser.add_argument('--eval_episode_count', default=10, type=int,
|
||||
help='number of evaluation epochs')
|
||||
|
||||
parser.add_argument('--iteration_count', default=9000, type=int,
|
||||
help='number of simulation iterations for each epoch')
|
||||
|
||||
parser.add_argument('--reward_goal', default=2180, type=int,
|
||||
help='minimum target rewards to reach for the attacker to reach its goal')
|
||||
|
||||
parser.add_argument('--ownership_goal', default=1.0, type=float,
|
||||
help='percentage of network nodes to own for the attacker to reach its goal')
|
||||
|
||||
parser.add_argument('--rewardplot_with', default=80, type=int,
|
||||
help='width of the reward plot (values are averaged across iterations to fit in the desired width)')
|
||||
|
||||
parser.add_argument('--chain_size', default=4, type=int,
|
||||
help='size of the chain of the CyberBattleChain sample environment')
|
||||
|
||||
parser.add_argument('--random_agent', dest='run_random_agent', action='store_true', help='run the random agent as a baseline for comparison')
|
||||
parser.add_argument('--no-random_agent', dest='run_random_agent', action='store_false', help='do not run the random agent as a baseline for comparison')
|
||||
parser.set_defaults(run_random_agent=True)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.ERROR, format="%(levelname)s: %(message)s")
|
||||
|
||||
print(f"torch cuda available={torch.cuda.is_available()}")
|
||||
|
||||
cyberbattlechain = gym.make('CyberBattleChain-v0',
|
||||
size=args.chain_size,
|
||||
attacker_goal=cyberbattle_env.AttackerGoal(
|
||||
own_atleast_percent=args.ownership_goal,
|
||||
reward=args.reward_goal))
|
||||
|
||||
ep = w.EnvironmentBounds.of_identifiers(
|
||||
maximum_total_credentials=22,
|
||||
maximum_node_count=22,
|
||||
identifiers=cyberbattlechain.identifiers
|
||||
)
|
||||
|
||||
all_runs = []
|
||||
|
||||
# Run Deep Q-learning
|
||||
dqn_learning_run = learner.epsilon_greedy_search(
|
||||
cyberbattle_gym_env=cyberbattlechain,
|
||||
environment_properties=ep,
|
||||
learner=dqla.DeepQLearnerPolicy(
|
||||
ep=ep,
|
||||
gamma=0.015,
|
||||
replay_memory_size=10000,
|
||||
target_update=10,
|
||||
batch_size=512,
|
||||
learning_rate=0.01), # torch default is 1e-2
|
||||
episode_count=args.training_episode_count,
|
||||
iteration_count=args.iteration_count,
|
||||
epsilon=0.90,
|
||||
render=False,
|
||||
# epsilon_multdecay=0.75, # 0.999,
|
||||
epsilon_exponential_decay=5000, # 10000
|
||||
epsilon_minimum=0.10,
|
||||
verbosity=Verbosity.Quiet,
|
||||
title="DQL"
|
||||
)
|
||||
|
||||
all_runs.append(dqn_learning_run)
|
||||
|
||||
if args.run_random_agent:
|
||||
random_run = learner.epsilon_greedy_search(
|
||||
cyberbattlechain,
|
||||
ep,
|
||||
learner=learner.RandomPolicy(),
|
||||
episode_count=args.eval_episode_count,
|
||||
iteration_count=args.iteration_count,
|
||||
epsilon=1.0, # purely random
|
||||
render=False,
|
||||
verbosity=Verbosity.Quiet,
|
||||
title="Random search"
|
||||
)
|
||||
all_runs.append(random_run)
|
||||
|
||||
colors = [asciichartpy.red, asciichartpy.green, asciichartpy.yellow, asciichartpy.blue]
|
||||
|
||||
print("Episode duration -- DQN=Red, Random=Green")
|
||||
print(asciichartpy.plot(p.episodes_lengths_for_all_runs(all_runs), {'height': 30, 'colors': colors}))
|
||||
|
||||
print("Cumulative rewards -- DQN=Red, Random=Green")
|
||||
c = p.averaged_cummulative_rewards(all_runs, args.rewardplot_with)
|
||||
print(asciichartpy.plot(c, {'height': 10, 'colors': colors}))
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Helper to run the random agent from a jupyter notebook"""
|
||||
|
||||
import cyberbattle._env.cyberbattle_env as cyberbattle_env
|
||||
import logging
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def run_random_agent(episode_count: int, iteration_count: int, gym_env: cyberbattle_env.CyberBattleEnv):
|
||||
"""Run a simple random agent on the specified gym environment and
|
||||
plot exploration graph and reward function
|
||||
"""
|
||||
|
||||
for i_episode in range(episode_count):
|
||||
observation = gym_env.reset()
|
||||
|
||||
total_reward = 0.0
|
||||
|
||||
for t in range(iteration_count):
|
||||
action = gym_env.sample_valid_action()
|
||||
|
||||
LOGGER.debug(f"action={action}")
|
||||
observation, reward, done, info = gym_env.step(action)
|
||||
|
||||
total_reward += reward
|
||||
|
||||
if reward > 0:
|
||||
print(f'+ rewarded action: {action} total_reward={total_reward} reward={reward} @t={t}')
|
||||
gym_env.render()
|
||||
|
||||
if done:
|
||||
print(f"Episode finished after {t+1} timesteps")
|
||||
break
|
||||
|
||||
gym_env.render()
|
||||
|
||||
gym_env.close()
|
||||
print("simulation ended")
|
|
@ -0,0 +1,286 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Defines a set of networks following a speficic pattern
|
||||
learnable from the properties associated with the nodes.
|
||||
|
||||
The network pattern is:
|
||||
Start ---> (Linux ---> Windows ---> ... Linux ---> Windows)* ---> Linux[Flag]
|
||||
|
||||
The network is parameterized by the length of the central Linux-Windows chain.
|
||||
The start node leaks the credentials to connect to all other nodes:
|
||||
|
||||
For each `XXX ---> Windows` section, the XXX node has:
|
||||
- a local vulnerability exposing the RDP password to the Windows machine
|
||||
- a bunch of other trap vulnerabilities (high cost with no outcome)
|
||||
For each `XXX ---> Linux` section,
|
||||
- the Windows node has a local vulnerability exposing the SSH password to the Linux machine
|
||||
- a bunch of other trap vulnerabilities (high cost with no outcome)
|
||||
|
||||
The chain is terminated by one node with a flag (reward).
|
||||
|
||||
A Node-Property matrix would be three-valued (0,1,?) and look like this:
|
||||
|
||||
===== Initial state
|
||||
Properties
|
||||
Nodes L W SQL
|
||||
1 1 0 0
|
||||
2 ? ? ?
|
||||
3 ? ? ?
|
||||
...
|
||||
10
|
||||
======= After discovering node 2
|
||||
Properties
|
||||
Nodes L W SQL
|
||||
1 1 0 0
|
||||
2 0 1 1
|
||||
3 ? ? ?
|
||||
...
|
||||
10
|
||||
===========================
|
||||
|
||||
"""
|
||||
from cyberbattle.simulation.model import Identifiers, NodeID, NodeInfo
|
||||
from ...simulation import model as m
|
||||
from typing import Dict
|
||||
|
||||
DEFAULT_ALLOW_RULES = [
|
||||
m.FirewallRule("RDP", m.RulePermission.ALLOW),
|
||||
m.FirewallRule("SSH", m.RulePermission.ALLOW),
|
||||
m.FirewallRule("HTTPS", m.RulePermission.ALLOW),
|
||||
m.FirewallRule("HTTP", m.RulePermission.ALLOW)]
|
||||
|
||||
# Environment constants used for all instances of the chain network
|
||||
ENV_IDENTIFIERS = Identifiers(
|
||||
properties=[
|
||||
'Windows',
|
||||
'Linux',
|
||||
'ApacheWebSite',
|
||||
'IIS_2019',
|
||||
'IIS_2020_patched',
|
||||
'MySql',
|
||||
'Ubuntu',
|
||||
'nginx/1.10.3',
|
||||
'SMB_vuln',
|
||||
'SMB_vuln_patched',
|
||||
'SQLServer',
|
||||
'Win10',
|
||||
'Win10Patched',
|
||||
'FLAG:Linux'
|
||||
],
|
||||
ports=[
|
||||
'HTTPS',
|
||||
'GIT',
|
||||
'SSH',
|
||||
'RDP',
|
||||
'PING',
|
||||
'MySQL',
|
||||
'SSH-key',
|
||||
'su'
|
||||
],
|
||||
local_vulnerabilities=[
|
||||
'ScanBashHistory',
|
||||
'ScanExplorerRecentFiles',
|
||||
'SudoAttempt',
|
||||
'CrackKeepPassX',
|
||||
'CrackKeepPass'
|
||||
],
|
||||
remote_vulnerabilities=[
|
||||
'ProbeLinux',
|
||||
'ProbeWindows'
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def prefix(x: int, name: str):
|
||||
"""Prefix node name with an instance"""
|
||||
return f"{x}_{name}"
|
||||
|
||||
|
||||
def rdp_password(index):
|
||||
"""Generate RDP password for the specified chain link"""
|
||||
return f"WindowsPassword!{index}"
|
||||
|
||||
|
||||
def ssh_password(index):
|
||||
"""Generate SSH password for the specified chain link"""
|
||||
return f"LinuxPassword!{index}"
|
||||
|
||||
|
||||
def create_network_chain_link(n: int) -> Dict[NodeID, NodeInfo]:
|
||||
"""Instantiate one link of the network chain with associated index n"""
|
||||
|
||||
def current(name):
|
||||
return prefix(n, name)
|
||||
|
||||
def next(name):
|
||||
return prefix(n + 1, name)
|
||||
|
||||
return {
|
||||
current("LinuxNode"): m.NodeInfo(
|
||||
services=[m.ListeningService("HTTPS"),
|
||||
m.ListeningService("SSH", allowedCredentials=[ssh_password(n)])],
|
||||
firewall=m.FirewallConfiguration(incoming=DEFAULT_ALLOW_RULES,
|
||||
outgoing=DEFAULT_ALLOW_RULES),
|
||||
value=100,
|
||||
properties=["MySql", "Ubuntu", "nginx/1.10.3"],
|
||||
owned_string="Intermediate chain node owned, no intinsic value",
|
||||
vulnerabilities=dict(
|
||||
ProbeLinux=m.VulnerabilityInfo(
|
||||
description="Probe to check if the node runs Linux",
|
||||
type=m.VulnerabilityType.REMOTE,
|
||||
outcome=m.ProbeSucceeded(["Ubuntu"]),
|
||||
reward_string="Remote machine is running Linux",
|
||||
cost=5.0
|
||||
),
|
||||
ProbeWindows=m.VulnerabilityInfo(
|
||||
description="Probe to check if the node runs Windows",
|
||||
type=m.VulnerabilityType.REMOTE,
|
||||
outcome=m.ProbeFailed(),
|
||||
reward_string="Remote machine is not running Windows",
|
||||
cost=5.0
|
||||
),
|
||||
ScanBashHistory=m.VulnerabilityInfo(
|
||||
description="Scan bash history for possible references to other machines",
|
||||
type=m.VulnerabilityType.LOCAL,
|
||||
outcome=m.LeakedNodesId([next("WindowsNode")]),
|
||||
reward_string="Found a reference to a remote Windows node in bash history",
|
||||
cost=1.0
|
||||
),
|
||||
ScanExplorerRecentFiles=m.VulnerabilityInfo(
|
||||
description="Scan Windows Explorer recent files for possible references to other machines",
|
||||
type=m.VulnerabilityType.LOCAL,
|
||||
outcome=m.ExploitFailed(),
|
||||
reward_string="Trap: feature not supported on Linux",
|
||||
cost=10.0
|
||||
),
|
||||
SudoAttempt=m.VulnerabilityInfo(
|
||||
description="Attempt to sudo into admin user",
|
||||
type=m.VulnerabilityType.LOCAL,
|
||||
outcome=m.ExploitFailed(),
|
||||
reward_string="Trap: suspicious attempt to run sudo",
|
||||
cost=100.0
|
||||
),
|
||||
CrackKeepPassX=m.VulnerabilityInfo(
|
||||
description="Attempt to crack KeepPassX and look for credentials",
|
||||
type=m.VulnerabilityType.LOCAL,
|
||||
outcome=m.LeakedCredentials(credentials=[
|
||||
m.CachedCredential(node=next("WindowsNode"), port="RDP",
|
||||
credential=rdp_password(n + 1))]),
|
||||
reward_string=f"Discovered password to Windows machine {n+1}",
|
||||
cost=1.0
|
||||
))),
|
||||
|
||||
next("WindowsNode"): m.NodeInfo(
|
||||
services=[m.ListeningService("HTTPS"),
|
||||
m.ListeningService("RDP", allowedCredentials=[rdp_password(n + 1)])],
|
||||
value=100,
|
||||
properties=["Windows", "Win10", "Win10Patched"],
|
||||
vulnerabilities=dict(
|
||||
ProbeLinux=m.VulnerabilityInfo(
|
||||
description="Probe to check if the node runs Linux",
|
||||
type=m.VulnerabilityType.REMOTE,
|
||||
outcome=m.ProbeFailed(),
|
||||
reward_string="Remote machine is not running Linux",
|
||||
cost=1.0
|
||||
),
|
||||
ProbeWindows=m.VulnerabilityInfo(
|
||||
description="Probe to check if the node runs Windows",
|
||||
type=m.VulnerabilityType.REMOTE,
|
||||
outcome=m.ProbeSucceeded(["Windows"]),
|
||||
reward_string="Remote machine is running Windows",
|
||||
cost=1.0
|
||||
),
|
||||
ScanBashHistory=m.VulnerabilityInfo(
|
||||
description="Scan bash history for possible references to other machines",
|
||||
type=m.VulnerabilityType.LOCAL,
|
||||
outcome=m.ExploitFailed(),
|
||||
reward_string="Trap: feature not supported on Windows!",
|
||||
cost=100.0
|
||||
),
|
||||
ScanExplorerRecentFiles=m.VulnerabilityInfo(
|
||||
description="Scan Windows Explorer recent files for possible references to other machines",
|
||||
type=m.VulnerabilityType.LOCAL,
|
||||
outcome=m.LeakedNodesId([prefix(n + 2, "LinuxNode")]),
|
||||
reward_string="Found a reference to a remote Linux node in bash history",
|
||||
cost=1.0
|
||||
),
|
||||
SudoAttempt=m.VulnerabilityInfo(
|
||||
description="Attempt to sudo into admin user",
|
||||
type=m.VulnerabilityType.LOCAL,
|
||||
outcome=m.ExploitFailed(),
|
||||
reward_string="Trap: feature not supported on Windows!",
|
||||
cost=100.0
|
||||
),
|
||||
CrackKeepPassX=m.VulnerabilityInfo(
|
||||
description="Attempt to crack KeepPassX and look for credentials",
|
||||
type=m.VulnerabilityType.LOCAL,
|
||||
outcome=m.ExploitFailed(),
|
||||
reward_string="Trap: feature not supported on Windows!",
|
||||
cost=100.0
|
||||
),
|
||||
CrackKeepPass=m.VulnerabilityInfo(
|
||||
description="Attempt to crack KeepPass and look for credentials",
|
||||
type=m.VulnerabilityType.LOCAL,
|
||||
outcome=m.LeakedCredentials(credentials=[
|
||||
m.CachedCredential(node=prefix(n + 2, "LinuxNode"), port="SSH",
|
||||
credential=ssh_password(n + 2))]),
|
||||
reward_string=f"Discovered password to Linux machine {n+2}",
|
||||
cost=1.0
|
||||
)
|
||||
))
|
||||
}
|
||||
|
||||
|
||||
def create_chain_network(size: int) -> Dict[NodeID, NodeInfo]:
|
||||
"""Create a chain network with the chain section of specified size.
|
||||
Size must be an even number
|
||||
The number of nodes in the network is `size + 2` to account for the start node (0)
|
||||
and final node (size + 1).
|
||||
"""
|
||||
|
||||
if size % 2 == 1:
|
||||
raise ValueError(f"Chain size must be even: {size}")
|
||||
|
||||
final_node_index = size + 1
|
||||
|
||||
nodes = {
|
||||
'start': m.NodeInfo(
|
||||
services=[],
|
||||
value=0,
|
||||
vulnerabilities=dict(
|
||||
ScanExplorerRecentFiles=m.VulnerabilityInfo(
|
||||
description="Scan Windows Explorer recent files for possible references to other machines",
|
||||
type=m.VulnerabilityType.LOCAL,
|
||||
outcome=m.LeakedCredentials(credentials=[
|
||||
m.CachedCredential(node=prefix(1, "LinuxNode"), port="SSH",
|
||||
credential=ssh_password(1))]),
|
||||
reward_string="Found a reference to a remote Linux node in bash history",
|
||||
cost=1.0
|
||||
)),
|
||||
agent_installed=True,
|
||||
reimagable=False),
|
||||
|
||||
prefix(final_node_index, "LinuxNode"): m.NodeInfo(
|
||||
services=[m.ListeningService("HTTPS"),
|
||||
m.ListeningService("SSH", allowedCredentials=[ssh_password(final_node_index)])],
|
||||
value=1000,
|
||||
owned_string="FLAG: flag discovered!",
|
||||
properties=["MySql", "Ubuntu", "nginx/1.10.3", "FLAG:Linux"],
|
||||
vulnerabilities=dict()
|
||||
)
|
||||
}
|
||||
|
||||
# Add chain links
|
||||
for i in range(1, size, 2):
|
||||
nodes.update(create_network_chain_link(i))
|
||||
|
||||
return nodes
|
||||
|
||||
|
||||
def new_environment(size) -> m.Environment:
|
||||
return m.Environment(
|
||||
network=m.create_network(create_chain_network(size)),
|
||||
vulnerability_library=dict([]),
|
||||
identifiers=ENV_IDENTIFIERS
|
||||
)
|
|
@ -0,0 +1,46 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""A simple test sandbox to play with creation of simulation environments"""
|
||||
import networkx as nx
|
||||
import yaml
|
||||
|
||||
from cyberbattle.simulation import model, model_test, actions_test
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Simple environment sandbox"""
|
||||
|
||||
# Create a toy graph
|
||||
graph = nx.DiGraph()
|
||||
graph.add_edges_from([('a', 'b'), ('b', 'c')])
|
||||
print(graph)
|
||||
|
||||
# create a random graph
|
||||
graph = nx.cubical_graph()
|
||||
|
||||
graph = model.assign_random_labels(graph)
|
||||
|
||||
vulnerabilities = actions_test.SAMPLE_VULNERABILITIES
|
||||
|
||||
model.setup_yaml_serializer()
|
||||
|
||||
# Define an environment from this graph
|
||||
env = model.Environment(
|
||||
network=graph,
|
||||
vulnerability_library=vulnerabilities,
|
||||
identifiers=actions_test.ENV_IDENTIFIERS
|
||||
)
|
||||
|
||||
model_test.check_reserializing(env)
|
||||
|
||||
model_test.check_reserializing(vulnerabilities)
|
||||
|
||||
# Save the environment to file as Yaml
|
||||
with open('./simpleenv.yaml', 'w') as file:
|
||||
yaml.dump(env, file)
|
||||
print(yaml.dump(env))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,240 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Model a toy Capture the flag exercise
|
||||
|
||||
See Jupyter notebook toyctf-simulation.ipynb for an example of
|
||||
game played on this simulation.
|
||||
"""
|
||||
from cyberbattle.simulation import model as m
|
||||
from cyberbattle.simulation.model import NodeID, NodeInfo, VulnerabilityID, VulnerabilityInfo
|
||||
from typing import Dict, Iterator, cast, Tuple
|
||||
|
||||
default_allow_rules = [
|
||||
m.FirewallRule("RDP", m.RulePermission.ALLOW),
|
||||
m.FirewallRule("SSH", m.RulePermission.ALLOW),
|
||||
m.FirewallRule("HTTPS", m.RulePermission.ALLOW),
|
||||
m.FirewallRule("HTTP", m.RulePermission.ALLOW)]
|
||||
|
||||
# Network nodes involved in the Capture the flag game
|
||||
nodes = {
|
||||
"Website": m.NodeInfo(
|
||||
services=[m.ListeningService("HTTPS"),
|
||||
m.ListeningService("SSH", allowedCredentials=[
|
||||
"ReusedMySqlCred-web"])],
|
||||
firewall=m.FirewallConfiguration(incoming=default_allow_rules,
|
||||
outgoing=default_allow_rules + [
|
||||
m.FirewallRule("su", m.RulePermission.ALLOW),
|
||||
m.FirewallRule("sudo", m.RulePermission.ALLOW)]),
|
||||
value=100,
|
||||
# If can SSH into server then gets FLAG "Shared credentials with
|
||||
# database user"
|
||||
properties=["MySql", "Ubuntu", "nginx/1.10.3"],
|
||||
owned_string="FLAG: Login using insecure SSH user/password",
|
||||
vulnerabilities=dict(
|
||||
ScanPageContent=m.VulnerabilityInfo(
|
||||
description="LeakedGitHubProjectUrl: Website page content shows a link to GitHub "
|
||||
"repo",
|
||||
type=m.VulnerabilityType.REMOTE,
|
||||
outcome=m.LeakedNodesId(["GitHubProject"]),
|
||||
reward_string="WEBSITE page content has a link to github -> Github project discovered!",
|
||||
cost=1.0
|
||||
),
|
||||
ScanPageSource=m.VulnerabilityInfo(
|
||||
description="Website page source contains refrence to browseable "
|
||||
"relative web directory",
|
||||
type=m.VulnerabilityType.REMOTE,
|
||||
outcome=m.LeakedNodesId(["Website.Directory"]),
|
||||
reward_string="Viewing the web page source reveals a URL to a .txt file and directory on the website",
|
||||
cost=1.0
|
||||
),
|
||||
CredScanBashHistory=m.VulnerabilityInfo(
|
||||
description="bash history leaking creds - FLAG Stealing "
|
||||
"credentials for the monitoring user",
|
||||
type=m.VulnerabilityType.LOCAL,
|
||||
outcome=m.LeakedCredentials(credentials=[
|
||||
m.CachedCredential(node="Website[user=monitor]", port="SSH",
|
||||
credential="monitorBashCreds")]),
|
||||
reward_string="FLAG: SSH history revealed credentials for the monitoring user (monitor)",
|
||||
cost=1.0
|
||||
))),
|
||||
|
||||
"Website.Directory": m.NodeInfo(
|
||||
services=[m.ListeningService("HTTPS")],
|
||||
value=50,
|
||||
properties=["Ubuntu", "nginx/1.10.3",
|
||||
"CTFFLAG:Readme.txt-Discover secret data"
|
||||
],
|
||||
vulnerabilities=dict(
|
||||
NavigateWebDirectoryFurther=m.VulnerabilityInfo(
|
||||
description="Discover MYSQL credentials MySql for user "
|
||||
"'web' in (getting-started.txt)",
|
||||
type=m.VulnerabilityType.REMOTE,
|
||||
outcome=m.LeakedCredentials(credentials=[
|
||||
m.CachedCredential(node="Website", port="MySQL",
|
||||
credential="ReusedMySqlCred-web")]),
|
||||
reward_string="Discover browseable web directory: Navigating to parent URL revealed file `readme.txt`"
|
||||
"with secret data (aflag); and `getting-started.txt` with MYSQL credentials",
|
||||
cost=1.0
|
||||
),
|
||||
NavigateWebDirectory=m.VulnerabilityInfo(
|
||||
description="Discover URL to external sharepoint website "
|
||||
"(in deprecation-checklist.txt)",
|
||||
type=m.VulnerabilityType.REMOTE,
|
||||
outcome=m.LeakedNodesId(["Sharepoint"]),
|
||||
reward_string="Navigating to parent URL revealed file `deprecation-checklist.txt` containing"
|
||||
"a URL to an external sharepoint website",
|
||||
cost=1.0
|
||||
)
|
||||
)),
|
||||
|
||||
"Website[user=monitor]": m.NodeInfo(
|
||||
services=[m.ListeningService("SSH", allowedCredentials=[]),
|
||||
m.ListeningService("SSH-key", allowedCredentials=["unkownkey"]),
|
||||
m.ListeningService("su", allowedCredentials=["monitorBashCreds"])],
|
||||
value=100,
|
||||
properties=["MySql", "Ubuntu", "nginx/1.10.3"],
|
||||
owned_string="FLAG User escalation by stealing credentials from bash history",
|
||||
firewall=m.FirewallConfiguration(
|
||||
outgoing=default_allow_rules,
|
||||
incoming=[m.FirewallRule("SSH", m.RulePermission.BLOCK,
|
||||
reason="password authentication disabled! SSH needs private key to authenticate."),
|
||||
m.FirewallRule("sudo", m.RulePermission.BLOCK,
|
||||
reason="`sudo -u monitor` failed. User 'monitor' not sudoable."
|
||||
"This warning will be reported!"),
|
||||
m.FirewallRule("su", m.RulePermission.ALLOW)] + default_allow_rules
|
||||
),
|
||||
vulnerabilities={
|
||||
"CredScan-HomeDirectory":
|
||||
m.VulnerabilityInfo(
|
||||
description="azurecredential.txt file in home directory",
|
||||
type=m.VulnerabilityType.LOCAL,
|
||||
outcome=m.LeakedCredentials(credentials=[
|
||||
m.CachedCredential(
|
||||
node="AzureResourceManager[user=monitor]",
|
||||
port="HTTPS",
|
||||
credential="azuread_user_credentials")]),
|
||||
reward_string="SSH: cat ~/azurecreds.txt (running as monitor) revealed Azure user credential!",
|
||||
cost=1.0),
|
||||
}),
|
||||
|
||||
"GitHubProject": m.NodeInfo(
|
||||
services=[m.ListeningService("GIT")],
|
||||
value=10,
|
||||
properties=["GitHub", "SasUrlInCommit"],
|
||||
vulnerabilities=dict(
|
||||
CredScanGitHistory=m.VulnerabilityInfo(
|
||||
description="Some secure access token (SAS) leaked in a "
|
||||
"reverted git commit",
|
||||
type=m.VulnerabilityType.REMOTE,
|
||||
precondition=m.Precondition('SasUrlInCommit&GitHub'),
|
||||
outcome=m.LeakedCredentials(credentials=[
|
||||
m.CachedCredential(node="AzureStorage",
|
||||
port="HTTPS",
|
||||
credential="SASTOKEN1")]),
|
||||
rates=m.Rates(probingDetectionRate=0.0,
|
||||
exploitDetectionRate=0.0,
|
||||
successRate=1.0),
|
||||
reward_string="CredScan success: Some secure access token (SAS) was leaked in a reverted git commit",
|
||||
cost=1.0
|
||||
))),
|
||||
|
||||
"AzureStorage": m.NodeInfo(
|
||||
services=[
|
||||
m.ListeningService("HTTPS", allowedCredentials=["SASTOKEN1"])],
|
||||
value=50,
|
||||
properties=["CTFFLAG:LeakedCustomerData"],
|
||||
vulnerabilities=dict(
|
||||
AccessDataWithSASToken=m.VulnerabilityInfo(
|
||||
description="Stealing secrets using a publicly shared "
|
||||
"SAS token",
|
||||
type=m.VulnerabilityType.REMOTE,
|
||||
outcome=m.CustomerData(),
|
||||
rates=m.Rates(successRate=1.0),
|
||||
reward_string="Stole data using a publicly shared SAS token",
|
||||
cost=1.0
|
||||
)
|
||||
)),
|
||||
|
||||
'Sharepoint': m.NodeInfo(
|
||||
services=[m.ListeningService("HTTPS")],
|
||||
value=100,
|
||||
properties=["SharepointLeakingPassword"],
|
||||
firewall=m.FirewallConfiguration(incoming=[m.FirewallRule("SSH", m.RulePermission.ALLOW),
|
||||
m.FirewallRule("HTTP", m.RulePermission.ALLOW),
|
||||
m.FirewallRule("HTTPS", m.RulePermission.ALLOW)],
|
||||
outgoing=[]),
|
||||
vulnerabilities=dict(
|
||||
ScanSharepointParentDirectory=m.VulnerabilityInfo(
|
||||
description="Navigate to SharePoint site, browse parent "
|
||||
"directory",
|
||||
type=m.VulnerabilityType.REMOTE,
|
||||
outcome=m.LeakedCredentials(credentials=[
|
||||
m.CachedCredential(node="AzureResourceManager",
|
||||
port="HTTPS",
|
||||
credential="ADPrincipalCreds")]),
|
||||
rates=m.Rates(successRate=1.0),
|
||||
reward_string="Navigating to the Sharepoint site revealed AD Service Principal Credentials",
|
||||
cost=1.0)
|
||||
)),
|
||||
|
||||
'AzureResourceManager': m.NodeInfo(
|
||||
services=[m.ListeningService("HTTPS", allowedCredentials=["ADPrincipalCreds", "azuread_user_credentials"])],
|
||||
owned_string="FLAG: Shared credentials with database user - Obtained secrets hidden in Azure Managed Resources",
|
||||
value=50,
|
||||
properties=["CTFFLAG:LeakedCustomerData2"],
|
||||
vulnerabilities=dict(
|
||||
ListAzureResources=m.VulnerabilityInfo(
|
||||
description="AzureVM info, including public IP address",
|
||||
type=m.VulnerabilityType.REMOTE,
|
||||
outcome=m.LeakedNodesId(["AzureVM"]),
|
||||
reward_string="Obtained Azure VM and public IP information",
|
||||
cost=1.0
|
||||
))),
|
||||
|
||||
'AzureResourceManager[user=monitor]': m.NodeInfo(
|
||||
services=[m.ListeningService("HTTPS", allowedCredentials=["azuread_user_credentials"])],
|
||||
owned_string="More secrets stolen when logged as interactive `monitor` user in Azure with `az`",
|
||||
value=50,
|
||||
properties=[],
|
||||
),
|
||||
|
||||
'AzureVM': m.NodeInfo(
|
||||
services=[m.ListeningService("PING"),
|
||||
m.ListeningService("SSH")],
|
||||
value=100,
|
||||
properties=["CTFFLAG:VMPRIVATEINFO"],
|
||||
firewall=m.FirewallConfiguration(
|
||||
incoming=[m.FirewallRule("SSH", m.RulePermission.BLOCK,
|
||||
reason="internet incoming traffic blocked on the VM by NSG firewall")],
|
||||
outgoing=[])),
|
||||
|
||||
'client': m.NodeInfo(
|
||||
services=[],
|
||||
value=0,
|
||||
vulnerabilities=dict(
|
||||
SearchEdgeHistory=m.VulnerabilityInfo(
|
||||
description="Search web history for list of accessed websites",
|
||||
type=m.VulnerabilityType.LOCAL,
|
||||
outcome=m.LeakedNodesId(["Website"]),
|
||||
reward_string="Web browser history revealed website URL of interest",
|
||||
cost=1.0
|
||||
)),
|
||||
agent_installed=True,
|
||||
reimagable=False),
|
||||
}
|
||||
|
||||
global_vulnerability_library: Dict[VulnerabilityID, VulnerabilityInfo] = dict([])
|
||||
|
||||
# Environment constants
|
||||
ENV_IDENTIFIERS = m.infer_constants_from_nodes(
|
||||
cast(Iterator[Tuple[NodeID, NodeInfo]], list(nodes.items())),
|
||||
global_vulnerability_library)
|
||||
|
||||
|
||||
def new_environment() -> m.Environment:
|
||||
return m.Environment(
|
||||
network=m.create_network(nodes),
|
||||
vulnerability_library=global_vulnerability_library,
|
||||
identifiers=ENV_IDENTIFIERS
|
||||
)
|
|
@ -0,0 +1,649 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
actions.py
|
||||
This file contains the class and associated methods for the AgentActions
|
||||
class which interacts directly with the environment. It is the class
|
||||
which both the user and RL agents should manipulate the environment.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
import dataclasses
|
||||
from datetime import time
|
||||
import boolean
|
||||
from collections import OrderedDict
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Iterator, List, NamedTuple, Optional, Set, Tuple, Dict, TypedDict, cast
|
||||
import IPython.core.display as d
|
||||
import pandas as pd
|
||||
|
||||
from cyberbattle.simulation.model import MachineStatus, PrivilegeLevel, PropertyName, VulnerabilityID, VulnerabilityType
|
||||
from . import model
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
Reward = float
|
||||
|
||||
DiscoveredNodeInfo = TypedDict('DiscoveredNodeInfo', {
|
||||
'id': model.NodeID,
|
||||
'status': str
|
||||
})
|
||||
|
||||
|
||||
class Penalty:
|
||||
"""Penalties (=negative reward) returned for some actions taken in the simulation"""
|
||||
# penalty for generic suspiciousness
|
||||
SUPSPICIOUSNESS = -5.0
|
||||
|
||||
# penalty for attempting a connection to a port that was not open
|
||||
SCANNING_UNOPEN_PORT = -10.0
|
||||
|
||||
# penalty for repeating the same exploit attempt
|
||||
REPEAT = -1
|
||||
|
||||
LOCAL_EXPLOIT_FAILED = -20
|
||||
FAILED_REMOTE_EXPLOIT = -50
|
||||
|
||||
# penalty for attempting to connect or execute an action on a node that's not in running state
|
||||
MACHINE_NOT_RUNNING = 0
|
||||
|
||||
# penalty for attempting a connection with an invalid password
|
||||
WRONG_PASSWORD = -10
|
||||
|
||||
# traffice blocked by outoing rule in a local firewall
|
||||
BLOCKED_BY_LOCAL_FIREWALL = -10
|
||||
|
||||
# traffice blocked by incoming rule in a remote firewall
|
||||
BLOCKED_BY_REMOTE_FIREWALL = -10
|
||||
|
||||
|
||||
# Reward for any successfully executed local or remote attack
|
||||
# (the attack cost gets substracted from this reward)
|
||||
SUCCEEDED_ATTACK_REWARD = 50
|
||||
|
||||
|
||||
class EdgeAnnotation(Enum):
|
||||
"""Annotation added to the network edges created as the simulation is played"""
|
||||
KNOWS = 0
|
||||
REMOTE_EXPLOIT = 1
|
||||
LATERAL_MOVE = 2
|
||||
|
||||
|
||||
class ActionResult(NamedTuple):
|
||||
"""Result from executing an action"""
|
||||
reward: Reward
|
||||
outcome: Optional[model.VulnerabilityOutcome]
|
||||
|
||||
|
||||
ALGEBRA = boolean.BooleanAlgebra()
|
||||
ALGEBRA.TRUE.dual = type(ALGEBRA.FALSE)
|
||||
ALGEBRA.FALSE.dual = type(ALGEBRA.TRUE)
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeTrackingInformation:
|
||||
"""Track information about nodes gathered throughout the simulation"""
|
||||
# Map (vulnid, local_or_remote) to time of last attack.
|
||||
# local_or_remote is true for local, false for remote
|
||||
last_attack: Dict[Tuple[model.VulnerabilityID, bool], time] = dataclasses.field(default_factory=dict)
|
||||
# Last time another node connected to this node
|
||||
last_connection: Optional[time] = None
|
||||
# All node properties discovered so far
|
||||
discovered_properties: Set[int] = dataclasses.field(default_factory=set)
|
||||
|
||||
|
||||
class AgentActions:
|
||||
"""
|
||||
This is the AgentActions class. It interacts with and makes changes to the environment.
|
||||
"""
|
||||
|
||||
def __init__(self, environment: model.Environment):
|
||||
"""
|
||||
AgentActions Constructor
|
||||
"""
|
||||
self._environment = environment
|
||||
self._gathered_credentials: Set[model.CredentialID] = set()
|
||||
self._discovered_nodes: "OrderedDict[model.NodeID, NodeTrackingInformation]" = OrderedDict()
|
||||
|
||||
# List of all special tags indicating a privilege level reached on a node
|
||||
self.privilege_tags = [model.PrivilegeEscalation(p).tag for p in list(PrivilegeLevel)]
|
||||
|
||||
# Mark all owned nodes as discovered
|
||||
for i, node in environment.nodes():
|
||||
if node.agent_installed:
|
||||
self.__mark_node_as_owned(i, PrivilegeLevel.LocalUser)
|
||||
|
||||
def discovered_nodes(self) -> Iterator[Tuple[model.NodeID, model.NodeInfo]]:
|
||||
for node_id in self._discovered_nodes:
|
||||
yield (node_id, self._environment.get_node(node_id))
|
||||
|
||||
def _check_prerequisites(self, target: model.NodeID, vulnerability: model.VulnerabilityInfo) -> bool:
|
||||
"""
|
||||
This is a quick helper function to check the prerequisites to see if
|
||||
they match the ones supplied.
|
||||
"""
|
||||
node: model.NodeInfo = self._environment.network.nodes[target]['data']
|
||||
node_flags = node.properties
|
||||
expr = vulnerability.precondition.expression
|
||||
|
||||
# this line seems redundant but it is necessary to declare the symbols used in the mapping
|
||||
# pylint: disable=unused-variable
|
||||
|
||||
mapping = {i: ALGEBRA.TRUE if str(i) in node_flags else ALGEBRA.FALSE
|
||||
for i in expr.get_symbols()}
|
||||
is_true: bool = cast(boolean.Expression, expr.subs(mapping)).simplify() == ALGEBRA.TRUE
|
||||
return is_true
|
||||
|
||||
def list_vulnerabilities_in_target(
|
||||
self,
|
||||
target: model.NodeID,
|
||||
type_filter: Optional[model.VulnerabilityType] = None) -> List[model.VulnerabilityID]:
|
||||
"""
|
||||
This function takes a model.NodeID for the target to be scanned
|
||||
and returns a list of vulnerability IDs.
|
||||
It checks each vulnerability in the library against the the properties of a given node
|
||||
and determines which vulnerabilities it has.
|
||||
"""
|
||||
if not self._environment.network.has_node(target):
|
||||
raise ValueError(f"invalid node id '{target}'")
|
||||
|
||||
target_node_data: model.NodeInfo = self._environment.get_node(target)
|
||||
|
||||
global_vuln: Set[model.VulnerabilityID] = {
|
||||
vuln_id
|
||||
for vuln_id, vulnerability in self._environment.vulnerability_library.items()
|
||||
if (type_filter is None or vulnerability.type == type_filter)
|
||||
and self._check_prerequisites(target, vulnerability)
|
||||
}
|
||||
|
||||
local_vuln: Set[model.VulnerabilityID] = {
|
||||
vuln_id
|
||||
for vuln_id, vulnerability in target_node_data.vulnerabilities.items()
|
||||
if (type_filter is None or vulnerability.type == type_filter)
|
||||
and self._check_prerequisites(target, vulnerability)
|
||||
}
|
||||
|
||||
return list(global_vuln.union(local_vuln))
|
||||
|
||||
def __annotate_edge(self, source_node_id: model.NodeID,
|
||||
target_node_id: model.NodeID,
|
||||
new_annotation: EdgeAnnotation) -> None:
|
||||
"""Create the edge if it does not already exist, and annotate with the maximum
|
||||
of the existing annotation and a specified new annotation"""
|
||||
edge_annotation = self._environment.network.get_edge_data(source_node_id, target_node_id)
|
||||
if edge_annotation is not None:
|
||||
if 'kind' in edge_annotation:
|
||||
new_annotation = EdgeAnnotation(max(edge_annotation['kind'].value, new_annotation.value))
|
||||
else:
|
||||
new_annotation = new_annotation.value
|
||||
self._environment.network.add_edge(source_node_id, target_node_id, kind=new_annotation, kind_as_float=float(new_annotation.value))
|
||||
|
||||
def get_discovered_properties(self, node_id: model.NodeID) -> Set[int]:
|
||||
return self._discovered_nodes[node_id].discovered_properties
|
||||
|
||||
def __mark_node_as_discovered(self, node_id: model.NodeID) -> None:
|
||||
logger.info('discovered node: ' + node_id)
|
||||
if node_id not in self._discovered_nodes:
|
||||
self._discovered_nodes[node_id] = NodeTrackingInformation()
|
||||
|
||||
def __mark_nodeproperties_as_discovered(self, node_id: model.NodeID, properties: List[PropertyName]):
|
||||
properties_indices = [self._environment.identifiers.properties.index(p)
|
||||
for p in properties
|
||||
if p not in self.privilege_tags]
|
||||
|
||||
if node_id in self._discovered_nodes:
|
||||
self._discovered_nodes[node_id].discovered_properties = self._discovered_nodes[node_id].discovered_properties.union(properties_indices)
|
||||
else:
|
||||
self._discovered_nodes[node_id] = NodeTrackingInformation(discovered_properties=set(properties_indices))
|
||||
|
||||
def __mark_allnodeproperties_as_discovered(self, node_id: model.NodeID):
|
||||
node_info: model.NodeInfo = self._environment.network.nodes[node_id]['data']
|
||||
self.__mark_nodeproperties_as_discovered(node_id, node_info.properties)
|
||||
|
||||
def __mark_node_as_owned(self,
|
||||
node_id: model.NodeID,
|
||||
privilege: PrivilegeLevel = model.PrivilegeLevel.LocalUser) -> None:
|
||||
if node_id not in self._discovered_nodes:
|
||||
self._discovered_nodes[node_id] = NodeTrackingInformation()
|
||||
node_info = self._environment.get_node(node_id)
|
||||
node_info.agent_installed = True
|
||||
node_info.privilege_level = model.escalate(node_info.privilege_level, privilege)
|
||||
self._environment.network.nodes[node_id].update({'data': node_info})
|
||||
|
||||
self.__mark_allnodeproperties_as_discovered(node_id)
|
||||
|
||||
def __mark_discovered_entities(self, reference_node: model.NodeID, outcome: model.VulnerabilityOutcome) -> None:
|
||||
if isinstance(outcome, model.LeakedCredentials):
|
||||
for credential in outcome.credentials:
|
||||
self.__mark_node_as_discovered(credential.node)
|
||||
self._gathered_credentials.add(credential.credential)
|
||||
logger.info('discovered credential: ' + str(credential))
|
||||
self.__annotate_edge(reference_node, credential.node, EdgeAnnotation.KNOWS)
|
||||
|
||||
elif isinstance(outcome, model.LeakedNodesId):
|
||||
for node_id in outcome.nodes:
|
||||
self.__mark_node_as_discovered(node_id)
|
||||
self.__annotate_edge(reference_node, node_id, EdgeAnnotation.KNOWS)
|
||||
|
||||
def get_node_privilegelevel(self, node_id: model.NodeID) -> model.PrivilegeLevel:
|
||||
"""Return the last recorded privilege level of the specified node"""
|
||||
node_info = self._environment.get_node(node_id)
|
||||
return node_info.privilege_level
|
||||
|
||||
def get_nodes_with_atleast_privilegelevel(self, level: PrivilegeLevel) -> List[model.NodeID]:
|
||||
"""Return all nodes with at least the specified privilege level"""
|
||||
return [n for n, info in self._environment.nodes() if info.privilege_level >= level]
|
||||
|
||||
def is_node_discovered(self, node_id: model.NodeID) -> bool:
|
||||
"""Returns true if previous actions have revealed the specified node ID"""
|
||||
return node_id in self._discovered_nodes
|
||||
|
||||
def __process_outcome(self,
|
||||
expected_type: VulnerabilityType,
|
||||
vulnerability_id: VulnerabilityID,
|
||||
node_id: model.NodeID,
|
||||
node_info: model.NodeInfo,
|
||||
local_or_remote: bool,
|
||||
failed_penalty: float,
|
||||
throw_if_vulnerability_not_present: bool
|
||||
) -> Tuple[bool, ActionResult]:
|
||||
|
||||
if node_info.status != model.MachineStatus.Running:
|
||||
logger.info("target machine not in running state")
|
||||
return False, ActionResult(reward=Penalty.MACHINE_NOT_RUNNING,
|
||||
outcome=None)
|
||||
|
||||
is_global_vulnerability = vulnerability_id in self._environment.vulnerability_library
|
||||
is_inplace_vulnerability = vulnerability_id in node_info.vulnerabilities
|
||||
|
||||
if is_global_vulnerability:
|
||||
vulnerabilities = self._environment.vulnerability_library
|
||||
elif is_inplace_vulnerability:
|
||||
vulnerabilities = node_info.vulnerabilities
|
||||
else:
|
||||
if throw_if_vulnerability_not_present:
|
||||
raise ValueError(f"Vulnerability '{vulnerability_id}' not supported by node='{node_id}'")
|
||||
else:
|
||||
logger.info(f"Vulnerability '{vulnerability_id}' not supported by node '{node_id}'")
|
||||
return False, ActionResult(reward=Penalty.SUPSPICIOUSNESS, outcome=None)
|
||||
|
||||
vulnerability = vulnerabilities[vulnerability_id]
|
||||
|
||||
outcome = vulnerability.outcome
|
||||
|
||||
if vulnerability.type != expected_type:
|
||||
raise ValueError(f"vulnerability id '{vulnerability_id}' is for an attack of type {vulnerability.type}, expecting: {expected_type}")
|
||||
|
||||
# check vulnerability prerequisites
|
||||
if not self._check_prerequisites(node_id, vulnerability):
|
||||
return False, ActionResult(reward=failed_penalty, outcome=model.ExploitFailed())
|
||||
|
||||
# if the vulnerability type is a privilege escalation
|
||||
# and if the escalation level is not already reached on that node,
|
||||
# then add the escalation tag to the node properties
|
||||
if isinstance(outcome, model.PrivilegeEscalation):
|
||||
if outcome.tag in node_info.properties:
|
||||
return False, ActionResult(reward=Penalty.REPEAT, outcome=outcome)
|
||||
|
||||
self.__mark_node_as_owned(node_id, outcome.level)
|
||||
|
||||
node_info.properties.append(outcome.tag)
|
||||
|
||||
elif isinstance(outcome, model.ProbeSucceeded):
|
||||
for p in outcome.discovered_properties:
|
||||
assert p in node_info.properties, \
|
||||
f'Discovered property {p} must belong to the set of properties associated with the node.'
|
||||
|
||||
self.__mark_nodeproperties_as_discovered(node_id, outcome.discovered_properties)
|
||||
|
||||
if node_id not in self._discovered_nodes:
|
||||
self._discovered_nodes[node_id] = NodeTrackingInformation()
|
||||
|
||||
lookup_key = (vulnerability_id, local_or_remote)
|
||||
|
||||
already_executed = lookup_key in self._discovered_nodes[node_id].last_attack
|
||||
|
||||
if already_executed:
|
||||
last_time = self._discovered_nodes[node_id].last_attack[lookup_key]
|
||||
if node_info.last_reimaging is None or last_time >= node_info.last_reimaging:
|
||||
return False, ActionResult(reward=Penalty.REPEAT, outcome=outcome)
|
||||
|
||||
self._discovered_nodes[node_id].last_attack[lookup_key] = time()
|
||||
|
||||
self.__mark_discovered_entities(node_id, outcome)
|
||||
logger.info("GOT REWARD: " + vulnerability.reward_string)
|
||||
return True, ActionResult(reward=0.0 if already_executed else SUCCEEDED_ATTACK_REWARD - vulnerability.cost,
|
||||
outcome=vulnerability.outcome)
|
||||
|
||||
def exploit_remote_vulnerability(self,
|
||||
node_id: model.NodeID,
|
||||
target_node_id: model.NodeID,
|
||||
vulnerability_id: model.VulnerabilityID
|
||||
) -> ActionResult:
|
||||
"""
|
||||
Attempt to exploit a remote vulnerability
|
||||
from a source node to another node using the specified
|
||||
vulnerability.
|
||||
"""
|
||||
if node_id not in self._environment.network.nodes:
|
||||
raise ValueError(f"invalid node id '{node_id}'")
|
||||
if target_node_id not in self._environment.network.nodes:
|
||||
raise ValueError(f"invalid target node id '{target_node_id}'")
|
||||
|
||||
source_node_info: model.NodeInfo = self._environment.get_node(node_id)
|
||||
target_node_info: model.NodeInfo = self._environment.get_node(target_node_id)
|
||||
|
||||
if not source_node_info.agent_installed:
|
||||
raise ValueError("Agent does not owned the source node '" + node_id + "'")
|
||||
|
||||
if target_node_id not in self._discovered_nodes:
|
||||
raise ValueError("Agent has not discovered the target node '" + target_node_id + "'")
|
||||
|
||||
succeeded, result = self.__process_outcome(
|
||||
model.VulnerabilityType.REMOTE,
|
||||
vulnerability_id,
|
||||
target_node_id,
|
||||
target_node_info,
|
||||
local_or_remote=False,
|
||||
failed_penalty=Penalty.FAILED_REMOTE_EXPLOIT,
|
||||
# We do not throw if the vulnerability is missing in order to
|
||||
# allow agent attempts to explore potential remote vulnerabilities
|
||||
throw_if_vulnerability_not_present=False
|
||||
)
|
||||
|
||||
if succeeded:
|
||||
self.__annotate_edge(node_id, target_node_id, EdgeAnnotation.REMOTE_EXPLOIT)
|
||||
|
||||
return result
|
||||
|
||||
def exploit_local_vulnerability(self, node_id: model.NodeID,
|
||||
vulnerability_id: model.VulnerabilityID) -> ActionResult:
|
||||
"""
|
||||
This function exploits a local vulnerability on a node
|
||||
it takes a nodeID for the target and a vulnerability ID.
|
||||
|
||||
It returns either a vulnerabilityoutcome object or None
|
||||
"""
|
||||
graph = self._environment.network
|
||||
if node_id not in graph.nodes:
|
||||
raise ValueError(f"invalid node id '{node_id}'")
|
||||
|
||||
node_info = self._environment.get_node(node_id)
|
||||
|
||||
if not node_info.agent_installed:
|
||||
raise ValueError(f"Agent does not owned the node '{node_id}'")
|
||||
|
||||
succeeded, result = self.__process_outcome(
|
||||
model.VulnerabilityType.LOCAL,
|
||||
vulnerability_id,
|
||||
node_id, node_info,
|
||||
local_or_remote=True,
|
||||
failed_penalty=Penalty.LOCAL_EXPLOIT_FAILED,
|
||||
throw_if_vulnerability_not_present=True)
|
||||
|
||||
return result
|
||||
|
||||
def __is_passing_firewall_rules(self, rules: List[model.FirewallRule], port_name: model.PortName) -> bool:
|
||||
"""Determine if traffic on the specified port is permitted by the specified sets of firewall rules"""
|
||||
for rule in rules:
|
||||
if rule.port == port_name:
|
||||
if rule.permission == model.RulePermission.ALLOW:
|
||||
return True
|
||||
else:
|
||||
logger.debug(f'BLOCKED TRAFFIC - PORT \'{port_name}\' Reason: ' + rule.reason)
|
||||
return False
|
||||
|
||||
logger.debug(f"BLOCKED TRAFFIC - PORT '{port_name}' - Reason: no rule defined for this port.")
|
||||
return False
|
||||
|
||||
def connect_to_remote_machine(
|
||||
self,
|
||||
source_node_id: model.NodeID,
|
||||
target_node_id: model.NodeID,
|
||||
port_name: model.PortName,
|
||||
credential: model.CredentialID) -> ActionResult:
|
||||
"""
|
||||
This function connects to a remote machine with credential as opposed to via an exploit.
|
||||
It takes a NodeId for the source machine, a NodeID for the target Machine, and a credential object
|
||||
for the credential.
|
||||
"""
|
||||
graph = self._environment.network
|
||||
if source_node_id not in graph.nodes:
|
||||
raise ValueError(f"invalid node id '{source_node_id}'")
|
||||
if target_node_id not in graph.nodes:
|
||||
raise ValueError(f"invalid node id '{target_node_id}''")
|
||||
|
||||
target_node = self._environment.get_node(target_node_id)
|
||||
source_node = self._environment.get_node(source_node_id)
|
||||
# ensures that the source node is owned by the agent
|
||||
# and that the target node is discovered
|
||||
|
||||
if not source_node.agent_installed:
|
||||
raise ValueError(f"Agent does not owned the source node '{source_node_id}'")
|
||||
|
||||
if target_node_id not in self._discovered_nodes:
|
||||
raise ValueError(f"Agent has not discovered the target node '{target_node_id}'")
|
||||
|
||||
if credential not in self._gathered_credentials:
|
||||
raise ValueError(f"Agent has not discovered credential '{credential}'")
|
||||
|
||||
if not self.__is_passing_firewall_rules(source_node.firewall.outgoing, port_name):
|
||||
logger.info(f"BLOCKED TRAFFIC: source node '{source_node_id}'" +
|
||||
f" is blocking outgoing traffic on port '{port_name}'")
|
||||
return ActionResult(reward=Penalty.BLOCKED_BY_LOCAL_FIREWALL,
|
||||
outcome=None)
|
||||
|
||||
if not self.__is_passing_firewall_rules(target_node.firewall.incoming, port_name):
|
||||
logger.info(f"BLOCKED TRAFFIC: target node '{target_node_id}'" +
|
||||
f" is blocking outgoing traffic on port '{port_name}'")
|
||||
return ActionResult(reward=Penalty.BLOCKED_BY_REMOTE_FIREWALL,
|
||||
outcome=None)
|
||||
|
||||
target_node_is_listening = port_name in [i.name for i in target_node.services]
|
||||
if not target_node_is_listening:
|
||||
logger.info(f"target node '{target_node_id}' not listening on port '{port_name}'")
|
||||
return ActionResult(reward=Penalty.SCANNING_UNOPEN_PORT,
|
||||
outcome=None)
|
||||
else:
|
||||
target_node_data: model.NodeInfo = self._environment.get_node(target_node_id)
|
||||
|
||||
if target_node_data.status != model.MachineStatus.Running:
|
||||
logger.info("target machine not in running state")
|
||||
return ActionResult(reward=Penalty.MACHINE_NOT_RUNNING,
|
||||
outcome=None)
|
||||
|
||||
# check the credentials before connecting
|
||||
if not self._check_service_running_and_authorized(target_node_data, port_name, credential):
|
||||
logger.info("invalid credentials supplied")
|
||||
return ActionResult(reward=Penalty.WRONG_PASSWORD,
|
||||
outcome=None)
|
||||
|
||||
is_already_owned = target_node_data.agent_installed
|
||||
if is_already_owned:
|
||||
return ActionResult(reward=Penalty.REPEAT,
|
||||
outcome=model.LateralMove())
|
||||
|
||||
if target_node_id not in self._discovered_nodes:
|
||||
self._discovered_nodes[target_node_id] = NodeTrackingInformation()
|
||||
|
||||
was_previously_owned_at = self._discovered_nodes[target_node_id].last_connection
|
||||
self._discovered_nodes[target_node_id].last_connection = time()
|
||||
|
||||
if was_previously_owned_at is not None and \
|
||||
target_node_data.last_reimaging is not None and \
|
||||
was_previously_owned_at >= target_node_data.last_reimaging:
|
||||
return ActionResult(reward=Penalty.REPEAT, outcome=model.LateralMove())
|
||||
|
||||
self.__annotate_edge(source_node_id, target_node_id, EdgeAnnotation.LATERAL_MOVE)
|
||||
self.__mark_node_as_owned(target_node_id)
|
||||
logger.info(f"Infected node '{target_node_id}' from '{source_node_id}'" +
|
||||
f" via {port_name} with credential '{credential}'")
|
||||
if target_node.owned_string:
|
||||
logger.info("Owned message: " + target_node.owned_string)
|
||||
|
||||
return ActionResult(reward=float(target_node_data.value) if was_previously_owned_at is None else 0.0,
|
||||
outcome=model.LateralMove())
|
||||
|
||||
def _check_service_running_and_authorized(self,
|
||||
target_node_data: model.NodeInfo,
|
||||
port_name: model.PortName,
|
||||
credential: model.CredentialID) -> bool:
|
||||
"""
|
||||
This is a quick helper function to check the prerequisites to see if
|
||||
they match the ones supplied.
|
||||
"""
|
||||
for service in target_node_data.services:
|
||||
if service.running and service.name == port_name and credential in service.allowedCredentials:
|
||||
return True
|
||||
return False
|
||||
|
||||
def list_nodes(self) -> List[DiscoveredNodeInfo]:
|
||||
"""Returns the list of nodes ID that were discovered or owned by the attacker."""
|
||||
return [cast(DiscoveredNodeInfo, {'id': node_id,
|
||||
'status': 'owned' if node_info.agent_installed else 'discovered'
|
||||
})
|
||||
for node_id, node_info in self.discovered_nodes()
|
||||
]
|
||||
|
||||
def list_remote_attacks(self, node_id: model.NodeID) -> List[model.VulnerabilityID]:
|
||||
"""Return list of all remote attacks that may be executed onto the specified node."""
|
||||
attacks: List[model.VulnerabilityID] = self.list_vulnerabilities_in_target(
|
||||
node_id, model.VulnerabilityType.REMOTE)
|
||||
return attacks
|
||||
|
||||
def list_local_attacks(self, node_id: model.NodeID) -> List[model.VulnerabilityID]:
|
||||
"""Return list of all local attacks that may be executed onto the specified node."""
|
||||
attacks: List[model.VulnerabilityID] = self.list_vulnerabilities_in_target(
|
||||
node_id, model.VulnerabilityType.LOCAL)
|
||||
return attacks
|
||||
|
||||
def list_attacks(self, node_id: model.NodeID) -> List[model.VulnerabilityID]:
|
||||
"""Return list of all attacks that may be executed on the specified node."""
|
||||
attacks: List[model.VulnerabilityID] = self.list_vulnerabilities_in_target(
|
||||
node_id)
|
||||
return attacks
|
||||
|
||||
def list_all_attacks(self) -> List[Dict[str, object]]:
|
||||
"""List all possible attacks from all the nodes currently owned by the attacker"""
|
||||
on_owned_nodes: List[Dict[str, object]] = [
|
||||
{'id': n['id'],
|
||||
'status': n['status'],
|
||||
'properties': self._environment.get_node(n['id']).properties,
|
||||
'local_attacks': self.list_local_attacks(n['id']),
|
||||
'remote_attacks': self.list_remote_attacks(n['id'])
|
||||
}
|
||||
for n in self.list_nodes() if n['status'] == 'owned']
|
||||
on_discovered_nodes: List[Dict[str, object]] = [{'id': n['id'],
|
||||
'status': n['status'],
|
||||
'local_attacks': None,
|
||||
'remote_attacks': self.list_remote_attacks(n['id'])}
|
||||
for n in self.list_nodes() if n['status'] != 'owned']
|
||||
return on_owned_nodes + on_discovered_nodes
|
||||
|
||||
def print_all_attacks(self) -> None:
|
||||
"""Pretty print list of all possible attacks from all the nodes currently owned by the attacker"""
|
||||
d.display(pd.DataFrame.from_dict(self.list_all_attacks())) # type: ignore
|
||||
|
||||
|
||||
class DefenderAgentActions:
|
||||
"""Actions reserved to defender agents"""
|
||||
|
||||
# Number of steps it takes to completely reimage a node
|
||||
REIMAGING_DURATION = 15
|
||||
|
||||
def __init__(self, environment: model.Environment):
|
||||
# map nodes being reimaged to the remaining number of steps to completion
|
||||
self.node_reimaging_progress: Dict[model.NodeID, int] = dict()
|
||||
|
||||
# Last calculated availability of the network
|
||||
self.__network_availability: float = 1.0
|
||||
|
||||
self._environment = environment
|
||||
|
||||
@property
|
||||
def network_availability(self):
|
||||
return self.__network_availability
|
||||
|
||||
def reimage_node(self, node_id: model.NodeID):
|
||||
"""Re-image a computer node"""
|
||||
# Mark the node for re-imaging and make it unavailable until re-imaging completes
|
||||
self.node_reimaging_progress[node_id] = self.REIMAGING_DURATION
|
||||
|
||||
node_info = self._environment.get_node(node_id)
|
||||
assert node_info.reimagable, f'Node {node_id} is not re-imageable'
|
||||
|
||||
node_info.agent_installed = False
|
||||
node_info.privilege_level = model.PrivilegeLevel.NoAccess
|
||||
node_info.status = model.MachineStatus.Imaging
|
||||
node_info.last_reimaging = time()
|
||||
self._environment.network.nodes[node_id].update({'data': node_info})
|
||||
|
||||
def on_attacker_step_taken(self):
|
||||
"""Function to be called each time a step is take in the simulation"""
|
||||
for node_id in list(self.node_reimaging_progress.keys()):
|
||||
remaining_steps = self.node_reimaging_progress[node_id]
|
||||
if remaining_steps > 0:
|
||||
self.node_reimaging_progress[node_id] -= 1
|
||||
else:
|
||||
logger.info(f"Machine re-imaging completed: {node_id}")
|
||||
node_data = self._environment.get_node(node_id)
|
||||
node_data.status = model.MachineStatus.Running
|
||||
self.node_reimaging_progress.pop(node_id)
|
||||
|
||||
# Calculate the network availability metric based on machines
|
||||
# and services that are running
|
||||
total_node_weights = 0
|
||||
network_node_availability = 0
|
||||
for node_id, node_info in self._environment.nodes():
|
||||
total_service_weights = 0
|
||||
running_service_weights = 0
|
||||
for service in node_info.services:
|
||||
total_service_weights += service.sla_weight
|
||||
running_service_weights += service.sla_weight * int(service.running)
|
||||
|
||||
if node_info.status == MachineStatus.Running:
|
||||
adjusted_node_availability = (1 + running_service_weights) / (1 + total_service_weights)
|
||||
else:
|
||||
adjusted_node_availability = 0.0
|
||||
|
||||
total_node_weights += node_info.sla_weight
|
||||
network_node_availability += adjusted_node_availability * node_info.sla_weight
|
||||
|
||||
self.__network_availability = network_node_availability / total_node_weights
|
||||
assert(self.__network_availability <= 1.0 and self.__network_availability >= 0.0)
|
||||
|
||||
def override_firewall_rule(self, node_id: model.NodeID, port_name: model.PortName, incoming: bool, permission: model.RulePermission):
|
||||
node_data = self._environment.get_node(node_id)
|
||||
rules = node_data.firewall.incoming if incoming else node_data.firewall.outgoing
|
||||
matching_rules = [r for r in rules if r.port == port_name]
|
||||
if matching_rules:
|
||||
for r in matching_rules:
|
||||
r.permission = permission
|
||||
else:
|
||||
new_rule = model.FirewallRule(port_name, permission)
|
||||
if incoming:
|
||||
node_data.firewall.incoming = [new_rule] + node_data.firewall.incoming
|
||||
else:
|
||||
node_data.firewall.outgoing = [new_rule] + node_data.firewall.outgoing
|
||||
|
||||
def block_traffic(self, node_id: model.NodeID, port_name: model.PortName, incoming: bool):
|
||||
return self.override_firewall_rule(node_id, port_name, incoming, permission=model.RulePermission.BLOCK)
|
||||
|
||||
def allow_traffic(self, node_id: model.NodeID, port_name: model.PortName, incoming: bool):
|
||||
return self.override_firewall_rule(node_id, port_name, incoming, permission=model.RulePermission.ALLOW)
|
||||
|
||||
def stop_service(self, node_id: model.NodeID, port_name: model.PortName):
|
||||
node_data = self._environment.get_node(node_id)
|
||||
assert node_data.status == model.MachineStatus.Running, "Machine must be running to stop a service"
|
||||
for service in node_data.services:
|
||||
if service.name == port_name:
|
||||
service.running = False
|
||||
|
||||
def start_service(self, node_id: model.NodeID, port_name: model.PortName):
|
||||
node_data = self._environment.get_node(node_id)
|
||||
assert node_data.status == model.MachineStatus.Running, "Machine must be running to start a service"
|
||||
for service in node_data.services:
|
||||
if service.name == port_name:
|
||||
service.running = True
|
|
@ -0,0 +1,379 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
This is the set of tests for actions.py which implements the actions an agent can take
|
||||
in this simulation.
|
||||
"""
|
||||
import random
|
||||
from datetime import datetime
|
||||
from typing import Union, Dict, List
|
||||
|
||||
import pytest
|
||||
import networkx as nx
|
||||
|
||||
from . import model, actions
|
||||
|
||||
ADMINTAG = model.AdminEscalation().tag
|
||||
SYSTEMTAG = model.SystemEscalation().tag
|
||||
|
||||
# pylint: disable=redefined-outer-name, protected-access
|
||||
# define fixtures as a type so mypy will shut up
|
||||
Fixture = Union[actions.AgentActions]
|
||||
|
||||
empty_vuln_dict: Dict[model.VulnerabilityID, model.VulnerabilityInfo] = {}
|
||||
SINGLE_VULNERABILITIES = {
|
||||
"UACME61":
|
||||
model.VulnerabilityInfo(
|
||||
description="UACME UAC bypass #61",
|
||||
type=model.VulnerabilityType.LOCAL,
|
||||
URL="https://github.com/hfiref0x/UACME",
|
||||
precondition=model.Precondition(f"Windows&Win10&(~({ADMINTAG}|{SYSTEMTAG}))"),
|
||||
outcome=model.AdminEscalation(),
|
||||
rates=model.Rates(0, 0.2, 1.0))}
|
||||
|
||||
# temporary vuln dictionary for development purposes only.
|
||||
# Remove once the full list of vulnerabilities is put together
|
||||
# here we'll have 1 UAC bypass, 1 credential dump, and 1 remote infection vulnerability
|
||||
SAMPLE_VULNERABILITIES = {
|
||||
"UACME61":
|
||||
model.VulnerabilityInfo(
|
||||
description="UACME UAC bypass #61",
|
||||
type=model.VulnerabilityType.LOCAL,
|
||||
URL="https://github.com/hfiref0x/UACME",
|
||||
precondition=model.Precondition(f"Windows&Win10&(~({ADMINTAG}|{SYSTEMTAG}))"),
|
||||
outcome=model.AdminEscalation(),
|
||||
rates=model.Rates(0, 0.2, 1.0)),
|
||||
"UACME67":
|
||||
model.VulnerabilityInfo(
|
||||
description="UACME UAC bypass #67 (fake system escalation) ",
|
||||
type=model.VulnerabilityType.LOCAL,
|
||||
URL="https://github.com/hfiref0x/UACME",
|
||||
precondition=model.Precondition(f"Windows&Win10&(~({ADMINTAG}|{SYSTEMTAG}))"),
|
||||
outcome=model.SystemEscalation(),
|
||||
rates=model.Rates(0, 0.2, 1.0)),
|
||||
"MimikatzLogonpasswords":
|
||||
model.VulnerabilityInfo(
|
||||
description="Mimikatz sekurlsa::logonpasswords.",
|
||||
type=model.VulnerabilityType.LOCAL,
|
||||
URL="https://github.com/gentilkiwi/mimikatz",
|
||||
precondition=model.Precondition(f"Windows&({ADMINTAG}|{SYSTEMTAG})"),
|
||||
outcome=model.LeakedCredentials([]),
|
||||
rates=model.Rates(0, 1.0, 1.0)),
|
||||
"RDPBF":
|
||||
model.VulnerabilityInfo(
|
||||
description="RDP Brute Force",
|
||||
type=model.VulnerabilityType.REMOTE,
|
||||
URL="https://attack.mitre.org/techniques/T1110/",
|
||||
precondition=model.Precondition("Windows&PortRDPOpen"),
|
||||
outcome=model.LateralMove(),
|
||||
rates=model.Rates(0, 0.2, 1.0),
|
||||
cost=1.0)
|
||||
}
|
||||
|
||||
ENV_IDENTIFIERS = model.Identifiers(
|
||||
local_vulnerabilities=['UACME61', 'UACME67', 'MimikatzLogonpasswords', 'UACME61'],
|
||||
remote_vulnerabilities=['RDPBF'],
|
||||
ports=['RDP', 'HTTP', 'HTTPS', 'SSH'],
|
||||
properties=[
|
||||
"Linux", "PortSSHOpen", "PortSQLOpen",
|
||||
"Windows", "Win10", "PortRDPOpen",
|
||||
"PortHTTPOpen", "PortHTTPsOpen",
|
||||
"SharepointLeakingPassword"]
|
||||
)
|
||||
|
||||
|
||||
def sample_random_firwall_configuration() -> model.FirewallConfiguration:
|
||||
"""Sample a random firewall set of rules"""
|
||||
return model.FirewallConfiguration(
|
||||
outgoing=[model.FirewallRule(p, permission=model.RulePermission.ALLOW)
|
||||
for p in random.choices(ENV_IDENTIFIERS.properties,
|
||||
k=random.randint(0, len(ENV_IDENTIFIERS.properties)))],
|
||||
incoming=[model.FirewallRule(p, permission=model.RulePermission.ALLOW)
|
||||
for p in random.choices(ENV_IDENTIFIERS.properties,
|
||||
k=random.randint(0, len(ENV_IDENTIFIERS.properties)))])
|
||||
|
||||
|
||||
# temporary info for a single node network
|
||||
SINGLE_NODE = {
|
||||
'a': model.NodeInfo(
|
||||
services=[model.ListeningService("RDP"),
|
||||
model.ListeningService("HTTP"),
|
||||
model.ListeningService("HTTPS")],
|
||||
value=70,
|
||||
properties=list(["Windows", "Win10", "PortRDPOpen", "PortHTTPOpen", "PortHTTPsOpen"]),
|
||||
firewall=sample_random_firwall_configuration(),
|
||||
agent_installed=False)}
|
||||
|
||||
# temporary info for 4 nodes
|
||||
# a is a windows web server, b is linux SQL server, c is a windows workstation,
|
||||
# and dc is a domain controller
|
||||
NODES = {
|
||||
'a': model.NodeInfo(
|
||||
services=[model.ListeningService("RDP"),
|
||||
model.ListeningService("HTTP"),
|
||||
model.ListeningService("HTTPS")],
|
||||
value=70,
|
||||
properties=list(["Windows", "Win10", "PortRDPOpen", "PortHTTPOpen", "PortHTTPsOpen"]),
|
||||
vulnerabilities=dict(
|
||||
ListNeighbors=model.VulnerabilityInfo(
|
||||
description="reveal other nodes",
|
||||
type=model.VulnerabilityType.LOCAL,
|
||||
outcome=model.LeakedNodesId(nodes=['b', 'c', 'dc'])),
|
||||
|
||||
DumpCreds=model.VulnerabilityInfo(
|
||||
description="leaking some creds",
|
||||
type=model.VulnerabilityType.LOCAL,
|
||||
outcome=model.LeakedCredentials([model.CachedCredential('Sharepoint', "HTTPS", "ADPrincipalCreds"),
|
||||
model.CachedCredential('Sharepoint', "HTTPS", "cred")])
|
||||
)
|
||||
),
|
||||
agent_installed=True),
|
||||
'b': model.NodeInfo(
|
||||
services=[model.ListeningService("SSH"),
|
||||
model.ListeningService("SQL")],
|
||||
value=80,
|
||||
properties=list(["Linux", "PortSSHOpen", "PortSQLOpen"]),
|
||||
agent_installed=False),
|
||||
'c': model.NodeInfo(
|
||||
services=[model.ListeningService("RDP"),
|
||||
model.ListeningService("HTTP"),
|
||||
model.ListeningService("HTTPS")],
|
||||
value=40,
|
||||
properties=list(["Windows", "Win10", "PortRDPOpen", "PortHTTPOpen", "PortHTTPsOpen"]),
|
||||
agent_installed=True),
|
||||
'dc': model.NodeInfo(
|
||||
services=[model.ListeningService("RDP"),
|
||||
model.ListeningService("WMI")],
|
||||
value=100, properties=list(["Windows", "Win10", "PortRDPOpen", "PortWMIOpen"]),
|
||||
agent_installed=False),
|
||||
'Sharepoint': model.NodeInfo(
|
||||
services=[model.ListeningService("HTTPS", allowedCredentials=["ADPrincipalCreds"])], value=100,
|
||||
properties=["SharepointLeakingPassword"],
|
||||
firewall=model.FirewallConfiguration(
|
||||
incoming=[model.FirewallRule(port="SSH", permission=model.RulePermission.ALLOW),
|
||||
model.FirewallRule(port="HTTPS", permission=model.RulePermission.ALLOW),
|
||||
model.FirewallRule(port="HTTP", permission=model.RulePermission.ALLOW),
|
||||
model.FirewallRule(port="RDP", permission=model.RulePermission.BLOCK)
|
||||
],
|
||||
outgoing=[]),
|
||||
vulnerabilities=dict(
|
||||
ScanSharepointParentDirectory=model.VulnerabilityInfo(
|
||||
description="Navigate to SharePoint site, browse parent "
|
||||
"directory",
|
||||
type=model.VulnerabilityType.REMOTE,
|
||||
outcome=model.LeakedCredentials(credentials=[
|
||||
model.CachedCredential(node="AzureResourceManager",
|
||||
port="HTTPS",
|
||||
credential="ADPrincipalCreds")]),
|
||||
rates=model.Rates(successRate=1.0),
|
||||
cost=1.0)
|
||||
)),
|
||||
|
||||
}
|
||||
|
||||
|
||||
# Define an environment from this graph
|
||||
ENV = model.Environment(
|
||||
network=model.create_network(NODES),
|
||||
vulnerability_library=dict([]),
|
||||
identifiers=ENV_IDENTIFIERS,
|
||||
creationTime=datetime.utcnow(),
|
||||
lastModified=datetime.utcnow(),
|
||||
)
|
||||
|
||||
|
||||
@ pytest.fixture
|
||||
def actions_on_empty_environment() -> actions.AgentActions:
|
||||
"""
|
||||
the test fixtures to reduce the amount of overhead
|
||||
This fixture will provide us with an empty environment.
|
||||
"""
|
||||
egraph = nx.empty_graph(0, create_using=nx.DiGraph())
|
||||
env = model.Environment(network=egraph,
|
||||
version=model.VERSION_TAG,
|
||||
vulnerability_library=SAMPLE_VULNERABILITIES,
|
||||
identifiers=ENV_IDENTIFIERS,
|
||||
creationTime=datetime.utcnow(),
|
||||
lastModified=datetime.utcnow())
|
||||
return actions.AgentActions(env)
|
||||
|
||||
|
||||
@ pytest.fixture
|
||||
def actions_on_single_node_environment() -> actions.AgentActions:
|
||||
"""
|
||||
This fixture will provide us with a single node environment
|
||||
"""
|
||||
env = model.Environment(network=model.create_network(SINGLE_NODE),
|
||||
version=model.VERSION_TAG,
|
||||
vulnerability_library=SAMPLE_VULNERABILITIES,
|
||||
identifiers=ENV_IDENTIFIERS,
|
||||
creationTime=datetime.utcnow(),
|
||||
lastModified=datetime.utcnow())
|
||||
return actions.AgentActions(env)
|
||||
|
||||
|
||||
@ pytest.fixture
|
||||
def actions_on_simple_environment() -> actions.AgentActions:
|
||||
"""
|
||||
This fixture will provide us with a 4 node environment environment.
|
||||
simulating three workstations connected to a single server
|
||||
"""
|
||||
env = model.Environment(network=model.create_network(NODES),
|
||||
version=model.VERSION_TAG,
|
||||
vulnerability_library=SAMPLE_VULNERABILITIES,
|
||||
identifiers=ENV_IDENTIFIERS,
|
||||
creationTime=datetime.utcnow(),
|
||||
lastModified=datetime.utcnow())
|
||||
return actions.AgentActions(env)
|
||||
|
||||
|
||||
def test_list_vulnerabilities_function(actions_on_single_node_environment: Fixture,
|
||||
actions_on_simple_environment: Fixture) -> None:
|
||||
"""
|
||||
This function will test the list_vulnerabilities function from the
|
||||
AgentActions class in actions.py
|
||||
"""
|
||||
# test on an environment with a single node
|
||||
single_node_results: List[model.VulnerabilityID] = []
|
||||
single_node_results = actions_on_single_node_environment.list_vulnerabilities_in_target('a')
|
||||
assert len(single_node_results) == 3
|
||||
|
||||
simple_graph_results: List[model.VulnerabilityID] = []
|
||||
simple_graph_results = actions_on_simple_environment.list_vulnerabilities_in_target('dc')
|
||||
assert len(simple_graph_results) == 3
|
||||
|
||||
|
||||
def test_exploit_remote_vulnerability(actions_on_simple_environment: Fixture) -> None:
|
||||
"""
|
||||
This function will test the exploit_remote_vulnerability function from the
|
||||
AgentActions class in actions.py
|
||||
"""
|
||||
|
||||
actions_on_simple_environment.exploit_local_vulnerability('a', "ListNeighbors")
|
||||
|
||||
# test with invalid source node
|
||||
with pytest.raises(ValueError, match=r"invalid node id '.*'"):
|
||||
actions_on_simple_environment.exploit_remote_vulnerability('z', 'b', "RDPBF")
|
||||
|
||||
# test with invalid destination node
|
||||
with pytest.raises(ValueError, match=r"invalid target node id '.*'"):
|
||||
actions_on_simple_environment.exploit_remote_vulnerability('a', 'z', "RDPBF")
|
||||
|
||||
# test with a local vulnerability
|
||||
with pytest.raises(ValueError, match=r"vulnerability id '.*' is for an attack of type .*"):
|
||||
actions_on_simple_environment.exploit_remote_vulnerability('a', 'c', "MimikatzLogonpasswords")
|
||||
|
||||
# test with an invalid vulnerability (one not there)
|
||||
result = actions_on_simple_environment.exploit_remote_vulnerability('a', 'c', "HackTheGibson")
|
||||
assert result.outcome is None and result.reward <= 0
|
||||
|
||||
# add RDP brute force to the target node
|
||||
# very hacky not to be used normally.
|
||||
graph: nx.graph.Graph = actions_on_simple_environment._environment.network
|
||||
node: model.NodeInfo = graph.nodes['c']['data']
|
||||
node.vulnerabilities = SAMPLE_VULNERABILITIES
|
||||
|
||||
# test a valid and functional one.
|
||||
result = actions_on_simple_environment.exploit_remote_vulnerability('a', 'c', "RDPBF")
|
||||
assert isinstance(result.outcome, model.LateralMove)
|
||||
assert result.reward <= actions.SUCCEEDED_ATTACK_REWARD - 1
|
||||
|
||||
|
||||
def test_exploit_local_vulnerability(actions_on_simple_environment: Fixture) -> None:
|
||||
"""
|
||||
This function will test the exploit_local_vulnerability function from the
|
||||
AgentActions class in actions.py
|
||||
"""
|
||||
|
||||
# check one with invalid prerequisites
|
||||
result: actions.ActionResult = actions_on_simple_environment.\
|
||||
exploit_local_vulnerability('a', "MimikatzLogonpasswords")
|
||||
assert isinstance(result.outcome, model.ExploitFailed)
|
||||
|
||||
# test admin privilege escalation
|
||||
# exploit_local_vulnerability(node_id, vulnerability_id)
|
||||
result = actions_on_simple_environment.exploit_local_vulnerability('a', "UACME61")
|
||||
assert isinstance(result.outcome, model.AdminEscalation)
|
||||
node: model.NodeInfo = actions_on_simple_environment._environment.network.nodes['a']['data']
|
||||
assert model.AdminEscalation().tag in node.properties
|
||||
|
||||
# test system privilege escalation
|
||||
result = actions_on_simple_environment.exploit_local_vulnerability('c', "UACME67")
|
||||
assert isinstance(result.outcome, model.SystemEscalation)
|
||||
node = actions_on_simple_environment._environment.network.nodes['c']['data']
|
||||
assert model.SystemEscalation().tag in node.properties
|
||||
|
||||
# test dump credentials
|
||||
result = actions_on_simple_environment.\
|
||||
exploit_local_vulnerability('a', "MimikatzLogonpasswords")
|
||||
assert isinstance(result.outcome, model.LeakedCredentials)
|
||||
|
||||
|
||||
def test_connect_to_remote_machine(actions_on_empty_environment: Fixture,
|
||||
actions_on_single_node_environment: Fixture,
|
||||
actions_on_simple_environment: Fixture) -> None:
|
||||
"""
|
||||
This function will test the connect_to_remote_machine function from the
|
||||
AgentActions class in actions.py
|
||||
"""
|
||||
actions_on_simple_environment.exploit_local_vulnerability('a', "ListNeighbors")
|
||||
actions_on_simple_environment.exploit_local_vulnerability('a', "DumpCreds")
|
||||
|
||||
# test connect to remote machine on an empty environment
|
||||
with pytest.raises(ValueError, match=r"invalid node id '.*'"):
|
||||
actions_on_empty_environment.connect_to_remote_machine("a", "b", "RDP", "cred")
|
||||
|
||||
# test connect to remote machine on an environment with 1 node
|
||||
with pytest.raises(ValueError, match=r"invalid node id '.*'"):
|
||||
actions_on_single_node_environment.connect_to_remote_machine("a", "b", "RDP", "cred")
|
||||
|
||||
graph: nx.graph.Graph = actions_on_simple_environment._environment.network
|
||||
|
||||
# test connect to remote machine on an environment with multiple nodes
|
||||
# test with valid source node and invalid destination node
|
||||
with pytest.raises(ValueError, match=r"invalid node id '.*'"):
|
||||
actions_on_simple_environment.\
|
||||
connect_to_remote_machine("a", "f", "RDP", "cred")
|
||||
|
||||
# test with an invalid source node and valid destination node
|
||||
with pytest.raises(ValueError, match=r"invalid node id '.*'"):
|
||||
actions_on_simple_environment.connect_to_remote_machine("f", "dc", "RDP", "cred")
|
||||
|
||||
# test with both nodes invalid
|
||||
with pytest.raises(ValueError, match=r"invalid node id '.*'"):
|
||||
actions_on_simple_environment.connect_to_remote_machine("f", "z", "RDP", "cred")
|
||||
|
||||
# test with invalid protocol
|
||||
result = actions_on_simple_environment.connect_to_remote_machine("a", "dc", "TCPIP", "cred")
|
||||
assert result.reward <= 0 and result.outcome is None
|
||||
|
||||
# test with invalid credentials
|
||||
result2 = actions_on_simple_environment.connect_to_remote_machine("a", "dc", "RDP", "cred")
|
||||
assert result2.outcome is None and result2.reward <= 0
|
||||
|
||||
# test blocking firewall rule
|
||||
ret_val = actions_on_simple_environment.connect_to_remote_machine("a", 'Sharepoint', "RDP", "ADPrincipalCreds")
|
||||
assert ret_val.reward < 0
|
||||
|
||||
# test with valid nodes
|
||||
ret_val = actions_on_simple_environment.connect_to_remote_machine("a", 'Sharepoint', "HTTPS", "ADPrincipalCreds")
|
||||
|
||||
assert ret_val.reward == 100
|
||||
|
||||
assert graph.has_edge("a", "dc")
|
||||
|
||||
|
||||
def test_check_prerequisites(actions_on_simple_environment: Fixture) -> None:
|
||||
"""
|
||||
This function will test the _checkPrerequisites function
|
||||
It's marked as a private function but still needs to be tested before use
|
||||
|
||||
"""
|
||||
# testing on a node/vuln combo which should give us a negative result
|
||||
result = actions_on_simple_environment._check_prerequisites('dc', SAMPLE_VULNERABILITIES["MimikatzLogonpasswords"])
|
||||
assert not result
|
||||
|
||||
# testing on a node/vuln combo which should give us a positive reuslt.
|
||||
result = actions_on_simple_environment._check_prerequisites('dc', SAMPLE_VULNERABILITIES["UACME61"])
|
||||
assert result
|
|
@ -0,0 +1,276 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""A 'Command & control'-like interface exposing to a human player
|
||||
the attacker view and actions of the game.
|
||||
This includes commands to visualize the part of the environment
|
||||
that were explored so far, and for each node where the attacker client
|
||||
is installed, execute actions on the machine.
|
||||
"""
|
||||
import networkx as nx
|
||||
from typing import List, Optional, Dict, Union, Tuple
|
||||
import plotly.graph_objects as go
|
||||
|
||||
from . import model, actions
|
||||
|
||||
|
||||
class CommandControl:
|
||||
""" The Command and Control interface to the simulation.
|
||||
|
||||
This represents a server that centralize information and secrets
|
||||
retrieved from the individual clients running on the network nodes.
|
||||
"""
|
||||
|
||||
# Global list aggregating all credentials gathered so far, from any node in the network
|
||||
__gathered_credentials: List[model.CachedCredential] = []
|
||||
_actuator: actions.AgentActions
|
||||
__environment: model.Environment
|
||||
__total_reward: float
|
||||
|
||||
def __init__(self, environment_or_actuator: Union[model.Environment, actions.AgentActions]):
|
||||
if isinstance(environment_or_actuator, model.Environment):
|
||||
self.__environment = environment_or_actuator
|
||||
self._actuator = actions.AgentActions(self.__environment)
|
||||
elif isinstance(environment_or_actuator, actions.AgentActions):
|
||||
self.__environment = environment_or_actuator._environment
|
||||
self._actuator = environment_or_actuator
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid type: expecting Union[model.Environment, actions.AgentActions])")
|
||||
|
||||
self.__gathered_credentials = []
|
||||
self.__total_reward = 0
|
||||
|
||||
def __save_credentials(self, outcome: model.VulnerabilityOutcome) -> None:
|
||||
"""Save credentials obtained from exploiting a vulnerability"""
|
||||
if isinstance(outcome, model.LeakedCredentials):
|
||||
self.__gathered_credentials.extend(outcome.credentials)
|
||||
return
|
||||
|
||||
def __accumulate_reward(self, reward: actions.Reward) -> None:
|
||||
"""Accumulate new reward"""
|
||||
self.__total_reward += reward
|
||||
|
||||
def total_reward(self) -> actions.Reward:
|
||||
"""Return the current accumulated reward"""
|
||||
return self.__total_reward
|
||||
|
||||
def list_nodes(self) -> List[actions.DiscoveredNodeInfo]:
|
||||
"""Returns the list of nodes ID that were discovered or owned by the attacker."""
|
||||
return self._actuator.list_nodes()
|
||||
|
||||
def get_node_color(self, node_info: model.NodeInfo) -> str:
|
||||
if node_info.agent_installed:
|
||||
return 'red'
|
||||
else:
|
||||
return 'green'
|
||||
|
||||
def plot_nodes(self) -> None:
|
||||
"""Plot the sub-graph of nodes either so far
|
||||
discovered (their ID is knowned by the agent)
|
||||
or owned (i.e. where the attacker client is installed)."""
|
||||
discovered_nodes = [node_id for node_id, _ in self._actuator.discovered_nodes()]
|
||||
sub_graph = self.__environment.network.subgraph(discovered_nodes)
|
||||
nx.draw(sub_graph,
|
||||
with_labels=True,
|
||||
node_color=[self.get_node_color(self.__environment.get_node(i)) for i in sub_graph.nodes])
|
||||
|
||||
def known_vulnerabilities(self) -> model.VulnerabilityLibrary:
|
||||
"""Return the global list of known vulnerability."""
|
||||
return self.__environment.vulnerability_library
|
||||
|
||||
def list_remote_attacks(self, node_id: model.NodeID) -> List[model.VulnerabilityID]:
|
||||
"""Return list of all remote attacks that the Command&Control may
|
||||
execute onto the specified node."""
|
||||
return self._actuator.list_remote_attacks(node_id)
|
||||
|
||||
def list_local_attacks(self, node_id: model.NodeID) -> List[model.VulnerabilityID]:
|
||||
"""Return list of all local attacks that the Command&Control may
|
||||
execute onto the specified node."""
|
||||
return self._actuator.list_local_attacks(node_id)
|
||||
|
||||
def list_attacks(self, node_id: model.NodeID) -> List[model.VulnerabilityID]:
|
||||
"""Return list of all attacks that the Command&Control may
|
||||
execute on the specified node."""
|
||||
return self._actuator.list_attacks(node_id)
|
||||
|
||||
def list_all_attacks(self) -> List[Dict[str, object]]:
|
||||
"""List all possible attacks from all the nodes currently owned by the attacker"""
|
||||
return self._actuator.list_all_attacks()
|
||||
|
||||
def print_all_attacks(self) -> None:
|
||||
"""Pretty print list of all possible attacks from all the nodes currently owned by the attacker"""
|
||||
return self._actuator.print_all_attacks()
|
||||
|
||||
def run_attack(self,
|
||||
node_id: model.NodeID,
|
||||
vulnerability_id: model.VulnerabilityID
|
||||
) -> Optional[model.VulnerabilityOutcome]:
|
||||
"""Run an attack and attempt to exploit a vulnerability on the specified node."""
|
||||
result = self._actuator.exploit_local_vulnerability(node_id, vulnerability_id)
|
||||
if result.outcome is not None:
|
||||
self.__save_credentials(result.outcome)
|
||||
self.__accumulate_reward(result.reward)
|
||||
return result.outcome
|
||||
|
||||
def run_remote_attack(self, node_id: model.NodeID,
|
||||
target_node_id: model.NodeID,
|
||||
vulnerability_id: model.VulnerabilityID
|
||||
) -> Optional[model.VulnerabilityOutcome]:
|
||||
"""Run a remote attack from the specified node to exploit a remote vulnerability
|
||||
in the specified target node"""
|
||||
|
||||
result = self._actuator.exploit_remote_vulnerability(
|
||||
node_id, target_node_id, vulnerability_id)
|
||||
if result.outcome is not None:
|
||||
self.__save_credentials(result.outcome)
|
||||
self.__accumulate_reward(result.reward)
|
||||
return result.outcome
|
||||
|
||||
def connect_and_infect(self, source_node_id: model.NodeID,
|
||||
target_node_id: model.NodeID,
|
||||
port_name: model.PortName,
|
||||
credentials: model.CredentialID) -> bool:
|
||||
"""Install the agent on a remote machine using the
|
||||
provided credentials"""
|
||||
result = self._actuator.connect_to_remote_machine(source_node_id, target_node_id, port_name,
|
||||
credentials)
|
||||
self.__accumulate_reward(result.reward)
|
||||
return result.outcome is not None
|
||||
|
||||
@property
|
||||
def credentials_gathered_so_far(self) -> List[model.CachedCredential]:
|
||||
"""Returns the list of credentials gathered so far by the
|
||||
attacker (from any node)"""
|
||||
return self.__gathered_credentials
|
||||
|
||||
|
||||
def get_outcome_first_credential(outcome: Optional[model.VulnerabilityOutcome]) -> model.CredentialID:
|
||||
"""Return the first credential found in a given vulnerability exploit outcome"""
|
||||
if outcome is not None and isinstance(outcome, model.LeakedCredentials):
|
||||
return outcome.credentials[0].credential
|
||||
else:
|
||||
raise ValueError('Vulnerability outcome does not contain any credential')
|
||||
|
||||
|
||||
class EnvironmentDebugging:
|
||||
"""Provides debugging feature exposing internals of the environment
|
||||
that are not normally revealed to an attacker agent according to
|
||||
the rules of the simulation.
|
||||
"""
|
||||
__environment: model.Environment
|
||||
__actuator: actions.AgentActions
|
||||
|
||||
def __init__(self, actuator_or_c2: Union[actions.AgentActions, CommandControl]):
|
||||
if isinstance(actuator_or_c2, actions.AgentActions):
|
||||
self.__actuator = actuator_or_c2
|
||||
elif isinstance(actuator_or_c2, CommandControl):
|
||||
self.__actuator = actuator_or_c2._actuator
|
||||
else:
|
||||
raise ValueError("Invalid type: expecting Union[actions.AgentActions, CommandControl])")
|
||||
|
||||
self.__environment = self.__actuator._environment
|
||||
|
||||
def network_as_plotly_traces(self, xref: str = "x", yref: str = "y") -> Tuple[List[go.Scatter], dict]:
|
||||
known_nodes = [node_id for node_id, _ in self.__actuator.discovered_nodes()]
|
||||
|
||||
subgraph = self.__environment.network.subgraph(known_nodes)
|
||||
# pos = nx.fruchterman_reingold_layout(subgraph)
|
||||
pos = nx.shell_layout(subgraph, [[known_nodes[0]], known_nodes[1:]])
|
||||
|
||||
def edge_text(source: model.NodeID, target: model.NodeID) -> str:
|
||||
data = self.__environment.network.get_edge_data(source, target)
|
||||
name: str = data['kind'].name
|
||||
return name
|
||||
|
||||
color_map = {actions.EdgeAnnotation.LATERAL_MOVE: 'red',
|
||||
actions.EdgeAnnotation.REMOTE_EXPLOIT: 'orange',
|
||||
actions.EdgeAnnotation.KNOWS: 'gray'}
|
||||
|
||||
def edge_color(source: model.NodeID, target: model.NodeID) -> str:
|
||||
data = self.__environment.network.get_edge_data(source, target)
|
||||
if 'kind' in data:
|
||||
return color_map[data['kind']]
|
||||
return 'black'
|
||||
|
||||
layout: dict = dict(title="CyberBattle simulation", font=dict(size=10), showlegend=True,
|
||||
autosize=False, width=800, height=400,
|
||||
margin=go.layout.Margin(l=2, r=2, b=15, t=35),
|
||||
hovermode='closest',
|
||||
annotations=[dict(
|
||||
ax=pos[source][0],
|
||||
ay=pos[source][1], axref=xref, ayref=yref,
|
||||
x=pos[target][0],
|
||||
y=pos[target][1], xref=xref, yref=yref,
|
||||
arrowcolor=edge_color(source, target),
|
||||
hovertext=edge_text(source, target),
|
||||
showarrow=True,
|
||||
arrowhead=1,
|
||||
arrowsize=1,
|
||||
arrowwidth=1,
|
||||
startstandoff=10,
|
||||
standoff=10,
|
||||
align='center',
|
||||
opacity=1
|
||||
) for (source, target) in subgraph.edges]
|
||||
)
|
||||
|
||||
owned_nodes_coordinates = [(i, c) for i, c in pos.items()
|
||||
if self.get_node_information(i).agent_installed]
|
||||
discovered_nodes_coordinates = [(i, c)
|
||||
for i, c in pos.items()
|
||||
if not self.get_node_information(i).agent_installed]
|
||||
|
||||
trace_owned_nodes = go.Scatter(
|
||||
x=[c[0] for i, c in owned_nodes_coordinates],
|
||||
y=[c[1] for i, c in owned_nodes_coordinates],
|
||||
mode='markers+text',
|
||||
name='owned',
|
||||
marker=dict(symbol='circle-dot',
|
||||
size=5,
|
||||
# green #0e9d00
|
||||
color='#D32F2E', # red
|
||||
line=dict(color='rgb(255,0,0)', width=8)
|
||||
),
|
||||
text=[i for i, c in owned_nodes_coordinates],
|
||||
hoverinfo='text',
|
||||
textposition="bottom center"
|
||||
)
|
||||
|
||||
trace_discovered_nodes = go.Scatter(
|
||||
x=[c[0] for i, c in discovered_nodes_coordinates],
|
||||
y=[c[1] for i, c in discovered_nodes_coordinates],
|
||||
mode='markers+text',
|
||||
name='discovered',
|
||||
marker=dict(symbol='circle-dot',
|
||||
size=5,
|
||||
color='#0e9d00', # green
|
||||
line=dict(color='rgb(0,255,0)', width=8)
|
||||
),
|
||||
text=[i for i, c in discovered_nodes_coordinates],
|
||||
hoverinfo='text',
|
||||
textposition="bottom center"
|
||||
)
|
||||
|
||||
dummy_scatter_for_edge_legend = [
|
||||
go.Scatter(
|
||||
x=[0], y=[0], mode="lines",
|
||||
line=dict(color=color_map[a]),
|
||||
name=a.name
|
||||
) for a in actions.EdgeAnnotation]
|
||||
|
||||
all_scatters = dummy_scatter_for_edge_legend + [trace_owned_nodes, trace_discovered_nodes]
|
||||
return (all_scatters, layout)
|
||||
|
||||
def plot_discovered_network(self) -> None:
|
||||
"""Plot the network graph with plotly"""
|
||||
fig = go.Figure()
|
||||
traces, layout = self.network_as_plotly_traces()
|
||||
for t in traces:
|
||||
fig.add_trace(t)
|
||||
fig.update_layout(layout)
|
||||
fig.show()
|
||||
|
||||
def get_node_information(self, node_id: model.NodeID) -> model.NodeInfo:
|
||||
"""Print node information"""
|
||||
return self.__environment.get_node(node_id)
|
|
@ -0,0 +1,77 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
Unit tests for commandcontrol.py.
|
||||
|
||||
"""
|
||||
# pylint: disable=missing-function-docstring
|
||||
|
||||
from . import model, commandcontrol
|
||||
from ..samples.toyctf import toy_ctf as ctf
|
||||
|
||||
|
||||
def test_toyctf() -> None:
|
||||
# Use the C&C to exploit remote and local vulnerabilities in the toy CTF game
|
||||
network = model.create_network(ctf.nodes)
|
||||
env = model.Environment(network=network,
|
||||
vulnerability_library=dict([]),
|
||||
identifiers=ctf.ENV_IDENTIFIERS)
|
||||
command = commandcontrol.CommandControl(env)
|
||||
leak_website = command.run_attack('client', 'SearchEdgeHistory')
|
||||
assert leak_website
|
||||
github = command.run_remote_attack('client', 'Website', 'ScanPageContent')
|
||||
leaked_sas_url_outcome = command.run_remote_attack('client', 'GitHubProject', 'CredScanGitHistory')
|
||||
leaked_sas_url = commandcontrol.get_outcome_first_credential(leaked_sas_url_outcome)
|
||||
|
||||
blobwithflag = command.connect_and_infect('client', 'AzureStorage', 'HTTPS', leaked_sas_url)
|
||||
assert(blobwithflag is not False)
|
||||
|
||||
browsable_directory = command.run_remote_attack('client', 'Website', 'ScanPageSource')
|
||||
assert browsable_directory
|
||||
|
||||
outcome_mysqlleak = command.run_remote_attack('client', 'Website.Directory', 'NavigateWebDirectoryFurther')
|
||||
mysql_credential = commandcontrol.get_outcome_first_credential(outcome_mysqlleak)
|
||||
sharepoint_url = command.run_remote_attack('client', 'Website.Directory', 'NavigateWebDirectory')
|
||||
assert sharepoint_url
|
||||
|
||||
outcome_azure_ad = command.run_remote_attack('client', 'Sharepoint', 'ScanSharepointParentDirectory')
|
||||
azure_ad_credentials = commandcontrol.get_outcome_first_credential(outcome_azure_ad)
|
||||
|
||||
azure_vm_info = command.connect_and_infect('client', 'AzureResourceManager', 'HTTPS', azure_ad_credentials)
|
||||
assert(azure_vm_info is not False)
|
||||
|
||||
azure_resources = command.run_remote_attack('client', 'AzureResourceManager', 'ListAzureResources')
|
||||
assert azure_resources
|
||||
|
||||
directly_ssh_connected = command.connect_and_infect('client', 'AzureVM', 'SSH', mysql_credential)
|
||||
assert not directly_ssh_connected
|
||||
|
||||
sshd = command.connect_and_infect('client', 'Website', 'SSH', mysql_credential)
|
||||
assert sshd is not False
|
||||
|
||||
outcome = command.run_attack('Website', 'CredScanBashHistory')
|
||||
monitor_bash_breds = commandcontrol.get_outcome_first_credential(outcome)
|
||||
|
||||
connected_as_monitor = command.connect_and_infect('Website', 'Website[user=monitor]', 'sudo', monitor_bash_breds)
|
||||
assert not connected_as_monitor
|
||||
|
||||
connected_as_monitor_from_client = command.connect_and_infect(
|
||||
'client', 'Website[user=monitor]', 'SSH', monitor_bash_breds)
|
||||
assert not connected_as_monitor_from_client
|
||||
|
||||
flag = command.connect_and_infect('Website', 'Website[user=monitor]', 'su', monitor_bash_breds)
|
||||
assert flag is not False
|
||||
|
||||
outcome_azuread = command.run_attack('Website[user=monitor]', 'CredScan-HomeDirectory')
|
||||
azure_ad_user_credential = commandcontrol.get_outcome_first_credential(outcome_azuread)
|
||||
|
||||
secrets = command.connect_and_infect('client', 'AzureResourceManager', 'HTTPS',
|
||||
azure_ad_user_credential)
|
||||
assert secrets is not False
|
||||
|
||||
reward = command.total_reward()
|
||||
print('Total reward ' + str(reward))
|
||||
assert reward == 759.0
|
||||
assert github is not None
|
||||
pass
|
|
@ -0,0 +1,300 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
environment_generation.py this function generates a semi random environment for
|
||||
the loonshot simulation v0.
|
||||
"""
|
||||
from typing import List, Dict, Set
|
||||
import random
|
||||
import re
|
||||
import networkx as nx
|
||||
from . import model
|
||||
|
||||
# These two lists are lists of potential vulnerabilities. They are split into linux vulnerabilities
|
||||
# and Windows vulnerabilities so i can
|
||||
|
||||
ADMINTAG = model.AdminEscalation().tag
|
||||
SYSTEMTAG = model.SystemEscalation().tag
|
||||
|
||||
potential_windows_vulns = {
|
||||
"UACME43":
|
||||
model.VulnerabilityInfo(
|
||||
description="UACME UAC bypass #43",
|
||||
type=model.VulnerabilityType.LOCAL,
|
||||
URL="https://github.com/hfiref0x/UACME",
|
||||
precondition=model.Precondition(f"Windows&(Win10|Win7)&(~({ADMINTAG}|{SYSTEMTAG}))"),
|
||||
outcome=model.AdminEscalation(),
|
||||
rates=model.Rates(0, 0.2, 1.0)),
|
||||
"UACME45":
|
||||
model.VulnerabilityInfo(
|
||||
description="UACME UAC bypass #45",
|
||||
type=model.VulnerabilityType.LOCAL,
|
||||
URL="https://github.com/hfiref0x/UACME",
|
||||
precondition=model.Precondition(f"Windows&Win10&(~({ADMINTAG}|{SYSTEMTAG}))"),
|
||||
outcome=model.AdminEscalation(),
|
||||
rates=model.Rates(0, 0.2, 1.0)),
|
||||
"UACME52":
|
||||
model.VulnerabilityInfo(
|
||||
description="UACME UAC bypass #52",
|
||||
type=model.VulnerabilityType.LOCAL,
|
||||
URL="https://github.com/hfiref0x/UACME",
|
||||
precondition=model.Precondition(f"Windows&(Win10|Win7)&(~({ADMINTAG}|{SYSTEMTAG}))"),
|
||||
outcome=model.AdminEscalation(),
|
||||
rates=model.Rates(0, 0.2, 1.0)),
|
||||
"UACME55":
|
||||
model.VulnerabilityInfo(
|
||||
description="UACME UAC bypass #55",
|
||||
type=model.VulnerabilityType.LOCAL,
|
||||
URL="https://github.com/hfiref0x/UACME",
|
||||
precondition=model.Precondition(f"Windows&(Win10|Win7)&(~({ADMINTAG}|{SYSTEMTAG}))"),
|
||||
outcome=model.AdminEscalation(),
|
||||
rates=model.Rates(0, 0.2, 1.0)),
|
||||
"UACME61":
|
||||
model.VulnerabilityInfo(
|
||||
description="UACME UAC bypass #61",
|
||||
type=model.VulnerabilityType.LOCAL,
|
||||
URL="https://github.com/hfiref0x/UACME",
|
||||
precondition=model.Precondition(f"Windows&Win10&(~({ADMINTAG}|{SYSTEMTAG}))"),
|
||||
outcome=model.AdminEscalation(),
|
||||
rates=model.Rates(0, 0.2, 1.0)),
|
||||
"MimikatzLogonpasswords":
|
||||
model.VulnerabilityInfo(
|
||||
description="Mimikatz sekurlsa::logonpasswords.",
|
||||
type=model.VulnerabilityType.LOCAL,
|
||||
URL="https://github.com/gentilkiwi/mimikatz",
|
||||
precondition=model.Precondition(f"Windows&({ADMINTAG}|{SYSTEMTAG})"),
|
||||
outcome=model.LeakedCredentials([]),
|
||||
rates=model.Rates(0, 1.0, 1.0)),
|
||||
"MimikatzKerberosExport":
|
||||
model.VulnerabilityInfo(
|
||||
description="Mimikatz Kerberos::list /export."
|
||||
"Exports .kirbi files to be used with pass the ticket",
|
||||
type=model.VulnerabilityType.LOCAL,
|
||||
URL="https://github.com/gentilkiwi/mimikatz",
|
||||
precondition=model.Precondition(f"Windows&DomainJoined&({ADMINTAG}|{SYSTEMTAG})"),
|
||||
outcome=model.LeakedCredentials([]),
|
||||
rates=model.Rates(0, 1.0, 1.0)),
|
||||
"PassTheTicket":
|
||||
model.VulnerabilityInfo(
|
||||
description="Mimikatz Kerberos::ptt /export."
|
||||
"Exports .kirbi files to be used with pass the ticket",
|
||||
type=model.VulnerabilityType.REMOTE,
|
||||
URL="https://github.com/gentilkiwi/mimikatz",
|
||||
precondition=model.Precondition(f"Windows&DomainJoined&KerberosTicketsDumped"
|
||||
f"&({ADMINTAG}|{SYSTEMTAG})"),
|
||||
outcome=model.LeakedCredentials([]),
|
||||
rates=model.Rates(0, 1.0, 1.0)),
|
||||
"RDPBF":
|
||||
model.VulnerabilityInfo(
|
||||
description="RDP Brute Force",
|
||||
type=model.VulnerabilityType.REMOTE,
|
||||
URL="https://attack.mitre.org/techniques/T1110/",
|
||||
precondition=model.Precondition("Windows&PortRDPOpen"),
|
||||
outcome=model.LateralMove(),
|
||||
rates=model.Rates(0, 0.2, 1.0)),
|
||||
|
||||
"SMBBF":
|
||||
model.VulnerabilityInfo(
|
||||
description="SSH Brute Force",
|
||||
type=model.VulnerabilityType.REMOTE,
|
||||
URL="https://attack.mitre.org/techniques/T1110/",
|
||||
precondition=model.Precondition("(Windows|Linux)&PortSMBOpen"),
|
||||
outcome=model.LateralMove(),
|
||||
rates=model.Rates(0, 0.2, 1.0))
|
||||
}
|
||||
|
||||
potential_linux_vulns = {
|
||||
"SudoCaching":
|
||||
model.VulnerabilityInfo(
|
||||
description="Escalating privileges from poorly configured sudo on linux/unix machines",
|
||||
type=model.VulnerabilityType.REMOTE,
|
||||
URL="https://attack.mitre.org/techniques/T1206/",
|
||||
precondition=model.Precondition(f"Linux&(~{ADMINTAG})"),
|
||||
outcome=model.AdminEscalation(),
|
||||
rates=model.Rates(0, 1.0, 1.0)),
|
||||
"SSHBF":
|
||||
model.VulnerabilityInfo(
|
||||
description="SSH Brute Force",
|
||||
type=model.VulnerabilityType.REMOTE,
|
||||
URL="https://attack.mitre.org/techniques/T1110/",
|
||||
precondition=model.Precondition("Linux&PortSSHOpen"),
|
||||
outcome=model.LateralMove(),
|
||||
rates=model.Rates(0, 0.2, 1.0)),
|
||||
"SMBBF":
|
||||
model.VulnerabilityInfo(
|
||||
description="SSH Brute Force",
|
||||
type=model.VulnerabilityType.REMOTE,
|
||||
URL="https://attack.mitre.org/techniques/T1110/",
|
||||
precondition=model.Precondition("(Windows|Linux)&PortSMBOpen"),
|
||||
outcome=model.LateralMove(),
|
||||
rates=model.Rates(0, 0.2, 1.0))
|
||||
}
|
||||
|
||||
# These are potential endpoints that can be open in a game. Note to add any more endpoints simply
|
||||
# add the protocol name to this list.
|
||||
# further note that ports are stored in a tuple. This is because some protoocls
|
||||
# (like SMB) have multiple official ports.
|
||||
potential_ports: List[model.PortName] = ["RDP", "SSH", "HTTP", "HTTPs",
|
||||
"SMB", "SQL", "FTP", "WMI"]
|
||||
|
||||
# These two lists are potential node states. They are split into linux states and windows
|
||||
# states so that we can generate real graphs that aren't just totally random.
|
||||
potential_linux_node_states: List[model.PropertyName] = ["Linux", ADMINTAG,
|
||||
"PortRDPOpen",
|
||||
"PortHTTPOpen", "PortHTTPsOpen",
|
||||
"PortSSHOpen", "PortSMBOpen",
|
||||
"PortFTPOpen", "DomainJoined"]
|
||||
potential_windows_node_states: List[model.PropertyName] = ["Windows", "Win10", "PortRDPOpen",
|
||||
"PortHTTPOpen", "PortHTTPsOpen",
|
||||
"PortSSHOpen", "PortSMBOpen",
|
||||
"PortFTPOpen", "BITSEnabled",
|
||||
"Win7", "DomainJoined"]
|
||||
|
||||
ENV_IDENTIFIERS = model.Identifiers(
|
||||
ports=potential_ports,
|
||||
properties=potential_linux_node_states + potential_windows_node_states,
|
||||
local_vulnerabilities=list(potential_windows_vulns.keys()),
|
||||
remote_vulnerabilities=list(potential_windows_vulns.keys())
|
||||
)
|
||||
|
||||
|
||||
def create_random_environment(name: str, size: int) -> model.Environment:
|
||||
"""
|
||||
This is the create random environment function. It takes a string for the name
|
||||
of the environment and an int for the size. It returns a randomly genernated
|
||||
environment.
|
||||
|
||||
Note this does not currently support generating credentials.
|
||||
"""
|
||||
if not name:
|
||||
raise ValueError("Please supply a non empty string for the name")
|
||||
|
||||
if size < 1:
|
||||
raise ValueError("Please supply a positive non zero positive"
|
||||
"integer for the size of the environment")
|
||||
graph = nx.DiGraph()
|
||||
nodes: Dict[str, model.NodeInfo] = {}
|
||||
|
||||
# append the linux and windows vulnerability dictionaries
|
||||
local_vuln_lib: Dict[model.VulnerabilityID, model.VulnerabilityInfo] = \
|
||||
{**potential_windows_vulns, **potential_linux_vulns}
|
||||
|
||||
os_types: List[str] = ["Linux", "Windows"]
|
||||
for i in range(size):
|
||||
rand_os: str = os_types[random.randint(0, 1)]
|
||||
nodes[str(i)] = create_random_node(rand_os, potential_ports)
|
||||
|
||||
graph.add_nodes_from([(k, {'data': v}) for (k, v) in list(nodes.items())])
|
||||
|
||||
return model.Environment(network=graph, vulnerability_library=local_vuln_lib, identifiers=ENV_IDENTIFIERS)
|
||||
|
||||
|
||||
def create_random_node(os_type: str, end_points: List[model.PortName]) \
|
||||
-> model.NodeInfo:
|
||||
"""
|
||||
This is the create random node function.
|
||||
Currently it takes a string for the OS type and returns a NodeInfo object
|
||||
Options for OS type are currently Linux or Windows,
|
||||
Options for the role are Server or Workstation
|
||||
"""
|
||||
|
||||
if not end_points:
|
||||
raise ValueError("No endpoints supplied")
|
||||
|
||||
if os_type not in ("Windows", "Linux"):
|
||||
raise ValueError("Unsupported OS Type please enter Linux or Windows")
|
||||
|
||||
# get the vulnerability dictionary for the important OS
|
||||
vulnerabilities: model.VulnerabilityLibrary = dict([])
|
||||
if os_type == "Linux":
|
||||
vulnerabilities = \
|
||||
select_random_vulnerabilities(os_type, random.randint(1, len(potential_linux_vulns)))
|
||||
else:
|
||||
vulnerabilities = \
|
||||
select_random_vulnerabilities(os_type, random.randint(1, len(potential_windows_vulns)))
|
||||
|
||||
firewall: model.FirewallConfiguration = create_firewall_rules(end_points)
|
||||
properties: List[model.PropertyName] = \
|
||||
get_properties_from_vulnerabilities(os_type, vulnerabilities)
|
||||
|
||||
return model.NodeInfo(services=[model.ListeningService(name=p) for p in end_points],
|
||||
vulnerabilities=vulnerabilities,
|
||||
value=int(random.random()),
|
||||
properties=properties,
|
||||
firewall=firewall,
|
||||
agent_installed=False)
|
||||
|
||||
|
||||
def select_random_vulnerabilities(os_type: str, num_vulns: int) \
|
||||
-> Dict[str, model.VulnerabilityInfo]:
|
||||
"""
|
||||
It takes an a string for the OS type, and an int for the number of
|
||||
vulnerabilities to select.
|
||||
|
||||
It selects num_vulns vulnerabilities from the global list of vulnerabilities for that
|
||||
specific operating system. It returns a dictionary of VulnerabilityInfo objects to
|
||||
the caller.
|
||||
"""
|
||||
|
||||
if num_vulns < 1:
|
||||
raise ValueError("Expected a positive value for num_vulns in select_random_vulnerabilities")
|
||||
|
||||
ret_val: Dict[str, model.VulnerabilityInfo] = {}
|
||||
keys: List[str]
|
||||
if os_type == "Linux":
|
||||
keys = random.sample(potential_linux_vulns.keys(), num_vulns)
|
||||
ret_val = {k: potential_linux_vulns[k] for k in keys}
|
||||
elif os_type == "Windows":
|
||||
keys = random.sample(potential_windows_vulns.keys(), num_vulns)
|
||||
ret_val = {k: potential_windows_vulns[k] for k in keys}
|
||||
else:
|
||||
raise ValueError("Invalid Operating System supplied to select_random_vulnerabilities")
|
||||
return ret_val
|
||||
|
||||
|
||||
def get_properties_from_vulnerabilities(os_type: str,
|
||||
vulns: Dict[model.NodeID, model.VulnerabilityInfo]) \
|
||||
-> List[model.PropertyName]:
|
||||
"""
|
||||
get_properties_from_vulnerabilities function.
|
||||
This function takes a string for os_type and returns a list of PropertyName objects
|
||||
"""
|
||||
ret_val: Set[model.PropertyName] = set()
|
||||
properties: List[model.PropertyName] = []
|
||||
|
||||
if os_type == "Linux":
|
||||
properties = potential_linux_node_states
|
||||
elif os_type == "Windows":
|
||||
properties = potential_windows_node_states
|
||||
|
||||
for prop in properties:
|
||||
|
||||
for vuln_id, vuln in vulns.items():
|
||||
if re.search(prop, str(vuln.precondition.expression)):
|
||||
ret_val.add(prop)
|
||||
|
||||
return list(ret_val)
|
||||
|
||||
|
||||
def create_firewall_rules(end_points: List[model.PortName]) -> model.FirewallConfiguration:
|
||||
"""
|
||||
This function takes a List of endpoints and returns a FirewallConfiguration
|
||||
|
||||
It iterates through the list of potential ports and if they're in the list passed
|
||||
to the function it adds a firewall rule allowing that port.
|
||||
Otherwise it adds a rule blocking that port.
|
||||
"""
|
||||
|
||||
ret_val: model.FirewallConfiguration = model.FirewallConfiguration()
|
||||
ret_val.incoming.clear()
|
||||
ret_val.outgoing.clear()
|
||||
for protocol in potential_ports:
|
||||
if protocol in end_points:
|
||||
ret_val.incoming.append(model.FirewallRule(protocol, model.RulePermission.ALLOW))
|
||||
ret_val.outgoing.append(model.FirewallRule(protocol, model.RulePermission.ALLOW))
|
||||
else:
|
||||
ret_val.incoming.append(model.FirewallRule(protocol, model.RulePermission.BLOCK))
|
||||
ret_val.outgoing.append(model.FirewallRule(protocol, model.RulePermission.BLOCK))
|
||||
|
||||
return ret_val
|
|
@ -0,0 +1,107 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
The unit tests for the environment_generation functions
|
||||
"""
|
||||
from collections import Counter
|
||||
from typing import List, Dict
|
||||
import pytest
|
||||
from . import environment_generation
|
||||
from . import model
|
||||
|
||||
windows_vulns: Dict[str, model.VulnerabilityInfo] = environment_generation.potential_windows_vulns
|
||||
linux_vulns: Dict[str, model.VulnerabilityInfo] = environment_generation.potential_linux_vulns
|
||||
|
||||
windows_node_states: List[model.PropertyName] = environment_generation.potential_linux_node_states
|
||||
linux_node_states: List[model.PropertyName] = environment_generation.potential_linux_node_states
|
||||
|
||||
potential_ports: List[model.PortName] = environment_generation.potential_ports
|
||||
|
||||
|
||||
def test_create_random_environment() -> None:
|
||||
"""
|
||||
The unit tests for create_random_environment function
|
||||
"""
|
||||
with pytest.raises(ValueError, match=r"Please supply a non empty string for the name"):
|
||||
environment_generation.create_random_environment("", 2)
|
||||
|
||||
with pytest.raises(ValueError, match=r"Please supply a positive non zero positive"
|
||||
r"integer for the size of the environment"):
|
||||
environment_generation.create_random_environment("Test_environment", -5)
|
||||
|
||||
result: model.Environment = environment_generation.\
|
||||
create_random_environment("Test_environment 2", 4)
|
||||
assert isinstance(result, model.Environment)
|
||||
|
||||
|
||||
def test_create_random_node() -> None:
|
||||
"""
|
||||
The unit tests for create_random_node() function
|
||||
"""
|
||||
|
||||
# check that the correct exceptions are generated
|
||||
with pytest.raises(ValueError, match=r"No endpoints supplied"):
|
||||
environment_generation.create_random_node("Linux", [])
|
||||
|
||||
with pytest.raises(ValueError, match=r"Unsupported OS Type please enter Linux or Windows"):
|
||||
environment_generation.create_random_node("Solaris", potential_ports)
|
||||
|
||||
test_node: model.NodeInfo = environment_generation.create_random_node("Linux", potential_ports)
|
||||
|
||||
assert isinstance(test_node, model.NodeInfo)
|
||||
|
||||
|
||||
def test_get_properties_from_vulnerabilities() -> None:
|
||||
"""
|
||||
This function tests the get_properties_from_vulnerabilities function
|
||||
It takes nothing and returns nothing.
|
||||
"""
|
||||
# testing on linux vulns
|
||||
props: List[model.PropertyName] = environment_generation.\
|
||||
get_properties_from_vulnerabilities("Linux", linux_vulns)
|
||||
assert "Linux" in props
|
||||
assert "PortSSHOpen" in props
|
||||
assert "PortSMBOpen" in props
|
||||
|
||||
# testing on Windows vulns
|
||||
windows_props: List[model.PropertyName] = environment_generation.get_properties_from_vulnerabilities(
|
||||
"Windows", windows_vulns)
|
||||
assert "Windows" in windows_props
|
||||
assert "PortRDPOpen" in windows_props
|
||||
assert "PortSMBOpen" in windows_props
|
||||
assert "DomainJoined" in windows_props
|
||||
assert "Win10" in windows_props
|
||||
assert "Win7" in windows_props
|
||||
|
||||
|
||||
def test_create_firewall_rules() -> None:
|
||||
"""
|
||||
This function tests the create_firewall_rules function.
|
||||
It takes nothing and returns nothing.
|
||||
"""
|
||||
empty_ports: List[model.PortName] = []
|
||||
potential_port_list: List[model.PortName] = ["RDP", "SSH", "HTTP", "HTTPs",
|
||||
"SMB", "SQL", "FTP", "WMI"]
|
||||
half_ports: List[model.PortName] = ["SSH", "HTTPs", "SQL", "FTP", "WMI"]
|
||||
all_blocked: List[model.FirewallRule] = [model.FirewallRule(
|
||||
port, model.RulePermission.BLOCK) for port in potential_port_list]
|
||||
all_allowed: List[model.FirewallRule] = [model.FirewallRule(
|
||||
port, model.RulePermission.ALLOW) for port in potential_port_list]
|
||||
half_allowed: List[model.FirewallRule] = [model.FirewallRule(port, model.RulePermission.ALLOW)
|
||||
if port in half_ports else model.FirewallRule(
|
||||
port, model.RulePermission.BLOCK) for
|
||||
port in potential_port_list]
|
||||
|
||||
# testing on an empty list should lead to
|
||||
results: model.FirewallConfiguration = environment_generation.create_firewall_rules(empty_ports)
|
||||
assert Counter(results.incoming) == Counter(all_blocked)
|
||||
assert Counter(results.outgoing) == Counter(all_blocked)
|
||||
# testing on a the list supported ports
|
||||
results = environment_generation.create_firewall_rules(potential_ports)
|
||||
assert Counter(results.incoming) == Counter(all_allowed)
|
||||
assert Counter(results.outgoing) == Counter(all_allowed)
|
||||
|
||||
results = environment_generation.create_firewall_rules(half_ports)
|
||||
assert Counter(results.incoming) == Counter(half_allowed)
|
||||
assert Counter(results.outgoing) == Counter(half_allowed)
|
|
@ -0,0 +1,287 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
""" Generating random graphs"""
|
||||
from cyberbattle.simulation.model import Identifiers, NodeID, CredentialID, PortName
|
||||
import numpy as np
|
||||
import networkx as nx
|
||||
from cyberbattle.simulation import model as m
|
||||
import random
|
||||
from typing import List, Tuple, DefaultDict
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
ENV_IDENTIFIERS = Identifiers(
|
||||
properties=[
|
||||
'breach_node'
|
||||
],
|
||||
ports=['SMB', 'HTTP', 'RDP'],
|
||||
local_vulnerabilities=[
|
||||
'ScanWindowsCredentialManagerForRDP',
|
||||
'ScanWindowsExplorerRecentFiles',
|
||||
'ScanWindowsCredentialManagerForSMB'
|
||||
],
|
||||
remote_vulnerabilities=[
|
||||
'Traceroute'
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def generate_random_traffic_network(
|
||||
n_clients: int = 200,
|
||||
n_servers={
|
||||
"SMB": 1,
|
||||
"HTTP": 1,
|
||||
"RDP": 1,
|
||||
},
|
||||
seed: int = 0,
|
||||
tolerance: np.float32 = np.float32(1e-3),
|
||||
alpha=np.array([(0.1, 0.3), (0.18, 0.09)], dtype=float),
|
||||
beta=np.array([(100, 10), (10, 100)], dtype=float),
|
||||
) -> nx.DiGraph:
|
||||
"""
|
||||
Randomly generate a directed multi-edge network graph representing
|
||||
fictitious SMB, HTTP, and RDP traffic.
|
||||
|
||||
Arguments:
|
||||
n_clients: number of workstation nodes that can initiate sessions with server nodes
|
||||
n_servers: dictionary indicatin the numbers of each nodes listening to each protocol
|
||||
seed: seed for the psuedo-random number generator
|
||||
tolerance: absolute tolerance for bounding the edge probabilities in [tolerance, 1-tolerance]
|
||||
alpha: beta distribution parameters alpha such that E(edge prob) = alpha / beta
|
||||
beta: beta distribution parameters beta such that E(edge prob) = alpha / beta
|
||||
|
||||
Returns:
|
||||
(nx.classes.multidigraph.MultiDiGraph): the randomly generated network from the hierarchical block model
|
||||
"""
|
||||
edges_labels = defaultdict(set) # set backed multidict
|
||||
|
||||
for protocol in list(n_servers.keys()):
|
||||
sizes = [n_clients, n_servers[protocol]]
|
||||
# sample edge probabilities from a beta distribution
|
||||
np.random.seed(seed)
|
||||
probs: np.ndarray = np.random.beta(a=alpha, b=beta, size=(2, 2))
|
||||
# don't allow probs too close to zero or one
|
||||
probs = np.clip(probs, a_min=tolerance, a_max=np.float32(1.0 - tolerance))
|
||||
|
||||
# scale by edge type
|
||||
if protocol == "SMB":
|
||||
probs = 3 * probs
|
||||
if protocol == "RDP":
|
||||
probs = 4 * probs
|
||||
|
||||
# sample edges using block models given edge probabilities
|
||||
di_graph_for_protocol = nx.stochastic_block_model(
|
||||
sizes=sizes, p=probs, directed=True, seed=seed)
|
||||
|
||||
for edge in di_graph_for_protocol.edges:
|
||||
edges_labels[edge].add(protocol)
|
||||
|
||||
digraph = nx.DiGraph()
|
||||
for (u, v), port in list(edges_labels.items()):
|
||||
digraph.add_edge(u, v, protocol=port)
|
||||
return digraph
|
||||
|
||||
|
||||
def cyberbattle_model_from_traffic_graph(
|
||||
traffic_graph: nx.DiGraph,
|
||||
cached_smb_password_probability=0.75,
|
||||
cached_rdp_password_probability=0.8,
|
||||
cached_accessed_network_shares_probability=0.6,
|
||||
cached_password_has_changed_probability=0.1,
|
||||
traceroute_discovery_probability=0.5,
|
||||
probability_two_nodes_use_same_password_to_access_given_resource=0.8
|
||||
) -> nx.graph.Graph:
|
||||
"""Generate a random CyberBattle network model from a specified traffic (directed multi) graph.
|
||||
|
||||
The input graph can for instance be generated with `generate_random_traffic_network`.
|
||||
Each edge of the input graph indicates that a communication took place
|
||||
between the two nodes with the protocol specified in the edge label.
|
||||
|
||||
Returns a CyberBattle network with the same nodes and implanted vulnerabilities
|
||||
to be used to instantiate a CyverBattleSim gym.
|
||||
|
||||
Arguments:
|
||||
|
||||
cached_smb_password_probability, cached_rdp_password_probability:
|
||||
probability that a password used for authenticated traffic was cached by the OS for SMB and RDP
|
||||
cached_accessed_network_shares_probability:
|
||||
probability that a network share accessed by the system was cached by the OS
|
||||
cached_password_has_changed_probability:
|
||||
probability that a given password cached on a node has been rotated on the target node
|
||||
(typically low has people tend to change their password infrequently)
|
||||
probability_two_nodes_use_same_password_to_access_given_resource:
|
||||
as the variable name says
|
||||
traceroute_discovery_probability:
|
||||
probability that a target node of an SMB/RDP connection get exposed by a traceroute attack
|
||||
"""
|
||||
# convert node IDs to string
|
||||
graph = nx.relabel_nodes(traffic_graph, {i: str(i) for i in traffic_graph.nodes})
|
||||
|
||||
password_counter: int = 0
|
||||
|
||||
def generate_password() -> CredentialID:
|
||||
nonlocal password_counter
|
||||
password_counter = password_counter + 1
|
||||
return f'unique_pwd{password_counter}'
|
||||
|
||||
def traffic_targets(source_node: NodeID, protocol: str) -> List[NodeID]:
|
||||
neighbors = [t for (s, t) in graph.edges()
|
||||
if s == source_node and protocol in graph.edges[(s, t)]['protocol']]
|
||||
return neighbors
|
||||
|
||||
# Map (node, port name) -> assigned pwd
|
||||
assigned_passwords: DefaultDict[Tuple[NodeID, PortName],
|
||||
List[CredentialID]] = defaultdict(list)
|
||||
|
||||
def assign_new_valid_password(node: NodeID, port: PortName) -> CredentialID:
|
||||
pwd = generate_password()
|
||||
assigned_passwords[node, port].append(pwd)
|
||||
return pwd
|
||||
|
||||
def reuse_valid_password(node: NodeID, port: PortName) -> CredentialID:
|
||||
"""Reuse a password already assigned to that node an port, if none is already
|
||||
assigned create and assign a new valid password"""
|
||||
if (node, port) not in assigned_passwords:
|
||||
return assign_new_valid_password(node, port)
|
||||
|
||||
# reuse any of the existing assigne valid password for that node/port
|
||||
return random.choice(assigned_passwords[node, port])
|
||||
|
||||
def create_cached_credential(node: NodeID, port: PortName) -> CredentialID:
|
||||
if random.random() < cached_password_has_changed_probability:
|
||||
# generate a new invalid password
|
||||
return generate_password()
|
||||
else:
|
||||
if random.random() < probability_two_nodes_use_same_password_to_access_given_resource:
|
||||
return reuse_valid_password(node, port)
|
||||
else:
|
||||
return assign_new_valid_password(node, port)
|
||||
|
||||
def add_leak_neighbors_vulnerability(
|
||||
node_id: m.NodeID,
|
||||
library: m.VulnerabilityLibrary = {}) -> m.VulnerabilityLibrary:
|
||||
"""Create random vulnerabilities
|
||||
that reveals immediate traffic neighbors from a given node"""
|
||||
|
||||
rdp_neighbors = traffic_targets(node_id, 'RDP')
|
||||
|
||||
if len(rdp_neighbors) > 0:
|
||||
library['ScanWindowsCredentialManagerForRDP'] = m.VulnerabilityInfo(
|
||||
description="Look for RDP credentials in the Windows Credential Manager",
|
||||
type=m.VulnerabilityType.LOCAL,
|
||||
outcome=m.LeakedCredentials(credentials=[
|
||||
m.CachedCredential(node=target_node, port='RDP',
|
||||
credential=create_cached_credential(target_node, 'RDP'))
|
||||
for target_node in rdp_neighbors
|
||||
if random.random() < cached_rdp_password_probability
|
||||
]),
|
||||
reward_string="Discovered creds in the Windows Credential Manager",
|
||||
cost=2.0
|
||||
)
|
||||
|
||||
smb_neighbors = traffic_targets(node_id, 'SMB')
|
||||
|
||||
if len(smb_neighbors) > 0:
|
||||
library['ScanWindowsExplorerRecentFiles'] = m.VulnerabilityInfo(
|
||||
description="Look for network shares in the Windows Explorer Recent files",
|
||||
type=m.VulnerabilityType.LOCAL,
|
||||
outcome=m.LeakedNodesId(
|
||||
[target_node
|
||||
for target_node in smb_neighbors
|
||||
if random.random() < cached_accessed_network_shares_probability
|
||||
]
|
||||
),
|
||||
reward_string="Windows Explorer Recent Files revealed network shares",
|
||||
cost=1.0
|
||||
)
|
||||
|
||||
library['ScanWindowsCredentialManagerForSMB'] = m.VulnerabilityInfo(
|
||||
description="Look for network credentials in the Windows Credential Manager",
|
||||
type=m.VulnerabilityType.LOCAL,
|
||||
outcome=m.LeakedCredentials(credentials=[
|
||||
m.CachedCredential(node=target_node, port='SMB',
|
||||
credential=create_cached_credential(target_node, 'SMB'))
|
||||
for target_node in smb_neighbors
|
||||
if random.random() < cached_smb_password_probability
|
||||
]),
|
||||
reward_string="Discovered SMB creds in the Windows Credential Manager",
|
||||
cost=2.0
|
||||
)
|
||||
|
||||
if len(smb_neighbors) > 0 and len(rdp_neighbors) > 0:
|
||||
library['Traceroute'] = m.VulnerabilityInfo(
|
||||
description="Attempt to discvover network nodes using Traceroute",
|
||||
type=m.VulnerabilityType.REMOTE,
|
||||
outcome=m.LeakedNodesId(
|
||||
[target_node
|
||||
for target_node in smb_neighbors or rdp_neighbors
|
||||
if random.random() < traceroute_discovery_probability
|
||||
]
|
||||
),
|
||||
reward_string="Discovered new network nodes via traceroute",
|
||||
cost=5.0
|
||||
)
|
||||
|
||||
return library
|
||||
|
||||
def create_vulnerabilities_from_traffic_data(node_id: m.NodeID):
|
||||
return add_leak_neighbors_vulnerability(node_id=node_id)
|
||||
|
||||
# Pick a random node as the agent entry node
|
||||
entry_node_index = random.randrange(len(graph.nodes))
|
||||
entry_node_id, entry_node_data = list(graph.nodes(data=True))[entry_node_index]
|
||||
graph.nodes[entry_node_id].clear()
|
||||
graph.nodes[entry_node_id].update(
|
||||
{'data': m.NodeInfo(services=[],
|
||||
value=0,
|
||||
properties=["breach_node"],
|
||||
vulnerabilities=create_vulnerabilities_from_traffic_data(entry_node_id),
|
||||
agent_installed=True,
|
||||
reimagable=False)})
|
||||
|
||||
def create_node_data(node_id: m.NodeID):
|
||||
return m.NodeInfo(
|
||||
services=[m.ListeningService(name=port, allowedCredentials=assigned_passwords[(target_node, port)])
|
||||
for (target_node, port) in assigned_passwords.keys()
|
||||
if target_node == node_id
|
||||
],
|
||||
value=random.randint(0, 100),
|
||||
vulnerabilities=create_vulnerabilities_from_traffic_data(node_id),
|
||||
agent_installed=False)
|
||||
|
||||
for node in list(graph.nodes):
|
||||
if node != entry_node_id:
|
||||
graph.nodes[node].clear()
|
||||
graph.nodes[node].update({'data': create_node_data(node)})
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
def new_environment():
|
||||
"""Create a new simulation environment based on
|
||||
a randomly generated network topology.
|
||||
|
||||
NOTE: the probabilities and parameter values used
|
||||
here for the statistical generative model
|
||||
were arbirarily picked. We recommend exploring different values for those parameters.
|
||||
"""
|
||||
traffic = generate_random_traffic_network(seed=1,
|
||||
n_clients=50,
|
||||
n_servers={
|
||||
"SMB": 15,
|
||||
"HTTP": 15,
|
||||
"RDP": 15,
|
||||
},
|
||||
alpha=[(1, 1), (0.2, 0.5)],
|
||||
beta=[(1000, 10), (10, 100)])
|
||||
network = cyberbattle_model_from_traffic_graph(
|
||||
traffic,
|
||||
cached_rdp_password_probability=0.8,
|
||||
cached_smb_password_probability=0.7,
|
||||
cached_accessed_network_shares_probability=0.8,
|
||||
cached_password_has_changed_probability=0.01,
|
||||
probability_two_nodes_use_same_password_to_access_given_resource=0.9)
|
||||
return m.Environment(network=network,
|
||||
vulnerability_library=dict([]),
|
||||
identifiers=ENV_IDENTIFIERS)
|
|
@ -0,0 +1,587 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Data model for the simulation environment.
|
||||
|
||||
The simulation environment is given by the directed graph
|
||||
formally defined by:
|
||||
|
||||
Node := NodeID x ListeningService[] x Value x Vulnerability[] x FirewallConfig
|
||||
Edge := NodeID x NodeID x PortName
|
||||
|
||||
where:
|
||||
- NodeID: string
|
||||
- ListeningService : Name x AllowedCredentials
|
||||
- AllowedCredentials : string[] # credential pair represented by just a
|
||||
string ID
|
||||
- Value : [0...100] # Intrinsic value of reaching this node
|
||||
- Vulnerability : VulnerabilityID x Type x Precondition x Outcome x Rates
|
||||
- VulnerabilityID : string
|
||||
- Rates : ProbingDetectionRate x ExploitDetectionRate x SuccessRate
|
||||
- FirewallConfig: {
|
||||
outgoing : FirwallRule[]
|
||||
incoming : FirwallRule [] }
|
||||
- FirewallRule: PortName x { ALLOW, BLOCK }
|
||||
"""
|
||||
|
||||
from datetime import datetime, time
|
||||
from typing import NamedTuple, List, Dict, Optional, Union, Tuple, Iterator
|
||||
import dataclasses
|
||||
from dataclasses import dataclass
|
||||
import matplotlib.pyplot as plt # type:ignore
|
||||
from enum import Enum, IntEnum
|
||||
import boolean
|
||||
import networkx as nx
|
||||
import yaml
|
||||
import random
|
||||
|
||||
VERSION_TAG = "0.1.0"
|
||||
|
||||
ALGEBRA = boolean.BooleanAlgebra()
|
||||
# These two lines define True as the dual of False and vice versa
|
||||
# it's necessary in order to make sure the simplify function in boolean.py
|
||||
# works correctly. See https://github.com/bastikr/boolean.py/issues/82
|
||||
ALGEBRA.TRUE.dual = type(ALGEBRA.FALSE)
|
||||
ALGEBRA.FALSE.dual = type(ALGEBRA.TRUE)
|
||||
|
||||
# Type alias for identifiers
|
||||
NodeID = str
|
||||
|
||||
# A unique identifier
|
||||
ID = str
|
||||
|
||||
# a (login,password/token) credential pair is abstracted as just a unique
|
||||
# string identifier
|
||||
CredentialID = str
|
||||
|
||||
# Intrinsic value of a reaching a given node in [0,100]
|
||||
NodeValue = int
|
||||
|
||||
|
||||
PortName = str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ListeningService:
|
||||
"""A service port on a given node accepting connection initiated
|
||||
with the specified allowed credentials """
|
||||
# Name of the port the service is listening to
|
||||
name: PortName
|
||||
# credential allowed to authenticate with the service
|
||||
allowedCredentials: List[CredentialID] = dataclasses.field(default_factory=list)
|
||||
# whether the service is running or stopped
|
||||
running: bool = True
|
||||
# Weight used to evaluate the cost of not running the service
|
||||
sla_weight = 1.0
|
||||
|
||||
|
||||
x = ListeningService(name='d')
|
||||
VulnerabilityID = str
|
||||
|
||||
# Probability rate
|
||||
Probability = float
|
||||
|
||||
# The name of a node property indicating the presence of a
|
||||
# service, component, feature or vulnerability on a given node.
|
||||
PropertyName = str
|
||||
|
||||
|
||||
class Rates(NamedTuple):
|
||||
"""Probabilities associated with a given vulnerability"""
|
||||
probingDetectionRate: Probability = 0.0
|
||||
exploitDetectionRate: Probability = 0.0
|
||||
successRate: Probability = 1.0
|
||||
|
||||
|
||||
class VulnerabilityType(Enum):
|
||||
"""Is the vulnerability exploitable locally or remotely?"""
|
||||
LOCAL = 1
|
||||
REMOTE = 2
|
||||
|
||||
|
||||
class PrivilegeLevel(IntEnum):
|
||||
"""Access privilege level on a given node"""
|
||||
NoAccess = 0
|
||||
LocalUser = 1
|
||||
Admin = 2
|
||||
System = 3
|
||||
MAXIMUM = 3
|
||||
|
||||
|
||||
def escalate(current_level, escalation_level: PrivilegeLevel) -> PrivilegeLevel:
|
||||
return PrivilegeLevel(max(int(current_level), int(escalation_level)))
|
||||
|
||||
|
||||
class VulnerabilityOutcome:
|
||||
"""Outcome of exploiting a given vulnerability"""
|
||||
|
||||
|
||||
class LateralMove(VulnerabilityOutcome):
|
||||
"""Lateral movement to the target node"""
|
||||
success: bool
|
||||
|
||||
|
||||
class CustomerData(VulnerabilityOutcome):
|
||||
"""Access customer data on target node"""
|
||||
|
||||
|
||||
class PrivilegeEscalation(VulnerabilityOutcome):
|
||||
"""Privilege escalation outcome"""
|
||||
|
||||
def __init__(self, level: PrivilegeLevel):
|
||||
self.level = level
|
||||
|
||||
@property
|
||||
def tag(self):
|
||||
"""Escalation tag that gets added to node properties when
|
||||
the escalation level is reached for that node"""
|
||||
return f"privilege_{self.level}"
|
||||
|
||||
|
||||
class SystemEscalation(PrivilegeEscalation):
|
||||
"""Escalation to SYSTEM privileges"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(PrivilegeLevel.System)
|
||||
|
||||
|
||||
class AdminEscalation(PrivilegeEscalation):
|
||||
"""Escalation to local administrator privileges"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(PrivilegeLevel.Admin)
|
||||
|
||||
|
||||
class ProbeSucceeded(VulnerabilityOutcome):
|
||||
"""Probing succeeded"""
|
||||
|
||||
def __init__(self, discovered_properties: List[PropertyName]):
|
||||
self.discovered_properties = discovered_properties
|
||||
|
||||
|
||||
class ProbeFailed(VulnerabilityOutcome):
|
||||
"""Probing failed"""
|
||||
|
||||
|
||||
class ExploitFailed(VulnerabilityOutcome):
|
||||
"""This is for situations where the exploit fails """
|
||||
|
||||
|
||||
class CachedCredential(NamedTuple):
|
||||
"""Encodes a machine-port-credential triplet"""
|
||||
node: NodeID
|
||||
port: PortName
|
||||
credential: CredentialID
|
||||
|
||||
|
||||
class LeakedCredentials(VulnerabilityOutcome):
|
||||
"""A set of credentials obtained by exploiting a vulnerability"""
|
||||
|
||||
credentials: List[CachedCredential]
|
||||
|
||||
def __init__(self, credentials: List[CachedCredential]):
|
||||
self.credentials = credentials
|
||||
|
||||
|
||||
class LeakedNodesId(VulnerabilityOutcome):
|
||||
"""A set of node IDs obtained by exploiting a vulnerability"""
|
||||
|
||||
def __init__(self, nodes: List[NodeID]):
|
||||
self.nodes = nodes
|
||||
|
||||
|
||||
VulnerabilityOutcomes = Union[
|
||||
LeakedCredentials, LeakedNodesId, PrivilegeEscalation, AdminEscalation,
|
||||
SystemEscalation, CustomerData, LateralMove, ExploitFailed]
|
||||
|
||||
|
||||
class AttackResult():
|
||||
"""The result of attempting a specific attack (either local or remote)"""
|
||||
success: bool
|
||||
expected_outcome: Union[VulnerabilityOutcomes, None]
|
||||
|
||||
|
||||
class Precondition:
|
||||
""" A predicate logic expression defining the condition under which a given
|
||||
feature or vulnerability is present or not.
|
||||
The symbols used in the expression refer to properties associated with
|
||||
the corresponding node.
|
||||
E.g. 'Win7', 'Server', 'IISInstalled', 'SQLServerInstalled',
|
||||
'AntivirusInstalled' ...
|
||||
"""
|
||||
|
||||
expression: boolean.Expression
|
||||
|
||||
def __init__(self, expression: Union[boolean.Expression, str]):
|
||||
if isinstance(expression, boolean.Expression):
|
||||
self.expression = expression
|
||||
else:
|
||||
self.expression = ALGEBRA.parse(expression)
|
||||
|
||||
|
||||
class VulnerabilityInfo(NamedTuple):
|
||||
"""Definition of a known vulnerability"""
|
||||
# an optional description of what the vulnerability is
|
||||
description: str
|
||||
# type of vulnerability
|
||||
type: VulnerabilityType
|
||||
# what happens when successfully exploiting the vulnerability
|
||||
outcome: VulnerabilityOutcome
|
||||
# a boolean expression over a node's properties determining if the
|
||||
# vulnerability is present or not
|
||||
precondition: Precondition = Precondition("true")
|
||||
# rates of success/failure associated with this vulnerability
|
||||
rates: Rates = Rates()
|
||||
# points to information about the vulnerability
|
||||
URL: str = ""
|
||||
# some cost associated with exploiting this vulnerability (e.g.
|
||||
# brute force more costly than dumping credentials)
|
||||
cost: float = 1.0
|
||||
# a string displayed when the vulnerability is successfully exploited
|
||||
reward_string: str = ""
|
||||
|
||||
|
||||
# A dictionary storing information about all supported vulnerabilities
|
||||
# or features supported by the simulation.
|
||||
# This is to be used as a global dictionary pre-populated before
|
||||
# starting the simulation and estimated from real-world data.
|
||||
VulnerabilityLibrary = Dict[VulnerabilityID, VulnerabilityInfo]
|
||||
|
||||
|
||||
class RulePermission(Enum):
|
||||
"""Determine if a rule is blocks or allows traffic"""
|
||||
ALLOW = 0
|
||||
BLOCK = 1
|
||||
|
||||
|
||||
class FirewallRule(NamedTuple):
|
||||
"""A firewall rule"""
|
||||
# A port name
|
||||
port: PortName
|
||||
# permission on this port
|
||||
permission: RulePermission
|
||||
# An optional reason for the block/allow rule
|
||||
reason: str = ""
|
||||
|
||||
|
||||
class FirewallConfiguration(NamedTuple):
|
||||
"""Firewall configuration on a given node.
|
||||
Determine if traffic should be allowed or specifically blocked
|
||||
on a given port for outgoing and incoming traffic.
|
||||
The rules are process in order: the first rule matching a given
|
||||
port is applied and the rest are ignored.
|
||||
|
||||
Port that are not listed in the configuration
|
||||
are assumed to be blocked. (Adding an explicit block rule
|
||||
can still be useful to give a reason for the block.)
|
||||
"""
|
||||
outgoing: List[FirewallRule] = [
|
||||
FirewallRule("RDP", RulePermission.ALLOW),
|
||||
FirewallRule("SSH", RulePermission.ALLOW),
|
||||
FirewallRule("HTTPS", RulePermission.ALLOW),
|
||||
FirewallRule("HTTP", RulePermission.ALLOW)]
|
||||
incoming: List[FirewallRule] = [
|
||||
FirewallRule("RDP", RulePermission.ALLOW),
|
||||
FirewallRule("SSH", RulePermission.ALLOW),
|
||||
FirewallRule("HTTPS", RulePermission.ALLOW),
|
||||
FirewallRule("HTTP", RulePermission.ALLOW)]
|
||||
|
||||
|
||||
class MachineStatus(Enum):
|
||||
"""Machine running status"""
|
||||
Stopped = 0
|
||||
Running = 1
|
||||
Imaging = 2
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeInfo:
|
||||
"""A computer node in the enterprise network"""
|
||||
# List of port/protocol the node is listening to
|
||||
services: List[ListeningService]
|
||||
# List of known vulnerabilities for the node
|
||||
vulnerabilities: VulnerabilityLibrary = dataclasses.field(default_factory=dict)
|
||||
# Intrinsic value of the node (translates into a reward if the node gets owned)
|
||||
value: NodeValue = 0
|
||||
# Properties of the nodes, some of which can imply further vulnerabilities
|
||||
properties: List[PropertyName] = dataclasses.field(default_factory=list)
|
||||
# Fireall configuration of the node
|
||||
firewall: FirewallConfiguration = FirewallConfiguration()
|
||||
# Attacker agent installed on the node? (aka the node is 'pwned')
|
||||
agent_installed: bool = False
|
||||
# Esclation level
|
||||
privilege_level: PrivilegeLevel = PrivilegeLevel.NoAccess
|
||||
# Can the node be re-imaged by a defender agent?
|
||||
reimagable: bool = True
|
||||
# Last time the node was reimaged
|
||||
last_reimaging: Optional[time] = None
|
||||
# String displayed when the node gets owned
|
||||
owned_string: str = ""
|
||||
# Machine status: running or stopped
|
||||
status = MachineStatus.Running
|
||||
# Relative node weight used to calculate the cost of stopping this machine
|
||||
# or its services
|
||||
sla_weight: float = 1.0
|
||||
|
||||
|
||||
class Identifiers(NamedTuple):
|
||||
"""Define the global set of identifiers used
|
||||
in the definition of a given environment.
|
||||
Such set defines a common vocabulary possibly
|
||||
shared across multiple environments, thus
|
||||
ensuring a consistent numbering convention
|
||||
that a machine learniong model can learn from."""
|
||||
# Array of all possible node property identifiers
|
||||
properties: List[PropertyName] = []
|
||||
# Array of all possible port names
|
||||
ports: List[PortName] = []
|
||||
# Array of all possible local vulnerabilities names
|
||||
local_vulnerabilities: List[VulnerabilityID] = []
|
||||
# Array of all possible remote vulnerabilities names
|
||||
remote_vulnerabilities: List[VulnerabilityID] = []
|
||||
|
||||
|
||||
def iterate_network_nodes(network: nx.graph.Graph) -> Iterator[Tuple[NodeID, NodeInfo]]:
|
||||
"""Iterates over the nodes in the network"""
|
||||
for nodeid, nodevalue in network.nodes.items():
|
||||
node_data: NodeInfo = nodevalue['data']
|
||||
yield nodeid, node_data
|
||||
|
||||
|
||||
class Environment(NamedTuple):
|
||||
""" The static graph defining the network of computers """
|
||||
network: nx.graph.Graph
|
||||
vulnerability_library: VulnerabilityLibrary
|
||||
identifiers: Identifiers
|
||||
creationTime: datetime = datetime.utcnow()
|
||||
lastModified: datetime = datetime.utcnow()
|
||||
# a version tag indicating the environment schema version
|
||||
version: str = VERSION_TAG
|
||||
|
||||
def nodes(self) -> Iterator[Tuple[NodeID, NodeInfo]]:
|
||||
"""Iterates over the nodes in the network"""
|
||||
return iterate_network_nodes(self.network)
|
||||
|
||||
def get_node(self, node_id: NodeID) -> NodeInfo:
|
||||
"""Retrieve info for the node with the specified ID"""
|
||||
node_info: NodeInfo = self.network.nodes[node_id]['data']
|
||||
return node_info
|
||||
|
||||
def plot_environment_graph(self) -> None:
|
||||
"""Plot the full environment graph"""
|
||||
nx.draw(self.network,
|
||||
with_labels=True,
|
||||
node_color=[n['data'].value
|
||||
for i, n in
|
||||
self.network.nodes.items()],
|
||||
cmap=plt.cm.Oranges) # type:ignore
|
||||
|
||||
|
||||
def create_network(nodes: Dict[NodeID, NodeInfo]) -> nx.DiGraph:
|
||||
"""Create a network with a set of nodes and no edges"""
|
||||
graph = nx.DiGraph()
|
||||
graph.add_nodes_from([(k, {'data': v}) for (k, v) in list(nodes.items())])
|
||||
return graph
|
||||
|
||||
# Helpers to infer constants from an environment
|
||||
|
||||
|
||||
def collect_ports_from_vuln(vuln: VulnerabilityInfo) -> List[PortName]:
|
||||
"""Returns all the port named referenced in a given vulnerability"""
|
||||
if isinstance(vuln.outcome, LeakedCredentials):
|
||||
return [c.port for c in vuln.outcome.credentials]
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
def collect_vulnerability_ids_from_nodes_bytype(
|
||||
nodes: Iterator[Tuple[NodeID, NodeInfo]],
|
||||
global_vulnerabilities: VulnerabilityLibrary,
|
||||
type: VulnerabilityType) -> List[VulnerabilityID]:
|
||||
"""Collect and return all IDs of all vulnerability of the specified type
|
||||
that are referenced in a given set of nodes and vulnerability library
|
||||
"""
|
||||
return sorted(list({
|
||||
id
|
||||
for _, node_info in nodes
|
||||
for id, v in node_info.vulnerabilities.items()
|
||||
if v.type == type
|
||||
}.union(
|
||||
id
|
||||
for id, v in global_vulnerabilities.items()
|
||||
if v.type == type
|
||||
)))
|
||||
|
||||
|
||||
def collect_properties_from_nodes(nodes: Iterator[Tuple[NodeID, NodeInfo]]) -> List[PropertyName]:
|
||||
"""Collect and return sorted list of all property names used in a given set of nodes"""
|
||||
return sorted({
|
||||
p
|
||||
for _, node_info in nodes
|
||||
for p in node_info.properties
|
||||
})
|
||||
|
||||
|
||||
def collect_ports_from_nodes(
|
||||
nodes: Iterator[Tuple[NodeID, NodeInfo]],
|
||||
vulnerability_library: VulnerabilityLibrary) -> List[PortName]:
|
||||
"""Collect and return all port names used in a given set of nodes
|
||||
and global vulnerability library"""
|
||||
return sorted(list({
|
||||
port
|
||||
for _, v in vulnerability_library.items()
|
||||
for port in collect_ports_from_vuln(v)
|
||||
}.union({
|
||||
port
|
||||
for _, node_info in nodes
|
||||
for _, v in node_info.vulnerabilities.items()
|
||||
for port in collect_ports_from_vuln(v)
|
||||
}.union(
|
||||
{service.name
|
||||
for _, node_info in nodes
|
||||
for service in node_info.services}))))
|
||||
|
||||
|
||||
def collect_ports_from_environment(environment: Environment) -> List[PortName]:
|
||||
"""Collect and return all port names used in a given environment"""
|
||||
return collect_ports_from_nodes(environment.nodes(), environment.vulnerability_library)
|
||||
|
||||
|
||||
def infer_constants_from_nodes(
|
||||
nodes: Iterator[Tuple[NodeID, NodeInfo]],
|
||||
vulnerabilities: Dict[VulnerabilityID, VulnerabilityInfo]) -> Identifiers:
|
||||
"""Infer global environment constants from a given network"""
|
||||
return Identifiers(
|
||||
properties=collect_properties_from_nodes(nodes),
|
||||
ports=collect_ports_from_nodes(nodes, vulnerabilities),
|
||||
local_vulnerabilities=collect_vulnerability_ids_from_nodes_bytype(
|
||||
nodes, vulnerabilities, VulnerabilityType.LOCAL),
|
||||
remote_vulnerabilities=collect_vulnerability_ids_from_nodes_bytype(
|
||||
nodes, vulnerabilities, VulnerabilityType.REMOTE)
|
||||
)
|
||||
|
||||
|
||||
def infer_constants_from_network(
|
||||
network: nx.Graph,
|
||||
vulnerabilities: Dict[VulnerabilityID, VulnerabilityInfo]) -> Identifiers:
|
||||
"""Infer global environment constants from a given network"""
|
||||
return infer_constants_from_nodes(iterate_network_nodes(network), vulnerabilities)
|
||||
|
||||
|
||||
# Network creation
|
||||
|
||||
# A sample set of envrionment constants
|
||||
SAMPLE_IDENTIFIERS = Identifiers(
|
||||
ports=['RDP', 'SSH', 'SMB', 'HTTP', 'HTTPS', 'WMI', 'SQL'],
|
||||
properties=[
|
||||
'Windows', 'Linux', 'HyperV-VM', 'Azure-VM', 'Win7', 'Win10',
|
||||
'PortRDPOpen', 'GuestAccountEnabled']
|
||||
)
|
||||
|
||||
|
||||
def assign_random_labels(
|
||||
graph: nx.Graph,
|
||||
vulnerabilities: VulnerabilityLibrary = dict([]),
|
||||
identifiers: Identifiers = SAMPLE_IDENTIFIERS) -> nx.Graph:
|
||||
"""Create an envrionment network by randomly assigning node information
|
||||
(properties, firewall configuration, vulnerabilities)
|
||||
to the nodes of a given graph structure"""
|
||||
|
||||
# convert node IDs to string
|
||||
graph = nx.relabel_nodes(graph, {i: str(i) for i in graph.nodes})
|
||||
|
||||
def create_random_firewall_configuration() -> FirewallConfiguration:
|
||||
return FirewallConfiguration(
|
||||
outgoing=[
|
||||
FirewallRule(port=p, permission=RulePermission.ALLOW)
|
||||
for p in
|
||||
random.sample(
|
||||
identifiers.properties,
|
||||
k=random.randint(0, len(identifiers.properties)))],
|
||||
incoming=[
|
||||
FirewallRule(port=p, permission=RulePermission.ALLOW)
|
||||
for p in random.sample(
|
||||
identifiers.properties,
|
||||
k=random.randint(0, len(identifiers.properties)))])
|
||||
|
||||
def create_random_properties() -> List[PropertyName]:
|
||||
return list(random.sample(
|
||||
identifiers.properties,
|
||||
k=random.randint(0, len(identifiers.properties))))
|
||||
|
||||
def pick_random_global_vulnerabilities() -> VulnerabilityLibrary:
|
||||
count = random.random()
|
||||
return {k: v for (k, v) in vulnerabilities.items() if random.random() > count}
|
||||
|
||||
def add_leak_neighbors_vulnerability(library: VulnerabilityLibrary, node_id: NodeID) -> None:
|
||||
"""Create a vulnerability for each node that reveals its immediate neighbors"""
|
||||
neighbors = {t for (s, t) in graph.edges() if s == node_id}
|
||||
if len(neighbors) > 0:
|
||||
library['RecentlyAccessedMachines'] = VulnerabilityInfo(
|
||||
description="AzureVM info, including public IP address",
|
||||
type=VulnerabilityType.LOCAL,
|
||||
outcome=LeakedNodesId(list(neighbors)))
|
||||
|
||||
def create_random_vulnerabilities(node_id: NodeID) -> VulnerabilityLibrary:
|
||||
library = pick_random_global_vulnerabilities()
|
||||
add_leak_neighbors_vulnerability(library, node_id)
|
||||
return library
|
||||
|
||||
# Pick a random node as the agent entry node
|
||||
entry_node_index = random.randrange(len(graph.nodes))
|
||||
entry_node_id, entry_node_data = list(graph.nodes(data=True))[entry_node_index]
|
||||
graph.nodes[entry_node_id].clear()
|
||||
node_data = NodeInfo(services=[],
|
||||
value=0,
|
||||
properties=create_random_properties(),
|
||||
vulnerabilities=create_random_vulnerabilities(entry_node_id),
|
||||
firewall=create_random_firewall_configuration(),
|
||||
agent_installed=True,
|
||||
reimagable=False,
|
||||
privilege_level=PrivilegeLevel.Admin)
|
||||
graph.nodes[entry_node_id].update({'data': node_data})
|
||||
|
||||
def create_random_node_data(node_id: NodeID) -> NodeInfo:
|
||||
return NodeInfo(
|
||||
services=[],
|
||||
value=random.randint(0, 100),
|
||||
properties=create_random_properties(),
|
||||
vulnerabilities=create_random_vulnerabilities(node_id),
|
||||
firewall=create_random_firewall_configuration(),
|
||||
agent_installed=False,
|
||||
privilege_level=PrivilegeLevel.NoAccess)
|
||||
|
||||
for node in list(graph.nodes):
|
||||
if node != entry_node_id:
|
||||
graph.nodes[node].clear()
|
||||
graph.nodes[node].update({'data': create_random_node_data(node)})
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
# Serialization
|
||||
|
||||
def setup_yaml_serializer() -> None:
|
||||
"""Setup a clean YAML formatter for object of type Environment.
|
||||
"""
|
||||
|
||||
yaml.add_representer(Precondition,
|
||||
lambda dumper, data: dumper.represent_scalar('!BooleanExpression',
|
||||
str(data.expression))) # type: ignore
|
||||
yaml.SafeLoader.add_constructor('!BooleanExpression',
|
||||
lambda loader, expression: Precondition(
|
||||
loader.construct_scalar(expression))) # type: ignore
|
||||
yaml.add_constructor('!BooleanExpression',
|
||||
lambda loader, expression:
|
||||
Precondition(loader.construct_scalar(expression))) # type: ignore
|
||||
|
||||
yaml.add_representer(VulnerabilityType,
|
||||
lambda dumper, data: dumper.represent_scalar('!VulnerabilityType',
|
||||
str(data.name))) # type: ignore
|
||||
|
||||
yaml.SafeLoader.add_constructor('!VulnerabilityType',
|
||||
lambda loader, expression: VulnerabilityType[
|
||||
loader.construct_scalar(expression)]) # type: ignore
|
||||
yaml.add_constructor('!VulnerabilityType',
|
||||
lambda loader, expression: VulnerabilityType[
|
||||
loader.construct_scalar(expression)]) # type: ignore
|
|
@ -0,0 +1,132 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
Unit tests for model.py.
|
||||
|
||||
Note that model.py mainly provides the data modelling for the simulation,
|
||||
that is naked data types without members. There is therefore not much
|
||||
relevant unit testing that can be implemented at this stage.
|
||||
Once we add operations to generate and modify environments there
|
||||
will be more room for unit-testing.
|
||||
|
||||
"""
|
||||
# pylint: disable=missing-function-docstring
|
||||
|
||||
from cyberbattle.simulation.model import AdminEscalation, Identifiers, SystemEscalation
|
||||
import yaml
|
||||
from datetime import datetime
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from . import model
|
||||
|
||||
ADMINTAG = AdminEscalation().tag
|
||||
SYSTEMTAG = SystemEscalation().tag
|
||||
|
||||
vulnerabilities = {
|
||||
"UACME61":
|
||||
model.VulnerabilityInfo(
|
||||
description="UACME UAC bypass #61",
|
||||
type=model.VulnerabilityType.LOCAL,
|
||||
URL="https://github.com/hfiref0x/UACME",
|
||||
precondition=model.Precondition(f"Windows&Win10&(~({ADMINTAG}|{SYSTEMTAG}))"),
|
||||
outcome=model.AdminEscalation(),
|
||||
rates=model.Rates(0, 0.2, 1.0))}
|
||||
|
||||
ENV_IDENTIFIERS = Identifiers(
|
||||
properties=[],
|
||||
ports=[],
|
||||
local_vulnerabilities=["UACME61"],
|
||||
remote_vulnerabilities=[]
|
||||
)
|
||||
|
||||
|
||||
# Verify that there is a unique node injected with the
|
||||
# attacker in a randomly generated graph
|
||||
def test_single_infected_node_initially() -> None:
|
||||
# create a random environment
|
||||
graph = nx.cubical_graph()
|
||||
graph = model.assign_random_labels(graph)
|
||||
env = model.Environment(network=graph,
|
||||
vulnerability_library=dict([]),
|
||||
identifiers=ENV_IDENTIFIERS)
|
||||
count = sum(1 for i in graph.nodes
|
||||
if env.get_node(i).agent_installed)
|
||||
assert count == 1
|
||||
return
|
||||
|
||||
|
||||
# ensures that an environment can successfully be serialized as yaml
|
||||
def test_environment_is_serializable() -> None:
|
||||
# create a random environment
|
||||
env = model.Environment(
|
||||
network=model.assign_random_labels(nx.cubical_graph()),
|
||||
version=model.VERSION_TAG,
|
||||
vulnerability_library=dict([]),
|
||||
identifiers=ENV_IDENTIFIERS,
|
||||
creationTime=datetime.utcnow(),
|
||||
lastModified=datetime.utcnow(),
|
||||
)
|
||||
# Dump the environment as Yaml
|
||||
_ = yaml.dump(env)
|
||||
assert True
|
||||
|
||||
|
||||
# Test random graph get_node_information
|
||||
def test_create_random_environment() -> None:
|
||||
graph = nx.cubical_graph()
|
||||
|
||||
graph = model.assign_random_labels(graph)
|
||||
|
||||
env = model.Environment(
|
||||
network=graph,
|
||||
vulnerability_library=vulnerabilities,
|
||||
identifiers=ENV_IDENTIFIERS
|
||||
)
|
||||
assert env
|
||||
pass
|
||||
|
||||
|
||||
def check_reserializing(object_to_serialize: object) -> None:
|
||||
"""Helper function to check that deserializing and serializing are inverse of each other"""
|
||||
serialized = yaml.dump(object_to_serialize)
|
||||
# print('Serialized: ' + serialized)
|
||||
deserialized = yaml.load(serialized, yaml.Loader)
|
||||
re_serialized = yaml.dump(deserialized)
|
||||
assert (serialized == re_serialized)
|
||||
|
||||
|
||||
def test_yaml_serialization_environment() -> None:
|
||||
"""Test Yaml serialization and deserialization for type Environment"""
|
||||
network = model.assign_random_labels(nx.cubical_graph())
|
||||
env = model.Environment(
|
||||
network=network,
|
||||
vulnerability_library=vulnerabilities,
|
||||
identifiers=model.infer_constants_from_network(network, vulnerabilities))
|
||||
|
||||
model.setup_yaml_serializer()
|
||||
|
||||
serialized = yaml.dump(env)
|
||||
assert (len(serialized) > 100)
|
||||
|
||||
check_reserializing(env)
|
||||
|
||||
|
||||
def test_yaml_serialization_precondition() -> None:
|
||||
"""Test Yaml serialization and deserialization for type Precondition"""
|
||||
model.setup_yaml_serializer()
|
||||
|
||||
precondition = model.Precondition(f"Windows&Win10&(~({ADMINTAG}|{SYSTEMTAG}))")
|
||||
check_reserializing(precondition)
|
||||
|
||||
deserialized = yaml.safe_load(yaml.dump(precondition))
|
||||
assert (precondition.expression == deserialized.expression)
|
||||
|
||||
|
||||
def test_yaml_serialization_vulnerabilitytype() -> None:
|
||||
"""Test Yaml serialization and deserialization for type VulnerabilityType"""
|
||||
model.setup_yaml_serializer()
|
||||
|
||||
object_to_serialize = model.VulnerabilityType.LOCAL
|
||||
check_reserializing(object_to_serialize)
|
После Ширина: | Высота: | Размер: 56 KiB |
После Ширина: | Высота: | Размер: 45 KiB |
После Ширина: | Высота: | Размер: 67 KiB |
После Ширина: | Высота: | Размер: 192 KiB |
После Ширина: | Высота: | Размер: 103 KiB |
После Ширина: | Высота: | Размер: 18 KiB |
После Ширина: | Высота: | Размер: 32 KiB |
После Ширина: | Высота: | Размер: 62 KiB |
После Ширина: | Высота: | Размер: 92 KiB |
После Ширина: | Высота: | Размер: 205 KiB |
После Ширина: | Высота: | Размер: 57 KiB |
После Ширина: | Высота: | Размер: 81 KiB |
|
@ -0,0 +1,12 @@
|
|||
.packages
|
||||
PackageRoot/*
|
||||
**/TestResults/*
|
||||
**/.vs/*
|
||||
**/bin/*
|
||||
**/obj/*
|
||||
msbuild.log
|
||||
*.user
|
||||
*.trace
|
||||
**/.nupkg/*
|
||||
**/.pkg/*
|
||||
**/packages/*
|
|
@ -0,0 +1,39 @@
|
|||
# Benchmark on baseline agent implementations
|
||||
|
||||
Results obtained on envrionment `CyberBattleChain10` using several algorithms: purely random search, Q-learning with epsilon-greedy, exploiting learned matrix only.
|
||||
|
||||
## Time to full ownership on `chain` network environment
|
||||
|
||||
Training plot showing duration taken to take over the entire network (Y-axis) across successive episodes (X-axis) for different attacker agent implementations (Tabular Q-learning, epsilon-greedy, Deep Q-Learning, Exploiting Q-function learnt from DQL).
|
||||
Lower number of iterations is better (the game is set to terminate when the attacker owns the entire network).
|
||||
Best attacker so far is the Deep Q-Learning
|
||||
|
||||
![image.png](.attachments/image-54d83b7b-65d1-4d6a-b0f6-d41b31460c81.png)
|
||||
|
||||
The next plot shows the cumulative reward function. Note that when once all the network nodes are owned there may still be additional rewards to be obtained by exploiting vulnerabilities on the owned nodes,
|
||||
but on this experiment we terminate the game as soon as the agent owns all the nodes. This explains why
|
||||
the DQL agent, which optimizes for network ownership, does not get to reach the maximum possible reward despite
|
||||
beating all the other agents. The gym's `done` function can easily re-configured to target a specific reward instead, in which case the DQL also beats the other agents.
|
||||
|
||||
![image.png](.attachments/image-f8f00fe7-466f-4d2b-aaee-dd20720854db.png)
|
||||
|
||||
|
||||
### Choice of features
|
||||
|
||||
With Q-learning, the best results were obtained when training on features that include network size-specific features, such as the number of discovered nodes (left). Good results are also obtained when using features that are not dependent on the network size (right).
|
||||
|
||||
| | Size agnostic features | Size-dependant features|
|
||||
|--|--|--|
|
||||
| Tabular Q vs Random | ![image.png](.attachments/image-9b0b1506-880b-43e5-91b0-75f2ce3fb032.png) | ![image.png](.attachments/image-41f45aa6-0af8-4c67-afec-24b1adc910c2.png) |
|
||||
|
||||
|
||||
## Transfer Learning on `chain` environment
|
||||
|
||||
This benchmark aims to measure the ability to learn a strategy from one environment and
|
||||
apply it to similar envrionments of a different size. We train the agent on an environment of size $x$ and evaluate it on an environment of size $y>x$.
|
||||
|
||||
As expected, using features that are proportional to the size of the environment (such as the number of nodes or number of number of credentials) did not provide best results. The agent fared better instead when using temporal features like a sliding-window of ports and node properties recently discovered.
|
||||
|
||||
| | Train on size 4 , evaluated on size 10 | - Train on size 10, evaluated on size 4 |
|
||||
|---|---|---|
|
||||
Tabular Q vs Random | ![image.png](.attachments/image-11d24066-875d-43ac-87cb-e91453688028.png) | ![image.png](.attachments/image-daf58e3d-a4a8-4810-8a5e-1c976a24b266.png) |
|
|
@ -0,0 +1,5 @@
|
|||
pandoc -o docs.tex quickintro.md benchmark.md ../README.md -s
|
||||
pdflatex docs.tex
|
||||
# \DeclareUnicodeCharacter{2500}{─}
|
||||
# \DeclareUnicodeCharacter{03F5}{ϵ}
|
||||
# pandoc -o docs.html quickintro.md benchmark.md ../README.md
|
|
@ -0,0 +1,190 @@
|
|||
# Quick introduction to CyberBattleSim
|
||||
|
||||
## What is it?
|
||||
|
||||
A high-level parameterizable model of enterprise networks that simulates the execution of attack and defense cyber-agent.
|
||||
A network topology and set of pre-defined vulneraiblities defines the arena on which the simulation is played.
|
||||
The attacker evolves in the network via lateral movements by exploiting existing vulneratiblities.
|
||||
The defender attempts to contain the attacker and evict it from the nework.
|
||||
CyberBattleSim offers an OpenAI Gym interface to its simulation to facilitate experimentation with Reinforcement Learning algorithms.
|
||||
|
||||
## Why a simulation environment?
|
||||
|
||||
Emulation runtime environments provide high fidelity and control: you can take existing code or binaries and run them directly in virtual machines running full-fledged Operating Systems and connected over a virtualized network, while giving access to the full system state. This comes at a performance cost, however.
|
||||
While simulated environments suffers from lack of realism, they tend to be Lightweight, Fast, Abstract, and more Controllable, which
|
||||
makes them more amenable to experimentations with Reinforcement Learning.
|
||||
Advantages of simulation include:
|
||||
|
||||
- Higher-level abstraction: we can modeled aspects of the system that matters to us, like application-level network communication versus packet-level network simulation. We can ignore low-level details if deemed uncessary (E.g., file system, registry).
|
||||
- Flexibility: Defining new machine sensors is straightforward (e.g., does not require low-level/driver code changes); we can restrict the action space to a manageable and relevant subset.
|
||||
- The global state is efficiently capturable, simplifying debugging and diagnosis.
|
||||
- A lightweight runtime footprint: run in memory on single machine/process.
|
||||
|
||||
We may explore the use of emulation technologies in the future, in particular for benchmarking purpose, as it would provide a more realistic assessment of the agents performances.
|
||||
|
||||
A design principle adopted is for the simulation to model just enough _complexity_ to represent attack techniques from the [MITRE matrix](https://attack.mitre.org/) while maintaining the _simplicity_ required to efficiently train an agent using Reinforcement Learning techniques.
|
||||
|
||||
On the attack-side, the current simulation focuses more particularly on the `Lateral movement` technique which are intrinsic to all post-breach attacks.
|
||||
|
||||
## How the simulation works
|
||||
|
||||
Let us go through a toy example and introduce how simulation works
|
||||
using RL terminology.
|
||||
Our network __environment__ is given by a directed annotated graph where
|
||||
nodes represent computers and edges represent knowledge of other nodes or communication taking place between nodes.
|
||||
![image.png](.attachments/image-377114ff-cdb7-4bee-88da-cac09640f661.png)
|
||||
Here you can see a toy example of network
|
||||
with machines running different OSes, software.
|
||||
Each machine has properties, a value, and suffers from pre-assigned vulnerabilities.
|
||||
Blue edges represent traffic running between nodes and
|
||||
are labelled by the communication protocol.
|
||||
![image.png](.attachments/image-9f950a75-2c63-457a-b109-56091f84711a.png)
|
||||
|
||||
There is a **single agent**: the attacker.
|
||||
|
||||
Initially, one node is infected (post-breach assumption)
|
||||
|
||||
Its **goal** is to maximize reward by discovering and 'owning' nodes
|
||||
in the network.
|
||||
|
||||
The environment is **partially observable**: the agent does not get to see all the
|
||||
nodes and edges of the network graph in advance.
|
||||
|
||||
Instead the attacker takes actions to gradually observe the environment. There are **three kinds of actions**
|
||||
offering a mix of exploitation and exploration capabilities
|
||||
to the agent:
|
||||
- perform a local attack,
|
||||
- perform a remote attack,
|
||||
- connect to other nodes.
|
||||
|
||||
The **reward** is a float represents the intrinsic value
|
||||
of a node (e.g., a SQL server has greater value than a test machine).
|
||||
|
||||
The attacker breaches into the network from the Win7 node
|
||||
on the left pointed by the fat orange arrow,
|
||||
- then proceeds with a lateral move to the Win8 node
|
||||
by exploiting a vulnerability in SMB,
|
||||
- then uses some cached credential to log into a Win7 machine,
|
||||
- exploits an IIS remote vulnerability to own the IIS server,
|
||||
- and finally uses leaked connection strings to get to the SQL DB.
|
||||
|
||||
|
||||
## What kind of vulnerabilities and attacks are supported by the simulation?
|
||||
|
||||
The simulation gym environment is parameterized by the network definition which consists of the underlying network graph itself together with description of supported vulnerabilities and the nodes were they are present.
|
||||
Since the simulation does not run any code, there is no way to actually implement vulnerabilities and exploits. Instead, we model each vulnerability abstractly by defining: a pre-condition determining if the a vulnerability is active on a given node; the probability that it can be successfully exploited by an attacker; and the side-effects of a successfful exploit. Each node has a set of assigned named-properties. The pre-condition is then expressed as a boolean expression over the set of possible node properties (or flags).
|
||||
|
||||
### Vulnerability outcomes
|
||||
Each vulnerability has a pre-defined outcome which may include:
|
||||
|
||||
- A leaked set of credentials;
|
||||
- A leaked reference to another node in the network;
|
||||
- Leaked information about a node (node properties);
|
||||
- Ownerhsip to a node;
|
||||
- Privilege escalation on the node.
|
||||
|
||||
Example of _remote_ vulnerabilities include:
|
||||
|
||||
- A SharePoint site exposing `ssh` credentials (but not necessarily the ID of the remote machine);
|
||||
- An `ssh` vulnerability granting access to the machine;
|
||||
- A github project leaking credentials in commit history;
|
||||
- A SharePoint site with file containing SAS token to storage account;
|
||||
|
||||
Examples of _local_ vulnerabilities:
|
||||
|
||||
- Extracting authentication token or credentials from a system cache;
|
||||
- Escalating to SYSTEM privileges;
|
||||
- Escalating to Administrator privileges.
|
||||
|
||||
Vulnerabilities can either be defined in-placed at the node level, or can be defined globally and activated by the pre-condition boolean expression.
|
||||
|
||||
|
||||
## Toy Example
|
||||
|
||||
Consider as a toy example, the 'Capture the flag' game played on the computer system depicted below:
|
||||
|
||||
![image.png](.attachments/image-8cfbbc68-6db1-42f2-867d-5502ff56c4b3.png)
|
||||
|
||||
Each graph node is a computing resource with implanted security flaws and vulnerabilities such as reused password, insucure passwords, leaked access tokens, misconfigured Access control, browsable direcotries, and so on. The goal of the attacker is to take ownership of critical nodes in the graph (e.g., Azure and Sharepoint resources). For simplicity we assume that no defender is present and that the game is fully static (no external events between two action of the attacker).
|
||||
|
||||
We formally defined this network in Python code at [toy_ctf.py](../cyberbattle/samples/toyctf/toy_ctf.py).
|
||||
Here is a snippet of the code showing how we define the noe `Website` with its properties, firewall configuration and implanted vulnerabilities:
|
||||
|
||||
```python
|
||||
nodes = {
|
||||
"Website": m.NodeInfo(
|
||||
services=[
|
||||
m.ListeningService("HTTPS"),
|
||||
m.ListeningService("SSH", allowedCredentials=[
|
||||
"ReusedMySqlCred-web"])],
|
||||
firewall=m.FirewallConfiguration(
|
||||
incoming=default_allow_rules,
|
||||
outgoing=default_allow_rules
|
||||
+ [
|
||||
m.FirewallRule("su", m.RulePermission.ALLOW),
|
||||
m.FirewallRule("sudo", m.RulePermission.ALLOW)]),
|
||||
value=100,
|
||||
properties=["MySql", "Ubuntu", "nginx/1.10.3"],
|
||||
owned_string="FLAG: Login using insecure SSH user/password",
|
||||
vulnerabilities=dict(
|
||||
ScanPageContent=m.VulnerabilityInfo(
|
||||
description="Website page content shows a link to GitHub repo",
|
||||
type=m.VulnerabilityType.REMOTE,
|
||||
outcome=m.LeakedNodesId(["GitHubProject"]),
|
||||
reward_string="page content has a link to a Github project",
|
||||
cost=1.0
|
||||
),
|
||||
ScanPageSource=m.VulnerabilityInfo(
|
||||
description="Website page source contains refrence to"
|
||||
"browseable relative web directory",
|
||||
type=m.VulnerabilityType.REMOTE,
|
||||
outcome=m.LeakedNodesId(["Website.Directory"]),
|
||||
reward_string="Viewing the web page source reveals a URL to"
|
||||
"a .txt file and directory on the website",
|
||||
cost=1.0
|
||||
),
|
||||
...
|
||||
)
|
||||
),
|
||||
...
|
||||
```
|
||||
|
||||
### Interactive play with ToyCTF
|
||||
|
||||
You can play the simulation interactively using the Jupyter notebook located at [toyctf-blank.ipynb](notebooks/toyctf-blank.ipynb). Try the following commands:
|
||||
|
||||
```python
|
||||
env.plot_environment_graph()
|
||||
plot()
|
||||
|
||||
c2.run_attack('client', 'SearchEdgeHistory')
|
||||
plot()
|
||||
|
||||
c2.run_remote_attack('client', 'Website', 'ScanPageContent')
|
||||
plot()
|
||||
|
||||
c2.run_remote_attack('client', 'Website', 'ScanPageSource')
|
||||
plot()
|
||||
|
||||
c2.run_remote_attack('client', 'Website.Directory', 'NavigateWebDirectoryFurther')
|
||||
plot()
|
||||
|
||||
c2.connect_and_infect('client', 'Website', 'SSH', 'ReusedMySqlCred-web')
|
||||
plot()
|
||||
```
|
||||
The plot function displays the subset of the network explorsed so far. After a few attempts the explored network should looks like this:
|
||||
|
||||
![image.png](.attachments/image-4fd79f98-36b2-45ae-82e4-2631aacda090.png)
|
||||
|
||||
### Solution
|
||||
The fully solved game is provided in the notebook [toyctf-solved.ipynb](notebooks/toyctf-solved.ipynb).
|
||||
|
||||
### A random agent playing ToyCTF
|
||||
|
||||
The notebook [toyctf-random.ipynb](notebooks/toyctf-random.ipynb)
|
||||
runs a random agent on the ToyCTF enviornment: at each step, the agent
|
||||
picks an action at random from the action space.
|
||||
|
||||
Here is the result after a couple of thousands iterations:
|
||||
|
||||
![image.png](.attachments/image-cdb2b5e1-92f9-4a9e-af9f-b1a9bcae96a5.png)
|
|
@ -0,0 +1,36 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# Look for the the required version of python, return its path in $PYTHON
|
||||
# and version in $PYTHONVER.
|
||||
#
|
||||
# Usage: source this file to set the correct default version of Python in the PATH:
|
||||
# source getpythonpath.sh
|
||||
#
|
||||
|
||||
PYTHON=`which python3.8`
|
||||
if [ -z "$PYTHON" ]; then
|
||||
PYTHON=`which python3`
|
||||
fi
|
||||
if [ -z "$PYTHON" ]; then
|
||||
PYTHON=`which python`
|
||||
fi
|
||||
|
||||
if [ -z "$PYTHON" ]; then
|
||||
echo "Could not located python interpreter: '$PYTHON'" >&2
|
||||
exit -1
|
||||
fi
|
||||
|
||||
PYTHONVER=`$PYTHON --version | cut -d' ' -f2`
|
||||
if [[ ! "$PYTHONVER" == "3.8."* ]]; then
|
||||
echo 'Version >=3.8 of Python is required' >&2
|
||||
exit
|
||||
else
|
||||
echo "Compatible version $PYTHONVER of Python detected at $PYTHON"
|
||||
fi
|
||||
|
||||
PYTHONPATH=$(dirname $PYTHON)
|
||||
|
||||
export PATH=$PYTHONPATH:$PATH
|
|
@ -0,0 +1,29 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
#.SYNOPSIS
|
||||
# Initialize a dev environment by installing all python dependencies on Windows.
|
||||
# Not supported anymore: use WSL-based Linux instead on Windows.
|
||||
param($installJupyterExtensions)
|
||||
|
||||
# Install pip
|
||||
$pipversion = $(py -m pip --version)
|
||||
if ($pipversion) {
|
||||
Write-Host "pip already installed"
|
||||
} else {
|
||||
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
|
||||
py.exe get-pip.py
|
||||
}
|
||||
|
||||
# Install virtualenv
|
||||
py -m pip install --user virtualenv
|
||||
|
||||
# Create a virtual environment
|
||||
py.exe -m venv venv
|
||||
|
||||
# Install all pip dependencies in the virtual environment
|
||||
& venv/Scripts/python.exe -m pip install -e .
|
||||
& venv/Scripts/python.exe -m pip install -e .[dev]
|
||||
|
||||
# Setup pre-commit to check every git commit
|
||||
& venv/Scripts/pre-commit.exe install -t pre-push
|
|
@ -0,0 +1,129 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
pushd "$(dirname "$0")"
|
||||
|
||||
UBUNTU_VERSION=$(lsb_release -rs)
|
||||
|
||||
verlte() {
|
||||
[ "$1" = "`echo -e "$1\n$2" | sort -V | head -n1`" ]
|
||||
}
|
||||
|
||||
verlte $UBUNTU_VERSION '20' && OLDER_UBUNTU=1 || OLDER_UBUNTU=0
|
||||
|
||||
if [ $OLDER_UBUNTU == 1 ]; then
|
||||
echo "Old version of Ubuntu detected ($UBUNTU_VERSION), will register an additional apt repository to install latest version of Python"
|
||||
ADD_PYTHON38_APTREPO=1
|
||||
else
|
||||
ADD_PYTHON38_APTREPO=0
|
||||
fi
|
||||
|
||||
if [ ""$AML_CloudName"" != "" ]; then
|
||||
echo "Running on AML machine: skipping venv creation by default"
|
||||
CREATE_VENV=0
|
||||
else
|
||||
CREATE_VENV=1
|
||||
fi
|
||||
|
||||
|
||||
while getopts "nr" opt; do
|
||||
case $opt in
|
||||
n)
|
||||
echo "skipping venv creation. Parameter: $OPTARG" >&2
|
||||
CREATE_VENV=0
|
||||
;;
|
||||
r)
|
||||
echo "Will add apt repo to install latest version of Python"
|
||||
ADD_PYTHON38_APTREPO=1
|
||||
;;
|
||||
\?)
|
||||
echo "Syntax: init.sh [-n]" >&2
|
||||
echo " -n skip creation of virtual environment" >&2
|
||||
echo " -r register required apt repository to install latest version of Python for older versions of Ubuntu (e.g. 16)" >&2
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
|
||||
SUDO=''
|
||||
if (( $EUID != 0 )); then
|
||||
SUDO='sudo -E'
|
||||
fi
|
||||
|
||||
if [ ! -z "${VIRTUAL_ENV}" ]; then
|
||||
echo 'Running under virtual environment, skipping installation of global packages';
|
||||
else
|
||||
|
||||
|
||||
if [ "${ADD_PYTHON38_APTREPO}" == "1" ]; then
|
||||
echo 'Adding APT repo ppa:deadsnakes/ppa'
|
||||
$SUDO apt install software-properties-common -y
|
||||
$SUDO add-apt-repository ppa:deadsnakes/ppa -y
|
||||
fi
|
||||
|
||||
$SUDO apt update
|
||||
|
||||
# install nodejs 12.0 (required by pyright typechecker)
|
||||
curl -sL https://deb.nodesource.com/setup_12.x | $SUDO bash -
|
||||
|
||||
# install all apt-get dependencies
|
||||
if [ $OLDER_UBUNTU == 1 ]; then
|
||||
# exclude package not available on older ubuntu
|
||||
cat apt-requirements.txt | grep -v python3-distutils | $SUDO xargs apt-get install -y
|
||||
else
|
||||
cat apt-requirements.txt | $SUDO xargs apt-get install -y
|
||||
fi
|
||||
|
||||
# $SUDO npm -g upgrade node
|
||||
$SUDO npm install -g --unsafe-perm=true --allow-root npm
|
||||
|
||||
# Make sure that the desired version of python is used
|
||||
# in the rest of the script and when calling pyright to
|
||||
# generate stubs
|
||||
$SUDO update-alternatives --install /usr/bin/python python /usr/bin/python3.8 2
|
||||
$SUDO update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.8 2
|
||||
export PATH="/usr/bin:${PATH}"
|
||||
|
||||
# install pyright
|
||||
./pyright.sh --version
|
||||
|
||||
# installing orca required to export images with plotly
|
||||
./install-orca.sh
|
||||
|
||||
# install pip
|
||||
if [ ! -f "/tmp/get-pip.py" ]; then
|
||||
curl https://bootstrap.pypa.io/get-pip.py -o /tmp/get-pip.py
|
||||
fi
|
||||
|
||||
python --version
|
||||
|
||||
python /tmp/get-pip.py
|
||||
|
||||
if [ "${CREATE_VENV}" == "1" ]; then
|
||||
# Install virtualenv
|
||||
python -m pip install --user virtualenv
|
||||
|
||||
# Create a virtual environment
|
||||
python -m venv venv
|
||||
|
||||
source venv/bin/activate
|
||||
fi
|
||||
fi
|
||||
|
||||
|
||||
./install-pythonpackages.sh
|
||||
|
||||
if [ "${CREATE_VENV}" == "1" ]; then
|
||||
# Add venv to jupyter notebook
|
||||
python -m ipykernel install --user --name=venv
|
||||
fi
|
||||
|
||||
# setup checks on every `git push`
|
||||
pre-commit install -t pre-push
|
||||
|
||||
./createstubs.sh
|
||||
|
||||
popd
|
|
@ -0,0 +1,16 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
echo 'Installing the Plotly orca dependency for plotly figure export'
|
||||
|
||||
SUDO=''
|
||||
if (( $EUID != 0 )); then
|
||||
SUDO='sudo -E'
|
||||
fi
|
||||
|
||||
|
||||
xargs -a apt-requirements-orca.txt $SUDO apt-get install
|
||||
|
||||
$SUDO npm install -g --unsafe-perm=true --allow-root electron@6.1.4 orca
|
|
@ -0,0 +1,14 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
set -e
|
||||
|
||||
. ./getpythonpath.sh
|
||||
|
||||
# Install python packages
|
||||
$PYTHON -m pip install --upgrade pip
|
||||
$PYTHON -m pip install wheel
|
||||
$PYTHON -m pip install -e .
|
||||
$PYTHON -m pip install -e .[dev]
|
|
@ -0,0 +1,219 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Copyright (c) Microsoft Corporation. All rights reserved. Licensed under the MIT License.\n",
|
||||
"\n",
|
||||
"# Chain network CyberBattle Gym played by a random agent"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Gym random agent attacking a chain-like network\n",
|
||||
"\n",
|
||||
"## Chain network\n",
|
||||
"We consider a computer network of Windows and Linux machines where each machine has vulnerability \n",
|
||||
"granting access to another machine as per the following pattern:\n",
|
||||
"\n",
|
||||
" Start ---> (Linux ---> Windows ---> ... Linux ---> Windows)* ---> Linux[Flag]\n",
|
||||
"\n",
|
||||
"The network is parameterized by the length of the central Linux-Windows chain.\n",
|
||||
"The start node leaks the credentials to connect to all other nodes:\n",
|
||||
"\n",
|
||||
"For each `XXX ---> Windows` section, the XXX node has:\n",
|
||||
" - a local vulnerability exposing the RDP password to the Windows machine\n",
|
||||
" - a bunch of other trap vulnerabilities (high cost with no outcome)\n",
|
||||
"For each `XXX ---> Linux` section,\n",
|
||||
" - the Windows node has a local vulnerability exposing the SSH password to the Linux machine\n",
|
||||
" - a bunch of other trap vulnerabilities (high cost with no outcome)\n",
|
||||
"\n",
|
||||
"The chain is terminated by one node with a flag (reward)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Benchmark\n",
|
||||
"The following plot shows the average and one standard deviation cumulative reward over time as a random agent attacks the network."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%%HTML\n",
|
||||
"<img src=\"random_plot.png\" width=\"300\">"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sys\n",
|
||||
"import logging\n",
|
||||
"import gym\n",
|
||||
"from gym import spaces\n",
|
||||
"import numpy as np\n",
|
||||
"import networkx as nx\n",
|
||||
"import cyberbattle.simulation.actions as actions\n",
|
||||
"import cyberbattle._env.cyberbattle_env as cyberbattle_env\n",
|
||||
"import cyberbattle.agents.random_agent as random_agent\n",
|
||||
"import cyberbattle.samples.chainpattern.chainpattern as chainpattern\n",
|
||||
"import importlib\n",
|
||||
"importlib.reload(actions)\n",
|
||||
"importlib.reload(cyberbattle_env)\n",
|
||||
"importlib.reload(chainpattern)\n",
|
||||
"\n",
|
||||
"logging.basicConfig(stream=sys.stdout, level=logging.ERROR, format=\"%(levelname)s: %(message)s\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# chainpattern.create_network_chain_link(2)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"gym_env = gym.make('CyberBattleChain-v0', size=10, attacker_goal=None)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"gym_env.environment"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"gym_env.environment.network.nodes"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"gym_env.action_space"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"gym_env.action_space.sample()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"gym_env.observation_space.sample()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"tags": [
|
||||
"outputPrepend"
|
||||
]
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for i in range(100) : gym_env.sample_valid_action()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"scrolled": false,
|
||||
"tags": [
|
||||
"outputPrepend"
|
||||
]
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"random_agent.run_random_agent(1, 10000, gym_env)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"o,r,d,i = gym_env.step(gym_env.sample_valid_action())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"scrolled": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"o"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
После Ширина: | Высота: | Размер: 51 KiB |
|
@ -0,0 +1,102 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Copyright (c) Microsoft Corporation. All rights reserved. Licensed under the MIT License.\n",
|
||||
"\n",
|
||||
"# Capture the Flag Toy Example - Interactive (Human player)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This is a blank instantiaion of the Capture The Flag network to be play interactively by a human player (not via the gym envrionment).\n",
|
||||
"The interface exposed to the attacker is given by the following commands:\n",
|
||||
" - c2.print_all_attacks()\n",
|
||||
" - c2.run_attack(node, attack_id)\n",
|
||||
" - c2.run_remote_attack(source_node, target_node, attack_id)\n",
|
||||
" - c2.connect_and_infect(source_node, target_node, port_name, credential_id)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sys, logging\n",
|
||||
"import cyberbattle.simulation.model as model\n",
|
||||
"import cyberbattle.simulation.commandcontrol as commandcontrol\n",
|
||||
"import cyberbattle.samples.toyctf.toy_ctf as ctf\n",
|
||||
"import plotly.offline as plo\n",
|
||||
"plo.init_notebook_mode(connected=True)\n",
|
||||
"logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=\"%(levelname)s: %(message)s\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"network = model.create_network(ctf.nodes)\n",
|
||||
"env = model.Environment(network=network, vulnerability_library=dict([]), identifiers=ctf.ENV_IDENTIFIERS)\n",
|
||||
"env"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env.plot_environment_graph()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"c2 = commandcontrol.CommandControl(env)\n",
|
||||
"dbg = commandcontrol.EnvironmentDebugging(c2)\n",
|
||||
"def plot():\n",
|
||||
" dbg.plot_discovered_network()\n",
|
||||
" c2.print_all_attacks()\n",
|
||||
"plot()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.2-final"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
|
@ -0,0 +1,162 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Copyright (c) Microsoft Corporation. All rights reserved. Licensed under the MIT License.\n",
|
||||
"\n",
|
||||
"# Random agent playing the Capture The Flag toy environment"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sys\n",
|
||||
"import logging\n",
|
||||
"import gym\n",
|
||||
"from gym import spaces\n",
|
||||
"import numpy as np\n",
|
||||
"import networkx as nx\n",
|
||||
"import cyberbattle.simulation.actions as actions\n",
|
||||
"import cyberbattle.simulation.commandcontrol as commandcontrol\n",
|
||||
"import cyberbattle._env.cyberbattle_env as cyberbattle_env\n",
|
||||
"import importlib\n",
|
||||
"importlib.reload(actions)\n",
|
||||
"importlib.reload(cyberbattle_env)\n",
|
||||
"importlib.reload(commandcontrol)\n",
|
||||
"\n",
|
||||
"logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=\"%(levelname)s: %(message)s\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### CyberBattle simulation\n",
|
||||
"- **Environment**: a network of nodes with assigned vulnerabilities/functionalities, value, and firewall configuration\n",
|
||||
"- **Action space**: local attack | remote attack | authenticated connection\n",
|
||||
"- **Observation**: effects of action on environment"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"gym_env = gym.make('CyberBattleToyCtf-v0')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"gym_env.environment"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"gym_env.action_space"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"gym_env.action_space.sample()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"scrolled": false,
|
||||
"tags": [
|
||||
"outputPrepend"
|
||||
]
|
||||
},
|
||||
"source": [
|
||||
"## A random agent"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"scrolled": false,
|
||||
"tags": [
|
||||
"outputPrepend"
|
||||
]
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for i_episode in range(1):\n",
|
||||
" observation = gym_env.reset()\n",
|
||||
"\n",
|
||||
" total_reward = 0\n",
|
||||
"\n",
|
||||
" for t in range(5600):\n",
|
||||
" action = gym_env.sample_valid_action()\n",
|
||||
"\n",
|
||||
" observation, reward, done, info = gym_env.step(action)\n",
|
||||
" \n",
|
||||
" total_reward += reward\n",
|
||||
" \n",
|
||||
" if reward>0:\n",
|
||||
" print('####### rewarded action: {action}')\n",
|
||||
" print(f'total_reward={total_reward} reward={reward}')\n",
|
||||
" gym_env.render()\n",
|
||||
" \n",
|
||||
" if done:\n",
|
||||
" print(\"Episode finished after {} timesteps\".format(t+1))\n",
|
||||
" break\n",
|
||||
"\n",
|
||||
" gym_env.render()\n",
|
||||
"\n",
|
||||
"gym_env.close()\n",
|
||||
"print(\"simulation ended\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### End of simulation"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -0,0 +1,80 @@
|
|||
#!/bin/bash
|
||||
# https://github.com/microsoft/pyright/blob/master/docs/ci-integration.md
|
||||
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
PATH_TO_PYRIGHT=`which pyright`
|
||||
|
||||
ARGS=$@
|
||||
|
||||
vercomp () {
|
||||
if [[ $1 == $2 ]]
|
||||
then
|
||||
return 0
|
||||
fi
|
||||
local IFS=.
|
||||
local i ver1=($1) ver2=($2)
|
||||
# fill empty fields in ver1 with zeros
|
||||
for ((i=${#ver1[@]}; i<${#ver2[@]}; i++))
|
||||
do
|
||||
ver1[i]=0
|
||||
done
|
||||
for ((i=0; i<${#ver1[@]}; i++))
|
||||
do
|
||||
if [[ -z ${ver2[i]} ]]
|
||||
then
|
||||
# fill empty fields in ver2 with zeros
|
||||
ver2[i]=0
|
||||
fi
|
||||
if ((10#${ver1[i]} > 10#${ver2[i]}))
|
||||
then
|
||||
return 1
|
||||
fi
|
||||
if ((10#${ver1[i]} < 10#${ver2[i]}))
|
||||
then
|
||||
return 2
|
||||
fi
|
||||
done
|
||||
return 0
|
||||
}
|
||||
|
||||
# Node version check
|
||||
echo "Checking node version..."
|
||||
NODE_VERSION=`node -v | cut -d'v' -f2`
|
||||
MIN_NODE_VERSION="10.15.2"
|
||||
vercomp $MIN_NODE_VERSION $NODE_VERSION
|
||||
# 1 == gt
|
||||
if [[ $? -eq 1 ]]; then
|
||||
echo "Node version ${NODE_VERSION} too old, min expected is ${MIN_NODE_VERSION}, run:"
|
||||
echo " npm -g upgrade node"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
# Do we need to sudo?
|
||||
echo "Checking node_modules dir..."
|
||||
NODE_MODULES=`npm -g root`
|
||||
SUDO="sudo"
|
||||
if [ -w "$NODE_MODULES" ]; then
|
||||
SUDO="" #nop
|
||||
fi
|
||||
|
||||
# If we can't find pyright, install it.
|
||||
echo "Checking pyright exists..."
|
||||
if [ -z "$PATH_TO_PYRIGHT" ]; then
|
||||
echo "...installing pyright"
|
||||
${SUDO} npm install -g pyright
|
||||
else
|
||||
# already installed, upgrade to make sure it's current
|
||||
# this avoids a sudo on launch if we're already current
|
||||
echo "Checking pyright version..."
|
||||
CURRENT=`pyright --version | cut -d' ' -f2`
|
||||
REMOTE=`npm info pyright version`
|
||||
if [ "$CURRENT" != "$REMOTE" ]; then
|
||||
echo "...new version of pyright found, upgrading."
|
||||
${SUDO} npm upgrade -g pyright
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "done."
|
||||
pyright $ARGS
|
|
@ -0,0 +1,22 @@
|
|||
{
|
||||
"exclude": [
|
||||
"**/.ipynb_checpoints",
|
||||
"**/__pycache__",
|
||||
"typings",
|
||||
"venv"
|
||||
],
|
||||
"ignore": [
|
||||
"cyberbattle/__init__.py",
|
||||
"cyberbattle/agents/__init__.py"
|
||||
],
|
||||
"reportMissingImports": true,
|
||||
"reportMissingTypeStubs": true,
|
||||
"enableTypeIgnoreComments": true,
|
||||
"reportUnusedFunction": "warning",
|
||||
"reportUnusedImport": "warning",
|
||||
"reportDuplicateImport": "warning",
|
||||
"reportUnnecessaryCast": "warning",
|
||||
"reportUndefinedVariable": "error",
|
||||
"strictParameterNoneValue": true,
|
||||
"strict": []
|
||||
}
|
|
@ -0,0 +1,35 @@
|
|||
# resolve incompatibility with pytest
|
||||
attrs==19.1.0
|
||||
|
||||
# for vscode intellisense
|
||||
jedi
|
||||
pandas
|
||||
jupyter>=1.0.0
|
||||
|
||||
torch===1.5.1
|
||||
|
||||
# linter and formatter
|
||||
flake8
|
||||
pep8-naming
|
||||
autopep8
|
||||
pre-commit
|
||||
|
||||
pytest~=5.4.2
|
||||
|
||||
jupytext~=1.6.0
|
||||
|
||||
psutil==5.7.2
|
||||
|
||||
### type stubs
|
||||
numpy_stubs @ git+https://github.com/numpy/numpy-stubs/@master#egg=numpy_stubs
|
||||
|
||||
# Unfortunately, the official `data-science-types` package overwrites
|
||||
# the numpy stubs above, and their stubs are not as reliable.
|
||||
# So we use instead one of the project fork that disables the numpy stubs
|
||||
# See https://github.com/predictive-analytics-lab/data-science-types/issues/132
|
||||
# and https://github.com/microsoft/pyright/issues/861
|
||||
#data-science-types~=0.2.20
|
||||
data-science-types @ git+https://github.com/blumu/data-science-types/@blumu-patch-1
|
||||
|
||||
asciichart
|
||||
asciichartpy
|
|
@ -0,0 +1,11 @@
|
|||
gym~=0.17.3
|
||||
numpy==1.19.4
|
||||
boolean.py~=3.7
|
||||
networkx==2.4
|
||||
pyyaml~=5.3.1
|
||||
setuptools~=49.2.1
|
||||
matplotlib~=3.2.1
|
||||
plotly~=4.11.0
|
||||
tabulate~=0.8.7
|
||||
ordered_set==4.0.2
|
||||
progressbar2==3.51.4
|
|
@ -0,0 +1,55 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""A sample run of the CyberBattle simulation"""
|
||||
|
||||
import gym
|
||||
import logging
|
||||
import sys
|
||||
|
||||
|
||||
def main() -> int:
|
||||
'''Entry point if called as an executable'''
|
||||
|
||||
root = logging.getLogger()
|
||||
root.setLevel(logging.DEBUG)
|
||||
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
handler.setLevel(logging.DEBUG)
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
handler.setFormatter(formatter)
|
||||
root.addHandler(handler)
|
||||
|
||||
env = gym.make('CyberBattleToyCtf-v0')
|
||||
|
||||
logging.info(env.action_space.sample())
|
||||
logging.info(env.observation_space.sample())
|
||||
|
||||
for i_episode in range(1):
|
||||
observation = env.reset()
|
||||
action_mask = env.compute_action_mask()
|
||||
total_reward = 0
|
||||
for t in range(500):
|
||||
env.render()
|
||||
|
||||
# sample a valid action
|
||||
action = env.action_space.sample()
|
||||
while not env.apply_mask(action_mask, action):
|
||||
action = env.action_space.sample()
|
||||
|
||||
print('action' + str(action))
|
||||
observation, reward, done, info = env.step(action)
|
||||
action_mask = observation['action_mask']
|
||||
total_reward = total_reward + reward
|
||||
# print(observation)
|
||||
print('total_reward=' + str(total_reward))
|
||||
if done:
|
||||
print("Episode finished after {} timesteps".format(t + 1))
|
||||
break
|
||||
|
||||
env.close()
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,6 @@
|
|||
[flake8]
|
||||
ignore = W504,W503,E501,N813,N812,E741
|
||||
max-line-length = 200
|
||||
max-doc-length = 200
|
||||
exclude = typings, venv
|
||||
per-file-ignores = ./cyberbattle/simulation/model.py:N815 ./cyberbattle/agents/baseline/agent_wrapper.py:N801
|
|
@ -0,0 +1,49 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""setup CyberBattle simulator module"""
|
||||
import os
|
||||
import setuptools
|
||||
from typing import List
|
||||
|
||||
pwd = os.path.dirname(__file__)
|
||||
|
||||
|
||||
def get_install_requires(requirements_txt) -> List[str]:
|
||||
"""get the list of requried packages"""
|
||||
install_requires = []
|
||||
with open(os.path.join(pwd, requirements_txt)) as file:
|
||||
for line in file:
|
||||
line = line.strip()
|
||||
if line and not line.startswith('#'):
|
||||
install_requires.append(line)
|
||||
return install_requires
|
||||
|
||||
|
||||
# main setup kw args
|
||||
setup_kwargs = {
|
||||
'name': 'loonshot-sim',
|
||||
'version': '0.1',
|
||||
'description': "The simulation and RL code for the S+C loonshot project",
|
||||
'author': 'S+C Loonshot Team',
|
||||
'author_email': 'scloonshot@Microsoft.com',
|
||||
'install_requires': get_install_requires("requirements.txt"),
|
||||
'classifiers': [
|
||||
'Environment :: Other Environment',
|
||||
'Intended Audience :: Science/Research',
|
||||
'Natural Language :: English',
|
||||
'Operating System :: OS Independent',
|
||||
'Programming Language :: Python',
|
||||
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
||||
],
|
||||
'zip_safe': True,
|
||||
'packages': setuptools.find_packages(exclude=['test_*.py', '*_test.py']),
|
||||
'extras_require': {
|
||||
'dev': get_install_requires('requirements.dev.txt')
|
||||
}
|
||||
}
|
||||
|
||||
if __name__ == '__main__':
|
||||
setuptools.setup(**setup_kwargs)
|