Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Not able to load a pre trained weight that I have used for training of custom data. #1621

Open
AGRocky opened this issue Dec 13, 2023 · 4 comments

Comments

@AGRocky
Copy link

AGRocky commented Dec 13, 2023

Hey Guys, Please hear out.
I have trained a model and have a pre trained weight but not able to load that model to test it on different set of images which i have prepared. please help me out on this one. It would be deeply appretiated.

thanks in advance

@JustinasLekavicius
Copy link

Was it CycleGAN or Pix2pix model?

To test it on a different data set, you could use this command as an example:

python test.py --dataroot directory, e.g. (/content/data/A) --name model_name --model test --netD n_layers --n_layers_D 3 --netG=unet_256 --norm=instance --direction AtoB --dataset_mode single --preprocess none --input_nc 1 --output_nc 3 --ndf 64 --ngf 64 --num_test 512

Make sure to replace the parameters with the same ones you used for training of the model (netD, n_layers, netG, ndf, ngf, etc.)

@AGRocky
Copy link
Author

AGRocky commented Dec 27, 2023

Hey @JustinasLekavicius thank you for replying to my issue. I am using CycleGAN for training purpose is to denoise the image. However I have the pre trained weight which works well when it is used with the python command line code
"python test.py --dataroot directory, e.g. (/content/data/A) --name model_name --model test..."

But when I try the same thing by creating a class to load the model and give image as input and get the output as denoised image, I am unable to do it. However if I try to load the model the prediction or the testing output image which is denoised image isn't getting generated accurately but it's happening in with the above python command code.
please help me with this

heartily thank you in advance

@AGRocky
Copy link
Author

AGRocky commented Dec 28, 2023

import torch
from models.networks import define_G
from PIL import Image
from torchvision import transforms
from IPython.display import display

Define the generator model

generator = define_G(input_nc=3, output_nc=3, ngf=64, netG='resnet_9blocks', norm='instance', use_dropout=False,init_type='normal',init_gain=0.02)

Load the pre-trained weights from a saved checkpoint

generator_checkpoint_path = 'latest_net_G.pth'
checkpoint = torch.load(generator_checkpoint_path, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))

Load the generator state_dict

generator.load_state_dict(checkpoint, strict=False)

Set the model to evaluation mode (important if using dropout during training)

generator.eval()

Load your input image

input_image_path = 'nt6.jpg'
input_image = Image.open(input_image_path).convert('RGB')

Resize the input image to the expected size

input_image = input_image.resize((512, 512))

Convert the input image to a PyTorch tensor

input_tensor = transforms.ToTensor()(input_image).unsqueeze(0) # Add batch dimension

Move the input tensor to the GPU if available

if torch.cuda.is_available():
input_tensor = input_tensor.to('cuda')

Move the generator to the same device as the input tensor

generator = generator.to(input_tensor.device)

Generate the output image

with torch.no_grad():
output_tensor = generator(input_tensor)

Move the output tensor to the CPU if necessary

output_tensor = output_tensor.cpu()

Convert the output tensor to a PIL image

output_image = transforms.ToPILImage()(output_tensor.squeeze(0))

Display the generated image

display(output_image)

Save the generated image

output_image.save('generated_image.jpg')

this is my code

@ystoneman
Copy link

Hi AGRocky,

To help troubleshoot your CycleGAN model issue, could you provide:

  1. Versions: Exact versions of Python, PyTorch, and other libraries used.
  2. System Specs: Your GPU model and overall system configuration.
  3. Error Details: Any specific error messages or warnings during model loading.

These details will help in accurately replicating the issue and providing a solution.

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants