GPT2 实现笔记(2)-LoRA微调

• 172 min read • 34329 words
Tags: LLM NLP
Categories: NLP

GPT2 实现笔记(2)-LoRA微调

该笔记是对 Stanford CS224-n 的 hw5 的 LoRA 微调部分的整理,用于整理本人实现的对特定任务的 LoRA 微调实现以及对特定下游任务的处理。

1. LoRA 微调总体架构

我们采用和原有GPT2模型实现类似的架构实现带LoRA微调的GPT2模型,详细架构如下:

2. LoRA 线性层实现

LoRA 线性层是 LoRA 微调的核心部分,它实现了 LoRA 微调包括低秩矩阵AABB、旁路注入结构等核心组件。

a.a. 初始化

在LoRA 中最重要的是以下的参数:

  • 低秩矩阵 AABB
  • rr
  • 缩放因子 scale=αrscale=\frac{\alpha}{r}

同时,我们还需要冻结住原始模型线性层的参数,不在训练中改变它们。

首先我们确定AABB的形状。我们需要将大小为 in_features 的输入映射到大小为 rankrank 的低秩空间中,然后再映射回原来的大小 out_features 中:

inputATBToutput\text{input} \rightarrow{A^T} \rightarrow{B^T} \rightarrow \text{output}

因此 AA 的形状为 r, in_featuresBB 的形状为 out_features, r

其他的配置信息直接读取 lora_config 对应内容即可:

def __init__(
    self,
    in_features: int,
    out_features: int,
    rank: int = 4,
    alpha: float = 16.0,
    dropout: float = 0.0,
    bias: bool = True,
):
    super().__init__()
        
    self.in_features = in_features
    self.out_features = out_features
    self.rank = rank
    self.alpha = alpha
    self.scaling = alpha / rank if rank > 0 else 0.0
        
    # Original linear layer (will be frozen)
    self.linear = nn.Linear(in_features, out_features, bias=bias)
        
    # LoRA matrices
    if rank > 0:
        self.lora_A = nn.Parameter(torch.zeros(rank, in_features))
        self.lora_B = nn.Parameter(torch.zeros(out_features, rank))
        self.lora_dropout = nn.Dropout(dropout)
            
        # Initialize LoRA matrices
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B)
        
    # Freeze the original linear layer
    for param in self.linear.parameters():
        param.requires_grad = False

使用如上的初始化AABB矩阵的方法,既可以保证AB=0AB=0、让LoRA模型最开始可以接收到没有损失的原模型内容,同时在开始微调后参数能够比较稳定地变化。

b.b. forward 模块

LoRA的 forward 模块包含 LoRA 微调的核心计算步骤:

h=W0x+ΔWx=W0x+BAxh = W_0x + ΔW x = W_0x + BAx

def forward(self, x: torch.Tensor) -> torch.Tensor:
    # Original linear transformation (frozen)
    result = self.linear(x)

    if self.rank > 0:
        lora_result = self.lora_dropout(x) @ self.lora_A.T @ self.lora_B.T
        result = result + lora_result
    
    return result

这部分并不包含原始的预训练模型的参数,也就是所谓的“参数被冻结了”。

c.c. convert_linear_to_lora 模块

一个简单的辅助函数,可以简洁地将线性层转换成我们的LoRA线性层:

def convert_linear_to_lora(
    linear_layer: nn.Linear,
    rank: int = 4,
    alpha: float = 16.0,
    dropout: float = 0.0
) -> LoRALayer:
    lora_layer = LoRALayer(
        in_features=linear_layer.in_features,
        out_features=linear_layer.out_features,
        rank=rank,
        alpha=alpha,
        dropout=dropout,
        bias=linear_layer.bias is not None
    )
    
    # Copy the weights from the original layer
    lora_layer.load_pretrained_weights(
        linear_layer.weight.data,
        linear_layer.bias.data if linear_layer.bias is not None else None
    )
    
    return lora_layer

3. LoRA 注意力层实现

LoRA 注意力层的实现较为简单,只需要把GPT2的注意力层的线性层参数换成 LoRALayer 即可。同时我们需要将GPT2预训练模型的参数权重迁移到我们的注意力层中:

@classmethod
def from_pretrained_attention(cls, original_attention, config, lora_config=None):
    lora_attention = cls(config, lora_config)
        
    # Load pretrained weights into LoRA layers
    lora_attention.query.load_pretrained_weights(
        original_attention.query.weight.data,
        original_attention.query.bias.data if original_attention.query.bias is not None else None
    )
    lora_attention.key.load_pretrained_weights(
        original_attention.key.weight.data,
        original_attention.key.bias.data if original_attention.key.bias is not None else None
    )
    lora_attention.value.load_pretrained_weights(
        original_attention.value.weight.data,
        original_attention.value.bias.data if original_attention.value.bias is not None else None
    )
        
    return lora_attention

其它代码实现略过。

4. LoRAGPT2 层实现

a.a. 总体设计

LoRAGPT2 层将 LoRA 自注意力层和 LoRA 线性层组合起来,构成最终的 LoRAGPT2Layer 层。这部分的实现也较为简单,将自注意力方法和线性层换成先前实现的 LoRA 组件即可。

同样地,我们也需要将预训练好的 GPT2 模型的参数权重加载到 LoRAGPT2 层中:

@classmethod
def from_pretrained_layer(cls, original_layer, config, lora_config=None):
    lora_layer = cls(config, lora_config)
        
    # Load attention weights
    lora_layer.self_attention = LoRACausalSelfAttention.from_pretrained_attention(
        original_layer.self_attention, config, lora_config
    )
        
    # Load attention dense layer weights
    lora_layer.attention_dense.load_pretrained_weights(
        original_layer.attention_dense.weight.data,
        original_layer.attention_dense.bias.data if original_layer.attention_dense.bias is not None else None
    )
        
    # Load feed-forward weights
    lora_layer.interm_dense.load_pretrained_weights(
        original_layer.interm_dense.weight.data,
        original_layer.interm_dense.bias.data if original_layer.interm_dense.bias is not None else None
    )
    lora_layer.out_dense.load_pretrained_weights(
        original_layer.out_dense.weight.data,
        original_layer.out_dense.bias.data if original_layer.out_dense.bias is not None else None
    )
        
    # Copy layer norm parameters (these will remain trainable)
    lora_layer.attention_layer_norm.load_state_dict(original_layer.attention_layer_norm.state_dict())
    lora_layer.out_layer_norm.load_state_dict(original_layer.out_layer_norm.state_dict())
        
    return lora_layer

b.b. logit 计算

在实际应用中,我们并不能直接用 hidden_state 进行处理,而是需要将它转换成逻辑单元(logit),然后经过 softmax、交叉熵损失计算,得到用来反向传播的信息。这个过程可以通过将 hidden_state 与嵌入矩阵的转置 ETE^T 相乘得到:

def hidden_state_to_token(self, hidden_state):
    return torch.matmul(hidden_state, self.word_embedding.weight.T)

从直觉角度,大模型的输入过程相当于一个 wordvecword \rightarrow vec 的过程,在模型的最后、我们得到最终的向量后,我们当然要通过 vecwordvec \rightarrow word 的解码过程,来获取这个结果向量的实际含义。而 hidden_state * E^T 这个操作,就是在高效地完成这个过程。它hidden_stateE 中的每一个词向量进行一次“相似度”计算(点积),一次性返回 hidden_state 与词汇表中所有词的相似度分数

5. LoRAGPT2 模型实现

实现完 LoRA 的这些层后,我们在 LoRAGPT2Model 中将它们组装起来。这和我们在 GPT2 实现中的做法类似。我们使用 nn.ModuleList 将我们的 LoRAGPT2Layer 层应用到模型中:

self.gpt_layers = nn.ModuleList([
    LoRAGPT2Layer(config, self.lora_config) for _ in range(config.num_hidden_layers)
])

LoRAGPT2Model 并不继承 GPT2Model,而是继承最基础的 GPTPreTrainedModel,因为它们所依赖的组件不同。不过大部分实现都是一致的。同时 LoRAGPT2Model 会从 GPT2Model 中加载预训练好的参数权重。

注:这只是本人的实现方法,也可能有直接基于 GPT2Model 改造的实现方法。

6. ParaPharse Detection 下游任务适配

下面我们以 hw5 中提出的 paraPharse Detection 任务为例详细讲解如何根据具体的下游任务进行微调。

我们使用一个新的类 LoRAParaphraseGPT 来封装整个下游任务训练流程。

a.a. 初始化

在初始化阶段,我们需要创建下游任务需要的组件:

  • 实际执行微调的 LoRAGPT2Model 模型。我们调用它的 from_pretrained 方法,让这个模型的初始参数与预训练过的 GPT2 模型一致。
  • 用于 ParaPharse Detection 分类任务的 token。

b.b. forward 模块

forward 模块实现了完成 ParaPharse Detection 任务的核心逻辑。具体流程如下:

  1. 我们获取 LoRAGPT2Model 的输出:
outputs = self.gpt(input_ids, attention_mask)

在PyTorch中,使用 Model(input) 调用一个 nn.Module 实例时,会自动调用它的 forward 方法。

  1. 我们只需要返回结果的最后一个token,因为它通常是对前面内容概括性最强的。我们使用 LoRAGPT2Model 将这个 token 转换为对应的逻辑变量:
last_token_hidden = outputs['last_token']  # [batch_size, hidden_size]
        
# Convert to token logits using weight tying
all_token_logits = self.gpt.hidden_state_to_token(last_token_hidden)  # [batch_size, vocab_size]
  1. 我们从获取的所有词的逻辑变量中提取我们需要的 yesno 的逻辑变量,以便之后的操作:
# Extract logits for "yes" and "no" tokens
yes_logits = all_token_logits[:, self.yes_token_id]  # [batch_size]
no_logits = all_token_logits[:, self.no_token_id]   # [batch_size]
  1. yes_logitsno_logits 都是形状为 [8] 的张量,我们把它堆叠成 [yes_logits, no_logits] 的形式,这样就可以直接发送给 F.cross_entropy 进行交叉熵计算了:
logits = torch.stack([no_logits, yes_logits], dim=1)  # [batch_size, 2]

c.c. train 模块

我们通过一个简单的 train 函数训练模型,具体流程如下:

  1. 加载数据集。
  2. 初始化模型、早停机制类等。
  3. 开始进行 epochs 次训练。每次训练:
  • 将模型切换到训练模式,并遍历所有训练批次:
model.train()
train_loss = 0
num_train_batches = 0
        
for batch in tqdm(para_train_dataloader, desc=f'train-{epoch}', disable=TQDM_DISABLE):
  • 从数据集中提取 LoRAGPT2Modelforward 推理方法需要的参数:
    • 从当前批次中取出 token_idsattention_mask,并用 .to(device) 把它们发送到GPU上,准备进行高速计算。
    • 从当前批次中取出 label,将从数据集中读出的原始 label,转换成损失函数能够理解的0-1分类索引。
label_ids = batch['labels'].to(device).long()
binary_labels = torch.zeros(label_ids.size(0), dtype=torch.long, device=device)
binary_labels[label_ids == YES_TOKEN_ID] = 1  # "yes" -> class 1
binary_labels[label_ids == NO_TOKEN_ID] = 0   # "no" -> class 0
labels = binary_labels
  1. 核心训练五步曲: 这是PyTorch训练的“标准动作”,整个学习过程包括:
    • 梯度清零。在进行新一轮的计算之前,必须清除上一轮留下的旧梯度。否则,梯度会累积,导致错误的更新。
    • 前向传播 (Forward Pass)。将一批数据喂给模型,模型进行一系列计算,最终给出它的预测结果 logits
    • 计算损失 (Loss Calculation)。将模型的预测 logits 与真实的 labels 进行比较,计算出一个交叉熵损失 loss 值。这个数值代表了模型这次“错得有多离谱”。
    • 反向传播 (Backward Pass)。PyTorch会根据loss值,自动计算出LoRA的参数应该如何调整才能让loss变小。这个“调整方向”就是梯度。
    • 参数更新 (Parameter Update)。优化器(AdamW)根据上一步算出的梯度,去实际地更新 LoRA 矩阵的权重。这样,模型就完成了一次学习。
# Forward pass
optimizer.zero_grad()
logits = model(b_ids, b_mask)
loss = F.cross_entropy(logits, labels, reduction='mean')
            
# Backward pass
loss.backward()
optimizer.step()

train_loss += loss.item()
num_train_batches += 1
  1. 记录当前批次的损失值,用于后续计算整个Epoch的平均训练损失:
train_loss = train_loss / num_train_batches
  1. 接下来我们就需要检验训练的效果了,我们使用 dev set 来检验,包括如下步骤:
    • 从“训练模式”切换到“评估模式”。
    • 关闭梯度计算。
    • 遍历所有训练批次,计算验证损失。
model.eval()
val_loss = 0
num_val_batches = 0
        
with torch.no_grad():
    for batch in tqdm(para_dev_dataloader, desc=f'val-{epoch}', disable=TQDM_DISABLE):
        b_ids = batch['token_ids'].to(device)
        b_mask = batch['attention_mask'].to(device)
                
        # Labels processing for validation loss
        label_ids = batch['labels'].to(device).long()
        binary_labels = torch.zeros(label_ids.size(0), dtype=torch.long, device=device)
        binary_labels[label_ids == YES_TOKEN_ID] = 1  # "yes" -> class 1
        binary_labels[label_ids == NO_TOKEN_ID] = 0   # "no" -> class 0
        labels = binary_labels
                
        logits = model(b_ids, b_mask)
        loss = F.cross_entropy(logits, labels, reduction='mean')
                
        val_loss += loss.item()
        num_val_batches += 1
  1. 评估与报告:在每个Epoch的最后,我们有了 train_lossval_loss,我们计算并打印这些指标:
val_loss = val_loss / num_val_batches

# Evaluate accuracy on dev set
dev_acc, dev_f1, *_ = model_eval_paraphrase(para_dev_dataloader, model, device)

# Check for improvement and save best model
improvement_status = ""
if val_loss < best_val_loss:
    best_val_loss = val_loss
    best_dev_acc = dev_acc
    best_epoch = epoch
    save_model(model, optimizer, args, args.filepath)
    improvement_status = " ⭐ NEW BEST!"

print(f"Epoch {epoch}: train loss :: {train_loss:.3f}, val loss :: {val_loss:.3f}, dev acc :: {dev_acc:.3f}{improvement_status}")
        
# Early stopping check
if early_stopping(val_loss):
    print(f"Early stopping triggered at epoch {epoch}")
    print(f"Best validation loss: {best_val_loss:.3f} at epoch {best_epoch}")
    break

print(f"\nTraining completed! Best validation loss: {best_val_loss:.3f} at epoch {best_epoch}")
print(f"Best dev accuracy: {best_dev_acc:.3f}")

这样,模型的训练就完成了。

Comments

Total words: 34329