STYLE: Fix pre-commit errors (#736)
* STYLE: Apply some autopep8 fixes * Fix more pre-commit errors
This commit is contained in:
Родитель
ad9aac5413
Коммит
edc72eda8c
1
.flake8
1
.flake8
|
@ -3,3 +3,4 @@ ignore = E226,E302,E41,W391, E701, W291, E722, W503, E128, E126, E127, E731, E40
|
|||
max-line-length = 160
|
||||
max-complexity = 25
|
||||
exclude = fastMRI/ test_outputs/ hi-ml/
|
||||
min_python_version = 3.7
|
||||
|
|
|
@ -30,7 +30,7 @@ jobs:
|
|||
}' -f org=$ORGANIZATION -F number=$PROJECT_NUMBER > project_data.json
|
||||
|
||||
echo 'PROJECT_ID='$(jq '.data.organization.projectNext.id' project_data.json) >> $GITHUB_ENV
|
||||
|
||||
|
||||
- name: Add issue to project
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.INNEREYE_OSS_PROJECT_ACCESS_TOKEN }}
|
||||
|
@ -43,4 +43,4 @@ jobs:
|
|||
id
|
||||
}
|
||||
}
|
||||
}' -f project=$PROJECT_ID -f issue=$ISSUE_ID --jq '.data.addProjectNextItem.projectNextItem.id')"
|
||||
}' -f project=$PROJECT_ID -f issue=$ISSUE_ID --jq '.data.addProjectNextItem.projectNextItem.id')"
|
||||
|
|
|
@ -3,7 +3,9 @@
|
|||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
from typing import Any, Dict, List, Optional, OrderedDict, Set, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union, TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from typing import OrderedDict
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
|
|
@ -72,6 +72,7 @@ class HelloDataModule(LightningDataModule):
|
|||
For cross validation (if required) we use k-fold cross-validation. The test set remains unchanged
|
||||
while the training and validation data cycle through the k-folds of the remaining data.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_folder: Path,
|
||||
|
|
|
@ -26,6 +26,7 @@ class PassThroughModel(SegmentationModelBase):
|
|||
"""
|
||||
Dummy model that returns a fixed segmentation, explained in make_nesting_rectangles.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
fg_classes = ["spinalcord", "lung_r", "lung_l", "heart", "esophagus"]
|
||||
fg_display_names = ["SpinalCord", "Lung_R", "Lung_L", "Heart", "Esophagus"]
|
||||
|
|
|
@ -26,6 +26,7 @@ class ResNetV2Block(nn.Module):
|
|||
ResNetV2 (https://arxiv.org/pdf/1603.05027.pdf) uses pre activation in the ResNet blocks.
|
||||
Big Transfer replaces BatchNorm with GroupNorm
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
|
@ -82,6 +83,7 @@ class ResNetV2Layer(nn.Module):
|
|||
"""
|
||||
Single layer of ResNetV2
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
|
@ -110,6 +112,7 @@ class BiTResNetV2(nn.Module):
|
|||
https://arxiv.org/pdf/1912.11370.pdf
|
||||
https://github.com/google-research/big_transfer
|
||||
"""
|
||||
|
||||
def __init__(self, num_groups: int = 32,
|
||||
num_classes: int = 21843,
|
||||
num_blocks_in_layer: Tuple[int, int, int, int] = (3, 4, 23, 3),
|
||||
|
|
|
@ -22,6 +22,7 @@ class WindowNormalizationForScalarItem(Transform3D[ScalarItem]):
|
|||
Transform3D to apply window normalization to "images" of a ScalarItem.
|
||||
"""
|
||||
# noinspection PyMissingConstructor
|
||||
|
||||
def __init__(self,
|
||||
output_range: Tuple[float, float] = (0, 1),
|
||||
sharpen: float = 1.9,
|
||||
|
|
|
@ -134,7 +134,6 @@ def get_labels_and_predictions_for_prediction_target_set(csv: Path,
|
|||
def print_metrics_for_thresholded_output_for_all_prediction_targets(csv_to_set_optimal_threshold: Path,
|
||||
csv_to_compute_metrics: Path,
|
||||
config: ScalarModelBase) -> None:
|
||||
|
||||
"""
|
||||
Given csvs written during inference for the validation and test sets, print out metrics for every combination of
|
||||
prediction targets that exist in the dataset (i.e. for every subset of classes that occur in the dataset).
|
||||
|
|
|
@ -70,7 +70,7 @@ def initialize_rpdb() -> None:
|
|||
# rpdb signal trapping does not work on Windows, as there is no SIGTRAP:
|
||||
if not is_linux():
|
||||
return
|
||||
import rpdb
|
||||
rpdb = __import__('rpdb') # hack so that the pre-commit hook does not flag this line
|
||||
rpdb_port = 4444
|
||||
rpdb.handle_trap(port=rpdb_port)
|
||||
# For some reason, os.getpid() does not return the ID of what appears to be the currently running process.
|
||||
|
|
|
@ -277,6 +277,7 @@ class DummySimCLRData(VisionDataset):
|
|||
"""
|
||||
Returns a constant vector of size three [1., 1., 1.]
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str,
|
||||
|
|
Загрузка…
Ссылка в новой задаче