JavaScriptを有効にしてください

【PyTorch】「CUDA error: device-side assert triggered」 解決の手引き

 ·  ☕ 4 min read

はじめに

  • PyTorchにて, “RuntimeError: CUDA error: device-side assert triggered"というエラーに出くわすことがある
    • ネットに転がってるモデルで発生すると特に厄介である (自分が作った沼ではないので…)
    • またMAEでのマスク処理のような, テクニカルな処理を行う場合などにも頻発
    • 再現性が取れず, 出力されるエラー内容も二転三転. 一定確率で上記のエラーが発生する.
    • 今まではうまく行ってたのに急にエラーが頻発することなども多々あり
  • ということで, 金輪際このエラーに立ち往生しないよう, “CUDA error: device-side assert triggered"に終止符を打とう!
    • ネットに転がってる議論は入力のshapeが云々・loss関数が云々と具体的で狙い撃ち的すぎる
    • なので, より実践的な解決の手引きをメモ程度にまとめておく

TL;DR

  • エラーが発生するタイミングが何となくでも分かっていれば, 意図的にCUDAからCPUに切り替えてpdbでデバッグすれば良し
  • タイミングが全くわからない場合はCUDA_LAUNCH_BLOCKING=1を設定する

エラー解決の手引き

① まず前提として, 大抵はindexがおかしいことが多い

  • サイズよりも大きいindexを指定していると発生しがち
    • e.g. attn[:,idx]idxattn.shape[-1]を超えてない?
  • ただし, 案外エラーがtorch内部に根ざしていることも多いため注意が必要
  • そういう場合は後述の2つを参照すべし

② 難点はassertionがエラー発生地点で発動しないコト

  • つまり, pythonがエラー発生地点駅を寝過ごしているのである…!
  • こういうときはCUDA_LAUNCH_BLOCKING=1を設定すると, 発生地点でエラーを吐いてくれる
    • e.g. CUDA_LAUNCH_BLOCKING=1 python train.py --hogehoge
  • CUDA-guide 曰く, 上記のおまじないにより, CUDAの非同期処理を無効化(つまり同期処理に切り替える)ことができる
  • したがって, CUDA_LAUNCH_BLOCKINGによってpythonくんがエラー発生地点を見過ごさないようになってくれる

Programmers can globally disable asynchronicity of kernel launches for all CUDA applications running on a system by setting the CUDA_LAUNCH_BLOCKING environment variable to 1. This feature is provided for debugging purposes only and should not be used as a way to make production software run reliably.

引用: https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html

③ (重要) エラーが発生するタイミングで, 意図的にCUDAからCPUに切り替えるとエラーが明確になることが多い

  • 例えば, 「このデータ(index)をモデルに流したときだけエラーが発生するなあ」とか
  • 「DataLoaderがtdqmで10%, 11%が過ぎた頃合いでいつもエラーが発生するなあ」など
    • こういう場合は, 意図的にCUDAからCPUに切り替えると良い. 下に例を示す.
    • 下の例はイテレーションが123回前後でエラーが頻発する場合の解法
    • こうすることで, CPU上で処理が走るため原因が見つけやすい
1
2
3
4
5
6
7
8
for iter, (x,y) in enumerate(dataloader):
    if iter > 123: # 123の前後でエラーが発生することが多い場合
        device = "cpu"
        model = model.cpu()
    else:
        device = "cuda"
        model = model.cuda()
    x, y = x.to(device), y.to(device)
具体例
  • 具体例を見るとわかりやすいので, 次の例を見てみよう
    • M2-Transformerの動作確認で実際に遭遇した例を示す.
    • 上のように意図的にcpuへ処理を切り替えると以下のように, 具体的なエラー箇所が確認できる👇
Epoch 0 - train:  10%|███████▋                    | 123/1230 [02:25<20:19,  2.03it/s, loss=8.28
Traceback (most recent call last):
  File "train.py", line 282, in <module>
    train_loss = train_xe(model, dataloader_train, optim, text_field)
  File "train.py", line 99, in train_xe
    out = model(detections, captions)
  File "/home/initial/.pyenv/versions/anaconda3-2020.07/envs/m2release/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/initial/workspace/meshed-memory-transformer/models/transformer/transformer.py", line 29, in forward
    dec_output = self.decoder(seq, enc_output, mask_enc)
  File "/home/initial/.pyenv/versions/anaconda3-2020.07/envs/m2release/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/initial/workspace/meshed-memory-transformer/models/transformer/decoders.py", line 95, in forward
    out = self.word_emb(input) + self.pos_emb(seq)
  File "/home/initial/.pyenv/versions/anaconda3-2020.07/envs/m2release/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/initial/.pyenv/versions/anaconda3-2020.07/envs/m2release/lib/python3.6/site-packages/torch/nn/modules/sparse.py", line 126, in forward
    self.norm_type, self.scale_grad_by_freq, self.sparse)
  File "/home/initial/.pyenv/versions/anaconda3-2020.07/envs/m2release/lib/python3.6/site-packages/torch/nn/functional.py", line 1852, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
    
IndexError: index out of range in self

  • 次に, pdbを使って下から順にどこが原因の根本なのかを探っていく
  • 順に調べていくとout = self.word_emb(input) + self.pos_emb(seq)の行が怪しいとわかるので, word_embpos_embを精査してみる
  • self.pos_emb(seq)をpdb上で叩くとエラーを吐くので, ここでself.pos_embの内部における配列外参照が原因であることが決定する
  • 最終的に, 以下のようにseqがEmbeddingのサイズに収まっていないことが原因であることがわかった
1
2
(pdb) self.pos_emb -> Embedding(55, 512)
(pdb) seq.shape -> torch.Size([150, 61])

まとめ

  • エラーが発生するタイミングが何となくでも分かっていれば, 意図的にCUDAからCPUに切り替えてpdbでデバッグすれば良し
  • タイミングが全くわからない場合はCUDA_LAUNCH_BLOCKING=1を設定する
共有

YuWd (Yuiga Wada)
著者
YuWd (Yuiga Wada)
機械学習・競プロ・iOS・Web