論文を読む"Distilling the knowledge in a neural network."⑬
背景
「蒸留(Distillation)」に関して、論文"Ravi Teja Mullapudi, Online Model Distillation for Efficient Video Inference."を読んだ。
その中で引用されていたモデル蒸留のために広く使用されている技術の論文として、次の論文を読み進めている。
[18] G. Hinton, O. Vinyals, and J. Dean. Distilling the knowledge in a neural network. arXiv preprint arXiv:1503.02531, 2015.
前回まで
Abstract
- 背景
機械学習アルゴリズムは、複数の異なるモデルの推論を平均化することで性能が向上する(モデルのアンサンブル)。
- 問題
モデルのアンサンブルは計算コストが高すぎる。
- 従来手法
Caruanaらの"Model compression."は、アンサンブル内の知識を単一のモデルに圧縮することを示している。
- 提案手法
従来手法をさらに異なる圧縮技術を用いることで発展させる。また、フルモデルとスペシャリストモデルとで構成するアンサンブルも提案。
- 効果
MNISTでいくつかの驚くべき結果を達成。
商用音響モデルを大幅に改善。
スペシャリストモデルは迅速かつ並行して訓練することが可能。
1 Introduction
- トレーニングと展開、それぞれの最適化
- 機械学習における知識
- 蒸留という有望なアプローチを阻害するのは、「モデルの中の知識と学習されたパラメータを区別してしまう傾向」。
- 知識とは「学習済みマッピング」である。
- だからこそ、不正解に対する確率もごくわずかであれ定義されてしまう。
- モデルを一般化する正しい方法
- 目的関数はユーザーの真の目的を反映すべきであるが、実際は反映できていない。
- モデルを一般化する正しい方法に関する情報が必要であるが、普通は利用不可能。
- ただし、大きなモデルを学習済みで、蒸留によって小さなモデルを作る場合、大きなモデルの学習に用いた方法が小さなモデルにも適用できる。
- 転送の具体的な方法
- 「面倒なモデル」の出力情報の活用
- Caruanaらは確率ではなくロジットを用いた。
- 本論文の「蒸留」では、ソフトマックスの温度(temperature)を使って、より適切なターゲットのソフトセットを生成する。
2 Distillation
- 「温度付きソフトマックス関数」について
- 最も単純な形式の蒸留
- 蒸留モデルは転送セットで訓練される。
- 転送セットは、温度を高くしたソフトマックス関数を用いた面倒なモデルで生成されたもの。
- 蒸留モデルの学習時は同じ温度、学習後は1。
- 転送セットのラベルが既知の場合、より良い改善方法は、異なる2つの目的関数の加重平均の使用
2.1 Matching logits is a special case of distillation
- 第1の目的関数は、蒸留モデルと面倒なモデルのロジット間の差の二乗を最小化すること
- 温度は高すぎず低すぎない中間の値が有効
3 Preliminary experiments on MNIST(MNISTでの予備実験)
- 大モデルに対する小モデルの比較をソフトターゲットマッチング追加タスク有無で行ない、蒸留の効果を確認
- 温度とユニットの数を振った実験の実施
- 転送セットから一部のクラスを削除した場合でも、ロバスト性を確認
4 Experiments on speech recognition(音声認識に関する実験)
- 自動音声認識(ASR)で使用されているディープニューラルネットワーク(DNN)音響モデルを、アンサンブルすることの効果について調査
- 本稿で提案する蒸留戦略が望ましい効果を達成することを証明
- すなわち、同じサイズのモデルを同じトレーニングデータから直接学習するよりも、はるかにうまく機能する単一のモデルを、モデルのアンサンブルから蒸留することが可能
5 Training ensembles of specialists on very big datasets(非常に大きなデータセットに関する専門家アンサンブルのトレーニング)
- アンサンブル用の個々のモデルのネットワーク規模、データセットサイズが非常に大きい場合、計算量大
- 「専門家モデル」を学習することで、 アンサンブルの学習に必要な計算量を削減
5.1 The JFT dataset
- データセットサイズが非常に大きい"JFT dataset"は、当時訓練に半年を要した。これからアンサンブルを作ろうとすれば、数年の訓練時間が必要になり、現実的ではない。
論文読解
Google翻訳した上で、自分の理解しやすいように修正しながら読んでいく。
5.2 Specialist Models
クラスの数が非常に多い場合は、面倒なモデルが、すべてのデータについてトレーニングされた1つのジェネラリストモデルと、非常に紛らわしいクラスサブセット(異なる種類のキノコのように)からの例で、非常に充実したデータについてトレーニングされた多数の「専門家」モデルと、を含むアンサンブルであることは意味があります。
このタイプのスペシャリストのsoftmaxは、専門ではないすべてのクラスを1つのゴミ箱クラスにまとめることによって、はるかに小さくすることができます。
大きなデータセット(例えばJFTデータセット)に対するアンサンブルとして、ここでは「ジェネラリストモデル」と「専門家モデル」とを含むアンサンブルを考える。
- 専門家モデル
考え方としては簡単で、紛らわしいクラスに対応するために、問題を小分けにして別途対応する推論器を生成する。ただ、大分類を行う推論器、その後、分類されたクラスの中で詳細な分類を行う推論器といった二段階処理を考えてしまうのが普通だと思うが、それをアンサンブルで一つの推論器にできる、しかもサイズは増やさずに、というのが面白い。
過剰適合を減らし、低レベルの特徴検出器を学習する作業を分担するために、各専門家モデルはジェネラリストモデルの重みで初期化されます。
これらの重みは、半分がその特別なサブセットから、半分が残りの訓練セットから無作為にサンプリングされる学習データセットで、専門家を訓練することによってわずかに修正される。
トレーニング後、専門家クラスがオーバーサンプリングされた割合のログで、ゴミ箱クラスのロジットを増加させることで、偏ったトレーニングセットを補正できます。
専門家モデルは、ジェネラリストモデルの重みで初期化される。これによって学習時間を大きく短縮する。
専門家用の学習は、半分が専門家用データ、半分が専門家用以外のデータからなる学習セットによって行う。
アンサンブルを作るために複数の学習を行う際、異なるデータセットを用いるというのは当たり前なので、このあたりに本論文の重要なノウハウがありそう。ここでは学習データの構成しか書いていないが、例えば重み固定などは行わないのだろうか? また、どれくらい学習したら収束したとみなすのだろう。
最後、ロジットを補正については良くイメージができない。実際に手を動かしてみないと分からなそう。
5.3 Assigning classes to specialists(専門家にクラスを割り当てる)
専門家のためのオブジェクトカテゴリのグループ化を導き出すために、私たちは私たちの完全なネットワークがしばしば混同するカテゴリに焦点を合わせることにしました。
混同行列を計算し、そのようなクラスタを見つける方法として使用することができたとしても、我々はより単純なアプローチ、クラスタを構築するための真のラベルを必要としないアプローチ、を選びました。
まず、専門家に与えるべきオブジェクトカテゴリのグループ化を導き出す必要がある。ジェネラリストモデルが間違えやすいものを専門家に解かせるのはいいが、ではジェネラリストモデルが間違えやすいデータベースを作るところから始めなくてはならない、ということか。
で、そこにもいろいろ工夫ができそうだが、ここではなるべく単純なアプローチを選ぶ、と。
特に、我々はジェネラリストモデルの推論の共分散行列にクラスタリングアルゴリズムを適用し、一緒に推論されることが多いクラスの集合が、我々の専門家モデルの1つ、のためのターゲットとして使用されるようにする。
我々はK-meansアルゴリズムのオンライン版を共分散行列の列に適用し、合理的なクラスタを得ました(表2に示す)。
同様の結果が得られるいくつかのクラスタリングアルゴリズムを試しました。
JFT 1:ティーパーティー。イースター;ブライダルシャワー;ベビーシャワー;イースターのウサギ; ...
JFT 2:橋;斜張橋。つり橋高架橋煙突; ...
JFT 3:トヨタカローラE100。オペルSignum;オペルアストラ;マツダファミリア...
表2:我々の共分散行列クラスタリングアルゴリズムによって計算されたクラスタからのクラス例
単純なアプローチとは、ジェネラリストモデルの推論の共分散行列に、K-meansアルゴリズムなど、単純なクラスタリングアルゴリズムを適用する、というもの。
これによって、例えば車種が違うだけのクラスタなどを得ることができる。
5.4 Performing inference with ensembles of specialists(専門家アンサンブルによる推論の実行)
専門家モデルが蒸留されたときに何が起きているかを調べる前に、専門家を含むアンサンブルがどの程度うまく機能しているかを、我々は確認したいと思いました。
専門家モデルに加えて、我々は常にジェネラリストモデルを持っており、専門家を持たないクラスを扱うことができ、どの専門家を使うかを決定できるようになっています。
入力画像が与えられると、2つのステップでトップワン分類(top-one classification)を行います。
Step 1:各テストケースについて、ジェネラリストモデルに従ってn個の最確クラスを見つけます。このクラスの集合をkと呼びます。我々の実験では、を使用しました。
Step 2:我々は次に、混同しやすいクラスの特別なサブセットを持つ、全ての専門家モデルmが、kと空でない交点を持つようにし、これを専門家のアクティブ集合と呼びます(この集合は空であってもよいです)。我々はそして、次式の最小化により、すべてのクラスにわたる、完全な確率分布を発見します。
ここで、
- KL:KL divergence(カルバック・ライブラー情報量)
- :専門家モデルの確率分布
- :ジェネラリストフルモデルの確率分布
を表します。
分布は、mのすべての専門家クラスと、1つのゴミ箱クラスの分布です。そのため、フルq分布からそのKL情報量を計算するとき、我々は、フルq分布がmのごみ箱中のすべてのクラスに割り当てる確率をすべて合計します。
蒸留の前のジェネラリストモデルと専門家モデルからなるアンサンブルの性能を確認しているようです。その方法は、2つのステップからなるトップワン分類(top-one classification)によって行う、とのことですが、良く理解できませんでした。
当たり前に考えれば、入力画像をまずジェネラリストモデルを使って分類し(クラスkに分類される)、それをさらに詳細に分類できる専門家モデル(クラスkをさらに分類できるアクティブ集合)を動作させ分類する、といったところでしょうか。
式5はモデルが出力する確率分布と真の確率分布とを比較する式であり、これの最小化により、真の確率分布を取得しようとしていると考えられます。
すべてのモデルが各クラスに対して単一の確率を生成する場合、KL(p,q)またはKL(q,p)のどちらを使用するかに依存して、解は算術平均または幾何平均のいずれかになります。それにもかかわらず、式5は一般閉形式解(general closed form solution)を持ちません。
我々は(T=1で)q=softmax(z)をパラメータ化し、そして勾配降下法を使用してロジットz を式5を用いて最適化します。この最適化は各画像に対して実行されなければならないことに留意してください。
式5は一般閉形式解(general closed form solution)を持たない、つまり解けないので、勾配降下法を使うと理解しました。qをパラメータ化して最小値を求め、そのときのロジットzを得ます。
5.5 Results
訓練済みのベースラインフルネットワークから始めて、専門家モデルの訓練は非常に高速です(数週間かかるJFTに対して数日)。また、すべての専門家モデルは、完全に独立して訓練されます。
表3に、ベースラインシステムと、専門家モデルと組み合わせたベースラインシステムとの、絶対テスト精度を示します。 61のスペシャリストモデルを使用では、全体としてテスト精度に4.4%の改善が見られます。
また、条件付きテストの精度についても報告します。これは、専門家クラスに属する例を考慮し、推論をそのクラスのサブセットに限定することによる精度です。
結果として、以下のようなことを述べています。
- 専門家モデルの訓練は非常に高速
- 独立に訓練可能
- 絶対テスト精度 の改善、条件付きテストの精度
JFT専門家実験のために、それぞれ300のクラス(と、ゴミ箱クラス)を持つ、61の専門家モデルを訓練しました。
専門家のためのクラスのセットは互いに素ではないため、特定の画像クラスをカバーする複数の専門家が存在しました。
表4は、以下を示します。
- テストセット例の数
- 専門家を使用した場合に1の位置で正しい例の数の変化
- クラスをカバーする専門家の数でブレークダウンされたJFTデータセットのためのtop1精度における、相対的なパーセンテージの改善
我々は、独立した専門家モデルの訓練は並列化が非常に容易であるため、特定のクラスをカバーする専門家が多くなれば精度がより改善するという一般的な傾向に、励まされています。
あるクラスをカバーする専門家は1つでなければならないという縛りがないとのことです。そう言った縛りがないことは、実用上、データの取り扱いはかなり楽になることでしょう。
しかも精度も上がるのかな?
まとめ
論文"Distilling the knowledge in a neural network."の5.2~5.5節を読みました。