* issue#264 bug fix

* double DQN bug fix

* rm print

Co-authored-by: ysqyang <v-yangqi@microsoft.com>
This commit is contained in:
ysqyang 2021-04-01 09:52:29 +08:00 коммит произвёл GitHub
Родитель 7e02baeb3d
Коммит df0e4a9d18
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 4 добавлений и 3 удалений

Просмотреть файл

@ -109,11 +109,12 @@ class DQN(AbsAgent):
q_all = self._get_q_values(states)
q = select_by_actions(q_all, actions)
next_q_all = self._get_q_values(next_states, is_eval=False, training=False)
next_q_all_target = self._get_q_values(next_states, is_eval=False, training=False)
if self.config.double:
next_q = select_by_actions(next_q_all) # (N,)
next_q_all_eval = self._get_q_values(next_states, training=False)
next_q = select_by_actions(next_q_all_target, next_q_all_eval.max(dim=1)[1]) # (N,)
else:
next_q, _ = get_max(next_q_all) # (N,)
next_q, _ = get_max(next_q_all_target) # (N,)
loss = get_td_errors(q, next_q, rewards, self.config.reward_discount, loss_func=self.config.loss_func)
self.model.step(loss.mean())