【Pytorch学习笔记】12.修改预训练模型权重参数的方法(用于对单通道灰度图使用预训练模型)

文章目录

    • 1.导出模型参数,修改参数
    • 2.修改模型结构,导回参数

我们在训练单通道图像,即灰度图(如医学影像数据)时,常会使用预训练模型进行训练。
但是一般的预训练模型是以ImageNet数据集预训练的,训练的对象是三通道的彩色图片。
这需要对模型的参数进行修改,让第一个卷积层的参数从3通道卷积改成1通道卷积。
(比如下图是将三通道改成单通道后卷积层的变化)

在这里插入图片描述

我们知道灰度图是三通道图各个通道的加权平均,所以我们可以假设改成单通道后,将3个通道对应的卷积矩阵对应位置相加(sum)得到1通道的卷积矩阵,再去卷积灰度图,这样几乎不折损对图像的特征提取能力。

下面以Resnet50预训练模型为例来修改第1个卷积层的参数,使其能用于单通道图片的训练。

1.导出模型参数,修改参数

Pytorch中修改模型的参数,如果涉及网络结构的变化,需要先修改网络结构再赋予参数值。
即导出预训练模型参数→修改预训练模型参数→修改模型的网络结构→导回修改后的模型

from torchvision.models import resnet50
net = resnet50(pretrained=True)
print(net.conv1)  # 查看第一个卷积层的结构

weights = net.state_dict()  # state_dict()以 有序字典 罗列参数
print(weights.keys())  # 查看参数的key
weights['conv1.weight'].shape  # 根据key取到参数,查看形状

在这里插入图片描述

修改模型参数

weights['conv1.weight'] = weights['conv1.weight'].sum(1, keepdim=True)  # 修改第一个卷积层的参数,从3通道卷积改成1通道卷积
weights['conv1.weight'].shape  # 查看修改后的形状

在这里插入图片描述

2.修改模型结构,导回参数

import torch.nn as nn

# 修改第一个卷积层的结构
net.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
# 导入修改后的参数
net.load_state_dict(weights)
net

在这里插入图片描述

修改成功!
接下去就可以用你自己的灰度图数据集来微调这个预训练模型了。


http://www.niftyadmin.cn/n/250503.html

相关文章

招商银行fintech 2021年机试

5月2日 选择题 包含单选(40道)和多项选择(10道),累计120分,最后有一个所谓20分的编程题,实际上是下面2个小时的编程题的链接,因此第一个小时所有时间都给选择题,最后拿…

JVM垃圾回收与调优

文章目录 1、如何判断对象可以回收1.1、 引用计数法1.2、可达性分析法1.3、五种引用类型1.3.1 、强引用1.3.2 、软、弱引用1.3.3 、虚引用、终结器引用1.3.4、 终结器引用1.3.5 、总结 2. 垃圾清除算法2.1、标记清除2.2 、标记整理2.3、 复制 3. 分代垃圾回收3.1 、新生代、老年…

sqlacodegen生成SQLAlchemy模型

SQLAlchemy是一个Python SQL工具包和ORM框架,它提供了一种方便的方式来与关系型数据库进行交互。sqlacodegen是一个用于生成SQLAlchemy模型代码的命令行工具,它可以根据现有的数据库表结构自动生成Python代码。 生成MySQL模型代 可以通过以下命令生成M…

解决在vue中使用elementUI自定义校验及点击提交不生效问题

前言: 本章讲述的主要是对身份证号码的校验 及 为何校验了但提交不生效问题。 拓展小知识: 🍀 1、身份证号码(二代18位身份证)的含义: 1️⃣ 1-2位:代表所属省级政府的代码; 2️⃣ 3…

awk命令常用例子

按列排序 awk {print $2, $1} filename | sort这个命令将文件中的第二列和第一列交换,并按照第二列进行排序。 统计行数 awk END{print NR} filename这个命令将统计文件中的行数并输出。 按照条件过滤 awk $1 > 10 {print $0} filename这个命令将输出第一列…

代码优化- 基本概念

思考一个问题:我们可以再抽象语法树上做编译优化吗? 答案是否定的,如果在抽象语法树上做编译优化的话,程序员所写的可能包含错误的代码,可能就被删除了,比如,对下面的程序做不可达代码删除优化…

【内摹访谈】谈谈AI爆发前夜的B端设计

本文来自摹客产品设计团队(MPD)的设计专栏“内摹访谈”。专栏介绍:专栏名称来源于西方美学理论「内摹仿说」,意指审美活动与摹仿活动紧密相连,审美不只针对表象动作,其核心在于由物及我,从表观带…

数据库基础篇 《9. 子查询》

目录 1. 需求分析与问题解决 1.1 实际问题 1.2 子查询的基本使用 ​编辑1.3 子查询的分类 分类方式1:我们按内查询的结果返回一条还是多条记录,将子查询分为 单行子查询 、 多行子查询 。 分类方式2: 我们按内查询是否被执行多次&#x…