зеркало из https://github.com/microsoft/torchgeo.git
Faster, more space-efficient tutorials (#1124)
* Speed up notebook tests * Black fix * Mock rest of variables * Undo URL changes * Update conda deps * Notebooks also plot images * Fix undefined variable * Test with serial data loading * Use tempfile for all data download directories * Encode timeout in notebook * Share datasets across processes * Fix missing import * Pretrained Weights: use EuroSAT100 * Transforms: use EuroSAT100 * Trainers: use EuroSAT100 * Blacken * MPLBACKEND is already Agg by default on Linux * Indices: use EuroSAT100 * Pretrained Weights: add output * Pretrained Weights: add output * Trainers: save output * Pretrained Weights: ResNet 50 -> 18 * Trainers: better graph * Indices: add missing plot * Cache downloads * Small edit * Revert "Cache downloads" This reverts commit5276c53a06
. * Revert "Revert "Cache downloads"" This reverts commit137c69e776
. * env only * half env * Variable with no braces * Set tmpdir elsewhere * Give up on tmpdir caching * Trainers: clear output * lightning.pytorch package import * nbstripout * Rerun upon failure * Re-add caching * Rerun failures on release branch too
This commit is contained in:
Родитель
cfe4541c2e
Коммит
ce4c4b1455
|
@ -72,12 +72,10 @@ jobs:
|
|||
- name: Install pip dependencies
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
run: |
|
||||
pip install .[datasets,docs,tests]
|
||||
pip install .[docs,tests] planetary_computer pystac pytest-rerunfailures
|
||||
pip list
|
||||
- name: Run notebook checks
|
||||
env:
|
||||
MLHUB_API_KEY: ${{ secrets.MLHUB_API_KEY }}
|
||||
run: pytest --nbmake docs/tutorials --durations=10
|
||||
run: pytest --nbmake --durations=10 --reruns=10 docs/tutorials
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
|
|
@ -30,12 +30,10 @@ jobs:
|
|||
- name: Install pip dependencies
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
run: |
|
||||
pip install .[datasets,docs,tests]
|
||||
pip install .[docs,tests] planetary_computer pystac pytest-rerunfailures
|
||||
pip list
|
||||
- name: Run notebook checks
|
||||
env:
|
||||
MLHUB_API_KEY: ${{ secrets.MLHUB_API_KEY }}
|
||||
run: pytest --nbmake --nbmake-timeout=3000 docs/tutorials --durations=10
|
||||
run: pytest --nbmake --durations=10 --reruns=10 docs/tutorials
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
|
|
@ -33,7 +33,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -58,7 +58,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"gather": {
|
||||
"logged": 1629238744113
|
||||
|
@ -90,12 +90,11 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"data_root = tempfile.gettempdir()\n",
|
||||
"naip_root = os.path.join(data_root, \"naip\")\n",
|
||||
"naip_root = os.path.join(tempfile.gettempdir(), \"naip\")\n",
|
||||
"naip_url = (\n",
|
||||
" \"https://naipeuwest.blob.core.windows.net/naip/v002/de/2018/de_060cm_2018/38075/\"\n",
|
||||
")\n",
|
||||
|
@ -118,12 +117,11 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chesapeake_root = os.path.join(data_root, \"chesapeake\")\n",
|
||||
"\n",
|
||||
"chesapeake_root = os.path.join(tempfile.gettempdir(), \"chesapeake\")\n",
|
||||
"chesapeake = ChesapeakeDE(chesapeake_root, download=True)"
|
||||
]
|
||||
},
|
||||
|
@ -143,7 +141,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"gather": {
|
||||
"logged": 1629238744228
|
||||
|
@ -167,6 +165,34 @@
|
|||
" return toc - tic, i"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The following variables can be modified to control the number of samples drawn per epoch."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"nbmake": {
|
||||
"mock": {
|
||||
"batch_size": 1,
|
||||
"length": 1,
|
||||
"size": 1,
|
||||
"stride": 1000000
|
||||
}
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"size = 1000\n",
|
||||
"length = 888\n",
|
||||
"batch_size = 12\n",
|
||||
"stride = 500"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
|
@ -183,7 +209,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"gather": {
|
||||
"logged": 1629248963725
|
||||
|
@ -197,24 +223,15 @@
|
|||
"outputId": "edcc8199-bd09-4832-e50c-7be8ac78995b",
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"296.582683801651 74\n",
|
||||
"54.20210099220276 74\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for cache in [False, True]:\n",
|
||||
" chesapeake = ChesapeakeDE(chesapeake_root, cache=cache)\n",
|
||||
" naip = NAIP(naip_root, crs=chesapeake.crs, res=chesapeake.res, cache=cache)\n",
|
||||
" dataset = chesapeake & naip\n",
|
||||
" sampler = RandomGeoSampler(dataset, size=1000, length=888)\n",
|
||||
" sampler = RandomGeoSampler(dataset, size=size, length=length)\n",
|
||||
" dataloader = DataLoader(\n",
|
||||
" dataset, batch_size=12, sampler=sampler, collate_fn=stack_samples\n",
|
||||
" dataset, batch_size=batch_size, sampler=sampler, collate_fn=stack_samples\n",
|
||||
" )\n",
|
||||
" duration, count = time_epoch(dataloader)\n",
|
||||
" print(duration, count)"
|
||||
|
@ -236,7 +253,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"gather": {
|
||||
"logged": 1629239313388
|
||||
|
@ -250,24 +267,15 @@
|
|||
"outputId": "159ce99f-a438-4ecc-d218-9b9e28d02055",
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"391.90197944641113 74\n",
|
||||
"118.0611424446106 74\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for cache in [False, True]:\n",
|
||||
" chesapeake = ChesapeakeDE(chesapeake_root, cache=cache)\n",
|
||||
" naip = NAIP(naip_root, crs=chesapeake.crs, res=chesapeake.res, cache=cache)\n",
|
||||
" dataset = chesapeake & naip\n",
|
||||
" sampler = GridGeoSampler(dataset, size=1000, stride=500)\n",
|
||||
" sampler = GridGeoSampler(dataset, size=size, stride=stride)\n",
|
||||
" dataloader = DataLoader(\n",
|
||||
" dataset, batch_size=12, sampler=sampler, collate_fn=stack_samples\n",
|
||||
" dataset, batch_size=batch_size, sampler=sampler, collate_fn=stack_samples\n",
|
||||
" )\n",
|
||||
" duration, count = time_epoch(dataloader)\n",
|
||||
" print(duration, count)"
|
||||
|
@ -289,7 +297,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"gather": {
|
||||
"logged": 1629249843438
|
||||
|
@ -303,22 +311,15 @@
|
|||
"outputId": "497f6869-1ab7-4db7-bbce-e943b493ca41",
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"230.51380324363708 74\n",
|
||||
"53.99923872947693 74\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for cache in [False, True]:\n",
|
||||
" chesapeake = ChesapeakeDE(chesapeake_root, cache=cache)\n",
|
||||
" naip = NAIP(naip_root, crs=chesapeake.crs, res=chesapeake.res, cache=cache)\n",
|
||||
" dataset = chesapeake & naip\n",
|
||||
" sampler = RandomBatchGeoSampler(dataset, size=1000, batch_size=12, length=888)\n",
|
||||
" sampler = RandomBatchGeoSampler(\n",
|
||||
" dataset, size=size, batch_size=batch_size, length=length\n",
|
||||
" )\n",
|
||||
" dataloader = DataLoader(dataset, batch_sampler=sampler, collate_fn=stack_samples)\n",
|
||||
" duration, count = time_epoch(dataloader)\n",
|
||||
" print(duration, count)"
|
||||
|
@ -349,10 +350,10 @@
|
|||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "ipython",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.7"
|
||||
"version": "3.10.8"
|
||||
},
|
||||
"nteract": {
|
||||
"version": "nteract-front-end@1.0.0"
|
||||
|
|
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
|
@ -2,7 +2,9 @@
|
|||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"id": "p63J-QmUrMN-"
|
||||
},
|
||||
"source": [
|
||||
"Copyright (c) Microsoft Corporation. All rights reserved.\n",
|
||||
"\n",
|
||||
|
@ -10,21 +12,23 @@
|
|||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"id": "XRSkMFqyrMOE"
|
||||
},
|
||||
"source": [
|
||||
"# Pretrained Weights\n",
|
||||
"\n",
|
||||
"In this tutorial, we demonstrate some available pretrained weights in TorchGeo. The implementation follows torchvisions' recently introduced [Multi-Weight API](https://pytorch.org/blog/introducing-torchvision-new-multi-weight-support-api/). We will use the [EuroSAT](https://torchgeo.readthedocs.io/en/stable/api/datasets.html#eurosat) dataset throughout this tutorial.\n",
|
||||
"In this tutorial, we demonstrate some available pretrained weights in TorchGeo. The implementation follows torchvisions' recently introduced [Multi-Weight API](https://pytorch.org/blog/introducing-torchvision-new-multi-weight-support-api/). We will use the [EuroSAT](https://torchgeo.readthedocs.io/en/stable/api/datasets.html#eurosat) dataset throughout this tutorial. Specifically, a subset containing only 100 images.\n",
|
||||
"\n",
|
||||
"It's recommended to run this notebook on Google Colab if you don't have your own GPU. Click the \"Open in Colab\" button above to get started."
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"id": "NBa5RPAirMOF"
|
||||
},
|
||||
"source": [
|
||||
"## Setup\n",
|
||||
"\n",
|
||||
|
@ -34,16 +38,23 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "5AIQ1B9DrMOG",
|
||||
"outputId": "6bf360ea-8f60-45cf-c96e-0eac54818079"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install torchgeo"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"id": "IcCOnzVLrMOI"
|
||||
},
|
||||
"source": [
|
||||
"## Imports\n",
|
||||
"\n",
|
||||
|
@ -53,42 +64,61 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"id": "rjEGiiurrMOI"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%matplotlib inline\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"import csv\n",
|
||||
"import tempfile\n",
|
||||
"\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import numpy as np\n",
|
||||
"import timm\n",
|
||||
"import torch\n",
|
||||
"from lightning.pytorch import Trainer\n",
|
||||
"from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint\n",
|
||||
"from lightning.pytorch.loggers import CSVLogger\n",
|
||||
"\n",
|
||||
"from torchgeo.datamodules import EuroSATDataModule\n",
|
||||
"from torchgeo.datamodules import EuroSAT100DataModule\n",
|
||||
"from torchgeo.trainers import ClassificationTask\n",
|
||||
"from torchgeo.models import ResNet50_Weights, ViTSmall16_Weights"
|
||||
"from torchgeo.models import ResNet18_Weights"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "njAH71F3rMOJ"
|
||||
},
|
||||
"source": [
|
||||
"The following variables can be used to control training."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"id": "TVG_Z9MKrMOJ",
|
||||
"nbmake": {
|
||||
"mock": {
|
||||
"batch_size": 1,
|
||||
"fast_dev_run": true,
|
||||
"max_epochs": 1,
|
||||
"num_workers": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# we set a flag to check to see whether the notebook is currently being run by PyTest, if this is the case then we'll\n",
|
||||
"# skip the expensive training.\n",
|
||||
"in_tests = \"PYTEST_CURRENT_TEST\" in os.environ"
|
||||
"batch_size = 10\n",
|
||||
"num_workers = 2\n",
|
||||
"max_epochs = 10\n",
|
||||
"fast_dev_run = False"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"id": "QNnDoIf2rMOK"
|
||||
},
|
||||
"source": [
|
||||
"## Datamodule\n",
|
||||
"\n",
|
||||
|
@ -98,18 +128,22 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"id": "ia5ktOVerMOL"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"root = os.path.join(tempfile.gettempdir(), \"eurosat\")\n",
|
||||
"\n",
|
||||
"datamodule = EuroSATDataModule(root=root, batch_size=64, num_workers=4, download=True)"
|
||||
"root = os.path.join(tempfile.gettempdir(), \"eurosat100\")\n",
|
||||
"datamodule = EuroSAT100DataModule(\n",
|
||||
" root=root, batch_size=batch_size, num_workers=num_workers, download=True\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"id": "ksoszRjZrMOL"
|
||||
},
|
||||
"source": [
|
||||
"## Weights\n",
|
||||
"\n",
|
||||
|
@ -121,16 +155,19 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"id": "wJOrRqBGrMOM"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"weights = ResNet50_Weights.SENTINEL2_ALL_MOCO"
|
||||
"weights = ResNet18_Weights.SENTINEL2_ALL_MOCO"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"id": "EIpnXuXgrMOM"
|
||||
},
|
||||
"source": [
|
||||
"This set of weights is a torchvision `WeightEnum` and holds information such as the download url link or additional meta data. TorchGeo takes care of the downloading and initialization of models with a desired set of weights. Given that EuroSAT is a classification dataset, we can use a `ClassificationTask` object that holds the model and optimizer object as well as the training logic."
|
||||
]
|
||||
|
@ -138,11 +175,31 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 86,
|
||||
"referenced_widgets": [
|
||||
"8fb7022c6e4947d8955ed6da0d89ef05",
|
||||
"c5adf2ec163a43f08b3de98d0eabf8df",
|
||||
"4c6928b34aef4e778c837993c4197bcd",
|
||||
"ec96c2767d30412c8af5306a5c2f5ee3",
|
||||
"cbfb522eafbd4d4991ca5f45ad32bd2c",
|
||||
"5d02318e1c9e4034bd686527cdbb18ef",
|
||||
"739421651fb84d31a7baeacef2f8226c",
|
||||
"8d83a7dad192492facd4b3e03cbb2392",
|
||||
"fc423924e14a49b18b99064ab45e3f1e",
|
||||
"17dd6c0bbbf947b1a47a9cb8267d43ef",
|
||||
"8704c42feea442abb67cfd55ae3c4fa9"
|
||||
]
|
||||
},
|
||||
"id": "RZ8MPYH1rMON",
|
||||
"outputId": "fa683b8f-da21-4f26-ca3a-46163c9f12bf"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"task = ClassificationTask(\n",
|
||||
" model=\"resnet50\",\n",
|
||||
" model=\"resnet18\",\n",
|
||||
" loss=\"ce\",\n",
|
||||
" weights=weights,\n",
|
||||
" in_channels=13,\n",
|
||||
|
@ -153,9 +210,10 @@
|
|||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"id": "dWidC6vDrMON"
|
||||
},
|
||||
"source": [
|
||||
"If you do not want to utilize the `ClassificationTask` functionality for your experiments, you can also just create a [timm](https://github.com/rwightman/pytorch-image-models) model with pretrained weights from TorchGeo as follows:"
|
||||
]
|
||||
|
@ -163,18 +221,21 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"id": "ZaZQ07jorMOO"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"in_chans = weights.meta[\"in_chans\"]\n",
|
||||
"model = timm.create_model(\"resnet50\", in_chans=in_chans, num_classes=10)\n",
|
||||
"model.load_state_dict(weights.get_state_dict(progress=True), strict=False)"
|
||||
"model = timm.create_model(\"resnet18\", in_chans=in_chans, num_classes=10)\n",
|
||||
"model = model.load_state_dict(weights.get_state_dict(progress=True), strict=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"id": "vgNswWKOrMOO"
|
||||
},
|
||||
"source": [
|
||||
"## Training\n",
|
||||
"\n",
|
||||
|
@ -184,40 +245,226 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"id": "0Sf-CBorrMOO"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"experiment_dir = os.path.join(tempfile.gettempdir(), \"eurosat_results\")\n",
|
||||
"\n",
|
||||
"checkpoint_callback = ModelCheckpoint(\n",
|
||||
" monitor=\"val_loss\", dirpath=experiment_dir, save_top_k=1, save_last=True\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"early_stopping_callback = EarlyStopping(monitor=\"val_loss\", min_delta=0.00, patience=10)\n",
|
||||
"\n",
|
||||
"csv_logger = CSVLogger(save_dir=experiment_dir, name=\"pretrained_weights_logs\")"
|
||||
"accelerator = \"gpu\" if torch.cuda.is_available() else \"cpu\"\n",
|
||||
"default_root_dir = os.path.join(tempfile.gettempdir(), \"experiments\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "veVvF-5LrMOP",
|
||||
"outputId": "698e3e9e-8a53-4897-d40e-13b43470e29e"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"trainer = Trainer(\n",
|
||||
" callbacks=[checkpoint_callback, early_stopping_callback],\n",
|
||||
" logger=[csv_logger],\n",
|
||||
" default_root_dir=experiment_dir,\n",
|
||||
" accelerator=accelerator,\n",
|
||||
" default_root_dir=default_root_dir,\n",
|
||||
" fast_dev_run=fast_dev_run,\n",
|
||||
" log_every_n_steps=1,\n",
|
||||
" min_epochs=1,\n",
|
||||
" max_epochs=10,\n",
|
||||
" fast_dev_run=in_tests,\n",
|
||||
" max_epochs=max_epochs,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 544,
|
||||
"referenced_widgets": [
|
||||
"3435cd7003324d5dbee2222c1d919595",
|
||||
"0a2c8c97b5df4d6c9388ed9637c19e8f",
|
||||
"4f5209cdf8e346368da573e818ac8b98",
|
||||
"8d9a2d2377b74bb397cadf9af65849ae",
|
||||
"08871807406540fba8a935793cd28f93",
|
||||
"cc743f08d18f4302b7c3fc8506c4c9d4",
|
||||
"287c893f66294687910eccac3bb6353a",
|
||||
"088567dd04944a7682b2add8353a4076",
|
||||
"3af55148c50045e5b043485200ba7547",
|
||||
"b68eec0296fc45458b5d4cd2111afe6e",
|
||||
"b3423fc7e10a4cf3a16d6bc08f99dd17",
|
||||
"9af62a957121474d8d8b2be00a083287",
|
||||
"c728392f8d1f4da79492fb65fa10fbbe",
|
||||
"01ac3f0b67844950a6d1e5fe44978d1b",
|
||||
"d0140e681f9a437084b496927244f8ac",
|
||||
"cbfee603fbad42daacef920f3acc6a22",
|
||||
"cf52dda9f70249b99c893ff886a96c5b",
|
||||
"9800bfe6afbb4185b25a81017cabe720",
|
||||
"10538d0692ee42648400a204abb788dc",
|
||||
"d6169eaaff4b4918acc91f88ce82b14b",
|
||||
"71ca3d663ae8402497c9f6bb507e5e1b",
|
||||
"2dfc6c6897d0449d9078ed2fd7e35e4b",
|
||||
"8020d3f767d849cdbe2126c50dc546c8",
|
||||
"198bcafb3e4146a984e26260dc0ca2fa",
|
||||
"e774bef4588241e8bfd2a46a3ad0489d",
|
||||
"bf6c49957c3b4b3ebc09819d6259c3bf",
|
||||
"10b348c37c0b46018960247d083ed439",
|
||||
"2174d7e0a0804ff687c4f6985d28b9c5",
|
||||
"e500c7a093d14eed8ea9ac66c759f206",
|
||||
"31eddc17aa7843beb41d740698b7276d",
|
||||
"9a935b8fdbd4425988fb52e01dcb5d2d",
|
||||
"487eb01386324ecfb7d4fae26b84c93d",
|
||||
"b08c4d2cf51049c79ec61aa72b6e787e",
|
||||
"a0b2c0f1d01d4b4ba8ab4ef2d1db6c1a",
|
||||
"8ace9fe0a980414e90ae09c431f5f126",
|
||||
"2813e52affed45e6901c468c573a69cb",
|
||||
"a085a49f93a54e289d4007e0595d6de1",
|
||||
"df0bf05e12404f199d5a412c8a077da0",
|
||||
"ca5cdfec77ff49f7a088a96b910b796c",
|
||||
"82ad1d0d948b44829471f1818ccc5b30",
|
||||
"dedb7f740221433ea8d4668145b7818d",
|
||||
"4b1d82016e4c47e48d89e86492aacdcb",
|
||||
"e774743224c44ea5bf997a3842395d1b",
|
||||
"24f5ac9cf9f44be892a138892a7bd219",
|
||||
"7debe41bd81346eb951605e0048fadf9",
|
||||
"9b4a4015605142bdb6d5147dc5cbc9e2",
|
||||
"91f26f9100af4d3da3b2628f3ec6c3d3",
|
||||
"43437ff07611497eb2bf4822fb434e53",
|
||||
"ca455b6a71274921b8dbb789e9005449",
|
||||
"164ccc2485c44f4ab37b72c47e5f0898",
|
||||
"13a7139766e54d1e9073648c018b41bc",
|
||||
"ea23ac32759c4c6d899911f3bad213ee",
|
||||
"109c5832df954967b40487d6a51929c1",
|
||||
"ec16f29395b34aa6b127d0d63502a88b",
|
||||
"38071ab74d6e464a9b320fc2fc40bda3",
|
||||
"fd6a665f5a1943c2ad22b97b0c1d6999",
|
||||
"05b08c2cc14f4e0cae59fef46e283431",
|
||||
"f42a11c6eb4c449eb214287e3e61884a",
|
||||
"2167acd001b34219875104f2555a54ee",
|
||||
"f71c8ebc88cb483c8adac440b787b7b8",
|
||||
"ce776b04772f4795b5cf3c2279377634",
|
||||
"c5c491755a644c5ebd1e7c34d5a2437c",
|
||||
"e2c89dc3dea942aa9261a215b2a81d15",
|
||||
"e2e4c0a732384398bec0c89af39a1c95",
|
||||
"2aaa5c2a04d74d408434a43632e38d7a",
|
||||
"84128d9cef9f4926b3c321ccd83d761d",
|
||||
"6bcc169344064f15b4410261de4c16a5",
|
||||
"f449d843801f4f829ea75a2c6e4b67db",
|
||||
"9cc12d2e6b0846828d10cf8e48f1874d",
|
||||
"5afb9ef1a858408485139346bf4960f0",
|
||||
"1453657473fc45fd8cc43a3fd6d911ad",
|
||||
"ac0e5a1eb7394e6a9a20b50b9cb3b2d6",
|
||||
"d7882f9f6ff947828a9d965b07ae6308",
|
||||
"8fca6a1b68ee4bfa96dafdc5c438ef91",
|
||||
"dceb2d342e04422c9b965d6759b740d1",
|
||||
"92c572a2a1d74e789576318f3edb08bf",
|
||||
"96fee82a24e84a3ab39e8543e97fe539",
|
||||
"2128512dbe3e4a02a68f86d13b0f467c",
|
||||
"1b9d867c345843c78fdc27dd0efac5ff",
|
||||
"44487915e4204d6abf86766d6b17e3df",
|
||||
"e9a2537dd76144b395e308bdb0494321",
|
||||
"2ca0650dfd9e4f30a206a81a16258c2e",
|
||||
"13a0b8dcb53f4f50a1e253c33c197dd6",
|
||||
"6430bd284b4c4453ab64b88bfdba4f9b",
|
||||
"e91501471db04da991e73bbf28abb71d",
|
||||
"a6e3c391120a4171b59e69f7111d34cd",
|
||||
"df8651ab692b4a648a2e412b2e045bae",
|
||||
"2133201a6113468490d75b3a4ef942b1",
|
||||
"e6698265603c46aca5ce96aee64b485b",
|
||||
"d6680844b3e345e3af2bb27d9afc60f0",
|
||||
"e0dd466941a14188a40bfbe9d9f818ee",
|
||||
"e6a846177d7d49e6a05eecc1174dc574",
|
||||
"e39fed0066c14e16bfd3f9c366f0f288",
|
||||
"f59c4c11e1b34a4288656f28c228ec76",
|
||||
"f1aea0224e7246a39011c36f7ca72702",
|
||||
"5ae6f538d2aa4be785b7be93af885562",
|
||||
"4929a4a1103e47c3bd16e1db31f478e8",
|
||||
"3b8ccaf505604661a65efd0267c0df0a",
|
||||
"6bcd9a3684d0499395c21fbb617ed890",
|
||||
"a0bbee764501493887d45b15f58d1c84",
|
||||
"899c206bc1b948b7aede52ba26b0ab0e",
|
||||
"6c0d5a1b905743ff86613e88e151b120",
|
||||
"98c9c5edc84e4d70a63406fdd1e360e5",
|
||||
"3200b22ab9bd4ffab358630eb09d21bf",
|
||||
"eb0c0a2a3d814c08b8e78bd99f99048f",
|
||||
"366fbb8b74554e7d87329676bbe53488",
|
||||
"736cf1d1b23e4d6aba833d9977b77626",
|
||||
"001bf9ab98ec495782911cab6a91fd67",
|
||||
"8875242b17124440bac9eada00cbf7cc",
|
||||
"a2e6022491e04e5b882907db35ac5c66",
|
||||
"b040763ebdf646fd8d3d1d24914bcb47",
|
||||
"a380e1c490a84b75b60f4e50e979f707",
|
||||
"70382059066c41328a3348a7056032d1",
|
||||
"81b1f031892043bfb7bc385edf52ce48",
|
||||
"4f3da49fe0504640ad7c08fdf0b80113",
|
||||
"69ff1319e4c64d47b9925bb17b9c7c93",
|
||||
"3820af62fa1d498f80598e204a69e60d",
|
||||
"d769c7a4405f474284f9f5ca94116858",
|
||||
"d71c86b25cc048508a351238b4cd171f",
|
||||
"51c78f4a481c4982a7576b4b9c1c81fb",
|
||||
"811893b6b9d24f7f9976328d67eeb0c4",
|
||||
"78206d211645442bbdb8675c175cc233",
|
||||
"b1146db74d2f45859997e2173ce3350c",
|
||||
"62fad5fc0fc24099a0a95305206a5950",
|
||||
"d5ff5612fc3b448c8ad18b06fcf4f9cf",
|
||||
"17b95a1469594b2791a0d057cf4b3367",
|
||||
"8f09c3ccd53c44faac8ad55234c1487b",
|
||||
"4c3e55fdbc024ad2be6ac2facf1a44db",
|
||||
"b88c7c0bc00b4d88a1c49f11389126b5",
|
||||
"fd2d47e147794aa4bedb9933fc4900ce",
|
||||
"7a429ced76e0431ea49a795ec30d5ed2",
|
||||
"08d02351fc8543228e70629577344c7a",
|
||||
"4e38ef1829cb4c25ad84a85c3a7ab221",
|
||||
"25193dcc82f54db8a2993bc4f828fa05",
|
||||
"cfa2f244cd8740798360a2c4722d5748",
|
||||
"ed353a8c04d6417a80ed5ddc1bcebd80",
|
||||
"2252979760814e9da9e455fd34cce955",
|
||||
"8414aeb74f4b4a03b66544867bbe5e8e",
|
||||
"acb7fa2daf5646dca3eb9c9a0d29dc06",
|
||||
"efaf4f813c8d4385a4308a8c7db4958a",
|
||||
"f860da95c8b24bc1969b887d36ad9902",
|
||||
"e33fb707901547d5a4045ea80e5a7c76",
|
||||
"d2144b82918b4dbca37e9a0c1a90c2e6",
|
||||
"eabe4f7af9fe47a6af5509bba81ec0cd",
|
||||
"fead1396f54f4d86a52a04acb2deeea0",
|
||||
"c0c405dcda684c88a6995805bbe1a714",
|
||||
"ae14e90e6a084d0ca79dd0376dcbaa3a",
|
||||
"621eb1d3097147b1991faa1474e7d966",
|
||||
"aa7b75f8100749519494896ba4d12b1c",
|
||||
"84cdf31dac71421b948474bd9e4cde53",
|
||||
"4f5ad2bae09e4112b6b3a95397d7fb55",
|
||||
"8e3980c9433447f89740624bd85a5768",
|
||||
"7e22fc0da31d490aa43be2153cfd7403",
|
||||
"a692d570c31649c28fea3bfac1aaab01",
|
||||
"85755c2e08a34a2daefd4555df0b4fff",
|
||||
"23618d84213a4ea98b9c17bf969272cd",
|
||||
"44690735319544eda0fbd509b92b52c0",
|
||||
"dbc11af4883c4f85b6805f8b52e8170c",
|
||||
"78aed98077a74a75bc408d6fffa05065",
|
||||
"941c2e43942448fba0ff0b13e4f6bf66",
|
||||
"e32858052ad54f4f859a599dce4d13d6",
|
||||
"19c4c142c6224c09ad29b13d0f9af2fe",
|
||||
"6899eab30b4044a8a6493cd5acd48e7a",
|
||||
"a23b211aefd54270821aea17dd96f88a",
|
||||
"17af3f9a527d419ca4b369db183ae659",
|
||||
"756215384ee8444ab59695bf05d29b55",
|
||||
"d4e4b2592d54418f9744e8d3bfced6e5",
|
||||
"bcfe68ab1f46499baf4e29bc9628ef64",
|
||||
"f5145a7dbec343c39ea80958b5331caa",
|
||||
"aed96231191b47a3adeac49c4121f814",
|
||||
"4aad22e983004b1e9e4cebae0fe2897a",
|
||||
"71782ed9d2f94f55be80f7c6dc1b0a59",
|
||||
"945afa9b582e4431a71c8a39f4d7cacc",
|
||||
"d723ac8b9b314ae8ae9861570d192bdf",
|
||||
"11bed5e93c8b46c7a21de82c10e5cfd4",
|
||||
"b2d11d577ed54cb3bcd49f766cc5013a"
|
||||
]
|
||||
},
|
||||
"id": "9WQD4cuwrMOP",
|
||||
"outputId": "590f75dd-064b-4bcc-b504-167bf2ad6cfb"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"trainer.fit(model=task, datamodule=datamodule)"
|
||||
|
@ -225,8 +472,16 @@
|
|||
}
|
||||
],
|
||||
"metadata": {
|
||||
"accelerator": "GPU",
|
||||
"colab": {
|
||||
"provenance": []
|
||||
},
|
||||
"execution": {
|
||||
"timeout": 1200
|
||||
},
|
||||
"gpuClass": "standard",
|
||||
"kernelspec": {
|
||||
"display_name": "torchEnv",
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
|
@ -240,9 +495,8 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.10 (default, Nov 14 2022, 12:59:47) \n[GCC 9.4.0]"
|
||||
"version": "3.10.8"
|
||||
},
|
||||
"orig_nbformat": 4,
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "b058dd71d0e7047e70e62f655d92ec955f772479bbe5e5addd202027292e8f60"
|
||||
|
@ -250,5 +504,5 @@
|
|||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
|
|
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
|
@ -26,7 +26,7 @@ dependencies:
|
|||
- laspy>=2
|
||||
- lightning>=1.8
|
||||
- mypy>=0.900
|
||||
- nbmake>=0.1
|
||||
- nbmake>=1.3.3
|
||||
- nbsphinx>=0.8.5
|
||||
- omegaconf>=2.1
|
||||
- opencv-python>=3.4.2.17
|
||||
|
|
|
@ -48,6 +48,6 @@ pyupgrade==1.24.0
|
|||
|
||||
# tests
|
||||
mypy==0.900
|
||||
nbmake==0.1
|
||||
nbmake==1.3.3
|
||||
pytest==6.1.2
|
||||
pytest-cov==2.4.0
|
||||
|
|
|
@ -122,8 +122,8 @@ style =
|
|||
tests =
|
||||
# mypy 0.900+ required for pyproject.toml support
|
||||
mypy>=0.900,<2
|
||||
# nbmake 0.1+ required to fix path_source bug
|
||||
nbmake>=0.1,<2
|
||||
# nbmake 1.3.3+ required for variable mocking
|
||||
nbmake>=1.3.3,<2
|
||||
# pytest 6.1.2+ required by nbmake
|
||||
pytest>=6.1.2,<8
|
||||
# pytest-cov 2.4+ required for pytest --cov flags
|
||||
|
|
Загрузка…
Ссылка в новой задаче