ココアのお勉強ブログ

技術職の一般人です。趣味でコード書いたりソフト触ったり。

AUGMIX: A SIMPLE DATA PROCESSING METHOD TO IMPROVE ROBUSTNESS AND UNCERTAINTY 自分用にメモ

論文リンク
AugMix: A Simple Method to Improve Robustness and Uncertainty under Data Shift | OpenReview
公式実装(Pytorch)
GitHub - google-research/augmix: AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty

どんなもの?

Datasetのshiftの下で、堅牢性と不確実性を改善する、シンプルだけど効果的な方法を提案。
下のgif(公式実装より引用)を見ると概要が良く分かる。

先行研究と比べてどこがすごい?

Hendrycks&Dietterichによって提案されたベンチマークで目に見えない崩壊?に対する堅牢性を大幅に改善したらしい。
CIFAR-10-C、CIFAR-100-Cで前のSOTAと Clean Error(壊れてないtestデータに対する通常の分類Error)の差を半分以上縮めたらしい。
よく分からないけどとりあえずError率下がったってことかな...?
https://twitter.com/balajiln/status/1202764043733229568 より)

技術や手法のキモはどこ?

連続してシフト処理(多分色空間的なシフトも含む?)みたいなことをすると元の画像から遠く離れた(異なる)画像を生成し、非現実的な画像になってしまうことがある(下図)。
でもこの違いは、ステップ数を調整することでバランスを取れる。
複数の拡張画像を生成してmixすることで多様性を高めることができる。

実際のアルゴリズムと画像に対する処理の流れが以下のようになる。

まずAugmentAndMix関数(2:~14:)について見ていきます。

Step1

画像のmix比の重み w_k を乱数で決める。

公式実装だと以下のようにしてる。

# alpha=1, width=3
ws = np.float32(
     np.random.dirichlet([alpha] * width))
# 出力例array([0.15282747, 0.6870145 , 0.160158  ], dtype=float32)

widthは3で固定だった。

Step2

i(=[1,2,3])それぞれについてoperationsを3つずつ(op1,op2,op3)ランダムで決める。
ここで言うoperationsとはtranslate_x(x軸方向に画像をずらす)、rotate(回転処理)のような処理を指す。
ちなみに公式実装ではoperationsは以下のリストからランダムに選ばれるようになっていた。

augmentations = [
    autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y,
    translate_x, translate_y
]

op_1,op_2,op_3 を決めたから全て適用するというわけではなく、
1. op_1だけ適用するか
2. op_1を適用した後さらにop_2を適用するか
3. op_1を適用しop_2を適用した後さらに op_3も適用するか

の3択 [op1,op12,op123]からさらにランダムで1つ決めます。

f:id:hotcocoastudy:20191207144419p:plain

それらk(=3)個の処理画像をStep1で求めた重みを使ってアルファ合成します。

これにより得た画像を x_{aug} とします。

Step3

ベータ分布 Beta(β,β)から乱数 mを得ます。

mを合成比として使用し元画像[tex:x{orig}]と[tex:x{aug}]をアルファ合成します。

以上により得た最終的な画像を  x_{augmix} とします。
これを返すことでAugmentAndMix関数ができます。

実際には元画像 [tex:x{orig}] に対して2回のAugMix処理をして [tex:x{augmix1}] と x_{augmix2}の2つの画像を得た後、

通常のLoss(SoftmaxCrossEntropyとか)に加えて以下のようなJensen-Shannon divergenceを加えたLossで評価する。

以上がAugMixの一連の流れになります。

※Jensen-Shannon divergenceについては以下の記事程度の理解です。
http://yusuke-ujitoko.hatenablog.com/entry/2017/05/07/200022

補足 Jensen-Shannon Consistency Lossは、多様な入力範囲に対して、安定性、一貫性をモデルに促すらしい。よく分からんけど。

どうやって有効だと検証した?

上で説明したけどもう一度、 Hendrycks&Dietterichによって提案されたベンチマークで目に見えない崩壊?に対する堅牢性を大幅に改善したらしい。
CIFAR-10-C、CIFAR-100-Cで前のSOTAと Clean Error(壊れてないtestデータに対する通常の分類Error)の差を半分以上縮めたらしい。

https://twitter.com/balajiln/status/1202764043733229568 より)

AugMixは、予測不確実性の推定も大幅に改善した。 AugMix + Deep Ensemblesは、増加するデータシフトの下でImageNet-CでSOTAキャリブレーションを達成した(Ovadia et al。2019)。 https://twitter.com/balajiln/status/1202765799636627457 より)

議論はある?

次に読むべき論文は?

Adversarial Examples Improve Image Recognition

分かりづらい解説記事になってしまってすみません...
何か間違いとかあったら指摘してもらえると嬉しいです