【Chainer基礎3】Chainってなに?分かりやすく5分で解説

【Chainer基礎3】Chainってなに?分かりやすく5分で解説

ダンゴムシの腹が蛇腹状になってるのは異様にムカつくけど

エビの腹が蛇腹状になってるのは「美味そっ♪」となります。

アイーンです。

 

本日もChainerを使いこなす基礎学習の続き

今日はChainについて勉強していきましょう。

 

それではどうぞ~

 

Chainってなに?

Chainとは、ニューラルネットワークの全体を管理する親方的存在です。あだ名は「おやっさん」

複数のLinkを組み合わせて構成されるニューラルネットワークでは、Chainクラスの配下にLinkを作ってまとめ上げることで、一つのChainオブジェクトとして管理することが出来ます。

 

また、CPUを使った計算処理からGPUでの計算へとデータを移動したりでき、これにより通常の4倍(当社比)で計算するなどの赤い彗星顔負けの効果をもたらします。

 

別におやっさ……Chainが無くても、VariableとLinksで十分ニューラルネットワークは組めます。

しかし、Chainの優れた箇所の紹介は既に終わりました(爆)

長いコードを書く際に有効なので、今回は簡単にコードの比較でお茶を濁してみましょう。

 

関数、クラス、Chainの3種類でニューラルネットワークを組んでみます。

 

コードの比較:関数で作るニューラルネットワーク

さ、importしてみます。

前回との違いは、3行目ですね。

from chainer import Variable , Chain

Chainの記述が増えていますね。これでChainが使えます。

 

さっそくLinearを書きましょう。

l1 = L.Linear( 4 , 3)

l2 = L.Linear( 3 , 2)

 

 

次に、この2つをまとめる関数を書きます。

 

def forward(x):

    b = l1(x)

    return l2(b)

 

入力4個を出力3個に変換するl1

l1から来た入力3個をl2出力2個に変換します。

これで完成です。

 

それでは実験用のVariableオブジェクトを作りましょう。

input_array = np.aray( [ [1 , 2 , 3 , 4 ]] , dtype = np.float32)

最初の入力は4個なので

[1.2.3.4.]の入ったinput_arrayを作ります。

 

x = Variable(input_array)

を使ってVariableオブジェクトに変換しました。

さぁ、入れてみましょう。

 

y = forward(x)

とすると、y.dataは出力2個になりました。

 

次に、クラスを使ってニューラルネットワークを書いてみましょう。

 

コードの比較:クラスで作るニューラルネットワーク

クラスを使って、関数の時と同様のl1とl2を書くとこんな風になります。

class Class():
    def __init__(self):
        self.l1 = L.Linear(4,3)
        self.l2 = L.Linear(3,2)

    def forward(self,x):
        b = self.l1(x)
        return self.l2(b)

 

__init__(初期設定)でl1とl2を書き

下にLinkを組み合わせる処理を書きます。

では、クラスを使って処理を行いましょう。

 

また下準備。

また[1. 2.3.4.]のxを作ります。

 

class1 = Class()

まずはclass1というインスタンスを作成しまして・・・

y = class1.forward(x)

class1のforward()にxを入力っと。

 

はい、出力結果は2つになりましたね。

成功成功。

 

最後にChainを使って同じ処理を書きましょう。

 

コードの比較:Chainで作るニューラルネットワーク

class firstChain(Chain):
    def __init__(self):
        super().__init__(
            l1=L.Linear(4, 3),
            l2=L.Linear(3, 2),
        )

    def __call__(self, x):
         h = self.l1(x)
         return self.l2(h)

 

Chainの書き方はクラスと同様です。firstChainという名前で作ります。

super().__init__でChainクラスを継承し、初期設定にl1とl2を書いてみましょう。

ちなみにChainの配下に書かれたLinkを子リンクと呼びます。

 

あと、__call__を使って、呼び出すだけで入出力の処理を行えるようにします。

 

実験用のVariableオブジェクト作成。

おんなじ記述です。

 

chain = firstChain()

まずはchainというインスタンスを作成しまして・・・

y = chain(x)

chain()にxを入力すると

 

出力が2つになりました。

おんなじ処理を三回するのは、なかなか飽きます。

 

 

Chainerを使ってニューラルネットワークを記述する際、今後はChainを使って記入していきましょう。

 

これで貴方の肩書は「Chain使い」です。

武器なら、たぶんムチ的な使い方をするんでしょう。

 

それではまた。