TensorFlow定义与混淆矩阵相关的实用程序
2018-09-14 16:08 更新
#版权所有2015 TensorFlow作者.版权所有.
#
#根据Apache许可证版本2.0(“许可证”)许可;
#除非符合许可证,否则您不得使用此文件.
#您可以获得许可证的副本
#
#http://www.apache.org/licenses/LICENSE-2.0
#
#除非适用法律要求或书面同意软件
根据许可证分发的#分发在“按原样”基础上,
#无明示或暗示的任何种类的保证或条件.
#查看有关权限的特定语言的许可证
许可证下的#限制.
# =============================================== =============================
""混淆矩阵相关的实用程序.""
@@remove_squeezable_dimensions
@@confusion_matrix
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
def remove_squeezable_dimensions(
labels, predictions, expected_rank_diff=0, name=None):
"""Squeeze last dim if ranks differ from expected by exactly 1.
In the common case where we expect shapes to match, `expected_rank_diff`
defaults to 0, and we squeeze the last dimension of the larger rank if they
differ by 1.
But, for example, if `labels` contains class IDs and `predictions` contains 1
probability per class, we expect `predictions` to have 1 more dimension than
`labels`, so `expected_rank_diff` would be 1. In this case, we'd squeeze
`labels` if `rank(predictions) - rank(labels) == 0`, and
`predictions` if `rank(predictions) - rank(labels) == 2`.
This will use static shape if available. Otherwise, it will add graph
operations, which could result in a performance hit.
Args:
labels: Label values, a `Tensor` whose dimensions match `predictions`.
predictions: Predicted values, a `Tensor` of arbitrary dimensions.
expected_rank_diff: Expected result of `rank(predictions) - rank(labels)`.
name: Name of the op.
Returns:
Tuple of `labels` and `predictions`, possibly with last dim squeezed.
"""
with ops.name_scope(name, 'remove_squeezable_dimensions',
[labels, predictions]):
predictions = ops.convert_to_tensor(predictions)
labels = ops.convert_to_tensor(labels)
predictions_shape = predictions.get_shape()
predictions_rank = predictions_shape.ndims
labels_shape = labels.get_shape()
labels_rank = labels_shape.ndims
if (labels_rank is not None) and (predictions_rank is not None):
# Use static rank.
rank_diff = predictions_rank - labels_rank
if rank_diff == expected_rank_diff + 1:
predictions = array_ops.squeeze(predictions, [-1])
elif rank_diff == expected_rank_diff - 1:
labels = array_ops.squeeze(labels, [-1])
return labels, predictions
# Use dynamic rank.
rank_diff = array_ops.rank(predictions) - array_ops.rank(labels)
if (predictions_rank is None) or (
predictions_shape.dims[-1].is_compatible_with(1)):
predictions = control_flow_ops.cond(
math_ops.equal(expected_rank_diff + 1, rank_diff),
lambda: array_ops.squeeze(predictions, [-1]),
lambda: predictions)
if (labels_rank is None) or (
labels_shape.dims[-1].is_compatible_with(1)):
labels = control_flow_ops.cond(
math_ops.equal(expected_rank_diff - 1, rank_diff),
lambda: array_ops.squeeze(labels, [-1]),
lambda: labels)
return labels, predictions
def confusion_matrix(labels, predictions, num_classes=None, dtype=dtypes.int32,
name=None, weights=None):
"""Computes the confusion matrix from predictions and labels.
Calculate the Confusion Matrix for a pair of prediction and
label 1-D int arrays.
The matrix columns represent the prediction labels and the rows represent the
real labels. The confusion matrix is always a 2-D array of shape `[n, n]`,
where `n` is the number of valid labels for a given classification task. Both
prediction and labels must be 1-D arrays of the same shape in order for this
function to work.
If `num_classes` is None, then `num_classes` will be set to the one plus
the maximum value in either predictions or labels.
Class labels are expected to start at 0. E.g., if `num_classes` was
three, then the possible labels would be `[0, 1, 2]`.
If `weights` is not `None`, then each prediction contributes its
corresponding weight to the total value of the confusion matrix cell.
For example:
```python
tf.contrib.metrics.confusion_matrix([1, 2, 4], [2, 2, 4]) ==>
[[0 0 0 0 0]
[0 0 1 0 0]
[0 0 1 0 0]
[0 0 0 0 0]
[0 0 0 0 1]]
```
Note that the possible labels are assumed to be `[0, 1, 2, 3, 4]`,
resulting in a 5x5 confusion matrix.
Args:
labels: 1-D `Tensor` of real labels for the classification task.
predictions: 1-D `Tensor` of predictions for a given classification.
num_classes: The possible number of labels the classification task can
have. If this value is not provided, it will be calculated
using both predictions and labels array.
dtype: Data type of the confusion matrix.
name: Scope name.
weights: An optional `Tensor` whose shape matches `predictions`.
Returns:
A k X k matrix representing the confusion matrix, where k is the number of
possible labels in the classification task.
Raises:
ValueError: If both predictions and labels are not 1-D vectors and have
mismatched shapes, or if `weights` is not `None` and its shape doesn't
match `predictions`.
"""
with ops.name_scope(name, 'confusion_matrix',
(predictions, labels, num_classes, weights)) as name:
labels, predictions = remove_squeezable_dimensions(
ops.convert_to_tensor(labels, name='labels'),
ops.convert_to_tensor(
predictions, name='predictions'))
predictions = math_ops.cast(predictions, dtypes.int64)
labels = math_ops.cast(labels, dtypes.int64)
if num_classes is None:
num_classes = math_ops.maximum(math_ops.reduce_max(predictions),
math_ops.reduce_max(labels)) + 1
if weights is not None:
predictions.get_shape().assert_is_compatible_with(weights.get_shape())
weights = math_ops.cast(weights, dtype)
shape = array_ops.stack([num_classes, num_classes])
indices = array_ops.transpose(array_ops.stack([labels, predictions]))
values = (array_ops.ones_like(predictions, dtype)
if weights is None else weights)
cm_sparse = sparse_tensor.SparseTensor(
indices=indices, values=values, dense_shape=math_ops.to_int64(shape))
zero_matrix = array_ops.zeros(math_ops.to_int32(shape), dtype)
return sparse_ops.sparse_add(zero_matrix, cm_sparse)
以上内容是否对您有帮助:
更多建议: