Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use compile-time promotion to reduce fmod size & build time #3456

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
93 changes: 64 additions & 29 deletions kernels/portable/cpu/op_floor_divide.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,60 @@ namespace native {
using Tensor = exec_aten::Tensor;
using ScalarType = exec_aten::ScalarType;

namespace {
template <
bool can_cast,
typename CTYPE_A,
typename CTYPE_B,
typename CTYPE_IN,
typename CTYPE_OUT>
struct FloorDivideInner;

template <
typename CTYPE_A,
typename CTYPE_B,
typename CTYPE_IN,
typename CTYPE_OUT>
struct FloorDivideInner<true, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
static void
run(const Tensor& a, const Tensor& b, Tensor& out, bool& div_by_zero_error) {
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
// NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
[&div_by_zero_error](const CTYPE_A val_a, const CTYPE_B val_b) {
if (is_integral_type<CTYPE_IN, /*includeBool=*/true>::value) {
if (val_b == 0) {
div_by_zero_error = true;
return static_cast<CTYPE_OUT>(0);
}
}
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
CTYPE_IN value = utils::floor_divide<CTYPE_IN>(a_casted, b_casted);

return static_cast<CTYPE_OUT>(value);
},
a,
b,
out);
}
};

struct ReportCanCastBug {
static void run(const Tensor&, const Tensor&, Tensor&, bool&) {
ET_DCHECK_MSG(false, "BUG: canCast should have been checked above");
}
};

template <
typename CTYPE_A,
typename CTYPE_B,
typename CTYPE_IN,
typename CTYPE_OUT>
struct FloorDivideInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
: public ReportCanCastBug {};

} // namespace

Tensor& floor_divide_out(
RuntimeContext& ctx,
const Tensor& a,
Expand All @@ -46,36 +100,17 @@ Tensor& floor_divide_out(
Bool, a_type, ctx, "floor_divide.out", CTYPE_A, [&]() {
ET_SWITCH_REAL_TYPES_AND(
Bool, b_type, ctx, "floor_divide.out", CTYPE_B, [&]() {
using CTYPE_IN = typename torch::executor::
promote_types<CTYPE_A, CTYPE_B>::type;
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
ET_SWITCH_REAL_TYPES(
common_type, ctx, "floor_divide.out", CTYPE_IN, [&]() {
ET_SWITCH_REAL_TYPES(
out_type, ctx, "floor_divide.out", CTYPE_OUT, [&]() {
apply_binary_elementwise_fn<
CTYPE_A,
CTYPE_B,
CTYPE_OUT>(
[common_type, &div_by_zero_error](
const CTYPE_A val_a, const CTYPE_B val_b) {
if (isIntegralType(
common_type, /*includeBool=*/true)) {
if (val_b == 0) {
div_by_zero_error = true;
return static_cast<CTYPE_OUT>(0);
}
}
CTYPE_IN a_casted =
static_cast<CTYPE_IN>(val_a);
CTYPE_IN b_casted =
static_cast<CTYPE_IN>(val_b);
CTYPE_IN value = utils::floor_divide<CTYPE_IN>(
a_casted, b_casted);

return static_cast<CTYPE_OUT>(value);
},
a,
b,
out);
});
out_type, ctx, "floor_divide.out", CTYPE_OUT, [&]() {
FloorDivideInner<
can_cast<CTYPE_IN, CTYPE_OUT>::value,
CTYPE_A,
CTYPE_B,
CTYPE_IN,
CTYPE_OUT>::run(a, b, out, div_by_zero_error);
});
});
});
Expand Down
93 changes: 65 additions & 28 deletions kernels/portable/cpu/op_fmod.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,60 @@ namespace native {

using Tensor = exec_aten::Tensor;

namespace {
template <
bool can_cast,
typename CTYPE_A,
typename CTYPE_B,
typename CTYPE_IN,
typename CTYPE_OUT>
struct FmodInner;

template <
typename CTYPE_A,
typename CTYPE_B,
typename CTYPE_IN,
typename CTYPE_OUT>
struct FmodInner<true, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
static void
run(const Tensor& a, const Tensor& b, Tensor& out, bool& div_by_zero_error) {
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
// NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
[&div_by_zero_error](const CTYPE_A val_a, const CTYPE_B val_b) {
if (is_integral_type<CTYPE_IN, /*includeBool=*/true>::value) {
if (val_b == 0) {
div_by_zero_error = true;
return static_cast<CTYPE_OUT>(0);
}
}
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
CTYPE_IN value = std::fmod(a_casted, b_casted);

return static_cast<CTYPE_OUT>(value);
},
a,
b,
out);
}
};

struct ReportCanCastBug {
static void run(const Tensor&, const Tensor&, Tensor&, bool&) {
ET_DCHECK_MSG(false, "BUG: canCast should have been checked above");
}
};

template <
typename CTYPE_A,
typename CTYPE_B,
typename CTYPE_IN,
typename CTYPE_OUT>
struct FmodInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
: public ReportCanCastBug {};

} // namespace

Tensor& fmod_Tensor_out(
RuntimeContext& ctx,
const Tensor& a,
Expand All @@ -44,35 +98,18 @@ Tensor& fmod_Tensor_out(
Bool, a_type, ctx, "fmod.Tensor_out", CTYPE_A, [&]() {
ET_SWITCH_REAL_TYPES_AND(
Bool, b_type, ctx, "fmod.Tensor_out", CTYPE_B, [&]() {
using CTYPE_IN = typename torch::executor::
promote_types<CTYPE_A, CTYPE_B>::type;
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
ET_SWITCH_REAL_TYPES(
common_type, ctx, "fmod.Tensor_out", CTYPE_IN, [&]() {
ET_SWITCH_REAL_TYPES(
out_type, ctx, "fmod.Tensor_out", CTYPE_OUT, [&]() {
apply_binary_elementwise_fn<
CTYPE_A,
CTYPE_B,
CTYPE_OUT>(
[common_type, &div_by_zero_error](
const CTYPE_A val_a, const CTYPE_B val_b) {
if (isIntegralType(
common_type, /*includeBool=*/true)) {
if (val_b == 0) {
div_by_zero_error = true;
return static_cast<CTYPE_OUT>(0);
}
}
CTYPE_IN a_casted =
static_cast<CTYPE_IN>(val_a);
CTYPE_IN b_casted =
static_cast<CTYPE_IN>(val_b);
CTYPE_IN value = std::fmod(a_casted, b_casted);

return static_cast<CTYPE_OUT>(value);
},
a,
b,
out);
});
out_type, ctx, "fmod.Tensor_out", CTYPE_OUT, [&]() {
FmodInner<
!std::is_same<CTYPE_IN, bool>::value &&
can_cast<CTYPE_IN, CTYPE_OUT>::value,
CTYPE_A,
CTYPE_B,
CTYPE_IN,
CTYPE_OUT>::run(a, b, out, div_by_zero_error);
});
});
});
Expand Down
68 changes: 54 additions & 14 deletions kernels/portable/cpu/op_maximum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,50 @@ const T& max(const T& a, const T& b) {
return (b > a) ? b : a;
}

template <
bool can_cast,
typename CTYPE_A,
typename CTYPE_B,
typename CTYPE_IN,
typename CTYPE_OUT>
struct MaximumInner;

template <
typename CTYPE_A,
typename CTYPE_B,
typename CTYPE_IN,
typename CTYPE_OUT>
struct MaximumInner<true, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
static void run(const Tensor& a, const Tensor& b, Tensor& out) {
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
// NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
[](const CTYPE_A val_a, const CTYPE_B val_b) {
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
CTYPE_IN value = max(a_casted, b_casted);

return static_cast<CTYPE_OUT>(value);
},
a,
b,
out);
}
};

struct ReportCanCastBug {
static void run(const Tensor&, const Tensor&, Tensor&) {
ET_DCHECK_MSG(false, "BUG: canCast should have been checked above");
}
};

template <
typename CTYPE_A,
typename CTYPE_B,
typename CTYPE_IN,
typename CTYPE_OUT>
struct MaximumInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
: public ReportCanCastBug {};

} // namespace

Tensor& maximum_out(
Expand All @@ -44,20 +88,16 @@ Tensor& maximum_out(

ET_SWITCH_REALHB_TYPES(a_type, ctx, "maximum.out", CTYPE_A, [&]() {
ET_SWITCH_REALHB_TYPES(b_type, ctx, "maximum.out", CTYPE_B, [&]() {
ET_SWITCH_REALB_TYPES(common_type, ctx, "maximum.out", CTYPE_IN, [&]() {
ET_SWITCH_REALHB_TYPES(out_type, ctx, "maximum.out", CTYPE_OUT, [&]() {
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
[](const CTYPE_A val_a, const CTYPE_B val_b) {
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
CTYPE_IN value = max(a_casted, b_casted);

return static_cast<CTYPE_OUT>(value);
},
a,
b,
out);
});
using CTYPE_IN = typename torch::executor::
promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
ET_SWITCH_REALHB_TYPES(out_type, ctx, "maximum.out", CTYPE_OUT, [&]() {
MaximumInner<
can_cast<CTYPE_IN, CTYPE_OUT>::value,
CTYPE_A,
CTYPE_B,
CTYPE_IN,
CTYPE_OUT>::run(a, b, out);
});
});
});
Expand Down
69 changes: 54 additions & 15 deletions kernels/portable/cpu/op_minimum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,50 @@ const T& min(const T& a, const T& b) {
return (b < a) ? b : a;
}

template <
bool can_cast,
typename CTYPE_A,
typename CTYPE_B,
typename CTYPE_IN,
typename CTYPE_OUT>
struct MinimumInner;

template <
typename CTYPE_A,
typename CTYPE_B,
typename CTYPE_IN,
typename CTYPE_OUT>
struct MinimumInner<true, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
static void run(const Tensor& a, const Tensor& b, Tensor& out) {
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
// NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
[](const CTYPE_A val_a, const CTYPE_B val_b) {
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
CTYPE_IN value = min(a_casted, b_casted);

return static_cast<CTYPE_OUT>(value);
},
a,
b,
out);
}
};

struct ReportCanCastBug {
static void run(const Tensor&, const Tensor&, Tensor&) {
ET_DCHECK_MSG(false, "BUG: canCast should have been checked above");
}
};

template <
typename CTYPE_A,
typename CTYPE_B,
typename CTYPE_IN,
typename CTYPE_OUT>
struct MinimumInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
: public ReportCanCastBug {};

} // namespace

Tensor& minimum_out(
Expand All @@ -44,22 +88,17 @@ Tensor& minimum_out(

ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "minimum.out", CTYPE_A, [&]() {
ET_SWITCH_REAL_TYPES_AND(Bool, b_type, ctx, "minimum.out", CTYPE_B, [&]() {
using CTYPE_IN =
typename torch::executor::promote_types<CTYPE_A, CTYPE_B>::type;
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
ET_SWITCH_REAL_TYPES_AND(
Bool, common_type, ctx, "minimum.out", CTYPE_IN, [&]() {
ET_SWITCH_REAL_TYPES_AND(
Bool, out_type, ctx, "minimum.out", CTYPE_OUT, [&]() {
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
[](const CTYPE_A val_a, const CTYPE_B val_b) {
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
CTYPE_IN value = min(a_casted, b_casted);

return static_cast<CTYPE_OUT>(value);
},
a,
b,
out);
});
Bool, out_type, ctx, "minimum.out", CTYPE_OUT, [&]() {
MinimumInner<
can_cast<CTYPE_IN, CTYPE_OUT>::value,
CTYPE_A,
CTYPE_B,
CTYPE_IN,
CTYPE_OUT>::run(a, b, out);
});
});
});
Expand Down