Skip to content
68 changes: 50 additions & 18 deletions python_bindings/src/halide/halide_/PyIROperator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,15 +128,16 @@ void define_operators(py::module &m) {
return py::cast(false_expr_value);
});

m.def("mux", (Expr(*)(const Expr &, const std::vector<Expr> &))&mux);
m.def("mux", static_cast<Expr (*)(const Expr &, const std::vector<Expr> &)>(&mux));
m.def("mux", static_cast<Expr (*)(const Expr &, const Tuple &)>(&mux));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Today I learned: overload selection when taking function pointers.

m.def("mux", static_cast<Tuple (*)(const Expr &, const std::vector<Tuple> &)>(&mux));

m.def("sin", &sin);
m.def("asin", &asin);
m.def("cos", &cos);
m.def("acos", &acos);
m.def("tan", &tan);
m.def("atan", &atan);
m.def("atan", &atan2);
m.def("atan2", &atan2);
m.def("sinh", &sinh);
m.def("asinh", &asinh);
Expand All @@ -150,6 +151,8 @@ void define_operators(py::module &m) {
m.def("log", &log);
m.def("pow", &pow);
m.def("erf", &erf);
m.def("fast_sin", &fast_sin);
m.def("fast_cos", &fast_cos);
m.def("fast_log", &fast_log);
m.def("fast_exp", &fast_exp);
m.def("fast_pow", &fast_pow);
Expand All @@ -163,55 +166,84 @@ void define_operators(py::module &m) {
m.def("is_nan", &is_nan);
m.def("is_inf", &is_inf);
m.def("is_finite", &is_finite);
m.def("reinterpret", (Expr(*)(Type, Expr))&reinterpret);
m.def("cast", (Expr(*)(Type, Expr))&cast);
m.def("reinterpret", static_cast<Expr (*)(Type, Expr)>(&reinterpret));
m.def("cast", static_cast<Expr (*)(Type, Expr)>(&cast));

m.def("print", [](const py::args &args) -> Expr {
return print(collect_print_args(args));
});

m.def(
"print_when", [](const Expr &condition, const py::args &args) -> Expr {
return print_when(condition, collect_print_args(args));
},
py::arg("condition"));

m.def(
"require", [](const Expr &condition, const Expr &value, const py::args &args) -> Expr {
auto v = args_to_vector<Expr>(args);
v.insert(v.begin(), value);
return require(condition, v);
},
py::arg("condition"), py::arg("value"));

m.def("lerp", &lerp);
m.def("popcount", &popcount);
m.def("count_leading_zeros", &count_leading_zeros);
m.def("count_trailing_zeros", &count_trailing_zeros);
m.def("div_round_to_zero", &div_round_to_zero);
m.def("mod_round_to_zero", &mod_round_to_zero);
m.def("random_float", (Expr(*)())&random_float);
m.def("random_uint", (Expr(*)())&random_uint);
m.def("random_int", (Expr(*)())&random_int);
m.def("random_float", (Expr(*)(Expr))&random_float, py::arg("seed"));
m.def("random_uint", (Expr(*)(Expr))&random_uint, py::arg("seed"));
m.def("random_int", (Expr(*)(Expr))&random_int, py::arg("seed"));
m.def("undef", (Expr(*)(Type))&undef);
m.def("random_float", [] { return random_float(); });
m.def("random_float", &random_float, py::arg("seed"));
m.def("random_uint", [] { return random_uint(); });
m.def("random_uint", &random_uint, py::arg("seed"));
m.def("random_int", [] { return random_int(); });
m.def("random_int", &random_int, py::arg("seed"));
m.def("undef", static_cast<Expr (*)(Type)>(&undef));

m.def(
"memoize_tag", [](const Expr &result, const py::args &cache_key_values) -> Expr {
return Internal::memoize_tag_helper(result, args_to_vector<Expr>(cache_key_values));
},
py::arg("result"));

m.def("likely", &likely);
m.def("likely_if_innermost", &likely_if_innermost);
m.def("saturating_cast", (Expr(*)(Type, Expr))&saturating_cast);
m.def("saturating_cast", static_cast<Expr (*)(Type, Expr)>(&saturating_cast));
m.def("strict_float", &strict_float);
m.def("scatter", static_cast<Expr (*)(const std::vector<Expr> &)>(&scatter));
m.def("gather", static_cast<Expr (*)(const std::vector<Expr> &)>(&gather));
m.def("extract_bits", static_cast<Expr (*)(Type, const Expr &, const Expr &)>(&extract_bits));
m.def("concat_bits", &concat_bits);
m.def("widen_right_add", &widen_right_add);
m.def("widen_right_mul", &widen_right_mul);
m.def("widen_right_sub", &widen_right_sub);
m.def("widening_add", &widening_add);
m.def("widening_mul", &widening_mul);
m.def("widening_sub", &widening_sub);
m.def("widening_shift_left", static_cast<Expr (*)(Expr, int)>(&widening_shift_left));
m.def("widening_shift_left", static_cast<Expr (*)(Expr, Expr)>(&widening_shift_left));
m.def("widening_shift_right", static_cast<Expr (*)(Expr, int)>(&widening_shift_right));
m.def("widening_shift_right", static_cast<Expr (*)(Expr, Expr)>(&widening_shift_right));
m.def("rounding_shift_left", static_cast<Expr (*)(Expr, int)>(&rounding_shift_left));
m.def("rounding_shift_left", static_cast<Expr (*)(Expr, Expr)>(&rounding_shift_left));
m.def("rounding_shift_right", static_cast<Expr (*)(Expr, int)>(&rounding_shift_right));
m.def("rounding_shift_right", static_cast<Expr (*)(Expr, Expr)>(&rounding_shift_right));
m.def("saturating_add", &saturating_add);
m.def("saturating_sub", &saturating_sub);
m.def("halving_add", &halving_add);
m.def("rounding_halving_add", &rounding_halving_add);
m.def("halving_sub", &halving_sub);
m.def("mul_shift_right", static_cast<Expr (*)(Expr, Expr, int)>(&mul_shift_right));
m.def("mul_shift_right", static_cast<Expr (*)(Expr, Expr, Expr)>(&mul_shift_right));
m.def("rounding_mul_shift_right", static_cast<Expr (*)(Expr, Expr, int)>(&rounding_mul_shift_right));
m.def("rounding_mul_shift_right", static_cast<Expr (*)(Expr, Expr, Expr)>(&rounding_mul_shift_right));
m.def("target_arch_is", &target_arch_is);
m.def("target_bits", &target_bits);
m.def("target_has_feature", &target_has_feature);
m.def("target_natural_vector_size", [](const Type &t) -> Expr {
return target_natural_vector_size(t);
});
m.def("target_natural_vector_size", static_cast<Expr (*)(Type)>(&target_natural_vector_size));
m.def("target_os_is", &target_os_is);
m.def("logical_not", [](const Expr &expr) -> Expr {
return !expr;
});
m.def("logical_not", [](const Expr &expr) -> Expr { return !expr; });
}

} // namespace PythonBindings
Expand Down
2 changes: 1 addition & 1 deletion src/IROperator.h
Original file line number Diff line number Diff line change
Expand Up @@ -1555,7 +1555,7 @@ f(scatter(3, 5)) = f(select(p, gather(5, 3), gather(3, 5)));
f(select(p, scatter(3, 5, 5), scatter(1, 2, 3))) = f(select(p, gather(5, 3, 3), gather(2, 3, 1)));
\endcode
*
* Note that in the p == true case, we redudantly load from 3 and write
* Note that in the p == true case, we redundantly load from 3 and write
* to 5 twice.
*/
//@{
Expand Down
Loading