Add backward compatibility loss functions for new error and strict imitation under Tensorflow implmentation together with example notebooks (#97)

* Add backward compatibility loss functions for new error and strict imitation under Tensorflow implmentation together with example notebooks

* Add parameters to clip values before taking log to prevent gradients from becoming NaNs

* Add docstrings and cleanup implementation

* Bump tensorflow version and related requirements
This commit is contained in:
ilmarinen 2020-12-21 12:39:49 -08:00 коммит произвёл GitHub
Родитель 9765a40111
Коммит 17ceb24d36
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
11 изменённых файлов: 2542 добавлений и 4 удалений

Просмотреть файл

@ -1,7 +1,121 @@
import tensorflow.compat.v2 as tf
class BCNLLLoss(object):
"""
Backward Compatibility New Error Negative Log Likelihood Loss
This class implements the backward compatibility loss function
with the underlying loss function being the Negative Log Likelihood
loss.
Note that the final layer of each model is assumed to have a
softmax output.
Example usage:
h1 = MyModel()
... train h1 ...
h1.trainable = False
lambda_c = 0.5 (regularization parameter)
h2 = MyNewModel() (this may be the same model type as MyModel)
bcloss = BCNLLLoss(h1, h2, lambda_c)
optimizer = tf.keras.optimizers.SGD(0.01)
tf_helpers.bc_fit(
h2,
training_set=ds_train,
testing_set=ds_test,
epochs=6,
bc_loss=bc_loss,
optimizer=optimizer)
Args:
h1: Our reference model which we would like to be compatible with.
h2: Our new model which will be the updated model.
lambda_c: A float between 0.0 and 1.0, which is a regularization
parameter that determines how much we want to penalize model h2
for being incompatible with h1. Lower values panalize less and
higher values penalize more.
"""
def __init__(self, h1, h2, lambda_c, clip_value_min=1e-10, clip_value_max=4.0):
self.h1 = h1
self.h2 = h2
self.lambda_c = lambda_c
self.clip_value_min = clip_value_min
self.clip_value_max = clip_value_max
self.__name__ = "BCNLLLoss"
def nll_loss(self, target_labels, model_output):
# Pick the model output probabilities corresponding to the ground truth labels
model_outputs_for_targets = tf.gather(
model_output, tf.dtypes.cast(target_labels, tf.int32), axis=1)
# We make sure to clip the probability values so that they do not
# result in Nan's once we take the logarithm
model_outputs_for_targets = tf.clip_by_value(
model_outputs_for_targets,
clip_value_min=self.clip_value_min,
clip_value_max=self.clip_value_max)
loss = -1 * tf.reduce_mean(tf.math.log(model_outputs_for_targets))
return loss
def dissonance(self, h2_output, target_labels):
nll_loss = self.nll_loss(target_labels, h2_output)
return nll_loss
def __call__(self, x, y):
h1_output = tf.argmax(self.h1(x), axis=1)
h2_output = self.h2(x)
h1_diff = h1_output - y
h1_correct = (h1_diff == 0)
_, x_support = tf.dynamic_partition(x, tf.dtypes.cast(h1_correct, tf.int32), 2)
_, y_support = tf.dynamic_partition(y, tf.dtypes.cast(h1_correct, tf.int32), 2)
h2_support_output = self.h2(x_support)
dissonance = self.dissonance(h2_support_output, y_support)
new_error_loss = self.nll_loss(y, h2_output) + self.lambda_c * dissonance
return new_error_loss
class BCCrossEntropyLoss(object):
"""
Backward Compatibility New Error Cross Entropy Loss
This class implements the backward compatibility loss function
with the underlying loss function being the Negative Log Likelihood
loss.
Note that the final layer of each model is assumed to have a
softmax output.
Example usage:
h1 = MyModel()
... train h1 ...
h1.trainable = False
lambda_c = 0.5 (regularization parameter)
h2 = MyNewModel() (this may be the same model type as MyModel)
bcloss = BCCrossEntropyLoss(h1, h2, lambda_c)
optimizer = tf.keras.optimizers.SGD(0.01)
tf_helpers.bc_fit(
h2,
training_set=ds_train,
testing_set=ds_test,
epochs=6,
bc_loss=bc_loss,
optimizer=optimizer)
Args:
h1: Our reference model which we would like to be compatible with.
h2: Our new model which will be the updated model.
lambda_c: A float between 0.0 and 1.0, which is a regularization
parameter that determines how much we want to penalize model h2
for being incompatible with h1. Lower values panalize less and
higher values penalize more.
"""
def __init__(self, h1, h2, lambda_c):
self.h1 = h1
@ -27,3 +141,131 @@ class BCCrossEntropyLoss(object):
new_error_loss = self.cce_loss(y, h2_output) + self.lambda_c * dissonance
return tf.reduce_sum(new_error_loss)
class BCBinaryCrossEntropyLoss(object):
"""
Backward Compatibility New Error Binary Cross Entropy Loss
This class implements the backward compatibility loss function
with the underlying loss function being the Negative Log Likelihood
loss.
Note that the final layer of each model is assumed to have a
softmax output.
Example usage:
h1 = MyModel()
... train h1 ...
h1.trainable = False
lambda_c = 0.5 (regularization parameter)
h2 = MyNewModel() (this may be the same model type as MyModel)
bcloss = BCBinaryCrossEntropyLoss(h1, h2, lambda_c)
optimizer = tf.keras.optimizers.SGD(0.01)
tf_helpers.bc_fit(
h2,
training_set=ds_train,
testing_set=ds_test,
epochs=6,
bc_loss=bc_loss,
optimizer=optimizer)
Args:
h1: Our reference model which we would like to be compatible with.
h2: Our new model which will be the updated model.
lambda_c: A float between 0.0 and 1.0, which is a regularization
parameter that determines how much we want to penalize model h2
for being incompatible with h1. Lower values panalize less and
higher values penalize more.
"""
def __init__(self, h1, h2, lambda_c):
self.h1 = h1
self.h2 = h2
self.lambda_c = lambda_c
self.__name__ = "BCBinaryCrossEntropyLoss"
self.bce_loss = tf.keras.losses.BinaryCrossentropy(
reduction=tf.keras.losses.Reduction.SUM)
def dissonance(self, h2_output, target_labels):
cross_entropy_loss = self.bce_loss(target_labels, h2_output)
return cross_entropy_loss
def __call__(self, x, y):
h1_output = tf.argmax(self.h1(x), axis=1)
h2_output = self.h2(x)
h1_diff = h1_output - tf.argmax(y, axis=1)
h1_correct = (h1_diff == 0)
_, x_support = tf.dynamic_partition(x, tf.dtypes.cast(h1_correct, tf.int32), 2)
_, y_support = tf.dynamic_partition(y, tf.dtypes.cast(h1_correct, tf.int32), 2)
h2_support_output = self.h2(x_support)
dissonance = self.dissonance(h2_support_output, y_support)
new_error_loss = self.bce_loss(y, h2_output) + self.lambda_c * dissonance
return tf.reduce_sum(new_error_loss)
class BCKLDivLoss(object):
"""
Backward Compatibility New Error Kullback Liebler Divergence Loss
This class implements the backward compatibility loss function
with the underlying loss function being the Negative Log Likelihood
loss.
Note that the final layer of each model is assumed to have a
softmax output.
Example usage:
h1 = MyModel()
... train h1 ...
h1.trainable = False
lambda_c = 0.5 (regularization parameter)
h2 = MyNewModel() (this may be the same model type as MyModel)
bcloss = BCKLDivLoss(h1, h2, lambda_c)
optimizer = tf.keras.optimizers.SGD(0.01)
tf_helpers.bc_fit(
h2,
training_set=ds_train,
testing_set=ds_test,
epochs=6,
bc_loss=bc_loss,
optimizer=optimizer)
Args:
h1: Our reference model which we would like to be compatible with.
h2: Our new model which will be the updated model.
lambda_c: A float between 0.0 and 1.0, which is a regularization
parameter that determines how much we want to penalize model h2
for being incompatible with h1. Lower values panalize less and
higher values penalize more.
"""
def __init__(self, h1, h2, lambda_c):
self.h1 = h1
self.h2 = h2
self.lambda_c = lambda_c
self.__name__ = "BCKLDivLoss"
self.kldiv_loss = tf.keras.losses.KLDivergence(
reduction=tf.keras.losses.Reduction.SUM)
def dissonance(self, h2_output, target_labels):
kldiv_loss = self.kldiv_loss(target_labels, h2_output)
return kldiv_loss
def __call__(self, x, y):
h1_output = tf.argmax(self.h1(x), axis=1)
h2_output = self.h2(x)
h1_diff = h1_output - tf.argmax(y, axis=1)
h1_correct = (h1_diff == 0)
_, x_support = tf.dynamic_partition(x, tf.dtypes.cast(h1_correct, tf.int32), 2)
_, y_support = tf.dynamic_partition(y, tf.dtypes.cast(h1_correct, tf.int32), 2)
h2_support_output = self.h2(x_support)
dissonance = self.dissonance(h2_support_output, y_support)
new_error_loss = self.kldiv_loss(y, h2_output) + self.lambda_c * dissonance
return tf.reduce_sum(new_error_loss)

Просмотреть файл

@ -0,0 +1,272 @@
import tensorflow.compat.v2 as tf
import tensorflow.compat.v1 as tf1
class BCStrictImitationNLLLoss(object):
"""
Strict Imitation Negative Log Likelihood Loss
This class implements the strict imitation loss function
with the underlying loss function being the Negative Log Likelihood
loss.
Note that the final layer of each model is assumed to have a
softmax output.
Example usage:
h1 = MyModel()
... train h1 ...
h1.trainable = False
lambda_c = 0.5 (regularization parameter)
h2 = MyNewModel() (this may be the same model type as MyModel)
bcloss = BCStrictImitationNLLLoss(h1, h2, lambda_c)
optimizer = tf.keras.optimizers.SGD(0.01)
tf_helpers.bc_fit(
h2,
training_set=ds_train,
testing_set=ds_test,
epochs=6,
bc_loss=bc_loss,
optimizer=optimizer)
Args:
h1: Our reference model which we would like to be compatible with.
h2: Our new model which will be the updated model.
lambda_c: A float between 0.0 and 1.0, which is a regularization
parameter that determines how much we want to penalize model h2
for being incompatible with h1. Lower values panalize less and
higher values penalize more.
"""
def __init__(self, h1, h2, lambda_c, clip_value_min=1e-10, clip_value_max=4.0):
self.h1 = h1
self.h2 = h2
self.lambda_c = lambda_c
self.clip_value_min = clip_value_min
self.clip_value_max = clip_value_max
self.__name__ = "BCStrictImitationNLLLoss"
def nll_loss(self, target_labels, model_output):
# Pick the model output probabilities corresponding to the ground truth labels
_, model_outputs_for_targets = tf.dynamic_partition(
model_output, tf.dtypes.cast(target_labels, tf.int32), 2)
# We make sure to clip the probability values so that they do not
# result in Nan's once we take the logarithm
model_outputs_for_targets = tf.clip_by_value(
model_outputs_for_targets,
clip_value_min=self.clip_value_min,
clip_value_max=self.clip_value_max)
loss = -1 * tf.reduce_mean(tf.math.log(model_outputs_for_targets))
return loss
def dissonance(self, h2_output, target_labels):
log_loss = tf1.losses.log_loss(target_labels, h2_output, epsilon=1e-07)
return log_loss
def __call__(self, x, y):
h1_output = tf.argmax(self.h1(x), axis=1)
h2_output = self.h2(x)
h1_diff = h1_output - tf.argmax(y, axis=1)
h1_correct = (h1_diff == 0)
_, x_support = tf.dynamic_partition(x, tf.dtypes.cast(h1_correct, tf.int32), 2)
_, y_support = tf.dynamic_partition(y, tf.dtypes.cast(h1_correct, tf.int32), 2)
h2_support_output = self.h2(x_support)
strict_imitation_dissonance = self.dissonance(h2_support_output, y_support)
strict_imitation_loss = self.nll_loss(y, h2_output) + self.lambda_c * strict_imitation_dissonance
return tf.reduce_sum(strict_imitation_loss)
class BCStrictImitationCrossEntropyLoss(object):
"""
Strict Imitation Cross Entropy Loss
This class implements the strict imitation loss function
with the underlying loss function being the Negative Log Likelihood
loss.
Note that the final layer of each model is assumed to have a
softmax output.
Example usage:
h1 = MyModel()
... train h1 ...
h1.trainable = False
lambda_c = 0.5 (regularization parameter)
h2 = MyNewModel() (this may be the same model type as MyModel)
bcloss = BCStrictImitationCrossEntropyLoss(h1, h2, lambda_c)
optimizer = tf.keras.optimizers.SGD(0.01)
tf_helpers.bc_fit(
h2,
training_set=ds_train,
testing_set=ds_test,
epochs=6,
bc_loss=bc_loss,
optimizer=optimizer)
Args:
h1: Our reference model which we would like to be compatible with.
h2: Our new model which will be the updated model.
lambda_c: A float between 0.0 and 1.0, which is a regularization
parameter that determines how much we want to penalize model h2
for being incompatible with h1. Lower values panalize less and
higher values penalize more.
"""
def __init__(self, h1, h2, lambda_c):
self.h1 = h1
self.h2 = h2
self.lambda_c = lambda_c
self.__name__ = "BCStrictImitationCrossEntropyLoss"
self.cce_loss = tf.keras.losses.SparseCategoricalCrossentropy(
reduction=tf.keras.losses.Reduction.SUM)
def dissonance(self, h2_output, target_labels):
log_loss = tf1.losses.log_loss(target_labels, h2_output, epsilon=1e-07)
return log_loss
def __call__(self, x, y):
h1_output = tf.argmax(self.h1(x), axis=1)
h2_output = self.h2(x)
h1_diff = h1_output - tf.argmax(y, axis=1)
h1_correct = (h1_diff == 0)
_, x_support = tf.dynamic_partition(x, tf.dtypes.cast(h1_correct, tf.int32), 2)
_, y_support = tf.dynamic_partition(y, tf.dtypes.cast(h1_correct, tf.int32), 2)
h2_support_output = self.h2(x_support)
strict_imitation_dissonance = self.dissonance(h2_support_output, y_support)
strict_imitation_loss = self.cce_loss(tf.argmax(y, axis=1), h2_output) + self.lambda_c * strict_imitation_dissonance
return tf.reduce_sum(strict_imitation_loss)
class BCStrictImitationBinaryCrossEntropyLoss(object):
"""
Strict Imitation Binary Cross Entropy Loss
This class implements the strict imitation loss function
with the underlying loss function being the Negative Log Likelihood
loss.
Note that the final layer of each model is assumed to have a
softmax output.
Example usage:
h1 = MyModel()
... train h1 ...
h1.trainable = False
lambda_c = 0.5 (regularization parameter)
h2 = MyNewModel() (this may be the same model type as MyModel)
bcloss = BCStrictImitationBinaryCrossEntropyLoss(h1, h2, lambda_c)
optimizer = tf.keras.optimizers.SGD(0.01)
tf_helpers.bc_fit(
h2,
training_set=ds_train,
testing_set=ds_test,
epochs=6,
bc_loss=bc_loss,
optimizer=optimizer)
Args:
h1: Our reference model which we would like to be compatible with.
h2: Our new model which will be the updated model.
lambda_c: A float between 0.0 and 1.0, which is a regularization
parameter that determines how much we want to penalize model h2
for being incompatible with h1. Lower values panalize less and
higher values penalize more.
"""
def __init__(self, h1, h2, lambda_c):
self.h1 = h1
self.h2 = h2
self.lambda_c = lambda_c
self.__name__ = "BCStrictImitationBinaryCrossEntropyLoss"
self.bce_loss = tf.keras.losses.BinaryCrossentropy(
reduction=tf.keras.losses.Reduction.SUM)
def dissonance(self, h2_output, target_labels):
log_loss = tf1.losses.log_loss(target_labels, h2_output, epsilon=1e-07)
return log_loss
def __call__(self, x, y):
h1_output = tf.argmax(self.h1(x), axis=1)
h2_output = self.h2(x)
h1_diff = h1_output - tf.argmax(y, axis=1)
h1_correct = (h1_diff == 0)
_, x_support = tf.dynamic_partition(x, tf.dtypes.cast(h1_correct, tf.int32), 2)
_, y_support = tf.dynamic_partition(y, tf.dtypes.cast(h1_correct, tf.int32), 2)
h2_support_output = self.h2(x_support)
strict_imitation_dissonance = self.dissonance(h2_support_output, y_support)
strict_imitation_loss = self.bce_loss(y, h2_output) + self.lambda_c * strict_imitation_dissonance
return strict_imitation_loss
class BCStrictImitationKLDivLoss(object):
"""
Strict Imitation Kullback Liebler Loss
This class implements the strict imitation loss function
with the underlying loss function being the Negative Log Likelihood
loss.
Note that the final layer of each model is assumed to have a
softmax output.
Example usage:
h1 = MyModel()
... train h1 ...
h1.trainable = False
lambda_c = 0.5 (regularization parameter)
h2 = MyNewModel() (this may be the same model type as MyModel)
bcloss = BCStrictImitationKLDivLoss(h1, h2, lambda_c)
optimizer = tf.keras.optimizers.SGD(0.01)
tf_helpers.bc_fit(
h2,
training_set=ds_train,
testing_set=ds_test,
epochs=6,
bc_loss=bc_loss,
optimizer=optimizer)
Args:
h1: Our reference model which we would like to be compatible with.
h2: Our new model which will be the updated model.
lambda_c: A float between 0.0 and 1.0, which is a regularization
parameter that determines how much we want to penalize model h2
for being incompatible with h1. Lower values panalize less and
higher values penalize more.
"""
def __init__(self, h1, h2, lambda_c):
self.h1 = h1
self.h2 = h2
self.lambda_c = lambda_c
self.__name__ = "BCStrictImitationKLDivLoss"
self.kldiv_loss = tf.keras.losses.KLDivergence(
reduction=tf.keras.losses.Reduction.SUM)
def dissonance(self, h2_output, target_labels):
log_loss = tf1.losses.log_loss(target_labels, h2_output, epsilon=1e-07)
return log_loss
def __call__(self, x, y):
h1_output = tf.argmax(self.h1(x), axis=1)
h2_output = self.h2(x)
h1_diff = h1_output - tf.argmax(y, axis=1)
h1_correct = (h1_diff == 0)
_, x_support = tf.dynamic_partition(x, tf.dtypes.cast(h1_correct, tf.int32), 2)
_, y_support = tf.dynamic_partition(y, tf.dtypes.cast(h1_correct, tf.int32), 2)
h2_support_output = self.h2(x_support)
dissonance = self.dissonance(h2_support_output, y_support)
new_error_loss = self.kldiv_loss(y, h2_output) + self.lambda_c * dissonance
return tf.reduce_sum(new_error_loss)

Просмотреть файл

@ -0,0 +1,258 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import tensorflow.compat.v2 as tf\n",
"import tensorflow_datasets as tfds\n",
"import tensorflow.keras.backend as kb\n",
"from backwardcompatibilityml import scores\n",
"from backwardcompatibilityml.tensorflow import helpers as tf_helpers\n",
"from backwardcompatibilityml.tensorflow.loss.new_error import BCBinaryCrossEntropyLoss\n",
"import copy\n",
"\n",
"tf.enable_v2_behavior()\n",
"tf.random.set_seed(0)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"(ds_train, ds_test), ds_info = tfds.load(\n",
" 'mnist',\n",
" split=['train', 'test'],\n",
" shuffle_files=True,\n",
" as_supervised=True,\n",
" with_info=True,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def normalize_img(image, label):\n",
" \"\"\"Normalizes images: `uint8` -> `float32`.\"\"\"\n",
" if label == 3:\n",
" label = 1\n",
" else:\n",
" label = 0\n",
" \n",
" label = tf.one_hot(label, 2)\n",
"\n",
" return tf.cast(image, tf.float32) / 255., label\n",
"\n",
"ds_train = ds_train.map(\n",
" normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
"ds_train = ds_train.cache()\n",
"ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)\n",
"ds_train = ds_train.batch(128)\n",
"ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"ds_test = ds_test.map(\n",
" normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
"ds_test = ds_test.batch(128)\n",
"ds_test = ds_test.cache()\n",
"ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/3\n",
"469/469 [==============================] - 1s 2ms/step - loss: 0.0746 - accuracy: 0.9762 - val_loss: 0.0324 - val_accuracy: 0.9906\n",
"Epoch 2/3\n",
"469/469 [==============================] - 1s 1ms/step - loss: 0.0311 - accuracy: 0.9903 - val_loss: 0.0223 - val_accuracy: 0.9931\n",
"Epoch 3/3\n",
"469/469 [==============================] - 1s 1ms/step - loss: 0.0200 - accuracy: 0.9936 - val_loss: 0.0228 - val_accuracy: 0.9927\n"
]
},
{
"data": {
"text/plain": [
"<tensorflow.python.keras.callbacks.History at 0x7f7ed853c978>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"binary_cross_entropy_loss = tf.keras.losses.BinaryCrossentropy()\n",
"model = tf.keras.models.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=(28, 28, 1)),\n",
" tf.keras.layers.Dense(128,activation='relu'),\n",
" tf.keras.layers.Dense(2, activation='softmax')\n",
"])\n",
"model.compile(\n",
" loss=binary_cross_entropy_loss,\n",
" optimizer=tf.keras.optimizers.Adam(0.001),\n",
" metrics=['accuracy'],\n",
")\n",
"\n",
"model.fit(\n",
" ds_train,\n",
" epochs=3,\n",
" validation_data=ds_test,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"lambda_c = 0.9\n",
"model.trainable = False\n",
"\n",
"h2 = tf.keras.models.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=(28, 28, 1)),\n",
" tf.keras.layers.Dense(128,activation='relu'),\n",
" tf.keras.layers.Dense(2, activation='softmax')\n",
"])\n",
"\n",
"bc_loss = BCBinaryCrossEntropyLoss(model, h2, lambda_c)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"optimizer = tf.keras.optimizers.Adam(0.001)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/6\n",
"=============================================== Training loss: 9.1773\n",
"Epoch 2/6\n",
"=============================================== Training loss: 14.9186\n",
"Epoch 3/6\n",
"=============================================== Training loss: 3.8693\n",
"Epoch 4/6\n",
"=============================================== Training loss: 0.1928\n",
"Epoch 5/6\n",
"=============================================== Training loss: 0.2528\n",
"Epoch 6/6\n",
"=============================================== Training loss: 1.9271\n",
"Training done.\n"
]
}
],
"source": [
"tf_helpers.bc_fit(h2, training_set=ds_train, testing_set=ds_test, epochs=6, bc_loss=bc_loss, optimizer=optimizer)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"model.trainable = False\n",
"h2.trainable = False"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"h1_predicted_labels = []\n",
"h2_predicted_labels = []\n",
"ground_truth_labels = []\n",
"for x_batch_test, y_batch_test in ds_test:\n",
" h1_batch_predictions = tf.argmax(model(x_batch_test), axis=1)\n",
" h2_batch_predictions = tf.argmax(h2(x_batch_test), axis=1)\n",
" h1_predicted_labels += h1_batch_predictions.numpy().tolist()\n",
" h2_predicted_labels += h2_batch_predictions.numpy().tolist()\n",
" ground_truth_labels += y_batch_test.numpy().tolist()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"lambda_c: 0.9\n",
"BTC: 0\n",
"BEC: 1.0\n"
]
}
],
"source": [
"btc = scores.trust_compatibility_score(h1_predicted_labels, h2_predicted_labels, ground_truth_labels)\n",
"bec = scores.error_compatibility_score(h1_predicted_labels, h2_predicted_labels, ground_truth_labels)\n",
"\n",
"print(f\"lambda_c: {lambda_c}\")\n",
"print(f\"BTC: {btc}\")\n",
"print(f\"BEC: {bec}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

Просмотреть файл

@ -0,0 +1,250 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import tensorflow.compat.v2 as tf\n",
"import tensorflow_datasets as tfds\n",
"import tensorflow.keras.backend as kb\n",
"from backwardcompatibilityml import scores\n",
"from backwardcompatibilityml.tensorflow import helpers as tf_helpers\n",
"from backwardcompatibilityml.tensorflow.loss.new_error import BCCrossEntropyLoss\n",
"import copy\n",
"\n",
"tf.enable_v2_behavior()\n",
"tf.random.set_seed(0)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"(ds_train, ds_test), ds_info = tfds.load(\n",
" 'mnist',\n",
" split=['train', 'test'],\n",
" shuffle_files=True,\n",
" as_supervised=True,\n",
" with_info=True,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def normalize_img(image, label):\n",
" \"\"\"Normalizes images: `uint8` -> `float32`.\"\"\"\n",
" return tf.cast(image, tf.float32) / 255., label\n",
"\n",
"ds_train = ds_train.map(\n",
" normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
"ds_train = ds_train.cache()\n",
"ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)\n",
"ds_train = ds_train.batch(128)\n",
"ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"ds_test = ds_test.map(\n",
" normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
"ds_test = ds_test.batch(128)\n",
"ds_test = ds_test.cache()\n",
"ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/3\n",
"469/469 [==============================] - 1s 2ms/step - loss: 0.3585 - accuracy: 0.9000 - val_loss: 0.1850 - val_accuracy: 0.9462\n",
"Epoch 2/3\n",
"469/469 [==============================] - 1s 1ms/step - loss: 0.1583 - accuracy: 0.9549 - val_loss: 0.1412 - val_accuracy: 0.9575\n",
"Epoch 3/3\n",
"469/469 [==============================] - 1s 1ms/step - loss: 0.1148 - accuracy: 0.9670 - val_loss: 0.1075 - val_accuracy: 0.9684\n"
]
},
{
"data": {
"text/plain": [
"<tensorflow.python.keras.callbacks.History at 0x7f9f702112b0>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = tf.keras.models.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=(28, 28, 1)),\n",
" tf.keras.layers.Dense(128,activation='relu'),\n",
" tf.keras.layers.Dense(10, activation='softmax')\n",
"])\n",
"model.compile(\n",
" loss=tf.keras.losses.sparse_categorical_crossentropy,\n",
" optimizer=tf.keras.optimizers.Adam(0.001),\n",
" metrics=['accuracy'],\n",
")\n",
"\n",
"model.fit(\n",
" ds_train,\n",
" epochs=3,\n",
" validation_data=ds_test,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"lambda_c = 0.9\n",
"model.trainable = False\n",
"\n",
"h2 = tf.keras.models.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=(28, 28, 1)),\n",
" tf.keras.layers.Dense(128,activation='relu'),\n",
" tf.keras.layers.Dense(10, activation='softmax')\n",
"])\n",
"\n",
"bc_loss = BCCrossEntropyLoss(model, h2, lambda_c)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"optimizer = tf.keras.optimizers.Adam(0.001)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/6\n",
"=============================================== Training loss: 32.5727\n",
"Epoch 2/6\n",
"=============================================== Training loss: 19.1671\n",
"Epoch 3/6\n",
"=============================================== Training loss: 9.7318\n",
"Epoch 4/6\n",
"=============================================== Training loss: 8.3924\n",
"Epoch 5/6\n",
"=============================================== Training loss: 8.9723\n",
"Epoch 6/6\n",
"=============================================== Training loss: 5.0370\n",
"Training done.\n"
]
}
],
"source": [
"tf_helpers.bc_fit(h2, training_set=ds_train, testing_set=ds_test, epochs=6, bc_loss=bc_loss, optimizer=optimizer)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"model.trainable = False\n",
"h2.trainable = False"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"h1_predicted_labels = []\n",
"h2_predicted_labels = []\n",
"ground_truth_labels = []\n",
"for x_batch_test, y_batch_test in ds_test:\n",
" h1_batch_predictions = tf.argmax(model(x_batch_test), axis=1)\n",
" h2_batch_predictions = tf.argmax(h2(x_batch_test), axis=1)\n",
" h1_predicted_labels += h1_batch_predictions.numpy().tolist()\n",
" h2_predicted_labels += h2_batch_predictions.numpy().tolist()\n",
" ground_truth_labels += y_batch_test.numpy().tolist()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"lambda_c: 0.9\n",
"BTC: 0.9938042131350682\n",
"BEC: 0.6487341772151899\n"
]
}
],
"source": [
"btc = scores.trust_compatibility_score(h1_predicted_labels, h2_predicted_labels, ground_truth_labels)\n",
"bec = scores.error_compatibility_score(h1_predicted_labels, h2_predicted_labels, ground_truth_labels)\n",
"\n",
"print(f\"lambda_c: {lambda_c}\")\n",
"print(f\"BTC: {btc}\")\n",
"print(f\"BEC: {bec}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

Просмотреть файл

@ -0,0 +1,252 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import tensorflow.compat.v2 as tf\n",
"import tensorflow_datasets as tfds\n",
"import tensorflow.keras.backend as kb\n",
"from backwardcompatibilityml import scores\n",
"from backwardcompatibilityml.tensorflow import helpers as tf_helpers\n",
"from backwardcompatibilityml.tensorflow.loss.new_error import BCKLDivLoss\n",
"import copy\n",
"\n",
"tf.enable_v2_behavior()\n",
"tf.random.set_seed(0)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"(ds_train, ds_test), ds_info = tfds.load(\n",
" 'mnist',\n",
" split=['train', 'test'],\n",
" shuffle_files=True,\n",
" as_supervised=True,\n",
" with_info=True,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def normalize_img(image, label):\n",
" \"\"\"Normalizes images: `uint8` -> `float32`.\"\"\"\n",
" label_one_hot = tf.one_hot(label, 10)\n",
" return tf.cast(image, tf.float32) / 255., label_one_hot\n",
"\n",
"ds_train = ds_train.map(\n",
" normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
"ds_train = ds_train.cache()\n",
"ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)\n",
"ds_train = ds_train.batch(128)\n",
"ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"ds_test = ds_test.map(\n",
" normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
"ds_test = ds_test.batch(128)\n",
"ds_test = ds_test.cache()\n",
"ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/3\n",
"469/469 [==============================] - 1s 2ms/step - loss: 0.3543 - accuracy: 0.9021 - val_loss: 0.1878 - val_accuracy: 0.9449\n",
"Epoch 2/3\n",
"469/469 [==============================] - 1s 1ms/step - loss: 0.1587 - accuracy: 0.9544 - val_loss: 0.1401 - val_accuracy: 0.9572\n",
"Epoch 3/3\n",
"469/469 [==============================] - 1s 1ms/step - loss: 0.1139 - accuracy: 0.9682 - val_loss: 0.1031 - val_accuracy: 0.9705\n"
]
},
{
"data": {
"text/plain": [
"<tensorflow.python.keras.callbacks.History at 0x7f63a05edc18>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"kldiv_loss = tf.keras.losses.KLDivergence()\n",
"model = tf.keras.models.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=(28, 28, 1)),\n",
" tf.keras.layers.Dense(128,activation='relu'),\n",
" tf.keras.layers.Dense(10, activation='softmax')\n",
"])\n",
"model.compile(\n",
" loss=kldiv_loss,\n",
" optimizer=tf.keras.optimizers.Adam(0.001),\n",
" metrics=['accuracy'],\n",
")\n",
"\n",
"model.fit(\n",
" ds_train,\n",
" epochs=3,\n",
" validation_data=ds_test,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"lambda_c = 0.9\n",
"model.trainable = False\n",
"\n",
"h2 = tf.keras.models.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=(28, 28, 1)),\n",
" tf.keras.layers.Dense(128,activation='relu'),\n",
" tf.keras.layers.Dense(10, activation='softmax')\n",
"])\n",
"\n",
"bc_loss = BCKLDivLoss(model, h2, lambda_c)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"optimizer = tf.keras.optimizers.Adam(0.001)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/6\n",
"=============================================== Training loss: 24.6628\n",
"Epoch 2/6\n",
"=============================================== Training loss: 24.0618\n",
"Epoch 3/6\n",
"=============================================== Training loss: 10.0755\n",
"Epoch 4/6\n",
"=============================================== Training loss: 12.3139\n",
"Epoch 5/6\n",
"=============================================== Training loss: 13.2446\n",
"Epoch 6/6\n",
"=============================================== Training loss: 14.1427\n",
"Training done.\n"
]
}
],
"source": [
"tf_helpers.bc_fit(h2, training_set=ds_train, testing_set=ds_test, epochs=6, bc_loss=bc_loss, optimizer=optimizer)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"model.trainable = False\n",
"h2.trainable = False"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"h1_predicted_labels = []\n",
"h2_predicted_labels = []\n",
"ground_truth_labels = []\n",
"for x_batch_test, y_batch_test in ds_test:\n",
" h1_batch_predictions = tf.argmax(model(x_batch_test), axis=1)\n",
" h2_batch_predictions = tf.argmax(h2(x_batch_test), axis=1)\n",
" h1_predicted_labels += h1_batch_predictions.numpy().tolist()\n",
" h2_predicted_labels += h2_batch_predictions.numpy().tolist()\n",
" ground_truth_labels += y_batch_test.numpy().tolist()"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"lambda_c: 0.9\n",
"BTC: 0\n",
"BEC: 1.0\n"
]
}
],
"source": [
"btc = scores.trust_compatibility_score(h1_predicted_labels, h2_predicted_labels, ground_truth_labels)\n",
"bec = scores.error_compatibility_score(h1_predicted_labels, h2_predicted_labels, ground_truth_labels)\n",
"\n",
"print(f\"lambda_c: {lambda_c}\")\n",
"print(f\"BTC: {btc}\")\n",
"print(f\"BEC: {bec}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

Просмотреть файл

@ -0,0 +1,250 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import tensorflow.compat.v2 as tf\n",
"import tensorflow_datasets as tfds\n",
"import tensorflow.keras.backend as kb\n",
"from backwardcompatibilityml import scores\n",
"from backwardcompatibilityml.tensorflow import helpers as tf_helpers\n",
"from backwardcompatibilityml.tensorflow.loss.new_error import BCNLLLoss\n",
"import copy\n",
"\n",
"tf.enable_v2_behavior()\n",
"tf.random.set_seed(0)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"(ds_train, ds_test), ds_info = tfds.load(\n",
" 'mnist',\n",
" split=['train', 'test'],\n",
" shuffle_files=True,\n",
" as_supervised=True,\n",
" with_info=True,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def normalize_img(image, label):\n",
" \"\"\"Normalizes images: `uint8` -> `float32`.\"\"\"\n",
" return tf.cast(image, tf.float32) / 255., label\n",
"\n",
"ds_train = ds_train.map(\n",
" normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
"ds_train = ds_train.cache()\n",
"ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)\n",
"ds_train = ds_train.batch(128)\n",
"ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"ds_test = ds_test.map(\n",
" normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
"ds_test = ds_test.batch(128)\n",
"ds_test = ds_test.cache()\n",
"ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/3\n",
"469/469 [==============================] - 1s 2ms/step - loss: 0.3601 - accuracy: 0.9003 - val_loss: 0.1981 - val_accuracy: 0.9427\n",
"Epoch 2/3\n",
"469/469 [==============================] - 1s 1ms/step - loss: 0.1667 - accuracy: 0.9524 - val_loss: 0.1353 - val_accuracy: 0.9593\n",
"Epoch 3/3\n",
"469/469 [==============================] - 1s 1ms/step - loss: 0.1188 - accuracy: 0.9662 - val_loss: 0.1101 - val_accuracy: 0.9665\n"
]
},
{
"data": {
"text/plain": [
"<tensorflow.python.keras.callbacks.History at 0x7f68204c5668>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = tf.keras.models.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=(28, 28, 1)),\n",
" tf.keras.layers.Dense(128,activation='relu'),\n",
" tf.keras.layers.Dense(10, activation='softmax')\n",
"])\n",
"model.compile(\n",
" loss=tf.keras.losses.sparse_categorical_crossentropy,\n",
" optimizer=tf.keras.optimizers.Adam(0.001),\n",
" metrics=['accuracy'],\n",
")\n",
"\n",
"model.fit(\n",
" ds_train,\n",
" epochs=3,\n",
" validation_data=ds_test,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"lambda_c = 0.9\n",
"model.trainable = False\n",
"\n",
"h2 = tf.keras.models.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=(28, 28, 1)),\n",
" tf.keras.layers.Dense(128,activation='relu'),\n",
" tf.keras.layers.Dense(10, activation='softmax')\n",
"])\n",
"\n",
"bc_loss = BCNLLLoss(model, h2, lambda_c)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"optimizer = tf.keras.optimizers.Adam(0.001)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/6\n",
"=============================================== Training loss: 4.3735\n",
"Epoch 2/6\n",
"=============================================== Training loss: 4.3665\n",
"Epoch 3/6\n",
"=============================================== Training loss: 4.3861\n",
"Epoch 4/6\n",
"=============================================== Training loss: 4.3799\n",
"Epoch 5/6\n",
"=============================================== Training loss: 4.3679\n",
"Epoch 6/6\n",
"=============================================== Training loss: 4.3695\n",
"Training done.\n"
]
}
],
"source": [
"tf_helpers.bc_fit(h2, training_set=ds_train, testing_set=ds_test, epochs=6, bc_loss=bc_loss, optimizer=optimizer)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"model.trainable = False\n",
"h2.trainable = False"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"h1_predicted_labels = []\n",
"h2_predicted_labels = []\n",
"ground_truth_labels = []\n",
"for x_batch_test, y_batch_test in ds_test:\n",
" h1_batch_predictions = tf.argmax(model(x_batch_test), axis=1)\n",
" h2_batch_predictions = tf.argmax(h2(x_batch_test), axis=1)\n",
" h1_predicted_labels += h1_batch_predictions.numpy().tolist()\n",
" h2_predicted_labels += h2_batch_predictions.numpy().tolist()\n",
" ground_truth_labels += y_batch_test.numpy().tolist()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"lambda_c: 0.9\n",
"BTC: 0.11598551474392137\n",
"BEC: 0.9582089552238806\n"
]
}
],
"source": [
"btc = scores.trust_compatibility_score(h1_predicted_labels, h2_predicted_labels, ground_truth_labels)\n",
"bec = scores.error_compatibility_score(h1_predicted_labels, h2_predicted_labels, ground_truth_labels)\n",
"\n",
"print(f\"lambda_c: {lambda_c}\")\n",
"print(f\"BTC: {btc}\")\n",
"print(f\"BEC: {bec}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

Просмотреть файл

@ -0,0 +1,258 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import tensorflow.compat.v2 as tf\n",
"import tensorflow_datasets as tfds\n",
"import tensorflow.keras.backend as kb\n",
"from backwardcompatibilityml import scores\n",
"from backwardcompatibilityml.tensorflow import helpers as tf_helpers\n",
"from backwardcompatibilityml.tensorflow.loss.strict_imitation import BCStrictImitationBinaryCrossEntropyLoss\n",
"import copy\n",
"\n",
"tf.enable_v2_behavior()\n",
"tf.random.set_seed(0)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"(ds_train, ds_test), ds_info = tfds.load(\n",
" 'mnist',\n",
" split=['train', 'test'],\n",
" shuffle_files=True,\n",
" as_supervised=True,\n",
" with_info=True,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def normalize_img(image, label):\n",
" \"\"\"Normalizes images: `uint8` -> `float32`.\"\"\"\n",
" if label == 3:\n",
" label = 1\n",
" else:\n",
" label = 0\n",
" \n",
" label = tf.one_hot(label, 2)\n",
"\n",
" return tf.cast(image, tf.float32) / 255., label\n",
"\n",
"ds_train = ds_train.map(\n",
" normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
"ds_train = ds_train.cache()\n",
"ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)\n",
"ds_train = ds_train.batch(128)\n",
"ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"ds_test = ds_test.map(\n",
" normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
"ds_test = ds_test.batch(128)\n",
"ds_test = ds_test.cache()\n",
"ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/3\n",
"469/469 [==============================] - 1s 2ms/step - loss: 0.0787 - accuracy: 0.9752 - val_loss: 0.0329 - val_accuracy: 0.9900\n",
"Epoch 2/3\n",
"469/469 [==============================] - 1s 1ms/step - loss: 0.0319 - accuracy: 0.9900 - val_loss: 0.0213 - val_accuracy: 0.9934\n",
"Epoch 3/3\n",
"469/469 [==============================] - 1s 1ms/step - loss: 0.0215 - accuracy: 0.9929 - val_loss: 0.0204 - val_accuracy: 0.9935\n"
]
},
{
"data": {
"text/plain": [
"<tensorflow.python.keras.callbacks.History at 0x7f91c04e5a58>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"binary_cross_entropy_loss = tf.keras.losses.BinaryCrossentropy()\n",
"model = tf.keras.models.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=(28, 28, 1)),\n",
" tf.keras.layers.Dense(128,activation='relu'),\n",
" tf.keras.layers.Dense(2, activation='softmax')\n",
"])\n",
"model.compile(\n",
" loss=binary_cross_entropy_loss,\n",
" optimizer=tf.keras.optimizers.Adam(0.001),\n",
" metrics=['accuracy'],\n",
")\n",
"\n",
"model.fit(\n",
" ds_train,\n",
" epochs=3,\n",
" validation_data=ds_test,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"lambda_c = 0.9\n",
"model.trainable = False\n",
"\n",
"h2 = tf.keras.models.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=(28, 28, 1)),\n",
" tf.keras.layers.Dense(128,activation='relu'),\n",
" tf.keras.layers.Dense(2, activation='softmax')\n",
"])\n",
"\n",
"bc_loss = BCStrictImitationBinaryCrossEntropyLoss(model, h2, lambda_c)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"optimizer = tf.keras.optimizers.Adam(0.001)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/6\n",
"=============================================== Training loss: 7.9496\n",
"Epoch 2/6\n",
"=============================================== Training loss: 6.8961\n",
"Epoch 3/6\n",
"=============================================== Training loss: 1.5612\n",
"Epoch 4/6\n",
"=============================================== Training loss: 0.4459\n",
"Epoch 5/6\n",
"=============================================== Training loss: 2.2741\n",
"Epoch 6/6\n",
"=============================================== Training loss: 0.3391\n",
"Training done.\n"
]
}
],
"source": [
"tf_helpers.bc_fit(h2, training_set=ds_train, testing_set=ds_test, epochs=6, bc_loss=bc_loss, optimizer=optimizer)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"model.trainable = False\n",
"h2.trainable = False"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"h1_predicted_labels = []\n",
"h2_predicted_labels = []\n",
"ground_truth_labels = []\n",
"for x_batch_test, y_batch_test in ds_test:\n",
" h1_batch_predictions = tf.argmax(model(x_batch_test), axis=1)\n",
" h2_batch_predictions = tf.argmax(h2(x_batch_test), axis=1)\n",
" h1_predicted_labels += h1_batch_predictions.numpy().tolist()\n",
" h2_predicted_labels += h2_batch_predictions.numpy().tolist()\n",
" ground_truth_labels += y_batch_test.numpy().tolist()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"lambda_c: 0.9\n",
"BTC: 0\n",
"BEC: 1.0\n"
]
}
],
"source": [
"btc = scores.trust_compatibility_score(h1_predicted_labels, h2_predicted_labels, ground_truth_labels)\n",
"bec = scores.error_compatibility_score(h1_predicted_labels, h2_predicted_labels, ground_truth_labels)\n",
"\n",
"print(f\"lambda_c: {lambda_c}\")\n",
"print(f\"BTC: {btc}\")\n",
"print(f\"BEC: {bec}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

Просмотреть файл

@ -0,0 +1,252 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import tensorflow.compat.v2 as tf\n",
"import tensorflow.compat.v1 as tf1\n",
"import tensorflow_datasets as tfds\n",
"import tensorflow.keras.backend as kb\n",
"from backwardcompatibilityml import scores\n",
"from backwardcompatibilityml.tensorflow import helpers as tf_helpers\n",
"from backwardcompatibilityml.tensorflow.loss.strict_imitation import BCStrictImitationCrossEntropyLoss\n",
"import copy\n",
"\n",
"tf.enable_v2_behavior()\n",
"tf.random.set_seed(0)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"(ds_train, ds_test), ds_info = tfds.load(\n",
" 'mnist',\n",
" split=['train', 'test'],\n",
" shuffle_files=True,\n",
" as_supervised=True,\n",
" with_info=True,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def normalize_img(image, label):\n",
" \"\"\"Normalizes images: `uint8` -> `float32`.\"\"\"\n",
" label = tf.one_hot(label, 10)\n",
" return tf.cast(image, tf.float32) / 255., label\n",
"\n",
"ds_train = ds_train.map(\n",
" normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
"ds_train = ds_train.cache()\n",
"ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)\n",
"ds_train = ds_train.batch(128)\n",
"ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"ds_test = ds_test.map(\n",
" normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
"ds_test = ds_test.batch(128)\n",
"ds_test = ds_test.cache()\n",
"ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/3\n",
"469/469 [==============================] - 1s 2ms/step - loss: 0.0582 - accuracy: 0.9013 - val_loss: 0.0314 - val_accuracy: 0.9470\n",
"Epoch 2/3\n",
"469/469 [==============================] - 1s 1ms/step - loss: 0.0278 - accuracy: 0.9544 - val_loss: 0.0232 - val_accuracy: 0.9623\n",
"Epoch 3/3\n",
"469/469 [==============================] - 1s 1ms/step - loss: 0.0203 - accuracy: 0.9673 - val_loss: 0.0191 - val_accuracy: 0.9677\n"
]
},
{
"data": {
"text/plain": [
"<tensorflow.python.keras.callbacks.History at 0x7f5d80334080>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = tf.keras.models.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=(28, 28, 1)),\n",
" tf.keras.layers.Dense(128,activation='relu'),\n",
" tf.keras.layers.Dense(10, activation='softmax')\n",
"])\n",
"model.compile(\n",
" loss=tf1.losses.log_loss,\n",
" optimizer=tf.keras.optimizers.Adam(0.001),\n",
" metrics=['accuracy'],\n",
")\n",
"\n",
"model.fit(\n",
" ds_train,\n",
" epochs=3,\n",
" validation_data=ds_test,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"lambda_c = 0.9\n",
"model.trainable = False\n",
"\n",
"h2 = tf.keras.models.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=(28, 28, 1)),\n",
" tf.keras.layers.Dense(128,activation='relu'),\n",
" tf.keras.layers.Dense(10, activation='softmax')\n",
"])\n",
"\n",
"bc_loss = BCStrictImitationCrossEntropyLoss(model, h2, lambda_c)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"optimizer = tf.keras.optimizers.Adam(0.001)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/6\n",
"=============================================== Training loss: 27.0147\n",
"Epoch 2/6\n",
"=============================================== Training loss: 7.9487\n",
"Epoch 3/6\n",
"=============================================== Training loss: 8.6261\n",
"Epoch 4/6\n",
"=============================================== Training loss: 10.2663\n",
"Epoch 5/6\n",
"=============================================== Training loss: 1.8691\n",
"Epoch 6/6\n",
"=============================================== Training loss: 8.0041\n",
"Training done.\n"
]
}
],
"source": [
"tf_helpers.bc_fit(h2, training_set=ds_train, testing_set=ds_test, epochs=6, bc_loss=bc_loss, optimizer=optimizer)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"model.trainable = False\n",
"h2.trainable = False"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"h1_predicted_labels = []\n",
"h2_predicted_labels = []\n",
"ground_truth_labels = []\n",
"for x_batch_test, y_batch_test in ds_test:\n",
" h1_batch_predictions = tf.argmax(model(x_batch_test), axis=1)\n",
" h2_batch_predictions = tf.argmax(h2(x_batch_test), axis=1)\n",
" h1_predicted_labels += h1_batch_predictions.numpy().tolist()\n",
" h2_predicted_labels += h2_batch_predictions.numpy().tolist()\n",
" ground_truth_labels += y_batch_test.numpy().tolist()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"lambda_c: 0.9\n",
"BTC: 0\n",
"BEC: 1.0\n"
]
}
],
"source": [
"btc = scores.trust_compatibility_score(h1_predicted_labels, h2_predicted_labels, ground_truth_labels)\n",
"bec = scores.error_compatibility_score(h1_predicted_labels, h2_predicted_labels, ground_truth_labels)\n",
"\n",
"print(f\"lambda_c: {lambda_c}\")\n",
"print(f\"BTC: {btc}\")\n",
"print(f\"BEC: {bec}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

Просмотреть файл

@ -0,0 +1,252 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import tensorflow.compat.v2 as tf\n",
"import tensorflow_datasets as tfds\n",
"import tensorflow.keras.backend as kb\n",
"from backwardcompatibilityml import scores\n",
"from backwardcompatibilityml.tensorflow import helpers as tf_helpers\n",
"from backwardcompatibilityml.tensorflow.loss.strict_imitation import BCStrictImitationKLDivLoss\n",
"import copy\n",
"\n",
"tf.enable_v2_behavior()\n",
"tf.random.set_seed(0)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"(ds_train, ds_test), ds_info = tfds.load(\n",
" 'mnist',\n",
" split=['train', 'test'],\n",
" shuffle_files=True,\n",
" as_supervised=True,\n",
" with_info=True,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def normalize_img(image, label):\n",
" \"\"\"Normalizes images: `uint8` -> `float32`.\"\"\"\n",
" label_one_hot = tf.one_hot(label, 10)\n",
" return tf.cast(image, tf.float32) / 255., label_one_hot\n",
"\n",
"ds_train = ds_train.map(\n",
" normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
"ds_train = ds_train.cache()\n",
"ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)\n",
"ds_train = ds_train.batch(128)\n",
"ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"ds_test = ds_test.map(\n",
" normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
"ds_test = ds_test.batch(128)\n",
"ds_test = ds_test.cache()\n",
"ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/3\n",
"469/469 [==============================] - 1s 2ms/step - loss: 0.3579 - accuracy: 0.8997 - val_loss: 0.1928 - val_accuracy: 0.9450\n",
"Epoch 2/3\n",
"469/469 [==============================] - 1s 1ms/step - loss: 0.1602 - accuracy: 0.9545 - val_loss: 0.1318 - val_accuracy: 0.9615\n",
"Epoch 3/3\n",
"469/469 [==============================] - 1s 1ms/step - loss: 0.1132 - accuracy: 0.9672 - val_loss: 0.1045 - val_accuracy: 0.9689\n"
]
},
{
"data": {
"text/plain": [
"<tensorflow.python.keras.callbacks.History at 0x7f0adc08fd68>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"kldiv_loss = tf.keras.losses.KLDivergence()\n",
"model = tf.keras.models.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=(28, 28, 1)),\n",
" tf.keras.layers.Dense(128,activation='relu'),\n",
" tf.keras.layers.Dense(10, activation='softmax')\n",
"])\n",
"model.compile(\n",
" loss=kldiv_loss,\n",
" optimizer=tf.keras.optimizers.Adam(0.001),\n",
" metrics=['accuracy'],\n",
")\n",
"\n",
"model.fit(\n",
" ds_train,\n",
" epochs=3,\n",
" validation_data=ds_test,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"lambda_c = 0.9\n",
"model.trainable = False\n",
"\n",
"h2 = tf.keras.models.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=(28, 28, 1)),\n",
" tf.keras.layers.Dense(128,activation='relu'),\n",
" tf.keras.layers.Dense(10, activation='softmax')\n",
"])\n",
"\n",
"bc_loss = BCStrictImitationKLDivLoss(model, h2, lambda_c)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"optimizer = tf.keras.optimizers.Adam(0.001)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/6\n",
"=============================================== Training loss: 33.4308\n",
"Epoch 2/6\n",
"=============================================== Training loss: 11.4675\n",
"Epoch 3/6\n",
"=============================================== Training loss: 5.5857\n",
"Epoch 4/6\n",
"=============================================== Training loss: 5.1022\n",
"Epoch 5/6\n",
"=============================================== Training loss: 2.4461\n",
"Epoch 6/6\n",
"=============================================== Training loss: 1.6310\n",
"Training done.\n"
]
}
],
"source": [
"tf_helpers.bc_fit(h2, training_set=ds_train, testing_set=ds_test, epochs=6, bc_loss=bc_loss, optimizer=optimizer)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"model.trainable = False\n",
"h2.trainable = False"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"h1_predicted_labels = []\n",
"h2_predicted_labels = []\n",
"ground_truth_labels = []\n",
"for x_batch_test, y_batch_test in ds_test:\n",
" h1_batch_predictions = tf.argmax(model(x_batch_test), axis=1)\n",
" h2_batch_predictions = tf.argmax(h2(x_batch_test), axis=1)\n",
" h1_predicted_labels += h1_batch_predictions.numpy().tolist()\n",
" h2_predicted_labels += h2_batch_predictions.numpy().tolist()\n",
" ground_truth_labels += y_batch_test.numpy().tolist()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"lambda_c: 0.9\n",
"BTC: 0\n",
"BEC: 1.0\n"
]
}
],
"source": [
"btc = scores.trust_compatibility_score(h1_predicted_labels, h2_predicted_labels, ground_truth_labels)\n",
"bec = scores.error_compatibility_score(h1_predicted_labels, h2_predicted_labels, ground_truth_labels)\n",
"\n",
"print(f\"lambda_c: {lambda_c}\")\n",
"print(f\"BTC: {btc}\")\n",
"print(f\"BEC: {bec}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

Просмотреть файл

@ -0,0 +1,252 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import tensorflow.compat.v2 as tf\n",
"import tensorflow.compat.v1 as tf1\n",
"import tensorflow_datasets as tfds\n",
"import tensorflow.keras.backend as kb\n",
"from backwardcompatibilityml import scores\n",
"from backwardcompatibilityml.tensorflow import helpers as tf_helpers\n",
"from backwardcompatibilityml.tensorflow.loss.strict_imitation import BCStrictImitationNLLLoss\n",
"import copy\n",
"\n",
"tf.enable_v2_behavior()\n",
"tf.random.set_seed(0)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"(ds_train, ds_test), ds_info = tfds.load(\n",
" 'mnist',\n",
" split=['train', 'test'],\n",
" shuffle_files=True,\n",
" as_supervised=True,\n",
" with_info=True,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def normalize_img(image, label):\n",
" \"\"\"Normalizes images: `uint8` -> `float32`.\"\"\"\n",
" label = tf.one_hot(label, 10)\n",
" return tf.cast(image, tf.float32) / 255., label\n",
"\n",
"ds_train = ds_train.map(\n",
" normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
"ds_train = ds_train.cache()\n",
"ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)\n",
"ds_train = ds_train.batch(128)\n",
"ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"ds_test = ds_test.map(\n",
" normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
"ds_test = ds_test.batch(128)\n",
"ds_test = ds_test.cache()\n",
"ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/3\n",
"469/469 [==============================] - 1s 2ms/step - loss: 0.0578 - accuracy: 0.9014 - val_loss: 0.0317 - val_accuracy: 0.9469\n",
"Epoch 2/3\n",
"469/469 [==============================] - 1s 1ms/step - loss: 0.0277 - accuracy: 0.9544 - val_loss: 0.0230 - val_accuracy: 0.9599\n",
"Epoch 3/3\n",
"469/469 [==============================] - 1s 1ms/step - loss: 0.0203 - accuracy: 0.9666 - val_loss: 0.0189 - val_accuracy: 0.9682\n"
]
},
{
"data": {
"text/plain": [
"<tensorflow.python.keras.callbacks.History at 0x7fb0f85384a8>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = tf.keras.models.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=(28, 28, 1)),\n",
" tf.keras.layers.Dense(128,activation='relu'),\n",
" tf.keras.layers.Dense(10, activation='softmax')\n",
"])\n",
"model.compile(\n",
" loss=tf1.losses.log_loss,\n",
" optimizer=tf.keras.optimizers.Adam(0.001),\n",
" metrics=['accuracy'],\n",
")\n",
"\n",
"model.fit(\n",
" ds_train,\n",
" epochs=3,\n",
" validation_data=ds_test,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"lambda_c = 0.9\n",
"model.trainable = False\n",
"\n",
"h2 = tf.keras.models.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=(28, 28, 1)),\n",
" tf.keras.layers.Dense(128,activation='relu'),\n",
" tf.keras.layers.Dense(10, activation='softmax')\n",
"])\n",
"\n",
"bc_loss = BCStrictImitationNLLLoss(model, h2, lambda_c)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"optimizer = tf.keras.optimizers.SGD(0.01)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/6\n",
"=============================================== Training loss: 0.7825\n",
"Epoch 2/6\n",
"=============================================== Training loss: 0.5823\n",
"Epoch 3/6\n",
"=============================================== Training loss: 0.6761\n",
"Epoch 4/6\n",
"=============================================== Training loss: 0.4231\n",
"Epoch 5/6\n",
"=============================================== Training loss: 0.5634\n",
"Epoch 6/6\n",
"=============================================== Training loss: 0.4668\n",
"Training done.\n"
]
}
],
"source": [
"tf_helpers.bc_fit(h2, training_set=ds_train, testing_set=ds_test, epochs=6, bc_loss=bc_loss, optimizer=optimizer)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"model.trainable = False\n",
"h2.trainable = False"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"h1_predicted_labels = []\n",
"h2_predicted_labels = []\n",
"ground_truth_labels = []\n",
"for x_batch_test, y_batch_test in ds_test:\n",
" h1_batch_predictions = tf.argmax(model(x_batch_test), axis=1)\n",
" h2_batch_predictions = tf.argmax(h2(x_batch_test), axis=1)\n",
" h1_predicted_labels += h1_batch_predictions.numpy().tolist()\n",
" h2_predicted_labels += h2_batch_predictions.numpy().tolist()\n",
" ground_truth_labels += y_batch_test.numpy().tolist()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"lambda_c: 0.9\n",
"BTC: 0\n",
"BEC: 1.0\n"
]
}
],
"source": [
"btc = scores.trust_compatibility_score(h1_predicted_labels, h2_predicted_labels, ground_truth_labels)\n",
"bec = scores.error_compatibility_score(h1_predicted_labels, h2_predicted_labels, ground_truth_labels)\n",
"\n",
"print(f\"lambda_c: {lambda_c}\")\n",
"print(f\"BTC: {btc}\")\n",
"print(f\"BEC: {bec}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

Просмотреть файл

@ -1,13 +1,13 @@
torch==1.5.1
Jinja2==2.11.2
numpy==1.19.0
numpy==1.19.4
scikit-learn==0.23.1
rai_core_flask==0.0.2
tensorboard==2.3.0
tensorboard==2.4.0
tensorboard-plugin-wit==1.7.0
tensorflow==2.3.1
tensorflow==2.4.0
tensorflow-datasets==4.1.0
tensorflow-estimator==2.3.0
tensorflow-estimator==2.4.0
tensorflow-metadata==0.25.0
Pillow==7.2.0
mlflow==1.12.1