您现在的位置是:首页 > 正文

LeNet模型——tensorflow实现

2024-01-30 22:27:44阅读 0


版本:tensorflow2.0.0rcl
github地址 https://github.com/yang-ze-kang/LeNet5

MNIST数据集

MNIST数据集简介

  1. 包含0~9手写数字
  2. 60000个训练集、10000个测试集
  3. 数据格式:28*28
  4. 灰度图(单通道)

MNIST数据集加载

方法一:加载网络数据

mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

方法二:将数据下载到本地电脑,加载本地电脑数据
数据下载地址:https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz

f = np.load("path")
x_train, y_train = f['x_train'],f['y_train']
x_test, y_test = f['x_test'],f['y_test']
f.close()

导入成功检验

print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)

结果
在这里插入图片描述

MNIST数据集可视化

依赖第三方库

import matplotlib.pyplot as plt

显示图片及label

image_index = 123
print(y_train[image_index])     #查看随机一张图片的label
plt.imshow(x_train[image_index], cmap='Greys')  #图片显示
plt.show()

数据集格式转换

x_train = np.pad(x_train, ((0, 0), (2, 2), (2, 2)), 'constant', constant_values=0)  #将图片从28*28扩展成32*32
x_train = x_train.astype('float32')  #数据类型转换
x_train /= 255  #数据正则化
x_train = x_train.reshape(x_train.shape[0], 32, 32, 1)  #数据维度转换
print(x_train.shape)

tendorflow

模型类

Model

  1. 实例化
    在这里插入图片描述
  2. 函数

sumary 查看模型
complie 编译时定义优化器、参数等
fit 训练模型
save 保存训练好的模型
evalute 评估模型

Sequetial

  1. 实例化:只有一个参数layers
  2. 继承自Model

卷积类Conv2D

filters, #卷积核个数
kernel_size, #卷积核大小
strides=(1,1), #步长
padding=’ ', #valid或SAME
data_format=None, #默认channels_last
activation, #激活函数

池化类AveragePooling2D

pool_size, #必须设置
stride
padding
data_format

LeNet模型

模型结构

LeNet5
在这里插入图片描述

模型构建

Model方法

class LeNet(tf.keras.Model):
    def __init__(self):
        super().__init__()
        #模型
        self.conv_layer_1 = tf.keras.layers.Conv2D(
            filters=6,
            kernel_size=(5, 5),
            padding='valid',
            activation=tf.nn.relu)

        self.pool_layer_1 = tf.keras.layers.MaxPool2D(
            pool_size=(2, 2),
            padding='same')

        self.conv_layer_2 = tf.keras.layers.Conv2D(
            filters=16,
            kernel_size=(5, 5),
            padding='valid',
            activation=tf.nn.relu)

        self.pool_layer_2 = tf.keras.layers.MaxPool2D(
            pool_size=(2, 2),
            padding='same')

        self.flatten = tf.keras.layers.Flatten()

        self.fc_layer_1 = tf.keras.layers.Dense(
            units=120,
            activation=tf.nn.relu)

        self.fc_layer_2 = tf.keras.layers.Dense(
            units=84,
            activation=tf.nn.relu)

        self.output_layer = tf.keras.layers.Dense(
            units=10,
            activation=tf.nn.relu)

    def call(self, inputs):
        x = self.conv_layer_1(inputs)
        x = self.pool_layer_1(x)
        x = self.conv_layer_2(x)
        x = self.pool_layer_2(x)
        x = self.flatten(x)
        x = self.fc_layer_1(x)
        x = self.fc_layer_2(x)
        output = self.output_layer(x)

        return output

Sequential方法

model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(filters=6, kernel_size=(5, 5), padding='valid', activation=tf.nn.relu, input_shape=(32, 32, 1)),
    tf.keras.layers.MaxPool2D(pool_size=(2, 2), padding='same'),
    tf.keras.layers.Conv2D(filters=6, kernel_size=(5, 5), padding='valid', activation=tf.nn.relu),
    tf.keras.layers.MaxPool2D(pool_size=(2, 2), padding='same'),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(units=120, activation=tf.nn.relu),
    tf.keras.layers.Dense(units=84, activation=tf.nn.relu),
    tf.keras.layers.Dense(units=10, activation=tf.nn.relu)
])

模型summary结果

在这里插入图片描述

模型训练与预测

模型训练

超参数

epochs -> 数据被训练的次数
batch_size ->小批梯度下降每批的大小
learning_rate -> 学习率

优化:Adam

源码

#------------------------------【训练】---------------------------------

#超参数设置
num_epochs = 10
batch_size = 64
learning_rate = 0.01

#优化器
adam_optimizer = tf.keras.optimizers.Adam(learning_rate)

#编译
model.compile(optimizer=adam_optimizer,
               loss=tf.keras.losses.sparse_categorical_crossentropy,
               metrics=['accuracy'])

#训练
start_time = datetime.datetime.now()

model.fit(x=x_train,
         y=y_train,
         batch_size=batch_size,
         epochs=num_epochs)

endtime = datetime.datetime.now()

time_cost = endtime - start_time
print('time_cost = ', time_cost)

保存/加载模型

#保存/加载模型
model.save('lenet.h5')
#model = tf.keras.models.load_model('lenet.h5')

评估

#------------------------------【评估】---------------------------------

x_test = DataFormat(x_test)
print(x_test.shape)

print(model.evaluate(x_test, y_test))

预测

#------------------------------【预测】---------------------------------

image_index = 2333
print(x_test[image_index].shape)
plt.imshow(x_test[image_index].reshape(32, 32), cmap='Greys')
plt.show()

pred = model.predict(x_test[image_index].reshape(1, 32, 32, 1))
print(pred.argmax())

网站文章

  • CCNP-第六篇-OSPF高级版(三)

    CCNP-第六篇-OSPF高级版(三)

    CCNP-第六篇-OSPF高级版(三) 这一节差不多都是密码认证了,还有个NSSA和OE1,OE2 OSPF默认路由 OSPF认证问题 OSPF特殊区域,NSSA,STUB OSPF下发默认路由 其实...

    2024-01-30 22:27:37
  • tornado mysql 操作_tornado 数据库操作

    tornado是python的web框架,web程序开发中数据库操作是必须的。安装:tornado的官方文档中提供的说明比较少,而且提供的模块中未找到数据库方面的模块,难道没有针对数据库操作进行封装?百度查询了一下才发现,tornado在升级过程中把数据库模块独立出来了。模块为torndb模块。模块git地址:https://github.com/bdarnell/torndb 官方文档地址...

    2024-01-30 22:27:29
  • 各种排序算法的时间复杂度和空间复杂度

    各种排序算法的时间复杂度和空间复杂度

    2024-01-30 22:27:20
  • 机器学习入门与Python实战(四):K近邻分类(KNN)

    机器学习入门与Python实战(四):K近邻分类(KNN)

    目录现实问题:“物以类聚,人以群分”一.KNN算法概述二.KNN算法介绍K近邻分类模型算法步骤距离计算方式KNN分类图K值选择三.KNN特点KNN算法的优势和劣势知识巩固Python实战:KNN数据分...

    2024-01-30 22:26:51
  • HTTP1.0 HTTP 1.1 HTTP 2.0主要区别

    HTTP1.0 HTTP 1.1主要区别长连接HTTP 1.0需要使用keep-alive参数来告知服务器端要建立一个长连接,而HTTP1.1默认支持长连接。HTTP是基于TCP/IP协议的,创建一个TCP连接是需要经过三次握手的,有一定的开销,如果每次通讯都要重新建立连接的话,对性能有影响。因此最好能维持一个长连接,可以用个长连接来发多个请求。节约带宽HTTP 1.1支持只发送he...

    2024-01-30 22:26:45
  • Ansible安装与配置(自动化运维管理工具) 热门推荐

    Ansible安装与配置(自动化运维管理工具) 热门推荐

    原文链接:http://blog.csdn.net/xyang81/article/details/51568227Ansible是一个简单高效的自动化运维管理工具,用Python开发,能大批量管理N...

    2024-01-30 22:26:39
  • 在Vue中封装一个select组件 热门推荐

    在Vue中封装一个select组件 热门推荐

    我们使用iview封装一个select组件 封装的是每一个select下拉框 <template> <div class='select'> <i-select :model.sync='selecteddata' :placeholder='placeholdertext' filterable multiple...

    2024-01-30 22:25:59
  • jsp ajax不返回数据,【100分】ajax在jsp页面接受不到数据解决方法

    当前位置:我的异常网» Java Web开发»【100分】ajax在jsp页面接受不到数据解决方法【100分】ajax在jsp页面接受不到数据解决方法www.myexceptions.net网友分享于:2013-03-20浏览:42次【100分】ajax在jsp页面接受不到数据大家帮我看看是哪的问题啊 ,我刚接触ajax-------jsp页面--------var xmlHttp;f...

    2024-01-30 22:25:53
  • jquery autocomplete前后台整合实例(1)

    最近在做项目时需要用到搜索自动提示,例如姓名查找模糊匹配提示。目前Jquery的自动提示插件非常多,我会例举几款,写出一些与后台交互的例子本文介绍一款Jquery autocomplete官方地址:https://github.com/devbridge/jQuery-Autocomplete下面直接来实例,不玩虚的,不参与后台交互的这里就不做介绍了,本文涉及的后台开发语言是java,

    2024-01-30 22:25:44
  • hadoop安装教程,分布式配置 CentOS7 Hadoop3.1.2

    hadoop安装教程,分布式配置 CentOS7 Hadoop3.1.2

    安装前的准备 1、 准备4台机器、或虚拟机 4台机器的名称和IP对应如下 master:192.168.199.128 slave1:192.168.199.129 slave2:192.168.199.130 slave3:192.168.199.131 2、分别为4台机器安装JDK8 步骤详细请参考:CentOS7卸载 OpenJDK 安装Sun的JDK...

    2024-01-30 22:25:11