Fix isort failures in metrics (PL^5528)
Remove from skipped module in pyproject.toml and fix failures on: - pytorch_lightning/metrics/*.py (cherry picked from commit 135af5d0131c140f4522a33bb6ef5041281b4ff7)
This commit is contained in:
Родитель
5cf4e45e6b
Коммит
9385e19d16
|
@ -11,29 +11,27 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pytorch_lightning.metrics.metric import Metric, MetricCollection # noqa: F401
|
||||
|
||||
from pytorch_lightning.metrics.classification import ( # noqa: F401
|
||||
Accuracy,
|
||||
AveragePrecision,
|
||||
ConfusionMatrix,
|
||||
F1,
|
||||
FBeta,
|
||||
HammingDistance,
|
||||
IoU,
|
||||
Precision,
|
||||
Recall,
|
||||
ConfusionMatrix,
|
||||
PrecisionRecallCurve,
|
||||
AveragePrecision,
|
||||
Recall,
|
||||
ROC,
|
||||
FBeta,
|
||||
F1,
|
||||
StatScores
|
||||
StatScores,
|
||||
)
|
||||
|
||||
from pytorch_lightning.metrics.metric import Metric, MetricCollection # noqa: F401
|
||||
from pytorch_lightning.metrics.regression import ( # noqa: F401
|
||||
MeanSquaredError,
|
||||
MeanAbsoluteError,
|
||||
MeanSquaredLogError,
|
||||
ExplainedVariance,
|
||||
MeanAbsoluteError,
|
||||
MeanSquaredError,
|
||||
MeanSquaredLogError,
|
||||
PSNR,
|
||||
R2Score,
|
||||
SSIM,
|
||||
R2Score
|
||||
)
|
||||
|
|
|
@ -14,10 +14,10 @@
|
|||
from pytorch_lightning.metrics.classification.accuracy import Accuracy # noqa: F401
|
||||
from pytorch_lightning.metrics.classification.average_precision import AveragePrecision # noqa: F401
|
||||
from pytorch_lightning.metrics.classification.confusion_matrix import ConfusionMatrix # noqa: F401
|
||||
from pytorch_lightning.metrics.classification.f_beta import FBeta, F1 # noqa: F401
|
||||
from pytorch_lightning.metrics.classification.f_beta import F1, FBeta # noqa: F401
|
||||
from pytorch_lightning.metrics.classification.hamming_distance import HammingDistance # noqa: F401
|
||||
from pytorch_lightning.metrics.classification.iou import IoU # noqa: F401
|
||||
from pytorch_lightning.metrics.classification.precision_recall import Precision, Recall # noqa: F401
|
||||
from pytorch_lightning.metrics.classification.precision_recall_curve import PrecisionRecallCurve # noqa: F401
|
||||
from pytorch_lightning.metrics.classification.roc import ROC # noqa: F401
|
||||
from pytorch_lightning.metrics.classification.stat_scores import StatScores # noqa: F401
|
||||
from pytorch_lightning.metrics.classification.iou import IoU # noqa: F401
|
||||
|
|
|
@ -15,8 +15,8 @@ from typing import Any, Callable, Optional
|
|||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.metrics.functional.accuracy import _accuracy_compute, _accuracy_update
|
||||
from pytorch_lightning.metrics.metric import Metric
|
||||
from pytorch_lightning.metrics.functional.accuracy import _accuracy_update, _accuracy_compute
|
||||
|
||||
|
||||
class Accuracy(Metric):
|
||||
|
|
|
@ -11,15 +11,12 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional, Any, Union, List
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.metrics import Metric
|
||||
from pytorch_lightning.metrics.functional.average_precision import (
|
||||
_average_precision_update,
|
||||
_average_precision_compute
|
||||
)
|
||||
from pytorch_lightning.metrics.functional.average_precision import _average_precision_compute, _average_precision_update
|
||||
from pytorch_lightning.metrics.metric import Metric
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
|
||||
|
||||
|
|
|
@ -15,10 +15,7 @@ from typing import Any, Optional
|
|||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.metrics.functional.confusion_matrix import (
|
||||
_confusion_matrix_update,
|
||||
_confusion_matrix_compute
|
||||
)
|
||||
from pytorch_lightning.metrics.functional.confusion_matrix import _confusion_matrix_compute, _confusion_matrix_update
|
||||
from pytorch_lightning.metrics.metric import Metric
|
||||
|
||||
|
||||
|
|
|
@ -15,10 +15,7 @@ from typing import Any, Optional
|
|||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.metrics.functional.f_beta import (
|
||||
_fbeta_update,
|
||||
_fbeta_compute
|
||||
)
|
||||
from pytorch_lightning.metrics.functional.f_beta import _fbeta_compute, _fbeta_update
|
||||
from pytorch_lightning.metrics.metric import Metric
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
|
||||
|
|
|
@ -14,8 +14,9 @@
|
|||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.metrics.functional.hamming_distance import _hamming_distance_compute, _hamming_distance_update
|
||||
from pytorch_lightning.metrics.metric import Metric
|
||||
from pytorch_lightning.metrics.functional.hamming_distance import _hamming_distance_update, _hamming_distance_compute
|
||||
|
||||
|
||||
class HammingDistance(Metric):
|
||||
|
|
|
@ -11,12 +11,12 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Tuple, Optional
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.metrics.utils import to_onehot, select_topk
|
||||
from pytorch_lightning.metrics.utils import select_topk, to_onehot
|
||||
|
||||
|
||||
def _basic_input_validation(preds: torch.Tensor, target: torch.Tensor, threshold: float, is_multiclass: bool):
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.metrics.classification.confusion_matrix import ConfusionMatrix
|
||||
from pytorch_lightning.metrics.functional.iou import _iou_from_confmat
|
||||
|
||||
|
|
|
@ -11,15 +11,15 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional, Any, Union, Tuple, List
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.metrics import Metric
|
||||
from pytorch_lightning.metrics.functional.precision_recall_curve import (
|
||||
_precision_recall_curve_compute,
|
||||
_precision_recall_curve_update,
|
||||
_precision_recall_curve_compute
|
||||
)
|
||||
from pytorch_lightning.metrics.metric import Metric
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
|
||||
|
||||
|
|
|
@ -11,15 +11,12 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional, Any, Union, List, Tuple
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.metrics import Metric
|
||||
from pytorch_lightning.metrics.functional.roc import (
|
||||
_roc_update,
|
||||
_roc_compute
|
||||
)
|
||||
from pytorch_lightning.metrics.functional.roc import _roc_compute, _roc_update
|
||||
from pytorch_lightning.metrics.metric import Metric
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
|
||||
|
||||
|
|
|
@ -15,8 +15,8 @@ from typing import Any, Callable, Optional, Tuple
|
|||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.metrics import Metric
|
||||
from pytorch_lightning.metrics.functional.stat_scores import _stat_scores_compute, _stat_scores_update
|
||||
from pytorch_lightning.metrics.metric import Metric
|
||||
|
||||
|
||||
class StatScores(Metric):
|
||||
|
|
|
@ -11,6 +11,8 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# TODO: unify metrics between class and functional, add below
|
||||
from pytorch_lightning.metrics.functional.accuracy import accuracy # noqa: F401
|
||||
from pytorch_lightning.metrics.functional.average_precision import average_precision # noqa: F401
|
||||
from pytorch_lightning.metrics.functional.classification import ( # noqa: F401
|
||||
auc,
|
||||
|
@ -22,11 +24,9 @@ from pytorch_lightning.metrics.functional.classification import ( # noqa: F401
|
|||
to_categorical,
|
||||
to_onehot,
|
||||
)
|
||||
# TODO: unify metrics between class and functional, add below
|
||||
from pytorch_lightning.metrics.functional.accuracy import accuracy # noqa: F401
|
||||
from pytorch_lightning.metrics.functional.confusion_matrix import confusion_matrix # noqa: F401
|
||||
from pytorch_lightning.metrics.functional.explained_variance import explained_variance # noqa: F401
|
||||
from pytorch_lightning.metrics.functional.f_beta import fbeta, f1 # noqa: F401
|
||||
from pytorch_lightning.metrics.functional.f_beta import f1, fbeta # noqa: F401
|
||||
from pytorch_lightning.metrics.functional.hamming_distance import hamming_distance # noqa: F401
|
||||
from pytorch_lightning.metrics.functional.image_gradients import image_gradients # noqa: F401
|
||||
from pytorch_lightning.metrics.functional.iou import iou # noqa: F401
|
||||
|
|
|
@ -11,9 +11,10 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Tuple, Optional
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.metrics.classification.helpers import _input_format_classification
|
||||
|
||||
|
||||
|
|
|
@ -11,13 +11,13 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional, Sequence, Tuple, Union, List
|
||||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.metrics.functional.precision_recall_curve import (
|
||||
_precision_recall_curve_compute,
|
||||
_precision_recall_curve_update,
|
||||
_precision_recall_curve_compute
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -16,20 +16,17 @@ from functools import wraps
|
|||
from typing import Callable, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.metrics.functional.average_precision import average_precision as __ap
|
||||
from pytorch_lightning.metrics.functional.iou import iou as __iou
|
||||
from pytorch_lightning.metrics.functional.precision_recall_curve import (
|
||||
_binary_clf_curve,
|
||||
precision_recall_curve as __prc
|
||||
)
|
||||
from pytorch_lightning.metrics.functional.precision_recall_curve import _binary_clf_curve
|
||||
from pytorch_lightning.metrics.functional.precision_recall_curve import precision_recall_curve as __prc
|
||||
from pytorch_lightning.metrics.functional.roc import roc as __roc
|
||||
from pytorch_lightning.metrics.utils import (
|
||||
to_categorical as __tc,
|
||||
to_onehot as __to,
|
||||
get_num_classes as __gnc,
|
||||
reduce,
|
||||
class_reduce,
|
||||
)
|
||||
from pytorch_lightning.metrics.utils import class_reduce
|
||||
from pytorch_lightning.metrics.utils import get_num_classes as __gnc
|
||||
from pytorch_lightning.metrics.utils import reduce
|
||||
from pytorch_lightning.metrics.utils import to_categorical as __tc
|
||||
from pytorch_lightning.metrics.utils import to_onehot as __to
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Union, Tuple, Sequence
|
||||
from typing import Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.metrics.classification.helpers import _input_format_classification
|
||||
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.metrics.functional.confusion_matrix import _confusion_matrix_update
|
||||
from pytorch_lightning.metrics.functional.reduction import reduce
|
||||
from pytorch_lightning.metrics.utils import get_num_classes
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.metrics.classification.helpers import _reduce_stat_scores
|
||||
from pytorch_lightning.metrics.functional.stat_scores import _stat_scores_update
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional, Sequence, Tuple, List, Union
|
||||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
|
||||
from typing import Tuple, Optional
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
|
|
@ -13,7 +13,8 @@
|
|||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.metrics.utils import reduce as __reduce, class_reduce as __cr
|
||||
from pytorch_lightning.metrics.utils import class_reduce as __cr
|
||||
from pytorch_lightning.metrics.utils import reduce as __reduce
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
|
||||
|
||||
|
|
|
@ -11,13 +11,13 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional, Sequence, Tuple, List, Union
|
||||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.metrics.functional.precision_recall_curve import (
|
||||
_binary_clf_curve,
|
||||
_precision_recall_curve_update,
|
||||
_binary_clf_curve
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -11,9 +11,10 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Tuple, Optional
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.metrics.classification.helpers import _input_format_classification
|
||||
|
||||
|
||||
|
|
|
@ -11,10 +11,10 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pytorch_lightning.metrics.regression.mean_squared_error import MeanSquaredError # noqa: F401
|
||||
from pytorch_lightning.metrics.regression.mean_absolute_error import MeanAbsoluteError # noqa: F401
|
||||
from pytorch_lightning.metrics.regression.mean_squared_log_error import MeanSquaredLogError # noqa: F401
|
||||
from pytorch_lightning.metrics.regression.explained_variance import ExplainedVariance # noqa: F401
|
||||
from pytorch_lightning.metrics.regression.mean_absolute_error import MeanAbsoluteError # noqa: F401
|
||||
from pytorch_lightning.metrics.regression.mean_squared_error import MeanSquaredError # noqa: F401
|
||||
from pytorch_lightning.metrics.regression.mean_squared_log_error import MeanSquaredLogError # noqa: F401
|
||||
from pytorch_lightning.metrics.regression.psnr import PSNR # noqa: F401
|
||||
from pytorch_lightning.metrics.regression.ssim import SSIM # noqa: F401
|
||||
from pytorch_lightning.metrics.regression.r2score import R2Score # noqa: F401
|
||||
from pytorch_lightning.metrics.regression.ssim import SSIM # noqa: F401
|
||||
|
|
|
@ -11,15 +11,16 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.metrics.functional.explained_variance import (
|
||||
_explained_variance_compute,
|
||||
_explained_variance_update,
|
||||
)
|
||||
from pytorch_lightning.metrics.metric import Metric
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.metrics.functional.explained_variance import (
|
||||
_explained_variance_update,
|
||||
_explained_variance_compute,
|
||||
)
|
||||
|
||||
|
||||
class ExplainedVariance(Metric):
|
||||
|
|
|
@ -11,14 +11,15 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from pytorch_lightning.metrics.metric import Metric
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.metrics.functional.mean_absolute_error import (
|
||||
_mean_absolute_error_compute,
|
||||
_mean_absolute_error_update,
|
||||
_mean_absolute_error_compute
|
||||
)
|
||||
from pytorch_lightning.metrics.metric import Metric
|
||||
|
||||
|
||||
class MeanAbsoluteError(Metric):
|
||||
|
|
|
@ -11,14 +11,15 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from pytorch_lightning.metrics.metric import Metric
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.metrics.functional.mean_squared_error import (
|
||||
_mean_squared_error_compute,
|
||||
_mean_squared_error_update,
|
||||
_mean_squared_error_compute
|
||||
)
|
||||
from pytorch_lightning.metrics.metric import Metric
|
||||
|
||||
|
||||
class MeanSquaredError(Metric):
|
||||
|
|
|
@ -11,14 +11,15 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from pytorch_lightning.metrics.metric import Metric
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.metrics.functional.mean_squared_log_error import (
|
||||
_mean_squared_log_error_compute,
|
||||
_mean_squared_log_error_update,
|
||||
_mean_squared_log_error_compute
|
||||
)
|
||||
from pytorch_lightning.metrics.metric import Metric
|
||||
|
||||
|
||||
class MeanSquaredLogError(Metric):
|
||||
|
|
|
@ -11,14 +11,12 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.metrics.functional.psnr import _psnr_compute, _psnr_update
|
||||
from pytorch_lightning.metrics.metric import Metric
|
||||
from pytorch_lightning.metrics.functional.psnr import (
|
||||
_psnr_update,
|
||||
_psnr_compute,
|
||||
)
|
||||
|
||||
|
||||
class PSNR(Metric):
|
||||
|
|
|
@ -15,11 +15,8 @@ from typing import Any, Callable, Optional
|
|||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.metrics.functional.r2score import _r2score_compute, _r2score_update
|
||||
from pytorch_lightning.metrics.metric import Metric
|
||||
from pytorch_lightning.metrics.functional.r2score import (
|
||||
_r2score_update,
|
||||
_r2score_compute
|
||||
)
|
||||
|
||||
|
||||
class R2Score(Metric):
|
||||
|
|
|
@ -11,12 +11,13 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
from typing import Any, Optional, Sequence
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.metrics.functional.ssim import _ssim_compute, _ssim_update
|
||||
from pytorch_lightning.metrics.metric import Metric
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.metrics.functional.ssim import _ssim_update, _ssim_compute
|
||||
|
||||
|
||||
class SSIM(Metric):
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Tuple, Optional
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче