Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use compile-time promotion to reduce bitwise op size & build time #3487

Closed
wants to merge 7 commits into from
63 changes: 18 additions & 45 deletions kernels/portable/cpu/op_bitwise_and.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
* LICENSE file in the root directory of this source tree.
*/

#include <cmath>
// patternlint-disable-next-line executorch-cpp-nostdinc
#include <functional>

#include <executorch/kernels/portable/cpu/pattern/bitwise_op.h>
#include <executorch/kernels/portable/cpu/scalar_utils.h>
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
#include <executorch/kernels/portable/cpu/util/functional_util.h>
Expand All @@ -17,20 +19,6 @@ namespace torch {
namespace executor {
namespace native {

namespace {

template <typename CTYPE>
CTYPE bitwise_and(CTYPE a, CTYPE b) {
return a & b;
}

template <>
bool bitwise_and<bool>(bool a, bool b) {
return a && b;
}

} // namespace

using Tensor = exec_aten::Tensor;

Tensor& bitwise_and_Tensor_out(
Expand All @@ -55,38 +43,23 @@ Tensor& bitwise_and_Tensor_out(
Bool, a_type, ctx, "bitwise_and.Tensor_out", CTYPE_A, [&]() {
ET_SWITCH_INT_TYPES_AND(
Bool, b_type, ctx, "bitwise_and.Tensor_out", CTYPE_B, [&]() {
ET_SWITCH_INT_TYPES_AND(
using CTYPE_IN = typename torch::executor::
promote_types<CTYPE_A, CTYPE_B>::type;
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
ET_SWITCH_REAL_TYPES_AND(
Bool,
common_type,
out_type,
ctx,
"bitwise_and.Tensor_out",
CTYPE_IN,
CTYPE_OUT,
[&]() {
ET_SWITCH_REAL_TYPES_AND(
Bool,
out_type,
ctx,
"bitwise_and.Tensor_out",
CTYPE_OUT,
[&]() {
apply_binary_elementwise_fn<
CTYPE_A,
CTYPE_B,
CTYPE_OUT>(
[](const CTYPE_A val_a, const CTYPE_B val_b) {
CTYPE_IN a_casted =
static_cast<CTYPE_IN>(val_a);
CTYPE_IN b_casted =
static_cast<CTYPE_IN>(val_b);
CTYPE_IN value =
bitwise_and(a_casted, b_casted);

return static_cast<CTYPE_OUT>(value);
},
a,
b,
out);
});
internal::BitwiseOpInner<
can_cast<CTYPE_IN, CTYPE_OUT>::value,
std::bit_and,
CTYPE_A,
CTYPE_B,
CTYPE_IN,
CTYPE_OUT>::run(a, b, out);
});
});
});
Expand Down Expand Up @@ -142,8 +115,8 @@ Tensor& bitwise_and_Scalar_out(
static_cast<CTYPE_IN>(val_a);
CTYPE_IN b_casted =
static_cast<CTYPE_IN>(val_b);
CTYPE_IN value =
bitwise_and(a_casted, b_casted);
CTYPE_IN value = std::bit_and<CTYPE_IN>()(
a_casted, b_casted);

return static_cast<CTYPE_OUT>(value);
},
Expand Down
61 changes: 18 additions & 43 deletions kernels/portable/cpu/op_bitwise_or.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
* LICENSE file in the root directory of this source tree.
*/

#include <cmath>
// patternlint-disable-next-line executorch-cpp-nostdinc
#include <functional>

#include <executorch/kernels/portable/cpu/pattern/bitwise_op.h>
#include <executorch/kernels/portable/cpu/scalar_utils.h>
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
#include <executorch/kernels/portable/cpu/util/functional_util.h>
Expand All @@ -17,20 +19,6 @@ namespace torch {
namespace executor {
namespace native {

namespace {

template <typename CTYPE>
CTYPE bitwise_or(CTYPE a, CTYPE b) {
return a | b;
}

template <>
bool bitwise_or<bool>(bool a, bool b) {
return a || b;
}

} // namespace

using Tensor = exec_aten::Tensor;

Tensor& bitwise_or_Tensor_out(
Expand All @@ -55,37 +43,23 @@ Tensor& bitwise_or_Tensor_out(
Bool, a_type, ctx, "bitwise_or.Tensor_out", CTYPE_A, [&]() {
ET_SWITCH_INT_TYPES_AND(
Bool, b_type, ctx, "bitwise_or.Tensor_out", CTYPE_B, [&]() {
ET_SWITCH_INT_TYPES_AND(
using CTYPE_IN = typename torch::executor::
promote_types<CTYPE_A, CTYPE_B>::type;
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
ET_SWITCH_REAL_TYPES_AND(
Bool,
common_type,
out_type,
ctx,
"bitwise_or.Tensor_out",
CTYPE_IN,
CTYPE_OUT,
[&]() {
ET_SWITCH_REAL_TYPES_AND(
Bool,
out_type,
ctx,
"bitwise_or.Tensor_out",
CTYPE_OUT,
[&]() {
apply_binary_elementwise_fn<
CTYPE_A,
CTYPE_B,
CTYPE_OUT>(
[](const CTYPE_A val_a, const CTYPE_B val_b) {
CTYPE_IN a_casted =
static_cast<CTYPE_IN>(val_a);
CTYPE_IN b_casted =
static_cast<CTYPE_IN>(val_b);
CTYPE_IN value = bitwise_or(a_casted, b_casted);

return static_cast<CTYPE_OUT>(value);
},
a,
b,
out);
});
internal::BitwiseOpInner<
can_cast<CTYPE_IN, CTYPE_OUT>::value,
std::bit_or,
CTYPE_A,
CTYPE_B,
CTYPE_IN,
CTYPE_OUT>::run(a, b, out);
});
});
});
Expand Down Expand Up @@ -141,7 +115,8 @@ Tensor& bitwise_or_Scalar_out(
static_cast<CTYPE_IN>(val_a);
CTYPE_IN b_casted =
static_cast<CTYPE_IN>(val_b);
CTYPE_IN value = bitwise_or(a_casted, b_casted);
CTYPE_IN value =
std::bit_or<CTYPE_IN>()(a_casted, b_casted);

return static_cast<CTYPE_OUT>(value);
},
Expand Down
64 changes: 18 additions & 46 deletions kernels/portable/cpu/op_bitwise_xor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
* LICENSE file in the root directory of this source tree.
*/

#include <cmath>
// patternlint-disable-next-line executorch-cpp-nostdinc
#include <functional>

#include <executorch/kernels/portable/cpu/pattern/bitwise_op.h>
#include <executorch/kernels/portable/cpu/scalar_utils.h>
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
#include <executorch/kernels/portable/cpu/util/functional_util.h>
Expand All @@ -17,28 +19,13 @@ namespace torch {
namespace executor {
namespace native {

namespace {

template <typename CTYPE>
CTYPE bitwise_xor(CTYPE a, CTYPE b) {
return a ^ b;
}

template <>
bool bitwise_xor<bool>(bool a, bool b) {
return a != b;
}

} // namespace

using Tensor = exec_aten::Tensor;

Tensor& bitwise_xor_Tensor_out(
RuntimeContext& ctx,
const Tensor& a,
const Tensor& b,
Tensor& out) {
// Determine output size and resize for dynamic shapes
ET_KERNEL_CHECK(
ctx,
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
Expand All @@ -56,38 +43,23 @@ Tensor& bitwise_xor_Tensor_out(
Bool, a_type, ctx, "bitwise_xor.Tensor_out", CTYPE_A, [&]() {
ET_SWITCH_INT_TYPES_AND(
Bool, b_type, ctx, "bitwise_xor.Tensor_out", CTYPE_B, [&]() {
ET_SWITCH_INT_TYPES_AND(
using CTYPE_IN = typename torch::executor::
promote_types<CTYPE_A, CTYPE_B>::type;
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
ET_SWITCH_REAL_TYPES_AND(
Bool,
common_type,
out_type,
ctx,
"bitwise_xor.Tensor_out",
CTYPE_IN,
CTYPE_OUT,
[&]() {
ET_SWITCH_REAL_TYPES_AND(
Bool,
out_type,
ctx,
"bitwise_xor.Tensor_out",
CTYPE_OUT,
[&]() {
apply_binary_elementwise_fn<
CTYPE_A,
CTYPE_B,
CTYPE_OUT>(
[](const CTYPE_A val_a, const CTYPE_B val_b) {
CTYPE_IN a_casted =
static_cast<CTYPE_IN>(val_a);
CTYPE_IN b_casted =
static_cast<CTYPE_IN>(val_b);
CTYPE_IN value =
bitwise_xor(a_casted, b_casted);

return static_cast<CTYPE_OUT>(value);
},
a,
b,
out);
});
internal::BitwiseOpInner<
can_cast<CTYPE_IN, CTYPE_OUT>::value,
std::bit_xor,
CTYPE_A,
CTYPE_B,
CTYPE_IN,
CTYPE_OUT>::run(a, b, out);
});
});
});
Expand Down Expand Up @@ -143,8 +115,8 @@ Tensor& bitwise_xor_Scalar_out(
static_cast<CTYPE_IN>(val_a);
CTYPE_IN b_casted =
static_cast<CTYPE_IN>(val_b);
CTYPE_IN value =
bitwise_xor(a_casted, b_casted);
CTYPE_IN value = std::bit_xor<CTYPE_IN>()(
a_casted, b_casted);

return static_cast<CTYPE_OUT>(value);
},
Expand Down
18 changes: 9 additions & 9 deletions kernels/portable/cpu/op_clamp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ __ET_NODISCARD bool check_bounds(
}
});
} else if (isFloatingType(out_type)) {
ET_SWITCH_FLOAT_TYPES(out_type, ctx, "clamp", CTYPE_OUT, [&]() {
ET_SWITCH_FLOATH_TYPES(out_type, ctx, "clamp", CTYPE_OUT, [&]() {
if (std::isfinite(val) &&
is_out_of_bounds<CTYPE_VAL, CTYPE_OUT, double>(val)) {
ET_LOG(Error, "%s value out of bounds", val_name);
Expand Down Expand Up @@ -119,7 +119,7 @@ Tensor& clamp_out(

ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out);

ET_SWITCH_REAL_TYPES(out_type, ctx, "clamp", CTYPE_OUT, [&]() {
ET_SWITCH_REALH_TYPES(out_type, ctx, "clamp", CTYPE_OUT, [&]() {
// Extract optional min value
CTYPE_OUT min = 0;
if (has_min) {
Expand All @@ -140,7 +140,7 @@ Tensor& clamp_out(
});
}

ET_SWITCH_REAL_TYPES_AND(Bool, in_type, ctx, "clamp", CTYPE_IN, [&]() {
ET_SWITCH_REALHB_TYPES(in_type, ctx, "clamp", CTYPE_IN, [&]() {
apply_unary_map_fn(
[has_min, min, has_max, max](const CTYPE_IN val_in) {
CTYPE_OUT val_out = static_cast<CTYPE_OUT>(val_in);
Expand Down Expand Up @@ -195,20 +195,20 @@ Tensor& clamp_tensor_out(
ScalarType out_type = out.scalar_type();

if (has_min) {
common_type = promoteTypes(common_type, min_type);
common_type = promoteTypes(common_type, min_type, /*half_to_float*/ true);
}
if (has_max) {
common_type = promoteTypes(common_type, max_type);
common_type = promoteTypes(common_type, max_type, /*half_to_float*/ true);
}

ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);

constexpr auto name = "clamp.Tensor_out";

ET_SWITCH_REALB_TYPES(in_type, ctx, name, CTYPE_IN, [&]() {
ET_SWITCH_REALB_TYPES(min_type, ctx, name, CTYPE_MIN, [&]() {
ET_SWITCH_REALB_TYPES(max_type, ctx, name, CTYPE_MAX, [&]() {
ET_SWITCH_REALB_TYPES(out_type, ctx, name, CTYPE_OUT, [&]() {
ET_SWITCH_REALHB_TYPES(in_type, ctx, name, CTYPE_IN, [&]() {
ET_SWITCH_REALHB_TYPES(min_type, ctx, name, CTYPE_MIN, [&]() {
ET_SWITCH_REALHB_TYPES(max_type, ctx, name, CTYPE_MAX, [&]() {
ET_SWITCH_REALHB_TYPES(out_type, ctx, name, CTYPE_OUT, [&]() {
apply_ternary_elementwise_fn<
CTYPE_IN,
CTYPE_MIN,
Expand Down