当前位置: 首页 > news >正文

深度学习基础之BatchNorm和LayerNorm

文章目录

  • BatchNorm
  • LayerNorm
  • 总结
  • 参考

BatchNorm

Batch Normalization(下文简称 Batch Norm)是 2015 年提出的方法。Batch Norm虽然是一个问世不久的新方法,但已经被很多研究人员和技术人员广泛使用。实际上,看一下机器学习竞赛的结果,就会发现很多通过使用这个方法而获得优异结果的例子。
Batch Norm有以下优点。
(1) 可以使学习快速进行(可以增大学习率)。
(2)不那么依赖初始值(对于初始值不用那么神经质)。
(3)抑制过拟合(降低Dropout等的必要性)

Batch Norm,顾名思义,以进行学习时的mini-batch为单位,按mini-batch进行正规化。具体而言,就是进行使数据分布的均值为0、方差为1的正规化。用数学式表示的话,如下所示。

在这里插入图片描述
看公式,是不是有点像经典机器学习里为了消除量纲的标准化的操作。
在这里插入图片描述
这就是Batch Normalization的算法了。

简单来说,其实就是对一个batch的数据进行标准化操作。

我们可以使用pytorch为我们写好的方法直接调用验证一下:

import torch.nn as nn
import torch as th
data = [[[1,2,5],[2,5,8.5],[3,3,3]],
        [[2,8,4],[1,3,9],[2,6,4]],
        [[1,1,1],[1,3,5],[0.5,6,0.2]]]
data = th.tensor(data)
data_bn = nn.BatchNorm1d(3)(data)
data_ln = nn.LayerNorm(3)(data)
mean = th.sum(data_bn)
mu = th.sum(th.pow(data_bn-mean, 2) / 27)
print(data_bn)
print(mean)
print(mu)

在这里插入图片描述
众所周知,浮点数运算会飘,所以2.3842e-07就相当于是0了
方差差计算出来是1
正好符合计算的结果。

所以batch norm是对一个batch的所有数据一起进行标准化操作。
在这里插入图片描述
这是使用手写数据集进行的测试实验,发现初始化参数不同时,对学习效果的影响是很大的,但是使用了batch norm之后,受到的影响就比较小了。

batch norm主要用于CV领域

LayerNorm

layer norm也是一种标准化的方法,公式也差不多,不过是对每个batch(3维)里的每个样本的每行进行标准化,主要是用于NLP领域的。

话不多说,上代码:

import torch.nn as nn
import torch as th
data = [[[1,2,5],[2,5,8.5],[3,3,3]],
        [[2,8,4],[1,3,9],[2,6,4]],
        [[1,1,1],[1,3,5],[0.5,6,0.2]]]
data = th.tensor(data)
data_ln = nn.LayerNorm(3)(data)
print(data_ln)
for b in data_ln:
        for line in b:
                mean = th.sum(line)
                mu = th.sum(th.pow(line-mean, 2))
                print(mean, mu / line.shape[0])

输出:
在这里插入图片描述
所有,使用layer norm 对应到NLP里就是相当于对每个词向量各自进行标准化。

总结

batch norm适用于CV,因为计算机视觉喂入的数据都是像素点,可以说数据点与点之间是可以比较的,所以使用batch norm可以有比较好的效果,而NLP里,每个词的词向量是一组向量表示一个词,一个词向量割裂开来看是没有意义的,因此不同词向量里的数据点是不能混为一谈的,所以batch norm之后可能会使得词损失语义,效果就可能不好了,但是使用layer norm只是让各个词向量进行标准化,就能够有比较理想的效果了。

参考

深度学习老师的课件

相关文章:

  • 大数据开发(离线实时音乐数仓)
  • 美易官方:盘前道指期货涨0.5%,游戏驿站跌逾15%
  • 知识图谱操作的探索与利用
  • vue实例的data属性,可以在哪些生命周期中获取到
  • 数据库-索引快速学
  • 38. 单调递增的数字(力扣LeetCode)
  • css常用的选择器介绍
  • Python爬虫实战入门:爬取360模拟翻译(仅实验)
  • KaiwuDB 拿下“物联之星”双项殊荣
  • WPF真入门教程29--MVVM常用框架之MvvmLight
  • 数据隐私安全趋势
  • Alist访问主页显示空白解决方法
  • 【Spring】面向切面编程详解(AOP)
  • 力扣(LeetCode)75. 颜色分类(C语言)
  • LeetCode算法题整理(200题左右)
  • flowable-ui绘图常见错误
  • 前端最新基础知识
  • 【2022年玄武云科技AI算法岗秋招面试记录】
  • ROBOGUIDE软件:FANUC机器人电弧跟踪功能介绍与示教编程操作
  • 5 个 TypeScript 库来改进你的代码
  • 高校教学管理信息系统/教学管理系统
  • 网络规划设计师上午真题(2020)
  • Java基础知识-char
  • 小程序是什么?
  • Springboot 对接云端服务器
  • PyTorch之数据集随机值
  • #边学边记 必修4 高项:对事的管理 第六章 项目质量管理之质量控制
  • docker的网络模式
  • 集中供暖热计量温控一体化管理系统
  • mybatis案例--mapper代理开发
  • 【程序运行时的两种环境】
  • vue的相关概念