-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Rework strict_float to use individual op intrinsics instead #8641
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This commit removes the strict_float intrinsic, which was problematic because was a scoped thing, so hoisting an Expr out of it can lead to incorrect codegen. This is hard for various lowering passes to handle. Instead, this adds strict versions of all the floating point ops and comparisons that either round or need to worry about nan/inf: strict_add, strict_mul, strict_eq, etc. The strict_float helper is now a mutator that replaces all adds, muls, eqs, etc with the strict intrinsic equivalents. The simplifier has no rules that handle these intrinsics, so it just naturally doesn't do anything it shouldn't. The changes to the simplifier look large, but they are all just removing can-I-simplify-this checks, which changes indentation. There is also a helper that does the reverse, which can be used by analysis or codegen passes. The C-like backends just unstrictify these intrinsics when encountered, and it's on the consumer to not set fast-math flags. When CodeGen_LLVM encounters an intrinsic it sets a scoped flag to indicate it is now in strict mode, and then recursively codegens the unstrictified version of the op. If the strict_float target flag is set, this scoped flag is just always true. CodeGen_LLVM defaults every floating point operation to be strict. When it encounters an op that has a strict equivalent but the scoped flag is currently false, it temporarily sets the floating point flags to relaxed for that op only. In this way it fails safe - use of fast math flags must be explicitly requested at each emission of a floating point op. Without this fail-safe behavior there were problems where something like a vector slice operation was marked as no-nans, which was then fed into a strict comparison, making the comparison not actually strict because the tag on the previous op meant that LLVM was allowed to assume the arguments were not nan. no-nans and similar are viral. The above only applies if there is any usage of strict float intrinsics in the module or if the target flag is set. If not, it just defaults all ops to fast, as with main. For the purposes of this test, is_nan, is_inf, and is_finite all count as strict float intrinsics. There should be no visible changes to users. There are no performance differences in the apps with the flag off.
mcourteaux
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks mostly good. One serious concern regarding the recursive behaviour of the new UnstrictifyFloat mutator.
Other than that: I'd be nice to get access to those strict ops from IR.h with helper functions instead of forcing users through the Call::make or with the StrictifyFloat/strict_float() mutator.
Perhaps a bit too much to ask, but as this is super relevant: adding immediately a strict_fma would be awesome (and simplify work on my upcoming PRs, I believe, regarding Horner polynomial evaluations).
| handle_expr_bounds(e); | ||
| } else if (op->is_strict_float_intrinsic()) { | ||
| Expr e = unstrictify_float(op); | ||
| handle_expr_bounds(e); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we sure handle_expr_bounds does not "simplify" any math?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it just recursively visits the children. This unstrictification is preserving the behavior of main - strict floating point math is analyzed by bounds inference as if it's infinite-precision. This is dubious - I'm sure you could construct code that exploits this to read out of bounds. However it's what main does, so I'd like to defer any changes here to future work. I'll make it an issue.
| fp_flags.setNoNaNs(); | ||
| fp_flags.setNoInfs(); | ||
| fp_flags.setNoSignedZeros(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we sure no-nans and no-infs are a good idea? I am thinking wanting to let Halide/LLVM reassociate and optimize the math without discarding the possibility that expressions can go to nan or inf. I'm thinking about 1/x, sec(x).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is preserving the status quo when there's no strict_float usage - I just moved this code around. We can have a discussion at the dev meeting about these changing these flags if you like, but it might have impacts on production code.
| // TODO: Enable/Disable RelaxedPrecision flags? | ||
| internal_assert(op->args.size() == 1); | ||
| op->args[0].accept(this); | ||
| Expr e = unstrictify_float(op); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This feels wrong. I checked the IRMutator, and it seems to destrictify the entire IR tree under this node. As such we will lose the information on which ops were strict and which weren't. The whole overhaul of the strict float mechanic is supposed to tell us which ops are strict and which aren't. The IRMutator shouldn't recurse, I believe. The codegen visitors should simply handle nested strict float intrinsic Calls again, on this very same line, without destrictifying recursively.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current state of the PR very conservatively makes all things under a strict float intrinsic strict, whether or not unstrictify is recursive, because codegen_llvm sets a scoped flag before recursively codegenning the args. Now this is a bit weird, because (even with a non-recursive unstrictify) the following two exprs result in different codegen:
let t = a*b + c in strict_add(t, 2) // fma
strict_add(a*b + c, 2) // no fma
I agree this is gross, so I'll make unstrictify non-recursive and I'll evaluate the strict op args outside of the scope of the flag.
src/StrictifyFloat.cpp
Outdated
| (op->name == "sqrt_f16" || | ||
| op->name == "sqrt_f32" || | ||
| op->name == "sqrt_f64")) { | ||
| return Call::make(op->type, Call::strict_sqrt, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is there strict_sqrt, and not anything like strict_exp, strict_log, strict_tan, strict_sin, strict_lambert_w, strict_bessel, etc? It feels arbitrary. I realize there are hardware instructions on some platforms for sqrt, but aren't those always supposed to be exact? PTX has some rounding options:
sqrt.approx{.ftz}.f32 d, a; // fast, approximate square root
sqrt.rnd{.ftz}.f32 d, a; // IEEE 754 compliant rounding
sqrt.rnd.f64 d, a; // IEEE 754 compliant rounding
.rnd = { .rn, .rz, .rm, .rp };
So that sort of would explain why there is strict_sqrt: to ensure the most common rounding mode and to not emit the .ftz (flush to zero) variant? But so does PTX have instructions for these:
sqrt
rsqrt
sin
cos
lg2
ex2
tanh
As there are an endless amount of transcendental calls that could have a strict version, I am hesitant to say that this is the best way forward. On the other hand, we are considering fast_tan, etc as equivalents for fast versions... So that might indicate we need a strict version of every transcendental too, and not just the sqrt one here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For most of our transcendentals (excluding the ones starting with fast_) there is no relaxed version, so every version is strict and no unsafe optimizations take place. For sqrt, at the llvm level it's an intrinsic that can be tagged with no-nans (meaning the input can be assumed to be >= 0) or not.
However, now that I say that, I realize that the way we direct things to the intrinsic in src/runtime/posix_math.ll never tags it with no-nans when the module has any strict float usage, so there's actually no way to do a relaxed sqrt in that situation. Halide itself doesn't have any simplifier rules around sqrt, so the only thing a relaxed sqrt would buy us is telling llvm it can do unsafe optimizations relating to it. I think I'll just delete strict_sqrt.
|
To manage scope, I'm intentionally leaving out fma for now. |
- Make unstrictify non-recursive, and evaluate args outside of the strict scope in CodeGen - Fix C++ emission and add preprocessor warning that looks for common fast-math flags - Remove strict_sqrt
mcourteaux
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM except for one minor comment regarding a potentially useful internal_assert we could add.
This commit removes the strict_float intrinsic, which was problematic because was a scoped thing, so hoisting an Expr out of it can lead to incorrect codegen. This is hard for various lowering passes to handle.
Instead, this adds strict versions of all the floating point ops and comparisons that either round or need to worry about nan/inf: strict_add, strict_mul, strict_eq, etc. The strict_float helper is now a mutator that replaces all adds, muls, eqs, etc with the strict intrinsic equivalents.
The simplifier has no rules that handle these intrinsics, so it just naturally doesn't do anything it shouldn't. The changes to the simplifier look large, but they are all just removing can-I-simplify-this checks, which changes indentation.
There is also a helper that does the reverse, which can be used by analysis or codegen passes. The C-like backends just unstrictify these intrinsics when encountered, and it's on the consumer to not set fast-math flags. When CodeGen_LLVM encounters an intrinsic it sets a scoped flag to indicate it is now in strict mode, and then recursively codegens the unstrictified version of the op. If the strict_float target flag is set, this scoped flag is just always true. CodeGen_LLVM defaults every floating point operation to be strict. When it encounters an op that has a strict equivalent but the scoped flag is currently false, it temporarily sets the floating point flags to relaxed for that op only. In this way it fails safe - use of fast math flags must be explicitly requested at each emission of a floating point op.
Without this fail-safe behavior there were problems where something like a vector slice operation was marked as no-nans, which was then fed into a strict comparison, making the comparison not actually strict because the tag on the previous op meant that LLVM was allowed to assume the arguments were not nan. no-nans and similar are viral.
The above only applies if there is any usage of strict float intrinsics in the module or if the target flag is set. If not, it just defaults all ops to fast, as with main. For the purposes of this test, is_nan, is_inf, and is_finite all count as strict float intrinsics.
There should be no visible changes to users. There are no performance differences in the apps with the flag off.