Fix dropout in TFMobileBert (#5150)
This commit is contained in:
Родитель
5ed94b2312
Коммит
f1679d7c48
|
@ -370,7 +370,7 @@ class TFMobileBertOutput(tf.keras.layers.Layer):
|
||||||
|
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
if not self.use_bottleneck:
|
if not self.use_bottleneck:
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states, training=training)
|
||||||
hidden_states = self.LayerNorm(hidden_states + residual_tensor_1)
|
hidden_states = self.LayerNorm(hidden_states + residual_tensor_1)
|
||||||
else:
|
else:
|
||||||
hidden_states = self.LayerNorm(hidden_states + residual_tensor_1)
|
hidden_states = self.LayerNorm(hidden_states + residual_tensor_1)
|
||||||
|
|
Загрузка…
Ссылка в новой задаче