jjzjj

VAE损失函数的推导及实现

小憨憨。 2023-12-17 原文

VAE损失函数的推导

VAE最原始的优化目标

我们从解码器的角度来引出VAE的优化目标,即传入一个变量z,我们期待解码器能生成我们所期望生成的数据。

我们举个简单的例子来说明一下:假设在我们当前的任务下解码器的目标是根据输入的z来生成一张手写数字图片。当我们传入z之后,解码器的输出可能是各种各样的,但我们希望解码器能生成手写数字图片,而不是生成一个汉字或者是其他奇奇怪怪的符号,而这就是VAE的最原始的优化目标。

我们使用p代表解码器,p(x|z)代表给定z时解码器产生x的概率,其中x并非一个具体的值,而可以看作是一类数据,比如在我们上述的例子中,x可以代表某种风格的手写体数字,p(x|z)就是生成这些数字的概率,这里的概率也并非一个具体的值,而是某一风格的每个数字对应了一个概率,其输出的是一个概率分布。

当我们明白了这些时,我们就可以写出来VAE的优化目标,即最大化解码器输出x的概率,即最大化p(x)。

损失函数推导前的准备

我们可以将p(x)其改写为包含了传入参数的形式,即

当我们将z从离散分布变为连续分布时,该式就变成了

这里的p(z)可以是任意分布,在VAE中我们常常假设p(z)服从标准正态分布。

我们同时也需要知道KL散度的一些相关知识:KL散度用于衡量两个分布之间的差异,其值越大则两个分布的差异越大,同时两个分布的KL散度非负。计算a、b两个分布的KL散度的公式如下

损失函数的推导其一

为了最大化p(x),我们可以采用极大似然估计的方法来进行,即最大化
对应于我们之前给的例子,这里的每个x可以代表了某一个风格的手写体,我们的目标是生成手写体数字,因此我们并不会局限其风格,只要生成的正确就要最大化其概率。

由于最大化L即相当于最大化log p(x),因此后续目标调整为最大化log p(x)。我们假设q代表了编码器,q(z|x)就代表了给定x时编码器产生z的概率。由于

即不管给定何种x,其产生不同z的概率之和恒为1。又因为p(x)与z无关,因此我们可以将log p(x)改写为如下的形式。

由于p(x) = p(x, z) / p(z|x) = (p(x, z) / q(z|x)) * (q(z|x) / p(z|x))

其中第一次变化使用了概率论的定理,第二次变化仅仅加入了一个中间项,可以直接约分掉,并不影响结果。

此时我们可以将log p(x)写为如下形式。

我们将log里的乘积拆开,变为两项之和,即

结合之前提到过的KL散度相关的知识,我们可以看出第二项其实就是KL(q(z|x) || p(z|x))。因为该值为非负项,所以log p(x)不可能小于第一项,我们使用Lb来指代第一项,从而便于书写。

结合我们在准备阶段所提到的

我们可以知道,当p(x|z)不变时,p(x)也不变,从而log p(x)也不变,那么Lb+KL(q(z|x) || p(z|x))的值就不会变。这时如果我们利用q(z|x)来最大化Lb,那么Lb就会增大,而KL(q(z|x) || p(z|x))的值就会减小。

那么如果q(z|x)不变呢?此时当我们增大p(x|z)时,Lb会增大且p(x)会增大,即log p(x)也会增大。

由此我们可以得出结论,只要我们最大化Lb就能使log p(x)最大化。

损失函数的推导其二

此时我们的目标变为了最大化Lb。
由于p(x,z)=p(z)*p(x|z),我们将Lb中的p(x,z)替换为p(z)*p(x|z),并将其从log里的拆开,可以得到如下结果

我们可以看出Lb的第一项为-KL(q(z|x) || p(z)),即q(z|x)与p(z)两个分布之间的Kl散度的相反数。Lb的第二项可以看作是在q(z|x)这个分布下log p(x|z)的期望,即
此时VAE的最终目标就一目了然了,VAE的训练目标有两个:
第一,最小化KL(q(z|x) || p(z)),使q(z|x)的分布尽量向p(z)靠近。
第二,最大化在q(z|x)这个分布下log p(x|z)的期望,其中q(z|x)为编码器输入x时产生z的概率。假设解码器利用z生成出了x’,我们就需要使x’尽可能向x靠近,以最大化log p(x|z)。

实际使用时所用到的损失函数

根据上述的两个训练目标,VAE的损失函数也被设计为两个:

  1. L1用于最小化KL(q(z|x) || p(z)),VAE假设q(z|x)的分布为正态分布,而p(z)为标准正态分布。计算两个正态分布之间的KL散度的公式如下:

    由于此处p(z)为标准正态分布,因此其μ为0,σ为1,那么我们带入后可得

    其中σ为q(z|x)的标准差,μ为q(z|x)的均值。

实际实现时,当编码器接收到x时,我们并不会让编码器直接输出对应的z,而是会使编码器输出z的分布的均值和标准差,此时我们就可以使用上述的式子作为损失函数,从而更新编码器参数。

此时我们得到了第一个损失函数。

在训练解码器时,我们会从标准正态分布中随机取样,使其乘上上述得到的方差,之后使其加上上述的均值,以此来构建解码器的输入,这样做相当于是给输入加上了噪音,使得解码器的稳定性更好。
2. L2使解码器输出的x’尽可能向x靠近,要做到这个,我们只需要最小化x’和x之间的均方误差即可,即

损失函数的代码实现

def loss_function(recon, x, mu, std) -> torch.Tensor:
    """
    :param recon: output of the decoder
    :param x: encoder input
    :param mu: mean
    :param std: standard deviation
    :return:
    """
    recon_loss = torch.nn.functional.mse_loss(recon, x, reduction="sum")
    kl_loss = -0.5 * (1 + 2 * torch.log(std) - mu.pow(2) - std.pow(2))
    kl_loss = torch.sum(kl_loss)
    loss = recon_loss + kl_loss
    return loss

有关VAE损失函数的推导及实现的更多相关文章

  1. ruby - 在没有 sass 引擎的情况下使用 sass 颜色函数 - 2

    我想在一个没有Sass引擎的类中使用Sass颜色函数。我已经在项目中使用了sassgem,所以我认为搭载会像以下一样简单:classRectangleincludeSass::Script::FunctionsdefcolorSass::Script::Color.new([0x82,0x39,0x06])enddefrender#hamlengineexecutedwithcontextofself#sothatwithintemlateicouldcall#%stop{offset:'0%',stop:{color:lighten(color)}}endend更新:参见上面的#re

  2. ruby-on-rails - 在 ruby​​ 中使用 gsub 函数替换单词 - 2

    我正在尝试用ruby​​中的gsub函数替换字符串中的某些单词,但有时效果很好,在某些情况下会出现此错误?这种格式有什么问题吗NoMethodError(undefinedmethod`gsub!'fornil:NilClass):模型.rbclassTest"replacethisID1",WAY=>"replacethisID2andID3",DELTA=>"replacethisID4"}end另一个模型.rbclassCheck 最佳答案 啊,我找到了!gsub!是一个非常奇怪的方法。首先,它替换了字符串,所以它实际上修改了

  3. ruby - 如何根据特征实现 FactoryGirl 的条件行为 - 2

    我有一个用户工厂。我希望默认情况下确认用户。但是鉴于unconfirmed特征,我不希望它们被确认。虽然我有一个基于实现细节而不是抽象的工作实现,但我想知道如何正确地做到这一点。factory:userdoafter(:create)do|user,evaluator|#unwantedimplementationdetailshereunlessFactoryGirl.factories[:user].defined_traits.map(&:name).include?(:unconfirmed)user.confirm!endendtrait:unconfirmeddoenden

  4. ruby - 在 Ruby 中有条件地定义函数 - 2

    我有一些代码在几个不同的位置之一运行:作为具有调试输出的命令行工具,作为不接受任何输出的更大程序的一部分,以及在Rails环境中。有时我需要根据代码的位置对代码进行细微的更改,我意识到以下样式似乎可行:print"Testingnestedfunctionsdefined\n"CLI=trueifCLIdeftest_printprint"CommandLineVersion\n"endelsedeftest_printprint"ReleaseVersion\n"endendtest_print()这导致:TestingnestedfunctionsdefinedCommandLin

  5. ruby - 在 Ruby 中按名称传递函数 - 2

    如何在Ruby中按名称传递函数?(我使用Ruby才几个小时,所以我还在想办法。)nums=[1,2,3,4]#Thisworks,butismoreverbosethanI'dlikenums.eachdo|i|putsiend#InJS,Icouldjustdosomethinglike:#nums.forEach(console.log)#InF#,itwouldbesomethinglike:#List.iternums(printf"%A")#InRuby,IwishIcoulddosomethinglike:nums.eachputs在Ruby中能不能做到类似的简洁?我可以只

  6. 华为OD机试用Python实现 -【明明的随机数】 2023Q1A - 2

    华为OD机试题本篇题目:明明的随机数题目输入描述输出描述:示例1输入输出说明代码编写思路最近更新的博客华为od2023|什么是华为od,od薪资待遇,od机试题清单华为OD机试真题大全,用Python解华为机试题|机试宝典【华为OD机试】全流程解析+经验分享,题型分享,防作弊指南华为o

  7. 基于C#实现简易绘图工具【100010177】 - 2

    C#实现简易绘图工具一.引言实验目的:通过制作窗体应用程序(C#画图软件),熟悉基本的窗体设计过程以及控件设计,事件处理等,熟悉使用C#的winform窗体进行绘图的基本步骤,对于面向对象编程有更加深刻的体会.Tutorial任务设计一个具有基本功能的画图软件**·包括简单的新建文件,保存,重新绘图等功能**·实现一些基本图形的绘制,包括铅笔和基本形状等,学习橡皮工具的创建**·设计一个合理舒适的UI界面**注明:你可能需要先了解一些关于winform窗体应用程序绘图的基本知识,以及关于GDI+类和结构的知识二.实验环境Windows系统下的visualstudio2017C#窗体应用程序三.

  8. C51单片机——实现用独立按键控制LED亮灭(调用函数篇) - 2

    说在前面这部分我本来是合为一篇来写的,因为目的是一样的,都是通过独立按键来控制LED闪灭本质上是起到开关的作用,即调用函数和中断函数。但是写一篇太累了,我还是决定分为两篇写,这篇是调用函数篇。在本篇中你主要看到这些东西!!!1.调用函数的方法(主要讲语法和格式)2.独立按键如何控制LED亮灭3.程序中的一些细节(软件消抖等)1.调用函数的方法思路还是比较清晰地,就是通过按下按键来控制LED闪灭,即每按下一次,LED取反一次。重要的是,把按键与LED联系在一起。我打算用K1来作为开关,看了一下开发板原理图,K1连接的是单片机的P31口,当按下K1时,P31是与GND相连的,也就是说,当我按下去时

  9. MIMO-OFDM无线通信技术及MATLAB实现(1)无线信道:传播和衰落 - 2

     MIMO技术的优缺点优点通过下面三个增益来总体概括:阵列增益。阵列增益是指由于接收机通过对接收信号的相干合并而活得的平均SNR的提高。在发射机不知道信道信息的情况下,MIMO系统可以获得的阵列增益与接收天线数成正比复用增益。在采用空间复用方案的MIMO系统中,可以获得复用增益,即信道容量成倍增加。信道容量的增加与min(Nt,Nr)成正比分集增益。在采用空间分集方案的MIMO系统中,可以获得分集增益,即可靠性性能的改善。分集增益用独立衰落支路数来描述,即分集指数。在使用了空时编码的MIMO系统中,由于接收天线或发射天线之间的间距较远,可认为它们各自的大尺度衰落是相互独立的,因此分布式MIMO

  10. 【Java入门】使用Java实现文件夹的遍历 - 2

    遍历文件夹我们通常是使用递归进行操作,这种方式比较简单,也比较容易理解。本文为大家介绍另一种不使用递归的方式,由于没有使用递归,只用到了循环和集合,所以效率更高一些!一、使用递归遍历文件夹整体思路1、使用File封装初始目录,2、打印这个目录3、获取这个目录下所有的子文件和子目录的数组。4、遍历这个数组,取出每个File对象4-1、如果File是否是一个文件,打印4-2、否则就是一个目录,递归调用代码实现publicclassSearchFile{publicstaticvoidmain(String[]args){//初始目录Filedir=newFile("d:/Dev");Datebeg

随机推荐