From 7005a4bcb63f595b93d35f175339152600799905 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 11 Jan 2024 11:52:28 +0000 Subject: [PATCH] [dynamo] Added dyn shapes support for math trigo ops: sin(h), cos(h), tan(h) ... (#114866) Description: - Added dynamic shapes support for math trigo ops: sin(h), cos(h), tan(h) ... ```python import math import torch def func(x, a, b): c = 0 c = c + math.sqrt(a) c = c + math.cos(a) c = c + math.cosh(a) c = c + math.sin(a) c = c + math.sinh(a) c = c + math.tan(a) c = c + math.tanh(a) c = c + math.asin(b) c = c + math.acos(b) c = c + math.atan(a) y = x + c return y cfunc = torch.compile(func, dynamic=True, fullgraph=True) device = "cpu" # or "cuda" x = torch.tensor([0, 1, 2, 3], dtype=torch.float32, device=device) a = 12 b = 1 out = cfunc(x, a, b) expected = func(x, a, b) torch.testing.assert_close(out, expected) ``` and the graph `TORCH_LOGS=+graph_code python check_math_ops.py`:
graph code ``` [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] TRACED GRAPH [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] ===== __compiled_fn_0 ===== [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] .0 class GraphModule(torch.nn.Module): [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] def forward(self, L_a_ : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor): [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] l_a_ = L_a_ [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] l_x_ = L_x_ [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] # File: check_math_ops.py:57, code: c = c + math.sqrt(a) [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] sym_sqrt = torch.sym_sqrt(l_a_) [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] add = 0 + sym_sqrt; sym_sqrt = None [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] # File: check_math_ops.py:58, code: c = c + math.cos(a) [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] sym_cos = torch.sym_cos(l_a_) [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] add_1 = add + sym_cos; add = sym_cos = None [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] # File: check_math_ops.py:59, code: c = c + math.cosh(a) [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] sym_cosh = torch.sym_cosh(l_a_) [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] add_2 = add_1 + sym_cosh; add_1 = sym_cosh = None [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] # File: check_math_ops.py:60, code: c = c + math.sin(a) [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] sym_sin = torch.sym_sin(l_a_) [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] add_3 = add_2 + sym_sin; add_2 = sym_sin = None [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] # File: check_math_ops.py:61, code: c = c + math.sinh(a) [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] sym_sinh = torch.sym_sinh(l_a_) [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] add_4 = add_3 + sym_sinh; add_3 = sym_sinh = None [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] # File: check_math_ops.py:62, code: c = c + math.tan(a) [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] sym_tan = torch.sym_tan(l_a_) [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] add_5 = add_4 + sym_tan; add_4 = sym_tan = None [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] # File: check_math_ops.py:63, code: c = c + math.tanh(a) [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] sym_tanh = torch.sym_tanh(l_a_) [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] add_6 = add_5 + sym_tanh; add_5 = sym_tanh = None [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] # File: check_math_ops.py:64, code: c = c + math.asin(b) [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] add_7 = add_6 + 1.5707963267948966; add_6 = None [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] # File: check_math_ops.py:65, code: c = c + math.acos(b) [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] add_8 = add_7 + 0.0; add_7 = None [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] # File: check_math_ops.py:66, code: c = c + math.atan(a) [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] sym_atan = torch.sym_atan(l_a_); l_a_ = None [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] add_9 = add_8 + sym_atan; add_8 = sym_atan = None [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] # File: check_math_ops.py:67, code: y = x + c [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] y = l_x_ + add_9; l_x_ = add_9 = None [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] return (y,) [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] ```
Generated code with `TORCH_LOGS=+output_code python check_math_ops.py`:
C++ code ``` [2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] cpp_fused_add_0 = async_compile.cpp(''' [2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] #include "/tmp/torchinductor_root/2l/c2ljzlm4sosod7u6lyrroqdba6hmfcyijrric6p4t3fhbcmw6osp.h" [2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] extern "C" void kernel(const float* in_ptr0, [2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] float* out_ptr0, [2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] const long ks0, [2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] const long ks1) [2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] { [2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] { [2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] #pragma GCC ivdep [2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] for(long x0=static_cast(0L); x0(ks0); x0+=static_cast(1L)) [2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] { [2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] auto tmp0 = in_ptr0[static_cast(x0)]; [2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] auto tmp1 = c10::convert(1.57079632679490 + (std::sqrt(ks1)) + (std::atan(ks1)) + (std::cos(ks1)) + (std::cosh(ks1)) + (std::sin(ks1)) + (std::sinh(ks1)) + (std::tan(ks1)) + (std::tanh(ks1))); [2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] auto tmp2 = decltype(tmp0)(tmp0 + tmp1); [2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] out_ptr0[static_cast(x0)] = tmp2; [2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] } [2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] } [2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] } [2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] ''') ```
Triton code ``` [2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] @pointwise( [2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] size_hints=[4], [2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] filename=__file__, [2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1), equal_to_1=(), i ds_of_folded_args=(), divisible_by_8=())]}, [2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_0', 'mutated_arg_names': []}, [2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] min_elem_per_thread=0 [2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] ) [2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] @triton.jit [2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] def triton_(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): [2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] xoffset = tl.program_id(0) * XBLOCK [2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] xindex = xoffset + tl.arange(0, XBLOCK)[:] [2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] xmask = xindex < xnumel [2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] x0 = xindex [2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] tmp0 = tl.load(in_ptr0 + (x0), xmask) [2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] tmp1 = 1.57079632679490 + (tl.math.sqrt(ks0.to(tl.float32))) + (tl.math.atan((ks0).to(tl.float32))) + (tl.math.cos((ks0).to(tl.float32))) + (tl.math.cosh((ks0).to(tl.float32))) + (tl.math.sin((ks0) .to(tl.float32))) + (tl.math.sinh((ks0).to(tl.float32))) + (tl.math.tan((ks0).to(tl.float32))) + (tl.math.tanh((ks0).to(tl.float32))) [2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] tmp2 = tmp1.to(tl.float32) [2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] tmp3 = tmp0 + tmp2 [2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] tl.store(out_ptr0 + (x0), tmp3, xmask) [2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] ''') ```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114866 Approved by: https://github.com/peterbell10 --- test/export/test_export.py | 2 +- .../test_torchinductor_dynamic_shapes.py | 28 +++++++ test/test_dynamic_shapes.py | 6 +- torch/__init__.py | 35 ++++++--- torch/_dynamo/variables/torch.py | 9 ++- torch/_export/serde/serialize.py | 1 - torch/_inductor/codegen/common.py | 36 +++++++++ torch/_inductor/codegen/cpp.py | 36 +++++++++ torch/_inductor/codegen/triton.py | 36 +++++++++ torch/_prims_common/__init__.py | 2 +- torch/fx/experimental/sym_node.py | 77 ++++++++++++++++--- torch/fx/experimental/validator.py | 4 +- torch/overrides.py | 11 ++- torch/utils/_sympy/interp.py | 3 + torch/utils/_sympy/reference.py | 2 +- torch/utils/_sympy/value_ranges.py | 52 +++++++++++++ 16 files changed, 304 insertions(+), 36 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index 1a1a0c30d7f12..2d3cdda49cc17 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -2223,7 +2223,7 @@ def forward(self, x): ep = export(M(), (torch.ones(16, 4),), dynamic_shapes={'x': {0: Dim("dim")}}) _ExportPassBaseDeprecatedDoNotUse()(ep.graph_module) FileCheck().check_count( - "torch.sym_sqrt", 1, exactly=True + "torch._sym_sqrt", 1, exactly=True ).run(ep.graph_module.code) def test_check_specialized_int(self): diff --git a/test/inductor/test_torchinductor_dynamic_shapes.py b/test/inductor/test_torchinductor_dynamic_shapes.py index c06c1ffd19362..df8c30be8e75d 100644 --- a/test/inductor/test_torchinductor_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_dynamic_shapes.py @@ -19,6 +19,7 @@ from torch.testing._internal.common_utils import ( IS_CI, IS_WINDOWS, + parametrize, TEST_WITH_ASAN, TEST_WITH_ROCM, TestCase, @@ -496,6 +497,33 @@ def fn(a): actual = cfn(5) self.assertEqual(expect, actual) + @parametrize( + "op", + [ + math.sqrt, + math.sin, + math.cos, + math.cosh, + math.sin, + math.sinh, + math.tan, + math.tanh, + math.asin, + math.acos, + math.atan, + ], + ) + def test_math_ops(self, device, op): + def func(x, fn, a): + return x + fn(a) + + cfunc = self.compile_fn(func, fullgraph=True) + x = torch.rand(10, device=device) + a = -1 if op in (math.asin, math.acos) else 12 + expected = func(x, op, a) + output = cfunc(x, op, a) + self.assertEqual(output, expected) + instantiate_device_type_tests(TestInductorDynamic, globals()) diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index 48d6e40b62acb..732361e2f81e7 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -16,7 +16,7 @@ from torch._C import _disabled_torch_function_impl from torch.fx.experimental import sym_node from torch.fx.experimental.proxy_tensor import make_fx -from torch.fx.experimental.sym_node import to_node, sym_sqrt, SymNode, method_to_operator +from torch.fx.experimental.sym_node import to_node, SymNode, method_to_operator from torch.fx.experimental.symbolic_shapes import ( DimConstraints, DimDynamic, @@ -394,7 +394,7 @@ def test_sym_int(self): def test_sym_sqrt(self): shape_env = ShapeEnv() a0 = create_symint(shape_env, 4) - r = sym_sqrt(a0) + r = torch._sym_sqrt(a0) self.assertEqual(r, 2) self.assertIsInstance(r, torch.SymFloat, msg=type(r)) self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(sqrt(s0), 2)""") @@ -764,7 +764,7 @@ def test_method(self, fn, first_type, second_type): values = ( 0.0, 1.0, - 2.5, + 0.5 if fn in ("sym_acos", "sym_asin") else 2.5 # avoid math domain error ) neg_values = tuple(-x for x in values) diff --git a/torch/__init__.py b/torch/__init__.py index 98c9a43511c39..6eb6cdd6bff9f 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -57,7 +57,6 @@ def _running_with_deploy(): 'set_warn_always', 'is_warn_always_enabled', 'SymInt', 'SymFloat', 'SymBool', 'sym_not', 'unravel_index', 'sym_int', 'sym_float', 'sym_max', 'sym_min', 'sym_ite', 'compile', 'vmap', - 'sym_sqrt', 'export', 'autocast', 'cond', ] @@ -491,15 +490,33 @@ def sym_min(a, b): return b.__sym_min__(a) return builtins.min(a, b) # type: ignore[operator] -# Drop in replacement for math.sqrt -def sym_sqrt(a): - from .overrides import has_torch_function_unary, handle_torch_function +# Drop in replacement for math.sqrt, math.sin, math.cos etc +current_module = sys.modules[__name__] + +def _get_sym_math_fn(name): + def fn(a): + from .overrides import has_torch_function_unary, handle_torch_function + + if has_torch_function_unary(a): + return handle_torch_function(fn, (a,), a) + if hasattr(a, f"__sym_{name}__"): + return getattr(a, f"__sym_{name}__")() + return getattr(math, name)(a) + + return fn + +for name in ("sqrt", "cos", "cosh", "sin", "sinh", "tan", "tanh", "asin", "acos", "atan"): + sym_name = f"_sym_{name}" + fn = _get_sym_math_fn(name) + fn.__qualname__ = fn.__name__ = sym_name + setattr(current_module, sym_name, fn) + +# Adding temporary shortcut +sym_sqrt = current_module._sym_sqrt +__all__.append("sym_sqrt") + +del fn, name, sym_name, current_module - if has_torch_function_unary(a): - return handle_torch_function(sym_sqrt, (a,), a) - if hasattr(a, "__sym_sqrt__"): - return a.__sym_sqrt__() - return math.sqrt(a) def sym_ite(b, t, f): from .overrides import has_torch_function, handle_torch_function diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 9b6ef1c3073b2..77624dbe0f760 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -524,10 +524,11 @@ def fn_with_prim_types(x): # of value + args to determine this. fn_ = self.value if any(isinstance(x, SymNodeVariable) for x in args): - if self.value == math.sqrt: - from torch.fx.experimental.sym_node import sym_sqrt - - fn_ = sym_sqrt + torch_sym_op = f"_sym_{self.value.__name__}" + if getattr(self.value, "__module__", None) == "math" and hasattr( + torch, torch_sym_op + ): + fn_ = getattr(torch, torch_sym_op) if fn_ is torch.tensor: diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index 7e117eaa3b4a9..92f221fc88761 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -165,7 +165,6 @@ def _reverse_map(d: Dict[Any, Enum]): operator.sub, operator.floordiv, operator.mod, - torch.sym_sqrt, torch.sym_int, torch.sym_ite, torch.sym_max, diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index a09cf97a838b3..b201ed905f28c 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -399,6 +399,42 @@ def _print_Min(self, expr): assert len(expr.args) >= 2 return f"min({', '.join(map(self._print, expr.args))})" + def _print_cos(self, expr): + assert len(expr.args) == 1 + return f"math.cos({self._print(expr.args[0])})" + + def _print_cosh(self, expr): + assert len(expr.args) == 1 + return f"math.cosh({self._print(expr.args[0])})" + + def _print_acos(self, expr): + assert len(expr.args) == 1 + return f"math.acos({self._print(expr.args[0])})" + + def _print_sin(self, expr): + assert len(expr.args) == 1 + return f"math.sin({self._print(expr.args[0])})" + + def _print_sinh(self, expr): + assert len(expr.args) == 1 + return f"math.sinh({self._print(expr.args[0])})" + + def _print_asin(self, expr): + assert len(expr.args) == 1 + return f"math.asin({self._print(expr.args[0])})" + + def _print_tan(self, expr): + assert len(expr.args) == 1 + return f"math.tan({self._print(expr.args[0])})" + + def _print_tanh(self, expr): + assert len(expr.args) == 1 + return f"math.tanh({self._print(expr.args[0])})" + + def _print_atan(self, expr): + assert len(expr.args) == 1 + return f"math.atan({self._print(expr.args[0])})" + def _print_Round(self, expr): assert len(expr.args) == 1 return f"round({self._print(expr.args[0])})" diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 9119d1bf5fb6a..fb2a5aaa98cff 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -399,6 +399,42 @@ def _print_Abs(self, expr): assert len(expr.args) == 1 return f"std::abs({self._print(expr.args[0])})" + def _print_cos(self, expr): + assert len(expr.args) == 1 + return f"std::cos({self._print(expr.args[0])})" + + def _print_cosh(self, expr): + assert len(expr.args) == 1 + return f"std::cosh({self._print(expr.args[0])})" + + def _print_acos(self, expr): + assert len(expr.args) == 1 + return f"std::acos({self._print(expr.args[0])})" + + def _print_sin(self, expr): + assert len(expr.args) == 1 + return f"std::sin({self._print(expr.args[0])})" + + def _print_sinh(self, expr): + assert len(expr.args) == 1 + return f"std::sinh({self._print(expr.args[0])})" + + def _print_asin(self, expr): + assert len(expr.args) == 1 + return f"std::asin({self._print(expr.args[0])})" + + def _print_tan(self, expr): + assert len(expr.args) == 1 + return f"std::tan({self._print(expr.args[0])})" + + def _print_tanh(self, expr): + assert len(expr.args) == 1 + return f"std::tanh({self._print(expr.args[0])})" + + def _print_atan(self, expr): + assert len(expr.args) == 1 + return f"std::atan({self._print(expr.args[0])})" + def _print_Round(self, expr): assert len(expr.args) == 1 return f"std::lrint({self._print(expr.args[0])})" diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 0f29230668cf5..5e6dccd7fcf24 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -310,6 +310,42 @@ def _print_Abs(self, expr): assert len(expr.args) == 1 return f"tl.abs({self._print(expr.args[0])})" + def _print_cos(self, expr): + assert len(expr.args) == 1 + return f"tl.math.cos(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_cosh(self, expr): + assert len(expr.args) == 1 + return f"tl.math.cosh(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_acos(self, expr): + assert len(expr.args) == 1 + return f"tl.math.acos(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_sin(self, expr): + assert len(expr.args) == 1 + return f"tl.math.sin(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_sinh(self, expr): + assert len(expr.args) == 1 + return f"tl.math.sinh(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_asin(self, expr): + assert len(expr.args) == 1 + return f"tl.math.asin(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_tan(self, expr): + assert len(expr.args) == 1 + return f"tl.math.tan(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_tanh(self, expr): + assert len(expr.args) == 1 + return f"tl.math.tanh(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_atan(self, expr): + assert len(expr.args) == 1 + return f"tl.math.atan(({self._print(expr.args[0])}).to(tl.float32))" + def _print_FloorDiv(self, expr): if expr.is_integer: return super()._print_FloorDiv(expr) diff --git a/torch/_prims_common/__init__.py b/torch/_prims_common/__init__.py index db8673024d0b7..c956efb3dc902 100644 --- a/torch/_prims_common/__init__.py +++ b/torch/_prims_common/__init__.py @@ -63,7 +63,7 @@ torch.sym_int, torch.sym_max, torch.sym_min, - torch.sym_sqrt, + torch._sym_sqrt, # type: ignore[attr-defined] torch.sym_ite, torch.Tensor.dim, torch.Tensor.ndim.__get__, # type: ignore[attr-defined] diff --git a/torch/fx/experimental/sym_node.py b/torch/fx/experimental/sym_node.py index 8f2dff475b2ff..3bc3aae1cf1b2 100644 --- a/torch/fx/experimental/sym_node.py +++ b/torch/fx/experimental/sym_node.py @@ -17,6 +17,8 @@ from functools import lru_cache, update_wrapper from typing import Optional, Type, TYPE_CHECKING, Union +import torch + # NB: The sym_* functions are used via getattr() and must be imported here. from torch import ( # noqa: F401 sym_float, @@ -24,7 +26,6 @@ sym_max, sym_min, sym_not, - sym_sqrt, SymBool, SymFloat, SymInt, @@ -41,7 +42,7 @@ log = logging.getLogger(__name__) -__all__ = ["SymNode", "method_to_operator", "magic_methods", "sym_sqrt"] +__all__ = ["SymNode", "method_to_operator", "magic_methods"] SymTypes = (SymInt, SymFloat, SymBool) @@ -301,9 +302,6 @@ def sym_max(self, other) -> "SymNode": # noqa: F811 def sym_ite(self, then_val, else_val) -> "SymNode": return self._sym_ite(then_val, else_val) # type: ignore[attr-defined] - def sym_sqrt(self) -> "SymNode": - return self._sym_sqrt() # type: ignore[attr-defined] - def is_contiguous(self, sizes, strides) -> "SymNode": return self._is_contiguous(sizes, strides) # type: ignore[attr-defined] @@ -436,7 +434,6 @@ def is_constant(self): "sym_max": sym_max, "sym_min": sym_min, "sym_not": sym_not, - "sym_sqrt": sym_sqrt, "truediv": operator.truediv, } @@ -446,10 +443,39 @@ def is_constant(self): "ceil", "floor", "neg", - "sym_sqrt", "sym_not", } + +# Adding math ops: sqrt, cos, sin, ... +def _get_sym_node_fn(name): + def fn(self): + return getattr(self, f"_sym_{name}")() + + return fn + + +math_op_names = ( + "sqrt", + "cos", + "cosh", + "sin", + "sinh", + "tan", + "tanh", + "asin", + "acos", + "atan", +) +for name in math_op_names: + sym_name = f"sym_{name}" + priv_sym_name = f"_{sym_name}" + setattr(SymNode, sym_name, _get_sym_node_fn(name)) + METHOD_TO_OPERATOR[sym_name] = getattr(torch, priv_sym_name) + unary_magic_methods.add(sym_name) + __all__.append(sym_name) + + # Unary methods that are not magic methods unary_nonmagic_methods = { "is_integer", @@ -473,7 +499,13 @@ def is_constant(self): magic_methods_on_operator_with_trailing_underscore = {"and", "or"} -always_float_magic_methods = {"truediv", "sym_float", "sym_sqrt", "pow"} +always_float_magic_methods = {"truediv", "sym_float", "pow"} + +for name in math_op_names: + sym_name = f"sym_{name}" + always_float_magic_methods.add(sym_name) + + always_int_magic_methods = {"ceil", "floor"} always_bool_magic_methods = { "eq", @@ -639,10 +671,25 @@ def _sympy_ite(a, t, f): return sympy.Piecewise((t, a), (f, True)) -def _sympy_sqrt(a): - import sympy +current_module = sys.modules[__name__] + + +def _get_sym_math_fn(name): + def fn(a): + import sympy + + return getattr(sympy, name)(a) + + return fn - return sympy.sqrt(a) + +for name in math_op_names: + priv_sympy_name = f"_sympy_{name}" + fn = _get_sym_math_fn(name) + fn.__qualname__ = fn.__name__ = priv_sympy_name + setattr(current_module, priv_sympy_name, fn) + +del fn, name, priv_sympy_name def _sympy_abs(a): @@ -690,13 +737,19 @@ def _sympy_is_integer(a): "sym_min": _sympy_min, "sym_max": _sympy_max, "sym_ite": _sympy_ite, - "sym_sqrt": _sympy_sqrt, "abs": _sympy_abs, "round": _sympy_round, "is_integer": _sympy_is_integer, } +for name in math_op_names: + sym_name = f"sym_{name}" + magic_methods[sym_name] = getattr(current_module, f"_sympy_{name}") + +del name, sym_name, math_op_names, current_module + + def sympy_is_contiguous(sizes, strides): dim = len(sizes) return sympy_is_contiguous_generic(sizes, strides, list(range(dim - 1, -1, -1))) diff --git a/torch/fx/experimental/validator.py b/torch/fx/experimental/validator.py index a87056fe3c6fd..46c0c716a315a 100644 --- a/torch/fx/experimental/validator.py +++ b/torch/fx/experimental/validator.py @@ -243,8 +243,6 @@ def round(self, number: z3.ArithRef, ndigits: Optional[z3.ArithRef] = None) -> z # 2. Calls an operation that corresponds to 'op', but works with Z3 # inhabitants (left as is if it works as is) def z3op(op: Callable, validator: "TranslationValidator") -> Callable: - from torch.fx.experimental.sym_node import sym_sqrt - # Operations that have booleans as their argument. # This is needed because the argument of some FX nodes were # literal integers, instead of booleans. So, whenever this flag @@ -297,7 +295,7 @@ def wrapper(*args): torch.sym_max: lift(ops.max), torch.sym_min: lift(ops.min), torch.sym_ite: lift(lambda b, t, f: t if b else f), - sym_sqrt: lift(ops.sqrt), + torch._sym_sqrt: lift(ops.sqrt), # type: ignore[attr-defined] # Not lifted because we only use this function as a # marker for adding the expression as validator input. torch._assert: torch._assert, diff --git a/torch/overrides.py b/torch/overrides.py index 96b78f80d9c0f..13f5681dd4896 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -1065,7 +1065,16 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.sym_min: lambda a, b: -1, torch.sym_not: lambda input: -1, torch.sym_ite: lambda a, b, c: -1, - torch.sym_sqrt: lambda input: -1, + torch._sym_sqrt: lambda input: -1, + torch._sym_cos: lambda input: -1, + torch._sym_cosh: lambda input: -1, + torch._sym_sin: lambda input: -1, + torch._sym_sinh: lambda input: -1, + torch._sym_tan: lambda input: -1, + torch._sym_tanh: lambda input: -1, + torch._sym_asin: lambda input: -1, + torch._sym_acos: lambda input: -1, + torch._sym_atan: lambda input: -1, torch.nansum: lambda input, dim=None: -1, torch.svd: lambda input, some=True, compute_uv=True, out=None: -1, torch.svd_lowrank: lambda input, q=6, niter=2, M=None: -1, diff --git a/torch/utils/_sympy/interp.py b/torch/utils/_sympy/interp.py index d6622dba09734..86515b6b1aa77 100644 --- a/torch/utils/_sympy/interp.py +++ b/torch/utils/_sympy/interp.py @@ -72,6 +72,9 @@ def handlers(): Round: "round", RoundDecimal: "round", } + for name in ["cos", "sin", "tan", "sinh", "cosh", "tanh", "asin", "acos", "atan"]: + HANDLERS[getattr(sympy, name)] = name + return HANDLERS diff --git a/torch/utils/_sympy/reference.py b/torch/utils/_sympy/reference.py index 0c49f5749aea0..adb25c7ffb0fe 100644 --- a/torch/utils/_sympy/reference.py +++ b/torch/utils/_sympy/reference.py @@ -195,7 +195,7 @@ def log(x): @staticmethod def sqrt(x): - return torch.sym_sqrt(x) + return torch._sym_sqrt(x) # type: ignore[attr-defined] @staticmethod def minimum(a, b): diff --git a/torch/utils/_sympy/value_ranges.py b/torch/utils/_sympy/value_ranges.py index bb8479558f217..6c87cfe83f028 100644 --- a/torch/utils/_sympy/value_ranges.py +++ b/torch/utils/_sympy/value_ranges.py @@ -513,6 +513,58 @@ def piecewise(*ranges): init_range = init_range | expr_range return init_range + @staticmethod + def cos(x): + # TODO: We should tighten value ranges + # If input range span is pi + 2*pi*k, then output range is (-1, 1) + # otherwise the minimum of the value of the function on the extremes + return ValueRanges(-1.0, 1.0) + + @staticmethod + def cosh(x): + x = ValueRanges.wrap(x) + if x.lower > 0: + return ValueRanges.increasing_map(x, sympy.cosh) + elif x.upper < 0: + return ValueRanges.decreasing_map(x, sympy.cosh) + return ValueRanges(0.0, sympy.oo) + + @staticmethod + def sin(x): + # TODO: We should tighten value ranges + # See details on cos + return ValueRanges(-1.0, 1.0) + + @staticmethod + def sinh(x): + return ValueRanges.increasing_map(x, sympy.sinh) + + @staticmethod + def tan(x): + return ValueRanges(-sympy.oo, sympy.oo) + + @staticmethod + def tanh(x): + return ValueRanges.increasing_map(x, sympy.tanh) + + @staticmethod + def asin(x): + x = ValueRanges.wrap(x) + if -1 <= x.lower and x.upper <= 1: + return ValueRanges.increasing_map(x, sympy.asin) + return ValueRanges.unknown() + + @staticmethod + def acos(x): + x = ValueRanges.wrap(x) + if -1 <= x.lower and x.upper <= 1: + return ValueRanges.decreasing_map(x, sympy.acos) + return ValueRanges.unknown() + + @staticmethod + def atan(x): + return ValueRanges.increasing_map(x, sympy.atan) + class ValueRangeAnalysis(SymPyValueRangeAnalysis): def __init__(self):