Master.double dqn bug fix (#315)
* issue#264 bug fix * double DQN bug fix * rm print Co-authored-by: ysqyang <v-yangqi@microsoft.com>
This commit is contained in:
Родитель
7e02baeb3d
Коммит
df0e4a9d18
|
@ -109,11 +109,12 @@ class DQN(AbsAgent):
|
|||
|
||||
q_all = self._get_q_values(states)
|
||||
q = select_by_actions(q_all, actions)
|
||||
next_q_all = self._get_q_values(next_states, is_eval=False, training=False)
|
||||
next_q_all_target = self._get_q_values(next_states, is_eval=False, training=False)
|
||||
if self.config.double:
|
||||
next_q = select_by_actions(next_q_all) # (N,)
|
||||
next_q_all_eval = self._get_q_values(next_states, training=False)
|
||||
next_q = select_by_actions(next_q_all_target, next_q_all_eval.max(dim=1)[1]) # (N,)
|
||||
else:
|
||||
next_q, _ = get_max(next_q_all) # (N,)
|
||||
next_q, _ = get_max(next_q_all_target) # (N,)
|
||||
|
||||
loss = get_td_errors(q, next_q, rewards, self.config.reward_discount, loss_func=self.config.loss_func)
|
||||
self.model.step(loss.mean())
|
||||
|
|
Загрузка…
Ссылка в новой задаче