stack states/observations for Disc/Enc
This commit is contained in:
Родитель
d28a1534a7
Коммит
f42e3dad1c
|
@ -125,6 +125,7 @@ class TorchPPOOptimizer(TorchOptimizer):
|
|||
# Convert to tensors
|
||||
current_obs = [ModelUtils.list_to_tensor(obs) for obs in current_obs]
|
||||
|
||||
|
||||
act_masks = ModelUtils.list_to_tensor(batch[BufferKey.ACTION_MASK])
|
||||
actions = AgentAction.from_buffer(batch)
|
||||
|
||||
|
@ -174,7 +175,7 @@ class TorchPPOOptimizer(TorchOptimizer):
|
|||
loss_masks,
|
||||
decay_eps,
|
||||
)
|
||||
self.loss = (
|
||||
self.loss += (
|
||||
policy_loss
|
||||
+ 0.5 * value_loss
|
||||
- decay_bet * ModelUtils.masked_mean(entropy, loss_masks)
|
||||
|
|
|
@ -265,7 +265,7 @@ class ASESettings(RewardSignalSettings):
|
|||
batch_size: int = 1024
|
||||
shared_discriminator: bool = True
|
||||
demo_path: str = attr.ib(kw_only=True)
|
||||
|
||||
timestep: int = 1
|
||||
|
||||
# SAMPLERS #############################################################################
|
||||
class ParameterRandomizationType(Enum):
|
||||
|
|
|
@ -8,12 +8,10 @@ from mlagents.trainers.buffer import AgentBuffer, AgentBufferField
|
|||
from mlagents.trainers.demo_loader import demo_to_buffer
|
||||
from mlagents.trainers.exception import TrainerConfigError, TrainerError
|
||||
from mlagents.trainers.policy.torch_policy import TorchPolicy
|
||||
from mlagents.trainers.settings import ASESettings
|
||||
from mlagents.trainers.settings import ASESettings, NetworkSettings
|
||||
from mlagents.trainers.torch_entities.action_flattener import ActionFlattener
|
||||
from mlagents.trainers.torch_entities.agent_action import AgentAction
|
||||
from mlagents.trainers.torch_entities.components.reward_providers import (
|
||||
BaseRewardProvider,
|
||||
)
|
||||
from mlagents.trainers.torch_entities.components.reward_providers import BaseRewardProvider
|
||||
from mlagents.trainers.torch_entities.layers import linear_layer
|
||||
from mlagents.trainers.torch_entities.networks import NetworkBody
|
||||
from mlagents.trainers.torch_entities.utils import ModelUtils
|
||||
|
@ -25,6 +23,7 @@ class ASERewardProvider(BaseRewardProvider):
|
|||
def __init__(self, specs: BehaviorSpec, settings: ASESettings) -> None:
|
||||
super().__init__(specs, settings)
|
||||
self._ignore_done = False
|
||||
|
||||
self._discriminator_encoder = DiscriminatorEncoder(specs, settings)
|
||||
_, self._demo_buffer = demo_to_buffer(settings.demo_path, 1, specs, True)
|
||||
self._settings = settings
|
||||
|
@ -33,41 +32,33 @@ class ASERewardProvider(BaseRewardProvider):
|
|||
self.diversity_objective_weight = settings.omega_do
|
||||
self.update_batch_size = settings.batch_size
|
||||
|
||||
|
||||
def evaluate(self, mini_batch: AgentBuffer) -> np.ndarray:
|
||||
#TODO update mini_batch - must be done when saving the Agent Buffer
|
||||
#mini_batch = update_batch_w_stacked_obs(mini_batch)
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
disc_reward, encoder_reward = self._discriminator_encoder.compute_rewards(
|
||||
mini_batch
|
||||
)
|
||||
disc_reward, encoder_reward = self._discriminator_encoder.compute_rewards(mini_batch)
|
||||
|
||||
return ModelUtils.to_numpy(
|
||||
disc_reward.squeeze(dim=1)
|
||||
+ encoder_reward.squeeze(dim=1) * self._settings.beta_sdo
|
||||
)
|
||||
disc_reward.squeeze(dim=1) + encoder_reward.squeeze(dim=1) * self._settings.beta_sdo)
|
||||
|
||||
def update(self, mini_batch: AgentBuffer) -> Dict[str, np.ndarray]:
|
||||
expert_batch = self._demo_buffer.sample_mini_batch(self.update_batch_size, 1)
|
||||
|
||||
expert_batch = self._demo_buffer.sample_mini_batch(
|
||||
self.update_batch_size, 1
|
||||
)
|
||||
|
||||
if self.update_batch_size > mini_batch.num_experiences:
|
||||
raise TrainerError(
|
||||
"Discriminator batch size should be less than Policy batch size."
|
||||
)
|
||||
raise TrainerError('Discriminator batch size should be less than Policy batch size.')
|
||||
|
||||
if self.update_batch_size <= mini_batch.num_experiences:
|
||||
if self.update_batch_size < mini_batch.num_experiences:
|
||||
mini_batch = mini_batch.sample_mini_batch(self.update_batch_size)
|
||||
|
||||
self._discriminator_encoder.discriminator_network_body.update_normalization(
|
||||
expert_batch
|
||||
)
|
||||
#self._discriminator_encoder.discriminator_network_body.update_normalization(expert_batch)
|
||||
|
||||
(
|
||||
disc_loss,
|
||||
disc_stats_dict,
|
||||
) = self._discriminator_encoder.compute_discriminator_loss(
|
||||
mini_batch, expert_batch
|
||||
)
|
||||
enc_loss, enc_stats_dict = self._discriminator_encoder.compute_encoder_loss(
|
||||
mini_batch
|
||||
)
|
||||
disc_loss, disc_stats_dict = self._discriminator_encoder.compute_discriminator_loss(mini_batch, expert_batch)
|
||||
enc_loss, enc_stats_dict = self._discriminator_encoder.compute_encoder_loss(mini_batch)
|
||||
loss = disc_loss + self._settings.encoder_scaling * enc_loss
|
||||
self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
|
@ -75,9 +66,8 @@ class ASERewardProvider(BaseRewardProvider):
|
|||
stats_dict = {**disc_stats_dict, **enc_stats_dict}
|
||||
return stats_dict
|
||||
|
||||
def compute_diversity_loss(
|
||||
self, policy: TorchPolicy, policy_batch: AgentBuffer
|
||||
) -> Tuple[torch.Tensor, Dict[str, np.ndarray]]:
|
||||
def compute_diversity_loss(self, policy: TorchPolicy, policy_batch: AgentBuffer) -> Tuple[
|
||||
torch.Tensor, Dict[str, np.ndarray]]:
|
||||
return self._discriminator_encoder.compute_diversity_loss(policy, policy_batch)
|
||||
|
||||
|
||||
|
@ -95,30 +85,52 @@ class DiscriminatorEncoder(nn.Module):
|
|||
observation_specs = copy.deepcopy(behavior_spec.observation_specs)
|
||||
self.latent_key = self.get_latent_key(observation_specs)
|
||||
del observation_specs[self.latent_key]
|
||||
|
||||
# Add stacked observation instead
|
||||
self.timestep = ase_settings.timestep
|
||||
observation_specs = self.get_stacked_observation(self.timestep, observation_specs)
|
||||
|
||||
network_settings = ase_settings.network_settings
|
||||
self.discriminator_network_body = NetworkBody(
|
||||
observation_specs, network_settings
|
||||
)
|
||||
|
||||
|
||||
|
||||
if ase_settings.shared_discriminator:
|
||||
self.encoder_network_body = self.discriminator_network_body
|
||||
else:
|
||||
self.encoder_network_body = NetworkBody(observation_specs, network_settings)
|
||||
self.encoding_size = network_settings.hidden_units
|
||||
# self.discriminator_output_layer = nn.Sequential(linear_layer(network_settings.hidden_units, 1, kernel_gain=0.2),
|
||||
# nn.Sigmoid())
|
||||
self.discriminator_output_layer = nn.Sequential(linear_layer(network_settings.hidden_units, 1, kernel_gain=0.2))
|
||||
self.discriminator_output_layer = nn.Sequential(linear_layer(network_settings.hidden_units, 1, kernel_gain=0.2),
|
||||
nn.Sigmoid())
|
||||
# self.discriminator_output_layer = nn.Sequential(linear_layer(network_settings.hidden_units, 1, kernel_gain=0.2))
|
||||
|
||||
self.encoder_output_layer = nn.Linear(
|
||||
self.encoding_size, ase_settings.latent_dim
|
||||
)
|
||||
self.encoder_output_layer = nn.Linear(self.encoding_size, ase_settings.latent_dim)
|
||||
self.latent_dim = ase_settings.latent_dim
|
||||
self.encoder_reward_scale = ase_settings.encoder_scaling
|
||||
self.discriminator_reward_scale = 1
|
||||
self.gradient_penalty_weight = ase_settings.omega_gp
|
||||
self._action_flattener = ActionFlattener(behavior_spec.action_spec)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_latent_key(observation_specs: List[ObservationSpec]) -> int: # type: ignore
|
||||
def get_stacked_observation(n: int, observation_specs: List[ObservationSpec]) -> List[torch.Tensor]:
|
||||
new_observation_specs = []
|
||||
|
||||
for i, spec in enumerate(observation_specs):
|
||||
new_dim = 2*observation_specs[i].shape[0]
|
||||
new_observation_specs.append(ObservationSpec(
|
||||
shape=(new_dim,),
|
||||
dimension_property=observation_specs[i].dimension_property,
|
||||
observation_type=observation_specs[i].observation_type,
|
||||
name=observation_specs[i].name,
|
||||
))
|
||||
return new_observation_specs
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_latent_key(observation_specs: List[ObservationSpec]) -> int:
|
||||
try:
|
||||
for idx, spec in enumerate(observation_specs):
|
||||
if spec.name == "EmbeddingSensor":
|
||||
|
@ -129,14 +141,12 @@ class DiscriminatorEncoder(nn.Module):
|
|||
def forward(self, inputs: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
discriminator_network_output, _ = self.discriminator_network_body(inputs)
|
||||
encoder_network_output, _ = self.encoder_network_body(inputs)
|
||||
discriminator_output = self.discriminator_output_layer(
|
||||
discriminator_network_output
|
||||
)
|
||||
discriminator_output = self.discriminator_output_layer(discriminator_network_output)
|
||||
encoder_output = self.encoder_output_layer(encoder_network_output)
|
||||
encoder_output = torch.nn.functional.normalize(encoder_output, dim=-1)
|
||||
return discriminator_output, encoder_output
|
||||
|
||||
def update_latents(self, expert_batch: AgentBuffer, mini_batch: AgentBuffer): # type: ignore
|
||||
def update_latents(self, expert_batch: AgentBuffer, mini_batch: AgentBuffer):
|
||||
n_obs = len(self.discriminator_network_body.processors)
|
||||
latents = mini_batch[ObsUtil.get_name_at(self.latent_key)]
|
||||
for i in range(n_obs - 2, -1, -1):
|
||||
|
@ -146,22 +156,20 @@ class DiscriminatorEncoder(nn.Module):
|
|||
break
|
||||
expert_batch[ObsUtil.get_name_at(self.latent_key)] = latents
|
||||
|
||||
def replace_latents(self, mini_batch: AgentBuffer, new_latents: np.ndarray): # type: ignore
|
||||
def replace_latents(self, mini_batch: AgentBuffer, new_latents: np.ndarray):
|
||||
new_mini_batch = copy.deepcopy(mini_batch)
|
||||
new_latents = AgentBufferField(new_latents.tolist())
|
||||
new_mini_batch[ObsUtil.get_name_at(self.latent_key)] = new_latents
|
||||
return new_mini_batch
|
||||
|
||||
def remove_latents(self, mini_batch: AgentBuffer): # type: ignore
|
||||
def remove_latents(self, mini_batch: AgentBuffer):
|
||||
new_mini_batch = copy.deepcopy(mini_batch)
|
||||
del new_mini_batch[ObsUtil.get_name_at(self.latent_key)]
|
||||
del new_mini_batch[ObsUtil.get_name_at_next(self.latent_key)]
|
||||
|
||||
return new_mini_batch
|
||||
|
||||
def get_state_inputs(
|
||||
self, mini_batch: AgentBuffer, ignore_latent: bool = True
|
||||
) -> List[torch.Tensor]:
|
||||
def get_state_inputs(self, mini_batch: AgentBuffer, ignore_latent: bool = True, for_policy: bool = False) -> List[torch.Tensor]:
|
||||
n_obs = len(self.discriminator_network_body.processors) + 1
|
||||
np_obs = ObsUtil.from_buffer(mini_batch, n_obs)
|
||||
# Convert to tensors
|
||||
|
@ -170,23 +178,75 @@ class DiscriminatorEncoder(nn.Module):
|
|||
for index, obs in enumerate(np_obs):
|
||||
if ignore_latent and index == self.latent_key:
|
||||
continue
|
||||
tensor_obs.append(ModelUtils.list_to_tensor(obs))
|
||||
|
||||
nlast = len(obs)-1
|
||||
|
||||
if ignore_latent == False and index == self.latent_key:
|
||||
tensor_obs.append(ModelUtils.list_to_tensor(obs)) #[0:nlast]))
|
||||
elif for_policy == False:
|
||||
# stack observation inputs
|
||||
stacked_o = []
|
||||
for i, o in enumerate(obs):
|
||||
if i == nlast:
|
||||
#TODO: quick hack to make sure it has same dimension
|
||||
stacked_o.append(np.hstack((o, o)))
|
||||
#continue
|
||||
else:
|
||||
stacked_o.append(np.hstack((o, obs[i+1])))
|
||||
|
||||
tensor_obs.append(ModelUtils.list_to_tensor(stacked_o)) #(obs))
|
||||
else:
|
||||
tensor_obs.append(ModelUtils.list_to_tensor(obs))#[0:nlast]))
|
||||
|
||||
return tensor_obs
|
||||
|
||||
def get_state_inputs_expert(self, mini_batch: AgentBuffer): # type: ignore
|
||||
def get_state_inputs_expert(self, mini_batch: AgentBuffer) -> List[torch.Tensor]:
|
||||
n_obs = len(self.discriminator_network_body.processors)
|
||||
np_obs = ObsUtil.from_buffer(mini_batch, n_obs)
|
||||
# Convert to tensors
|
||||
tensor_obs = [ModelUtils.list_to_tensor(obs) for obs in np_obs]
|
||||
|
||||
#tensor_obs = [ModelUtils.list_to_tensor(obs) for obs in np_obs]
|
||||
tensor_obs = []
|
||||
for index, obs in enumerate(np_obs):
|
||||
nlast = len(obs)-1
|
||||
if True: # index == self.latent_key:
|
||||
# tensor_obs.append(ModelUtils.list_to_tensor(obs))#[0:nlast]))
|
||||
#else:
|
||||
# stack observation inputs
|
||||
stacked_o = []
|
||||
for i, o in enumerate(obs):
|
||||
if i == nlast:
|
||||
#TODO: quick hack to make sure it has same dimension
|
||||
stacked_o.append(np.hstack((o, o)))
|
||||
#continue
|
||||
else:
|
||||
stacked_o.append(np.hstack((o, obs[i+1])))
|
||||
|
||||
tensor_obs.append(ModelUtils.list_to_tensor(stacked_o)) #(obs))
|
||||
#print(tensor_obs[index].shape)
|
||||
return tensor_obs
|
||||
|
||||
def get_next_state_inputs(self, mini_batch: AgentBuffer) -> List[torch.Tensor]:
|
||||
n_obs = len(self.discriminator_network_body.processors)
|
||||
np_obs_next = ObsUtil.from_buffer_next(mini_batch, n_obs)
|
||||
# Convert to tensors
|
||||
tensor_obs = [ModelUtils.list_to_tensor(obs) for obs in np_obs_next]
|
||||
#tensor_obs = [ModelUtils.list_to_tensor(obs) for obs in np_obs_next]
|
||||
tensor_obs = []
|
||||
for index, obs in enumerate(np_obs):
|
||||
nlast = len(obs)-1
|
||||
if index == self.latent_key:
|
||||
tensor_obs.append(ModelUtils.list_to_tensor(obs))#[0:nlast]))
|
||||
else:
|
||||
# stack observation inputs
|
||||
stacked_o = []
|
||||
for i, o in enumerate(obs):
|
||||
if i == nlast:
|
||||
#TODO: quick hack to make sure it has same dimension
|
||||
stacked_o.append(np.hstack((o, o)))
|
||||
#continue
|
||||
else:
|
||||
stacked_o.append(np.hstack((o, obs[i+1])))
|
||||
|
||||
tensor_obs.append(ModelUtils.list_to_tensor(stacked_o)) #(obs))
|
||||
|
||||
return tensor_obs
|
||||
|
||||
def get_action_input(self, mini_batch: AgentBuffer) -> torch.Tensor:
|
||||
|
@ -194,7 +254,7 @@ class DiscriminatorEncoder(nn.Module):
|
|||
|
||||
def get_actions(self, policy: TorchPolicy, mini_batch: AgentBuffer) -> torch.Tensor:
|
||||
with torch.no_grad():
|
||||
obs = self.get_state_inputs(mini_batch, False)
|
||||
obs = self.get_state_inputs(mini_batch, False, True)
|
||||
action, _, _ = policy.actor.get_action_and_stats(obs)
|
||||
mu = action.continuous_tensor
|
||||
return mu
|
||||
|
@ -204,9 +264,7 @@ class DiscriminatorEncoder(nn.Module):
|
|||
ase_latents = inputs[self.latent_key]
|
||||
return ase_latents
|
||||
|
||||
def compute_rewards(
|
||||
self, mini_batch: AgentBuffer
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def compute_rewards(self, mini_batch: AgentBuffer) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# self.discriminator_network_body.update_normalization(mini_batch)
|
||||
disc_output, enc_output = self.compute_estimates(mini_batch)
|
||||
ase_latents = self.get_ase_latents(mini_batch)
|
||||
|
@ -214,36 +272,28 @@ class DiscriminatorEncoder(nn.Module):
|
|||
disc_reward = self._calc_disc_reward(disc_output)
|
||||
return disc_reward, enc_reward
|
||||
|
||||
def compute_estimates(
|
||||
self, mini_batch: AgentBuffer, ignore_latents: bool = True
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def compute_estimates(self, mini_batch: AgentBuffer, ignore_latents: bool = True) -> Tuple[
|
||||
torch.Tensor, torch.Tensor]:
|
||||
inputs = self.get_state_inputs(mini_batch)
|
||||
disc_output, enc_output = self.forward(inputs)
|
||||
|
||||
return disc_output, enc_output
|
||||
|
||||
def compute_estimates_expert(
|
||||
self, mini_batch: AgentBuffer
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def compute_estimates_expert(self, mini_batch: AgentBuffer) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
inputs = self.get_state_inputs_expert(mini_batch)
|
||||
self.discriminator_network_body.update_normalization_input(inputs)
|
||||
disc_output, enc_output = self.forward(inputs)
|
||||
return disc_output, enc_output
|
||||
|
||||
def compute_cat_estimates(
|
||||
self, mini_batch: AgentBuffer
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def compute_cat_estimates(self, mini_batch: AgentBuffer) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
inputs = self.get_state_inputs(mini_batch)
|
||||
next_inputs = self.get_next_state_inputs(mini_batch)
|
||||
inputs_cat = [
|
||||
torch.cat([inp, next_inp], dim=0)
|
||||
for inp, next_inp in zip(inputs, next_inputs)
|
||||
]
|
||||
inputs_cat = [torch.cat([inp, next_inp], dim=0) for inp, next_inp in zip(inputs, next_inputs)]
|
||||
disc_output, enc_output = self.forward(inputs_cat)
|
||||
return disc_output, enc_output
|
||||
|
||||
def compute_discriminator_loss(
|
||||
self, policy_batch: AgentBuffer, expert_batch: AgentBuffer
|
||||
) -> Tuple[torch.Tensor, Dict[str, np.ndarray]]:
|
||||
def compute_discriminator_loss(self, policy_batch: AgentBuffer, expert_batch: AgentBuffer) -> \
|
||||
Tuple[
|
||||
torch.Tensor, Dict[str, np.ndarray]]:
|
||||
# needs to compute the loss like ase:amp_agent.py:470, includes samples from a replay buffer???
|
||||
# uses torch.nn.bcewithlogitloss, so need to remove sigmoid at the output of the disc
|
||||
# also need to change gradient mag computation
|
||||
|
@ -251,28 +301,14 @@ class DiscriminatorEncoder(nn.Module):
|
|||
stats_dict: Dict[str, np.ndarray] = {}
|
||||
policy_estimate, _ = self.compute_estimates(policy_batch)
|
||||
expert_estimate, _ = self.compute_estimates_expert(expert_batch)
|
||||
|
||||
disc_loss_policy = torch.nn.BCEWithLogitsLoss(policy_estimate, torch.zeros_like(policy_estimate))
|
||||
disc_loss_expert = torch.nn.BCEWithLogitsLoss(expert_estimate, torch.ones_like(expert_estimate))
|
||||
|
||||
total_loss += 0.5 * (disc_loss_policy + disc_loss_expert)
|
||||
|
||||
#logit reg
|
||||
# not implemented yet
|
||||
|
||||
stats_dict["Policy/ASE Discriminator Policy Estimate"] = policy_estimate.mean().item()
|
||||
stats_dict["Policy/ASE Discriminator Expert Estimate"] = expert_estimate.mean().item()
|
||||
# discriminator_loss = -(
|
||||
# torch.log(expert_estimate + self.EPSILON) + torch.log(1.0 - policy_estimate + self.EPSILON)).mean()
|
||||
# total_loss += discriminator_loss
|
||||
|
||||
# grad penalty
|
||||
|
||||
discriminator_loss = -(
|
||||
torch.log(expert_estimate + self.EPSILON) + torch.log(1.0 - policy_estimate + self.EPSILON)).mean()
|
||||
total_loss += discriminator_loss
|
||||
if self.gradient_penalty_weight > 0:
|
||||
gradient_magnitude_loss = (
|
||||
self.gradient_penalty_weight
|
||||
* self.compute_gradient_magnitude(policy_batch, expert_batch)
|
||||
)
|
||||
self.gradient_penalty_weight * self.compute_gradient_magnitude(policy_batch, expert_batch))
|
||||
stats_dict["Policy/ASE Grad Mag Loss"] = gradient_magnitude_loss.item()
|
||||
total_loss += gradient_magnitude_loss
|
||||
|
||||
|
@ -280,9 +316,7 @@ class DiscriminatorEncoder(nn.Module):
|
|||
|
||||
return total_loss, stats_dict
|
||||
|
||||
def compute_gradient_magnitude(
|
||||
self, policy_batch: AgentBuffer, expert_batch: AgentBuffer
|
||||
) -> torch.Tensor:
|
||||
def compute_gradient_magnitude(self, policy_batch: AgentBuffer, expert_batch: AgentBuffer) -> torch.Tensor:
|
||||
policy_inputs = self.get_state_inputs(policy_batch)
|
||||
expert_inputs = self.get_state_inputs_expert(expert_batch)
|
||||
interp_inputs = []
|
||||
|
@ -293,16 +327,13 @@ class DiscriminatorEncoder(nn.Module):
|
|||
interp_inputs.append(interp_input)
|
||||
hidden, _ = self.discriminator_network_body(interp_inputs)
|
||||
estimate = self.discriminator_output_layer(hidden).squeeze(1).sum()
|
||||
gradient = torch.autograd.grad(
|
||||
estimate, tuple(interp_inputs), create_graph=True
|
||||
)[0]
|
||||
safe_norm = (torch.sum(gradient**2, dim=1) + self.EPSILON).sqrt()
|
||||
gradient = torch.autograd.grad(estimate, tuple(interp_inputs), create_graph=True)[0]
|
||||
safe_norm = (torch.sum(gradient ** 2, dim=1) + self.EPSILON).sqrt()
|
||||
gradient_mag = torch.mean((safe_norm - 1) ** 2)
|
||||
return gradient_mag
|
||||
|
||||
def compute_encoder_loss(
|
||||
self, policy_batch: AgentBuffer
|
||||
) -> Tuple[torch.Tensor, Dict[str, np.ndarray]]:
|
||||
def compute_encoder_loss(self, policy_batch: AgentBuffer) -> Tuple[
|
||||
torch.Tensor, Dict[str, np.ndarray]]:
|
||||
total_loss = torch.zeros(1)
|
||||
stats_dict: Dict[str, np.ndarray] = {}
|
||||
_, encoder_prediction = self.compute_estimates(policy_batch)
|
||||
|
@ -312,9 +343,8 @@ class DiscriminatorEncoder(nn.Module):
|
|||
stats_dict["Losses/ASE Encoder Loss"] = total_loss.item()
|
||||
return total_loss, stats_dict
|
||||
|
||||
def compute_diversity_loss(
|
||||
self, policy: TorchPolicy, policy_batch: AgentBuffer
|
||||
) -> Tuple[torch.Tensor, Dict[str, np.ndarray]]:
|
||||
def compute_diversity_loss(self, policy: TorchPolicy, policy_batch: AgentBuffer) -> Tuple[
|
||||
torch.Tensor, Dict[str, np.ndarray]]:
|
||||
# currently only supports continuous actions
|
||||
total_loss = torch.zeros(1)
|
||||
stats_dict: Dict[str, np.ndarray] = {}
|
||||
|
@ -338,7 +368,7 @@ class DiscriminatorEncoder(nn.Module):
|
|||
stats_dict["Losses/ASE Diversity Loss"] = total_loss.item()
|
||||
return total_loss, stats_dict
|
||||
|
||||
def sample_latents(self, n) -> np.ndarray: # type: ignore
|
||||
def sample_latents(self, n) -> np.ndarray:
|
||||
# torch version for future reference
|
||||
# z = torch.normal(torch.zeros([n, self.latent_dim], device=default_device()))
|
||||
# z = torch.nn.functional.normalize(z, dim=-1)
|
||||
|
@ -347,9 +377,7 @@ class DiscriminatorEncoder(nn.Module):
|
|||
denom = np.linalg.norm(z, axis=1, keepdims=True)
|
||||
return z / denom
|
||||
|
||||
def _calc_encoder_reward(
|
||||
self, encoder_prediction: torch.Tensor, ase_latents: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
def _calc_encoder_reward(self, encoder_prediction: torch.Tensor, ase_latents: torch.Tensor) -> torch.Tensor:
|
||||
error = self._calc_encoder_error(encoder_prediction, ase_latents)
|
||||
enc_reward = torch.clamp(-error, 0.0)
|
||||
enc_reward *= self.encoder_reward_scale
|
||||
|
@ -357,16 +385,10 @@ class DiscriminatorEncoder(nn.Module):
|
|||
|
||||
def _calc_disc_reward(self, discriminator_prediction: torch.Tensor) -> torch.Tensor:
|
||||
disc_reward = -torch.log(
|
||||
torch.maximum(
|
||||
1 - discriminator_prediction,
|
||||
torch.tensor(0.0001, device=default_device()),
|
||||
)
|
||||
)
|
||||
torch.maximum(1 - discriminator_prediction, torch.tensor(0.0001, device=default_device())))
|
||||
disc_reward *= self.discriminator_reward_scale
|
||||
return disc_reward
|
||||
|
||||
@staticmethod
|
||||
def _calc_encoder_error(
|
||||
encoder_prediction: torch.Tensor, ase_latents: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
return -torch.sum(encoder_prediction * ase_latents, dim=-1, keepdim=True)
|
||||
def _calc_encoder_error(encoder_prediction: torch.Tensor, ase_latents: torch.Tensor) -> torch.Tensor:
|
||||
return -torch.sum(encoder_prediction * ase_latents, dim=-1, keepdim=True)
|
|
@ -82,6 +82,11 @@ class ObservationEncoder(nn.Module):
|
|||
"""
|
||||
return self._total_goal_enc_size
|
||||
|
||||
def update_normalization_input(self, inputs: List[torch.Tensor]) -> None:
|
||||
for vec_input, enc in zip(inputs, self.processors):
|
||||
if isinstance(enc, VectorInput):
|
||||
enc.update_normalization(vec_input)
|
||||
|
||||
def update_normalization(self, buffer: AgentBuffer) -> None:
|
||||
obs = ObsUtil.from_buffer(buffer, len(self.processors))
|
||||
for vec_input, enc in zip(obs, self.processors):
|
||||
|
@ -216,6 +221,9 @@ class NetworkBody(nn.Module):
|
|||
else:
|
||||
self.lstm = None # type: ignore
|
||||
|
||||
def update_normalization_input(self, inputs: List[torch.Tensor]) -> None:
|
||||
self.observation_encoder.update_normalization_input(inputs)
|
||||
|
||||
def update_normalization(self, buffer: AgentBuffer) -> None:
|
||||
self.observation_encoder.update_normalization(buffer)
|
||||
|
||||
|
@ -764,3 +772,4 @@ class LearningRate(nn.Module):
|
|||
# Todo: add learning rate decay
|
||||
super().__init__()
|
||||
self.learning_rate = torch.Tensor([lr])
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче