* STYLE: Apply some autopep8 fixes

* Fix more pre-commit errors
This commit is contained in:
Fernando Pérez-García 2022-06-07 10:14:00 +02:00 коммит произвёл GitHub
Родитель ad9aac5413
Коммит edc72eda8c
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
10 изменённых файлов: 14 добавлений и 5 удалений

Просмотреть файл

@ -3,3 +3,4 @@ ignore = E226,E302,E41,W391, E701, W291, E722, W503, E128, E126, E127, E731, E40
max-line-length = 160
max-complexity = 25
exclude = fastMRI/ test_outputs/ hi-ml/
min_python_version = 3.7

Просмотреть файл

@ -3,7 +3,9 @@
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from typing import Any, Dict, List, Optional, OrderedDict, Set, Tuple, Union
from typing import Any, Dict, List, Optional, Set, Tuple, Union, TYPE_CHECKING
if TYPE_CHECKING:
from typing import OrderedDict
import pytorch_lightning as pl
import torch

Просмотреть файл

@ -72,6 +72,7 @@ class HelloDataModule(LightningDataModule):
For cross validation (if required) we use k-fold cross-validation. The test set remains unchanged
while the training and validation data cycle through the k-folds of the remaining data.
"""
def __init__(
self,
root_folder: Path,

Просмотреть файл

@ -26,6 +26,7 @@ class PassThroughModel(SegmentationModelBase):
"""
Dummy model that returns a fixed segmentation, explained in make_nesting_rectangles.
"""
def __init__(self, **kwargs: Any) -> None:
fg_classes = ["spinalcord", "lung_r", "lung_l", "heart", "esophagus"]
fg_display_names = ["SpinalCord", "Lung_R", "Lung_L", "Heart", "Esophagus"]

Просмотреть файл

@ -26,6 +26,7 @@ class ResNetV2Block(nn.Module):
ResNetV2 (https://arxiv.org/pdf/1603.05027.pdf) uses pre activation in the ResNet blocks.
Big Transfer replaces BatchNorm with GroupNorm
"""
def __init__(self,
in_channels: int,
out_channels: int,
@ -82,6 +83,7 @@ class ResNetV2Layer(nn.Module):
"""
Single layer of ResNetV2
"""
def __init__(self,
in_channels: int,
out_channels: int,
@ -110,6 +112,7 @@ class BiTResNetV2(nn.Module):
https://arxiv.org/pdf/1912.11370.pdf
https://github.com/google-research/big_transfer
"""
def __init__(self, num_groups: int = 32,
num_classes: int = 21843,
num_blocks_in_layer: Tuple[int, int, int, int] = (3, 4, 23, 3),

Просмотреть файл

@ -22,6 +22,7 @@ class WindowNormalizationForScalarItem(Transform3D[ScalarItem]):
Transform3D to apply window normalization to "images" of a ScalarItem.
"""
# noinspection PyMissingConstructor
def __init__(self,
output_range: Tuple[float, float] = (0, 1),
sharpen: float = 1.9,

Просмотреть файл

@ -134,7 +134,6 @@ def get_labels_and_predictions_for_prediction_target_set(csv: Path,
def print_metrics_for_thresholded_output_for_all_prediction_targets(csv_to_set_optimal_threshold: Path,
csv_to_compute_metrics: Path,
config: ScalarModelBase) -> None:
"""
Given csvs written during inference for the validation and test sets, print out metrics for every combination of
prediction targets that exist in the dataset (i.e. for every subset of classes that occur in the dataset).

Просмотреть файл

@ -70,7 +70,7 @@ def initialize_rpdb() -> None:
# rpdb signal trapping does not work on Windows, as there is no SIGTRAP:
if not is_linux():
return
import rpdb
rpdb = __import__('rpdb') # hack so that the pre-commit hook does not flag this line
rpdb_port = 4444
rpdb.handle_trap(port=rpdb_port)
# For some reason, os.getpid() does not return the ID of what appears to be the currently running process.

Просмотреть файл

@ -277,6 +277,7 @@ class DummySimCLRData(VisionDataset):
"""
Returns a constant vector of size three [1., 1., 1.]
"""
def __init__(
self,
root: str,