jjzjj

对于训练时loss出现负值的情况

翰墨大人 2025-01-03 原文


在训练时候loss出现负值,就立马停下来分析一下原因在哪。最有可能是损失函数出现问题,开始只使用交叉熵损失时没有出现过,在加上了dice loss时就出现了问题。于是就去dice loss中寻找原因。
1:首先需要明白语义分割的GT,每一个像素点的值就是像素的类别。

# -*- coding: utf-8 -*-
import numpy as np
from torchvision import transforms
import torch
from PIL import Image
img = Image.open('C:/Users/翰墨大人/Desktop/0003_lable.png') #图像所在位置
img1 = np.array(img)
img1 = torch.from_numpy(img1).type(torch.FloatTensor)
# trans = transforms.ToTensor()
# img1 = trans(img1)

a = torch.unique(img1) # 查看图片内的像素值
print(a)
print(img.mode) # 查看图片模式

打印结果:

tensor([ 0.,  1.,  5.,  7.,  8., 26., 29., 38., 40.])
P

原图一共有四十个类别,在003_lable这张GT上只出现了上述的类别。剩下的像素点和上述像素点是重复的。
注意:
这里将np格式转换为tensor格式时候,不能使用transforms.ToTensor(),他会将像素值发生改变。这样新的像素值点和类别就不是一一对应的。

tensor([0.0000, 0.0039, 0.0196, 0.0275, 0.0314, 0.1020, 0.1137, 0.1490, 0.1569])
P

2:语义分割的GT一共有四十类,没有通道,而在模型中pred的输出为一个通道。
在计算二分类dice loss时候,首先要将pred进行sigmoid。GT是四十个类别,如果是二分类的话那么标签就必须是[0,1]。而loss为负值的原因就是没有将标签转换为[0,1]。
X是pred,Y是GT,分子就是X和Y进行矩阵的相乘再相加,当Y中含有大于1的类别,比如30,40的话,而X又是进过sigmoid之后再(0,1)之内,那么分子除以分母的值就会大于1,造成dice loss就变成了负值。

如何将GT变为[0,1]呢?因为我需要的是对GT进行提边,使用一个拉普拉斯对GT进行卷积,然后再使用一个阈值,大于阈值为1,小于为0。为是边缘和不是边缘。GT就又灰度图像转换为了二值图像。

l = F.conv2d(x,sobel,padding=1,stride=1)
print(l.shape)
ll = torch.unique(l) # 查看图片内的像素值
print(ll)

l[l>0.1]=1
l[l<0.1]=0
l_ = torch.unique(l) # 查看图片内的像素值
print(l_)

结果:

对于多分类的dice loss:
GT需要进行one-hot编码,这样一个通道,像素点为(0-40)的GT就会变为四十个通道,每个通道像素点为(0,1),每个通道都可以看做一个二分类问题,属于该类别和不属于该类别。而pred的输出通道也为40,通过计算pred的每个通道和GT的每个通道的loss,最后求均值得到总loss。
3:代码代码参考
单分类代码:

import torch
import torch.nn as nn

class BinaryDiceLoss(nn.Model):
	def __init__(self):
		super(BinaryDiceLoss, self).__init__()
	
	def forward(self, input, targets):
		# 获取每个批次的大小 N
		N = targets.size()[0]
		# 平滑变量
		smooth = 1
		# 将宽高 reshape 到同一纬度
		input_flat = input.view(N, -1)
		targets_flat = targets.view(N, -1)
	
		# 计算交集
		intersection = input_flat * targets_flat 
		N_dice_eff = (2 * intersection.sum(1) + smooth) / (input_flat.sum(1) + targets_flat.sum(1) + smooth)
		# 计算一个批次中平均每张图的损失
		loss = 1 - dice_eff.sum() / N
		return loss

多分类代码:

import torch
import torch.nn as nn

class MultiClassDiceLoss(nn.Module):
	def __init__(self, weight=None, ignore_index=None, **kwargs):
		super(MultiClassDiceLoss, self).__init__()
		self.weight = weight
		self.ignore_index = ignore_index
		self.kwargs = kwargs
	
	def forward(self, input, target):
		"""
			input tesor of shape = (N, C, H, W)
			target tensor of shape = (N, H, W)
		"""
		# 先将 target 进行 one-hot 处理,转换为 (N, C, H, W)
		nclass = input.shape[1]
		target = one_hot(target.long(), nclass)

		assert input.shape == target.shape, "predict & target shape do not match"
		
		binaryDiceLoss = BinaryDiceLoss()
		total_loss = 0
		
		# 归一化输出
		logits = F.softmax(input, dim=1)
		C = target.shape[1]
		
		# 遍历 channel,得到每个类别的二分类 DiceLoss
		for i in range(C):
			dice_loss = binaryDiceLoss(logits[:, i], target[:, i])
			total_loss += dice_loss
		
		# 每个类别的平均 dice_loss
		return total_loss / C

有关对于训练时loss出现负值的情况的更多相关文章

  1. ruby - 默认情况下使选项为 false - 2

    这是在Ruby中设置默认值的常用方法:classQuietByDefaultdefinitialize(opts={})@verbose=opts[:verbose]endend这是一个容易落入的陷阱:classVerboseNoMatterWhatdefinitialize(opts={})@verbose=opts[:verbose]||trueendend正确的做法是:classVerboseByDefaultdefinitialize(opts={})@verbose=opts.include?(:verbose)?opts[:verbose]:trueendend编写Verb

  2. ruby - 在没有 sass 引擎的情况下使用 sass 颜色函数 - 2

    我想在一个没有Sass引擎的类中使用Sass颜色函数。我已经在项目中使用了sassgem,所以我认为搭载会像以下一样简单:classRectangleincludeSass::Script::FunctionsdefcolorSass::Script::Color.new([0x82,0x39,0x06])enddefrender#hamlengineexecutedwithcontextofself#sothatwithintemlateicouldcall#%stop{offset:'0%',stop:{color:lighten(color)}}endend更新:参见上面的#re

  3. ruby - 在不使用 RVM 的情况下在 Mac 上卸载和升级 Ruby - 2

    我最近决定从我的系统中卸载RVM。在thispage提出的一些论点说服我:实际上,我的决定是,我根本不想担心Ruby的多个版本。我只想使用1.9.2-p290版本而不用担心其他任何事情。但是,当我在我的Mac上运行ruby--version时,它告诉我我的版本是1.8.7。我四处寻找如何简单地从我的Mac上卸载这个Ruby,但奇怪的是我没有找到任何东西。似乎唯一想卸载Ruby的人运行linux,而使用Mac的每个人都推荐RVM。如何从我的Mac上卸载Ruby1.8.7?我想升级到1.9.2-p290版本,并且我希望我的系统上只有一个版本。 最佳答案

  4. ruby - 使用 rbenv 和 ruby​​-build 构建 ruby​​ 失败,出现 undefined symbol : SSLv2_method - 2

    我正在尝试在配备ARMv7处理器的SynologyDS215j上安装ruby​​2.2.4或2.3.0。我用了optware-ng安装gcc、make、openssl、openssl-dev和zlib。我根据README中的说明安装了rbenv(版本1.0.0-19-g29b4da7)和ruby​​-build插件。.这些是随optware-ng安装的软件包及其版本binutils-2.25.1-1gcc-5.3.0-6gconv-modules-2.21-3glibc-opt-2.21-4libc-dev-2.21-1libgmp-6.0.0a-1libmpc-1.0.2-1libm

  5. ruby - 在什么情况下会使用 Sinatra 或 Merb? - 2

    我正在学习Rails,对Sinatra和Merb知之甚少。我想知道您会在哪些情况下使用Merb/Sinatra。感谢您的反馈! 最佳答案 Sinatra是一个比Rails更小、更轻的框架。如果你想让一些东西快速运行,只需发送几个URL并返回一些简单的内容,就可以使用它。看看Sinatrahomepage;这就是启动和运行“Hello,World”所需的全部内容,而在Rails中,您需要生成整个项目结构、设置Controller和View、设置路由等等(我还没有有一段时间写了一个Rails应用程序,所以我不知道“Hello,World

  6. ruby - 是否可以在不实际发送或读取数据的情况下查明 ruby​​ 套接字是否处于 ESTABLISHED 或 CLOSE_WAIT 状态? - 2

    s=Socket.new(Socket::AF_INET,Socket::SOCK_STREAM,0)s.connect(Socket.pack_sockaddr_in('port','hostname'))ssl=OpenSSL::SSL::SSLSocket.new(s,sslcert)ssl.connect从这里开始,如果ssl连接和底层套接字仍然是ESTABLISHED,或者它是否在默认值7200之后进入CLOSE_WAIT,我想检查一个线程几秒钟甚至更糟的是在实际上不需要.write()或.read()的情况下关闭。是用select()、IO.select()还是其他方法完成

  7. ruby-on-rails - 在这种情况下我如何模拟一个对象?没有明显的方法可以用模拟替换对象 - 2

    假设我在Store的模型中有这个非常简单的方法:defgeocode_addressloc=Store.geocode(address)self.lat=loc.latself.lng=loc.lngend如果我想编写一些不受地理编码服务影响的测试脚本,这些脚本可能已关闭、有限制或取决于我的互联网连接,我该如何模拟地理编码服务?如果我可以将地理编码对象传递到该方法中,那将很容易,但我不知道在这种情况下该怎么做。谢谢!特里斯坦 最佳答案 使用内置模拟和stub的rspecs,你可以做这样的事情:setupdo@subject=MyCl

  8. ruby - 在没有基准或时间的情况下用 Ruby 测量用户时间或系统时间 - 2

    因为我现在正在做一些时间测量,我想知道是否可以在不使用Benchmark类或命令行实用程序time的情况下测量用户时间或系统时间。使用Time类只显示挂钟时间,而不显示系统和用户时间,但是我正在寻找具有相同灵active的解决方案,例如time=TimeUtility.now#somecodeuser,system,real=TimeUtility.now-time原因是我有点不喜欢Benchmark,因为它不能只返回数字(编辑:我错了-它可以。请参阅下面的答案。)。当然,我可以解析输出,但感觉不对。*NIX系统的time实用程序也应该可以解决我的问题,但我想知道是否已经在Ruby中实

  9. ruby - 如何更优雅地记下这三种情况? - 2

    是否可以让这段代码更紧凑?我在这里错过了什么吗?ifvaluemax_ratemax_rateelsevalueend 最佳答案 这里有一些完全不同的东西:[min_rate,value,max_rate].sort[1] 关于ruby-如何更优雅地记下这三种情况?,我们在StackOverflow上找到一个类似的问题: https://stackoverflow.com/questions/13309740/

  10. ruby - 在没有 root 的情况下安装 Jekyll - 2

    我想在共享服务器上建立一个jekyll博客。当我尝试安装Jekyll时,我得到“您没有写权限”。我该如何在没有root或sudo的情况下解决这个问题?更多细节:我在共享服务器上有空间,但没有根访问权限。我无法安装Ruby,尽管托管公司应我的要求安装了它。当我尝试安装Jekyll时我使用user@hosting.org[~]#geminstalljekyll这是我得到的回应:ERROR:Whileexecutinggem...(Gem::FilePermissionError)Youdon'thavewritepermissionsintothe/usr/lib/ruby/gems/1.

随机推荐