benchmark + parameter_sweeper module (#55)
* benchmark code * benchmark script update * update fastai code * benchmark code * benchmark script * flake8 req * flake8 req * train schedule * update * benchmarking nb * benchmark * benchmark stable * gitignore * experiments + test * delete benchmark.py * reformat msg in notebook * benchmark script * type error * fixes * fixes * fixes * lxml to env.yml
This commit is contained in:
Родитель
63db5b7565
Коммит
b01067535b
8
.flake8
8
.flake8
|
@ -7,6 +7,10 @@
|
||||||
# E501 Line too long (82 > 79 characters)
|
# E501 Line too long (82 > 79 characters)
|
||||||
# W503 Line break occurred before a binary operator
|
# W503 Line break occurred before a binary operator
|
||||||
# F403 'from module import *' used; unable to detect undefined names
|
# F403 'from module import *' used; unable to detect undefined names
|
||||||
|
# F405 '<function>' may be undefined, or defined from star imports
|
||||||
|
# E402 module level import not at top of file
|
||||||
|
# E731 do not assign a lambda expression, use a def
|
||||||
|
# F821 undefined name 'get_ipython' --> from generated python files using nbconvert
|
||||||
|
|
||||||
ignore = E203, E266, E501, W503, F403
|
ignore = E203, E266, E501, W503, F403, F405, E402, E731, F821
|
||||||
max-line-length = 79
|
max-line-length = 79
|
||||||
|
|
|
@ -114,3 +114,6 @@ image_classification/data/*
|
||||||
|
|
||||||
# don't save .swp files
|
# don't save .swp files
|
||||||
*.swp
|
*.swp
|
||||||
|
|
||||||
|
# don't save .csv files
|
||||||
|
*.csv
|
||||||
|
|
|
@ -35,6 +35,7 @@ dependencies:
|
||||||
- azureml-sdk[notebooks,contrib]==1.0.10
|
- azureml-sdk[notebooks,contrib]==1.0.10
|
||||||
- azure-storage>=0.36.0
|
- azure-storage>=0.36.0
|
||||||
- black>=18.6b4
|
- black>=18.6b4
|
||||||
|
- lxml>=4.3.2
|
||||||
- torchvision
|
- torchvision
|
||||||
- memory-profiler>=0.54.0
|
- memory-profiler>=0.54.0
|
||||||
- nvidia-ml-py3>=7.352.0
|
- nvidia-ml-py3>=7.352.0
|
||||||
|
|
|
@ -54,14 +54,17 @@
|
||||||
"source": [
|
"source": [
|
||||||
"import sys\n",
|
"import sys\n",
|
||||||
"sys.path.append(\"../\")\n",
|
"sys.path.append(\"../\")\n",
|
||||||
"import io, time, urllib.request\n",
|
"import io\n",
|
||||||
|
"import time\n",
|
||||||
|
"import urllib.request\n",
|
||||||
"import fastai\n",
|
"import fastai\n",
|
||||||
"from fastai.vision import *\n",
|
"from fastai.vision import *\n",
|
||||||
"from ipywebrtc import CameraStream, ImageRecorder\n",
|
"from ipywebrtc import CameraStream, ImageRecorder\n",
|
||||||
"from ipywidgets import HBox, Label, Layout, Widget\n",
|
"from ipywidgets import HBox, Label, Layout, Widget\n",
|
||||||
"from torch.cuda import get_device_name\n",
|
"from torch.cuda import get_device_name\n",
|
||||||
"from utils_ic.constants import IMAGENET_IM_SIZE\n",
|
"from utils_ic.constants import IMAGENET_IM_SIZE\n",
|
||||||
"from utils_ic.datasets import imagenet_labels, data_path\n",
|
"from utils_ic.datasets import imagenet_labels\n",
|
||||||
|
"from utils_ic.common import data_path\n",
|
||||||
"from utils_ic.imagenet_models import model_to_learner\n",
|
"from utils_ic.imagenet_models import model_to_learner\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
|
|
@ -76,9 +76,9 @@
|
||||||
"import sys\n",
|
"import sys\n",
|
||||||
"sys.path.append(\"../\")\n",
|
"sys.path.append(\"../\")\n",
|
||||||
"from pathlib import Path\n",
|
"from pathlib import Path\n",
|
||||||
"from utils_ic.datasets import Urls, unzip_url, data_path\n",
|
"from utils_ic.datasets import Urls, unzip_url\n",
|
||||||
"from fastai.vision import *\n",
|
"from fastai.vision import *\n",
|
||||||
"from fastai.metrics import error_rate, accuracy"
|
"from fastai.metrics import accuracy"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
|
@ -2,42 +2,46 @@
|
||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
|
|
||||||
# <i>Copyright (c) Microsoft Corporation. All rights reserved.</i>
|
# <i>Copyright (c) Microsoft Corporation. All rights reserved.</i>
|
||||||
#
|
#
|
||||||
# <i>Licensed under the MIT License.</i>
|
# <i>Licensed under the MIT License.</i>
|
||||||
|
|
||||||
# # WebCam Image Classification Quickstart Notebook
|
# # WebCam Image Classification Quickstart Notebook
|
||||||
#
|
#
|
||||||
# <br>
|
# <br>
|
||||||
#
|
#
|
||||||
# Image classification is a classical problem in computer vision that of determining whether or not the image data contains some specific object, feature, or activity. It is regarded as a mature research area
|
# Image classification is a classical problem in computer vision that of determining whether or not the image data contains some specific object, feature, or activity. It is regarded as a mature research area
|
||||||
# and currently the best models are based on [convolutional neural networks (CNNs)](https://en.wikipedia.org/wiki/Convolutional_neural_network). Such models with weights trained on millions of images and hundreds of object classes in [ImageNet dataset](http://www.image-net.org/) are available from major deep neural network frameworks such as [CNTK](https://www.microsoft.com/en-us/cognitive-toolkit/features/model-gallery/), [fast.ai](https://docs.fast.ai/vision.models.html#Computer-Vision-models-zoo), [Keras](https://keras.io/applications/), [PyTorch](https://pytorch.org/docs/stable/torchvision/models.html), and [TensorFlow](https://tfhub.dev/s?module-type=image-classification).
|
# and currently the best models are based on [convolutional neural networks (CNNs)](https://en.wikipedia.org/wiki/Convolutional_neural_network). Such models with weights trained on millions of images and hundreds of object classes in [ImageNet dataset](http://www.image-net.org/) are available from major deep neural network frameworks such as [CNTK](https://www.microsoft.com/en-us/cognitive-toolkit/features/model-gallery/), [fast.ai](https://docs.fast.ai/vision.models.html#Computer-Vision-models-zoo), [Keras](https://keras.io/applications/), [PyTorch](https://pytorch.org/docs/stable/torchvision/models.html), and [TensorFlow](https://tfhub.dev/s?module-type=image-classification).
|
||||||
#
|
#
|
||||||
#
|
#
|
||||||
# This notebook shows a simple example of how to load pretrained mobel and run it on a webcam stream. Here, we use [ResNet](https://arxiv.org/abs/1512.03385) model by utilizing `fastai.vision` package.
|
# This notebook shows a simple example of how to load pretrained mobel and run it on a webcam stream. Here, we use [ResNet](https://arxiv.org/abs/1512.03385) model by utilizing `fastai.vision` package.
|
||||||
#
|
#
|
||||||
# > For more details about image classification tasks including transfer-learning (aka fine tuning), please see our [training introduction notebook](01_training_introduction.ipynb).
|
# > For more details about image classification tasks including transfer-learning (aka fine tuning), please see our [training introduction notebook](01_training_introduction.ipynb).
|
||||||
|
|
||||||
# In[1]:
|
# In[1]:
|
||||||
|
|
||||||
|
|
||||||
get_ipython().run_line_magic('reload_ext', 'autoreload')
|
get_ipython().run_line_magic("reload_ext", "autoreload")
|
||||||
get_ipython().run_line_magic('autoreload', '2')
|
get_ipython().run_line_magic("autoreload", "2")
|
||||||
get_ipython().run_line_magic('matplotlib', 'inline')
|
get_ipython().run_line_magic("matplotlib", "inline")
|
||||||
|
|
||||||
|
|
||||||
# In[2]:
|
# In[2]:
|
||||||
|
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
sys.path.append("../")
|
sys.path.append("../")
|
||||||
import io, time, urllib.request
|
import io
|
||||||
|
import time
|
||||||
|
import urllib.request
|
||||||
import fastai
|
import fastai
|
||||||
from fastai.vision import *
|
from fastai.vision import *
|
||||||
from ipywebrtc import CameraStream, ImageRecorder
|
from ipywebrtc import CameraStream, ImageRecorder
|
||||||
from ipywidgets import HBox, Label, Layout, Widget
|
from ipywidgets import HBox, Label, Layout, Widget
|
||||||
from torch.cuda import get_device_name
|
from torch.cuda import get_device_name
|
||||||
from utils_ic.constants import IMAGENET_IM_SIZE
|
from utils_ic.constants import IMAGENET_IM_SIZE
|
||||||
from utils_ic.datasets import imagenet_labels, data_path
|
from utils_ic.datasets import imagenet_labels
|
||||||
|
from utils_ic.common import data_path
|
||||||
from utils_ic.imagenet_models import model_to_learner
|
from utils_ic.imagenet_models import model_to_learner
|
||||||
|
|
||||||
|
|
||||||
|
@ -46,13 +50,13 @@ print(get_device_name(0))
|
||||||
|
|
||||||
|
|
||||||
# ## 1. Load Pretrained Model
|
# ## 1. Load Pretrained Model
|
||||||
#
|
#
|
||||||
# We use ResNet18 which is a relatively small and fast compare to other CNNs models. The [reported error rate](https://pytorch-zh.readthedocs.io/en/latest/torchvision/models.html) of the model on ImageNet is 30.24% for top-1 and 10.92% for top-5<sup>*</sup>.
|
# We use ResNet18 which is a relatively small and fast compare to other CNNs models. The [reported error rate](https://pytorch-zh.readthedocs.io/en/latest/torchvision/models.html) of the model on ImageNet is 30.24% for top-1 and 10.92% for top-5<sup>*</sup>.
|
||||||
#
|
#
|
||||||
# The pretrained model expects input images normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225], which is defined in `fastai.vision.imagenet_stats`.
|
# The pretrained model expects input images normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225], which is defined in `fastai.vision.imagenet_stats`.
|
||||||
#
|
#
|
||||||
# The output of the model is the probability distribution of the classes in ImageNet. To convert them into human-readable labels, we utilize the label json file used from [Keras](https://github.com/keras-team/keras/blob/master/keras/applications/imagenet_utils.py).
|
# The output of the model is the probability distribution of the classes in ImageNet. To convert them into human-readable labels, we utilize the label json file used from [Keras](https://github.com/keras-team/keras/blob/master/keras/applications/imagenet_utils.py).
|
||||||
#
|
#
|
||||||
# > \* top-n: *n* labels considered most probable by the mode
|
# > \* top-n: *n* labels considered most probable by the mode
|
||||||
|
|
||||||
# In[3]:
|
# In[3]:
|
||||||
|
@ -66,12 +70,12 @@ print(f"{', '.join(labels[:5])}, ...")
|
||||||
# In[4]:
|
# In[4]:
|
||||||
|
|
||||||
|
|
||||||
# Convert a pretrained imagenet model into Learner for prediction.
|
# Convert a pretrained imagenet model into Learner for prediction.
|
||||||
learn = model_to_learner(models.resnet18(pretrained=True), IMAGENET_IM_SIZE)
|
learn = model_to_learner(models.resnet18(pretrained=True), IMAGENET_IM_SIZE)
|
||||||
|
|
||||||
|
|
||||||
# ## 2. Classify Images
|
# ## 2. Classify Images
|
||||||
#
|
#
|
||||||
# ### 2.1 Image file
|
# ### 2.1 Image file
|
||||||
# First, we prepare a coffee mug image to show an example of how to score a single image by using the model.
|
# First, we prepare a coffee mug image to show an example of how to score a single image by using the model.
|
||||||
|
|
||||||
|
@ -82,7 +86,7 @@ learn = model_to_learner(models.resnet18(pretrained=True), IMAGENET_IM_SIZE)
|
||||||
IM_URL = "https://cvbp.blob.core.windows.net/public/images/cvbp_cup.jpg"
|
IM_URL = "https://cvbp.blob.core.windows.net/public/images/cvbp_cup.jpg"
|
||||||
urllib.request.urlretrieve(IM_URL, os.path.join(data_path(), "example.jpg"))
|
urllib.request.urlretrieve(IM_URL, os.path.join(data_path(), "example.jpg"))
|
||||||
|
|
||||||
im = open_image(os.path.join(data_path(), "example.jpg"), convert_mode='RGB')
|
im = open_image(os.path.join(data_path(), "example.jpg"), convert_mode="RGB")
|
||||||
im
|
im
|
||||||
|
|
||||||
|
|
||||||
|
@ -100,8 +104,8 @@ print(f"Took {time.time()-start_time} sec")
|
||||||
|
|
||||||
|
|
||||||
# ### 2.2 WebCam Stream
|
# ### 2.2 WebCam Stream
|
||||||
#
|
#
|
||||||
# Now, let's use WebCam stream for image classification. We use `ipywebrtc` to start a webcam and get the video stream to the notebook's widget. For details about `ipywebrtc`, see [this link](https://ipywebrtc.readthedocs.io/en/latest/).
|
# Now, let's use WebCam stream for image classification. We use `ipywebrtc` to start a webcam and get the video stream to the notebook's widget. For details about `ipywebrtc`, see [this link](https://ipywebrtc.readthedocs.io/en/latest/).
|
||||||
|
|
||||||
# In[7]:
|
# In[7]:
|
||||||
|
|
||||||
|
@ -109,38 +113,42 @@ print(f"Took {time.time()-start_time} sec")
|
||||||
# Webcam
|
# Webcam
|
||||||
w_cam = CameraStream(
|
w_cam = CameraStream(
|
||||||
constraints={
|
constraints={
|
||||||
'facing_mode': 'user',
|
"facing_mode": "user",
|
||||||
'audio': False,
|
"audio": False,
|
||||||
'video': { 'width': IMAGENET_IM_SIZE, 'height': IMAGENET_IM_SIZE }
|
"video": {"width": IMAGENET_IM_SIZE, "height": IMAGENET_IM_SIZE},
|
||||||
},
|
},
|
||||||
layout=Layout(width=f'{IMAGENET_IM_SIZE}px')
|
layout=Layout(width=f"{IMAGENET_IM_SIZE}px"),
|
||||||
)
|
)
|
||||||
# Image recorder for taking a snapshot
|
# Image recorder for taking a snapshot
|
||||||
w_imrecorder = ImageRecorder(stream=w_cam, layout=Layout(padding='0 0 0 50px'))
|
w_imrecorder = ImageRecorder(stream=w_cam, layout=Layout(padding="0 0 0 50px"))
|
||||||
# Label widget to show our classification results
|
# Label widget to show our classification results
|
||||||
w_label = Label(layout=Layout(padding='0 0 0 50px'))
|
w_label = Label(layout=Layout(padding="0 0 0 50px"))
|
||||||
|
|
||||||
|
|
||||||
def classify_frame(_):
|
def classify_frame(_):
|
||||||
""" Classify an image snapshot by using a pretrained model
|
""" Classify an image snapshot by using a pretrained model
|
||||||
"""
|
"""
|
||||||
# Once capturing started, remove the capture widget since we don't need it anymore
|
# Once capturing started, remove the capture widget since we don't need it anymore
|
||||||
if w_imrecorder.layout.display != 'none':
|
if w_imrecorder.layout.display != "none":
|
||||||
w_imrecorder.layout.display = 'none'
|
w_imrecorder.layout.display = "none"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
im = open_image(io.BytesIO(w_imrecorder.image.value), convert_mode='RGB')
|
im = open_image(
|
||||||
|
io.BytesIO(w_imrecorder.image.value), convert_mode="RGB"
|
||||||
|
)
|
||||||
_, ind, prob = learn.predict(im)
|
_, ind, prob = learn.predict(im)
|
||||||
# Show result label and confidence
|
# Show result label and confidence
|
||||||
w_label.value = f"{labels[ind]} ({prob[ind]:.2f})"
|
w_label.value = f"{labels[ind]} ({prob[ind]:.2f})"
|
||||||
except OSError:
|
except OSError:
|
||||||
# If im_recorder doesn't have valid image data, skip it.
|
# If im_recorder doesn't have valid image data, skip it.
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Taking the next snapshot programmatically
|
# Taking the next snapshot programmatically
|
||||||
w_imrecorder.recording = True
|
w_imrecorder.recording = True
|
||||||
|
|
||||||
# Register classify_frame as a callback. Will be called whenever image.value changes.
|
|
||||||
w_imrecorder.image.observe(classify_frame, 'value')
|
# Register classify_frame as a callback. Will be called whenever image.value changes.
|
||||||
|
w_imrecorder.image.observe(classify_frame, "value")
|
||||||
|
|
||||||
|
|
||||||
# In[8]:
|
# In[8]:
|
||||||
|
@ -156,20 +164,16 @@ HBox([w_cam, w_imrecorder, w_label])
|
||||||
# <img src="https://cvbp.blob.core.windows.net/public/images/cvbp_webcam.png" style="width: 400px;"/>
|
# <img src="https://cvbp.blob.core.windows.net/public/images/cvbp_webcam.png" style="width: 400px;"/>
|
||||||
# <i>Webcam image classification example</i>
|
# <i>Webcam image classification example</i>
|
||||||
# </center>
|
# </center>
|
||||||
#
|
#
|
||||||
# <br>
|
# <br>
|
||||||
#
|
#
|
||||||
# In this notebook, we have shown a quickstart example of using a pretrained model to classify images. The model, however, is not able to predict the object labels that are not part of ImageNet. From our [training introduction notebook](01_training_introduction.ipynb), you can find how to fine-tune the model to address such problems.
|
# In this notebook, we have shown a quickstart example of using a pretrained model to classify images. The model, however, is not able to predict the object labels that are not part of ImageNet. From our [training introduction notebook](01_training_introduction.ipynb), you can find how to fine-tune the model to address such problems.
|
||||||
|
|
||||||
# In[9]:
|
# In[9]:
|
||||||
|
|
||||||
|
|
||||||
# Stop the model and webcam
|
# Stop the model and webcam
|
||||||
Widget.close_all()
|
Widget.close_all()
|
||||||
|
|
||||||
|
|
||||||
# In[ ]:
|
# In[ ]:
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -34,9 +34,9 @@ import sys
|
||||||
|
|
||||||
sys.path.append("../")
|
sys.path.append("../")
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from utils_ic.datasets import Urls, unzip_url, data_path
|
from utils_ic.datasets import Urls, unzip_url
|
||||||
from fastai.vision import *
|
from fastai.vision import *
|
||||||
from fastai.metrics import error_rate, accuracy
|
from fastai.metrics import accuracy
|
||||||
|
|
||||||
|
|
||||||
# Set some parameters. We'll use the `unzip_url` helper function to download and unzip our data.
|
# Set some parameters. We'll use the `unzip_url` helper function to download and unzip our data.
|
||||||
|
|
|
@ -0,0 +1,182 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# coding: utf-8
|
||||||
|
|
||||||
|
# # Testing different Hyperparameters
|
||||||
|
|
||||||
|
# Lets say we want to learn more about __how different learning rates and different image sizes affect our model's accuracy when restricted to 10 epochs__, and we want to build an experiment to test out these hyperparameters.
|
||||||
|
#
|
||||||
|
# In this notebook, we'll walk through how we use out Parameter Sweeper module with the following:
|
||||||
|
#
|
||||||
|
# - use python to perform this experiment
|
||||||
|
# - use the CLI to perform this experiment
|
||||||
|
# - evalute the results using Pandas
|
||||||
|
|
||||||
|
# ---
|
||||||
|
|
||||||
|
# In[1]:
|
||||||
|
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.append("../")
|
||||||
|
import os
|
||||||
|
|
||||||
|
from utils_ic.common import ic_root_path
|
||||||
|
from utils_ic.datasets import unzip_url, Urls
|
||||||
|
from utils_ic.parameter_sweeper import *
|
||||||
|
|
||||||
|
|
||||||
|
# Lets download some data that we want to test on. To use the Parameter Sweeper tool for single label classification, we'll need to make sure that the data is stored such that images are sorted into their classes inside of a subfolder. In this notebook, we'll use the Fridge Objects dataset provided in `utils_ic.datasets.Urls`, which is stored in the correct format.
|
||||||
|
|
||||||
|
# In[2]:
|
||||||
|
|
||||||
|
|
||||||
|
input_data = unzip_url(Urls.fridge_objects_path, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
# ## Using Python
|
||||||
|
|
||||||
|
# We start by creating the Parameter Sweeper object:
|
||||||
|
|
||||||
|
# In[3]:
|
||||||
|
|
||||||
|
|
||||||
|
sweeper = ParameterSweeper()
|
||||||
|
|
||||||
|
|
||||||
|
# Before we start testing, it's a good idea to see what the default parameters Are. We can use a the property `parameters` to easily see those default values.
|
||||||
|
|
||||||
|
# In[4]:
|
||||||
|
|
||||||
|
|
||||||
|
sweeper.parameters
|
||||||
|
|
||||||
|
|
||||||
|
# Now that we know the defaults, we can pass it the parameters we want to test.
|
||||||
|
#
|
||||||
|
# In this notebook, we want to see the effect of different learning rates across different image sizes using only 8 epochs (the default number of epochs is 15). To do so, I would run the `update_parameters` functions as follows:
|
||||||
|
#
|
||||||
|
# ```python
|
||||||
|
# sweeper.update_parameters(learning_rate=[1e-3, 1e-4, 1e-5], im_size=[299, 499], epochs=[10])
|
||||||
|
# ```
|
||||||
|
#
|
||||||
|
# Notice that all parameters must be passed in as a list, including single values such the number of epochs.
|
||||||
|
#
|
||||||
|
# These parameters will be used to calculate the number of permutations to run. In this case, we've pass in three options for learning rates, two for image sizes, and one for number of epochs. This will result in 3 X 2 X 1 total permutations (in otherwords, 6 permutations).
|
||||||
|
|
||||||
|
# In[5]:
|
||||||
|
|
||||||
|
|
||||||
|
sweeper.update_parameters(
|
||||||
|
learning_rate=[1e-3, 1e-4, 1e-5], im_size=[299, 499], epochs=[10]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Now that we have our parameters defined, we call the `run()` function with the dataset to test on.
|
||||||
|
#
|
||||||
|
# We can also optionally pass in:
|
||||||
|
# - the number of repetitions to run each permutation (default is 3)
|
||||||
|
# - whether or not we want the training to stop early if the metric (accuracy) doesn't improve by 0.01 (1%) over 3 epochs (default is False)
|
||||||
|
#
|
||||||
|
# The `run` function returns a multi-index dataframe which we can work with right away.
|
||||||
|
|
||||||
|
# In[6]:
|
||||||
|
|
||||||
|
|
||||||
|
df = sweeper.run(datasets=[input_data], reps=3)
|
||||||
|
df
|
||||||
|
|
||||||
|
|
||||||
|
# ## Using the CLI
|
||||||
|
|
||||||
|
# Instead of using python to run this experiment, we may want to test from the CLI. We can do so by using the `scripts/benchmark.py` file.
|
||||||
|
#
|
||||||
|
# First we move up to the `/image_classification` directory.
|
||||||
|
|
||||||
|
# In[7]:
|
||||||
|
|
||||||
|
|
||||||
|
os.chdir(ic_root_path())
|
||||||
|
|
||||||
|
|
||||||
|
# To reproduce the same test (different learning rates across different image sizes using only 8 epochs), and the same settings (3 repetitions, and no early_stopping) we can run the following:
|
||||||
|
#
|
||||||
|
# ```sh
|
||||||
|
# python scripts/sweep.py
|
||||||
|
# --learning-rates 1e-3 1e-4 1e-5
|
||||||
|
# --im-size 99 299
|
||||||
|
# --epochs 5
|
||||||
|
# --repeat 3
|
||||||
|
# --no-early-stopping
|
||||||
|
# --inputs <my-data-dir>
|
||||||
|
# --output lr_bs_test.csv
|
||||||
|
# ```
|
||||||
|
#
|
||||||
|
# Additionally, we've added an output parameter, which will automatically dump our dataframe into a csv file.
|
||||||
|
#
|
||||||
|
# To simplify the command, we can use the acryonyms of the params. We can also remove `--no-early-stopping` as that is the default behavior.
|
||||||
|
#
|
||||||
|
# ```sh
|
||||||
|
# python scripts/sweep.py -lr 1e-3 1e-4 1e-5 -is 99 299 -e 5 -i <my-data-dir> -o lr_bs_test.csv
|
||||||
|
# ```
|
||||||
|
|
||||||
|
# In[8]:
|
||||||
|
|
||||||
|
|
||||||
|
# use {sys.executable} instead of just running `python` to ensure the command is executed using the environment cvbp
|
||||||
|
get_ipython().system(
|
||||||
|
"{sys.executable} scripts/sweep.py -lr 1e-3 1e-4 1e-5 -is 99 299 -e 5 -i {input_data} -o data/lr_bs_test.csv"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Once the script completes, load the csv into a dataframe to explore it's contents. We'll want to specify `index_col=[0, 1, 2]` since it is a multi-index dataframe.
|
||||||
|
#
|
||||||
|
# ```python
|
||||||
|
# df = pd.read_csv("data/lr_bs_test.csv", index_col=[0, 1, 2])
|
||||||
|
# ```
|
||||||
|
|
||||||
|
# HINT: You can learn more about how to use the script with the `--help` flag.
|
||||||
|
|
||||||
|
# In[14]:
|
||||||
|
|
||||||
|
|
||||||
|
get_ipython().system("{sys.executable} scripts/sweep.py --help")
|
||||||
|
|
||||||
|
|
||||||
|
# ---
|
||||||
|
|
||||||
|
# ## Visualizing our results
|
||||||
|
|
||||||
|
# When we read in out multi-index dataframe, index 0 represents the run number, index 1 represents a single permutation of parameters, and index 2 represents the dataset.
|
||||||
|
|
||||||
|
# To see the results, show the df using the `clean_df` helper function. This will display all the hyperparameters in a nice, readable way.
|
||||||
|
|
||||||
|
# In[15]:
|
||||||
|
|
||||||
|
|
||||||
|
df = clean_df(df)
|
||||||
|
df
|
||||||
|
|
||||||
|
|
||||||
|
# Since we've run our benchmarking over 3 repetitions, we may want to just look at the averages across the different __run numbers__.
|
||||||
|
|
||||||
|
# In[16]:
|
||||||
|
|
||||||
|
|
||||||
|
df.mean(level=(1, 2)).T
|
||||||
|
|
||||||
|
|
||||||
|
# Additionally, we may want simply to see which set of hyperparameters perform the best across the different __datasets__. We can do that by averaging the results of the different datasets. (The results of this step will look similar to the above since we're only passing in one dataset).
|
||||||
|
|
||||||
|
# In[17]:
|
||||||
|
|
||||||
|
|
||||||
|
df.mean(level=(1)).T
|
||||||
|
|
||||||
|
|
||||||
|
# To make it easier to see which permutation did the best, we can plot the results using the `plot_df` helper function. This plot will help us easily see which parameters offer the highest accuracies.
|
||||||
|
|
||||||
|
# In[18]:
|
||||||
|
|
||||||
|
|
||||||
|
plot_df(df.mean(level=(1)), sort_by="accuracy")
|
|
@ -0,0 +1,262 @@
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
from utils_ic.parameter_sweeper import *
|
||||||
|
from utils_ic.datasets import data_path
|
||||||
|
from argparse import RawTextHelpFormatter, Namespace
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
argparse_desc_msg = """
|
||||||
|
This script is used to benchmark the different hyperparameters when it comes to doing image classification.
|
||||||
|
|
||||||
|
This script will run all permutations of the parameters that are passed in.
|
||||||
|
|
||||||
|
This script will either run these tests on:
|
||||||
|
- an input dataset defined by --input
|
||||||
|
- a set of benchmarking datasets defined by --benchmark, which will create a
|
||||||
|
temporary data directory with all benchmarking datasets loaded into it, and delete it at the end.
|
||||||
|
|
||||||
|
This script uses accuracy as the evaluation metric.
|
||||||
|
|
||||||
|
Use [-W ignore] to ignore warning messages when running the script.
|
||||||
|
"""
|
||||||
|
|
||||||
|
argparse_epilog_msg = """
|
||||||
|
Example usage:
|
||||||
|
{default_params}
|
||||||
|
|
||||||
|
# Test the effect of 3 learning rates on 3 batch sizes
|
||||||
|
$ python sweep.py -lr 1e-3 1e-4 1e-5 -bs 8 16 32 -i <input_data> -o learning_rate_batch_size.csv
|
||||||
|
|
||||||
|
# Test the effect of one cycle policy without using discriminative learning rates over 5 runs
|
||||||
|
$ python sweep.py -dl False -ocp True False -r 5 -i <input_data> -o ocp_dl.csv
|
||||||
|
|
||||||
|
# Test different architectures and image sizes
|
||||||
|
$ python sweep.py -a squeezenet1_1 resenet18 resnet50 -is 299 499 -i <input_data> -o arch_im_sizes.csv
|
||||||
|
|
||||||
|
# Test different training schedules over 3 runs on the benchmark dataset
|
||||||
|
$ python sweep.py -ts body_only head_first_then_body -r 3 --benchmark -o training_schedule.csv
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
To view results, we recommend using pandas dataframes:
|
||||||
|
|
||||||
|
```
|
||||||
|
import pandas as pd
|
||||||
|
df = pd.read_csv("results.csv", index_col=[0, 1, 2])
|
||||||
|
```
|
||||||
|
|
||||||
|
""".format
|
||||||
|
|
||||||
|
time_msg = """Total Time elapsed: {time} seconds.""".format
|
||||||
|
|
||||||
|
output_msg = """Output has been saved to '{output_path}'.""".format
|
||||||
|
|
||||||
|
|
||||||
|
def _str_to_bool(string: str) -> bool:
|
||||||
|
""" Convert string to bool. """
|
||||||
|
if string.lower() in ("yes", "true", "t", "y", "1"):
|
||||||
|
return True
|
||||||
|
elif string.lower() in ("no", "false", "f", "n", "0"):
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
raise argparse.ArgumentTypeError("Boolean value expected.")
|
||||||
|
|
||||||
|
|
||||||
|
def _get_parser(default_params: Dict[str, List[Any]]) -> Namespace:
|
||||||
|
""" Get parser for this script. """
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description=argparse_desc_msg(),
|
||||||
|
epilog=argparse_epilog_msg(default_params=default_params),
|
||||||
|
formatter_class=RawTextHelpFormatter,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--learning-rate",
|
||||||
|
"-lr",
|
||||||
|
dest="learning_rates",
|
||||||
|
nargs="+",
|
||||||
|
help="Learning rate - recommended options: [1e-3, 1e-4, 1e-5] ",
|
||||||
|
type=float,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
"-e",
|
||||||
|
dest="epochs",
|
||||||
|
nargs="+",
|
||||||
|
help="Epochs - recommended options: [3, 5, 10, 15]",
|
||||||
|
type=int,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch-size",
|
||||||
|
"-bs",
|
||||||
|
dest="batch_sizes",
|
||||||
|
nargs="+",
|
||||||
|
help="Batch sizes - recommended options: [8, 16, 32, 64]",
|
||||||
|
type=int,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--im-size",
|
||||||
|
"-is",
|
||||||
|
dest="im_sizes",
|
||||||
|
nargs="+",
|
||||||
|
help="Image sizes - recommended options: [299, 499]",
|
||||||
|
type=int,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--architecture",
|
||||||
|
"-a",
|
||||||
|
dest="architectures",
|
||||||
|
nargs="+",
|
||||||
|
choices=["squeezenet1_1", "resnet18", "resnet34", "resnet50"],
|
||||||
|
help="Choose an architecture.",
|
||||||
|
type=str,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--transform",
|
||||||
|
"-t",
|
||||||
|
dest="transforms",
|
||||||
|
nargs="+",
|
||||||
|
help="Tranform (data augmentation) - options: [True, False]",
|
||||||
|
type=_str_to_bool,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dropout",
|
||||||
|
"-d",
|
||||||
|
dest="dropouts",
|
||||||
|
nargs="+",
|
||||||
|
help="Dropout - recommended options: [0.5]",
|
||||||
|
type=float,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--weight-decay",
|
||||||
|
"-wd",
|
||||||
|
dest="weight_decays",
|
||||||
|
nargs="+",
|
||||||
|
help="Weight decay - recommended options: [0.01]",
|
||||||
|
type=float,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--training-schedule",
|
||||||
|
"-ts",
|
||||||
|
dest="training_schedules",
|
||||||
|
nargs="+",
|
||||||
|
choices=["head_only", "body_only", "head_first_then_body"],
|
||||||
|
help="Choose a training schedule",
|
||||||
|
type=str,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--discriminative-lr",
|
||||||
|
"-dl",
|
||||||
|
dest="discriminative_lrs",
|
||||||
|
nargs="+",
|
||||||
|
help="Discriminative learning rate - options: [True, False]. To use discriminative learning rates, training schedule must not be 'head_only'",
|
||||||
|
choices=["True", "False"],
|
||||||
|
type=_str_to_bool,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--one-cycle-policy",
|
||||||
|
"-ocp",
|
||||||
|
dest="one_cycle_policies",
|
||||||
|
nargs="+",
|
||||||
|
help="one cycle policy - options: [True, False]",
|
||||||
|
type=_str_to_bool,
|
||||||
|
)
|
||||||
|
i_parser = parser.add_mutually_exclusive_group(required=True)
|
||||||
|
i_parser.add_argument(
|
||||||
|
"--inputs",
|
||||||
|
"-i",
|
||||||
|
dest="inputs",
|
||||||
|
nargs="+",
|
||||||
|
help="A list of data paths to run the tests on. The datasets must be structured so that each class is in a separate folder. <--benchmark> must be False",
|
||||||
|
type=str,
|
||||||
|
)
|
||||||
|
i_parser.add_argument(
|
||||||
|
"--benchmark",
|
||||||
|
dest="benchmark",
|
||||||
|
action="store_true",
|
||||||
|
help="Whether or not to use curated benchmark datasets to test. <--input> must be empty",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--early-stopping",
|
||||||
|
dest="early_stopping",
|
||||||
|
action="store_true",
|
||||||
|
help="Stop training early if possible",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--repeat",
|
||||||
|
"-r",
|
||||||
|
dest="repeat",
|
||||||
|
help="The number of times to repeat each permutation",
|
||||||
|
type=int,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output", "-o", dest="output", help="The path of the output file."
|
||||||
|
)
|
||||||
|
parser.set_defaults(
|
||||||
|
repeat=3, early_stopping=False, inputs=None, benchmark=False
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# if discriminative lr is on, we cannot have a 'head_only'
|
||||||
|
# training_schedule
|
||||||
|
if args.discriminative_lrs is not None and True in args.discriminative_lrs:
|
||||||
|
assert "head_only" not in args.training_schedules
|
||||||
|
|
||||||
|
# get mapping of architecture enum: ex. "resnet34" -->
|
||||||
|
# Architecture.resnet34 -> models.resnet34
|
||||||
|
if args.architectures is not None:
|
||||||
|
args.architectures = [Architecture[a] for a in args.architectures]
|
||||||
|
|
||||||
|
# get mapping of training enum: ex. "head_only" -->
|
||||||
|
# TrainingSchedule.head_only --> 0
|
||||||
|
if args.training_schedules is not None:
|
||||||
|
args.training_schedules = [
|
||||||
|
TrainingSchedule[t] for t in args.training_schedules
|
||||||
|
]
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
sweeper = ParameterSweeper()
|
||||||
|
args = _get_parser(sweeper.parameters)
|
||||||
|
|
||||||
|
sweeper.update_parameters(
|
||||||
|
learning_rate=args.learning_rates,
|
||||||
|
epochs=args.epochs,
|
||||||
|
batch_size=args.batch_sizes,
|
||||||
|
im_size=args.im_sizes,
|
||||||
|
architecture=args.architectures,
|
||||||
|
transform=args.transforms,
|
||||||
|
dropout=args.dropouts,
|
||||||
|
weight_decay=args.weight_decays,
|
||||||
|
training_schedule=args.training_schedules,
|
||||||
|
discriminative_lr=args.discriminative_lrs,
|
||||||
|
one_cycle_policy=args.one_cycle_policies,
|
||||||
|
)
|
||||||
|
|
||||||
|
data = args.inputs
|
||||||
|
if not data:
|
||||||
|
data = Experiment.download_benchmark_datasets(
|
||||||
|
Path(data_path()) / "benchmark_data"
|
||||||
|
)
|
||||||
|
|
||||||
|
df = sweeper.run(
|
||||||
|
datasets=data, reps=args.repeat, early_stopping=args.early_stopping
|
||||||
|
)
|
||||||
|
df.to_csv(args.output)
|
||||||
|
|
||||||
|
if args.benchmark:
|
||||||
|
for path in args.inputs:
|
||||||
|
shutil.rmtree(path)
|
||||||
|
|
||||||
|
end = time.time()
|
||||||
|
print(time_msg(time=round(end - start, 1)))
|
||||||
|
print(output_msg(output_path=os.path.realpath(args.output)))
|
|
@ -0,0 +1,3 @@
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
TEMP_DIR = Path("tmp_data")
|
|
@ -0,0 +1,99 @@
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
import shutil
|
||||||
|
import pandas as pd
|
||||||
|
from pathlib import Path
|
||||||
|
from utils_ic.datasets import Urls, unzip_url
|
||||||
|
from utils_ic.parameter_sweeper import *
|
||||||
|
from constants import TEMP_DIR
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_data():
|
||||||
|
if os.path.exists(TEMP_DIR):
|
||||||
|
shutil.rmtree(TEMP_DIR)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def setup_all_datasets(request):
|
||||||
|
""" Sets up all available datasets for testing on. """
|
||||||
|
ParameterSweeper.download_benchmark_datasets(TEMP_DIR)
|
||||||
|
request.addfinalizer(cleanup_data)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def setup_a_dataset(request):
|
||||||
|
""" Sets up a dataset for testing on. """
|
||||||
|
os.makedirs(TEMP_DIR)
|
||||||
|
unzip_url(Urls.fridge_objects_path, TEMP_DIR, exist_ok=True)
|
||||||
|
request.addfinalizer(cleanup_data)
|
||||||
|
|
||||||
|
|
||||||
|
def _test_sweeper_run(df: pd.DataFrame, df_length: int):
|
||||||
|
""" Performs basic tests that all df should pass.
|
||||||
|
Args:
|
||||||
|
df (pd.DataFame): the df to check
|
||||||
|
df_length (int): to assert the len(df) == df_length
|
||||||
|
"""
|
||||||
|
# assert len
|
||||||
|
assert len(df) == df_length
|
||||||
|
# assert df is a multi-index dataframe
|
||||||
|
assert isinstance(df.index, pd.core.index.MultiIndex)
|
||||||
|
# assert clean_df works
|
||||||
|
df = clean_df(df)
|
||||||
|
assert isinstance(df.index, pd.core.index.MultiIndex)
|
||||||
|
# assert no error when calling plot_df function
|
||||||
|
plot_df(df)
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_sweeper_single_dataset(setup_a_dataset):
|
||||||
|
""" Test default sweeper on a single dataset. """
|
||||||
|
fridge_objects_path = TEMP_DIR / "fridgeObjects"
|
||||||
|
sweeper = ParameterSweeper()
|
||||||
|
df = sweeper.run([fridge_objects_path])
|
||||||
|
_test_sweeper_run(df, df_length=3)
|
||||||
|
|
||||||
|
# assert accuracy over 3 runs is > 85%
|
||||||
|
assert df.mean(level=(1))["accuracy"][0] > 0.85
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_sweeper_benchmark_dataset(setup_all_datasets):
|
||||||
|
"""
|
||||||
|
Test default sweeper on benchmark dataset.
|
||||||
|
WARNING: This test can take a while to execute since we run the sweeper
|
||||||
|
across all benchmark datasets.
|
||||||
|
"""
|
||||||
|
datasets = [Path(d) for d in os.scandir(TEMP_DIR) if os.path.isdir(d)]
|
||||||
|
sweeper = ParameterSweeper()
|
||||||
|
df = sweeper.run(datasets, reps=1)
|
||||||
|
_test_sweeper_run(df, df_length=len(datasets))
|
||||||
|
|
||||||
|
# assert min accuracy for each dataset
|
||||||
|
assert df.mean(level=(2)).loc["fridgeObjects", "accuracy"] > 0.85
|
||||||
|
assert df.mean(level=(2)).loc["food101Subset", "accuracy"] > 0.75
|
||||||
|
assert df.mean(level=(2)).loc["fashionTexture", "accuracy"] > 0.70
|
||||||
|
assert df.mean(level=(2)).loc["flickrLogos32Subset", "accuracy"] > 0.75
|
||||||
|
assert df.mean(level=(2)).loc["lettuce", "accuracy"] > 0.70
|
||||||
|
assert df.mean(level=(2)).loc["recycle_v3", "accuracy"] > 0.85
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_parameters_01(setup_a_dataset):
|
||||||
|
""" Tests updating parameters. """
|
||||||
|
fridge_objects_path = TEMP_DIR / "fridgeObjects"
|
||||||
|
sweeper = ParameterSweeper()
|
||||||
|
|
||||||
|
# at this point there should only be 1 permutation of the default params
|
||||||
|
assert len(sweeper.permutations) == 1
|
||||||
|
sweeper.update_parameters(
|
||||||
|
learning_rate=[1e-3, 1e-4, 1e-5], im_size=[299, 499], epochs=[5]
|
||||||
|
)
|
||||||
|
# assert that there are not 6 permutations
|
||||||
|
assert len(sweeper.permutations) == 6
|
||||||
|
df = sweeper.run([fridge_objects_path])
|
||||||
|
_test_sweeper_run(df, df_length=18)
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_parameters_02(setup_a_dataset):
|
||||||
|
""" Tests exception when updating parameters. """
|
||||||
|
sweeper = ParameterSweeper()
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
sweeper.update_parameters(bad_key=[1e-3, 1e-4, 1e-5])
|
|
@ -21,7 +21,7 @@ def make_temp_data_dir(request):
|
||||||
|
|
||||||
|
|
||||||
def _test_url_data(url: str, path: Union[Path, str], dir_name: str):
|
def _test_url_data(url: str, path: Union[Path, str], dir_name: str):
|
||||||
data_path = unzip_url(url, fpath=path, dest=path, overwrite=True)
|
data_path = unzip_url(url, fpath=path, dest=path, exist_ok=True)
|
||||||
# assert zip file exists
|
# assert zip file exists
|
||||||
assert os.path.exists(os.path.join(path, f"{dir_name}.zip"))
|
assert os.path.exists(os.path.join(path, f"{dir_name}.zip"))
|
||||||
# assert unzipped file (titled {dir_name}) exists
|
# assert unzipped file (titled {dir_name}) exists
|
||||||
|
@ -48,25 +48,6 @@ def test_unzip_url_abs_path(make_temp_data_dir):
|
||||||
_test_url_data(Urls.recycle_path, abs_path, "recycle_v3")
|
_test_url_data(Urls.recycle_path, abs_path, "recycle_v3")
|
||||||
|
|
||||||
|
|
||||||
def test_unzip_url_overwrite(make_temp_data_dir):
|
|
||||||
""" Test if overwrite is true and file exists """
|
|
||||||
|
|
||||||
# test overwrite=True
|
|
||||||
os.makedirs(TEMP_DIR / "fridgeObjects")
|
|
||||||
fridge_objects_path = unzip_url(
|
|
||||||
Urls.fridge_objects_path, TEMP_DIR, overwrite=True
|
|
||||||
)
|
|
||||||
assert os.path.realpath(TEMP_DIR / "fridgeObjects") == os.path.realpath(
|
|
||||||
fridge_objects_path
|
|
||||||
)
|
|
||||||
assert len(os.listdir(fridge_objects_path)) >= 0
|
|
||||||
|
|
||||||
# test file exists error when overwrite=False
|
|
||||||
os.makedirs(TEMP_DIR / "lettuce")
|
|
||||||
with pytest.raises(FileExistsError):
|
|
||||||
unzip_url(Urls.lettuce_path, TEMP_DIR, overwrite=False)
|
|
||||||
|
|
||||||
|
|
||||||
def test_unzip_url_exist_ok(make_temp_data_dir):
|
def test_unzip_url_exist_ok(make_temp_data_dir):
|
||||||
"""
|
"""
|
||||||
Test if exist_ok is true and (file exists, file does not exist)
|
Test if exist_ok is true and (file exists, file does not exist)
|
||||||
|
|
|
@ -0,0 +1,14 @@
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def ic_root_path() -> Path:
|
||||||
|
"""Get the image classification root path"""
|
||||||
|
return os.path.realpath(os.path.join(os.path.dirname(__file__), os.pardir))
|
||||||
|
|
||||||
|
|
||||||
|
def data_path() -> Path:
|
||||||
|
"""Get the data directory path"""
|
||||||
|
return os.path.realpath(
|
||||||
|
os.path.join(os.path.dirname(__file__), os.pardir, "data")
|
||||||
|
)
|
|
@ -1,6 +1,6 @@
|
||||||
import os
|
import os
|
||||||
import requests
|
import requests
|
||||||
import shutil
|
from .common import data_path
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
from urllib.parse import urljoin, urlparse
|
from urllib.parse import urljoin, urlparse
|
||||||
|
@ -19,6 +19,7 @@ class Urls:
|
||||||
# datasets
|
# datasets
|
||||||
fridge_objects_path = urljoin(base, "fridgeObjects.zip")
|
fridge_objects_path = urljoin(base, "fridgeObjects.zip")
|
||||||
food_101_subset_path = urljoin(base, "food101Subset.zip")
|
food_101_subset_path = urljoin(base, "food101Subset.zip")
|
||||||
|
fashion_texture_path = urljoin(base, "fashionTexture.zip")
|
||||||
flickr_logos_32_subset_path = urljoin(base, "flickrLogos32Subset.zip")
|
flickr_logos_32_subset_path = urljoin(base, "flickrLogos32Subset.zip")
|
||||||
lettuce_path = urljoin(base, "lettuce.zip")
|
lettuce_path = urljoin(base, "lettuce.zip")
|
||||||
recycle_path = urljoin(base, "recycle_v3.zip")
|
recycle_path = urljoin(base, "recycle_v3.zip")
|
||||||
|
@ -38,15 +39,8 @@ def imagenet_labels() -> list:
|
||||||
return [labels[str(k)][1] for k in range(len(labels))]
|
return [labels[str(k)][1] for k in range(len(labels))]
|
||||||
|
|
||||||
|
|
||||||
def data_path() -> Path:
|
|
||||||
"""Get the data path"""
|
|
||||||
return os.path.realpath(
|
|
||||||
os.path.join(os.path.dirname(__file__), os.pardir, "data")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_file_name(url: str) -> str:
|
def _get_file_name(url: str) -> str:
|
||||||
"""Get a file name based on url"""
|
""" Get a file name based on url. """
|
||||||
return urlparse(url).path.split("/")[-1]
|
return urlparse(url).path.split("/")[-1]
|
||||||
|
|
||||||
|
|
||||||
|
@ -55,15 +49,21 @@ def unzip_url(
|
||||||
fpath: Union[Path, str] = data_path(),
|
fpath: Union[Path, str] = data_path(),
|
||||||
dest: Union[Path, str] = data_path(),
|
dest: Union[Path, str] = data_path(),
|
||||||
exist_ok: bool = False,
|
exist_ok: bool = False,
|
||||||
overwrite: bool = False,
|
|
||||||
) -> Path:
|
) -> Path:
|
||||||
"""
|
""" Download file from URL to {fpath} and unzip to {dest}.
|
||||||
Download file from URL to {fpath} and unzip to {dest}.
|
|
||||||
{fpath} and {dest} must be directories
|
{fpath} and {dest} must be directories
|
||||||
Params:
|
|
||||||
exist_ok: if exist_ok, then skip if exists, otherwise throw error
|
Args:
|
||||||
overwrite: if overwrite, remove zipped file and unziped dir
|
url (str): url to download from
|
||||||
Returns path of {dest}
|
fpath (Union[Path, str]): The location to save the url zip file to
|
||||||
|
dest (Union[Path, str]): The destination to unzip {fpath}
|
||||||
|
exist_ok (bool): if exist_ok, then skip if exists, otherwise throw error
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileExistsError: if file exists
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path of {dest}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _raise_file_exists_error(path: Union[Path, str]) -> None:
|
def _raise_file_exists_error(path: Union[Path, str]) -> None:
|
||||||
|
@ -78,16 +78,6 @@ def unzip_url(
|
||||||
zip_file = Path(os.path.join(fpath, fname))
|
zip_file = Path(os.path.join(fpath, fname))
|
||||||
unzipped_dir = Path(os.path.join(fpath, fname_without_extension))
|
unzipped_dir = Path(os.path.join(fpath, fname_without_extension))
|
||||||
|
|
||||||
if overwrite:
|
|
||||||
try:
|
|
||||||
os.remove(zip_file)
|
|
||||||
except OSError as e:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
shutil.rmtree(unzipped_dir)
|
|
||||||
except OSError as e:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# download zipfile if zipfile not exists
|
# download zipfile if zipfile not exists
|
||||||
if zip_file.is_file():
|
if zip_file.is_file():
|
||||||
_raise_file_exists_error(zip_file)
|
_raise_file_exists_error(zip_file)
|
||||||
|
@ -106,3 +96,20 @@ def unzip_url(
|
||||||
z.close()
|
z.close()
|
||||||
|
|
||||||
return os.path.realpath(os.path.join(fpath, fname_without_extension))
|
return os.path.realpath(os.path.join(fpath, fname_without_extension))
|
||||||
|
|
||||||
|
|
||||||
|
def unzip_urls(
|
||||||
|
urls: List[Url], dest: Union[Path, str] = data_path()
|
||||||
|
) -> List[Path]:
|
||||||
|
""" Download and unzip all datasets in Urls to dest """
|
||||||
|
|
||||||
|
# make dir if not exist
|
||||||
|
if not Path(dest).is_dir():
|
||||||
|
os.makedirs(dest)
|
||||||
|
|
||||||
|
# download all data urls
|
||||||
|
paths = list()
|
||||||
|
for url in urls:
|
||||||
|
paths.append(unzip_url(url, dest, exist_ok=True))
|
||||||
|
|
||||||
|
return paths
|
||||||
|
|
|
@ -0,0 +1,436 @@
|
||||||
|
import itertools
|
||||||
|
import pandas as pd
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
|
||||||
|
from utils_ic.datasets import Urls, data_path, unzip_urls
|
||||||
|
from collections import OrderedDict
|
||||||
|
from fastai.vision import *
|
||||||
|
from fastai.callbacks import EarlyStoppingCallback
|
||||||
|
from fastai.metrics import accuracy
|
||||||
|
from functools import partial
|
||||||
|
from matplotlib.axes import Axes
|
||||||
|
from typing import Union, List, Any, Dict
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
Time = float
|
||||||
|
parameter_flag = "PARAMETERS"
|
||||||
|
|
||||||
|
|
||||||
|
class TrainingSchedule(Enum):
|
||||||
|
head_only = ("head_only",)
|
||||||
|
body_only = ("body_only",)
|
||||||
|
head_first_then_body = "head_first_then_body"
|
||||||
|
|
||||||
|
|
||||||
|
class Architecture(Enum):
|
||||||
|
resnet18 = partial(models.resnet18)
|
||||||
|
resnet34 = partial(models.resnet34)
|
||||||
|
resnet50 = partial(models.resnet50)
|
||||||
|
squeezenet1_1 = partial(models.squeezenet1_1)
|
||||||
|
|
||||||
|
|
||||||
|
def clean_df(df: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Cleans up experiment paramter strings in {df} by removing all experiment
|
||||||
|
parameters that held constant through each experiment. This method uses a
|
||||||
|
variable <parameter_flag> to search for strings.
|
||||||
|
Args:
|
||||||
|
df (pd.DataFrame): dataframe to clean up
|
||||||
|
Return:
|
||||||
|
pd.DataFrame: df with renamed experiment parameter strings
|
||||||
|
"""
|
||||||
|
text = df.to_html()
|
||||||
|
text = re.findall(fr">\s{{0,1}}{parameter_flag}\s{{0,1}}(.*?)</th>", text)
|
||||||
|
|
||||||
|
sets = [set(t.split("|")) for t in text]
|
||||||
|
intersection = sets[0].intersection(*sets)
|
||||||
|
|
||||||
|
html = df.to_html()
|
||||||
|
for i in intersection:
|
||||||
|
html = html.replace(i, "")
|
||||||
|
html = html.replace("PARAMETERS", "P:")
|
||||||
|
html = html.replace("|", " ")
|
||||||
|
|
||||||
|
return pd.read_html(html, index_col=[0, 1, 2])[0]
|
||||||
|
|
||||||
|
|
||||||
|
def plot_df(
|
||||||
|
df: pd.DataFrame,
|
||||||
|
sort_by: str = "accuracy",
|
||||||
|
figsize: Tuple[int, int] = (12, 8),
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Visuaize graph from {df}, which must contain columns "accuracy" and
|
||||||
|
"duration".
|
||||||
|
Args:
|
||||||
|
df (pd.DataFrame): the dataframe to visualize.
|
||||||
|
sort_by (str): whether to sort visualization by accuracy or duration.
|
||||||
|
figsize (Tuple[int, int]): as defined in matplotlib.
|
||||||
|
Raises:
|
||||||
|
ValueError: if {sort_by} is an invalid value.
|
||||||
|
"""
|
||||||
|
if sort_by not in ("accuracy", "duration"):
|
||||||
|
raise ValueError("{sort_by} must equal 'accuracy' or 'duration'")
|
||||||
|
|
||||||
|
def add_value_labels(
|
||||||
|
ax: Axes, spacing: int = 5, percentage: bool = False
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Add labels to the end of each bar in a bar chart.
|
||||||
|
Args:
|
||||||
|
ax (Axes): The matplotlib object containing the axes of the plot to annotate.
|
||||||
|
spacing (int): The distance between the labels and the bars.
|
||||||
|
percentage (bool): if y-value is a percentage
|
||||||
|
"""
|
||||||
|
for rect in ax.patches:
|
||||||
|
y_value = rect.get_height()
|
||||||
|
x_value = rect.get_x() + rect.get_width() / 2
|
||||||
|
|
||||||
|
label = (
|
||||||
|
"{:.2f}%".format(y_value * 100)
|
||||||
|
if percentage
|
||||||
|
else "{:.1f}".format(y_value)
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.annotate(
|
||||||
|
label,
|
||||||
|
(x_value, y_value),
|
||||||
|
xytext=(0, spacing), # Vertically shift label by `space`
|
||||||
|
textcoords="offset points", # Interpret `xytext` as offset in points
|
||||||
|
ha="center", # Horizontally center label
|
||||||
|
va="bottom", # Vertically align label
|
||||||
|
)
|
||||||
|
|
||||||
|
top_accuracy = df["accuracy"].max()
|
||||||
|
top_duration = df["duration"].max()
|
||||||
|
ax1, ax2 = df.sort_values(by=sort_by).plot.bar(
|
||||||
|
rot=90, subplots=True, legend=False, figsize=figsize
|
||||||
|
)
|
||||||
|
ax1.set_title("Duration (seconds)")
|
||||||
|
ax2.set_title("Accuracy (%)")
|
||||||
|
ax1.set_ylabel("seconds")
|
||||||
|
ax2.set_ylabel("%")
|
||||||
|
ax1.set_ylim(top=top_duration * 1.2)
|
||||||
|
ax2.set_ylim(top=top_accuracy * 1.2)
|
||||||
|
add_value_labels(ax2, percentage=True)
|
||||||
|
add_value_labels(ax1)
|
||||||
|
|
||||||
|
|
||||||
|
class ParameterSweeper:
|
||||||
|
""" Test different permutations of a set of parameters.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
param_order <Tuple[str]>: A fixed ordering of parameters (to match the ordering of <params>)
|
||||||
|
default_params <Dict[str, Any]>: A dict of default parameters
|
||||||
|
params <Dict[str, List[Any]]>: The parameters to run experiments on
|
||||||
|
"""
|
||||||
|
|
||||||
|
default_params = dict(
|
||||||
|
learning_rate=1e-4,
|
||||||
|
epoch=15,
|
||||||
|
batch_size=16,
|
||||||
|
im_size=299,
|
||||||
|
architecture=Architecture.resnet18,
|
||||||
|
transform=True,
|
||||||
|
dropout=0.5,
|
||||||
|
weight_decay=0.01,
|
||||||
|
training_schedule=TrainingSchedule.head_first_then_body,
|
||||||
|
discriminative_lr=False,
|
||||||
|
one_cycle_policy=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, **kwargs) -> None:
|
||||||
|
"""
|
||||||
|
Initialize class with default params if kwargs is empty.
|
||||||
|
Otherwise, initialize params with kwargs.
|
||||||
|
"""
|
||||||
|
self.params = OrderedDict(
|
||||||
|
learning_rate=[self.default_params.get("learning_rate")],
|
||||||
|
epochs=[self.default_params.get("epoch")],
|
||||||
|
batch_size=[self.default_params.get("batch_size")],
|
||||||
|
im_size=[self.default_params.get("im_size")],
|
||||||
|
architecture=[self.default_params.get("architecture")],
|
||||||
|
transform=[self.default_params.get("transform")],
|
||||||
|
dropout=[self.default_params.get("dropout")],
|
||||||
|
weight_decay=[self.default_params.get("weight_decay")],
|
||||||
|
training_schedule=[self.default_params.get("training_schedule")],
|
||||||
|
discriminative_lr=[self.default_params.get("discriminative_lr")],
|
||||||
|
one_cycle_policy=[self.default_params.get("one_cycle_policy")],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.param_order = tuple(self.params.keys())
|
||||||
|
self.update_parameters(**kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> Dict[str, Any]:
|
||||||
|
""" Returns parameters to test on if run() is called. """
|
||||||
|
return self.params
|
||||||
|
|
||||||
|
@property
|
||||||
|
def permutations(self) -> List[Tuple[Any]]:
|
||||||
|
""" Returns a list of all permutations, expressed in tuples. """
|
||||||
|
params = tuple([self.params[k] for k in self.param_order])
|
||||||
|
permutations = list(itertools.product(*params))
|
||||||
|
return permutations
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_data_bunch(
|
||||||
|
path: Union[Path, str], transform: bool, im_size: int, bs: int
|
||||||
|
) -> ImageDataBunch:
|
||||||
|
"""
|
||||||
|
Create ImageDataBunch and return it. TODO in future version is to allow
|
||||||
|
users to pass in their own image bunch or their own Transformation
|
||||||
|
objects (instead of using fastai's <get_transforms>)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (Union[Path, str]): path to data to create databunch with
|
||||||
|
transform (bool): a flag to set fastai default transformations (get_transforms())
|
||||||
|
im_size (int): image size of databunch
|
||||||
|
bs (int): batch size of databunch
|
||||||
|
Returns:
|
||||||
|
ImageDataBunch
|
||||||
|
"""
|
||||||
|
path = path if type(path) is Path else Path(path)
|
||||||
|
tfms = get_transforms() if transform else None
|
||||||
|
return (
|
||||||
|
ImageList.from_folder(path)
|
||||||
|
.split_by_rand_pct(valid_pct=0.33)
|
||||||
|
.label_from_folder()
|
||||||
|
.transform(tfms=tfms, size=im_size)
|
||||||
|
.databunch(bs=bs)
|
||||||
|
.normalize(imagenet_stats)
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _early_stopping_callback(
|
||||||
|
metric: str = "accuracy", min_delta: float = 0.01, patience: int = 3
|
||||||
|
) -> partial:
|
||||||
|
""" Returns an early stopping callback. """
|
||||||
|
return partial(
|
||||||
|
EarlyStoppingCallback,
|
||||||
|
monitor="accuracy",
|
||||||
|
min_delta=0.01, # conservative
|
||||||
|
patience=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _serialize_permutations(p: Tuple[Any]) -> str:
|
||||||
|
""" Serializes all parameters as a string that uses {parameter_flag}. """
|
||||||
|
p = iter(p)
|
||||||
|
return (
|
||||||
|
f"{parameter_flag} "
|
||||||
|
f"[learning_rate: {next(p)}]|[epochs: {next(p)}]|[batch_size: {next(p)}]|"
|
||||||
|
f"[im_size: {next(p)}]|[arch: {next(p).name}]|"
|
||||||
|
f"[transforms: {next(p)}]|[dropout: {next(p)}]|"
|
||||||
|
f"[weight_decay: {next(p)}]|[training_schedule: {next(p).name}]|"
|
||||||
|
f"[discriminative_lr: {next(p)}]|[one_cycle_policy: {next(p)}]"
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _make_df_from_dict(
|
||||||
|
results: Dict[Any, Dict[Any, Dict[Any, Dict[Any, Any]]]]
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
""" Converts a 4-times-nested dictionary into a multi-index dataframe. """
|
||||||
|
return pd.DataFrame.from_dict(
|
||||||
|
{
|
||||||
|
(i, j, k): results[i][j][k]
|
||||||
|
for i in results.keys()
|
||||||
|
for j in results[i].keys()
|
||||||
|
for k in results[i][j].keys()
|
||||||
|
},
|
||||||
|
orient="index",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _param_tuple_to_dict(self, params: Tuple[Any]) -> Dict[str, Any]:
|
||||||
|
""" Converts a tuple of parameters to a Dict. """
|
||||||
|
return dict(
|
||||||
|
learning_rate=params[self.param_order.index("learning_rate")],
|
||||||
|
batch_size=params[self.param_order.index("batch_size")],
|
||||||
|
transform=params[self.param_order.index("transform")],
|
||||||
|
im_size=params[self.param_order.index("im_size")],
|
||||||
|
epochs=params[self.param_order.index("epochs")],
|
||||||
|
architecture=params[self.param_order.index("architecture")],
|
||||||
|
dropout=params[self.param_order.index("dropout")],
|
||||||
|
weight_decay=params[self.param_order.index("weight_decay")],
|
||||||
|
discriminative_lr=params[
|
||||||
|
self.param_order.index("discriminative_lr")
|
||||||
|
],
|
||||||
|
training_schedule=params[
|
||||||
|
self.param_order.index("training_schedule")
|
||||||
|
],
|
||||||
|
one_cycle_policy=params[
|
||||||
|
self.param_order.index("one_cycle_policy")
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def download_benchmark_datasets(
|
||||||
|
cls, dest: Union[Path, str] = data_path()
|
||||||
|
) -> List[Path]:
|
||||||
|
""" Download benchmark datasets to {dest}. """
|
||||||
|
benchmark_urls = [
|
||||||
|
Urls.fridge_objects_path,
|
||||||
|
Urls.fashion_texture_path,
|
||||||
|
Urls.flickr_logos_32_subset_path,
|
||||||
|
Urls.food_101_subset_path,
|
||||||
|
Urls.lettuce_path,
|
||||||
|
Urls.recycle_path,
|
||||||
|
]
|
||||||
|
return unzip_urls(benchmark_urls, dest)
|
||||||
|
|
||||||
|
def _learn(
|
||||||
|
self, data_path: Path, params: Tuple[Any], stop_early: bool
|
||||||
|
) -> Tuple[Learner, Time]:
|
||||||
|
"""
|
||||||
|
Given a set of permutations, create a learner to train and validate on
|
||||||
|
the dataset.
|
||||||
|
Args:
|
||||||
|
data_path (Path): The location of the data to use
|
||||||
|
params (Tuple[Any]): The set of parameters to train and validate on
|
||||||
|
stop_early (bool): Whether or not to stop early if the evaluation
|
||||||
|
metric does not improve
|
||||||
|
Returns:
|
||||||
|
Tuple[Learner, Time]: Learn object from Fastai and the duration in
|
||||||
|
seconds it took.
|
||||||
|
"""
|
||||||
|
start = time.time()
|
||||||
|
params = self._param_tuple_to_dict(params)
|
||||||
|
|
||||||
|
transform = params["transform"]
|
||||||
|
im_size = params["im_size"]
|
||||||
|
epochs = params["epochs"]
|
||||||
|
batch_size = params["batch_size"]
|
||||||
|
architecture = params["architecture"]
|
||||||
|
dropout = params["dropout"]
|
||||||
|
learning_rate = params["learning_rate"]
|
||||||
|
discriminative_lr = params["discriminative_lr"]
|
||||||
|
training_schedule = params["training_schedule"]
|
||||||
|
one_cycle_policy = params["one_cycle_policy"]
|
||||||
|
weight_decay = params["weight_decay"]
|
||||||
|
|
||||||
|
data = self._get_data_bunch(data_path, transform, im_size, batch_size)
|
||||||
|
|
||||||
|
callbacks = list()
|
||||||
|
if stop_early:
|
||||||
|
callbacks.append(_early_stopping_callback())
|
||||||
|
|
||||||
|
learn = cnn_learner(
|
||||||
|
data,
|
||||||
|
architecture.value,
|
||||||
|
metrics=accuracy,
|
||||||
|
ps=dropout,
|
||||||
|
callback_fns=callbacks,
|
||||||
|
)
|
||||||
|
|
||||||
|
head_learning_rate = learning_rate
|
||||||
|
body_learning_rate = (
|
||||||
|
slice(learning_rate, 3e-3) if discriminative_lr else learning_rate
|
||||||
|
)
|
||||||
|
|
||||||
|
def fit(
|
||||||
|
learn: Learner, e: int, lr: Union[slice, float], wd=float
|
||||||
|
) -> partial:
|
||||||
|
""" Returns a partial func for either fit_one_cycle or fit
|
||||||
|
depending on <one_cycle_policy> """
|
||||||
|
return (
|
||||||
|
partial(learn.fit_one_cycle, cyc_len=e, max_lr=lr, wd=wd)
|
||||||
|
if one_cycle_policy
|
||||||
|
else partial(learn.fit, epochs=e, lr=lr, wd=wd)
|
||||||
|
)
|
||||||
|
|
||||||
|
if training_schedule is TrainingSchedule.head_only:
|
||||||
|
if discriminative_lr:
|
||||||
|
raise Exception(
|
||||||
|
"Cannot run discriminative_lr if training schedule is head_only."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
fit(learn, epochs, body_learning_rate, weight_decay)()
|
||||||
|
|
||||||
|
elif training_schedule is TrainingSchedule.body_only:
|
||||||
|
learn.unfreeze()
|
||||||
|
fit(learn, epochs, body_learning_rate, weight_decay)()
|
||||||
|
|
||||||
|
elif training_schedule is TrainingSchedule.head_first_then_body:
|
||||||
|
head_epochs = epochs // 4
|
||||||
|
fit(learn, head_epochs, head_learning_rate, weight_decay)()
|
||||||
|
learn.unfreeze()
|
||||||
|
fit(
|
||||||
|
learn, epochs - head_epochs, body_learning_rate, weight_decay
|
||||||
|
)()
|
||||||
|
|
||||||
|
end = time.time()
|
||||||
|
duration = end - start
|
||||||
|
|
||||||
|
return learn, duration
|
||||||
|
|
||||||
|
def update_parameters(self, **kwargs) -> None:
|
||||||
|
""" Update the class object's parameters.
|
||||||
|
If kwarg key is not in an existing param key, then raise exception.
|
||||||
|
If the kwarg value is None, pass.
|
||||||
|
Otherwise overwrite the corresponding self.params key.
|
||||||
|
"""
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
if k not in self.params.keys():
|
||||||
|
raise Exception("Parameter {k} is invalid.")
|
||||||
|
if v is None:
|
||||||
|
continue
|
||||||
|
self.params[k] = v
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self, datasets: List[Path], reps: int = 3, early_stopping: bool = False
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
""" Performs the experiment.
|
||||||
|
Iterates through the number of specified <reps>, the list permutations
|
||||||
|
as defined in this class, and the <datasets> to calculate evaluation
|
||||||
|
metrics and duration for each run.
|
||||||
|
|
||||||
|
WARNING: this method can take a long time depending on your experiment
|
||||||
|
definition.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
datasets (List[Path]): A list of datasets to iterate over.
|
||||||
|
reps (int): The number of runs to loop over.
|
||||||
|
early_stopping (bool): Whether we want to perform early stopping.
|
||||||
|
Returns:
|
||||||
|
pd.DataFrame: a multi-index dataframe with the results stored in it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
res = dict()
|
||||||
|
for rep in range(reps):
|
||||||
|
|
||||||
|
res[rep] = dict()
|
||||||
|
for i, permutation in enumerate(self.permutations):
|
||||||
|
print(
|
||||||
|
f"Running {i+1} of {len(self.permutations)} permutations. "
|
||||||
|
f"Repeat {rep+1} of {reps}."
|
||||||
|
)
|
||||||
|
|
||||||
|
stringified_permutation = self._serialize_permutations(
|
||||||
|
permutation
|
||||||
|
)
|
||||||
|
res[rep][stringified_permutation] = dict()
|
||||||
|
for dataset in datasets:
|
||||||
|
|
||||||
|
data_name = os.path.basename(dataset)
|
||||||
|
|
||||||
|
res[rep][stringified_permutation][data_name] = dict()
|
||||||
|
|
||||||
|
learn, duration = self._learn(
|
||||||
|
dataset, permutation, early_stopping
|
||||||
|
)
|
||||||
|
|
||||||
|
_, metric = learn.validate(
|
||||||
|
learn.data.valid_dl, metrics=[accuracy]
|
||||||
|
)
|
||||||
|
|
||||||
|
res[rep][stringified_permutation][data_name][
|
||||||
|
"duration"
|
||||||
|
] = duration
|
||||||
|
res[rep][stringified_permutation][data_name][
|
||||||
|
"accuracy"
|
||||||
|
] = float(metric)
|
||||||
|
|
||||||
|
learn.destroy()
|
||||||
|
|
||||||
|
return self._make_df_from_dict(res)
|
Загрузка…
Ссылка в новой задаче