зеркало из https://github.com/Azure/counterfit.git
create test framework (#67)
Co-authored-by: Shiven Chawla <shivenchawla@microsoft.com>
This commit is contained in:
Родитель
411737e9f5
Коммит
e2acf81734
Двоичный файл не отображается.
|
@ -46,7 +46,6 @@ pip-delete-this-directory.txt
|
|||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
|
|
28
conftest.py
28
conftest.py
|
@ -0,0 +1,28 @@
|
|||
from collections import OrderedDict, namedtuple
|
||||
import json
|
||||
import os
|
||||
import pytest
|
||||
import tests.utils.helpers as hp
|
||||
|
||||
@pytest.fixture(scope='function', autouse=True)
|
||||
def test_data(request, load_module_test_data):
|
||||
"""gets only required portion from module test data dictionary as test object"""
|
||||
return hp.extract_test_data(request.node.originalname, load_module_test_data)
|
||||
|
||||
@pytest.fixture(scope='module', autouse=True)
|
||||
def load_module_test_data(request, load_session_test_data):
|
||||
"""gets only required portion from session test data dictionary"""
|
||||
key = request.module.__name__
|
||||
if key not in load_session_test_data:
|
||||
return {}
|
||||
return load_session_test_data[request.module.__name__]
|
||||
|
||||
@pytest.fixture(scope='session', autouse=True)
|
||||
def load_session_test_data():
|
||||
"""loads test data from json as dictionary"""
|
||||
folder_path = os.path.abspath(os.path.dirname(__file__))
|
||||
folder = os.path.join(folder_path, *["tests", "test_data"])
|
||||
file = os.path.join(folder, "config.json")
|
||||
with open(file) as fp:
|
||||
data = json.load(fp)
|
||||
return data
|
|
@ -1,5 +0,0 @@
|
|||
target:
|
||||
creditfraud
|
||||
attacks:
|
||||
attack_1:
|
||||
attack_name: hop_skip_jump
|
|
@ -12,13 +12,12 @@ from counterfit.targets import CreditFraud, Digits, DigitKeras, SatelliteImages
|
|||
|
||||
from counterfit import Counterfit
|
||||
|
||||
|
||||
@pytest.fixture(params=[CreditFraud, Digits, DigitKeras, SatelliteImages])
|
||||
def target(request):
|
||||
yield request.param
|
||||
|
||||
# Evasion
|
||||
@pytest.fixture(params=['boundary', 'hop_skip_jump'])
|
||||
@pytest.fixture(params=['hop_skip_jump'])
|
||||
def attack(request):
|
||||
yield request.param
|
||||
|
||||
|
@ -29,16 +28,6 @@ def build_attack(target_obj: 'CFTarget', attack: str):
|
|||
return Counterfit.build_attack(target, attack)
|
||||
|
||||
|
||||
def test_boundary_credit():
|
||||
attack = build_attack(CreditFraud, 'boundary')
|
||||
# attack = build_attack(CreditFraud, 'hop_skip_jump')
|
||||
# >>>> run() attack estimator: BlackBoxClassifier(model=None, clip_values=None, preprocessing=StandardisationMeanStd(mean=0.0, std=1.0, apply_fit=True, apply_predict=True), preprocessing_defences=None, postprocessing_defences=None, preprocessing_operations=[StandardisationMeanStd(mean=0.0, std=1.0, apply_fit=True, apply_predict=True)], nb_classes=2, predict_fn=<bound method CFTarget.predict_wrapper of <counterfit.targets.creditfraud.CreditFraud object at 0x7f75f3fc22e0>>, input_shape=(30,))
|
||||
attack_did_succeed = Counterfit.run_attack(attack)
|
||||
|
||||
assert attack
|
||||
assert attack_did_succeed
|
||||
|
||||
|
||||
def test_attack(target, attack):
|
||||
cfattack = build_attack(target, attack)
|
||||
assert cfattack
|
||||
|
|
|
@ -10,15 +10,15 @@ warnings.filterwarnings("ignore", category=FutureWarning)
|
|||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
|
||||
from targets import CreditFraud
|
||||
from targets import Digits
|
||||
from targets import DigitKeras
|
||||
from targets import SatelliteImages
|
||||
from tests.mocks.targets import CreditFraud
|
||||
from tests.mocks.targets import Digits
|
||||
from tests.mocks.targets import DigitKeras
|
||||
from tests.mocks.targets import SatelliteImages
|
||||
|
||||
from counterfit import Counterfit
|
||||
|
||||
|
||||
@pytest.fixture(params=[CreditFraud, Digits, DigitKeras])
|
||||
@pytest.fixture(params=[CreditFraud, DigitKeras])
|
||||
def target(request):
|
||||
yield request.param
|
||||
|
||||
|
|
|
@ -8,13 +8,13 @@ warnings.filterwarnings("ignore", category=FutureWarning)
|
|||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
|
||||
from targets import CreditFraud
|
||||
from targets import Digits
|
||||
from targets import DigitKeras
|
||||
from tests.mocks.targets import CreditFraud
|
||||
from tests.mocks.targets import Digits
|
||||
from tests.mocks.targets import DigitKeras
|
||||
from counterfit import Counterfit
|
||||
|
||||
|
||||
@pytest.fixture(params=[CreditFraud, Digits, DigitKeras])
|
||||
@pytest.fixture(params=[DigitKeras])
|
||||
def target(request):
|
||||
yield request.param
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ warnings.filterwarnings("ignore", category=FutureWarning)
|
|||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
|
||||
from targets import DigitKeras
|
||||
from tests.mocks.targets import DigitKeras
|
||||
|
||||
from counterfit import Counterfit
|
||||
|
||||
|
|
|
@ -5,8 +5,8 @@ warnings.filterwarnings('ignore')
|
|||
|
||||
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..'))
|
||||
|
||||
from targets import Digits
|
||||
from targets import DigitKeras
|
||||
from tests.mocks.targets import Digits
|
||||
from tests.mocks.targets import DigitKeras
|
||||
from counterfit import Counterfit
|
||||
|
||||
|
||||
|
|
|
@ -1,7 +0,0 @@
|
|||
from counterfit.core.targets import CFTarget
|
||||
|
||||
from .creditfraud import CreditFraud
|
||||
from .digits_mlp import Digits
|
||||
from .digits_keras import DigitKeras
|
||||
from .movie_reviews import MovieReviewsTarget
|
||||
from .satellite import SatelliteImages
|
|
@ -15,36 +15,35 @@ def test_frameworks():
|
|||
frameworks = Counterfit.get_frameworks()
|
||||
assert dict(frameworks)
|
||||
|
||||
def test_build_target():
|
||||
def test_build_target(test_data):
|
||||
target = Counterfit.build_target(
|
||||
data_type="images",
|
||||
endpoint="http://locahost/score",
|
||||
output_classes=["Cat", "NotACat"],
|
||||
classifier="closed-box",
|
||||
data_type=test_data.data_type,
|
||||
endpoint=test_data.endpoint,
|
||||
output_classes=test_data.output_classes,
|
||||
classifier=test_data.classifier,
|
||||
input_shape=(1,),
|
||||
load_func=load,
|
||||
predict_func=predict,
|
||||
X = [[1,0]]
|
||||
X = test_data.X
|
||||
)
|
||||
|
||||
assert isinstance(target, CFTarget)
|
||||
|
||||
|
||||
def test_build_attack():
|
||||
def test_build_attack(test_data):
|
||||
target = Counterfit.build_target(
|
||||
data_type="image",
|
||||
endpoint="http://locahost/score",
|
||||
output_classes=["Cat", "NotACat"],
|
||||
classifier="closed-box",
|
||||
data_type=test_data.data_type,
|
||||
endpoint=test_data.endpoint,
|
||||
output_classes=test_data.output_classes,
|
||||
classifier=test_data.classifier,
|
||||
input_shape=(1,),
|
||||
load_func=load,
|
||||
predict_func=predict,
|
||||
X = [[1,0]]
|
||||
X = test_data.X
|
||||
)
|
||||
|
||||
cfattack = Counterfit.build_attack(
|
||||
target=target,
|
||||
attack="hop_skip_jump"
|
||||
attack=test_data.attack
|
||||
)
|
||||
|
||||
assert isinstance(cfattack, CFAttack)
|
|
@ -0,0 +1,7 @@
|
|||
from counterfit.core.targets import CFTarget
|
||||
|
||||
from tests.mocks.targets.creditfraud import CreditFraud
|
||||
from tests.mocks.targets.digits_mlp import Digits
|
||||
from tests.mocks.targets.digits_keras import DigitKeras
|
||||
from tests.mocks.targets.movie_reviews import MovieReviewsTarget
|
||||
from tests.mocks.targets.satellite import SatelliteImages
|
|
@ -1,2 +0,0 @@
|
|||
def test_run():
|
||||
assert False
|
|
@ -0,0 +1,33 @@
|
|||
{
|
||||
"test_core": {
|
||||
"test_build_target": {
|
||||
"data_type": "images",
|
||||
"endpoint": "http://locahost/score",
|
||||
"output_classes": ["Cat", "NotACat"],
|
||||
"classifier": "closed-box",
|
||||
"X": [[1,0]]
|
||||
},
|
||||
"test_build_attack": {
|
||||
"data_type": "image",
|
||||
"endpoint": "http://locahost/score",
|
||||
"output_classes": ["Cat", "NotACat"],
|
||||
"classifier": "closed-box",
|
||||
"X": [[1,0]],
|
||||
"attack": "hop_skip_jump"
|
||||
}
|
||||
},
|
||||
"test_hop_skip_jump": {
|
||||
"test_attack": {
|
||||
"attack_targets": [
|
||||
{
|
||||
"attack": "hop_skip_jump",
|
||||
"target": "CreditFraud"
|
||||
},
|
||||
{
|
||||
"attack": "hop_skip_jump",
|
||||
"target": "DigitKeras"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,46 @@
|
|||
from collections import OrderedDict, namedtuple
|
||||
from tests.mocks.targets import CreditFraud
|
||||
from tests.mocks.targets import Digits
|
||||
from tests.mocks.targets import DigitKeras
|
||||
from tests.mocks.targets import SatelliteImages
|
||||
|
||||
def get_target_class(target: str):
|
||||
"""loads attack-target combinations from json as dictionary"""
|
||||
target_classes = {
|
||||
"CreditFraud": CreditFraud(),
|
||||
"Digits": Digits(),
|
||||
"DigitKeras": DigitKeras(),
|
||||
"SatelliteImages": SatelliteImages()
|
||||
}
|
||||
return target_classes[target]
|
||||
|
||||
def extract_test_data(key: str, load_module_test_data):
|
||||
"""gets only required portion from module test data dictionary as test object"""
|
||||
if key not in load_module_test_data:
|
||||
return
|
||||
return create_namedtuple_from_dict(load_module_test_data[key])
|
||||
|
||||
|
||||
def create_namedtuple_from_dict(obj):
|
||||
"""converts given list or dict to named tuples, generic alternative to dataclass"""
|
||||
if isinstance(obj, dict):
|
||||
fields = sorted(obj.keys())
|
||||
namedtuple_type = namedtuple(
|
||||
typename='test_data',
|
||||
field_names=fields,
|
||||
rename=True,
|
||||
)
|
||||
field_value_pairs = OrderedDict(
|
||||
(str(field), create_namedtuple_from_dict(obj[field]))
|
||||
for field in fields
|
||||
)
|
||||
try:
|
||||
return namedtuple_type(**field_value_pairs)
|
||||
except TypeError:
|
||||
# Cannot create namedtuple instance so fallback to dict (invalid attribute names)
|
||||
return dict(**field_value_pairs)
|
||||
elif isinstance(obj, (list, set, tuple, frozenset)):
|
||||
return [create_namedtuple_from_dict(item) for item in obj]
|
||||
else:
|
||||
return obj
|
||||
|
Загрузка…
Ссылка в новой задаче