IT博客汇
  • 首页
  • 精华
  • 技术
  • 设计
  • 资讯
  • 扯淡
  • 权利声明
  • 登录 注册

    Use repeated dataset correctly with timm’s data loader

    RobinDong发表于 2025-06-25 00:10:37
    love 0

    For an experiment of metaformer, I was trying to add CIFAR100 dataset into the training script. Since CIFAR100 is too small, I need to let it repeat mulitple times in one epoch. Therefore I add a new type of dataset:

    class RepeatDataset(Dataset):
        def __init__(self, dataset, repeats):
            self.dataset = dataset
            self.repeats = repeats
            self.length = len(dataset) * repeats
    
        def __getitem__(self, idx):
            return self.dataset[idx % len(self.dataset)]
    
        def __len__(self): 
            return self.length

    But the training will report error:

    Traceback (most recent call last):                                                                                    
      File "/home/robin/code/metaformer/train.py", line 970, in <module>                                                  
        main()                                                                                                            
      File "/home/robin/code/metaformer/train.py", line 732, in main                                                      
        train_metrics = train_one_epoch(                       
                        ^^^^^^^^^^^^^^^^                                                                                  
      File "/home/robin/code/metaformer/train.py", line 798, in train_one_epoch                                           
        for batch_idx, (input, target) in enumerate(loader):                                                              
                                          ^^^^^^^^^^^^^^^^^                                                               
      File "/home/robin/miniconda3/envs/poolformer/lib/python3.12/site-packages/timm/data/loader.py", line 131, in __iter__                                                                                                                     
        for next_input, next_target in self.loader:                                                                       
                                       ^^^^^^^^^^^                                                                        
      File "/home/robin/miniconda3/envs/poolformer/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 733, in __next__                                                                                                          
        data = self._next_data()                                                                                                                                                                                                                
               ^^^^^^^^^^^^^^^^^                                                                                          
      File "/home/robin/miniconda3/envs/poolformer/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1515, in _next_data                                                                                                       
        return self._process_data(data, worker_id)                                                                        
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                        
      File "/home/robin/miniconda3/envs/poolformer/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1550, in _process_data                                                                                                    
        data.reraise()                                         
      File "/home/robin/miniconda3/envs/poolformer/lib/python3.12/site-packages/torch/_utils.py", line 750, in reraise                                                                                                                          
        raise exception                                        
    AttributeError: Caught AttributeError in DataLoader worker process 0.                                                 
    Original Traceback (most recent call last):                                                                           
      File "/home/robin/miniconda3/envs/poolformer/lib/python3.12/site-packages/torch/utils/data/_utils/worker.py", line 349, in _worker_loop                                                                                                   
        data = fetcher.fetch(index)  # type: ignore[possibly-undefined]                                                   
               ^^^^^^^^^^^^^^^^^^^^                            
      File "/home/robin/miniconda3/envs/poolformer/lib/python3.12/site-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch                                                                                                            
        return self.collate_fn(data)                           
               ^^^^^^^^^^^^^^^^^^^^^                           
      File "/home/robin/miniconda3/envs/poolformer/lib/python3.12/site-packages/timm/data/mixup.py", line 305, in __call__                                                                                                                      
        output = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)                                         
                                           ^^^^^^^^^^^^^^^^^                                                              
    AttributeError: 'Image' object has no attribute 'shape'. Did you mean: 'save'?                         

    It cost me a quite long time to solve it. The key is in the implementation of “timm.data.create_loader”: https://github.com/huggingface/pytorch-image-models/blob/main/timm/data/loader.py#L291. In it, it will set “dataset.transform” to a new value, and in “timm.data.dataset” https://github.com/huggingface/pytorch-image-models/blob/main/timm/data/dataset.py#L66-L67, it will check and use this new set “transform”:

    ...
            if self.transform is not None:
                img = self.transform(img)     
    ...

    Since the class RepeatDataset is created by myself and it will not handle the “dataset.transform = create_transform()”, it failed when calling the non-existed “transform()”.

    The fix comes from ChatGPT and I think it’s not bad:

    class RepeatDataset(Dataset):
        def __init__(self, dataset, repeats):
            self.dataset = dataset
            self.repeats = repeats
            self.length = len(dataset) * repeats
    
        @property
        def transform(self):
            return self.dataset.transform
    
        @transform.setter
        def transform(self, value):
            self.dataset.transform = value
    
        def __getitem__(self, idx):
            return self.dataset[idx % len(self.dataset)]
    
        def __len__(self):
            return self.length


沪ICP备19023445号-2号
友情链接