stack states/observations for Disc/Enc

This commit is contained in:
Yejin Kim 2023-02-10 14:01:45 -08:00
Родитель d28a1534a7
Коммит f42e3dad1c
4 изменённых файлов: 152 добавлений и 120 удалений

Просмотреть файл

@ -125,6 +125,7 @@ class TorchPPOOptimizer(TorchOptimizer):
# Convert to tensors
current_obs = [ModelUtils.list_to_tensor(obs) for obs in current_obs]
act_masks = ModelUtils.list_to_tensor(batch[BufferKey.ACTION_MASK])
actions = AgentAction.from_buffer(batch)
@ -174,7 +175,7 @@ class TorchPPOOptimizer(TorchOptimizer):
loss_masks,
decay_eps,
)
self.loss = (
self.loss += (
policy_loss
+ 0.5 * value_loss
- decay_bet * ModelUtils.masked_mean(entropy, loss_masks)

Просмотреть файл

@ -265,7 +265,7 @@ class ASESettings(RewardSignalSettings):
batch_size: int = 1024
shared_discriminator: bool = True
demo_path: str = attr.ib(kw_only=True)
timestep: int = 1
# SAMPLERS #############################################################################
class ParameterRandomizationType(Enum):

Просмотреть файл

@ -8,12 +8,10 @@ from mlagents.trainers.buffer import AgentBuffer, AgentBufferField
from mlagents.trainers.demo_loader import demo_to_buffer
from mlagents.trainers.exception import TrainerConfigError, TrainerError
from mlagents.trainers.policy.torch_policy import TorchPolicy
from mlagents.trainers.settings import ASESettings
from mlagents.trainers.settings import ASESettings, NetworkSettings
from mlagents.trainers.torch_entities.action_flattener import ActionFlattener
from mlagents.trainers.torch_entities.agent_action import AgentAction
from mlagents.trainers.torch_entities.components.reward_providers import (
BaseRewardProvider,
)
from mlagents.trainers.torch_entities.components.reward_providers import BaseRewardProvider
from mlagents.trainers.torch_entities.layers import linear_layer
from mlagents.trainers.torch_entities.networks import NetworkBody
from mlagents.trainers.torch_entities.utils import ModelUtils
@ -25,6 +23,7 @@ class ASERewardProvider(BaseRewardProvider):
def __init__(self, specs: BehaviorSpec, settings: ASESettings) -> None:
super().__init__(specs, settings)
self._ignore_done = False
self._discriminator_encoder = DiscriminatorEncoder(specs, settings)
_, self._demo_buffer = demo_to_buffer(settings.demo_path, 1, specs, True)
self._settings = settings
@ -33,41 +32,33 @@ class ASERewardProvider(BaseRewardProvider):
self.diversity_objective_weight = settings.omega_do
self.update_batch_size = settings.batch_size
def evaluate(self, mini_batch: AgentBuffer) -> np.ndarray:
#TODO update mini_batch - must be done when saving the Agent Buffer
#mini_batch = update_batch_w_stacked_obs(mini_batch)
with torch.no_grad():
disc_reward, encoder_reward = self._discriminator_encoder.compute_rewards(
mini_batch
)
disc_reward, encoder_reward = self._discriminator_encoder.compute_rewards(mini_batch)
return ModelUtils.to_numpy(
disc_reward.squeeze(dim=1)
+ encoder_reward.squeeze(dim=1) * self._settings.beta_sdo
)
disc_reward.squeeze(dim=1) + encoder_reward.squeeze(dim=1) * self._settings.beta_sdo)
def update(self, mini_batch: AgentBuffer) -> Dict[str, np.ndarray]:
expert_batch = self._demo_buffer.sample_mini_batch(self.update_batch_size, 1)
expert_batch = self._demo_buffer.sample_mini_batch(
self.update_batch_size, 1
)
if self.update_batch_size > mini_batch.num_experiences:
raise TrainerError(
"Discriminator batch size should be less than Policy batch size."
)
raise TrainerError('Discriminator batch size should be less than Policy batch size.')
if self.update_batch_size <= mini_batch.num_experiences:
if self.update_batch_size < mini_batch.num_experiences:
mini_batch = mini_batch.sample_mini_batch(self.update_batch_size)
self._discriminator_encoder.discriminator_network_body.update_normalization(
expert_batch
)
#self._discriminator_encoder.discriminator_network_body.update_normalization(expert_batch)
(
disc_loss,
disc_stats_dict,
) = self._discriminator_encoder.compute_discriminator_loss(
mini_batch, expert_batch
)
enc_loss, enc_stats_dict = self._discriminator_encoder.compute_encoder_loss(
mini_batch
)
disc_loss, disc_stats_dict = self._discriminator_encoder.compute_discriminator_loss(mini_batch, expert_batch)
enc_loss, enc_stats_dict = self._discriminator_encoder.compute_encoder_loss(mini_batch)
loss = disc_loss + self._settings.encoder_scaling * enc_loss
self.optimizer.zero_grad()
loss.backward()
@ -75,9 +66,8 @@ class ASERewardProvider(BaseRewardProvider):
stats_dict = {**disc_stats_dict, **enc_stats_dict}
return stats_dict
def compute_diversity_loss(
self, policy: TorchPolicy, policy_batch: AgentBuffer
) -> Tuple[torch.Tensor, Dict[str, np.ndarray]]:
def compute_diversity_loss(self, policy: TorchPolicy, policy_batch: AgentBuffer) -> Tuple[
torch.Tensor, Dict[str, np.ndarray]]:
return self._discriminator_encoder.compute_diversity_loss(policy, policy_batch)
@ -95,30 +85,52 @@ class DiscriminatorEncoder(nn.Module):
observation_specs = copy.deepcopy(behavior_spec.observation_specs)
self.latent_key = self.get_latent_key(observation_specs)
del observation_specs[self.latent_key]
# Add stacked observation instead
self.timestep = ase_settings.timestep
observation_specs = self.get_stacked_observation(self.timestep, observation_specs)
network_settings = ase_settings.network_settings
self.discriminator_network_body = NetworkBody(
observation_specs, network_settings
)
if ase_settings.shared_discriminator:
self.encoder_network_body = self.discriminator_network_body
else:
self.encoder_network_body = NetworkBody(observation_specs, network_settings)
self.encoding_size = network_settings.hidden_units
# self.discriminator_output_layer = nn.Sequential(linear_layer(network_settings.hidden_units, 1, kernel_gain=0.2),
# nn.Sigmoid())
self.discriminator_output_layer = nn.Sequential(linear_layer(network_settings.hidden_units, 1, kernel_gain=0.2))
self.discriminator_output_layer = nn.Sequential(linear_layer(network_settings.hidden_units, 1, kernel_gain=0.2),
nn.Sigmoid())
# self.discriminator_output_layer = nn.Sequential(linear_layer(network_settings.hidden_units, 1, kernel_gain=0.2))
self.encoder_output_layer = nn.Linear(
self.encoding_size, ase_settings.latent_dim
)
self.encoder_output_layer = nn.Linear(self.encoding_size, ase_settings.latent_dim)
self.latent_dim = ase_settings.latent_dim
self.encoder_reward_scale = ase_settings.encoder_scaling
self.discriminator_reward_scale = 1
self.gradient_penalty_weight = ase_settings.omega_gp
self._action_flattener = ActionFlattener(behavior_spec.action_spec)
@staticmethod
def get_latent_key(observation_specs: List[ObservationSpec]) -> int: # type: ignore
def get_stacked_observation(n: int, observation_specs: List[ObservationSpec]) -> List[torch.Tensor]:
new_observation_specs = []
for i, spec in enumerate(observation_specs):
new_dim = 2*observation_specs[i].shape[0]
new_observation_specs.append(ObservationSpec(
shape=(new_dim,),
dimension_property=observation_specs[i].dimension_property,
observation_type=observation_specs[i].observation_type,
name=observation_specs[i].name,
))
return new_observation_specs
@staticmethod
def get_latent_key(observation_specs: List[ObservationSpec]) -> int:
try:
for idx, spec in enumerate(observation_specs):
if spec.name == "EmbeddingSensor":
@ -129,14 +141,12 @@ class DiscriminatorEncoder(nn.Module):
def forward(self, inputs: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
discriminator_network_output, _ = self.discriminator_network_body(inputs)
encoder_network_output, _ = self.encoder_network_body(inputs)
discriminator_output = self.discriminator_output_layer(
discriminator_network_output
)
discriminator_output = self.discriminator_output_layer(discriminator_network_output)
encoder_output = self.encoder_output_layer(encoder_network_output)
encoder_output = torch.nn.functional.normalize(encoder_output, dim=-1)
return discriminator_output, encoder_output
def update_latents(self, expert_batch: AgentBuffer, mini_batch: AgentBuffer): # type: ignore
def update_latents(self, expert_batch: AgentBuffer, mini_batch: AgentBuffer):
n_obs = len(self.discriminator_network_body.processors)
latents = mini_batch[ObsUtil.get_name_at(self.latent_key)]
for i in range(n_obs - 2, -1, -1):
@ -146,22 +156,20 @@ class DiscriminatorEncoder(nn.Module):
break
expert_batch[ObsUtil.get_name_at(self.latent_key)] = latents
def replace_latents(self, mini_batch: AgentBuffer, new_latents: np.ndarray): # type: ignore
def replace_latents(self, mini_batch: AgentBuffer, new_latents: np.ndarray):
new_mini_batch = copy.deepcopy(mini_batch)
new_latents = AgentBufferField(new_latents.tolist())
new_mini_batch[ObsUtil.get_name_at(self.latent_key)] = new_latents
return new_mini_batch
def remove_latents(self, mini_batch: AgentBuffer): # type: ignore
def remove_latents(self, mini_batch: AgentBuffer):
new_mini_batch = copy.deepcopy(mini_batch)
del new_mini_batch[ObsUtil.get_name_at(self.latent_key)]
del new_mini_batch[ObsUtil.get_name_at_next(self.latent_key)]
return new_mini_batch
def get_state_inputs(
self, mini_batch: AgentBuffer, ignore_latent: bool = True
) -> List[torch.Tensor]:
def get_state_inputs(self, mini_batch: AgentBuffer, ignore_latent: bool = True, for_policy: bool = False) -> List[torch.Tensor]:
n_obs = len(self.discriminator_network_body.processors) + 1
np_obs = ObsUtil.from_buffer(mini_batch, n_obs)
# Convert to tensors
@ -170,23 +178,75 @@ class DiscriminatorEncoder(nn.Module):
for index, obs in enumerate(np_obs):
if ignore_latent and index == self.latent_key:
continue
tensor_obs.append(ModelUtils.list_to_tensor(obs))
nlast = len(obs)-1
if ignore_latent == False and index == self.latent_key:
tensor_obs.append(ModelUtils.list_to_tensor(obs)) #[0:nlast]))
elif for_policy == False:
# stack observation inputs
stacked_o = []
for i, o in enumerate(obs):
if i == nlast:
#TODO: quick hack to make sure it has same dimension
stacked_o.append(np.hstack((o, o)))
#continue
else:
stacked_o.append(np.hstack((o, obs[i+1])))
tensor_obs.append(ModelUtils.list_to_tensor(stacked_o)) #(obs))
else:
tensor_obs.append(ModelUtils.list_to_tensor(obs))#[0:nlast]))
return tensor_obs
def get_state_inputs_expert(self, mini_batch: AgentBuffer): # type: ignore
def get_state_inputs_expert(self, mini_batch: AgentBuffer) -> List[torch.Tensor]:
n_obs = len(self.discriminator_network_body.processors)
np_obs = ObsUtil.from_buffer(mini_batch, n_obs)
# Convert to tensors
tensor_obs = [ModelUtils.list_to_tensor(obs) for obs in np_obs]
#tensor_obs = [ModelUtils.list_to_tensor(obs) for obs in np_obs]
tensor_obs = []
for index, obs in enumerate(np_obs):
nlast = len(obs)-1
if True: # index == self.latent_key:
# tensor_obs.append(ModelUtils.list_to_tensor(obs))#[0:nlast]))
#else:
# stack observation inputs
stacked_o = []
for i, o in enumerate(obs):
if i == nlast:
#TODO: quick hack to make sure it has same dimension
stacked_o.append(np.hstack((o, o)))
#continue
else:
stacked_o.append(np.hstack((o, obs[i+1])))
tensor_obs.append(ModelUtils.list_to_tensor(stacked_o)) #(obs))
#print(tensor_obs[index].shape)
return tensor_obs
def get_next_state_inputs(self, mini_batch: AgentBuffer) -> List[torch.Tensor]:
n_obs = len(self.discriminator_network_body.processors)
np_obs_next = ObsUtil.from_buffer_next(mini_batch, n_obs)
# Convert to tensors
tensor_obs = [ModelUtils.list_to_tensor(obs) for obs in np_obs_next]
#tensor_obs = [ModelUtils.list_to_tensor(obs) for obs in np_obs_next]
tensor_obs = []
for index, obs in enumerate(np_obs):
nlast = len(obs)-1
if index == self.latent_key:
tensor_obs.append(ModelUtils.list_to_tensor(obs))#[0:nlast]))
else:
# stack observation inputs
stacked_o = []
for i, o in enumerate(obs):
if i == nlast:
#TODO: quick hack to make sure it has same dimension
stacked_o.append(np.hstack((o, o)))
#continue
else:
stacked_o.append(np.hstack((o, obs[i+1])))
tensor_obs.append(ModelUtils.list_to_tensor(stacked_o)) #(obs))
return tensor_obs
def get_action_input(self, mini_batch: AgentBuffer) -> torch.Tensor:
@ -194,7 +254,7 @@ class DiscriminatorEncoder(nn.Module):
def get_actions(self, policy: TorchPolicy, mini_batch: AgentBuffer) -> torch.Tensor:
with torch.no_grad():
obs = self.get_state_inputs(mini_batch, False)
obs = self.get_state_inputs(mini_batch, False, True)
action, _, _ = policy.actor.get_action_and_stats(obs)
mu = action.continuous_tensor
return mu
@ -204,9 +264,7 @@ class DiscriminatorEncoder(nn.Module):
ase_latents = inputs[self.latent_key]
return ase_latents
def compute_rewards(
self, mini_batch: AgentBuffer
) -> Tuple[torch.Tensor, torch.Tensor]:
def compute_rewards(self, mini_batch: AgentBuffer) -> Tuple[torch.Tensor, torch.Tensor]:
# self.discriminator_network_body.update_normalization(mini_batch)
disc_output, enc_output = self.compute_estimates(mini_batch)
ase_latents = self.get_ase_latents(mini_batch)
@ -214,36 +272,28 @@ class DiscriminatorEncoder(nn.Module):
disc_reward = self._calc_disc_reward(disc_output)
return disc_reward, enc_reward
def compute_estimates(
self, mini_batch: AgentBuffer, ignore_latents: bool = True
) -> Tuple[torch.Tensor, torch.Tensor]:
def compute_estimates(self, mini_batch: AgentBuffer, ignore_latents: bool = True) -> Tuple[
torch.Tensor, torch.Tensor]:
inputs = self.get_state_inputs(mini_batch)
disc_output, enc_output = self.forward(inputs)
return disc_output, enc_output
def compute_estimates_expert(
self, mini_batch: AgentBuffer
) -> Tuple[torch.Tensor, torch.Tensor]:
def compute_estimates_expert(self, mini_batch: AgentBuffer) -> Tuple[torch.Tensor, torch.Tensor]:
inputs = self.get_state_inputs_expert(mini_batch)
self.discriminator_network_body.update_normalization_input(inputs)
disc_output, enc_output = self.forward(inputs)
return disc_output, enc_output
def compute_cat_estimates(
self, mini_batch: AgentBuffer
) -> Tuple[torch.Tensor, torch.Tensor]:
def compute_cat_estimates(self, mini_batch: AgentBuffer) -> Tuple[torch.Tensor, torch.Tensor]:
inputs = self.get_state_inputs(mini_batch)
next_inputs = self.get_next_state_inputs(mini_batch)
inputs_cat = [
torch.cat([inp, next_inp], dim=0)
for inp, next_inp in zip(inputs, next_inputs)
]
inputs_cat = [torch.cat([inp, next_inp], dim=0) for inp, next_inp in zip(inputs, next_inputs)]
disc_output, enc_output = self.forward(inputs_cat)
return disc_output, enc_output
def compute_discriminator_loss(
self, policy_batch: AgentBuffer, expert_batch: AgentBuffer
) -> Tuple[torch.Tensor, Dict[str, np.ndarray]]:
def compute_discriminator_loss(self, policy_batch: AgentBuffer, expert_batch: AgentBuffer) -> \
Tuple[
torch.Tensor, Dict[str, np.ndarray]]:
# needs to compute the loss like ase:amp_agent.py:470, includes samples from a replay buffer???
# uses torch.nn.bcewithlogitloss, so need to remove sigmoid at the output of the disc
# also need to change gradient mag computation
@ -251,28 +301,14 @@ class DiscriminatorEncoder(nn.Module):
stats_dict: Dict[str, np.ndarray] = {}
policy_estimate, _ = self.compute_estimates(policy_batch)
expert_estimate, _ = self.compute_estimates_expert(expert_batch)
disc_loss_policy = torch.nn.BCEWithLogitsLoss(policy_estimate, torch.zeros_like(policy_estimate))
disc_loss_expert = torch.nn.BCEWithLogitsLoss(expert_estimate, torch.ones_like(expert_estimate))
total_loss += 0.5 * (disc_loss_policy + disc_loss_expert)
#logit reg
# not implemented yet
stats_dict["Policy/ASE Discriminator Policy Estimate"] = policy_estimate.mean().item()
stats_dict["Policy/ASE Discriminator Expert Estimate"] = expert_estimate.mean().item()
# discriminator_loss = -(
# torch.log(expert_estimate + self.EPSILON) + torch.log(1.0 - policy_estimate + self.EPSILON)).mean()
# total_loss += discriminator_loss
# grad penalty
discriminator_loss = -(
torch.log(expert_estimate + self.EPSILON) + torch.log(1.0 - policy_estimate + self.EPSILON)).mean()
total_loss += discriminator_loss
if self.gradient_penalty_weight > 0:
gradient_magnitude_loss = (
self.gradient_penalty_weight
* self.compute_gradient_magnitude(policy_batch, expert_batch)
)
self.gradient_penalty_weight * self.compute_gradient_magnitude(policy_batch, expert_batch))
stats_dict["Policy/ASE Grad Mag Loss"] = gradient_magnitude_loss.item()
total_loss += gradient_magnitude_loss
@ -280,9 +316,7 @@ class DiscriminatorEncoder(nn.Module):
return total_loss, stats_dict
def compute_gradient_magnitude(
self, policy_batch: AgentBuffer, expert_batch: AgentBuffer
) -> torch.Tensor:
def compute_gradient_magnitude(self, policy_batch: AgentBuffer, expert_batch: AgentBuffer) -> torch.Tensor:
policy_inputs = self.get_state_inputs(policy_batch)
expert_inputs = self.get_state_inputs_expert(expert_batch)
interp_inputs = []
@ -293,16 +327,13 @@ class DiscriminatorEncoder(nn.Module):
interp_inputs.append(interp_input)
hidden, _ = self.discriminator_network_body(interp_inputs)
estimate = self.discriminator_output_layer(hidden).squeeze(1).sum()
gradient = torch.autograd.grad(
estimate, tuple(interp_inputs), create_graph=True
)[0]
safe_norm = (torch.sum(gradient**2, dim=1) + self.EPSILON).sqrt()
gradient = torch.autograd.grad(estimate, tuple(interp_inputs), create_graph=True)[0]
safe_norm = (torch.sum(gradient ** 2, dim=1) + self.EPSILON).sqrt()
gradient_mag = torch.mean((safe_norm - 1) ** 2)
return gradient_mag
def compute_encoder_loss(
self, policy_batch: AgentBuffer
) -> Tuple[torch.Tensor, Dict[str, np.ndarray]]:
def compute_encoder_loss(self, policy_batch: AgentBuffer) -> Tuple[
torch.Tensor, Dict[str, np.ndarray]]:
total_loss = torch.zeros(1)
stats_dict: Dict[str, np.ndarray] = {}
_, encoder_prediction = self.compute_estimates(policy_batch)
@ -312,9 +343,8 @@ class DiscriminatorEncoder(nn.Module):
stats_dict["Losses/ASE Encoder Loss"] = total_loss.item()
return total_loss, stats_dict
def compute_diversity_loss(
self, policy: TorchPolicy, policy_batch: AgentBuffer
) -> Tuple[torch.Tensor, Dict[str, np.ndarray]]:
def compute_diversity_loss(self, policy: TorchPolicy, policy_batch: AgentBuffer) -> Tuple[
torch.Tensor, Dict[str, np.ndarray]]:
# currently only supports continuous actions
total_loss = torch.zeros(1)
stats_dict: Dict[str, np.ndarray] = {}
@ -338,7 +368,7 @@ class DiscriminatorEncoder(nn.Module):
stats_dict["Losses/ASE Diversity Loss"] = total_loss.item()
return total_loss, stats_dict
def sample_latents(self, n) -> np.ndarray: # type: ignore
def sample_latents(self, n) -> np.ndarray:
# torch version for future reference
# z = torch.normal(torch.zeros([n, self.latent_dim], device=default_device()))
# z = torch.nn.functional.normalize(z, dim=-1)
@ -347,9 +377,7 @@ class DiscriminatorEncoder(nn.Module):
denom = np.linalg.norm(z, axis=1, keepdims=True)
return z / denom
def _calc_encoder_reward(
self, encoder_prediction: torch.Tensor, ase_latents: torch.Tensor
) -> torch.Tensor:
def _calc_encoder_reward(self, encoder_prediction: torch.Tensor, ase_latents: torch.Tensor) -> torch.Tensor:
error = self._calc_encoder_error(encoder_prediction, ase_latents)
enc_reward = torch.clamp(-error, 0.0)
enc_reward *= self.encoder_reward_scale
@ -357,16 +385,10 @@ class DiscriminatorEncoder(nn.Module):
def _calc_disc_reward(self, discriminator_prediction: torch.Tensor) -> torch.Tensor:
disc_reward = -torch.log(
torch.maximum(
1 - discriminator_prediction,
torch.tensor(0.0001, device=default_device()),
)
)
torch.maximum(1 - discriminator_prediction, torch.tensor(0.0001, device=default_device())))
disc_reward *= self.discriminator_reward_scale
return disc_reward
@staticmethod
def _calc_encoder_error(
encoder_prediction: torch.Tensor, ase_latents: torch.Tensor
) -> torch.Tensor:
return -torch.sum(encoder_prediction * ase_latents, dim=-1, keepdim=True)
def _calc_encoder_error(encoder_prediction: torch.Tensor, ase_latents: torch.Tensor) -> torch.Tensor:
return -torch.sum(encoder_prediction * ase_latents, dim=-1, keepdim=True)

Просмотреть файл

@ -82,6 +82,11 @@ class ObservationEncoder(nn.Module):
"""
return self._total_goal_enc_size
def update_normalization_input(self, inputs: List[torch.Tensor]) -> None:
for vec_input, enc in zip(inputs, self.processors):
if isinstance(enc, VectorInput):
enc.update_normalization(vec_input)
def update_normalization(self, buffer: AgentBuffer) -> None:
obs = ObsUtil.from_buffer(buffer, len(self.processors))
for vec_input, enc in zip(obs, self.processors):
@ -216,6 +221,9 @@ class NetworkBody(nn.Module):
else:
self.lstm = None # type: ignore
def update_normalization_input(self, inputs: List[torch.Tensor]) -> None:
self.observation_encoder.update_normalization_input(inputs)
def update_normalization(self, buffer: AgentBuffer) -> None:
self.observation_encoder.update_normalization(buffer)
@ -764,3 +772,4 @@ class LearningRate(nn.Module):
# Todo: add learning rate decay
super().__init__()
self.learning_rate = torch.Tensor([lr])