【AI】ニューラルネットワークから重み、バイアスを取得する方法(Keras版)

こんにちは、ヒガシです。

 

このページでは、Keras使って構築したAIモデル(ニューラルネットワーク)が保有する「重み」と「バイアス」をnumpyの配列として取得する方法をご紹介していきます。

 

それではさっそくやっていきましょう!

 

スポンサーリンク

使用するニューラルネットワークの紹介

今回サンプルとして使用するモデルはシンプルなMLP(Multi Layer Perceptron)モデルです。

 

構造は以下の通りです。

入力は4(上には書いてないけど)、中間層が20,10、出力が1 のモデルになっています。

 

このモデル内のDense3層について重みとバイアスを取得してみます。

 

なお、今後このモデルは変数名「model」に格納されている状態で話を進めていきます。

 

スポンサーリンク

モデルから層を切り出す方法

重みとバイアスを取り出す際、まずはモデルを各層に分割してあげる必要があります。

 

というわけでモデルから層を切り出す方法は以下の通りです。

layer = model.layers[ i ]
layer:切り出した層の変数
i :切り出す層

 

これで変数layerに指定した層を切り出すことができます。

 

そしてこの変数layerの中に今回取り出したい重みとバイアスが含まれています。

 

スポンサーリンク

取り出した層から重みとバイアスを取得する

それではこの記事の本題である「重み」と「バイアス」を取り出す方法をご紹介します。

 

先ほど紹介したやり方で層を切り出し、変数layerに格納された状態を想定して書いています。

weights=layer.get_weights()[0]
bias=layer.get_weights()[1]

これで変数weightsに重みが、変数biasにバイアスが配列として格納されます。

スポンサーリンク

層の切り出し⇒重み&バイアス取得を1行でやる場合

ここまで紹介したやり方は1st_stepで層の切り出し、2nd_stepで重み、バイアスの取得を行いましたが、これらはまとめて実行することもできます。

 

以下がそのやり方です。

weights = model.layers[i].get_weights()[0]
bias=model.layers[i].get_weights()[1]

これでも変数weightsに重みを、変数biasにバイアスを配列として格納することができます。

スポンサーリンク

モデル内のすべての層の重みとバイアスを取得した結果

それでは冒頭に紹介した3層のMLPモデルに対して、先ほど紹介したやり方で各層の重みとバイアスを取得してみます。

 

以下がそのコードです。

※層の切り出し、重み&バイアス取得を分けて実行しています。

#1層目の係数を取得
layer1 = model.layers[0]
L1_weights=layer1.get_weights()[0]
L1_bias=layer1.get_weights()[1]
#2層目の係数を取得
layer2 = model.layers[1]
L2_weights=layer2.get_weights()[0]
L2_bias=layer2.get_weights()[1]
#3層目の係数を取得
layer3 = model.layers[2]
L3_weights=layer3.get_weights()[0]
L3_bias=layer3.get_weights()[1]

 

これで各変数に重み、バイアスが格納されています。

 

中身の数値を見ても仕方ないので大きさだけ確認してみましょう。

以下のコードで確認します。

print('1層目の重み、バイアス')
print(L1_weights.shape)
print(L1_bias.shape)
print('2層目の重み、バイアス')
print(L2_weights.shape)
print(L2_bias.shape)
print('3層目の重み、バイアス')
print(L3_weights.shape)
print(L3_bias.shape)

こいつを実行すると以下の結果が出力されました。

先ほども紹介した通り、今回使用したモデルは、入力が4(上には書いてないけど)、中間層が20,10、出力が1 のMLPモデルですので、問題なく対応する形で重み、バイアスが取得できていることがわかりますね。

 

スポンサーリンク

おわりに

というわけで今回はKerasで構築したニューラルネットワークのモデルから、「重み」と「バイアス」を取得する方法をご紹介しました。

 

実務においては非常に役立つスキルですので、ぜひ覚えておきましょう。

 

このように、私のブログでは様々なスキルを紹介しています。

過去記事一覧

 

今は仕事中で時間がないかもしれませんが、ぜひ通勤時間中などに他の記事も読んでいただけると嬉しいです。
⇒興味をもった方は【ヒガサラ】で検索してみてください。

確実にスキルアップできるはずです。

 

最後に、この記事が役に立ったという方は、ぜひ応援よろしくお願いします。
↓ 応援ボタン
にほんブログ村 IT技術ブログへ
にほんブログ村

それではまた!

コメント

タイトルとURLをコピーしました