从零学习大模型(十二)-----基于梯度的重要性剪枝(Gradient-based Pruning)

梯度的重要性定义

权重重要性(Weight Importance)

权重重要性通常指的是某个权重参数在模型输出中的影响程度。权重重要性的评估通常基于以下几个方面:

  • 绝对值:一个常见的方法是直接使用权重的绝对值作为其重要性指标。权重越大,表示其对模型输出的影响越大,因此可以认为其重要性越高。
  • 对模型性能的影响:通过在剪枝前后比较模型的性能,可以间接评估某个权重的重要性。若剪除某个权重后,模型性能显著下降,则说明该权重是重要的。

梯度的重要性(Gradient Importance)

梯度的重要性则是基于模型的训练过程来评估权重的重要性,具体体现在以下几个方面:

  • 梯度的计算:在反向传播过程中,每个权重的梯度值反映了该权重对损失函数的影响。如果一个权重的梯度较大,意味着它在当前训练样本下对模型的输出变化有较大贡献。
  • 动态性:与静态的权重重要性不同,梯度的重要性是动态的,会随着训练过程中的样本和损失函数的变化而变化。这意味着在不同的训练阶段或针对不同的数据样本,某些权重的梯度可能会显示出不同的“重要性”。

区别与联系

  • 静态 vs. 动态:权重重要性通常是静态评估,基于权重本身的值;而梯度的重要性是动态的,依赖于当前的训练状态和样本。
  • 应用场景:在剪枝时,可以综合考虑两者。某些方法可能首先基于权重重要性进行初步剪枝,然后使用梯度重要性进行精细调整,以优化剪枝后的模型性能。
  • 互补性:两者可以结合使用,提供更全面的剪枝策略。例如,优先剪除权重绝对值小且梯度小的权重,以最小化对模型性能的影响。

基于梯度的重要性剪枝的工作原理

1.梯度的重要性评估

**损失函数:**假设有一个损失函数 L L L,通常是模型输出与真实标签之间的差异,比如均方误差或交叉熵:
L = 1 N ∑ i = 1 N l ( y i , y ^ i ) L = \frac{1}{N} \sum_{i=1}^{N} l(y_i, \hat{y}_i) L=N1i=1Nl(yi,y^i)
其中 y i y_i yi 是真实标签, y ^ i \hat{y}_i y^i 是模型预测。

**权重的梯度计算:**在反向传播过程中,计算每个权重 w j w_j wj 的梯度:
g j = ∂ L ∂ w j g_j = \frac{\partial L}{\partial w_j} gj=wjL
这里 g j g_j gj 是与权重 w j w_j wj 相关的梯度。

2.确定剪枝阈值

**重要性度量:**通常用梯度的绝对值作为权重的重要性指标:
I j = ∣ g j ∣ I_j = |g_j| Ij=gj
其中 I j I_j Ij 是权重 w j w_j wj 的重要性。

**剪枝阈值:**设定一个阈值 θ \theta θ,可以选择某个比例的权重进行剪枝。例如,保留最重要的 k % k\% k% 的权重:
Threshold = quantile ( I , 1 − k ) \text{Threshold} = \text{quantile}(I, 1 - k) Threshold=quantile(I,1k)

3.执行剪枝

**剪除不重要的权重:**设定剪枝规则,所有重要性低于阈值的权重被剪除:
w j ′ = { 0 if  I j < θ w j otherwise w_j' = \begin{cases} 0 & \text{if } I_j < \theta \\ w_j & \text{otherwise} \end{cases} wj={0wjif Ij<θotherwise
这里 w j ′ w_j' wj剪枝后的权重。

4.重训练

再训练模型:在剪枝后进行再训练,以调整剩余的权重,使模型恢复性能。可以使用梯度下降优化:
w j ← w j − η g j w_j \leftarrow w_j - \eta g_j wjwjηgj
其中 η \eta η学习率。

基于CNN的梯度重要性剪枝代码

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms

# 定义简单的CNN模型
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=5)  # 输入通道改为3(RGB)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5)
        self.fc1 = nn.Linear(32 * 5 * 5, 128)  # 根据输入尺寸调整
        self.fc2 = nn.Linear(128, 10)  # 输出10个类

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 32 * 5 * 5)  # 根据输入尺寸调整
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 训练模型
def train_model(model, train_loader, optimizer, criterion, num_epochs=1):
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)  # 将数据转移到GPU
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()  # 累加损失

        average_loss = total_loss / len(train_loader)  # 计算平均损失
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {average_loss:.4f}')  # 输出平均损失

# 剪枝权重
def prune_weights(model, threshold):
    with torch.no_grad():
        for name, param in model.named_parameters():
            if 'weight' in name:  # 只剪枝权重参数
                importance = param.grad.abs()  # 计算权重的梯度绝对值
                mask = importance < threshold  # 创建剪枝掩码
                param[mask] = 0  # 剪除不重要的权重

# 数据加载
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 标准化
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

# 选择设备:如果有可用的MPS,使用MPS;否则如果有可用的GPU,使用GPU;否则使用CPU
if torch.backends.mps.is_available():
    device = torch.device("mps")  # MacBook M3
elif torch.cuda.is_available():
    device = torch.device("cuda")  # Windows GPU
else:
    device = torch.device("cpu")  # CPU

# 初始化模型、损失函数和优化器
model = SimpleCNN().to(device)  # 将模型转移到GPU
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 训练模型
train_model(model, train_loader, optimizer, criterion, num_epochs=5)

# 计算梯度并剪枝
model.train()  # 确保模型处于训练模式
for data, target in train_loader:
    data, target = data.to(device), target.to(device)  # 将数据转移到GPU
    optimizer.zero_grad()  # 清空之前的梯度
    output = model(data)
    loss = criterion(output, target)
    loss.backward()  # 计算梯度
    break  # 只计算一次梯度用于剪枝

# 剪枝阈值设定
threshold = 0.01  # 设定剪枝阈值
prune_weights(model, threshold)

# 重训练模型
train_model(model, train_loader, optimizer, criterion, num_epochs=5)

print("模型剪枝与重训练完成。")

注意

置为零的权重的影响
  1. 前向传播
    • 在前向传播阶段,设置为零的权重确实不会影响模型的输出。具体来说,如果权重 w j = 0 w_j=0 wj=0,那么在计算输出时,所有与 w j w_j wj相关的激活值都会被乘以零,结果为零。因此,零权重不会对损失函数产生任何贡献。
  2. 损失函数的贡献
    • 损失函数是根据模型输出和真实标签之间的差异计算的。对于那些与零权重相关的部分,它们的输出为零,所以这些权重不会影响损失的计算,也就不会产生梯度(贡献)。
  3. 反向传播
    • 在反向传播过程中,只有对损失函数有贡献的权重才会计算梯度。由于零权重没有参与损失函数的计算,它们的梯度将会是零。

基于LLAMA2梯度重要性剪枝的代码

import torch
from transformers import LlamaForCausalLM, LlamaTokenizer
from transformers import AdamW
from datasets import load_dataset

def gradient_based_pruning(model, pruning_rate):
    """
    对模型进行梯度剪枝
    :param model: 待剪枝的模型
    :param pruning_rate: 剪枝率 (0到1之间的浮点数)
    """
    for name, param in model.named_parameters():
        if 'weight' in name and param.grad is not None:
            # 计算权重的梯度绝对值
            weight_grad = param.grad.abs()
            
            # 计算剪枝阈值
            threshold = torch.quantile(weight_grad, pruning_rate)
            
            # 剪枝
            mask = weight_grad < threshold
            param.data[mask] = 0  # 将低于阈值的权重置为0

# 加载IMDB情感分析数据集
dataset = load_dataset("imdb")

# 加载预训练的LLAMA2模型和分词器
model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b")
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b")

# 准备输入数据(取前1000条样本进行示例)
train_texts = dataset['train']['text'][:1000]
train_labels = dataset['train']['label'][:1000]

# 编码输入数据
inputs = tokenizer(train_texts, padding=True, truncation=True, return_tensors="pt")

# 设置优化器
optimizer = AdamW(model.parameters(), lr=5e-5)

# 训练循环示例
model.train()
for epoch in range(3):  # 假设训练3个epoch
    optimizer.zero_grad()  # 清除梯度
    
    # 获取模型输出
    outputs = model(input_ids=inputs['input_ids'], labels=inputs['input_ids'])
    loss = outputs.loss  # 获取损失
    loss.backward()  # 反向传播

    # 在这里进行梯度剪枝
    pruning_rate = 0.2  # 剪去20%的权重
    gradient_based_pruning(model, pruning_rate)

    optimizer.step()  # 更新模型参数

    print(f"Epoch {epoch + 1}, Loss: {loss.item()}")

# 训练完成后,你可以继续评估剪枝后的模型

基于梯度重要性剪枝的局限性

1. 依赖于梯度信息

  • 基于梯度的重要性剪枝方法依赖于当前的梯度信息来决定哪些参数是重要的。然而,在一些情况下,梯度可能并不能充分反映参数的重要性。例如,某些参数可能在特定的训练阶段或特定的输入上具有重要性,但在当前的梯度计算中未能体现。

2. 训练过程中的不稳定性

  • 梯度信息会随着训练的进行而变化,这可能导致剪枝结果的不一致性。剪枝时使用的阈值选择可能会在不同的训练周期间产生不同的效果,导致剪枝的稳定性和可重复性较差。

3. 对模型性能的影响

  • 剪枝可能会影响模型的性能,尤其是在剪枝过于激进的情况下。模型的表达能力可能受到损害,从而导致在测试集上的性能下降。

4. 剪枝后重训练的需求

  • 一般来说,剪枝后需要对模型进行重训练,以恢复由于剪枝带来的性能损失。这一过程需要额外的计算资源和时间。

5. 缺乏全局视角

  • 基于梯度的剪枝方法通常是局部的,关注的是单个参数的影响。它可能没有考虑到模型结构的全局信息,例如某些参数可能在网络的不同层之间有相互作用,而单独剪枝某一层的参数可能会影响整个模型的性能。

6. 剪枝阈值的选择

  • 剪枝的效果在很大程度上依赖于所选择的阈值。这一阈值往往需要通过实验进行调优,没有统一的标准,这使得在不同任务上应用时可能需要额外的调整和优化。

7. 对动态输入的敏感性

  • 对于输入变化较大的模型(如图像分类中的多种图像处理),基于梯度的剪枝可能无法适应新的数据分布,导致模型在不同输入下表现不稳定。

http://www.niftyadmin.cn/n/5744209.html

相关文章

《Spring Boot从入门到实战》第五章习题答案

5.7 本章练习 1&#xff09;创建Spring Boot Web项目&#xff0c;使用Thymeleaf页面模板引擎实现人员管理模块的功能。 答案&#xff1a; 1. 创建人员实体类 创建一个 Person 实体类&#xff0c;用于定义人员属性 package com.example.demo.bean;import javax.persistence.…

[JAVA]有关一些Maven的介绍

Maven 是一个强大的项目管理和构建自动化工具&#xff0c;主要用于Java项目&#xff0c;也可以用于其他基于JVM&#xff08;Java虚拟机&#xff09;的语言项目&#xff0c;它的作用介绍如下 1.依赖管理 在Java开发中&#xff0c;一个项目通常会依赖许多外部的库&#xff0c;比…

Latex中给公式加边框

1、这里使用的不是 amsmath 的 \boxed 命令, 而是 empheq 的 empheq 环境以及 xcolor 的 \fcolorbox 命令, 下面是代码, 可以分别阅读这两个手册来获取更多的信息 \documentclass{article} \usepackage{xcolor} \usepackage{empheq} \usepackage{amsmath} \begin{document}\be…

城镇住房保障:SpringBoot系统架构解析

2相关技术 2.1 MYSQL数据库 MySQL是一个真正的多用户、多线程SQL数据库服务器。 是基于SQL的客户/服务器模式的关系数据库管理系统&#xff0c;它的有点有有功能强大、使用简单、管理方便、安全可靠性高、运行速度快、多线程、跨平台性、完全网络化、稳定性等&#xff0c;非常…

深入了解Git、GitHub、GitLab及其应用技巧

在现代软件开发中&#xff0c;掌握版本控制系统&#xff08;VCS&#xff09;是至关重要的&#xff0c;其中Git是最流行的分布式版本控制工具之一。本文将详细介绍Git的用途及其基本操作&#xff0c;并深入探讨GitLab、GitHub、和Git Desktop的使用方法&#xff0c;同时总结Git的…

【Linux】解锁操作系统潜能,高效线程管理的实战技巧

目录 1. 线程的概念2. 线程的理解3. 地址空间和页表4. 线程的控制4.1. POSIX线程库4.2 线程创建 — pthread_create4.3. 获取线程ID — pthread_self4.4. 线程终止4.5. 线程等待 — pthread_join4.6. 线程分离 — pthread_detach 5. 线程的特点5.1. 优点5.2. 缺点5.3. 线程异常…

SQL 数据结构查询

1&#xff1a;查询变得结构。 SELECT COLID,SO.NAME,EP.VALUE,SO.LENGTH,MIN(ST.NAME) AS TYPE FROM SYS.EXTENDED_PROPERTIES EP RIGHT JOIN SYS.SYSCOLUMNS SO ON MAJOR_IDID AND COLIDMINOR_ID LEFT JOIN SYS.SYSTYPES ST ON ST.XTYPESO…

Pinia 如何在项目中使用?

Pinia 是一个现代的、轻量级的状态管理库&#xff0c;专为 Vue 3 设计。它简化了 Vuex 的复杂性&#xff0c;提供了更直观和灵活的方式来管理应用的状态。下面详细介绍 Pinia 的基本概念、安装、配置和使用方法。 1. 先安装对应的包 npm add pinia2. 在项目中注册 /main.ts …