博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
weighted_cross_entropy_with_logits
阅读量:6999 次
发布时间:2019-06-27

本文共 2049 字,大约阅读时间需要 6 分钟。

weighted_cross_entropy_with_logits

原创文章,请勿转载!!!

weighted_cross_entropy_with_logits(targets, logits, pos_weight, name=None):

此函数功能以及计算方式基本与tf_nn_sigmoid_cross_entropy_with_logits差不多,但是加上了权重的功能,是计算具有权重的sigmoid交叉熵函数

计算方法 :

\[pos_weight*targets * -log(sigmoid(logits)) + (1 - targets) * -log(1 - sigmoid(logits))\]

官方文档定义及推导过程:

通常的cross-entropy交叉熵函数定义如下:

\[targets * -log(sigmoid(logits)) + (1 - targets) * -log(1 - sigmoid(logits))\]

对于加了权值pos_weight的交叉熵函数:

\[ targets * -log(sigmoid(logits)) * pos_weight + (1 - targets) * -log(1 - sigmoid(logits))\]

现在我们使用 x = logits, z = targets, q = pos_weight的代数式

The loss is:        qz * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))      = qz * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))      = qz * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))      = qz * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))      = (1 - z) * x + (qz +  1 - z) * log(1 + exp(-x))      = (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(-x))

我们把l = (1 + (q - 1) * z), 来确保稳定性并且比避免溢出,公式为:

\[(1 - z) * x + l * (log(1 + exp(-abs(x))) + max(-x, 0))\]

logits and targets 必须要有相同的数据类型和shape.

参数:

_sentinel:本质上是不用的参数,不用填

targets:一个和logits具有相同的数据类型(type)和尺寸形状(shape)的张量(tensor)

shape:[batch_size,num_classes],单样本是[num_classes]

logits:一个数据类型(type)是float32或float64的张量

pos_weight:正样本的一个系数

name:操作的名字,可填可不填

实例代码

import numpy as npimport tensorflow as tfinput_data = tf.Variable(np.random.rand(3, 3), dtype=tf.float32)# np.random.rand()传入一个shape,返回一个在[0,1)区间符合均匀分布的arrayoutput = tf.nn.weighted_cross_entropy_with_logits(logits=input_data,                                                  targets=[[1.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 1.0]],                                                  pos_weight=2.0)with tf.Session() as sess:    init = tf.global_variables_initializer()    sess.run(init)    print(sess.run(output))# [[ 1.04947078  0.89594436  0.92146152]#  [ 0.70252579  1.00673866  1.08856964]#  [ 1.07195592  1.18525708  1.04106498]]

转载于:https://www.cnblogs.com/cloud-ken/p/7435579.html

你可能感兴趣的文章
JXL读写Excel
查看>>
mysql自定义排序
查看>>
java UDP 一对一文件传输
查看>>
Netty5入门学习笔记003-TCP粘包/拆包问题的解决之道(下)
查看>>
SpringMVC之@ResponseBody
查看>>
Ubuntu开机自动挂载Windows分区(NTFS FAT32)教程
查看>>
Oracle学习笔记6
查看>>
Centos7开通端口方法
查看>>
php数据库永久链接其实一般没必要使用,如果网站并发量大,数据库支持的连接数小就会出问题...
查看>>
oracle--架构
查看>>
动态规划的基本方法---多阶段决策过程及实例
查看>>
顺序数据---隐马尔科夫模型
查看>>
Spring boot 使用jpa时对于数据库的配置
查看>>
驰骋工作流引擎设计系列02
查看>>
Spring Security源码分析十:初识Spring Security OAuth2
查看>>
HDOJ 2087 KMP算法
查看>>
【转载】erlang 如何自定义 behaviour
查看>>
apache tomcat 集群 负债均衡 部署
查看>>
一步一步学Ruby(四):Ruby标准类型
查看>>
Node.js + WebSocket 实现的简易聊天室
查看>>