Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tf.data filter dataset too slow #67330

Open
huangrt01 opened this issue May 10, 2024 · 2 comments
Open

tf.data filter dataset too slow #67330

huangrt01 opened this issue May 10, 2024 · 2 comments
Assignees
Labels
comp:data tf.data related issues TF 2.16 type:performance Performance Issue

Comments

@huangrt01
Copy link

huangrt01 commented May 10, 2024

Issue type

Performance

Have you reproduced the bug with TensorFlow Nightly?

Yes

Source

source

TensorFlow version

2.4

Custom code

Yes

OS platform and distribution

No response

Mobile device

No response

Python version

No response

Bazel version

No response

GCC/compiler version

No response

CUDA/cuDNN version

No response

GPU model and memory

No response

Current behavior?

Similar to issue #53169, I have observed that the "filter before batch" approach is significantly slow. Filtering the dataset alone takes 430ms, whereas the "batch+map" method only requires 20ms.
In theory, the computation of filter and map should be similar, but "filter before batch" consumes excessive time.
I attempted to filter after batching, but encountered a limitation where the filter predicate must return a scalar boolean value. Unfortunately, it does not support filtering batched elements.

My question is:
Is there a potential optimization for this performance issue? I aim to develop a customized operation that can filter batched elements (accepting [M,] shaped tensors as input and producing [N,] tensors as output). Is there a more efficient approach available?

import time

import tensorflow as tf
fast_dataset = tf.data.Dataset.range(10000)


def fast_benchmark(dataset, name, num_epochs=2):
    start_time = time.perf_counter()
    for _ in tf.data.Dataset.range(num_epochs):
        for _ in dataset:
            pass
    tf.print("Test", name, "Execution time(ms):", 1000 * (time.perf_counter() - start_time))


def increment(x):
    return x+1


def filter_fn(x):
  return tf.math.equal(tf.math.mod(x, 2), 1)


if __name__ == '__main__':
  fast_benchmark(
    fast_dataset
    .map(increment)
    .batch(256)
    ,
    "map+batch"
  )
  fast_benchmark(
    fast_dataset
    .batch(256)
    .map(increment)
    ,
    "batch+map"
  )
  fast_benchmark(
    fast_dataset
    .map(increment)
    .batch(256)
    .prefetch(tf.data.AUTOTUNE)
    ,
    "map+batch+prefetch"
  )
  fast_benchmark(
    fast_dataset
    .batch(256)
    .map(increment)
    .prefetch(tf.data.AUTOTUNE)
    ,
    "batch+map+prefetch"
  )
  fast_benchmark(
    fast_dataset
    .prefetch(tf.data.AUTOTUNE)
    .batch(256)
    .map(increment)
    ,
    "prefetch+batch+map"
  )
  fast_benchmark(
    fast_dataset
    .batch(256)
    .prefetch(tf.data.AUTOTUNE)
    .map(increment)
    ,
    "batch+prefetch+map"
  )
  fast_benchmark(
    fast_dataset
    .filter(filter_fn)
    .batch(256)
    ,
    "filter+batch"
  )
  fast_benchmark(
    fast_dataset
    .batch(256)
    .filter(filter_fn)
    ,
    "batch+filter"
  )

result:
image

Standalone code to reproduce the issue

import time

import tensorflow as tf
fast_dataset = tf.data.Dataset.range(10000)


def fast_benchmark(dataset, name, num_epochs=2):
    start_time = time.perf_counter()
    for _ in tf.data.Dataset.range(num_epochs):
        for _ in dataset:
            pass
    tf.print("Test", name, "Execution time(ms):", 1000 * (time.perf_counter() - start_time))


def increment(x):
    return x+1


def filter_fn(x):
  return tf.math.equal(tf.math.mod(x, 2), 1)


if __name__ == '__main__':
  fast_benchmark(
    fast_dataset
    .map(increment)
    .batch(256)
    ,
    "map+batch"
  )
  fast_benchmark(
    fast_dataset
    .batch(256)
    .map(increment)
    ,
    "batch+map"
  )
  fast_benchmark(
    fast_dataset
    .map(increment)
    .batch(256)
    .prefetch(tf.data.AUTOTUNE)
    ,
    "map+batch+prefetch"
  )
  fast_benchmark(
    fast_dataset
    .batch(256)
    .map(increment)
    .prefetch(tf.data.AUTOTUNE)
    ,
    "batch+map+prefetch"
  )
  fast_benchmark(
    fast_dataset
    .prefetch(tf.data.AUTOTUNE)
    .batch(256)
    .map(increment)
    ,
    "prefetch+batch+map"
  )
  fast_benchmark(
    fast_dataset
    .batch(256)
    .prefetch(tf.data.AUTOTUNE)
    .map(increment)
    ,
    "batch+prefetch+map"
  )
  fast_benchmark(
    fast_dataset
    .filter(filter_fn)
    .batch(256)
    ,
    "filter+batch"
  )
  fast_benchmark(
    fast_dataset
    .batch(256)
    .filter(filter_fn)
    ,
    "batch+filter"
  )

Relevant log output

Test map+batch Execution time(ms): 585.0009880959988
Test batch+map Execution time(ms): 23.16068299114704
Test map+batch+prefetch Execution time(ms): 503.9997957646847
Test batch+map+prefetch Execution time(ms): 19.63987946510315
Test prefetch+batch+map Execution time(ms): 54.23441715538502
Test batch+prefetch+map Execution time(ms): 16.469698399305344
Test filter+batch Execution time(ms): 282.77427703142166
@google-ml-butler google-ml-butler bot added the type:performance Performance Issue label May 10, 2024
@SuryanarayanaY SuryanarayanaY added comp:data tf.data related issues TF 2.16 labels May 14, 2024
@SuryanarayanaY
Copy link
Collaborator

Hi @huangrt01 ,

I have tested the code with tf-nightly and found execution fails.Could you please check the gist and confirm anything missing in submitted code?

@SuryanarayanaY SuryanarayanaY added the stat:awaiting response Status - Awaiting response from author label May 14, 2024
@huangrt01
Copy link
Author

Hi @huangrt01 ,

I have tested the code with tf-nightly and found execution fails.Could you please check the gist and confirm anything missing in submitted code?

Hi @SuryanarayanaY
you can just ignore the error and the "batch+filter" benchmark. In fact, we have got the log results:
image

@google-ml-butler google-ml-butler bot removed the stat:awaiting response Status - Awaiting response from author label May 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:data tf.data related issues TF 2.16 type:performance Performance Issue
Projects
None yet
Development

No branches or pull requests

2 participants