Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions src/AssociativeOpsTable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,32 @@ void populate_ops_table_single_uint32_select(const vector<Type> &types, vector<A
table.emplace_back(select(x0 < -y0, y0, tmax_0), zero_0, true); // Saturating add
}

// This function exists because the Solve module strips strict_float on one side of the pattern matching.
// This leads to failed pattern matches in the nan-propagating min/max patterns.
// TODO: Once strict_float has been reworked, this should be removed.
Expr is_nan_not_strict(Expr x) {
Type t = Bool(x.type().lanes());
if (x.type().element_of() == Float(64)) {
return Call::make(t, "is_nan_f64", {std::move(x)}, Call::PureExtern);
}
if (x.type().element_of() == Float(16)) {
return Call::make(t, "is_nan_f16", {std::move(x)}, Call::PureExtern);
}
internal_assert(x.type().element_of() == Float(32));
return Call::make(t, "is_nan_f32", {std::move(x)}, Call::PureExtern);
}

void populate_ops_table_single_float_select(const vector<Type> &types, vector<AssociativePattern> &table) {
declare_vars_single(types);
// Propagating max operators
table.emplace_back(select(is_nan_not_strict(x0) || x0 > y0, x0, y0), tmin_0, true);
table.emplace_back(select(is_nan_not_strict(x0) || x0 >= y0, x0, y0), tmin_0, true);

// Propagating min operators
table.emplace_back(select(is_nan_not_strict(x0) || x0 < y0, x0, y0), tmax_0, true);
table.emplace_back(select(is_nan_not_strict(x0) || x0 <= y0, x0, y0), tmax_0, true);
}

const map<TableKey, void (*)(const vector<Type> &types, vector<AssociativePattern> &)> val_type_to_populate_luts_fn = {
{TableKey(ValType::All, IRNodeType::Add, 1), &populate_ops_table_single_general_add},
{TableKey(ValType::All, IRNodeType::Mul, 1), &populate_ops_table_single_general_mul},
Expand All @@ -275,6 +301,10 @@ const map<TableKey, void (*)(const vector<Type> &types, vector<AssociativePatter

{TableKey(ValType::UInt32, IRNodeType::Cast, 1), &populate_ops_table_single_uint32_cast},
{TableKey(ValType::UInt32, IRNodeType::Select, 1), &populate_ops_table_single_uint32_select},

{TableKey(ValType::Float16, IRNodeType::Select, 1), &populate_ops_table_single_float_select},
{TableKey(ValType::Float32, IRNodeType::Select, 1), &populate_ops_table_single_float_select},
{TableKey(ValType::Float64, IRNodeType::Select, 1), &populate_ops_table_single_float_select},
};

const vector<AssociativePattern> &get_ops_table_helper(const vector<Type> &types, IRNodeType root, size_t dim) {
Expand Down
2 changes: 1 addition & 1 deletion src/Pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ struct PipelineContents;
*
* The 'name' field specifies the type of Autoscheduler
* to be used (e.g. Adams2019, Mullapudi2016). If this is an empty string,
* no autoscheduling will be done; if not, it mustbe the name of a known Autoscheduler.
* no autoscheduling will be done; if not, it must be the name of a known Autoscheduler.
*
* At this time, well-known autoschedulers include:
* "Mullapudi2016" -- heuristics-based; the first working autoscheduler; currently built in to libHalide
Expand Down
60 changes: 60 additions & 0 deletions test/correctness/rfactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1060,6 +1060,64 @@ int rfactor_precise_bounds_test() {
return 0;
}

enum MaxRFactorTestVariant {
BitwiseOr,
LogicalOr,
};

template<MaxRFactorTestVariant variant>
int isnan_max_rfactor_test() {
RDom r(0, 16);
RVar ri("ri");
Var x("x"), y("y"), u("u");

ImageParam in(Float(32), 2);

const auto make_reduce = [&](const char *name) {
Func reduce(name);
reduce(y) = Float(32).min();
switch (variant) {
case BitwiseOr:
reduce(y) = select(reduce(y) > cast(reduce.type(), in(r, y)) | is_nan(reduce(y)), reduce(y), cast(reduce.type(), in(r, y)));
break;
case LogicalOr:
reduce(y) = select(reduce(y) > cast(reduce.type(), in(r, y)) || is_nan(reduce(y)), reduce(y), cast(reduce.type(), in(r, y)));
break;
}
return reduce;
};

Func reference = make_reduce("reference");

Func reduce = make_reduce("reduce");
reduce.update(0).split(r, r, ri, 8).rfactor(ri, u);

float tests[][16] = {
{NAN, 0.29f, 0.19f, 0.68f, 0.44f, 0.40f, 0.39f, 0.53f, 0.23f, 0.21f, 0.85f, 0.19f, 0.37f, 0.03f, 0.14f, 0.64f},
{0.98f, 0.65f, 0.86f, 0.16f, 0.14f, 0.91f, 0.74f, 0.99f, 0.91f, 0.01f, 0.11f, 0.59f, 0.05f, 0.90f, 0.93f, NAN},
{0.84f, 0.14f, 0.99f, 0.19f, 0.63f, 0.12f, 0.51f, 0.67f, NAN, 0.34f, 0.89f, 0.93f, 0.72f, 0.69f, 0.58f, 0.63f},
{0.44f, 0.12f, 0.00f, 0.30f, 0.80f, 0.88f, 0.95f, 0.12f, 0.90f, 0.99f, 0.67f, 0.71f, 0.35f, 0.67f, 0.18f, 0.93f},
};

Buffer<float, 2> buf{tests};
in.set(buf);

Buffer<float, 1> ref_vals = reference.realize({4}, get_jit_target_from_environment().with_feature(Target::StrictFloat));
Buffer<float, 1> fac_vals = reduce.realize({4}, get_jit_target_from_environment().with_feature(Target::StrictFloat));

for (int i = 0; i < 4; i++) {
if (std::isnan(fac_vals(i)) && std::isnan(ref_vals(i))) {
continue;
}
if (fac_vals(i) != ref_vals(i)) {
std::cerr << "At index " << i << ", expected: " << ref_vals(i) << ", got: " << fac_vals(i) << "\n";
return 1;
}
}

return 0;
}

} // namespace

int main(int argc, char **argv) {
Expand Down Expand Up @@ -1100,6 +1158,8 @@ int main(int argc, char **argv) {
{"argmin rfactor test", argmin_rfactor_test},
{"inlined rfactor with disappearing rvar test", inlined_rfactor_with_disappearing_rvar_test},
{"rfactor bounds tests", rfactor_precise_bounds_test},
{"isnan max rfactor test (bitwise or)", isnan_max_rfactor_test<BitwiseOr>},
{"isnan max rfactor test (logical or)", isnan_max_rfactor_test<LogicalOr>},
};

using Sharder = Halide::Internal::Test::Sharder;
Expand Down
Loading