0. Added notebook for visualization

1. Converted some of the features in _load_features from strs to ints
This commit is contained in:
Caleb Robinson 2021-06-23 23:32:11 +00:00 коммит произвёл Adam J. Stewart
Родитель e1a889c4c0
Коммит 6a0e3aacd2
2 изменённых файлов: 147 добавлений и 3 удалений

Просмотреть файл

@ -0,0 +1,141 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "indirect-delivery",
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"%load_ext autoreload\n",
"%autoreload 2\n",
"import sys\n",
"sys.path.append(\"..\")\n",
"import os\n",
"\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from torchgeo.datasets import TropicalCycloneWindEstimation"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "advanced-gauge",
"metadata": {},
"outputs": [],
"source": [
"ROOT_DIR = os.path.expanduser(\"~/mount/data/\")\n",
"API_KEY = \"\""
]
},
{
"cell_type": "markdown",
"id": "tired-malta",
"metadata": {},
"source": [
"## Benchmarking"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bigger-archive",
"metadata": {},
"outputs": [],
"source": [
"%%time\n",
"# takes ~3 minutes and 20 seconds to download train split to SSD\n",
"# takes ~9 seconds to verify checksum\n",
"train_dataset = TropicalCycloneWindEstimation(\n",
" ROOT_DIR,\n",
" split=\"train\",\n",
" download=True,\n",
" api_key=API_KEY,\n",
" checksum=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "reverse-bridges",
"metadata": {},
"outputs": [],
"source": [
"%%time\n",
"test_dataset = TropicalCycloneWindEstimation(\n",
" ROOT_DIR,\n",
" split=\"test\",\n",
" download=True,\n",
" api_key=API_KEY,\n",
" checksum=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "invisible-aurora",
"metadata": {},
"outputs": [],
"source": [
"len(train_dataset), len(test_dataset)"
]
},
{
"cell_type": "markdown",
"id": "banner-consultancy",
"metadata": {},
"source": [
"## Visualization"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "swiss-gates",
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"for i in np.random.choice(len(train_dataset), size=10):\n",
" \n",
" sample = train_dataset[i]\n",
" img = sample[\"image\"]\n",
" wind_speed = sample[\"wind_speed\"]\n",
" \n",
" plt.figure(figsize=(5,5))\n",
" plt.imshow(img, cmap=\"Greys_r\", vmin=0, vmax=255)\n",
" plt.axis(\"off\")\n",
" plt.title(\"Windspeed: %d\" % (wind_speed))\n",
" plt.show()\n",
" plt.close()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "landcover",
"language": "python",
"name": "conda-env-landcover-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

Просмотреть файл

@ -140,14 +140,13 @@ class TropicalCycloneWindEstimation(VisionDataset):
Returns:
the image
"""
print(directory)
filename = os.path.join(directory.format("source"), "image.jpg")
with Image.open(filename) as img:
array = np.array(img)
tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined]
return tensor
def _load_features(self, directory: str) -> Dict[str, str]:
def _load_features(self, directory: str) -> Dict[str, Any]:
"""Load features for a single image.
Parameters:
@ -158,12 +157,16 @@ class TropicalCycloneWindEstimation(VisionDataset):
"""
filename = os.path.join(directory.format("source"), "features.json")
with open(filename) as f:
features: Dict[str, str] = json.load(f)
features: Dict[str, Any] = json.load(f)
filename = os.path.join(directory.format("labels"), "labels.json")
with open(filename) as f:
features.update(json.load(f))
features["relative_time"] = int(features["relative_time"])
features["ocean"] = int(features["ocean"])
features["wind_speed"] = int(features["wind_speed"])
return features
def _check_integrity(self) -> bool: