Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allowing for user defined transforms #10308

Open
wants to merge 15 commits into
base: main
Choose a base branch
from

Conversation

edkazcarlson
Copy link

@edkazcarlson edkazcarlson commented Apr 25, 2024

🛠️ PR Summary

Made with ❤️ by Ultralytics Actions

🌟 Summary

Enhancements and Flexibility in Data Handling and Model Configuration

📊 Key Changes

  • Introduced inputCh configuration to specify the number of input channels for images.
  • Added flexibility in data augmentation and label transformation with override_label_transforms and append_label_transforms options.
  • Enhanced model and dataset initialization to support the new configurations.
  • Improved plot handling for datasets with non-standard (non-3) image channels.

🎯 Purpose & Impact

  • Custom Input Channels: Allows models to handle images with different numbers of channels (e.g., grayscale or multispectral images), making the library more versatile.
  • Data Augmentation Customization: The addition of label transformation options provides users with the ability to customize or extend the data preprocessing and augmentation steps. This can lead to better model performance by tailoring the preprocessing steps to specific dataset characteristics.
  • Better Support for Non-Standard Images: The updates ensure that plotting functions gracefully handle images with non-standard channels, avoiding errors and improving user experience when working with such data.

These changes enhance the library's flexibility, making it more adaptable to various types of data and specific project requirements.

Copy link

github-actions bot commented Apr 25, 2024

All Contributors have signed the CLA. ✅
Posted by the CLA Assistant Lite bot.

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👋 Hello @edkazcarlson, thank you for submitting an Ultralytics YOLOv8 🚀 PR! To allow your work to be integrated as seamlessly as possible, we advise you to:

  • ✅ Verify your PR is up-to-date with ultralytics/ultralytics main branch. If your PR is behind you can update your code by clicking the 'Update branch' button or by running git pull and git merge main locally.
  • ✅ Verify all YOLOv8 Continuous Integration (CI) checks are passing.
  • ✅ Update YOLOv8 Docs for any new or updated features.
  • ✅ Reduce changes to the absolute minimum required for your bug fix or feature addition. "It is not daily increase but daily decrease, hack away the unessential. The closer to the source, the less wastage there is." — Bruce Lee

See our Contributing Guide for details and let us know if you have any questions!

Copy link

codecov bot commented Apr 25, 2024

Codecov Report

Attention: Patch coverage is 77.41935% with 14 lines in your changes are missing coverage. Please review.

Project coverage is 70.41%. Comparing base (51c3169) to head (2df3330).

Files Patch % Lines
ultralytics/data/dataset.py 57.89% 8 Missing ⚠️
ultralytics/models/yolo/classify/train.py 70.00% 3 Missing ⚠️
ultralytics/models/yolo/detect/train.py 83.33% 1 Missing ⚠️
ultralytics/models/yolo/detect/val.py 87.50% 1 Missing ⚠️
ultralytics/utils/plotting.py 50.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main   #10308      +/-   ##
==========================================
- Coverage   74.59%   70.41%   -4.19%     
==========================================
  Files         124      124              
  Lines       15664    15702      +38     
==========================================
- Hits        11685    11057     -628     
- Misses       3979     4645     +666     
Flag Coverage Δ
Benchmarks 35.49% <51.61%> (-0.02%) ⬇️
GPU 36.68% <51.61%> (-6.27%) ⬇️
Tests 66.43% <75.80%> (-3.86%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@edkazcarlson
Copy link
Author

I have read the CLA Document and I sign the CLA

@edkazcarlson edkazcarlson changed the title User/ecarlson/adding transformation compatability Allowing for user defined transforms Apr 28, 2024
@Burhan-Q Burhan-Q added the enhancement New feature or request label May 13, 2024
@edkazcarlson
Copy link
Author

Unsure if this will help with reviewing, but I tested the changes with the following code

def preprocess32(x):
    """
    Convert to float32 and normalize to 0-1
    Returns: np array
    """
    if (type(x) == type(torch.tensor([]))):
        x = x.transpose(0,1).transpose(1,2) # go from channel, height, width to height, width, channel
        x = x.numpy().astype(np.float32)
        if np.max(x) > 1.1:
            x /= 255
    elif type(x) == type(Image.Image()):
        x = np.float32(np.asarray(x))
        x /= 255
    elif type(x) == type(np.array([])):
        x = np.float32(x)
        x /= 255
    else: 
        print(f'In preprocess32 found {type(x)} wanted {type(Image.Image())} or {type(np.array())} or {type(torch.tensor([]))}')
        exit()
        
    assert np.max(x) <= 1.01, f'np.max(x) {np.max(x)}'
    assert np.min(x) >= -.01, f'np.min(x) {np.min(x)}'
    x = np.clip(x, 0, 1)
    return x

def FourChannelTransformMethod(x):
    x = preprocess32(x)
    x = torch.tensor(x, dtype= torch.float)
    x = x.transpose(2,1).transpose(1,0) # hwc -> chw
    x = torch.cat((x, torch.zeros_like(x[0]).unsqueeze(0)), dim=0)
    return x #c h w 

class FourChannelTransform(object):
    """Changes an image from bgr to lrgb.

    Args: normalizeSB: boolean that is true if the saturation and brightness are normalized around 0
    """
    def __init__(self):
        pass
    def __call__(self, labels):
        img = FourChannelTransformMethod(labels['img'])
        labels['img'] = img.to(torch.float16)
        return labels
    
    
class ThreeChannelTransform(object):
    """Changes an image from bgr to lrgb.

    Args: normalizeSB: boolean that is true if the saturation and brightness are normalized around 0
    """
    def __init__(self, dtype):
        self.dtype = dtype
        pass
    def __call__(self, labels):
        labels['img'] = labels['img'].to(self.dtype)
        return labels

class TwoChannelTransform(object):
    def __init__(self):
        pass

    def __call__(self, labels):
        labels['img'] = labels['img'][0:2]
        return labels

def firstTest():
    print('Default DetectionTrainer')
    overrides = {'epochs': 2, 'imgsz': 640, 'data': 'coco.yaml', 'model': f'yolov8n.yaml', 'inputCh': 3, 'batch': 8, 'close_mosaic': 1}
    trainer = DetectionTrainer(overrides=overrides)
    trainer.train()

def secondTest():
    print('3 channel detection trainer with float16')
    overrides = {'epochs': 2, 'imgsz': 640, 'data': 'coco.yaml', 'model': f'yolov8n.yaml', 'inputCh': 3, 'batch': 8, 'close_mosaic': 1}
    trainer = DetectionTrainer(overrides=overrides, append_label_transforms=ThreeChannelTransform(torch.float16))
    trainer.train()

def thirdTest():
    print('3 channel detection trainer with float32')
    overrides = {'epochs': 2, 'imgsz': 640, 'data': 'coco.yaml', 'model': f'yolov8n.yaml', 'inputCh': 3, 'batch': 8, 'close_mosaic': 1}
    trainer = DetectionTrainer(overrides=overrides, append_label_transforms=ThreeChannelTransform(torch.float32))
    trainer.train()
    
def fourthTest():
    print('4 channel detection')
    overrides = {'epochs': 2, 'imgsz': 640, 'data': 'coco.yaml', 'model': f'yolov8n.yaml', 'inputCh': 4, 'batch': 8, 'close_mosaic': 1}
    trainer = DetectionTrainer(overrides=overrides, append_label_transforms=FourChannelTransform())
    trainer.train()
    
def fifthTest():
    print('2 channel detection')
    overrides = {'epochs': 2, 'imgsz': 640, 'data': 'coco.yaml', 'model': f'yolov8n.yaml', 'inputCh': 2, 'batch': 8, 'close_mosaic': 1}
    trainer = DetectionTrainer(overrides=overrides, append_label_transforms=TwoChannelTransform())
    trainer.train()
    

@glenn-jocher
Copy link
Member

Thanks for sharing your testing code! It looks comprehensive and covers a variety of scenarios with different channel configurations and data types. This will definitely help in understanding how the changes perform across different setups. If you encounter any issues or have further suggestions, feel free to share! 🚀

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants