зеркало из https://github.com/microsoft/torchgeo.git
0. Added notebook for visualization
1. Converted some of the features in _load_features from strs to ints
This commit is contained in:
Родитель
e1a889c4c0
Коммит
6a0e3aacd2
|
@ -0,0 +1,141 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "indirect-delivery",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%matplotlib inline\n",
|
||||
"%load_ext autoreload\n",
|
||||
"%autoreload 2\n",
|
||||
"import sys\n",
|
||||
"sys.path.append(\"..\")\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"\n",
|
||||
"from torchgeo.datasets import TropicalCycloneWindEstimation"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "advanced-gauge",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ROOT_DIR = os.path.expanduser(\"~/mount/data/\")\n",
|
||||
"API_KEY = \"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "tired-malta",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Benchmarking"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "bigger-archive",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%%time\n",
|
||||
"# takes ~3 minutes and 20 seconds to download train split to SSD\n",
|
||||
"# takes ~9 seconds to verify checksum\n",
|
||||
"train_dataset = TropicalCycloneWindEstimation(\n",
|
||||
" ROOT_DIR,\n",
|
||||
" split=\"train\",\n",
|
||||
" download=True,\n",
|
||||
" api_key=API_KEY,\n",
|
||||
" checksum=True\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "reverse-bridges",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%%time\n",
|
||||
"test_dataset = TropicalCycloneWindEstimation(\n",
|
||||
" ROOT_DIR,\n",
|
||||
" split=\"test\",\n",
|
||||
" download=True,\n",
|
||||
" api_key=API_KEY,\n",
|
||||
" checksum=True\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "invisible-aurora",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"len(train_dataset), len(test_dataset)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "banner-consultancy",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Visualization"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "swiss-gates",
|
||||
"metadata": {
|
||||
"scrolled": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for i in np.random.choice(len(train_dataset), size=10):\n",
|
||||
" \n",
|
||||
" sample = train_dataset[i]\n",
|
||||
" img = sample[\"image\"]\n",
|
||||
" wind_speed = sample[\"wind_speed\"]\n",
|
||||
" \n",
|
||||
" plt.figure(figsize=(5,5))\n",
|
||||
" plt.imshow(img, cmap=\"Greys_r\", vmin=0, vmax=255)\n",
|
||||
" plt.axis(\"off\")\n",
|
||||
" plt.title(\"Windspeed: %d\" % (wind_speed))\n",
|
||||
" plt.show()\n",
|
||||
" plt.close()"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "landcover",
|
||||
"language": "python",
|
||||
"name": "conda-env-landcover-py"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
|
@ -140,14 +140,13 @@ class TropicalCycloneWindEstimation(VisionDataset):
|
|||
Returns:
|
||||
the image
|
||||
"""
|
||||
print(directory)
|
||||
filename = os.path.join(directory.format("source"), "image.jpg")
|
||||
with Image.open(filename) as img:
|
||||
array = np.array(img)
|
||||
tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined]
|
||||
return tensor
|
||||
|
||||
def _load_features(self, directory: str) -> Dict[str, str]:
|
||||
def _load_features(self, directory: str) -> Dict[str, Any]:
|
||||
"""Load features for a single image.
|
||||
|
||||
Parameters:
|
||||
|
@ -158,12 +157,16 @@ class TropicalCycloneWindEstimation(VisionDataset):
|
|||
"""
|
||||
filename = os.path.join(directory.format("source"), "features.json")
|
||||
with open(filename) as f:
|
||||
features: Dict[str, str] = json.load(f)
|
||||
features: Dict[str, Any] = json.load(f)
|
||||
|
||||
filename = os.path.join(directory.format("labels"), "labels.json")
|
||||
with open(filename) as f:
|
||||
features.update(json.load(f))
|
||||
|
||||
features["relative_time"] = int(features["relative_time"])
|
||||
features["ocean"] = int(features["ocean"])
|
||||
features["wind_speed"] = int(features["wind_speed"])
|
||||
|
||||
return features
|
||||
|
||||
def _check_integrity(self) -> bool:
|
||||
|
|
Загрузка…
Ссылка в новой задаче