Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] There is the error in timm/train.py when i use the Webdataset (timm/imagent-w21-wds in huggingface) with class map #2154

Open
TheDarkKnight-21th opened this issue Apr 19, 2024 · 9 comments
Assignees
Labels
bug Something isn't working

Comments

@TheDarkKnight-21th
Copy link

TheDarkKnight-21th commented Apr 19, 2024

Describe the bug
A clear and concise description of what the bug is.

i wanted to train the the parts of IN21k-winter class , so i made the class map of IN21k-winter and i run training.

but there is the error in timm/train.py when i use the webdataset (timm/imagent-w21-wds) with class map

This error is "KEYERROR".

when i use the class map with Image Folder(dataset is original IN21k-winter, not wds) , there is not any error.

what should do i run the training scripts with class map of timm/imagenet-w21-wds?

(plz check the pytorch-image-models/timm/data/readers/reader_wds.py)

To Reproduce
Steps to reproduce the behavior:

  1. if you don't have timm/imagenet-w21-wds , plz download the dataset in huggingface.
  2. make a class map file (.txt) and put the class names (e.g. n03613592, n04116512, n04027706, n04367480 ....)
  3. and the run the timm/train.py with class_map

Expected behavior
A clear and concise description of what you expected to happen.

i just wanna run the training scripts with class map of timm/imagenet-w21-wds

Screenshots
If applicable, add screenshots to help explain your problem.

<class map example (.txt)>

image

<"error">

image

Desktop (please complete the following information):

  • OS: Ubuntu [e.g. Windows 10, Ubuntu 18.04]
  • This repository version: latest version
  • PyTorch version w/ CUDA/cuDNN [e.g. from conda list, 1.7.0 py3.8_cuda11.0.221_cudnn8.0.3_0]

Additional context
Add any other context about the problem here.

@TheDarkKnight-21th TheDarkKnight-21th added the bug Something isn't working label Apr 19, 2024
@TheDarkKnight-21th TheDarkKnight-21th changed the title [BUG] There is the error in timm/train.py when i use the class map of Webdataset (timm/imagent-w21-wds in huggingface) [BUG] There is the error in timm/train.py when i use the Webdataset (timm/imagent-w21-wds in huggingface) with class map Apr 19, 2024
@rwightman
Copy link
Collaborator

The webdataset pipeline doesn't have access to the string classnames at that point, it's using integer indicies from the get go, so the map capability is pretty minimal it can only really remap integers, not even filter out. I'd have to add specific filtering/remap functionality but not sure it's easy to do in a generic way without looking closer, would need to use other metadata to get the classname (sysnset) ...

@TheDarkKnight-21th
Copy link
Author

TheDarkKnight-21th commented Apr 19, 2024

The webdataset pipeline doesn't have access to the string classnames at that point, it's using integer indicies from the get go, so the map capability is pretty minimal it can only really remap integers, not even filter out. I'd have to add specific filtering/remap functionality but not sure it's easy to do in a generic way without looking closer, would need to use other metadata to get the classname (sysnset) ...

so you mean now i cannot use class map for wds now. if not what can i do? plz tell me solution in detail :) i just wanna class pruning...

@rwightman
Copy link
Collaborator

Quickest path would be to hack the _decode function at this line

Decode the json there, get the class_name from the json, and then do a lookup on your class map there (have to make the map accessible there), return None before image decode if the class isn't in the map (this will skip the sample), and overwrite the class_label with mapped value if it is there.

So not too crazy but not trivial either.

@TheDarkKnight-21th
Copy link
Author

Quickest path would be to hack the _decode function at this line

Decode the json there, get the class_name from the json, and then do a lookup on your class map there (have to make the map accessible there), return None before image decode if the class isn't in the map (this will skip the sample), and overwrite the class_label with mapped value if it is there.

So not too crazy but not trivial either.

okay i will try soon and report to you.

@TheDarkKnight-21th
Copy link
Author

TheDarkKnight-21th commented Apr 24, 2024

Quickest path would be to hack the _decode function at this line

Decode the json there, get the class_name from the json, and then do a lookup on your class map there (have to make the map accessible there), return None before image decode if the class isn't in the map (this will skip the sample), and overwrite the class_label with mapped value if it is there.

So not too crazy but not trivial either.

hi rwightman.

i added the code what you did say, i ran the train.py. but that have had so ling running time. (None class map => 2 hours on 1 epoch, w/ class map => 5days on 1 epoch)

what should i do? ( i attach the reader_wds.py)
reader_wds.zip

( the code in timm/data/readers/reader_wds.py 157.line)


if class_map.endswith(".txt"):
meta = json.loads(sample['json'])
class_name = meta["class_name"]
with open(class_map) as f:
class_list = [name.split("\n")[0] for name in f.readlines()]
#class_to_idx = {v.strip(): k for k, v in enumerate(f)}
if class_name not in class_list:
return None

(the image of the code )
image

@rwightman
Copy link
Collaborator

rwightman commented Apr 24, 2024

@TheDarkKnight-21th that's going to be extremely slow, you're loading the same mapping file every sample.

You'd want to add a class_to_idx argument to the _decoder fn. If it's a valid map (dict w/ entries), execute the remap block, you also want to remap the actual label not just filter out valid, because you'd; want to collapse the label space it consecutive indices in most use cases.

In the __init__ method of the reader the class map is already loaded to the class_to_idx, so you 'd just bind that as an argument to _decode via the partial that's already there...

@TheDarkKnight-21th
Copy link
Author

TheDarkKnight-21th commented Apr 25, 2024

yeah you are right. i chagned the code that you said. i just added class_to_idx on _decoder and also added the argument on pipline partial. but still it is so slow. (1epoch => 3.5days)
Also. Trainig iteration log don't reduce. (the iteration before pruning and the iteration after pruning in 1 epoch are same.)

<_decode ()>
image

<ReaderWds init()>
image

(i add this code also)
<ReaderWds iter()>
image

i also post the reader_wds.py in reader_wds.zip. plz unzip this file.
reader_wds.zip

@rwightman
Copy link
Collaborator

rwightman commented Apr 25, 2024

With sharded datasets there is no way of knowing what samples are still valid due to filtering, so there is no way of knowing the dataset length without calculating yourself. You have to provide a new number of samples as estimate for the filtered dataset.

Now with that in mind, since it will continue to use the same # of samples and thus steps, if you are significnatly filtering, ie only taking a small % of the classes it willl be exrremely inefficient because you have to iterate over ALL the samples in the shard to get the ones you want, so you have to read all that data still (you just avoid the decode), but likely to slow your dataloading per samples extracted. If you took say 100 of 19000 classes you'd have to pass over the dataset ~190x to get the same # of samples.

So you probably want to evaluate your motivation of doing this. If you want to say use 50% of the classes (roughly evenly distributed in frequency) that'd be okay, if you want to say use a few thousand or of the least frequent classes this will be inefficient. And your CPU + disk throughput will be determining where the limit is...

@TheDarkKnight-21th
Copy link
Author

With sharded datasets there is no way of knowing what samples are still valid due to filtering, so there is no way of knowing the dataset length without calculating yourself. You have to provide a new number of samples as estimate for the filtered dataset.

Now with that in mind, since it will continue to use the same # of samples and thus steps, if you are significnatly filtering, ie only taking a small % of the classes it willl be exrremely inefficient because you have to iterate over ALL the samples in the shard to get the ones you want, so you have to read all that data still (you just avoid the decode), but likely to slow your dataloading per samples extracted. If you took say 100 of 19000 classes you'd have to pass over the dataset ~190x to get the same # of samples.

So you probably want to evaluate your motivation of doing this. If you want to say use 50% of the classes (roughly evenly distributed in frequency) that'd be okay, if you want to say use a few thousand or of the least frequent classes this will be inefficient. And your CPU + disk throughput will be determining where the limit is...

To summarize what you have been saying, if you proceed as mentioned in the quoted text above, i can selectively train data sample using a class_map, but due to reasons mentioned above (e.g., repeated processes for sampling, limited resource issues), the training could be slow.
And to accurately measure the number of training iterations for sampling done through class_map, I would need to define the total number of iterations myself, right?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants