Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to convert custom .pth Model to .onnx? #193

Closed
endh1337 opened this issue Feb 21, 2022 · 18 comments
Closed

How to convert custom .pth Model to .onnx? #193

endh1337 opened this issue Feb 21, 2022 · 18 comments

Comments

@endh1337
Copy link

Hey, first of all thanks for sharing your great work!

As described in the docs I trained a custom model based on the U²Net architecture to remove specific backgrounds and the results were fine but it seems like you cut off custom model support in 3b18bad. Are you planing to add this again in future? Could you please give an insight how you converted the u2net .pth-models to the .onnx ones?

Thanks!

@danielgatis
Copy link
Owner

try this:
https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html

@endh1337
Copy link
Author

Wow! Thanks for your fast and also helpful response!

I converted my custom model via

import io
import numpy as np
import torch.onnx
from model import U2NET

torch_model = U2NET(3,1)
model_path = "<pathToStateDict>.pth"
batch_size = 1

torch_model.load_state_dict(torch.load(model_path))
torch_model.eval()

x = torch.randn(batch_size, 3, 320, 320, requires_grad=True)
torch_out = torch_model(x)

torch.onnx.export(torch_model, x, "model.onnx", export_params=True, opset_version=11, do_constant_folding=True, input_names = ['input'], output_names = ['output'], dynamic_axes = {'input' : {0: 'batch_size'}, 'output': {0: 'batch_size'}})

And now it's working again with rembg!

@suri199507
Copy link

Wow! Thanks for your fast and also helpful response!

I converted my custom model via

import io
import numpy as np
import torch.onnx
from model import U2NET

torch_model = U2NET(3,1)
model_path = "<pathToStateDict>.pth"
batch_size = 1

torch_model.load_state_dict(torch.load(model_path))
torch_model.eval()

x = torch.randn(batch_size, 3, 320, 320, requires_grad=True)
torch_out = torch_model(x)

torch.onnx.export(torch_model, x, "model.onnx", export_params=True, opset_version=11, do_constant_folding=True, input_names = ['input'], output_names = ['output'], dynamic_axes = {'input' : {0: 'batch_size'}, 'output': {0: 'batch_size'}})

And now it's working again with rembg!
@endh1337
How did you train the model with custom images, and How did you get this from model import U2NET

@danielgatis
Copy link
Owner

Hi @endh1337,
It would be wonderful if you could describe your steps to training models

@endh1337
Copy link
Author

endh1337 commented Mar 1, 2022

Hey sorry @danielgatis and @suri199507, latest events got me quite busy

First of all, I'm a newbie to python and machine learning and stuff, I'm just crawling the web for information I can use to achieve some specific background/object removal tasks. So please don't judge this non professional response 😅

TL;DR

  1. Create your own dataset like for example the DUTS-TR dataset
  2. Checkout and install the U²Net-Repository
  3. (optional) get it to work with the DUTS-TR dataset so you can get see by yourself how this works
  4. Change hardcoded references and stuff for use with DUTS-TR dataset in u2net_train.py to work with your own dataset (file system references etc,)
  5. If you know what you are doing, change the training parameters and loss function - I did not know, so I left it like it was
  6. Maybe extend your training data by mirroring and rotating your images and masks
  7. run python u2net_train.py and leave it running until your auto saved results in saved_models satisfy you
  8. Convert it to onnx like I described early
  9. Use it with rembg :)

Long answer (no in-depth guide!)

The U²Net-Repository was originally trained on the DUTS-TR dataset which is a set of imagery and their counterpart masks. So you have images like

DUTS-TR/DUTS-TR-Image/ILSVRC2012_test_00000022.jpg

ILSVRC2012_test_00000022

and their couterpart mask DUTS-TR/DUTS-TR-Mask/ILSVRC2012_test_00000022.png

ILSVRC2012_test_00000022

which is according to the resources I found the ground truth binary mask (i guess this means only black and white) of what element of the image should be segmentated.

So at first, u need to create your own dataset like DUTS-TR and mask the objects you want to segmentate in white and leave the background / parts you want to be removed by RemBG black. By the way, RemBG does not only work for background removal, you can train a U²Net model to also segmentate a specific part of the image you want to be removed (leave it black in the mask and all the surroundings white. You can change this behavior, but by default you have one directory with the images (.jpg) and another directory containting only the masks (same name like it's original image but .png extension)

I cloned the U²Net-Repository and made a few changes in it's u2net_train.py-file (like changing the model name and directories for the train data because references to them were quiet hard coded). Because U²Net was originally trained with the DUTS-TR dataset, you'll need to change some file system references to your own dataset. Here an example out of u2net_train.py

model_name = 'u2net' #'u2netp'

data_dir = os.path.join(os.getcwd(), 'train_data' + os.sep)
tra_image_dir = os.path.join('DUTS', 'DUTS-TR', 'DUTS-TR', 'im_aug' + os.sep)
tra_label_dir = os.path.join('DUTS', 'DUTS-TR', 'DUTS-TR', 'gt_aug' + os.sep)

If you change the model_name some other references won't work as well, so you need to make some adjustments like switching from

# ------- 3. define model --------
# define the net
if(model_name=='u2net'):
    net = U2NET(3, 1)
elif(model_name=='u2netp'):
    net = U2NETP(3,1)

to

# ------- 3. define model --------
net = U2NET(3, 1)

You will need to change some other parts I won't describe here e.g auto saving models, cuda support etc.. Crawl The U2net Issues for more information. If you (are different than me and you) know what you are doing, you can adjust the model parameters like the loss function in u2net_train.py or others in model\u2net.py. Anyways, I left them like they were and got nevertheless good results, although my dataset has nothing to do with "salient object detection".

After fixing all the errors occuring while executing python u2net_train.py (almost every error/warning was due to switching file system references/directories etc.), you can leave it running for some time (I trained my dataset almost 2 weeks to get satisfiable results). Your model/weights will be saved after a certain count of iterations which you can then use to test your weights on test images in u2net_test.py.

After that you convert it to .onnx like described before and it works wonderful with rembg.

To @suri199507

and How did you get this from model import U2NET

The script to convert the .pth-weights and model to onnx I posted before just sits in the same cloned repository of U²Net, so it imports the models out of model/u2net.py. As far as i know, pytorch needs this information besides the saved model weights (the .pth file) to interprete whats going on and process an .onnx based model.


My response is messier than I thought it'll be. Hope its helpful anyways

@940smiley
Copy link

Hey sorry @danielgatis and @suri199507, latest events got me quite busy

First of all, I'm a newbie to python and machine learning and stuff, I'm just crawling the web for information I can use to achieve some specific background/object removal tasks. So please don't judge this non professional response 😅

TL;DR

  1. Create your own dataset like for example the DUTS-TR dataset
  2. Checkout and install the U²Net-Repository
  3. (optional) get it to work with the DUTS-TR dataset so you can get see by yourself how this works
  4. Change hardcoded references and stuff for use with DUTS-TR dataset in u2net_train.py to work with your own dataset (file system references etc,)
  5. If you know what you are doing, change the training parameters and loss function - I did not know, so I left it like it was
  6. Maybe extend your training data by mirroring and rotating your images and masks
  7. run python u2net_train.py and leave it running until your auto saved results in saved_models satisfy you
  8. Convert it to onnx like I described early
  9. Use it with rembg :)

Long answer (no in-depth guide!)

The U²Net-Repository was originally trained on the DUTS-TR dataset which is a set of imagery and their counterpart masks. So you have images like

DUTS-TR/DUTS-TR-Image/ILSVRC2012_test_00000022.jpg

ILSVRC2012_test_00000022

and their couterpart mask DUTS-TR/DUTS-TR-Mask/ILSVRC2012_test_00000022.png

ILSVRC2012_test_00000022

which is according to the resources I found the ground truth binary mask (i guess this means only black and white) of what element of the image should be segmentated.

So at first, u need to create your own dataset like DUTS-TR and mask the objects you want to segmentate in white and leave the background / parts you want to be removed by RemBG black. By the way, RemBG does not only work for background removal, you can train a U²Net model to also segmentate a specific part of the image you want to be removed (leave it black in the mask and all the surroundings white. You can change this behavior, but by default you have one directory with the images (.jpg) and another directory containting only the masks (same name like it's original image but .png extension)

I cloned the U²Net-Repository and made a few changes in it's u2net_train.py-file (like changing the model name and directories for the train data because references to them were quiet hard coded). Because U²Net was originally trained with the DUTS-TR dataset, you'll need to change some file system references to your own dataset. Here an example out of u2net_train.py

model_name = 'u2net' #'u2netp'

data_dir = os.path.join(os.getcwd(), 'train_data' + os.sep)
tra_image_dir = os.path.join('DUTS', 'DUTS-TR', 'DUTS-TR', 'im_aug' + os.sep)
tra_label_dir = os.path.join('DUTS', 'DUTS-TR', 'DUTS-TR', 'gt_aug' + os.sep)

If you change the model_name some other references won't work as well, so you need to make some adjustments like switching from

# ------- 3. define model --------
# define the net
if(model_name=='u2net'):
    net = U2NET(3, 1)
elif(model_name=='u2netp'):
    net = U2NETP(3,1)

to

# ------- 3. define model --------
net = U2NET(3, 1)

You will need to change some other parts I won't describe here e.g auto saving models, cuda support etc.. Crawl The U2net Issues for more information. If you (are different than me and you) know what you are doing, you can adjust the model parameters like the loss function in u2net_train.py or others in model\u2net.py. Anyways, I left them like they were and got nevertheless good results, although my dataset has nothing to do with "salient object detection".

After fixing all the errors occuring while executing python u2net_train.py (almost every error/warning was due to switching file system references/directories etc.), you can leave it running for some time (I trained my dataset almost 2 weeks to get satisfiable results). Your model/weights will be saved after a certain count of iterations which you can then use to test your weights on test images in u2net_test.py.

After that you convert it to .onnx like described before and it works wonderful with rembg.

To @suri199507

and How did you get this from model import U2NET

The script to convert the .pth-weights and model to onnx I posted before just sits in the same cloned repository of U²Net, so it imports the models out of model/u2net.py. As far as i know, pytorch needs this information besides the saved model weights (the .pth file) to interprete whats going on and process an .onnx based model.

My response is messier than I thought it'll be. Hope its helpful anyways

this helped me alot!!, thanks!

@khadija23
Copy link

onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Got invalid dimensions for input: input for the following indices
index: 1 Got: 3 Expected: 1
Please fix either the inputs or the model
I'm goting this error with the my exported model to onnx I trained my own U2NET with midv500 dataset

@khadija23
Copy link

this is my code to export model
i```
mport io
import numpy as np
import torch.onnx
import pathlib
from utils import models

#torch_model = UNet(3,3)
model_path = "'
batch_size = 1
model = models.UNet(n_channels=1, n_classes=1)
checkpoint = torch.load(pathlib.Path(model_path))
model.load_state_dict(checkpoint)
model.load_state_dict(torch.load(model_path))
model.eval()
dynamic_axes_dict = {
    'actual_input': [0, 2, 3],
'Output': [0]
} 
x = torch.randn(batch_size, 1, 320, 320, requires_grad=True)
torch_out = model(x)

torch.onnx.export(model, x, "model.onnx", export_params=True, opset_version=11, do_constant_folding=True, input_names = ['input'], output_names = ['output'], dynamic_axes = dynamic_axes_dict)

@nazzimsulaiman
Copy link

How can I use rembg to load a custom trained model for prediction!

@AeroDEmi
Copy link

AeroDEmi commented Feb 1, 2023

Is there a way to extract the weights (.pth) from the onnx model?

@RacoonPy-Ai
Copy link

this is my code to export model
i```
mport io
import numpy as np
import torch.onnx
import pathlib
from utils import models

#torch_model = UNet(3,3)
model_path = "'
batch_size = 1
model = models.UNet(n_channels=1, n_classes=1)
checkpoint = torch.load(pathlib.Path(model_path))
model.load_state_dict(checkpoint)
model.load_state_dict(torch.load(model_path))
model.eval()
dynamic_axes_dict = {
    'actual_input': [0, 2, 3],
'Output': [0]
} 
x = torch.randn(batch_size, 1, 320, 320, requires_grad=True)
torch_out = model(x)

torch.onnx.export(model, x, "model.onnx", export_params=True, opset_version=11, do_constant_folding=True, input_names = ['input'], output_names = ['output'], dynamic_axes = dynamic_axes_dict)

i have try this script to convert model , model is converted but not working

@Gradylo
Copy link

Gradylo commented Apr 22, 2023

How can I use rembg to load a custom trained model for prediction!

@eliassama
Copy link

u2net_train.py or others in model\u2net.py. Anyways, I left them like they were and got nevertheless good results, although my dataset has nothing to do with "salient object detection".

What should I do with my own data set, and I'm wondering is there an easy, quick way to turn a color picture into a black and white binary image that I can customize the body for?

@tdp1996
Copy link

tdp1996 commented Oct 13, 2023

How can I use rembg to load a custom trained model for prediction!

hello, I have the same problem. I don't know how to use my custom-trained model with rembg :(

@tdp1996
Copy link

tdp1996 commented Oct 14, 2023

Hi guys, how can I use my custom-trained model with rembg?

@danielgatis
Copy link
Owner

try this:

rembg i -m u2net_custom -x '{"model_path": "~/.u2net/u2net.onnx"}' path/to/input.png path/to/output.png

@Jonathunky
Copy link

Dear friends,

I've heavily improved this code, and created a dedicated rembg-trainer repo!

It's much much faster now (uses hardware acceleration if possible, multi-threading where possible), more reliable, easier to start working with, and saves model into onnx format every x iterations, so you can easily compare model behaviour after each x iterations. Should be very intuitive and understandable.

Please kindly take a look. Thanks ever so much!

@jianwei164274
Copy link

try this:

rembg i -m u2net_custom -x '{"model_path": "~/.u2net/u2net.onnx"}' path/to/input.png path/to/output.png

Can I use the remove function with a custom model_path?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests