nit linter
This commit is contained in:
Родитель
e00a6d0eeb
Коммит
f1be69c86a
|
@ -104,9 +104,11 @@ class DiscriminatorEncoder(nn.Module):
|
|||
else:
|
||||
self.encoder_network_body = NetworkBody(observation_specs, network_settings)
|
||||
self.encoding_size = network_settings.hidden_units
|
||||
# self.discriminator_output_layer = nn.Sequential(linear_layer(network_settings.hidden_units, 1, kernel_gain=0.2),
|
||||
# nn.Sigmoid())
|
||||
self.discriminator_output_layer = nn.Sequential(linear_layer(network_settings.hidden_units, 1, kernel_gain=0.2))
|
||||
# self.discriminator_output_layer = nn.Sequential(linear_layer(
|
||||
# network_settings.hidden_units, 1, kernel_gain=0.2), nn.Sigmoid())
|
||||
self.discriminator_output_layer = nn.Sequential(
|
||||
linear_layer(network_settings.hidden_units, 1, kernel_gain=0.2)
|
||||
)
|
||||
|
||||
self.encoder_output_layer = nn.Linear(
|
||||
self.encoding_size, ase_settings.latent_dim
|
||||
|
@ -252,16 +254,24 @@ class DiscriminatorEncoder(nn.Module):
|
|||
policy_estimate, _ = self.compute_estimates(policy_batch)
|
||||
expert_estimate, _ = self.compute_estimates_expert(expert_batch)
|
||||
|
||||
disc_loss_policy = torch.nn.BCEWithLogitsLoss(policy_estimate, torch.zeros_like(policy_estimate))
|
||||
disc_loss_expert = torch.nn.BCEWithLogitsLoss(expert_estimate, torch.ones_like(expert_estimate))
|
||||
disc_loss_policy = torch.nn.BCEWithLogitsLoss(
|
||||
policy_estimate, torch.zeros_like(policy_estimate)
|
||||
)
|
||||
disc_loss_expert = torch.nn.BCEWithLogitsLoss(
|
||||
expert_estimate, torch.ones_like(expert_estimate)
|
||||
)
|
||||
|
||||
total_loss += 0.5 * (disc_loss_policy + disc_loss_expert)
|
||||
total_loss += torch.mean(0.5 * (disc_loss_policy + disc_loss_expert))
|
||||
|
||||
#logit reg
|
||||
# logit reg
|
||||
# not implemented yet
|
||||
|
||||
stats_dict["Policy/ASE Discriminator Policy Estimate"] = policy_estimate.mean().item()
|
||||
stats_dict["Policy/ASE Discriminator Expert Estimate"] = expert_estimate.mean().item()
|
||||
stats_dict[
|
||||
"Policy/ASE Discriminator Policy Estimate"
|
||||
] = policy_estimate.mean().item()
|
||||
stats_dict[
|
||||
"Policy/ASE Discriminator Expert Estimate"
|
||||
] = expert_estimate.mean().item()
|
||||
# discriminator_loss = -(
|
||||
# torch.log(expert_estimate + self.EPSILON) + torch.log(1.0 - policy_estimate + self.EPSILON)).mean()
|
||||
# total_loss += discriminator_loss
|
||||
|
|
Загрузка…
Ссылка в новой задаче