Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Cannot register custom collate_fn function #1880

Open
ShangWeiKuo opened this issue Mar 6, 2024 · 0 comments
Open

[Bug] Cannot register custom collate_fn function #1880

ShangWeiKuo opened this issue Mar 6, 2024 · 0 comments

Comments

@ShangWeiKuo
Copy link

Branch

main branch (mmpretrain version)

Describe the bug

This is my custom code for configs/base/dasetsets/PN_bs64_pil.py. After running the command mim train mmpretrain ..., I got the runtime error as the picture below.

However. I cannot realize what happened for the line which @FUNCTIONS.register_module() is on based on the error message.

How should I do to deal with the problem?

from mmengine.dataset import DefaultSampler

from mmpretrain.datasets import (CenterCrop, CustomDataset, LoadImageFromFile,
                                 PackInputs, RandomFlip, RandomResizedCrop,
                                 ResizeEdge)
from mmpretrain.evaluation import Accuracy
from mmengine.registry import FUNCTIONS
import torch

@FUNCTIONS.register_module()
def custom_collate_fn(data):
    data.sort(key=lambda x: len(x[0][0]), reverse=False) 
    data_list = []
    label_list = []
    min_len = len(data[0][0][0])
    for batch in range(0, len(data)): 
        data_list.append(data[batch][0][:,:min_len])
        label_list.append (data[batch][1])
    data_tensor = torch.tensor(data_list, dtype=torch.float32)
    label_tensor = torch.tensor(label_list, dtype=torch.float32)
    data_copy = (data_tensor, label_tensor)
    return data_copy

# dataset settings
dataset_type = CustomDataset
data_preprocessor = dict(
    num_classes=2,
    # RGB format normalization parameters
    mean=[123.675, 116.28, 103.53],
    std=[58.395, 57.12, 57.375],
    # convert image from BGR to RGB
    to_rgb=True,
)

train_pipeline = [
    dict(type=LoadImageFromFile),
    dict(type=RandomFlip, prob=0.5, direction='horizontal'),
    dict(type=PackInputs),
]

val_pipeline = [
    dict(type=LoadImageFromFile),
    dict(type=PackInputs),
]

train_dataloader = dict(
    batch_size=16,
    num_workers=4,
    collate_fn=dict(type='custom_collate_fn'),
    dataset=dict(
        type=dataset_type,
        data_root=r'\path\to\train',
        pipeline=train_pipeline),
    sampler=dict(type=DefaultSampler, shuffle=True),
)

val_dataloader = dict(
    batch_size=16,
    num_workers=4,
    collate_fn=dict(type='custom_collate_fn'),
    dataset=dict(
        type=dataset_type,
        data_root=r'\path\to\val',
        pipeline=val_pipeline),
    sampler=dict(type=DefaultSampler, shuffle=False),
)
val_evaluator = dict(type=Accuracy, topk=(1, 5))

test_dataloader = dict(
    batch_size=16,
    num_workers=4,
    collate_fn=dict(type='custom_collate_fn'),
    dataset=dict(
        type=dataset_type,
        data_root=r'\path\to\test',
        pipeline=val_pipeline),
    sampler=dict(type=DefaultSampler, shuffle=False),
)
test_evaluator = dict(type=Accuracy, topk=(1, 5))
截圖 2024-03-06 下午2 22 36

Environment

'sys.platform': 'win32',
'Python': '3.8.18 (default, Sep 11 2023, 13:39:12) [MSC v.1916 64 bit ' '(AMD64)]',
'CUDA available': True,
'MUSA available': False,
'numpy_random_seed': 2147483648,
'GPU 0': 'NVIDIA GeForce RTX 3070',
'CUDA_HOME': 'C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.2',
'NVCC': 'Cuda compilation tools, release 12.2, V12.2.91',
'GCC': 'n/a',
'PyTorch': '1.10.1',
'TorchVision': '0.11.2',
'OpenCV': '4.9.0',
'MMEngine': '0.10.3',
'MMCV': '2.1.0',
'MMPreTrain': '1.2.0+17a886c'

Other information

No response

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant