Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
GZYZG committed Mar 17, 2022
1 parent 40b0c76 commit b157d66
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 63 deletions.
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# Exercise and practice when I'm learning :open_book:
这个项目用于记录我读研期间的一些练习,作为我的练兵场,记录我练习技能的场所!
:fire: :fire: :fire:
:fire: :fire: :fire:

现在已经有的内容:
- [通用指标的计算](./general_metrics.py),实现了混淆矩阵和混淆矩阵可视化
- [语义分割指标的计算](./segmentation_metrics.py),只实现了 dice 系数和 IoU 的计算
-
63 changes: 63 additions & 0 deletions general_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import numpy as np
import matplotlib.pyplot as plt


def confusion_matrix(label, predict, n):
"""
计算混淆矩阵
:param label: 标签,np.array类型。形状可以是(n_sample,) 或者 (n_sample, n_classes),当为第二种形状时可以表示多标签分类的情况
:param predict: 预测值,与 `label` 同理
:param n: 类别数目
:return: 混淆矩阵,np.array类型。shape 为 (n, n)。$cm_{ij}$表示真实标签为 $i$,预测标签为 $j$ 的样本个数
"""
k = (label >= 0) & (label < n)
# bincount()函数用于统计数组内每个非负整数的个数
# 详见 https://docs.scipy.org/doc/numpy/reference/generated/numpy.bincount.html
return np.bincount(n * label[k].astype(int) + predict[k], minlength=n ** 2).reshape(n, n)


def plot_confusion_matrix(cm, classes, normalized=False, title="Confusion matrix", cmap=plt.cm.Blues):
"""
画混淆矩阵
:param cm: 混淆矩阵
:param classes: 类别列表。classes[i] 表示 i 所对应的类名
:param normalized: 是否进行归一化
:param title: str,表头
:param cmap: 配色方案
:return:
"""
if normalized:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print("Normalized confusion matrix")
else:
print("Confusion matrix with out normalization")

print(cm)

plt.imshow(cm, interpolation="nearest", cmap=cmap)
plt.title(title, fontsize=14)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)

fmt = '.2f' if normalized else 'd'
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, format(cm[i, j], fmt),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()


if __name__ == "__main__":
label = np.array([2, 0, 1, 1]) # np.array([[0, 1, 1], [1, 1, 0], [0, 0, 1]])
predict = np.array([0, 0, 1, 1]) # np.array([[0, 1, 0], [1, 0, 1], [0, 0, 0]])
cm = confusion_matrix(label, predict, 3)
plot_confusion_matrix(cm, [0, 1, 2], title="Confusion Matrix")
print(f"Confusion Matrix: \n{cm}")

print(f"mIoU: {IoU(label, predict, 3)}")
62 changes: 0 additions & 62 deletions segmentation_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,9 @@
图像分割的评价指标
"""
import numpy as np
import matplotlib.pyplot as plt
import itertools


def confusion_matrix(label, predict, n):
"""
计算混淆矩阵
:param label: 标签,np.array类型。形状可以是(n_sample,) 或者 (n_sample, n_classes),当为第二种形状时可以表示多标签分类的情况
:param predict: 预测值,与 `label` 同理
:param n: 类别数目
:return: 混淆矩阵,np.array类型。shape 为 (n, n)。$cm_{ij}$表示真实标签为 $i$,预测标签为 $j$ 的样本个数
"""
k = (label >= 0) & (label < n)
# bincount()函数用于统计数组内每个非负整数的个数
# 详见 https://docs.scipy.org/doc/numpy/reference/generated/numpy.bincount.html
return np.bincount(n * label[k].astype(int) + predict[k], minlength=n ** 2).reshape(n, n)


def IoU(label, predict, class_n):
"""
计算各类的IoU, Intersection over Union
Expand All @@ -39,43 +24,6 @@ def IoU(label, predict, class_n):
return miou


def plot_confusion_matrix(cm, classes, normalized=False, title="Confusion matrix", cmap=plt.cm.Blues):
"""
画混淆矩阵
:param cm: 混淆矩阵
:param classes: 类别列表。classes[i] 表示 i 所对应的类名
:param normalized: 是否进行归一化
:param title: str,表头
:param cmap: 配色方案
:return:
"""
if normalized:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print("Normalized confusion matrix")
else:
print("Confusion matrix with out normalization")

print(cm)

plt.imshow(cm, interpolation="nearest", cmap=cmap)
plt.title(title, fontsize=14)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)

fmt = '.2f' if normalized else 'd'
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, format(cm[i, j], fmt),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()


def dice_coefficient(pred, target, smooth=1):
"""
计算dice系数
Expand All @@ -92,13 +40,3 @@ def dice_coefficient(pred, target, smooth=1):
intersection = (t1 * t2).sum()

return 2. * intersection / (t1.sum() + t2.sum() + smooth)


if __name__ == "__main__":
label = np.array([2, 0, 1, 1]) # np.array([[0, 1, 1], [1, 1, 0], [0, 0, 1]])
predict = np.array([0, 0, 1, 1]) # np.array([[0, 1, 0], [1, 0, 1], [0, 0, 0]])
cm = confusion_matrix(label, predict, 3)
plot_confusion_matrix(cm, [0, 1, 2], title="Confusion Matrix")
print(f"Confusion Matrix: \n{cm}")

print(f"mIoU: {IoU(label, predict, 3)}")

0 comments on commit b157d66

Please sign in to comment.