fix error of parameter initialization for LoRa embedding

This commit is contained in:
Liu Jianfeng 2023-03-15 08:50:38 +00:00
Родитель d3b25ab9fc
Коммит 9cea83c7f7
1 изменённых файлов: 2 добавлений и 2 удалений

Просмотреть файл

@ -56,8 +56,8 @@ class Embedding(nn.Embedding, LoRALayer):
nn.Embedding.reset_parameters(self)
if hasattr(self, 'lora_A'):
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.zeros_(self.lora_A)
nn.init.normal_(self.lora_B)
nn.init.normal_(self.lora_A)
nn.init.zeros_(self.lora_B)
def train(self, mode: bool = True):
nn.Embedding.train(self, mode)