Skip to content

Commit

Permalink
Use compile-time promotion to reduce floor_divide size & build time (p…
Browse files Browse the repository at this point in the history
…ytorch#3455)

Summary:

Continuing rollout of this technique.

Reviewed By: manuelcandales

Differential Revision: D56827786
  • Loading branch information
swolchok authored and facebook-github-bot committed May 7, 2024
1 parent 388caf9 commit 80d5bb7
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 29 deletions.
93 changes: 64 additions & 29 deletions kernels/portable/cpu/op_floor_divide.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,60 @@ namespace native {
using Tensor = exec_aten::Tensor;
using ScalarType = exec_aten::ScalarType;

namespace {
template <
bool can_cast,
typename CTYPE_A,
typename CTYPE_B,
typename CTYPE_IN,
typename CTYPE_OUT>
struct FloorDivideInner;

template <
typename CTYPE_A,
typename CTYPE_B,
typename CTYPE_IN,
typename CTYPE_OUT>
struct FloorDivideInner<true, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
static void
run(const Tensor& a, const Tensor& b, Tensor& out, bool& div_by_zero_error) {
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
// NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
[&div_by_zero_error](const CTYPE_A val_a, const CTYPE_B val_b) {
if (is_integral_type<CTYPE_IN, /*includeBool=*/true>::value) {
if (val_b == 0) {
div_by_zero_error = true;
return static_cast<CTYPE_OUT>(0);
}
}
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
CTYPE_IN value = utils::floor_divide<CTYPE_IN>(a_casted, b_casted);

return static_cast<CTYPE_OUT>(value);
},
a,
b,
out);
}
};

struct ReportCanCastBug {
static void run(const Tensor&, const Tensor&, Tensor&, bool&) {
ET_DCHECK_MSG(false, "BUG: canCast should have been checked above");
}
};

template <
typename CTYPE_A,
typename CTYPE_B,
typename CTYPE_IN,
typename CTYPE_OUT>
struct FloorDivideInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
: public ReportCanCastBug {};

} // namespace

Tensor& floor_divide_out(
RuntimeContext& ctx,
const Tensor& a,
Expand All @@ -46,36 +100,17 @@ Tensor& floor_divide_out(
Bool, a_type, ctx, "floor_divide.out", CTYPE_A, [&]() {
ET_SWITCH_REAL_TYPES_AND(
Bool, b_type, ctx, "floor_divide.out", CTYPE_B, [&]() {
using CTYPE_IN = typename torch::executor::
promote_types<CTYPE_A, CTYPE_B>::type;
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
ET_SWITCH_REAL_TYPES(
common_type, ctx, "floor_divide.out", CTYPE_IN, [&]() {
ET_SWITCH_REAL_TYPES(
out_type, ctx, "floor_divide.out", CTYPE_OUT, [&]() {
apply_binary_elementwise_fn<
CTYPE_A,
CTYPE_B,
CTYPE_OUT>(
[common_type, &div_by_zero_error](
const CTYPE_A val_a, const CTYPE_B val_b) {
if (isIntegralType(
common_type, /*includeBool=*/true)) {
if (val_b == 0) {
div_by_zero_error = true;
return static_cast<CTYPE_OUT>(0);
}
}
CTYPE_IN a_casted =
static_cast<CTYPE_IN>(val_a);
CTYPE_IN b_casted =
static_cast<CTYPE_IN>(val_b);
CTYPE_IN value = utils::floor_divide<CTYPE_IN>(
a_casted, b_casted);

return static_cast<CTYPE_OUT>(value);
},
a,
b,
out);
});
out_type, ctx, "floor_divide.out", CTYPE_OUT, [&]() {
FloorDivideInner<
can_cast<CTYPE_IN, CTYPE_OUT>::value,
CTYPE_A,
CTYPE_B,
CTYPE_IN,
CTYPE_OUT>::run(a, b, out, div_by_zero_error);
});
});
});
Expand Down
6 changes: 6 additions & 0 deletions runtime/core/exec_aten/util/scalar_type_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,12 @@ inline constexpr bool isIntegralType(
t == exec_aten::ScalarType::Short);
}

template <typename T, bool includeBool>
struct is_integral_type
: public std::integral_constant<
bool,
isIntegralType(CppTypeToScalarType<T>::value, includeBool)> {};

inline constexpr bool isFloatingType(exec_aten::ScalarType t) {
return (
t == exec_aten::ScalarType::Double || t == exec_aten::ScalarType::Float ||
Expand Down

0 comments on commit 80d5bb7

Please sign in to comment.