зеркало из https://github.com/microsoft/LoRA.git
fix error of parameter initialization for LoRa embedding
This commit is contained in:
Родитель
d3b25ab9fc
Коммит
9cea83c7f7
|
@ -56,8 +56,8 @@ class Embedding(nn.Embedding, LoRALayer):
|
|||
nn.Embedding.reset_parameters(self)
|
||||
if hasattr(self, 'lora_A'):
|
||||
# initialize A the same way as the default for nn.Linear and B to zero
|
||||
nn.init.zeros_(self.lora_A)
|
||||
nn.init.normal_(self.lora_B)
|
||||
nn.init.normal_(self.lora_A)
|
||||
nn.init.zeros_(self.lora_B)
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
nn.Embedding.train(self, mode)
|
||||
|
|
Загрузка…
Ссылка в новой задаче