This commit is contained in:
mahon94 2023-02-09 10:26:45 -08:00
Родитель e00a6d0eeb
Коммит f1be69c86a
1 изменённых файлов: 19 добавлений и 9 удалений

Просмотреть файл

@ -104,9 +104,11 @@ class DiscriminatorEncoder(nn.Module):
else:
self.encoder_network_body = NetworkBody(observation_specs, network_settings)
self.encoding_size = network_settings.hidden_units
# self.discriminator_output_layer = nn.Sequential(linear_layer(network_settings.hidden_units, 1, kernel_gain=0.2),
# nn.Sigmoid())
self.discriminator_output_layer = nn.Sequential(linear_layer(network_settings.hidden_units, 1, kernel_gain=0.2))
# self.discriminator_output_layer = nn.Sequential(linear_layer(
# network_settings.hidden_units, 1, kernel_gain=0.2), nn.Sigmoid())
self.discriminator_output_layer = nn.Sequential(
linear_layer(network_settings.hidden_units, 1, kernel_gain=0.2)
)
self.encoder_output_layer = nn.Linear(
self.encoding_size, ase_settings.latent_dim
@ -252,16 +254,24 @@ class DiscriminatorEncoder(nn.Module):
policy_estimate, _ = self.compute_estimates(policy_batch)
expert_estimate, _ = self.compute_estimates_expert(expert_batch)
disc_loss_policy = torch.nn.BCEWithLogitsLoss(policy_estimate, torch.zeros_like(policy_estimate))
disc_loss_expert = torch.nn.BCEWithLogitsLoss(expert_estimate, torch.ones_like(expert_estimate))
disc_loss_policy = torch.nn.BCEWithLogitsLoss(
policy_estimate, torch.zeros_like(policy_estimate)
)
disc_loss_expert = torch.nn.BCEWithLogitsLoss(
expert_estimate, torch.ones_like(expert_estimate)
)
total_loss += 0.5 * (disc_loss_policy + disc_loss_expert)
total_loss += torch.mean(0.5 * (disc_loss_policy + disc_loss_expert))
#logit reg
# logit reg
# not implemented yet
stats_dict["Policy/ASE Discriminator Policy Estimate"] = policy_estimate.mean().item()
stats_dict["Policy/ASE Discriminator Expert Estimate"] = expert_estimate.mean().item()
stats_dict[
"Policy/ASE Discriminator Policy Estimate"
] = policy_estimate.mean().item()
stats_dict[
"Policy/ASE Discriminator Expert Estimate"
] = expert_estimate.mean().item()
# discriminator_loss = -(
# torch.log(expert_estimate + self.EPSILON) + torch.log(1.0 - policy_estimate + self.EPSILON)).mean()
# total_loss += discriminator_loss