IT博客汇
  • 首页
  • 精华
  • 技术
  • 设计
  • 资讯
  • 扯淡
  • 权利声明
  • 登录 注册

    Hanging of PyTorch’s data loader

    RobinDong发表于 2023-05-05 01:25:34
    love 0

    Long story short. I am trying to build a Siamese network for audio classification. For 50% possibility, the “dataset.py” will try to find a pair of audios in the same category but with different files (also, different category for another 50% possibility). But when the evaluating start, it will hang after fetching a few batches. The trace could be see:

    Traceback (most recent call last):                                                                                                                                                                                                        
      File "/home/robin/song/birdclef/old_train.py", line 395, in <module>                                                
        train(args, train_loader, eval_loader)                                                                                                                                                                                                  
      File "/home/robin/song/birdclef/old_train.py", line 280, in train                                                   
        accuracy = evaluate(args, net, eval_loader)                                                                                                                                                                                             
      File "/home/robin/song/birdclef/old_train.py", line 91, in evaluate                                                 
        sounds1, sounds2, type_ids = next(batch_iterator)                                                                 
      File "/home/robin/miniconda3/envs/bird/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 634, in __next__
        data = self._next_data()                                                                                                                                                                                                                
      File "/home/robin/miniconda3/envs/bird/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1329, in _next_data
        idx, data = self._get_data()                                                                                      
      File "/home/robin/miniconda3/envs/bird/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1285, in _get_data                                                                                                              
        success, data = self._try_get_data()                                                                                                                                                                                                    
      File "/home/robin/miniconda3/envs/bird/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1133, in _try_get_data
        data = self._data_queue.get(timeout=timeout)                                                                      
      File "/home/robin/miniconda3/envs/bird/lib/python3.10/queue.py", line 180, in get                                   
        self.not_empty.wait(remaining)                                                                                    
      File "/home/robin/miniconda3/envs/bird/lib/python3.10/threading.py", line 324, in wait                              
        gotit = waiter.acquire(True, timeout)                                                                                                                                                                                                   
    KeyboardInterrupt 

    As usual, I start with suspection of PyTorch. Is the version of PyTorch too new (2.0) that it includes some flaws? Then I quickly rejected my thoughts: if it’s the problem of PyTorch, why it didn’t meet same situation when not using Siamese network?

    Then I found this issue in PyTorch GitHub page. It pointed to the clue: the new code in “dataset.py”. Now I notice the problem in my code:

                arr = self.cat_map[ebird_code]
                pair_wav_name = np.random.choice(arr)
                while pair_wav_name == wav_name:
                    pair_wav_name = np.random.choice(arr)
                pair_sound = self.get_sound(pair_wav_name, ebird_code)

    If a category only have one file, this loop will continue forever. This is the reason of the hang.

    The solution is simple:

                arr = self.cat_map[ebird_code]
                if len(arr) > 1:
                    pair_wav_name = np.random.choice(arr)
                    while pair_wav_name == wav_name:
                        pair_wav_name = np.random.choice(arr)
                else:
                    pair_wav_name = wav_name
                pair_sound = self.get_sound(pair_wav_name, ebird_code)


沪ICP备19023445号-2号
友情链接