GPT2 实现笔记(2)-LoRA微调
GPT2 实现笔记(2)-LoRA微调
该笔记是对 Stanford CS224-n 的 hw5 的 LoRA 微调部分的整理,用于整理本人实现的对特定任务的 LoRA 微调实现以及对特定下游任务的处理。
1. LoRA 微调总体架构
我们采用和原有GPT2模型实现类似的架构实现带LoRA微调的GPT2模型,详细架构如下:
2. LoRA 线性层实现
LoRA 线性层是 LoRA 微调的核心部分,它实现了 LoRA 微调包括低秩矩阵、、旁路注入结构等核心组件。
初始化
在LoRA 中最重要的是以下的参数:
- 低秩矩阵 ,
- 秩
- 缩放因子
同时,我们还需要冻结住原始模型线性层的参数,不在训练中改变它们。
首先我们确定、的形状。我们需要将大小为 in_features
的输入映射到大小为 的低秩空间中,然后再映射回原来的大小 out_features
中:
因此 的形状为 r, in_features
, 的形状为 out_features, r
。
其他的配置信息直接读取 lora_config
对应内容即可:
def __init__(
self,
in_features: int,
out_features: int,
rank: int = 4,
alpha: float = 16.0,
dropout: float = 0.0,
bias: bool = True,
):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.rank = rank
self.alpha = alpha
self.scaling = alpha / rank if rank > 0 else 0.0
# Original linear layer (will be frozen)
self.linear = nn.Linear(in_features, out_features, bias=bias)
# LoRA matrices
if rank > 0:
self.lora_A = nn.Parameter(torch.zeros(rank, in_features))
self.lora_B = nn.Parameter(torch.zeros(out_features, rank))
self.lora_dropout = nn.Dropout(dropout)
# Initialize LoRA matrices
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
# Freeze the original linear layer
for param in self.linear.parameters():
param.requires_grad = False
使用如上的初始化、矩阵的方法,既可以保证、让LoRA模型最开始可以接收到没有损失的原模型内容,同时在开始微调后参数能够比较稳定地变化。
forward
模块
LoRA的 forward
模块包含 LoRA 微调的核心计算步骤:
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Original linear transformation (frozen)
result = self.linear(x)
if self.rank > 0:
lora_result = self.lora_dropout(x) @ self.lora_A.T @ self.lora_B.T
result = result + lora_result
return result
这部分并不包含原始的预训练模型的参数,也就是所谓的“参数被冻结了”。
convert_linear_to_lora
模块
一个简单的辅助函数,可以简洁地将线性层转换成我们的LoRA线性层:
def convert_linear_to_lora(
linear_layer: nn.Linear,
rank: int = 4,
alpha: float = 16.0,
dropout: float = 0.0
) -> LoRALayer:
lora_layer = LoRALayer(
in_features=linear_layer.in_features,
out_features=linear_layer.out_features,
rank=rank,
alpha=alpha,
dropout=dropout,
bias=linear_layer.bias is not None
)
# Copy the weights from the original layer
lora_layer.load_pretrained_weights(
linear_layer.weight.data,
linear_layer.bias.data if linear_layer.bias is not None else None
)
return lora_layer
3. LoRA 注意力层实现
LoRA 注意力层的实现较为简单,只需要把GPT2的注意力层的线性层参数换成 LoRALayer
即可。同时我们需要将GPT2预训练模型的参数权重迁移到我们的注意力层中:
@classmethod
def from_pretrained_attention(cls, original_attention, config, lora_config=None):
lora_attention = cls(config, lora_config)
# Load pretrained weights into LoRA layers
lora_attention.query.load_pretrained_weights(
original_attention.query.weight.data,
original_attention.query.bias.data if original_attention.query.bias is not None else None
)
lora_attention.key.load_pretrained_weights(
original_attention.key.weight.data,
original_attention.key.bias.data if original_attention.key.bias is not None else None
)
lora_attention.value.load_pretrained_weights(
original_attention.value.weight.data,
original_attention.value.bias.data if original_attention.value.bias is not None else None
)
return lora_attention
其它代码实现略过。
4. LoRAGPT2 层实现
总体设计
LoRAGPT2 层将 LoRA 自注意力层和 LoRA 线性层组合起来,构成最终的 LoRAGPT2Layer
层。这部分的实现也较为简单,将自注意力方法和线性层换成先前实现的 LoRA 组件即可。
同样地,我们也需要将预训练好的 GPT2 模型的参数权重加载到 LoRAGPT2
层中:
@classmethod
def from_pretrained_layer(cls, original_layer, config, lora_config=None):
lora_layer = cls(config, lora_config)
# Load attention weights
lora_layer.self_attention = LoRACausalSelfAttention.from_pretrained_attention(
original_layer.self_attention, config, lora_config
)
# Load attention dense layer weights
lora_layer.attention_dense.load_pretrained_weights(
original_layer.attention_dense.weight.data,
original_layer.attention_dense.bias.data if original_layer.attention_dense.bias is not None else None
)
# Load feed-forward weights
lora_layer.interm_dense.load_pretrained_weights(
original_layer.interm_dense.weight.data,
original_layer.interm_dense.bias.data if original_layer.interm_dense.bias is not None else None
)
lora_layer.out_dense.load_pretrained_weights(
original_layer.out_dense.weight.data,
original_layer.out_dense.bias.data if original_layer.out_dense.bias is not None else None
)
# Copy layer norm parameters (these will remain trainable)
lora_layer.attention_layer_norm.load_state_dict(original_layer.attention_layer_norm.state_dict())
lora_layer.out_layer_norm.load_state_dict(original_layer.out_layer_norm.state_dict())
return lora_layer
logit
计算
在实际应用中,我们并不能直接用 hidden_state
进行处理,而是需要将它转换成逻辑单元(logit),然后经过 softmax、交叉熵损失计算,得到用来反向传播的信息。这个过程可以通过将 hidden_state
与嵌入矩阵的转置 相乘得到:
def hidden_state_to_token(self, hidden_state):
return torch.matmul(hidden_state, self.word_embedding.weight.T)
从直觉角度,大模型的输入过程相当于一个 的过程,在模型的最后、我们得到最终的向量后,我们当然要通过 的解码过程,来获取这个结果向量的实际含义。而
hidden_state * E^T
这个操作,就是在高效地完成这个过程。它将hidden_state
与E
中的每一个词向量进行一次“相似度”计算(点积),一次性返回hidden_state
与词汇表中所有词的相似度分数。
5. LoRAGPT2 模型实现
实现完 LoRA 的这些层后,我们在 LoRAGPT2Model
中将它们组装起来。这和我们在 GPT2 实现中的做法类似。我们使用 nn.ModuleList
将我们的 LoRAGPT2Layer
层应用到模型中:
self.gpt_layers = nn.ModuleList([
LoRAGPT2Layer(config, self.lora_config) for _ in range(config.num_hidden_layers)
])
LoRAGPT2Model
并不继承 GPT2Model
,而是继承最基础的 GPTPreTrainedModel
,因为它们所依赖的组件不同。不过大部分实现都是一致的。同时 LoRAGPT2Model
会从 GPT2Model
中加载预训练好的参数权重。
注:这只是本人的实现方法,也可能有直接基于
GPT2Model
改造的实现方法。
6. ParaPharse Detection 下游任务适配
下面我们以 hw5 中提出的 paraPharse Detection 任务为例详细讲解如何根据具体的下游任务进行微调。
我们使用一个新的类 LoRAParaphraseGPT
来封装整个下游任务训练流程。
初始化
在初始化阶段,我们需要创建下游任务需要的组件:
- 实际执行微调的
LoRAGPT2Model
模型。我们调用它的from_pretrained
方法,让这个模型的初始参数与预训练过的 GPT2 模型一致。 - 用于 ParaPharse Detection 分类任务的 token。
forward
模块
forward
模块实现了完成 ParaPharse Detection 任务的核心逻辑。具体流程如下:
- 我们获取
LoRAGPT2Model
的输出:
outputs = self.gpt(input_ids, attention_mask)
在PyTorch中,使用
Model(input)
调用一个nn.Module
实例时,会自动调用它的forward
方法。
- 我们只需要返回结果的最后一个token,因为它通常是对前面内容概括性最强的。我们使用
LoRAGPT2Model
将这个 token 转换为对应的逻辑变量:
last_token_hidden = outputs['last_token'] # [batch_size, hidden_size]
# Convert to token logits using weight tying
all_token_logits = self.gpt.hidden_state_to_token(last_token_hidden) # [batch_size, vocab_size]
- 我们从获取的所有词的逻辑变量中提取我们需要的
yes
和no
的逻辑变量,以便之后的操作:
# Extract logits for "yes" and "no" tokens
yes_logits = all_token_logits[:, self.yes_token_id] # [batch_size]
no_logits = all_token_logits[:, self.no_token_id] # [batch_size]
yes_logits
和no_logits
都是形状为[8]
的张量,我们把它堆叠成[yes_logits, no_logits]
的形式,这样就可以直接发送给F.cross_entropy
进行交叉熵计算了:
logits = torch.stack([no_logits, yes_logits], dim=1) # [batch_size, 2]
train
模块
我们通过一个简单的 train
函数训练模型,具体流程如下:
- 加载数据集。
- 初始化模型、早停机制类等。
- 开始进行
epochs
次训练。每次训练:
- 将模型切换到训练模式,并遍历所有训练批次:
model.train()
train_loss = 0
num_train_batches = 0
for batch in tqdm(para_train_dataloader, desc=f'train-{epoch}', disable=TQDM_DISABLE):
- 从数据集中提取
LoRAGPT2Model
的forward
推理方法需要的参数:- 从当前批次中取出
token_ids
和attention_mask
,并用.to(device)
把它们发送到GPU上,准备进行高速计算。 - 从当前批次中取出
label
,将从数据集中读出的原始label
,转换成损失函数能够理解的0-1分类索引。
- 从当前批次中取出
label_ids = batch['labels'].to(device).long()
binary_labels = torch.zeros(label_ids.size(0), dtype=torch.long, device=device)
binary_labels[label_ids == YES_TOKEN_ID] = 1 # "yes" -> class 1
binary_labels[label_ids == NO_TOKEN_ID] = 0 # "no" -> class 0
labels = binary_labels
- 核心训练五步曲: 这是PyTorch训练的“标准动作”,整个学习过程包括:
- 梯度清零。在进行新一轮的计算之前,必须清除上一轮留下的旧梯度。否则,梯度会累积,导致错误的更新。
- 前向传播 (Forward Pass)。将一批数据喂给模型,模型进行一系列计算,最终给出它的预测结果
logits
。 - 计算损失 (Loss Calculation)。将模型的预测
logits
与真实的labels
进行比较,计算出一个交叉熵损失loss
值。这个数值代表了模型这次“错得有多离谱”。 - 反向传播 (Backward Pass)。PyTorch会根据loss值,自动计算出LoRA的参数应该如何调整才能让loss变小。这个“调整方向”就是梯度。
- 参数更新 (Parameter Update)。优化器(AdamW)根据上一步算出的梯度,去实际地更新 LoRA 矩阵的权重。这样,模型就完成了一次学习。
# Forward pass
optimizer.zero_grad()
logits = model(b_ids, b_mask)
loss = F.cross_entropy(logits, labels, reduction='mean')
# Backward pass
loss.backward()
optimizer.step()
train_loss += loss.item()
num_train_batches += 1
- 记录当前批次的损失值,用于后续计算整个Epoch的平均训练损失:
train_loss = train_loss / num_train_batches
- 接下来我们就需要检验训练的效果了,我们使用 dev set 来检验,包括如下步骤:
- 从“训练模式”切换到“评估模式”。
- 关闭梯度计算。
- 遍历所有训练批次,计算验证损失。
model.eval()
val_loss = 0
num_val_batches = 0
with torch.no_grad():
for batch in tqdm(para_dev_dataloader, desc=f'val-{epoch}', disable=TQDM_DISABLE):
b_ids = batch['token_ids'].to(device)
b_mask = batch['attention_mask'].to(device)
# Labels processing for validation loss
label_ids = batch['labels'].to(device).long()
binary_labels = torch.zeros(label_ids.size(0), dtype=torch.long, device=device)
binary_labels[label_ids == YES_TOKEN_ID] = 1 # "yes" -> class 1
binary_labels[label_ids == NO_TOKEN_ID] = 0 # "no" -> class 0
labels = binary_labels
logits = model(b_ids, b_mask)
loss = F.cross_entropy(logits, labels, reduction='mean')
val_loss += loss.item()
num_val_batches += 1
- 评估与报告:在每个Epoch的最后,我们有了
train_loss
和val_loss
,我们计算并打印这些指标:
val_loss = val_loss / num_val_batches
# Evaluate accuracy on dev set
dev_acc, dev_f1, *_ = model_eval_paraphrase(para_dev_dataloader, model, device)
# Check for improvement and save best model
improvement_status = ""
if val_loss < best_val_loss:
best_val_loss = val_loss
best_dev_acc = dev_acc
best_epoch = epoch
save_model(model, optimizer, args, args.filepath)
improvement_status = " ⭐ NEW BEST!"
print(f"Epoch {epoch}: train loss :: {train_loss:.3f}, val loss :: {val_loss:.3f}, dev acc :: {dev_acc:.3f}{improvement_status}")
# Early stopping check
if early_stopping(val_loss):
print(f"Early stopping triggered at epoch {epoch}")
print(f"Best validation loss: {best_val_loss:.3f} at epoch {best_epoch}")
break
print(f"\nTraining completed! Best validation loss: {best_val_loss:.3f} at epoch {best_epoch}")
print(f"Best dev accuracy: {best_dev_acc:.3f}")
这样,模型的训练就完成了。
Comments