STY: Ignore styling in git blame

This commit is contained in:
Peter Hessey 2023-03-23 15:08:00 +00:00
Родитель 6f0cff82d6
Коммит 64511b92b5
3 изменённых файлов: 28 добавлений и 17 удалений

Просмотреть файл

@ -1,2 +1,2 @@
# black styling refactor
hashgoeshere
6f0cff82d6a5e9bdd8adc6e4832a3d5935f97437

Просмотреть файл

@ -24,11 +24,21 @@ for folder in [repo_root / "hi-ml-azure" / "src", repo_root / "hi-ml" / "src"]:
sys.path.append(str(folder))
from health_ml.utils.logging import AzureMLLogger # noqa: E402
from health_azure.utils import (set_environment_variables_for_multi_node, # noqa: E402
is_local_rank_zero, is_global_rank_zero)
from health_azure.amulet import (ENV_AMLT_PROJECT_NAME, ENV_AMLT_INPUT_OUTPUT, # noqa: E402
ENV_AMLT_DATAREFERENCE_OUTPUT, is_amulet_job, get_amulet_aml_working_dir,
get_amulet_data_dir, get_amulet_output_dir, prepare_amulet_job)
from health_azure.utils import (
set_environment_variables_for_multi_node, # noqa: E402
is_local_rank_zero,
is_global_rank_zero,
)
from health_azure.amulet import (
ENV_AMLT_PROJECT_NAME,
ENV_AMLT_INPUT_OUTPUT, # noqa: E402
ENV_AMLT_DATAREFERENCE_OUTPUT,
is_amulet_job,
get_amulet_aml_working_dir,
get_amulet_data_dir,
get_amulet_output_dir,
prepare_amulet_job,
)
from health_azure import submit_to_azure_if_needed # noqa: E402
@ -186,15 +196,16 @@ def run_training_loop(logging_folder: Optional[Path] = None) -> None:
# Write all metrics also to AzureML natively, so that they are visible in the AzureML UI
loggers.append(AzureMLLogger())
trainer = Trainer(accelerator=accelerator,
strategy=strategy,
max_epochs=2,
logger=loggers,
num_nodes=1,
devices=devices,
# Setting the logging interval to a very small value because we have a tiny dataset
log_every_n_steps=1
)
trainer = Trainer(
accelerator=accelerator,
strategy=strategy,
max_epochs=2,
logger=loggers,
num_nodes=1,
devices=devices,
# Setting the logging interval to a very small value because we have a tiny dataset
log_every_n_steps=1,
)
model = BoringModel()
data_module = BoringDataModule()
trainer.fit(model, datamodule=data_module)

Просмотреть файл

@ -120,8 +120,8 @@ class MultiImageModel(ImageModel):
super().__init__(**kwargs)
assert isinstance(self.encoder, MultiImageEncoder), "MultiImageModel only supports MultiImageEncoder"
def forward(
self, current_image: torch.Tensor, previous_image: Optional[torch.Tensor] = None # type: ignore[override]
def forward( # type: ignore[override]
self, current_image: torch.Tensor, previous_image: Optional[torch.Tensor] = None
) -> ImageModelOutput:
with torch.set_grad_enabled(not self.freeze_encoder):
patch_x, pooled_x = self.encoder(