JavaScriptを有効にしてください

【論文メモ】Deep Learning without Shortcuts: Shaping the Kernel with Tailored Rectifiers

 ·  ☕ 8 min read

はじめに

  • ICLR22 [paper]
    • 深層学習において, 残差接続は不可欠な存在となりつつある
      • 残差接続により, より深い層数のNNを実現できるようになった
    • 残差接続に対する解釈の矛盾
      • 昨今の研究により残差接続は比較的浅い層をアンサンブルするような効果があるとの見方が強まっている
      • しかし, 「深層」学習という名が体を表す通り, 一般には「層を増やす」ことがモデルの表現力を高めていると言われており, ここに残差接続に対する解釈の矛盾が存在する
    • また残差接続は推論時においてメモリを圧迫しているとの見方も存在する
      • 残差接続が結合されるまで, 入力を保持する必要があるため, 一つのスキップでメモリを倍使う (下図参照)
      • 例えば, 残差接続はResNet-50における特徴量の40%もメモリを使用している
  • したがって, 残差接続の再考が必要であり, 残差接続を用いず層を増やす手法としてTATを提案
    • NNをカーネル関数へ近似し, Q/C mapsを用いてNNの挙動を理論解析する


引用: RepVGG: Making VGG-style ConvNets Great Again



カーネルの近似

  • 活性化関数を ϕ()としたとき, 全結合のネットワーク fの各層における出力は以下のように書ける.

xl+1=ϕ(Wlxl+bl)Rdl+1

  • ただし, 重みは WliidN(0,1/dl)で初期化され, バイアス bl0で初期化されるとする.
  • このとき, fθl(x):=x として, fθl:RkRdlのカーネル関数 κfl(x1,x2)を以下のように定義すると,

κfl(x1,x2)=1dlfθl(x1)fθl(x2)

  • ネットワーク fの層の幅を無限大に大きくしたときに, カーネル κf1(x1,x2)は以下のような κ~f1(Σx1,x2)によって近似できることが知られている. (上のカーネルは fによって直接記述されているが, 近似されたカーネルは活性化関数 ϕによって書き下されている点に留意されたい)

κ~f1(Σx1,x20)=EzN(0,Σx1,x20)[ϕ(z)ϕ(z)]=:Σ1

Σx1,x20=1d0[x1x1x2x1x1x2x2x2]

  • 各層ごとの ΣlΣl+1の間にも以下のような漸化式が成り立ち, 各層のカーネルを計算することができる.

Σl+1=EzN(0,Σl)[ϕ(z)ϕ(z)]

Σl=[κ~fl(x1,x1)κ~fl(x1,x2)κ~fl(x1,x2)κ~fl(x2,x2)]

  • (幅を無限大に飛ばす→NTKが想起されるが, NTKとは若干異なる)

Q/C maps

  • Q map

    • Σl+1の対角成分 qil+1Σlの対角成分である qilにのみ依存するので,

    qil+1=EzN(0,qil)[ϕ(z)2]=EzN(0,1)[ϕ(qilz)2]

    • ただし, qi0=|xi|2d0
    • このとき, qil+1=Q(qil)であるような Qlocal Q mapと呼ぶ
    • また L層のネットワーク f全体において, Qf(q)=QQQQL times(q)global Q mapと呼ぶ
    • カーネル K(x,y)再生核ヒルベルト空間において x,y間の類似度を表すので, 対角成分 q及び Qは入力の振幅を表す
  • C map

    • 一方で Σl+1の非対角成分 cl+1については, cl+1=C(cl,q1,q2)と全成分に依存するので, 少し計算が厄介 
    • Clocal C mapと呼び, 一般に, 以下のように計算される. (説明略)

    cl+1=C(cl,q1l,q2l)=E[z1 z2]N(0,Σl)[ϕ(z1)ϕ(z2)]Q(q1l)Q(q2l)

    Σl=[q1lq1lq2lclq1lq2lclq2l]

    • ただし, c0=x1x2/d0

    • 非対角成分 cκ~fl(x1,x2)であるから, ある層の異なるノードにおける入力 x1,x2の類似度を計算することになる

      • つまり, location-wiseな入力の類似度を計算することになる
    • こちらも同様, L層のネットワーク f全体において, Cf(c)=CCCCL times(c)global C mapと呼ぶ

      • この関数は c0=x1x2/d0=類似度を入力として, 類似度 cfの出力ごとの類似度とどう関係があるかをマッピングする
      • Cf(c)は入力の類似度に対しどれだけ出力の類似度を保持できているかを表すので, なるだけ Cf(c)は非線形である方が良い
      • Cが線形であればあるほど活性化関数も線形に近づく (14.3)
        • 線形だと, 類似度をそのまま出力してる = 線形
      • 逆に Cf(c)が一様に1に近づけば近づくほど, 入力間の類似度を正しく fが測れないため, ネットワークの出力から入力間の相対的な距離を推測するのが困難であることになり, 勾配による学習が進まなくなる証拠となる
      • 例えば ReLUを使った1層のネットワークの場合↓


  • これが何層にも連なると, C mapの値は1へと収束し, 単純に層を増やすだけでは学習が困難になる傍証が得られる↓



引用: Rapid training of deep neural networks without skip connections or normalization layers using Deep Kernel Shaping

Tailored Activation Transformation for Leakly ReLU

  • Cf(c)について望ましい状態
    • Cf(0)=0 つまり, 全く類似してないサンプルは出力の類似度も0であってほしい
    • Cf(1)=1つまり, なるだけ Cf(x)=1に接近するような平坦な形は望ましくない

  • Leakly ReLU (LReLu) ϕα(x)について
    ϕα(x)=max{x,0}+αmin{x,0},

ϕ~α(x)=21+α2ϕα(x)

  • という活性化関数を定義すると,

Q(q)=q,C(c)=c+(1α)2π(1+α2)(1c2ccos1(c))

  • が成り立つ. ( αについては後述の方法で求める)

  • Q(q)=qについて

    • 各層の入力に対して, 摂動に強くなる方向へ制約がかかるため, カーネルの近似誤差が小さくなる
    • ここについてはあまりよくわかっていない
  • C(c)について

    • Cf(c)について望ましい状態について再考すると
      • Cf(1)=1→ 満足可能
      • Cf(0)=0 → これを満足するには α=1とする必要があり, これだと Cf(c)=cと線形になってしまうのでよろしくない
        • そこで, Cf(0)=ηとして, Cf(c)の線形度合いを調整できるようにする
  • 以上より,
    Q(q)=q,Cf(1)=1,Cf(0)=η

  • を満たすことができるように αを決定すべき

    • そのようなLReLUのことをTReLUと呼び, 活性化関数をこのように変化させる手法としてTailored Activation Transformation(TAT)を提案

    • また, このような ϕ~α(x)を使ったネットワークは, 活性化関数を 2max{x,0}として残差接続の重みを 21+α2としたResNetと同等であることが証明できる

TReLU計算アルゴリズム

  • 目標: Cf(0)=ηを満たすような αを見つける

  • このとき, Cf()ではなく, fの全てのサブネットワーク gについて μf0(α)を定義し, μf0(α)について最適化を行う
    μf0(α)=maxg:gfCg(0)

  • α,C(c)の性質について

  • C(c)は以下のように cのみに依存して簡単に計算できる.
    C(c)=c+(1α)2π(1+α2)(1c2ccos1(c))=:A

  • このとき,
    dCdα=2A(α21)π(α2+1)2<0(α(1,1))

  • より, C(c)αについて単調減少なので, 二分探索ができる

アルゴリズム
  1. C(c)は以下のように cのみに依存して簡単に計算できるので, まず全てのサブネットワーク gから μf0(α)を求める
    C(c)=c+(1α)2π(1+α2)(1c2ccos1(c))

  2. 次に, α(1,1) について, μf0(α)=ηを満たす αを二分探索で求める

  3. ϕ~α(x)=21+α2ϕα(x)を活性化関数としてモデルに適応する


一般化アルゴリズム

  • LReLUだけでなく, 一般に滑らかな活性化関数であれば同様の手法を用いることができる


評価



ロジックの整理 & misc

  • FFNにおけるカーネルを用意する→二次のGram行列の漸化式でカーネルの近似解が求まる→対角成分と非対角成分をそれぞれQ/C mapsと呼び, 入力の振幅と類似度を図る指標となる→ネットワーク全体における global Q/C mapsが理論解析のための道具となる

  • 理想的なQ/C mapsの値が存在→maximal slope functionでサブネット全体を最適化

  • C mapを最大化するサブネットワークを探す =:μ

  • μαについて単調に減少するので, α二分探索

  • αを活性化関数にapply

  • 論文中に αの範囲は明記されていなかったが, 下のコードを見る限り, α[1,1]っぽい (supprimentalから引用)

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
def binary_search(fn, target, input_=0.0, min_=-1.0, max_=1.0, tol=1e-6):
  value = fn(input_)

  if np.abs(value - target) < tol:
    return input_

  if value < target:
    new_input = 0.5 * (input_ + min_)
    max_ = input_
  elif value > target:
    if np.isinf(max_):
      new_input = input_ * 2
    else:
      new_input = 0.5 * (input_ + max_)
    min_ = input_

  return binary_search(fn, target, new_input, min_, max_, tol=tol)
共有

YuWd (Yuiga Wada)
著者
YuWd (Yuiga Wada)
機械学習・競プロ・iOS・Web