зеркало из https://github.com/microsoft/torchgeo.git
Remove dependency on cartopy (#51)
* Remove dependency on cartopy * Remove extent and bbox
This commit is contained in:
Родитель
f507b08252
Коммит
72d667687b
|
@ -43,7 +43,6 @@ jobs:
|
|||
python-version: 3.9
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt-get install libgeos-dev libproj-dev # needed for cartopy
|
||||
pip install cython numpy # needed for pycocotools
|
||||
pip install --pre -r requirements.txt
|
||||
- name: Run sphinx checks
|
||||
|
|
|
@ -29,7 +29,6 @@ jobs:
|
|||
python-version: 3.9
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt-get install libgeos-dev libproj-dev # needed for cartopy
|
||||
pip install cython numpy # needed for pycocotools
|
||||
pip install --pre -r requirements.txt
|
||||
- name: Run mypy checks
|
||||
|
@ -37,8 +36,6 @@ jobs:
|
|||
pytest:
|
||||
name: pytest
|
||||
runs-on: ${{ matrix.os }}
|
||||
env:
|
||||
PKG_CONFIG_PATH: /usr/local/opt/proj@7/lib/pkgconfig
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest, windows-latest]
|
||||
|
@ -67,10 +64,10 @@ jobs:
|
|||
key: ${{ runner.os }}-python-${{ matrix.python-version }}-${{ hashFiles('requirements.txt') }}
|
||||
restore-keys: ${{ runner.os }}-python-${{ matrix.python-version }}-
|
||||
- name: Install apt dependencies
|
||||
run: sudo apt-get install libgeos-dev libproj-dev unrar
|
||||
run: sudo apt-get install unrar
|
||||
if: ${{ runner.os == 'Linux' }}
|
||||
- name: Install brew dependencies
|
||||
run: brew install geos proj@7 rar
|
||||
run: brew install rar
|
||||
if: ${{ runner.os == 'macOS' }}
|
||||
- name: Install conda
|
||||
uses: conda-incubator/setup-miniconda@v2
|
||||
|
@ -78,7 +75,7 @@ jobs:
|
|||
python-version: ${{ matrix.python-version }}
|
||||
if: ${{ runner.os == 'Windows' }}
|
||||
- name: Install conda dependencies
|
||||
run: conda install -c conda-forge cartopy h5py 'rasterio>=1.0'
|
||||
run: conda install -c conda-forge h5py 'rasterio>=1.0'
|
||||
if: ${{ runner.os == 'Windows' }}
|
||||
- name: Install pip dependencies
|
||||
run: |
|
||||
|
|
|
@ -51,7 +51,6 @@ nitpick_ignore = [
|
|||
# https://github.com/sphinx-doc/sphinx/issues/8127
|
||||
("py:class", ".."),
|
||||
# TODO: can't figure out why this isn't found
|
||||
("py:class", "cartopy._crs.CRS"),
|
||||
("py:class", "LightningDataModule"),
|
||||
]
|
||||
|
||||
|
@ -88,7 +87,6 @@ autodoc_typehints = "description"
|
|||
|
||||
# sphinx.ext.intersphinx
|
||||
intersphinx_mapping = {
|
||||
"cartopy": ("https://scitools.org.uk/cartopy/docs/latest/", None),
|
||||
"python": ("https://docs.python.org/3", None),
|
||||
"pytorch-lightning": ("https://pytorch-lightning.readthedocs.io/en/latest/", None),
|
||||
"rasterio": ("https://rasterio.readthedocs.io/en/latest/", None),
|
||||
|
|
|
@ -66,7 +66,6 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"import cartopy.crs as ccrs\r\n",
|
||||
"from rasterio.crs import CRS\r\n",
|
||||
"from torch.utils.data import DataLoader\r\n",
|
||||
"\r\n",
|
||||
|
@ -112,8 +111,6 @@
|
|||
"ROOT = \"/mnt/blobfuse/adam-scratch\"\r\n",
|
||||
"\r\n",
|
||||
"crs = CRS.from_epsg(32616)\r\n",
|
||||
"projection = ccrs.UTM(16)\r\n",
|
||||
"transform = ccrs.UTM(16)\r\n",
|
||||
"\r\n",
|
||||
"landsat = Landsat8(root=ROOT, crs=crs, bands=[\"B4\", \"B3\", \"B2\"])\r\n",
|
||||
"cdl = CDL(root=ROOT, crs=crs)"
|
||||
|
@ -293,10 +290,9 @@
|
|||
"for sample in dataloader:\r\n",
|
||||
" image = sample[\"image\"][0]\r\n",
|
||||
" label = sample[\"masks\"][0]\r\n",
|
||||
" bbox = sample[\"bbox\"][0]\r\n",
|
||||
"\r\n",
|
||||
" landsat.plot(image, bbox, projection, transform)\r\n",
|
||||
" cdl.plot(label, bbox, projection, transform)"
|
||||
" landsat.plot(image)\r\n",
|
||||
" cdl.plot(label)"
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
name: torchgeo
|
||||
dependencies:
|
||||
- cartopy
|
||||
- cudatoolkit=11.1
|
||||
- h5py
|
||||
- matplotlib
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
affine
|
||||
black[colorama]>=21b
|
||||
cartopy
|
||||
flake8
|
||||
h5py
|
||||
isort[colors]>=4.3.5
|
||||
|
|
|
@ -27,7 +27,6 @@ setup_requires =
|
|||
setuptools>=42
|
||||
install_requires =
|
||||
affine
|
||||
cartopy
|
||||
matplotlib
|
||||
numpy
|
||||
pillow
|
||||
|
|
|
@ -5,7 +5,6 @@ spack:
|
|||
- "python@3.6:+bz2"
|
||||
- py-affine
|
||||
- "py-black@21:+colorama"
|
||||
- py-cartopy
|
||||
- py-flake8
|
||||
- py-h5py
|
||||
- "py-isort@4.3.5:+colors"
|
||||
|
|
|
@ -61,7 +61,7 @@ class TestCDL:
|
|||
def test_plot(self, dataset: CDL) -> None:
|
||||
query = dataset.bounds
|
||||
x = dataset[query]
|
||||
dataset.plot(x["masks"], query)
|
||||
dataset.plot(x["masks"])
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
|
||||
|
|
|
@ -37,7 +37,7 @@ class TestLandsat8:
|
|||
def test_plot(self, dataset: Landsat8) -> None:
|
||||
query = dataset.bounds
|
||||
x = dataset[query]
|
||||
dataset.plot(x["image"], query)
|
||||
dataset.plot(x["image"])
|
||||
|
||||
def test_no_data(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(FileNotFoundError, match="No Landsat data was found in "):
|
||||
|
|
|
@ -5,13 +5,11 @@ import os
|
|||
from datetime import datetime
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
import cartopy.crs as ccrs
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import rasterio
|
||||
import torch
|
||||
from cartopy.crs import CRS as CCRS
|
||||
from rasterio.crs import CRS as RCRS
|
||||
from rasterio.crs import CRS
|
||||
from rasterio.vrt import WarpedVRT
|
||||
from rtree.index import Index, Property
|
||||
from torch import Tensor
|
||||
|
@ -19,8 +17,7 @@ from torch import Tensor
|
|||
from .geo import GeoDataset
|
||||
from .utils import BoundingBox, check_integrity, download_and_extract_archive
|
||||
|
||||
_ccrs = ccrs.AlbersEqualArea(-96, 23, 0, 0, (29.5, 45.5))
|
||||
_rcrs = RCRS.from_wkt(
|
||||
_crs = CRS.from_wkt(
|
||||
"""
|
||||
PROJCS["Albers Conical Equal Area",
|
||||
GEOGCS["NAD83",
|
||||
|
@ -86,7 +83,7 @@ class CDL(GeoDataset):
|
|||
def __init__(
|
||||
self,
|
||||
root: str = "data",
|
||||
crs: RCRS = _rcrs,
|
||||
crs: CRS = _crs,
|
||||
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
|
@ -198,26 +195,18 @@ class CDL(GeoDataset):
|
|||
md5=md5 if self.checksum else None,
|
||||
)
|
||||
|
||||
def plot(
|
||||
self,
|
||||
image: Tensor,
|
||||
bbox: BoundingBox,
|
||||
projection: CCRS = _ccrs,
|
||||
transform: CCRS = _ccrs,
|
||||
) -> None:
|
||||
def plot(self, image: Tensor) -> None:
|
||||
"""Plot an image on a map.
|
||||
|
||||
Args:
|
||||
image: the image to plot
|
||||
bbox: the bounding box of the image
|
||||
projection: :term:`projection` of map
|
||||
transform: :term:`coordinate reference system (CRS)` of data
|
||||
"""
|
||||
# Convert from class labels to RGBA values
|
||||
array = image.squeeze().numpy()
|
||||
array = self.cmap[array]
|
||||
|
||||
# Plot the image
|
||||
ax = plt.axes(projection=projection)
|
||||
ax.imshow(array, origin="lower", extent=bbox[:4], transform=transform)
|
||||
ax = plt.axes()
|
||||
ax.imshow(array, origin="lower")
|
||||
ax.axis("off")
|
||||
plt.show()
|
||||
|
|
|
@ -6,13 +6,11 @@ import os
|
|||
from datetime import datetime
|
||||
from typing import Any, Callable, Dict, Optional, Sequence
|
||||
|
||||
import cartopy.crs as ccrs
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import rasterio
|
||||
import torch
|
||||
from cartopy.crs import CRS as CCRS
|
||||
from rasterio.crs import CRS as RCRS
|
||||
from rasterio.crs import CRS
|
||||
from rasterio.vrt import WarpedVRT
|
||||
from rtree.index import Index, Property
|
||||
from torch import Tensor
|
||||
|
@ -20,8 +18,7 @@ from torch import Tensor
|
|||
from .geo import GeoDataset
|
||||
from .utils import BoundingBox
|
||||
|
||||
_ccrs = ccrs.UTM(16)
|
||||
_rcrs = RCRS.from_epsg(32616)
|
||||
_crs = CRS.from_epsg(32616)
|
||||
|
||||
|
||||
class Landsat(GeoDataset, abc.ABC):
|
||||
|
@ -52,7 +49,7 @@ class Landsat(GeoDataset, abc.ABC):
|
|||
def __init__(
|
||||
self,
|
||||
root: str = "data",
|
||||
crs: RCRS = _rcrs,
|
||||
crs: CRS = _crs,
|
||||
bands: Sequence[str] = [],
|
||||
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
|
||||
) -> None:
|
||||
|
@ -139,20 +136,11 @@ class Landsat(GeoDataset, abc.ABC):
|
|||
|
||||
return sample
|
||||
|
||||
def plot(
|
||||
self,
|
||||
image: Tensor,
|
||||
bbox: BoundingBox,
|
||||
projection: CCRS = _ccrs,
|
||||
transform: CCRS = _ccrs,
|
||||
) -> None:
|
||||
def plot(self, image: Tensor) -> None:
|
||||
"""Plot an image on a map.
|
||||
|
||||
Args:
|
||||
image: the image to plot
|
||||
bbox: the bounding box of the image
|
||||
projection: :term:`projection` of map
|
||||
transform: :term:`coordinate reference system (CRS)` of data
|
||||
"""
|
||||
# Convert from CxHxW to HxWxC
|
||||
image = image.permute((1, 2, 0))
|
||||
|
@ -165,8 +153,9 @@ class Landsat(GeoDataset, abc.ABC):
|
|||
array = np.clip(array, 0, 1)
|
||||
|
||||
# Plot the image
|
||||
ax = plt.axes(projection=projection)
|
||||
ax.imshow(array, origin="lower", extent=bbox[:4], transform=transform)
|
||||
ax = plt.axes()
|
||||
ax.imshow(array, origin="lower")
|
||||
ax.axis("off")
|
||||
plt.show()
|
||||
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче