あどりぶろぐ=adli"B"log

表情認識を研究する海なし県民のブログ

めっちゃ簡単!scikit-learnでConfusion Matrix(混同行列)を実装しました!

 

f:id:shubarie:20180215170920p:plain

はじめに

今週はscikit-learnを使ってConfusion Matrixの作成と図示、保存の機能を実装しました。

Confusion Matrix(混同行列)

機械学習を用いたクラス分類の精度を評価するには、混同行列 (Confusion matrix) を作成して、正しく識別できた件数、誤って識別した件数を比較することが一般的です。混同行列は横方向に識別モデルが算出した識別結果、縦に実際の値 (正解データ) を記します。

参考:scikit-learn でクラス分類結果を評価する – Python でデータサイエンス

scikit-learn

Simple and efficient tools for data mining and data analysis

参考:http://scikit-learn.org/stable/

僕は表情認識の研究を以前から連続して行っており、Confusion Matrixは実装したことがありました。

 

先週、そのスクリプトを引っ張り出してきて見たのですが、これがものsssssっすごく読みにくい。。

 

当時の僕は、「自分が作って自分が読むだけのプログラムなんてMainに全部放り込めばええやろw 動けばいいんやw」という考え方だったのですが、、いや、自分でも読めねえよ過去の俺。。

f:id:shubarie:20180215170504p:plain

1年前の自分が作ったスクリプトが吐き出した画像


というわけで書き直すに至ったのですが、どうせならライブラリ使ってスマートに実装しようぜということで、今回のscikit-learnを使っての実装に至りました。

 

結果

 冒頭にも紹介したコチラが今回実装したスクリプトから吐き出されるConfusion Matrixです。

f:id:shubarie:20180215170920p:plain

現在の自分が作ったスクリプトが吐き出した画像

横軸ラベルが若干見切れてたり、ラベルが番号(0=Angry, 1=Disgust, 2=Fear, 3=Happy, 4=Sad, 5=Surprise, 6=Neutral)になっていたりと、まだまだ手直しする部分はありますが、必要十分なものができたと思います。

 

ついでに、precision、recall、f1-scoreなどのモデル評価値の計算と可視化も実装しました。こちらはテキスト(.txt)形式で保存されます。

 

f:id:shubarie:20180215171945p:plain

モデル評価値の計算と可視化

以下のリポジトリの、scripts/evaluation.pyの各種パスをお好みで定義し、実行することで、データに対する学習モデルのテストが行われ、その結果が保存されます。

github.com

実装方法

中で実際に 動いているソースコードはこんな感じです。


def evaluate(self,input_dir,result_dir,labels):
    # confusion matrixをプロットし画像として保存する関数
    # 参考: http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html#sphx-glr-auto-examples-model-selection-plot-confusion-matrix-py
        def plot_confusion_matrix(cm, classes, output_file,
                                  normalize=False,
                                  title='Confusion matrix',
                                  cmap=plt.cm.Blues):
            """
            This function prints and plots the confusion matrix.
            Normalization can be applied by setting `normalize=True`.
            """
            if normalize:
                cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
                print("Normalized confusion matrix")
            else:
                print('Confusion matrix, without normalization')

            print(cm)

            plt.imshow(cm, interpolation='nearest', cmap=cmap)
            plt.title(title)
            plt.colorbar()
            tick_marks = np.arange(len(classes))
            plt.xticks(tick_marks, classes, rotation=45)
            plt.yticks(tick_marks, classes)

            fmt = '.2f' if normalize else 'd'
            thresh = cm.max() / 2.
            for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
                plt.text(j, i, format(cm[i, j], fmt),
                         horizontalalignment="center",
                         color="white" if cm[i, j] > thresh else "black")

            plt.tight_layout()
            plt.ylabel('True label')
            plt.xlabel('Predicted label')
            plt.savefig(output_file)
            
        # 検証用ディレクトリを掘りながらpredictする
        y_true=[]
        y_pred=[]
        files=os.listdir(input_dir)
        for file_ in files:
            sub_path=os.path.join(input_dir,file_)
            subfiles=os.listdir(sub_path)
            for subfile in subfiles:
                img_path=os.path.join(sub_path,subfile)
                img=cv2.imread(img_path,cv2.IMREAD_GRAYSCALE)
                _,reshaped_img=self.dataset.reshape_img(img)
                label,_=self.predict(reshaped_img)
                y_true.append(int(file_))
                y_pred.append(label)
                
        # 有効桁数を下2桁とする
        np.set_printoptions(precision=2)
        
        # accuracyの計算
        accuracy=accuracy_score(y_true,y_pred)
        
        # confusion matrixの作成
        cnf_matrix=confusion_matrix(y_true,y_pred,labels=labels)
        
        # report(各種スコア)の作成と保存
        report=classification_report(y_true,y_pred,labels=labels)
        report_file=open(result_dir+"/report.txt","w")
        report_file.write(report)
        report_file.close()
        print(report)
        
        # confusion matrixのプロット、保存、表示
        title="overall accuracy:"+str(accuracy)
        plt.figure()
        plot_confusion_matrix(cnf_matrix, classes=labels,output_file=result_dir+"/CM_without_normalize.png",title=title)
        plt.figure()
        plot_confusion_matrix(cnf_matrix, classes=labels,output_file=result_dir+"/CM_normalized.png", normalize=True,title=title)
        plt.show()

 

このうち、実際にConfusion Matrixを作成しているのはこの部分。


# confusion matrixの作成
cnf_matrix=confusion_matrix(y_true,y_pred,labels=labels)

肝心のConfusion Matrixは、なんと1行で実装できてしまったw

 

それ以外の部分は、sklearn.metrics.confusion_matrixの入力調整と、出力されたConfusion Matrixを表示・保存を行っているだけです。

 

sklearn.metrics.confusion_matrixへの入力

40~53行目に対応

 

sklearn.metrics.confusion_matrixのリファレンスを見るとどうやら、入力であるy_trueとy_predは「サンプル数と同じ長さを持つラベルの入った配列」らしい。

 

そこで今回は、ディレクトリを掘りながら、画像1枚1枚をテストし、出力ラベルをy_predに格納する方法で入力を作成した。

Confusion Matrixの表示と保存

5~38行目および71~77行目に対応

 

 これに関しては公式のサンプルを流用した。

Confusion matrix — scikit-learn 0.19.1 documentation

 

感想

比較的、再利用性の高いスクリプトがかけた(気がする)。

今後はたくさんデータをとり、それを学習モデルにかけて、生成されたConfusion Matrixを見ながら映像コンテンツ推薦アルゴリズムを開発するための知見を集めていきます。

 

あと関係ないけど、ブログにスクリプトを張るのが想像以上に面倒くさかったです。

まだまだ見た目が悪い…

以下、参考にしました。

mae.chab.in