зеркало из https://github.com/microsoft/hi-ml.git
STY: Ignore styling in git blame
This commit is contained in:
Родитель
6f0cff82d6
Коммит
64511b92b5
|
@ -1,2 +1,2 @@
|
|||
# black styling refactor
|
||||
hashgoeshere
|
||||
6f0cff82d6a5e9bdd8adc6e4832a3d5935f97437
|
||||
|
|
|
@ -24,11 +24,21 @@ for folder in [repo_root / "hi-ml-azure" / "src", repo_root / "hi-ml" / "src"]:
|
|||
sys.path.append(str(folder))
|
||||
|
||||
from health_ml.utils.logging import AzureMLLogger # noqa: E402
|
||||
from health_azure.utils import (set_environment_variables_for_multi_node, # noqa: E402
|
||||
is_local_rank_zero, is_global_rank_zero)
|
||||
from health_azure.amulet import (ENV_AMLT_PROJECT_NAME, ENV_AMLT_INPUT_OUTPUT, # noqa: E402
|
||||
ENV_AMLT_DATAREFERENCE_OUTPUT, is_amulet_job, get_amulet_aml_working_dir,
|
||||
get_amulet_data_dir, get_amulet_output_dir, prepare_amulet_job)
|
||||
from health_azure.utils import (
|
||||
set_environment_variables_for_multi_node, # noqa: E402
|
||||
is_local_rank_zero,
|
||||
is_global_rank_zero,
|
||||
)
|
||||
from health_azure.amulet import (
|
||||
ENV_AMLT_PROJECT_NAME,
|
||||
ENV_AMLT_INPUT_OUTPUT, # noqa: E402
|
||||
ENV_AMLT_DATAREFERENCE_OUTPUT,
|
||||
is_amulet_job,
|
||||
get_amulet_aml_working_dir,
|
||||
get_amulet_data_dir,
|
||||
get_amulet_output_dir,
|
||||
prepare_amulet_job,
|
||||
)
|
||||
from health_azure import submit_to_azure_if_needed # noqa: E402
|
||||
|
||||
|
||||
|
@ -186,15 +196,16 @@ def run_training_loop(logging_folder: Optional[Path] = None) -> None:
|
|||
# Write all metrics also to AzureML natively, so that they are visible in the AzureML UI
|
||||
loggers.append(AzureMLLogger())
|
||||
|
||||
trainer = Trainer(accelerator=accelerator,
|
||||
strategy=strategy,
|
||||
max_epochs=2,
|
||||
logger=loggers,
|
||||
num_nodes=1,
|
||||
devices=devices,
|
||||
# Setting the logging interval to a very small value because we have a tiny dataset
|
||||
log_every_n_steps=1
|
||||
)
|
||||
trainer = Trainer(
|
||||
accelerator=accelerator,
|
||||
strategy=strategy,
|
||||
max_epochs=2,
|
||||
logger=loggers,
|
||||
num_nodes=1,
|
||||
devices=devices,
|
||||
# Setting the logging interval to a very small value because we have a tiny dataset
|
||||
log_every_n_steps=1,
|
||||
)
|
||||
model = BoringModel()
|
||||
data_module = BoringDataModule()
|
||||
trainer.fit(model, datamodule=data_module)
|
||||
|
|
|
@ -120,8 +120,8 @@ class MultiImageModel(ImageModel):
|
|||
super().__init__(**kwargs)
|
||||
assert isinstance(self.encoder, MultiImageEncoder), "MultiImageModel only supports MultiImageEncoder"
|
||||
|
||||
def forward(
|
||||
self, current_image: torch.Tensor, previous_image: Optional[torch.Tensor] = None # type: ignore[override]
|
||||
def forward( # type: ignore[override]
|
||||
self, current_image: torch.Tensor, previous_image: Optional[torch.Tensor] = None
|
||||
) -> ImageModelOutput:
|
||||
with torch.set_grad_enabled(not self.freeze_encoder):
|
||||
patch_x, pooled_x = self.encoder(
|
||||
|
|
Загрузка…
Ссылка в новой задаче