您现在的位置是:首页 > 正文

二、现有网络模型(torchvision.models.vgg16)的修改与使用

2024-02-29 11:49:52阅读 2

1.torchvision.models.vgg16

官方文档 : https://pytorch.org/vision/stable/models.html#id2
在这里插入图片描述

pretrained (bool) – If True, returns a model pre-trained on ImageNet

ImageNet数据集太大不好下载

2.pretrained设置不同时网络模型的差别

在这里插入图片描述

3.如何修改现有网络结构

修改vgg16_true网路结构,添加linear层

import torchvision
from torch.nn import Linear

vgg16_false = torchvision.models.vgg16(pretrained=False)
vgg16_true = torchvision.models.vgg16(pretrained=True)

print(vgg16_true)
# vgg16_true.add_module("add_linear",Linear(1000,10))
vgg16_true.classifier.add_module("add_linear",Linear(1000,10))
print(vgg16_true)

在这里插入图片描述
修改vgg16_false网路结构,更改分类器第6层为指定linear层

print(vgg16_false)
vgg16_false.classifier[6]=Linear(4096,10)
print(vgg16_false)

在这里插入图片描述

4.模型的保存、加载

vgg16_method1 结构+参数

import torch
import torchvision

vgg16 = torchvision.models.vgg16(pretrained=False)
# vgg16_method1 结构+参数
torch.save(vgg16, "vgg16_method1.pth")

# 模型加载(在另一个文件加载)
model = torch.load("vgg16_method1.pth")
print(model)

方式1的陷阱
自定义网络结构如下:

import torch
import torchvision
from torch import nn

class Qu(nn.Module):
    def __init__(self):
        super(Qu, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)

    def forward(self, x):
        x = self.conv1(x)
        return
        
qu = Qu()
torch.save(qu, "qu_method1.pth")

在另一个文件加载该模型,会报错
AttributeError: Can’t get attribute ‘Qu’ on <module ‘main’ from ‘D:/documents/Gra_Proj/DL_Pytorch/B_tudui/model_load.py’>

正确的调用格式需要复制原模型的类定义

class Qu(nn.Module):
    def __init__(self):
        super(Qu, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)

    def forward(self, x):
        x = self.conv1(x)
        return

model = torch.load("qu_method1.pth")
print(model)

或者用import

from model_save import *

model = torch.load("qu_method1.pth")
print(model)

vgg16_method2 参数(官方推荐)

import torch
import torchvision
from torch import nn

vgg16 = torchvision.models.vgg16(pretrained=False)
# vgg16_method2 参数(官方推荐)
torch.save(vgg16.state_dict(), "vgg16_method2.pth")

# 模型加载(在另一个文件加载)
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
print(vgg16)

网站文章

  • android平台架构介绍

    android平台架构介绍

    下面这张图片是在google官网上下载的关于android系统的体系结构图: 从上面的图片我们可以看出来,android系统的底层建立在linux系统之上,该平台由操作系统、中间件、用户界面和应用软件4层架构组成,它采用的主要方法被称为软件叠层,这种叠层方法结构使得层与层之间相互分离,明确了各层之间的分工,这种分工保证了低耦合性,当下层的层内或者层下发生改变时,上层的应用程序无需

    2024-02-29 11:49:24
  • 扩散加权成像(DWI):从原理到临床

    扩散加权成像(DWI):从原理到临床

    2019独角兽企业重金招聘Python工程师标准&gt;&gt;&gt; ...

    2024-02-29 11:49:16
  • 为什么要使用Maven?

    为什么要使用Maven?

    为什么要使用Maven? 之所以会提出这个问题,是因为即使不使用Maven我们仍然可以进行B/S结构项目的开发。从表述层、业务逻辑层到持久化层再到数据库都有成熟的解决方案——不使用Maven我们一样可...

    2024-02-29 11:49:08
  • 职高计算机专业考本科要多少分,职高多少分才能上本科 需要多少分数

    职高计算机专业考本科要多少分,职高多少分才能上本科 需要多少分数

    职高多少分才能上本科,需要多少分数,小编整理了相关信息,来看一下!职高多少分才能上本科高职高考是“3+证书”本科的分数,总分现在是550分。职高高考本科院校招生录取一般在400多分,各招生院校的录取分...

    2024-02-29 11:49:02
  • SegNetr: 重新思考 U 形网络中的局部-全局交互和跳过连接

    SegNetr: 重新思考 U 形网络中的局部-全局交互和跳过连接

    近年来,U 形网络因其简单且易于调整的结构而在医学图像分割领域占据主导地位。然而,现有的U型分割网络:1)大多侧重于设计复杂的自注意力模块来弥补基于卷积运算的长期依赖性的不足,这增加了网络的总体参数数...

    2024-02-29 11:48:33
  • ubuntu查看修改主机名

    ubuntu查看修改主机名

    为什么80%的码农都做不了架构师?>>> ...

    2024-02-29 11:48:28
  • GCC中通过--wrap选项使用包装函数

    GCC中通过--wrap选项使用包装函数

    在使用GCC编译器时,如果不想工程使用系统的库函数,例如在自己的工程中可以根据选项来控制是否使用系统中提供的malloc/free, new/delete函数,可以有两种方法: (1). 使用LD_P...

    2024-02-29 11:48:19
  • 数据库学习记录——错题总结(一)

    数据库学习记录——错题总结(一)

    第一套 1.对关系模型叙述错误的是( )。 正确答案: D 你的答案: D (正确) 建立在严格的数学理论、集合论和谓词演算公式的基础之上 微机 DBMS 绝大部分采取关系数据模型 用二维表表示关系模...

    2024-02-29 11:47:50
  • Android Fragment生命周期及各个方法使用

    Android Fragment生命周期及各个方法使用

    在Android开发中,我们都少不了使用Fragment,一直在使用,但是没有很详细的理解过具体生命周期的回调,这段时间比较闲,特定写一下总结:就像activity一样,fragment也有它们自己的...

    2024-02-29 11:47:43
  • 蒸米ROP学习笔记(一步一步学 ROP 之 Linux_x86 篇)

    这里写自定义目录标题欢迎使用Markdown编辑器新的改变功能快捷键合理的创建标题,有助于目录的生成如何改变文本的样式插入链接与图片如何插入一段漂亮的代码片生成一个适合你的列表创建一个表格设定内容居中...

    2024-02-29 11:47:34