зеркало из https://github.com/microsoft/torchgeo.git
Remove dependency on cartopy (#51)
* Remove dependency on cartopy * Remove extent and bbox
This commit is contained in:
Родитель
f507b08252
Коммит
72d667687b
|
@ -43,7 +43,6 @@ jobs:
|
||||||
python-version: 3.9
|
python-version: 3.9
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
sudo apt-get install libgeos-dev libproj-dev # needed for cartopy
|
|
||||||
pip install cython numpy # needed for pycocotools
|
pip install cython numpy # needed for pycocotools
|
||||||
pip install --pre -r requirements.txt
|
pip install --pre -r requirements.txt
|
||||||
- name: Run sphinx checks
|
- name: Run sphinx checks
|
||||||
|
|
|
@ -29,7 +29,6 @@ jobs:
|
||||||
python-version: 3.9
|
python-version: 3.9
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
sudo apt-get install libgeos-dev libproj-dev # needed for cartopy
|
|
||||||
pip install cython numpy # needed for pycocotools
|
pip install cython numpy # needed for pycocotools
|
||||||
pip install --pre -r requirements.txt
|
pip install --pre -r requirements.txt
|
||||||
- name: Run mypy checks
|
- name: Run mypy checks
|
||||||
|
@ -37,8 +36,6 @@ jobs:
|
||||||
pytest:
|
pytest:
|
||||||
name: pytest
|
name: pytest
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
env:
|
|
||||||
PKG_CONFIG_PATH: /usr/local/opt/proj@7/lib/pkgconfig
|
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest, macos-latest, windows-latest]
|
os: [ubuntu-latest, macos-latest, windows-latest]
|
||||||
|
@ -67,10 +64,10 @@ jobs:
|
||||||
key: ${{ runner.os }}-python-${{ matrix.python-version }}-${{ hashFiles('requirements.txt') }}
|
key: ${{ runner.os }}-python-${{ matrix.python-version }}-${{ hashFiles('requirements.txt') }}
|
||||||
restore-keys: ${{ runner.os }}-python-${{ matrix.python-version }}-
|
restore-keys: ${{ runner.os }}-python-${{ matrix.python-version }}-
|
||||||
- name: Install apt dependencies
|
- name: Install apt dependencies
|
||||||
run: sudo apt-get install libgeos-dev libproj-dev unrar
|
run: sudo apt-get install unrar
|
||||||
if: ${{ runner.os == 'Linux' }}
|
if: ${{ runner.os == 'Linux' }}
|
||||||
- name: Install brew dependencies
|
- name: Install brew dependencies
|
||||||
run: brew install geos proj@7 rar
|
run: brew install rar
|
||||||
if: ${{ runner.os == 'macOS' }}
|
if: ${{ runner.os == 'macOS' }}
|
||||||
- name: Install conda
|
- name: Install conda
|
||||||
uses: conda-incubator/setup-miniconda@v2
|
uses: conda-incubator/setup-miniconda@v2
|
||||||
|
@ -78,7 +75,7 @@ jobs:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
if: ${{ runner.os == 'Windows' }}
|
if: ${{ runner.os == 'Windows' }}
|
||||||
- name: Install conda dependencies
|
- name: Install conda dependencies
|
||||||
run: conda install -c conda-forge cartopy h5py 'rasterio>=1.0'
|
run: conda install -c conda-forge h5py 'rasterio>=1.0'
|
||||||
if: ${{ runner.os == 'Windows' }}
|
if: ${{ runner.os == 'Windows' }}
|
||||||
- name: Install pip dependencies
|
- name: Install pip dependencies
|
||||||
run: |
|
run: |
|
||||||
|
|
|
@ -51,7 +51,6 @@ nitpick_ignore = [
|
||||||
# https://github.com/sphinx-doc/sphinx/issues/8127
|
# https://github.com/sphinx-doc/sphinx/issues/8127
|
||||||
("py:class", ".."),
|
("py:class", ".."),
|
||||||
# TODO: can't figure out why this isn't found
|
# TODO: can't figure out why this isn't found
|
||||||
("py:class", "cartopy._crs.CRS"),
|
|
||||||
("py:class", "LightningDataModule"),
|
("py:class", "LightningDataModule"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -88,7 +87,6 @@ autodoc_typehints = "description"
|
||||||
|
|
||||||
# sphinx.ext.intersphinx
|
# sphinx.ext.intersphinx
|
||||||
intersphinx_mapping = {
|
intersphinx_mapping = {
|
||||||
"cartopy": ("https://scitools.org.uk/cartopy/docs/latest/", None),
|
|
||||||
"python": ("https://docs.python.org/3", None),
|
"python": ("https://docs.python.org/3", None),
|
||||||
"pytorch-lightning": ("https://pytorch-lightning.readthedocs.io/en/latest/", None),
|
"pytorch-lightning": ("https://pytorch-lightning.readthedocs.io/en/latest/", None),
|
||||||
"rasterio": ("https://rasterio.readthedocs.io/en/latest/", None),
|
"rasterio": ("https://rasterio.readthedocs.io/en/latest/", None),
|
||||||
|
|
|
@ -66,7 +66,6 @@
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"source": [
|
"source": [
|
||||||
"import cartopy.crs as ccrs\r\n",
|
|
||||||
"from rasterio.crs import CRS\r\n",
|
"from rasterio.crs import CRS\r\n",
|
||||||
"from torch.utils.data import DataLoader\r\n",
|
"from torch.utils.data import DataLoader\r\n",
|
||||||
"\r\n",
|
"\r\n",
|
||||||
|
@ -112,8 +111,6 @@
|
||||||
"ROOT = \"/mnt/blobfuse/adam-scratch\"\r\n",
|
"ROOT = \"/mnt/blobfuse/adam-scratch\"\r\n",
|
||||||
"\r\n",
|
"\r\n",
|
||||||
"crs = CRS.from_epsg(32616)\r\n",
|
"crs = CRS.from_epsg(32616)\r\n",
|
||||||
"projection = ccrs.UTM(16)\r\n",
|
|
||||||
"transform = ccrs.UTM(16)\r\n",
|
|
||||||
"\r\n",
|
"\r\n",
|
||||||
"landsat = Landsat8(root=ROOT, crs=crs, bands=[\"B4\", \"B3\", \"B2\"])\r\n",
|
"landsat = Landsat8(root=ROOT, crs=crs, bands=[\"B4\", \"B3\", \"B2\"])\r\n",
|
||||||
"cdl = CDL(root=ROOT, crs=crs)"
|
"cdl = CDL(root=ROOT, crs=crs)"
|
||||||
|
@ -293,10 +290,9 @@
|
||||||
"for sample in dataloader:\r\n",
|
"for sample in dataloader:\r\n",
|
||||||
" image = sample[\"image\"][0]\r\n",
|
" image = sample[\"image\"][0]\r\n",
|
||||||
" label = sample[\"masks\"][0]\r\n",
|
" label = sample[\"masks\"][0]\r\n",
|
||||||
" bbox = sample[\"bbox\"][0]\r\n",
|
|
||||||
"\r\n",
|
"\r\n",
|
||||||
" landsat.plot(image, bbox, projection, transform)\r\n",
|
" landsat.plot(image)\r\n",
|
||||||
" cdl.plot(label, bbox, projection, transform)"
|
" cdl.plot(label)"
|
||||||
],
|
],
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
|
@ -552,4 +548,4 @@
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
"nbformat_minor": 2
|
"nbformat_minor": 2
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
name: torchgeo
|
name: torchgeo
|
||||||
dependencies:
|
dependencies:
|
||||||
- cartopy
|
|
||||||
- cudatoolkit=11.1
|
- cudatoolkit=11.1
|
||||||
- h5py
|
- h5py
|
||||||
- matplotlib
|
- matplotlib
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
affine
|
affine
|
||||||
black[colorama]>=21b
|
black[colorama]>=21b
|
||||||
cartopy
|
|
||||||
flake8
|
flake8
|
||||||
h5py
|
h5py
|
||||||
isort[colors]>=4.3.5
|
isort[colors]>=4.3.5
|
||||||
|
|
|
@ -27,7 +27,6 @@ setup_requires =
|
||||||
setuptools>=42
|
setuptools>=42
|
||||||
install_requires =
|
install_requires =
|
||||||
affine
|
affine
|
||||||
cartopy
|
|
||||||
matplotlib
|
matplotlib
|
||||||
numpy
|
numpy
|
||||||
pillow
|
pillow
|
||||||
|
|
|
@ -5,7 +5,6 @@ spack:
|
||||||
- "python@3.6:+bz2"
|
- "python@3.6:+bz2"
|
||||||
- py-affine
|
- py-affine
|
||||||
- "py-black@21:+colorama"
|
- "py-black@21:+colorama"
|
||||||
- py-cartopy
|
|
||||||
- py-flake8
|
- py-flake8
|
||||||
- py-h5py
|
- py-h5py
|
||||||
- "py-isort@4.3.5:+colors"
|
- "py-isort@4.3.5:+colors"
|
||||||
|
|
|
@ -61,7 +61,7 @@ class TestCDL:
|
||||||
def test_plot(self, dataset: CDL) -> None:
|
def test_plot(self, dataset: CDL) -> None:
|
||||||
query = dataset.bounds
|
query = dataset.bounds
|
||||||
x = dataset[query]
|
x = dataset[query]
|
||||||
dataset.plot(x["masks"], query)
|
dataset.plot(x["masks"])
|
||||||
|
|
||||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||||
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
|
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
|
||||||
|
|
|
@ -37,7 +37,7 @@ class TestLandsat8:
|
||||||
def test_plot(self, dataset: Landsat8) -> None:
|
def test_plot(self, dataset: Landsat8) -> None:
|
||||||
query = dataset.bounds
|
query = dataset.bounds
|
||||||
x = dataset[query]
|
x = dataset[query]
|
||||||
dataset.plot(x["image"], query)
|
dataset.plot(x["image"])
|
||||||
|
|
||||||
def test_no_data(self, tmp_path: Path) -> None:
|
def test_no_data(self, tmp_path: Path) -> None:
|
||||||
with pytest.raises(FileNotFoundError, match="No Landsat data was found in "):
|
with pytest.raises(FileNotFoundError, match="No Landsat data was found in "):
|
||||||
|
|
|
@ -5,13 +5,11 @@ import os
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Callable, Dict, Optional
|
from typing import Any, Callable, Dict, Optional
|
||||||
|
|
||||||
import cartopy.crs as ccrs
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import rasterio
|
import rasterio
|
||||||
import torch
|
import torch
|
||||||
from cartopy.crs import CRS as CCRS
|
from rasterio.crs import CRS
|
||||||
from rasterio.crs import CRS as RCRS
|
|
||||||
from rasterio.vrt import WarpedVRT
|
from rasterio.vrt import WarpedVRT
|
||||||
from rtree.index import Index, Property
|
from rtree.index import Index, Property
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
@ -19,8 +17,7 @@ from torch import Tensor
|
||||||
from .geo import GeoDataset
|
from .geo import GeoDataset
|
||||||
from .utils import BoundingBox, check_integrity, download_and_extract_archive
|
from .utils import BoundingBox, check_integrity, download_and_extract_archive
|
||||||
|
|
||||||
_ccrs = ccrs.AlbersEqualArea(-96, 23, 0, 0, (29.5, 45.5))
|
_crs = CRS.from_wkt(
|
||||||
_rcrs = RCRS.from_wkt(
|
|
||||||
"""
|
"""
|
||||||
PROJCS["Albers Conical Equal Area",
|
PROJCS["Albers Conical Equal Area",
|
||||||
GEOGCS["NAD83",
|
GEOGCS["NAD83",
|
||||||
|
@ -86,7 +83,7 @@ class CDL(GeoDataset):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
root: str = "data",
|
root: str = "data",
|
||||||
crs: RCRS = _rcrs,
|
crs: CRS = _crs,
|
||||||
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
|
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
|
||||||
download: bool = False,
|
download: bool = False,
|
||||||
checksum: bool = False,
|
checksum: bool = False,
|
||||||
|
@ -198,26 +195,18 @@ class CDL(GeoDataset):
|
||||||
md5=md5 if self.checksum else None,
|
md5=md5 if self.checksum else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def plot(
|
def plot(self, image: Tensor) -> None:
|
||||||
self,
|
|
||||||
image: Tensor,
|
|
||||||
bbox: BoundingBox,
|
|
||||||
projection: CCRS = _ccrs,
|
|
||||||
transform: CCRS = _ccrs,
|
|
||||||
) -> None:
|
|
||||||
"""Plot an image on a map.
|
"""Plot an image on a map.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
image: the image to plot
|
image: the image to plot
|
||||||
bbox: the bounding box of the image
|
|
||||||
projection: :term:`projection` of map
|
|
||||||
transform: :term:`coordinate reference system (CRS)` of data
|
|
||||||
"""
|
"""
|
||||||
# Convert from class labels to RGBA values
|
# Convert from class labels to RGBA values
|
||||||
array = image.squeeze().numpy()
|
array = image.squeeze().numpy()
|
||||||
array = self.cmap[array]
|
array = self.cmap[array]
|
||||||
|
|
||||||
# Plot the image
|
# Plot the image
|
||||||
ax = plt.axes(projection=projection)
|
ax = plt.axes()
|
||||||
ax.imshow(array, origin="lower", extent=bbox[:4], transform=transform)
|
ax.imshow(array, origin="lower")
|
||||||
|
ax.axis("off")
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
|
@ -6,13 +6,11 @@ import os
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Callable, Dict, Optional, Sequence
|
from typing import Any, Callable, Dict, Optional, Sequence
|
||||||
|
|
||||||
import cartopy.crs as ccrs
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import rasterio
|
import rasterio
|
||||||
import torch
|
import torch
|
||||||
from cartopy.crs import CRS as CCRS
|
from rasterio.crs import CRS
|
||||||
from rasterio.crs import CRS as RCRS
|
|
||||||
from rasterio.vrt import WarpedVRT
|
from rasterio.vrt import WarpedVRT
|
||||||
from rtree.index import Index, Property
|
from rtree.index import Index, Property
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
@ -20,8 +18,7 @@ from torch import Tensor
|
||||||
from .geo import GeoDataset
|
from .geo import GeoDataset
|
||||||
from .utils import BoundingBox
|
from .utils import BoundingBox
|
||||||
|
|
||||||
_ccrs = ccrs.UTM(16)
|
_crs = CRS.from_epsg(32616)
|
||||||
_rcrs = RCRS.from_epsg(32616)
|
|
||||||
|
|
||||||
|
|
||||||
class Landsat(GeoDataset, abc.ABC):
|
class Landsat(GeoDataset, abc.ABC):
|
||||||
|
@ -52,7 +49,7 @@ class Landsat(GeoDataset, abc.ABC):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
root: str = "data",
|
root: str = "data",
|
||||||
crs: RCRS = _rcrs,
|
crs: CRS = _crs,
|
||||||
bands: Sequence[str] = [],
|
bands: Sequence[str] = [],
|
||||||
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
|
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -139,20 +136,11 @@ class Landsat(GeoDataset, abc.ABC):
|
||||||
|
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
def plot(
|
def plot(self, image: Tensor) -> None:
|
||||||
self,
|
|
||||||
image: Tensor,
|
|
||||||
bbox: BoundingBox,
|
|
||||||
projection: CCRS = _ccrs,
|
|
||||||
transform: CCRS = _ccrs,
|
|
||||||
) -> None:
|
|
||||||
"""Plot an image on a map.
|
"""Plot an image on a map.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
image: the image to plot
|
image: the image to plot
|
||||||
bbox: the bounding box of the image
|
|
||||||
projection: :term:`projection` of map
|
|
||||||
transform: :term:`coordinate reference system (CRS)` of data
|
|
||||||
"""
|
"""
|
||||||
# Convert from CxHxW to HxWxC
|
# Convert from CxHxW to HxWxC
|
||||||
image = image.permute((1, 2, 0))
|
image = image.permute((1, 2, 0))
|
||||||
|
@ -165,8 +153,9 @@ class Landsat(GeoDataset, abc.ABC):
|
||||||
array = np.clip(array, 0, 1)
|
array = np.clip(array, 0, 1)
|
||||||
|
|
||||||
# Plot the image
|
# Plot the image
|
||||||
ax = plt.axes(projection=projection)
|
ax = plt.axes()
|
||||||
ax.imshow(array, origin="lower", extent=bbox[:4], transform=transform)
|
ax.imshow(array, origin="lower")
|
||||||
|
ax.axis("off")
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
|
Загрузка…
Ссылка в новой задаче