数学がわからない

日々の勉強をアウトプットする。

Noise2Noiseを理解する(2)

論文"Noise2Noise: Learning Image Restoration without Clean Data"を読みます。

前回、"1. Introduction"を読みました。
gyokan.hatenablog.com

今回は"2. Theoretical Background"です。ちょっと長くなったので目次を載せます。

"2. Theoretical Background(理論的背景)"を読む

未知の真値を推定する一般的な戦略:ML推定

室温の測定値の信頼できないセット(y_1, y_2, ...)があると仮定します。


真の未知の温度を推定するための一般的な戦略は、いくつかの損失関数Lに従って、測定値から最小の平均偏差を持つ、数zを見つけることです。



\begin{align}
\underset{z}{argmin} \mathbb{E}_y\{L(z, y)\}.
\tag{2} 
\end{align}


L_2損失L(z,y) = (z-y)^2の場合、この最小値は観測値の算術平均で求められます。


\begin{align}
z = \mathbb{E}_y\{y\}.
\tag{3} 
\end{align}

L2損失関数と算術平均

同じものを測定したデータが複数あるとき、それがばらついているなら平均をとろう、というのはごく当たり前の話です。ここでは、それを少し詳細に考察しています。

まず、真値を得るための損失関数を一般化してLとし、最小化問題を解こう、としているのが式2です。具体的な損失関数を「差の二乗」とすると、結局算術平均に落ち着くことを示しているのが式3です。

 \mathbb{E}_y\{f(y)\}は、全てのyについての関数f(y)の値の平均値を表しています。つまり、yをy_1からy_NまでのN個の値をとすると、


\begin{align}
\mathbb{E}_y\{f(y)\} = \displaystyle \frac{\sum_{i=1}^N f(y_i)}{N}
\end{align}
\tag{A}

となります。これを用いて、損失関数がL_2損失L(z,y) = (z-y)^2である場合の式2から式3を導出する過程を以下に示します。まず式2を書き換えます。


\begin{align}
\underset{z}{argmin} \mathbb{E}_y\{L(z, y)\} = \underset{z}{argmin} \displaystyle \frac{\sum_{i=1}^N (z-y_i)^2}{N}
\tag{2'} 
\end{align}

zの最小値を求めるためには、次のようにzについて微分したものをゼロとして解けばよいです。


\begin{align}
\displaystyle \frac{d}{dz} \frac{\sum_{i=1}^N (z-y_i)^2}{N} &= 0 \\
\displaystyle \frac{d}{dz} \sum_{i=1}^N (z^2-2zy_i+y_i^2) &= 0\\
\displaystyle \frac{d}{dz} \left( \sum_{i=1}^N z^2 - \sum_{i=1}^N 2zy_i + \sum_{i=1}^N y_i^2 \right) &= 0\\
\displaystyle \frac{d}{dz} \left( N \times z^2 - 2z \sum_{i=1}^N y_i + \sum_{i=1}^N y_i^2 \right) &= 0\\
\displaystyle \frac{d}{dz} N \times z^2 - \frac{d}{dz} 2z \sum_{i=1}^N y_i + \frac{d}{dz} \sum_{i=1}^N y_i^2  &= 0\\
\displaystyle N \times 2z - 2\sum_{i=1}^N y_i &= 0\\
\displaystyle z  &= \frac{\sum_{i=1}^N y_i}{N}\\
\tag{B} 
\end{align}

これで式3が示せました。

L1損失関数と中央値

L_1損失である絶対偏差の合計L(z,y) = |z  -  y|の場合は、観測値の中央値でその最適値が得られます。

同じように損失関数がL_1損失L(z,y) = |z  -  y|である場合の式2を変形してみます。


\begin{align}
\underset{z}{argmin} \mathbb{E}_y\{L(z, y)\} &= \underset{z}{argmin} \displaystyle \frac{\sum_{i=1}^N |z-y_i|}{N} \\
&= \underset{z}{argmin} \displaystyle \frac{\sum_{z > y_i}(z-y_i) + \sum_{z < y_i} (y_i-z)}{N}
\tag{2''} 
\end{align}


zの最小値を求めるために、zについて微分し、ゼロとなる方程式を解きます。

\begin{align}
\displaystyle \frac{d}{dz} \frac{\sum_{z > y_i}(z-y_i) + \sum_{z < y_i} (y_i-z)}{N} &= 0 \\
\displaystyle \frac{d}{dz} \sum_{z > y_i}(z-y_i) + \frac{d}{dz} \sum_{z < y_i} (y_i-z) &= 0 \\
\displaystyle \sum_{z > y_i}1 + \sum_{z < y_i} (-1) &= 0
\tag{C} 
\end{align}

ここでN個のy_iのうち、M個のy_iがzより小さいと仮定すると、次のようになります。


\begin{align}
\displaystyle \sum_{z > y_i}1 + \sum_{z < y_i} (-1) &= 0 \\
\displaystyle M \times 1 + (N-M) \times (-1) &= 0 \\
\displaystyle M - N + M &= 0 \\
\displaystyle 2M &= N \\
\displaystyle M &= \frac{N}{2} \\
\tag{C'} 
\end{align}
つまり、N個のy_iのうち、M=N/2個、つまり半分のy_iがzより小さいとき、この方程式は成り立ちます。
これは、zがy_iの中央値であることを意味します。

一般的なクラスの偏差最小化推定量は、M-estimatorとして知られています(Huber、1964)。統計的観点から、これらの共通損失関数を使用した要約推定(summary estimation)は、損失関数を負の対数尤度として解釈することにより、ML推定(最尤推定)と見なすことができます。

ML推定(最尤推定)とは、「与えられたデータから、データが従う確率分布の母数を、点推定する方法」です。

ここでは、与えられたデータ「室温の測定値の信頼できないセット(y_1, y_2, ...)」から、データが従う確率分布の母数(平均値)を、点推定しているので、確かにML推定(最尤推定)と見なすことができます。

だから何だ、という話ですが、続きます。

点推定手順の一般化

ニューラルネットワークリグレッサ(regressors)の訓練は、この点推定手順の一般化です。入力とターゲットのペアのセット(x_i,y_i)に対する、典型的なトレーニングタスクの形式を観察しましょう。ここで、次式のネットワーク関数f_\theta(x)は、\thetaによってパラメーター化されています。



\begin{align}
\underset{\theta}{argmin} \mathbb{E}_{(x,y)} \{ L(f_\theta(x),y) \}.
\tag{4}
\end{align}

式(4)はつまり、式(2)のzを一般化し、f_{\theta}(x)に置き換えたものです。

室温を例に、この式(4)は、入力xが外気温、f_{\theta}(x)を外気温xのときの室温、yを温度計での測定値と考えてみます。ここでパラメータθは例えば部屋の断熱性能であり、最小化問題により求める対象です。

確かに、入力データへの依存性を取り除き、単に学習済みスカラーを出力する単純なf_\thetaを使用すると、式(4)は式(2)になります。

例えるなら、断熱性θが凄すぎて、入力である外気温xに室温zが依存しないときの話ですね。入力xもパラメータθも式から無くなってしまいます。

逆に、完全なトレーニングタスクは、すべてのトレーニングサンプルにおいて、同じ最小化問題に分解されます。簡単な操作は、式(4)が次式と等価であることを示します。



\begin{align}
\underset{\theta}{argmin} \mathbb{E}_x \{\mathbb{E}_{y|x} \{ L(f_\theta(x),y) \}\}.
\tag{5}
\end{align}

式(4)と式(5)が等価であることを示す式変形は良く分からないです・・・。ただ、感覚的には、式(4)がx、yを両方に関する平均値を求めるもの、式(5)はあるxという条件のとき、yに関する平均値を求め、それをさらに全てのxについて平均するというもので、やっていることは同じですが、手順の違いを強調した表現でしょうか。

室温を例に具体的に考えると、

  • 式(4):様々なタイミングで外気温xと室温測定値yのペアを取得してその誤差の平均値を求める。
  • 式(5):ある外気温xにおいてばらつきのある測定値yを複数取得して誤差の平均を求める。それを様々な外気温xに関して行って平均する。

ということだと理解します。

ネットワークは、理論的には、入力サンプルごとに別々に点推定問題を解くことによってこの損失を最小化できます。したがって、根本的な損失の特性は、ニューラルネットワークのトレーニングによって継承されます。

式(4)を式(5)に変形することで、入力xとターゲットyをペアで考えなくても、入力xごとに別々に考えれば良いことが分かります。

気温の例では、様々なタイミングの外気温xと測定値yのペアを取得するより、ある外気温xのときにまとめて複数の測定値yを取得する方が手間がかからず大量のデータを取得できるでしょう。

そうやって入力サンプルごとに点推定を解いても「根本的な損失の特性」なるものはちゃんと得られる、というのがニューラルネットワークのトレーニングの能力である、と論文は述べているのだと思います(たぶん)。

式(1)の隠している微妙な点

有限数の入力とターゲットのペア(x_i,y_i)を用いた式1による、リグレッサーをトレーニングする通常のプロセスは、微妙な点を隠しています。すなわち、入力とターゲットの間の1:1マッピングであるとプロセスによって(偽って)ほのめかされていますが、正しくはマッピングは多値です。


例えば、すべての自然画像に対する超解像タスク(Ledig et al., 2017)では、低解像度画像xは多くの異なる高解像度画像yで説明できます、エッジとテクスチャの正確な位置と方向に関する知識は間引きによって失われているので。言い換えれば、p(y|x)は、低解像度xと一致する自然画像の非常に複雑な分布です。

ところどころ英語の意味が分かりませんが、具体例をイメージすると分かりやすいです。超解像タスクにおいて、低解像度画像xを超解像度画像yにしようとしても、xに対し正解のyは1つではなく複数あるということを言っており、リグレッサーはこれに目を瞑っています。

L2損失を用いて、低解像度画像と高解像度画像のトレーニングペアでニューラルネットワークのレグレッサーをトレーニングすると、ネットワークはすべてのもっともらしい説明(たとえば、異なる量だけシフトしたエッジ)の平均を出力することを学習し、結果として、ネットワークの推論には空間的なぼけが生じます。

そして複数ある高解像度画像yを用いてニューラルネットワークを学習することが(式5)で表されています。さらに、誤差関数をL2損失とするなら、得られる高解像度画像z=f_{\theta}(x)は、複数あるyの平均値であり、このことが(式3)で表されています。

複数ある高解像度画像yはそれぞれが高解像度でも、平均をとるとぼけ画像になってしまいます。とすると、超解像タスクでは結局ぼけ画像しか得られません。

このよく知られた傾向に対抗するために、かなりの量の研究が行われてきています。例えば、学習済み識別関数を損失として使用する(Ledig et al., 2017; Isola et al., 2017)ことよって。

  • [Ledig et al., 2017]: Isola, Phillip, Zhu, Jun-Yan, Zhou, Tinghui, and Efros, Alexei A. Image-to-image translation with conditional adversarial networks. In Proc. CVPR 2017, 2017.
  • [Isola et al., 2017]: Ledig, Christian, Theis, Lucas, Huszar, Ferenc, Caballero, Jose, Aitken, Andrew P., Tejani, Alykhan, Totz, Johannes, Wang, Zehan, and Shi, Wenzhe. Photo-realistic single image super-resolution using a generative adversarial network. In Proc. CVPR, pp. 105–114, 2017.

以上のことはもちろん一般によく知られていてたくさんの研究があります。(なんだか、解決しそうにない問題に思えるのですが・・・。)

L2損失のもたらす予期しない利益

我々の観察は、ある問題にとってはこの傾向が予期しない利益をもたらす、というものです。


些細で、そして一見役に立たないL2最小化の性質は、もし予想をターゲットと一致させる乱数でターゲットを置き換えても、予想に反して、見積もりは変わらない、というものです。

ここに書かれていることがNoise2Noiseにおいて、最も重要な原理なのだと思います。

これまで見てきた通り、例えば超解像タスクでは、低解像度画像xを超解像度変換する関数f_\thetaを求めたいのですが、低解像度画像xに対応する高解像度画像yは複数存在するため、L2最小化を誤差関数として機械学習を行うと、関数f_\thetaはyの平均値を出力するような学習が進んでしまい、結局ぼかし処理を生成してしまいます。

この性質にはがっかりなのですが、結局yの平均値を出力するような関数f_\thetaが得られると分かっているなら、その関数が得られる条件内でyを置き換えても良い。結局得られる関数は変わらないのだから。

これは直感的には( ゚Д゚)ハァ?なのですが、式を見ると確かにそのようです。

これは簡単に理解できます。すなわち、yの集合がどんな特定の分布から引き出されたとしても、式(3)は成り立ちます。


その結果、入力条件付き目標分布p(y|x)が同じ条件付き期待値を有する任意の分布に置き換えられても、式(5)の最適ネットワークパラメータθも変化しないままである。

入力低解像度画像xにおける目標の高解像度画像yをどう置き換えても、ぼかした結果を出力する超解像度処理fθが生成される、ということは変わりません。

結論

これは、原理的には、ニューラルネットワークのトレーニングターゲットを、ゼロ平均ノイズで破損させても、ニューラルネットワークが学習するものは変わらないことを意味しています。
これを式1からの破損入力と組み合わせると、経験的リスク(empirical risk)最小化タスクが残ります。



\begin{align}
\underset{\theta}{argmin} \sum_i L(f_{\theta}(\hat{x}_i),\hat{y}_i),
\tag{6}
\end{align}

この式は、式1のターゲットy_iを、ノイズ画像\hat{y}_iに置き換えた式です。
式1では問題を解くにはノイズのないデータが必要でしたが、式6では不要になってしまいました。

ここで、入力とターゲットの両方は、破損した分布(必ずしも同じ必要はない)から引き出されます。\mathbb{E} \{\hat{y}_i | \hat{x}_i\} = y_iのような、観測されていないクリーンなターゲットy_iが根底にあることを条件に。


無限のデータが与えられると、解は(1)の解と同じになります。


有限データの場合、分散はターゲット内の破損の平均分散をトレーニングサンプルの数で割ったものです(付録を参照)。

とはいえ、条件はあるようです。

興味深いことに、上記のどれも破損の尤度モデルにも、根本的なクリーン画像多様体(clean image manifold)に対する密度モデル(事前)(density model (prior) )にも依存していない。


つまり、明示的なp(noisy|clean)またはp(clean)は必要ありません。データがそれらに従って分散されている限り。

条件はあっても、クリーンなデータの分布や、その分布におけるノイズの分布といったものが必要ないことが強味なのだと思います。

ノイズのあるデータ事態の収集は容易ですが、そのデータからノイズの分布を求めるのは容易ではありません。本手法を使うと、それを求めることなく、データを収集することでノイズ低減処理ができるようになる、と。

多くの画像復元タスクにおいて、破損した入力データが期待しているのは、クリーンなターゲットであり、我々は復元を目指しています。


暗い場所での写真撮影はその一例です。この例では、長時間のノイズのない露光は、短時間の独立したノイズの多い露光の平均です。


これを念頭に置くと、上記は、高価または困難な可能性のある長時間露光なしに、ノイズのある画像ペアのみを与えられて、光子ノイズを除去することを学習できる可能性を示唆しています。

暗所撮影写真の例では、入手するのが難しいノイズ除去の学習には長時間露光画像が一見必要に思えますが、実は入手が容易な短時間露光画像からでもノイズ除去の学習ができます。

他の損失関数についても同様の観察が可能です。


例えば、L1損失はターゲットの中央値を回復します。これは、次のことを意味しています。すなわち、ニューラルネットワークは、重要な(上位50%まで)異常値の内容を含む画像を修復するように訓練することができ、やはりそのような破損した画像の対へのアクセスのみを必要とすることを意味します。

良く分からないですが・・・、L2損失は平均値でしたが、L1損失は中央値を求める処理になります。この場合も、例えばクリーンな画像をターゲットとしてクリーンな画像の中央値を求めるのも、ノイズ画像をターゲットとしてノイズ画像の中央値を求めるのみ、得られる中央値は変わらない、ということでしょうか。

次のセクションでは、これらの理論的能力も実際には効率的に実現可能であることを実証する、多種多様な例を紹介します。

以上、理論のお話でした。なんとなく分かったような分からないような・・・。

まとめ

"Noise2Noise: Learning Image Restoration without Clean Data"の"2. Theoretical Background(理論的背景)"を読みました。次は"3. Practical Experiments(実技実験)"を読みます。