Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GlobalAveragePooling1D fails with empty inputs and a mask #67023

Open
theomeb opened this issue May 6, 2024 · 2 comments
Open

GlobalAveragePooling1D fails with empty inputs and a mask #67023

theomeb opened this issue May 6, 2024 · 2 comments
Assignees
Labels
comp:ops OPs related issues TF 2.15 For issues related to 2.15.x type:bug Bug

Comments

@theomeb
Copy link

theomeb commented May 6, 2024

Issue type

Bug

Have you reproduced the bug with TensorFlow Nightly?

No

Source

source

TensorFlow version

2.15.0

Custom code

No

OS platform and distribution

No response

Mobile device

No response

Python version

3.10.12

Bazel version

No response

GCC/compiler version

No response

CUDA/cuDNN version

No response

GPU model and memory

No response

Current behavior?

tf.keras.layers.GlobalAveragePooling1D cannot be called on an empty tensor with an empty mask. This can cause issue when using a model that uses this layer under a distributed strategy, e.g. MirroredStrategy, which will distribute the data over multiple GPUs. For instance, for a dataset with 3 samples, a batch_size of 2 and 3 GPUs, one of the GPUs will have empty batches which will cause the error.

This is due to the casting math_ops.cast(mask, inputs[0].dtype) in GlobalAveragePooling1D::call() which implies a non-empty inputs tensor, while math_ops.cast(mask, inputs.dtype) should do the trick without causing the error.

Standalone code to reproduce the issue

import tensorflow as tf
import numpy as np

x = np.random.rand(0, 3, 4)
mask = np.random.rand(0, 3)
print(x, mask)

y = tf.keras.layers.GlobalAveragePooling1D()(x, mask=mask)
print(y)

Relevant log output

[] []
---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-20-0f7c323aaaab> in <cell line: 7>()
      5 print(x, mask)
      6 
----> 7 y = tf.keras.layers.GlobalAveragePooling1D()(x, mask=mask)
      8 print(y)

1 frames
/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py in error_handler(*args, **kwargs)
     68             # To get the full stack trace, call:
     69             # `tf.debugging.disable_traceback_filtering()`
---> 70             raise e.with_traceback(filtered_tb) from None
     71         finally:
     72             del filtered_tb

/usr/local/lib/python3.10/dist-packages/tensorflow/python/framework/ops.py in raise_from_not_ok_status(e, name)
   5881 def raise_from_not_ok_status(e, name) -> NoReturn:
   5882   e.message += (" name: " + str(name if name is not None else ""))
-> 5883   raise core._status_to_exception(e) from None  # pylint: disable=protected-access
   5884 
   5885 

InvalidArgumentError: Exception encountered when calling layer 'global_average_pooling1d_8' (type GlobalAveragePooling1D).

{{function_node __wrapped__StridedSlice_device_/job:localhost/replica:0/task:0/device:CPU:0}} slice index 0 of dimension 0 out of bounds. [Op:StridedSlice] name: global_average_pooling1d_8/strided_slice/

Call arguments received by layer 'global_average_pooling1d_8' (type GlobalAveragePooling1D):
  • inputs=tf.Tensor(shape=(0, 3, 4), dtype=float32)
  • mask=array([], shape=(0, 3), dtype=float64)
@google-ml-butler google-ml-butler bot added the type:bug Bug label May 6, 2024
@theomeb theomeb changed the title GlobalAveragePooling1D fails on empty mask GlobalAveragePooling1D fails with empty inputs and a mask May 6, 2024
@theomeb
Copy link
Author

theomeb commented May 6, 2024

A colab environment with the bug and the fix: https://colab.research.google.com/drive/196CeYNTFTXdSj2-1sBlqQvrWIER3Q36L?usp=sharing

@sushreebarsa sushreebarsa added comp:ops OPs related issues TF 2.15 For issues related to 2.15.x labels May 8, 2024
@sushreebarsa
Copy link
Contributor

@SuryanarayanaY I was able to replicate the issue reported here.
Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:ops OPs related issues TF 2.15 For issues related to 2.15.x type:bug Bug
Projects
None yet
Development

No branches or pull requests

3 participants