GPT2 实现笔记(3)-ReFT微调
GPT2 实现笔记(3)-ReFT微调
该笔记是对 Stanford CS224-n 的 hw5 的 ReFT 微调部分的整理,用于整理本人实现的对特定任务的 ReFT 微调实现以及对特定下游任务的处理。
1. ReFT 微调总体架构
ReFT 微调总体架构和 LoRA 架构类似:
与 LoRA 不同,ReFT 使用模型返回的 hidden_state
进行微调,因此不需要额外实现自注意力层。
2. ReFT 线性层实现
ReFT 线性层是 ReFT 微调的核心部分,它实现了 ReFT 对 hidden_state
进行干预的核心逻辑。
初始化
我们并不直接使用论文中 LoReFT 的公式来构建 ReFT:
LoReFT 只学习在固定子空间内的变换 。它假设我们事先选定的随机子空间 就是足够好的。
ReFT 本身更像一个框架或思想:在预训练模型前向传播的过程中,对其中某一层或某几层的隐藏状态 进行干预,生成一个新的隐藏状态 ,而模型本身的权重保持不变。
而是使用更为灵活的设计:将 ReFT 论文中的干预转化为为一个迷你的、低秩的前馈神经网络,它被直接注入到模型中。同时使用下面的策略更新 hidden_state
:
这里的 和 矩阵是受 LoFT 启发的设计,它们的维度形状也和 LoRA 中的、矩阵类似。
我们的 Adapter 模块设计直接让模型学习如何进行降维和升维。这意味着让模型自己去发现哪个子空间是最重要的,这在理论上可能比一个固定的随机子空间更强大、更灵活。
这样的 ReFT 实现被看作是将 LoRA 的思想从“修改权重”迁移到了“修改激活”上,并加入了一个非线性激活函数来增强效果。这种简介有效的结构是Adapter模块的常见实现方式。
def __init__(
self,
hidden_size: int,
rank: int = 4,
alpha: float = 16.0,
dropout: float = 0.0,
activation: str = 'relu',
):
super().__init__()
self.hidden_size = hidden_size
self.rank = rank
self.alpha = alpha
self.scaling = alpha / rank if rank > 0 else 0.0
# ReFT intervention matrices (low-rank decomposition)
if rank > 0:
# Down projection: hidden_size -> rank
self.reft_down = nn.Parameter(torch.zeros(rank, hidden_size))
# Up projection: rank -> hidden_size
self.reft_up = nn.Parameter(torch.zeros(hidden_size, rank))
self.reft_dropout = nn.Dropout(dropout)
# Initialize ReFT matrices
nn.init.kaiming_uniform_(self.reft_down, a=math.sqrt(5))
nn.init.zeros_(self.reft_up)
# Activation function for the intervention
if activation == 'relu':
self.activation = nn.ReLU()
elif activation == 'gelu':
self.activation = nn.GELU()
elif activation == 'tanh':
self.activation = nn.Tanh()
else:
raise ValueError(f"Unsupported activation: {activation}")
forward
模块
ReFT 的 forward
模块的详细流程如下:
- 获取当前的隐藏层
hidden_state
。 - 对
hidden_state
进行干预:
# Down projection: [..., hidden_size] -> [..., rank]
down_proj = torch.matmul(hidden_states, self.reft_down.T)
# Apply activation
activated = self.activation(down_proj)
# Apply dropout
activated = self.reft_dropout(activated)
# Up projection: [..., rank] -> [..., hidden_size]
up_proj = torch.matmul(activated, self.reft_up.T)
# Apply scaling and add to original states (residual connection)
intervention = up_proj * self.scaling
- 进行残差连接。
# Apply scaling and add to original states (residual connection)
intervention = up_proj * self.scaling
modified_states = original_states + intervention
3. ReFTGPT2 层实现
ReFTGPT2 层将前面的 ReFT 组件注入到 GPT2Model
中。这种注入式设计从 ReFTGPT2Layer
初始化的方式就能看出来。
初始化
ReFTGPT2Layer
的初始化和 LoRAGPT2Layer
类似,都是预先创建好需要的组件:
# ReFT intervention modules (only these are trainable)
self.reft_interventions = nn.ModuleDict()
# Apply interventions based on configuration
if 'attention' in reft_config.intervention_locations:
self.reft_interventions['attention'] = ReFTIntervention(
hidden_size=config.hidden_size,
rank=reft_config.rank,
alpha=reft_config.alpha,
dropout=reft_config.dropout,
activation=reft_config.activation
)
if 'ffn' in reft_config.intervention_locations:
self.reft_interventions['ffn'] = ReFTIntervention(
hidden_size=config.hidden_size,
rank=reft_config.rank,
alpha=reft_config.alpha,
dropout=reft_config.dropout,
activation=reft_config.activation
)
但是 ReFTGPT2Layer
的初始化有两个特别的地方:
- ReFT相关配置存放在
nn.ModuleDict()
中。 - 我们在
ReFTGPT2Layer
的初始化中完成对原有模型参数的冻结。
# ReFT intervention modules (only these are trainable)
self.reft_interventions = nn.ModuleDict()
# Apply interventions based on configuration
if 'attention' in reft_config.intervention_locations:
self.reft_interventions['attention'] = ReFTIntervention(
hidden_size=config.hidden_size,
rank=reft_config.rank,
alpha=reft_config.alpha,
dropout=reft_config.dropout,
activation=reft_config.activation
)
if 'ffn' in reft_config.intervention_locations:
self.reft_interventions['ffn'] = ReFTIntervention(
hidden_size=config.hidden_size,
rank=reft_config.rank,
alpha=reft_config.alpha,
dropout=reft_config.dropout,
activation=reft_config.activation
)
......
# Freeze all original parameters
self.freeze_original_parameters()
和 LoRA 不同,ReFT并没有实现一个单独的
ReFTLayer
,而是复用GPT2Model
原有的层级结构,因此需要使用ModuleDict
把多个独立的干预模块“注入”到一个完整的、原始的层结构中。而 LoRA 实现了自己的层结构LoRALayer
,直接替换原有层级结构即可。
forward
模块
ReFT 微调的 forward
模块和 GPT2Layer
基本一致,不过对从注意力层和前馈网络取出的 hidden_state
,需要使用 ReFT 组件进行干预:
# Self-attention with pre-layer norm (same as original)
ln_output = self.attention_layer_norm(hidden_states)
att_output = self.self_attention(ln_output, attention_mask)
hidden_states = self.add(hidden_states, att_output, self.attention_dense, self.attention_dropout)
# Apply ReFT intervention after attention if configured
if 'attention' in self.reft_interventions:
hidden_states = self.reft_interventions['attention'](hidden_states)
# Feed-forward with pre-layer norm (same as original)
ln_output = self.out_layer_norm(hidden_states)
interm_output = self.interm_af(self.interm_dense(ln_output))
hidden_states = self.add(hidden_states, interm_output, self.out_dense, self.out_dropout)
# Apply ReFT intervention after FFN if configured
if 'ffn' in self.reft_interventions:
hidden_states = self.reft_interventions['ffn'](hidden_states)
return hidden_states
4. ReFTGPT2 模型实现
同样地,我们将 ReFTGPT2Layer
组装到我们的模型中:
self.gpt_layers = nn.ModuleList([
ReFTGPT2Layer(config, self.reft_config)
if self.reft_config.intervention_layers is None or i in self.reft_config.intervention_layers
else ReFTGPT2Layer(config, None) # No intervention for this layer
for i in range(config.num_hidden_layers)
])
注意
ReFTGPT2Model
和LoRAGPT2Model
的区别:在ReFTGPT2Model
中,我们根据我们的reft_config
选择性应用ReFTGPT2Layer
;但在LoRAGPT2Model
中,我们默认将所有层都替换为LoRAGPT2Layer
。
Comments