PyTorchで画像分類(その4)

前回は学習に必要な損失関数と最適化アルゴリズムを作成したので、今回は実際に学習をします。

題材は前回までと同じkaggleの犬/猫の画像分類コンペを使います。

学習

早速ですが、実装します。

def train_model(net, dataloader_dict, criterion, optimizer, num_epoch):
    
    # ベストなネットワークの重みを保持する変数
    best_model_wts = copy.deepcopy(net.state_dict())
    best_acc = 0.0

    # GPUが使えるのであればGPUを有効化する
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    net = net.to(device)
    
    # (エポック)回分のループ
    for epoch in range(num_epoch):
        print('Epoch {}/{}'.format(epoch + 1, num_epoch))
        print('-'*20)
        
        for phase in ['train', 'val']:
            
            if phase == 'train':
                # 学習モード
                net.train()
            else:
                # 推論モード
                net.eval()
                
            epoch_loss = 0.0
            epoch_corrects = 0
            
            # 第1回で作成したDataLoaderを使って画像データを読み込む
            for inputs, labels in tqdm(dataloader_dict[phase]):
                inputs = inputs.to(device)
                labels = labels.to(device)
                # 勾配を初期化する
                optimizer.zero_grad()
                
                # 学習モードの場合のみ勾配の計算を可能にする
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = net(inputs)
                    _, preds = torch.max(outputs, 1)
                    # 第3回で作成した損失関数を使って損失を計算する
                    loss = criterion(outputs, labels)
                    
                    if phase == 'train':
                        # 誤差を逆伝搬する
                        loss.backward()
                        # パラメータを更新する
                        optimizer.step()
                        
                    epoch_loss += loss.item() * inputs.size(0)
                    epoch_corrects += torch.sum(preds == labels.data)
                    
            # 1エポックでの損失を計算
            epoch_loss = epoch_loss / len(dataloader_dict[phase].dataset)
            # 1エポックでの正解率を計算
            epoch_acc = epoch_corrects.double() / len(dataloader_dict[phase].dataset)
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
            
            # 推論モードでベストの正解率を出したモデルを保存する
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(net.state_dict())
                
    print('Best val acc: {:4f}'.format(best_acc))

    # 以後の推論等ではベストのモデルを使うため、ベストのモデルを戻り値とする
    net.load_state_dict(best_model_wts)
    return net

ソースコードのコメントに簡単な説明を入れました。

Kerasの場合、損失関数と最適化アルゴリズムをcompileメソッドに渡してfitメソッドを呼べば学習できてしまいますが、それに比べてPyTorchは記述量が多いです。が、処理の中身がブラックボックス化されておらず、カスタマイズの余地があるため、個人的にはPyTorchの方が好きです。(あくまでも個人の好みです)

学習用の関数の実装が終わったので、実際に学習させましょう。

num_epoch = 30
net = train_model(net, dataloader_dict, criterion, optimizer, num_epoch)

これでしばらく(GPUを使って数時間)待てば、netに学習済みのモデルがされます。

推論(Kaggleに提出するためのCSV作成)

学習が終わったので、学習で使用していない画像を使って推論してみましょう。(最終的にKaggleに提出するためのCSVを作成します)

id_list = []
pred_list = []

# 推論のため勾配の計算をしない
with torch.no_grad():
    for test_path in tqdm(test_list):
        img = Image.open(test_path)
        _id = int(test_path.split('/')[-1].split('.')[0])

        transform = ImageTransform(size, mean, std)
        # valフェーズで画像変換
        img = transform(img, phase='val')
        # [x,x,x,...] -> [[x,x,x,...]]に変換
        img = img.unsqueeze(0)
        img = img.to(device)

        # 推論モード
        net.eval()

        # 推論
        outputs = net(img)
        # ソフトマックス関数を使って確率を出力
        preds = F.softmax(outputs, dim=1)[:, 1].tolist()
        
        id_list.append(_id)
        pred_list.append(preds[0])
    
    
res = pd.DataFrame({
    'id': id_list,
    'label': pred_list
})

res.sort_values(by='id', inplace=True)
res.reset_index(drop=True, inplace=True)

# CSV出力
res.to_csv('submission.csv', index=False)

CSVファイルが成果物として出力されるので、これをKaggleにsubmitすればスコアが確認できます!

終わりに

全4回を通してPyTorchでの画像分類の流れを解説しましたが、精度を上げるためには数多のテクニックがありますので、また次の機会にアップできればと思います!