jjzjj

python - 如何计算 Tensorflow 中的所有二阶导数(仅 Hessian 矩阵的对角线)?

coder 2023-08-21 原文

我有一个损失值/函数,我想计算关于张量f(大小为n)的所有二阶导数。我设法使用了 tf.gradients 两次,但在第二次应用它时,它对第一个输入的导数求和(请参阅我的代码中的 second_derivatives)。

我还设法检索了 Hessian 矩阵,但我只想计算它的对角线以避免额外计算。

import tensorflow as tf
import numpy as np

f = tf.Variable(np.array([[1., 2., 0]]).T)
loss = tf.reduce_prod(f ** 2 - 3 * f + 1)

first_derivatives = tf.gradients(loss, f)[0]

second_derivatives = tf.gradients(first_derivatives, f)[0]

hessian = [tf.gradients(first_derivatives[i,0], f)[0][:,0] for i in range(3)]

model = tf.initialize_all_variables()
with tf.Session() as sess:
    sess.run(model)
    print "\nloss\n", sess.run(loss)
    print "\nloss'\n", sess.run(first_derivatives)
    print "\nloss''\n", sess.run(second_derivatives)
    hessian_value = np.array(map(list, sess.run(hessian)))
    print "\nHessian\n", hessian_value

我的想法是 tf.gradients(first_derivatives, f[0, 0])[0] 可以检索例如关于 f_0 的二阶导数,但似乎 tensorflow 没有不允许从张量的切片中导出。

最佳答案

tf.gradients([f1,f2,f3],...) 计算 f=f1+f2+f3 的梯度 此外,区分 x[0] 是有问题的,因为 x[0] 指的是一个新的 Slice 节点,它不是你的损失,因此关于它的导数将是 None。你可以通过使用 packx[0], x[1], ... 粘合到 xx 中来绕过它,并让你的损失取决于 xx 而不是 x 。另一种方法是为各个组件使用单独的变量,在这种情况下计算 Hessian 看起来像这样。

def replace_none_with_zero(l):
  return [0 if i==None else i for i in l] 

tf.reset_default_graph()

x = tf.Variable(1.)
y = tf.Variable(1.)
loss = tf.square(x) + tf.square(y)
grads = tf.gradients([loss], [x, y])
hess0 = replace_none_with_zero(tf.gradients([grads[0]], [x, y]))
hess1 = replace_none_with_zero(tf.gradients([grads[1]], [x, y]))
hessian = tf.pack([tf.pack(hess0), tf.pack(hess1)])
sess = tf.InteractiveSession()
sess.run(tf.initialize_all_variables())
print hessian.eval()

你会看到

[[ 2.  0.]
 [ 0.  2.]]

关于python - 如何计算 Tensorflow 中的所有二阶导数(仅 Hessian 矩阵的对角线)?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/38200982/

有关python - 如何计算 Tensorflow 中的所有二阶导数(仅 Hessian 矩阵的对角线)?的更多相关文章

  1. ruby - 如何使用 Nokogiri 的 xpath 和 at_xpath 方法 - 2

    我正在学习如何使用Nokogiri,根据这段代码我遇到了一些问题:require'rubygems'require'mechanize'post_agent=WWW::Mechanize.newpost_page=post_agent.get('http://www.vbulletin.org/forum/showthread.php?t=230708')puts"\nabsolutepathwithtbodygivesnil"putspost_page.parser.xpath('/html/body/div/div/div/div/div/table/tbody/tr/td/div

  2. ruby - 如何从 ruby​​ 中的字符串运行任意对象方法? - 2

    总的来说,我对ruby​​还比较陌生,我正在为我正在创建的对象编写一些rspec测试用例。许多测试用例都非常基础,我只是想确保正确填充和返回值。我想知道是否有办法使用循环结构来执行此操作。不必为我要测试的每个方法都设置一个assertEquals。例如:describeitem,"TestingtheItem"doit"willhaveanullvaluetostart"doitem=Item.new#HereIcoulddotheitem.name.shouldbe_nil#thenIcoulddoitem.category.shouldbe_nilendend但我想要一些方法来使用

  3. ruby - 其他文件中的 Rake 任务 - 2

    我试图在一个项目中使用rake,如果我把所有东西都放到Rakefile中,它会很大并且很难读取/找到东西,所以我试着将每个命名空间放在lib/rake中它自己的文件中,我添加了这个到我的rake文件的顶部:Dir['#{File.dirname(__FILE__)}/lib/rake/*.rake'].map{|f|requiref}它加载文件没问题,但没有任务。我现在只有一个.rake文件作为测试,名为“servers.rake”,它看起来像这样:namespace:serverdotask:testdoputs"test"endend所以当我运行rakeserver:testid时

  4. ruby-on-rails - Ruby net/ldap 模块中的内存泄漏 - 2

    作为我的Rails应用程序的一部分,我编写了一个小导入程序,它从我们的LDAP系统中吸取数据并将其塞入一个用户表中。不幸的是,与LDAP相关的代码在遍历我们的32K用户时泄漏了大量内存,我一直无法弄清楚如何解决这个问题。这个问题似乎在某种程度上与LDAP库有关,因为当我删除对LDAP内容的调用时,内存使用情况会很好地稳定下来。此外,不断增加的对象是Net::BER::BerIdentifiedString和Net::BER::BerIdentifiedArray,它们都是LDAP库的一部分。当我运行导入时,内存使用量最终达到超过1GB的峰值。如果问题存在,我需要找到一些方法来更正我的代

  5. python - 如何使用 Ruby 或 Python 创建一系列高音调和低音调的蜂鸣声? - 2

    关闭。这个问题是opinion-based.它目前不接受答案。想要改进这个问题?更新问题,以便editingthispost可以用事实和引用来回答它.关闭4年前。Improvethisquestion我想在固定时间创建一系列低音和高音调的哔哔声。例如:在150毫秒时发出高音调的蜂鸣声在151毫秒时发出低音调的蜂鸣声200毫秒时发出低音调的蜂鸣声250毫秒的高音调蜂鸣声有没有办法在Ruby或Python中做到这一点?我真的不在乎输出编码是什么(.wav、.mp3、.ogg等等),但我确实想创建一个输出文件。

  6. ruby-on-rails - Rails 3 中的多个路由文件 - 2

    Rails2.3可以选择随时使用RouteSet#add_configuration_file添加更多路由。是否可以在Rails3项目中做同样的事情? 最佳答案 在config/application.rb中:config.paths.config.routes在Rails3.2(也可能是Rails3.1)中,使用:config.paths["config/routes"] 关于ruby-on-rails-Rails3中的多个路由文件,我们在StackOverflow上找到一个类似的问题

  7. ruby-on-rails - 如何验证 update_all 是否实际在 Rails 中更新 - 2

    给定这段代码defcreate@upgrades=User.update_all(["role=?","upgraded"],:id=>params[:upgrade])redirect_toadmin_upgrades_path,:notice=>"Successfullyupgradeduser."end我如何在该操作中实际验证它们是否已保存或未重定向到适当的页面和消息? 最佳答案 在Rails3中,update_all不返回任何有意义的信息,除了已更新的记录数(这可能取决于您的DBMS是否返回该信息)。http://ar.ru

  8. ruby-on-rails - 'compass watch' 是如何工作的/它是如何与 rails 一起使用的 - 2

    我在我的项目目录中完成了compasscreate.和compassinitrails。几个问题:我已将我的.sass文件放在public/stylesheets中。这是放置它们的正确位置吗?当我运行compasswatch时,它不会自动编译这些.sass文件。我必须手动指定文件:compasswatchpublic/stylesheets/myfile.sass等。如何让它自动运行?文件ie.css、print.css和screen.css已放在stylesheets/compiled。如何在编译后不让它们重新出现的情况下删除它们?我自己编译的.sass文件编译成compiled/t

  9. ruby - 如何将脚本文件的末尾读取为数据文件(Perl 或任何其他语言) - 2

    我正在寻找执行以下操作的正确语法(在Perl、Shell或Ruby中):#variabletoaccessthedatalinesappendedasafileEND_OF_SCRIPT_MARKERrawdatastartshereanditcontinues. 最佳答案 Perl用__DATA__做这个:#!/usr/bin/perlusestrict;usewarnings;while(){print;}__DATA__Texttoprintgoeshere 关于ruby-如何将脚

  10. ruby - 如何指定 Rack 处理程序 - 2

    Rackup通过Rack的默认处理程序成功运行任何Rack应用程序。例如:classRackAppdefcall(environment)['200',{'Content-Type'=>'text/html'},["Helloworld"]]endendrunRackApp.new但是当最后一行更改为使用Rack的内置CGI处理程序时,rackup给出“NoMethodErrorat/undefinedmethod`call'fornil:NilClass”:Rack::Handler::CGI.runRackApp.newRack的其他内置处理程序也提出了同样的反对意见。例如Rack

随机推荐