GPT2 实现笔记(3)-ReFT微调

• 164 min read • 32666 words

GPT2 实现笔记(3)-ReFT微调

该笔记是对 Stanford CS224-n 的 hw5 的 ReFT 微调部分的整理,用于整理本人实现的对特定任务的 ReFT 微调实现以及对特定下游任务的处理。

1. ReFT 微调总体架构

ReFT 微调总体架构和 LoRA 架构类似:

与 LoRA 不同,ReFT 使用模型返回的 hidden_state 进行微调,因此不需要额外实现自注意力层。

2. ReFT 线性层实现

ReFT 线性层是 ReFT 微调的核心部分,它实现了 ReFT 对 hidden_state 进行干预的核心逻辑。

a.a. 初始化

我们并不直接使用论文中 LoReFT 的公式来构建 ReFT:

ΦLoReFT(h)=h+RT(Wh+bRh)\Phi_{\text{LoReFT}}(h) = h + R^T(Wh + b - Rh)

LoReFT 只学习在固定子空间内的变换 WW。它假设我们事先选定的随机子空间 RR 就是足够好的。

ReFT 本身更像一个框架或思想:在预训练模型前向传播的过程中,对其中某一层或某几层的隐藏状态 hh 进行干预,生成一个新的隐藏状态 hh',而模型本身的权重保持不变。

而是使用更为灵活的设计:将 ReFT 论文中的干预转化为为一个迷你的、低秩的前馈神经网络,它被直接注入到模型中。同时使用下面的策略更新 hidden_state

h=h+Wup(activate(Wdown(h)))h = h + W_{up}(\text{activate}(W_{down}(h)))

这里的 WupW_{up}WdownW_{down} 矩阵是受 LoFT 启发的设计,它们的维度形状也和 LoRA 中的AABB矩阵类似。

我们的 Adapter 模块设计直接让模型学习如何进行降维和升维。这意味着让模型自己去发现哪个子空间是最重要的,这在理论上可能比一个固定的随机子空间更强大、更灵活。

这样的 ReFT 实现被看作是将 LoRA 的思想从“修改权重”迁移到了“修改激活”上,并加入了一个非线性激活函数来增强效果。这种简介有效的结构是Adapter模块的常见实现方式。

def __init__(
    self,
    hidden_size: int,
    rank: int = 4,
    alpha: float = 16.0,
    dropout: float = 0.0,
    activation: str = 'relu',
):
    super().__init__()
        
    self.hidden_size = hidden_size
    self.rank = rank
    self.alpha = alpha
    self.scaling = alpha / rank if rank > 0 else 0.0
        
    # ReFT intervention matrices (low-rank decomposition)
    if rank > 0:
        # Down projection: hidden_size -> rank
        self.reft_down = nn.Parameter(torch.zeros(rank, hidden_size))
        # Up projection: rank -> hidden_size  
        self.reft_up = nn.Parameter(torch.zeros(hidden_size, rank))
        self.reft_dropout = nn.Dropout(dropout)
            
        # Initialize ReFT matrices
        nn.init.kaiming_uniform_(self.reft_down, a=math.sqrt(5))
        nn.init.zeros_(self.reft_up)
            
        # Activation function for the intervention
        if activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'gelu':
            self.activation = nn.GELU()
        elif activation == 'tanh':
            self.activation = nn.Tanh()
        else:
            raise ValueError(f"Unsupported activation: {activation}")

b.b. forward 模块

ReFT 的 forward 模块的详细流程如下:

  1. 获取当前的隐藏层 hidden_state
  2. hidden_state 进行干预:
# Down projection: [..., hidden_size] -> [..., rank]
down_proj = torch.matmul(hidden_states, self.reft_down.T)
        
# Apply activation
activated = self.activation(down_proj)
        
# Apply dropout
activated = self.reft_dropout(activated)
        
# Up projection: [..., rank] -> [..., hidden_size]
up_proj = torch.matmul(activated, self.reft_up.T)
        
# Apply scaling and add to original states (residual connection)
intervention = up_proj * self.scaling
  1. 进行残差连接。
# Apply scaling and add to original states (residual connection)
        intervention = up_proj * self.scaling
        modified_states = original_states + intervention

3. ReFTGPT2 层实现

ReFTGPT2 层将前面的 ReFT 组件注入到 GPT2Model 中。这种注入式设计从 ReFTGPT2Layer 初始化的方式就能看出来。

a.a. 初始化

ReFTGPT2Layer 的初始化和 LoRAGPT2Layer 类似,都是预先创建好需要的组件:

# ReFT intervention modules (only these are trainable)
self.reft_interventions = nn.ModuleDict()
        
# Apply interventions based on configuration
if 'attention' in reft_config.intervention_locations:
    self.reft_interventions['attention'] = ReFTIntervention(
        hidden_size=config.hidden_size,
        rank=reft_config.rank,
        alpha=reft_config.alpha,
        dropout=reft_config.dropout,
        activation=reft_config.activation
    )
            
    if 'ffn' in reft_config.intervention_locations:
        self.reft_interventions['ffn'] = ReFTIntervention(
            hidden_size=config.hidden_size,
            rank=reft_config.rank,
            alpha=reft_config.alpha,
            dropout=reft_config.dropout,
            activation=reft_config.activation
        )

但是 ReFTGPT2Layer 的初始化有两个特别的地方:

  • ReFT相关配置存放在 nn.ModuleDict() 中。
  • 我们在 ReFTGPT2Layer 的初始化中完成对原有模型参数的冻结。
# ReFT intervention modules (only these are trainable)
self.reft_interventions = nn.ModuleDict()
        
# Apply interventions based on configuration
if 'attention' in reft_config.intervention_locations:
    self.reft_interventions['attention'] = ReFTIntervention(
        hidden_size=config.hidden_size,
        rank=reft_config.rank,
        alpha=reft_config.alpha,
        dropout=reft_config.dropout,
        activation=reft_config.activation
    )
            
if 'ffn' in reft_config.intervention_locations:
    self.reft_interventions['ffn'] = ReFTIntervention(
        hidden_size=config.hidden_size,
        rank=reft_config.rank,
        alpha=reft_config.alpha,
        dropout=reft_config.dropout,
        activation=reft_config.activation
    )

......

# Freeze all original parameters
self.freeze_original_parameters()

和 LoRA 不同,ReFT并没有实现一个单独的 ReFTLayer,而是复用 GPT2Model 原有的层级结构,因此需要使用 ModuleDict 把多个独立的干预模块“注入”到一个完整的、原始的层结构中。而 LoRA 实现了自己的层结构 LoRALayer,直接替换原有层级结构即可。

b.b. forward 模块

ReFT 微调的 forward 模块和 GPT2Layer 基本一致,不过对从注意力层和前馈网络取出的 hidden_state,需要使用 ReFT 组件进行干预:

# Self-attention with pre-layer norm (same as original)
ln_output = self.attention_layer_norm(hidden_states)
att_output = self.self_attention(ln_output, attention_mask)
hidden_states = self.add(hidden_states, att_output, self.attention_dense, self.attention_dropout)
        
# Apply ReFT intervention after attention if configured
if 'attention' in self.reft_interventions:
    hidden_states = self.reft_interventions['attention'](hidden_states)
        
# Feed-forward with pre-layer norm (same as original)
ln_output = self.out_layer_norm(hidden_states)
interm_output = self.interm_af(self.interm_dense(ln_output))
hidden_states = self.add(hidden_states, interm_output, self.out_dense, self.out_dropout)
        
# Apply ReFT intervention after FFN if configured
if 'ffn' in self.reft_interventions:
    hidden_states = self.reft_interventions['ffn'](hidden_states)
        
return hidden_states

4. ReFTGPT2 模型实现

同样地,我们将 ReFTGPT2Layer 组装到我们的模型中:

self.gpt_layers = nn.ModuleList([
    ReFTGPT2Layer(config, self.reft_config) 
    if self.reft_config.intervention_layers is None or i in self.reft_config.intervention_layers
    else ReFTGPT2Layer(config, None)  # No intervention for this layer
    for i in range(config.num_hidden_layers)
])

注意 ReFTGPT2ModelLoRAGPT2Model 的区别:在 ReFTGPT2Model 中,我们根据我们的 reft_config 选择性应用 ReFTGPT2Layer;但在 LoRAGPT2Model 中,我们默认将所有层都替换为 LoRAGPT2Layer

Comments

Total words: 32666