您现在的位置是:首页 > 正文

《TensorFlow深度学习》(七)——Keras高层接口

2024-01-30 22:54:12阅读 0

常见功能模块

常见网络层类

tf.keras.layer提供了大量常见网络的类,使用call方法即可完成前向计算

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

x = tf.constant([2.,1.,0.1])
layer = layers.Softmax(axis = -1)
out = layer(x)

网络容器

对于常见的网络,需要手动调用每一层的类实例完成前向传播运算,这部分代码显得臃肿,可以用Keras提供的Sequential将多个网络层封装为一个大网络模型,只需要调用网络模型的实例依次即可完成数据从第一层到最末层的顺序传播运算。

#导入Sequential容器
#两层的全连接层加上单独的激活层函数
from tensorflow.keras import layers,Sequential
network = Sequential([
	layers.Dense(3,activation=None)
	layers.ReLu()
	layers.Dense(2,activation=None)
	layers.ReLu()
])
x = tf.random.normal([4,3])
out=network(x)

Sequential容器也可以通过add方法追加新的网络层,通过summary可以方便的打印出网络结构和参数量

模型装配、训练与测试

在训练网络时,一般的流程是通过前向计算获得网络的输出值,再通过损失函数计算网络误差,然后通过自动求导工具计算梯度并更新,同时间隔性地测试网络的性能。对于这种常用的训练逻辑,可以直接通过Keras 提供的模型装配与训练等高层接口实现,简洁清晰。

# 创建5 层的全连接网络
network = Sequential([layers.Dense(256, activation='relu'),
layers.Dense(128, activation='relu'),
layers.Dense(64, activation='relu'),
layers.Dense(32, activation='relu'),
layers.Dense(10)])
network.build(input_shape=(4, 28*28))
network.summary()

首先通过compile 函数指定网络使用的优化器对象、损失函数类型,评价指标等设定,这一步称为装配。例如

# 导入优化器,损失函数模块
from tensorflow.keras import optimizers,losses
# 模型装配
# 采用Adam 优化器,学习率为0.01;采用交叉熵损失函数,包含Softmax
network.compile(optimizer=optimizers.Adam(lr=0.01),
loss=losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'] # 设置测量指标为准确率
)

模型装配完成后,即可通过fit()函数送入待训练的数据集和验证用的数据集,这一步称为模型训练

# 指定训练集为train_db,验证集为val_db,训练5 个epochs,每2 个epoch 验证一次
# 返回训练轨迹信息保存在history 对象中
history = network.fit(train_db, epochs=5, validation_data=val_db,
validation_freq=2)

其中history.history 为字典对象,包含了训练过程中的loss、测量指标等记录项,我们可以直接查看这些训练数据。

通过 Model.predict(x)方法即可完成模型的预测,例如:

# 加载一个batch 的测试数据
x,y = next(iter(db_test))
print('predict x:', x.shape) # 打印当前batch 的形状
out = network.predict(x) # 模型预测,预测结果保存在out 中
print(out)

模型保存

在 Keras 中,有三种常用的模型保存与加载方法。

  1. 张量方式
# 保存模型参数到文件上
network.save_weights('weights.ckpt')
print('saved weights.')
del network # 删除网络对象
# 重新创建相同的网络结构
network = Sequential([layers.Dense(256, activation='relu'),
layers.Dense(128, activation='relu'),
layers.Dense(64, activation='relu'),
layers.Dense(32, activation='relu'),
layers.Dense(10)])
network.compile(optimizer=optimizers.Adam(lr=0.01),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
# 从参数文件中读取数据并写入当前网络
network.load_weights('weights.ckpt')
print('loaded weights!')
  1. 网络方式
    通过Model.save(path)函数可以将模型的结构以及模型的参数保存到path 文件上,在不需要网络源文件的条件下,通过keras.models.load_model(path)即可恢复网络结构和网络参数。
# 保存模型结构与模型参数到文件
network.save('model.h5')
print('saved total model.')
del network # 删除网络对象
# 从文件恢复网络结构与网络参数
network = keras.models.load_model('model.h5')
  1. SavedModel 方式
    当需要将模型部署到其他平台时,采用TensorFlow 提出的SavedModel 方式更具有平台无关性
# 保存模型结构与模型参数到文件
tf.saved_model.save(network, 'model-savedmodel')
print('saving savedmodel.')
del network # 删除网络对象
print('load savedmodel from file.')
# 从文件恢复网络结构与网络参数
network = tf.saved_model.load('model-savedmodel')
# 准确率计量器
acc_meter = metrics.CategoricalAccuracy()
for x,y in ds_val: # 遍历测试集
	pred = network(x) # 前向计算
	acc_meter.update_state(y_true=y, y_pred=pred) # 更新准确率统计
# 打印准确率
print("Test Accuracy:%f" % acc_meter.result())

自定义网络

假设需要一个没有偏置向量的全连接层,即bias 为0,同时固定激活函数为ReLU 函数

class MyDense(layers.Layer):
# 自定义网络层
def __init__(self, inp_dim, outp_dim):
	super(MyDense, self).__init__()
	# 创建权值张量并添加到类管理列表中,设置为需要优化
	self.kernel = self.add_variable('w', [inp_dim, outp_dim],
	trainable=True)
def call(self, inputs, training=None):
	# 实现自定义类的前向计算逻辑
	# X@W
	out = inputs @ self.kernel
	# 执行激活函数运算
	out = tf.nn.relu(out)
	return out

模型乐园

对于常用的网络模型,如ResNet、VGG 等,不需要手动创建网络,可以直接从keras.applications 子模块中通过一行代码即可创建并使用这些经典模型,同时还可以通过设置weights 参数加载预训练的网络参数,非常方便

测量工具

Keras 的测量工具的使用方法一般有4 个主要步骤:新建测量器,写入数据,读取统计数据和清零测量器。keras.metrics模块:
# 新建平均测量器,适合Loss 数据 loss_meter = metrics.Mean()

# 在每个Step 结束时采集一次loss 值
# 记录采样的数据,通过float()函数将张量转换为普通数值
loss_meter.update_state(float(loss))
# 打印统计期间的平均loss
print(step, 'loss:', loss_meter.result())
if step % 100 == 0:
# 打印统计的平均loss
	print(step, 'loss:', loss_meter.result())
	loss_meter.reset_states() # 打印完后,清零测量器

可视化

# 安装TensorBoard
pip install tensorboard

在模型端,需要创建写入监控数据的Summary 类,并在需要的时候写入监控数据。首先通过tf.summary.create_file_writer 创建监控对象类实例,并指定监控数据的写入目录,代码如下:

# 创建监控类,监控数据将写入log_dir 目录
summary_writer = tf.summary.create_file_writer(log_dir)
with summary_writer.as_default(): # 写入环境
	# 当前时间戳step 上的数据为loss,写入到名为train-loss 数据库中
	tf.summary.scalar('train-loss', float(loss), step=step)
# 对于图片类型的数据,可以通过tf.summary.image 函数写入监控图片数据
with summary_writer.as_default():# 写入环境
	# 写入测试准确率
	tf.summary.scalar('test-acc', float(total_correct/total),
	step=step)
	# 可视化测试用的图片,设置最多可视化9 张图片
	tf.summary.image("val-onebyone-images:", val_images,
	max_outputs=9, step=step)

在运行程序时,监控数据被写入到指定文件目录中。如果要实时远程查看、可视化这些数据,还需要借助于浏览器和Web 后端。

通过在cmd 终端运行tensorboard --logdir path 指定Web 后端监控的文件目录path,此时打开浏览器,并输入网址http://localhost:6006 (也可以通过IP 地址远程访问,具体端口号可能会变动,可查看命令提示) 即可监控网络训练进度

在这里插入图片描述

网站文章

  • 详解基于MATLAB的车牌识别系统设计与实现(1):车牌定位

    详解基于MATLAB的车牌识别系统设计与实现(1):车牌定位

    车牌识别系统主要包括车牌定位、字符分割和字符识别三个核心模块。车牌定位是利用车牌的颜色和形状特征确认并获取汽车的车牌位置;字符分割是将获取到的车牌切割成单个字符;字符识别目前主要有基于模板匹配算法和基于人工神经网络算法对切割的字符进行识别。本节内容主要讲解车牌定位,主要内容有:读取图像预处理边缘检测形态学操作定位裁剪主函数代码如下 :// main.mclose all;cle...

    2024-01-30 22:53:41
  • ACPI电源管理的6个状态(S0-S5) 热门推荐

    ACPI 电源管理的6个状态:S0: 主机正常工作状态S1: CPU停止工作,wakeup时间为:0sS2: CPU关闭,wakeup时间为:0.1sS3: 主机中除了内存以外其他所有部件都停止工 作,wakeup时间为:0.5sS4: 主机将内存中的信息写入到硬盘中,之后关闭所有组件,wakeup时间为:30sS5: 关机

    2024-01-30 22:53:33
  • Servlet Cookie取不到值原因

    现象: 在测试带Cookie的HTTP请求时发现,服务端用request.getHeader("cookie")可以去到值;但是用request.getCookies()却不行Co...

    2024-01-30 22:53:16
  • Java中的类与变量

    一、静态 1、类的静态成员与类直接相关,与对象无关,在一个类的所有实例之间共享同一个静态成员,该类的对象共享其静态成员变量的值 2、静态成员函数中不能调用非静态成员,静态成员变量可被该类的所有方法访问...

    2024-01-30 22:52:48
  • 网络安全,别报!!!

    网络安全,别报!!!

    网络安全究竟要不要报?7年网工如是说

    2024-01-30 22:52:38
  • 思考R-CNN的一些问题,如何提取特征,分类,训练,测试

    思考R-CNN的一些问题,如何提取特征,分类,训练,测试

    1.R-CNN是什么 论文链接 把region proposal和CNNs结合起来,所以该方法被称为R-CNN:Regions with CNN features,整个检测系统有三个模块构成。 第一个...

    2024-01-30 22:52:31
  • React自定义Hooks——useLocalStorage

    封装任何的函数,都是将一些重复的逻辑单独封装到一个函数。而不是为了封装而封装。 定义useLocalStorage,是因为localStorage,在不同的组件中获取,更新。每次更新或者获取都是loc...

    2024-01-30 22:52:25
  • 数据通信与计算机网络 TASK0

    数据通信与计算机网络 TASK0

    数据通信与计算机网络 TASK0

    2024-01-30 22:51:55
  • goland的激活码 热门推荐

    goland的激活码 http://idea.youbbs.org

    2024-01-30 22:51:49
  • Numpy基础一:ndarray

    Numpy基础一:ndarray

    1. ndarray对象numpy中的内置多维数组类型ndarray由一系列同类型数据所构成,并以0作为下标开始索引。ndarray中的每个元素在内存中都有相同的存储大小区域。ndarray由以下四部...

    2024-01-30 22:51:43