ディープラーニングの精度を上げるテクニックとしてTTAというものがあります。
TTAは有用なテクニックとして知られていますが、使い方によっては思ったほど精度が引き上がらないということも起こりうるようです。
そのあたりについて触れている論文があったので、今回紹介したいと思います。
本記事の元ネタとなる論文はこちらです。
TTAとは、、、
そもそもTTA(Test-Time Augmentation)とは何かと言うと、ディープラーニングの学習時に行っているAugmentation(データに様々な加工を加えてデータ量を増やし精度を上げるテクニック)を推論時にも行い、精度を引き上げるテクニックです。
TTAの効果を引き上げるには、、、
論文ではTTAは基本的に有効だけど、取り扱うデータの量や特徴、ネットワークモデルによって挙動が変わるので、それらを意識したほうがベターということが書かれています。
論文を読んで、気になったポイントを列挙します。
- TTAの精度を引き上げるためには、不正解のものが正解になったものだけでなく、正解のものが不正解になってしまったものも意識しないといけない。
- Standard TTAは反転、切り取り、拡大を組み合わせて30パターンの画像を作成して推論を行う。Expand TTAはPILライブラリを使ったバイナリ変換等を行ったりして128パターンの画像を作成して推論を行う。
- この論文ではImageNetとFlowers-102の2種類のデータセットでのTTAの挙動を調べているが、Flowers-102の方がTTAの効果は少ない。Flower-102の方が画像の構図が似ているのでTTAの恩恵は受けにくいためではないかと考えられる。
- モデルによってTTAの効果は変わる。複雑なモデルほどTTAの効果は少ない。複雑なモデルはTTAが与えるデータの揺らぎを吸収してしまうためと考えられる。モデルがどの要素に対して揺らぎを吸収するのか(不変性)が大事。
- データ量が少ないほどTTAの効果は大きい。
- ラベルの定義の仕方によってTTAから受ける影響は変わる。ImageNetとFlowers-102では前者の方がTTAから受ける影響は大きい。ImageNetでは、ラベルの中にTVと画面のような階層的なパターンがあったり、1枚の画像の中に複数のオブジェクトが含まれていたり、ラベルの中に犬種のように似たようなものがあったりするため。特に3番目については大きさの違いでしか見分けがつかない場合にTTAで拡大してしまうと誤った判定を引き起こしてしまう。
- 単純なネットワークモデルでもTTAを行うことで複雑なネットワークモデルに勝つことができる。
これらを踏まえた上でTTAの効果を引き出すパターンは以下のようになると考えられます。
- 取り扱う画像データの構図にばらつきがある。
- ネットワークモデルがシンプル。
- データ量が少ない。
- ラベル同士が似通っていない。
まとめ
TTAは推論時のデータ量を増やすため、少なからず計算時間に影響を及ぼします。そのため、少しでも処理時間対効果を引き上げるために、これらの特性を意識した実装が大切だと思います。