PyTorchで画像分類(その1)

PyTorchで画像分類をやってみたので、何回かに分けて解説します。

題材としては画像分類のシンプルな問題であるkaggleの犬/猫の画像分類コンペを使います。

データセット(Dataset)の実装

このコンペでは画像ファイルが提供されますが、PyTorchで処理するためにはテンソル形式に変換する必要があります。また、ディープラーニングでは読み込んだ画像を単純にテンソル形式に変換しただけでは精度が出ないので、1つの画像を色々と変換してロバスト性(頑強性。外乱に対して安定しているか。外乱とは外部から系を安定状態からずらそうとする力)を高めます。具体的には、中心で切り抜き、リサイズ、回転、反転等の変換をします。

それらを一手に担うのがデータセット(Dataset)です。

ではデータセットを実装してみましょう。

画像ファイルを読み込んでテンソル形式に変換するクラスです。PyTorchで提供されるdata.Datasetを継承することで簡単に実装できます。

class DogVsCatDataset(data.Dataset):
    
    def __init__(self, file_list, transform=None, phase='train'):    
        self.file_list = file_list
        self.transform = transform
        self.phase = phase
        
    def __len__(self):
        return len(self.file_list)
    
    def __getitem__(self, idx):
        
        # 画像パスを取得する
        img_path = self.file_list[idx]
        # 画像ファイルを読み込む
        img = Image.open(img_path)
        
        # 画像に対してリサイズ等の変換を行う
        img_transformed = self.transform(img, self.phase)
        
        # ラベルを取得する
        # 今回の問題ではファイル名にラベルが含まれているので、ファイル名を分解して取得する
        label = img_path.split('/')[-1].split('.')[0]
        if label == 'dog':
            label = 1
        elif label == 'cat':
            label = 0

        return img_transformed, label

扱う画像の種類が変わろうと、実装する内容はほとんど変わりません。

各メソッドは以下のような役割になっています。

  • __init__
    • インスタンス変数を設定。
  • __len__
    • データの件数を返す。
  • __getitem__
    • インデックス番号を渡したら、画像ファイルを読み込み、各種変換を施した上で、変換後の画像データ(テンソル形式)とラベルを返す。

しかし、Datasetクラスを見ても画像を変換するロジックが見当たりません。

はい、変換用のクラスは別途作成しないといけません。

class ImageTransform():
    
    def __init__(self, resize, mean, std):
        self.data_transform = {
            'train': transforms.Compose([
                transforms.RandomResizedCrop(resize),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ]),
            'val': transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(resize),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ])
        }
        
    def __call__(self, img, phase):
        return self.data_transform[phase](img)

こちらのクラスも扱う画像の種類が変わろうと、実装する内容はほとんど変わらず、どのような加工を加えたいのかで微調整する形となります。

各メソッドは以下のような役割になっています。

  • __init__
    • インスタンス変数を設定。
  • __call__
    • インスタンス変数に設定されている変換処理を実行。

ここで余談ですが、Pythonの仕様で__call__メソッドは特殊メソッドであり、インスタンスを関数のように呼び出すことができるようになります。イメージはこちらです。

t = ImageTransform(size, mean, std)
# インスタンス化

t(img, phase)
# ImageTransformの__call__が呼ばれる。

次回

データセットが作成できたので、画像ファイルを読み込むことが可能になりました。次回はディープラーニングの魂であるモデルの作成を解説したいと思います!