VAEによる異常検知(マハラノビス距離編)
前回の続きでここではmnistの「1」という数字以外(今回は「9」)の画像に対し異常検知するプログラムで説明していきます。
今回の実装では以上のgithubと、以下のサイトを参考にして実装しました。
VariationalAutoencoderRunner.pyの
X_train, X_test = min_max_scale(mnist.train.images, mnist.test.images)
直下に以下のコードを追加します。
VariationalAutoencoder.pyの最後に以下のコードを追加します。
# インポートと結果画像の保存先指定 %matplotlib inline import matplotlib.pyplot as plt import numpy as np import os
# テストデータを「1」だけにする xx1, yy1 , zz1 = ,, aa=np.array([0,1,0,0,0,0,0,0,0,0]) #「1」のラベル bb=np.array([0,0,0,0,0,0,0,0,0,1]) #「9」のラベル for i in range(len(X_test)): if (y[i] == aa).all(): xx1.append(X_test[i]) yy1.append(y) xx1 = np.array(xx1) yy1 = np.array(yy1)
# テストデータを「9」だけにする xx9, yy9 , zz9 = ,, for i in range(len(X_test)): if (y[i] == bb).all(): xx9.append(X_test[i]) yy9.append(y[i]) xx9 = np.array(xx9) yy9 = np.array(yy9)
# テスト画像「1」と「9」をそれぞれエンコードする dxx1 = autoencoder.transform(xx1) dxx9 = autoencoder.transform(xx9)
# 「1」からのマハラノビス距離をそれぞれ計算する #calculate mahalanobis distance from scipy.spatial.distance import mahalanobis
# calculate covariance matrix for training data sigma = np.cov(dxx1, rowvar=False)
# make inverse of it inv_sigma = np.linalg.inv(np.cov(dxx1, rowvar=False))
# calculate mean for training data mean = np.mean(dxx1, axis=0)
# calculate mahalanobis distances outlier_dists = [mahalanobis(mean, outlier, inv_sigma) for outlier in dxx9] inlier_dists = [mahalanobis(mean, inlier, inv_sigma) for inlier in dxx1]
average_outlier_dist = np.mean(outlier_dists) average_inlier_dist = np.mean(inlier_dists) print('average_outlier_dist', average_outlier_dist) print('average_inlier_dist', average_inlier_dist) print(sigma)
# マハラノビス距離を用いて閾値を決定する THRESHOLD = 0 N_OUTLIERS=len(outlier_dists) N_INLIERS=len(inlier_dists) print(N_OUTLIERS) print(N_INLIERS) thr= f_value= rec= pre= r=0 b=0 for i in range(100): threshold = THRESHOLD + 0.5 * i for a in range(N_OUTLIERS): if outlier_dists[a] > threshold: r=r+1 for a in range(len(inlier_dists)): if inlier_dists[a] > threshold: b=b+1 # the number of predicted outliers n = r + b print("{}回目".format(i+1)) print("n={}".format(n))
precision = r / n recall = r / N_OUTLIERS f = 2 * precision * recall / (precision + recall) # print('thr={},f={},p={},r={}'.format(threshold, f, precision, recall)) print('{} {} {} {}\n'.format(threshold, f, precision, recall)) thr.append(threshold) f_value.append(f) rec.append(recall) pre.append(precision) r=0 b=0
# display F values plt.figure(figsize=(10, 5)) thrs = prs = res = [] plt.xlabel('threshold', fontsize=20) plt.plot(thr, rec, marker='.', label='Recall') plt.plot(thr, f_value, marker='.', label='F value') plt.plot(thr, pre, marker='.', label='Precision')
plt.legend(loc='best') plt.show()
max=np.argmax(f_value) print("thresholdの値が {} のとき".format(round(thr[max],2))) print("F_valueは最大値 {} をとる".format(round(f_value[max],3))) print("このときPrecisionは{}、Recallは{}である。".format(pre[max],rec[max]))
# 閾値を用いて異常検知をする from numpy.random import * plt.figure(figsize=(12, 4)) k=randint(200) batch_xs = X_test[k] batch_xsd=autoencoder.transform(batch_xs.reshape([1,784])) ax = plt.subplot(2, 2, 1) plt.imshow(batch_xs.reshape(28, 28)) plt.gray() ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) mah=mahalanobis(mean, batch_xsd, inv_sigma) print("マハラノビス距離:" + str(mah)) if mah > thr[max]: print('1ではない(異常)') else: print('1である(正常)') |
まったく同じ結果にはなりませんが閾値とそのときのF値などがでます。
最後にテスト画像(「1」か「9」)からランダムに1枚選び検知するようにしました。
以下がそのときの結果です。今回の場合マハラノビス距離が上で求められた閾値17.5以下の場合正常、17.5以上の場合異常と判断するようになっています。
マハラノビス距離を用いた異常検知
参考に「1」と「9」のマハラノビス距離の分布を示すと以下のようになっています。
直観的にも17.5付近でうまく分けることができることが期待できます。