【Chainer実践編1】深層学習に必要なデータってどんなの?ニューラルネットワークを訓練してみたら理解できた。

【Chainer実践編1】深層学習に必要なデータってどんなの?ニューラルネットワークを訓練してみたら理解できた。

晴れた日は大体家に引きこもります。

アイーンです。

 

今回から、Chainerの実践編ということで、花(アヤメ)の品種分類を行うことで、ちまたで噂のscikit-learnとは何者なのかを理解しつつ

分類とはどういうものか?

ディープラーニングに必要なデータとはどんなものなのか?

も併せて知ってしまおうという企画です。

scikit-learn(サイキットラーン)とは?

機械学習に必要なアルゴリズムがたっぷりと詰め込まれたありがた~いライブラリです。

データ解析やデータの分類といった、数学の深い知識が必要なプログラムでも、概要だけ知っておけばscikit-learnでちょちょいっとコード化することが出来たりします。

ちなみにライブラリってのは、Numpyとかmatplotlibとかの便利機能が入ってるものだと思ってください。

scikit-learnは、Anaconda3にもともと導入されています。

 

アヤメの品種分類

今回は、scikit-learnが持っているデータセットの中にある、アヤメの品種データを使って品種分類をやってみようという趣向です。

これがアヤメの花ですね。

アヤメの品種分類ということですが、別にscikit-learnの中にアヤメの画像があるわけじゃございません。

というのも、アヤメの品種の見分け方をご説明しますと

このように、がく片の長さと幅、花弁の長さと幅によって見分けるそうです。

すなわち入っているのは

sepal length (cm) がく片の長さ

sepal width (cm) がく片の幅

petal length (cm) 花弁の長さ

petal width (cm) 花弁の幅

という4種類の数値をもったリストです。しょぼ…。

 

データセットの中には3種類のアヤメデータが用意されており

 

setosa(セトサ)

 

versicolor(バージカラー)

 

virginica(ヴァージニカ)

以上3つが用意され、それぞれ

0がセトサ

1がバージカラー

2がヴァージニカ

という正解ラベルも持っています。

つまり、入っているのは

「5.1  3.5  1.4  0.2」「0」(セトサ)

てなもんですよ。情緒もクソもあったもんじゃない。

まぁ情緒は置いといて

 

このように4つの測定値と、品種のデータが

全部で150組用意されてます。

これを、分類できるようにニューラルネットワークを作り、訓練していきましょう。

ニューラルネットワーク訓練の概要

今回は4層のNN(ニューラルネットワーク)をつくり、学習させます。

先ほどの4データを入力層に渡し、出力層にはそれぞれの品種に該当する出力値が出てきます。

閾値を超え、出力が1に最も近い物が、「この品種である確率が高い!」とNNが判断した物になります。

また、今回は学習が終わったNNを使って、実際に品種の分類が可能かまでをやってみましょう。

 

それでは、次回からコードを書いてみます。

ここが出来れば、テストデータが理解出来ると思います。

 

これで貴方の肩書は「アヤメの品種分類師」です。

アヤメの見分け方を身に着けました。居酒屋で語れる程度に。

 

それではまた。