Skip to content

Commit

Permalink
[dynamo] Added dyn shapes support for math trigo ops: sin(h), cos(h),…
Browse files Browse the repository at this point in the history
… tan(h) ... (pytorch#114866)

Description:
- Added dynamic shapes support for math trigo ops: sin(h), cos(h), tan(h) ...

```python
import math
import torch

def func(x, a, b):
    c = 0
    c = c + math.sqrt(a)
    c = c + math.cos(a)
    c = c + math.cosh(a)
    c = c + math.sin(a)
    c = c + math.sinh(a)
    c = c + math.tan(a)
    c = c + math.tanh(a)
    c = c + math.asin(b)
    c = c + math.acos(b)
    c = c + math.atan(a)
    y = x + c
    return y

cfunc = torch.compile(func, dynamic=True, fullgraph=True)

device = "cpu"  # or "cuda"
x = torch.tensor([0, 1, 2, 3], dtype=torch.float32, device=device)
a = 12
b = 1

out = cfunc(x, a, b)
expected = func(x, a, b)
torch.testing.assert_close(out, expected)
```

and the graph `TORCH_LOGS=+graph_code python check_math_ops.py`:

<details>
<summary>
graph code
</summary>

```
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] TRACED GRAPH
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]  ===== __compiled_fn_0 =====
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]  <eval_with_key>.0 class GraphModule(torch.nn.Module):
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]     def forward(self, L_a_ : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         l_a_ = L_a_
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         l_x_ = L_x_
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         # File: check_math_ops.py:57, code: c = c + math.sqrt(a)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         sym_sqrt = torch.sym_sqrt(l_a_)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         add = 0 + sym_sqrt;  sym_sqrt = None
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         # File: check_math_ops.py:58, code: c = c + math.cos(a)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         sym_cos = torch.sym_cos(l_a_)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         add_1 = add + sym_cos;  add = sym_cos = None
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         # File: check_math_ops.py:59, code: c = c + math.cosh(a)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         sym_cosh = torch.sym_cosh(l_a_)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         add_2 = add_1 + sym_cosh;  add_1 = sym_cosh = None
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         # File: check_math_ops.py:60, code: c = c + math.sin(a)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         sym_sin = torch.sym_sin(l_a_)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         add_3 = add_2 + sym_sin;  add_2 = sym_sin = None
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         # File: check_math_ops.py:61, code: c = c + math.sinh(a)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         sym_sinh = torch.sym_sinh(l_a_)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         add_4 = add_3 + sym_sinh;  add_3 = sym_sinh = None
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         # File: check_math_ops.py:62, code: c = c + math.tan(a)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         sym_tan = torch.sym_tan(l_a_)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         add_5 = add_4 + sym_tan;  add_4 = sym_tan = None
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         # File: check_math_ops.py:63, code: c = c + math.tanh(a)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         sym_tanh = torch.sym_tanh(l_a_)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         add_6 = add_5 + sym_tanh;  add_5 = sym_tanh = None
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         # File: check_math_ops.py:64, code: c = c + math.asin(b)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         add_7 = add_6 + 1.5707963267948966;  add_6 = None
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         # File: check_math_ops.py:65, code: c = c + math.acos(b)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         add_8 = add_7 + 0.0;  add_7 = None
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         # File: check_math_ops.py:66, code: c = c + math.atan(a)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         sym_atan = torch.sym_atan(l_a_);  l_a_ = None
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         add_9 = add_8 + sym_atan;  add_8 = sym_atan = None
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         # File: check_math_ops.py:67, code: y = x + c
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         y = l_x_ + add_9;  l_x_ = add_9 = None
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         return (y,)
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]
[2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]
```
</details>

Generated code with `TORCH_LOGS=+output_code python check_math_ops.py`:
<details>
<summary>
C++ code
</summary>

```
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] cpp_fused_add_0 = async_compile.cpp('''
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] #include "/tmp/torchinductor_root/2l/c2ljzlm4sosod7u6lyrroqdba6hmfcyijrric6p4t3fhbcmw6osp.h"
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] extern "C" void kernel(const float* in_ptr0,
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG]                        float* out_ptr0,
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG]                        const long ks0,
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG]                        const long ks1)
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] {
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG]     {
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG]         #pragma GCC ivdep
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG]         for(long x0=static_cast<long>(0L); x0<static_cast<long>(ks0); x0+=static_cast<long>(1L))
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG]         {
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG]             auto tmp0 = in_ptr0[static_cast<long>(x0)];
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG]             auto tmp1 = c10::convert<float>(1.57079632679490 + (std::sqrt(ks1)) + (std::atan(ks1)) + (std::cos(ks1)) + (std::cosh(ks1)) + (std::sin(ks1)) + (std::sinh(ks1)) + (std::tan(ks1)) + (std::tanh(ks1)));
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG]             auto tmp2 = decltype(tmp0)(tmp0 + tmp1);
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG]             out_ptr0[static_cast<long>(x0)] = tmp2;
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG]         }
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG]     }
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] }
[2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] ''')
```

</details>

<details>
<summary>
Triton code
</summary>

```
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] @pointwise(
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG]     size_hints=[4],
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG]     filename=__file__,
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG]     triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1), equal_to_1=(), i
ds_of_folded_args=(), divisible_by_8=())]},
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG]     inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_0', 'mutated_arg_names': []},
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG]     min_elem_per_thread=0
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] )
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] @triton.jit
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] def triton_(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG]     xoffset = tl.program_id(0) * XBLOCK
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG]     xindex = xoffset + tl.arange(0, XBLOCK)[:]
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG]     xmask = xindex < xnumel
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG]     x0 = xindex
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG]     tmp0 = tl.load(in_ptr0 + (x0), xmask)
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG]     tmp1 = 1.57079632679490 + (tl.math.sqrt(ks0.to(tl.float32))) + (tl.math.atan((ks0).to(tl.float32))) + (tl.math.cos((ks0).to(tl.float32))) + (tl.math.cosh((ks0).to(tl.float32))) + (tl.math.sin((ks0)
.to(tl.float32))) + (tl.math.sinh((ks0).to(tl.float32))) + (tl.math.tan((ks0).to(tl.float32))) + (tl.math.tanh((ks0).to(tl.float32)))
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG]     tmp2 = tmp1.to(tl.float32)
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG]     tmp3 = tmp0 + tmp2
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG]     tl.store(out_ptr0 + (x0), tmp3, xmask)
[2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] ''')
```

</details>

Pull Request resolved: pytorch#114866
Approved by: https://github.com/peterbell10
  • Loading branch information
vfdev-5 authored and pytorchmergebot committed Jan 11, 2024
1 parent 2b5a201 commit 7005a4b
Show file tree
Hide file tree
Showing 16 changed files with 304 additions and 36 deletions.
2 changes: 1 addition & 1 deletion test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -2223,7 +2223,7 @@ def forward(self, x):
ep = export(M(), (torch.ones(16, 4),), dynamic_shapes={'x': {0: Dim("dim")}})
_ExportPassBaseDeprecatedDoNotUse()(ep.graph_module)
FileCheck().check_count(
"torch.sym_sqrt", 1, exactly=True
"torch._sym_sqrt", 1, exactly=True
).run(ep.graph_module.code)

def test_check_specialized_int(self):
Expand Down
28 changes: 28 additions & 0 deletions test/inductor/test_torchinductor_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torch.testing._internal.common_utils import (
IS_CI,
IS_WINDOWS,
parametrize,
TEST_WITH_ASAN,
TEST_WITH_ROCM,
TestCase,
Expand Down Expand Up @@ -496,6 +497,33 @@ def fn(a):
actual = cfn(5)
self.assertEqual(expect, actual)

@parametrize(
"op",
[
math.sqrt,
math.sin,
math.cos,
math.cosh,
math.sin,
math.sinh,
math.tan,
math.tanh,
math.asin,
math.acos,
math.atan,
],
)
def test_math_ops(self, device, op):
def func(x, fn, a):
return x + fn(a)

cfunc = self.compile_fn(func, fullgraph=True)
x = torch.rand(10, device=device)
a = -1 if op in (math.asin, math.acos) else 12
expected = func(x, op, a)
output = cfunc(x, op, a)
self.assertEqual(output, expected)


instantiate_device_type_tests(TestInductorDynamic, globals())

Expand Down
6 changes: 3 additions & 3 deletions test/test_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torch._C import _disabled_torch_function_impl
from torch.fx.experimental import sym_node
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.sym_node import to_node, sym_sqrt, SymNode, method_to_operator
from torch.fx.experimental.sym_node import to_node, SymNode, method_to_operator
from torch.fx.experimental.symbolic_shapes import (
DimConstraints,
DimDynamic,
Expand Down Expand Up @@ -394,7 +394,7 @@ def test_sym_int(self):
def test_sym_sqrt(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 4)
r = sym_sqrt(a0)
r = torch._sym_sqrt(a0)
self.assertEqual(r, 2)
self.assertIsInstance(r, torch.SymFloat, msg=type(r))
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(sqrt(s0), 2)""")
Expand Down Expand Up @@ -764,7 +764,7 @@ def test_method(self, fn, first_type, second_type):
values = (
0.0,
1.0,
2.5,
0.5 if fn in ("sym_acos", "sym_asin") else 2.5 # avoid math domain error
)

neg_values = tuple(-x for x in values)
Expand Down
35 changes: 26 additions & 9 deletions torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def _running_with_deploy():
'set_warn_always', 'is_warn_always_enabled', 'SymInt', 'SymFloat',
'SymBool', 'sym_not', 'unravel_index',
'sym_int', 'sym_float', 'sym_max', 'sym_min', 'sym_ite', 'compile', 'vmap',
'sym_sqrt',
'export', 'autocast', 'cond',
]

Expand Down Expand Up @@ -491,15 +490,33 @@ def sym_min(a, b):
return b.__sym_min__(a)
return builtins.min(a, b) # type: ignore[operator]

# Drop in replacement for math.sqrt
def sym_sqrt(a):
from .overrides import has_torch_function_unary, handle_torch_function
# Drop in replacement for math.sqrt, math.sin, math.cos etc
current_module = sys.modules[__name__]

def _get_sym_math_fn(name):
def fn(a):
from .overrides import has_torch_function_unary, handle_torch_function

if has_torch_function_unary(a):
return handle_torch_function(fn, (a,), a)
if hasattr(a, f"__sym_{name}__"):
return getattr(a, f"__sym_{name}__")()
return getattr(math, name)(a)

return fn

for name in ("sqrt", "cos", "cosh", "sin", "sinh", "tan", "tanh", "asin", "acos", "atan"):
sym_name = f"_sym_{name}"
fn = _get_sym_math_fn(name)
fn.__qualname__ = fn.__name__ = sym_name
setattr(current_module, sym_name, fn)

# Adding temporary shortcut
sym_sqrt = current_module._sym_sqrt
__all__.append("sym_sqrt")

del fn, name, sym_name, current_module

if has_torch_function_unary(a):
return handle_torch_function(sym_sqrt, (a,), a)
if hasattr(a, "__sym_sqrt__"):
return a.__sym_sqrt__()
return math.sqrt(a)

def sym_ite(b, t, f):
from .overrides import has_torch_function, handle_torch_function
Expand Down
9 changes: 5 additions & 4 deletions torch/_dynamo/variables/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,10 +524,11 @@ def fn_with_prim_types(x):
# of value + args to determine this.
fn_ = self.value
if any(isinstance(x, SymNodeVariable) for x in args):
if self.value == math.sqrt:
from torch.fx.experimental.sym_node import sym_sqrt

fn_ = sym_sqrt
torch_sym_op = f"_sym_{self.value.__name__}"
if getattr(self.value, "__module__", None) == "math" and hasattr(
torch, torch_sym_op
):
fn_ = getattr(torch, torch_sym_op)

if fn_ is torch.tensor:

Expand Down
1 change: 0 additions & 1 deletion torch/_export/serde/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ def _reverse_map(d: Dict[Any, Enum]):
operator.sub,
operator.floordiv,
operator.mod,
torch.sym_sqrt,
torch.sym_int,
torch.sym_ite,
torch.sym_max,
Expand Down
36 changes: 36 additions & 0 deletions torch/_inductor/codegen/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,42 @@ def _print_Min(self, expr):
assert len(expr.args) >= 2
return f"min({', '.join(map(self._print, expr.args))})"

def _print_cos(self, expr):
assert len(expr.args) == 1
return f"math.cos({self._print(expr.args[0])})"

def _print_cosh(self, expr):
assert len(expr.args) == 1
return f"math.cosh({self._print(expr.args[0])})"

def _print_acos(self, expr):
assert len(expr.args) == 1
return f"math.acos({self._print(expr.args[0])})"

def _print_sin(self, expr):
assert len(expr.args) == 1
return f"math.sin({self._print(expr.args[0])})"

def _print_sinh(self, expr):
assert len(expr.args) == 1
return f"math.sinh({self._print(expr.args[0])})"

def _print_asin(self, expr):
assert len(expr.args) == 1
return f"math.asin({self._print(expr.args[0])})"

def _print_tan(self, expr):
assert len(expr.args) == 1
return f"math.tan({self._print(expr.args[0])})"

def _print_tanh(self, expr):
assert len(expr.args) == 1
return f"math.tanh({self._print(expr.args[0])})"

def _print_atan(self, expr):
assert len(expr.args) == 1
return f"math.atan({self._print(expr.args[0])})"

def _print_Round(self, expr):
assert len(expr.args) == 1
return f"round({self._print(expr.args[0])})"
Expand Down
36 changes: 36 additions & 0 deletions torch/_inductor/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,42 @@ def _print_Abs(self, expr):
assert len(expr.args) == 1
return f"std::abs({self._print(expr.args[0])})"

def _print_cos(self, expr):
assert len(expr.args) == 1
return f"std::cos({self._print(expr.args[0])})"

def _print_cosh(self, expr):
assert len(expr.args) == 1
return f"std::cosh({self._print(expr.args[0])})"

def _print_acos(self, expr):
assert len(expr.args) == 1
return f"std::acos({self._print(expr.args[0])})"

def _print_sin(self, expr):
assert len(expr.args) == 1
return f"std::sin({self._print(expr.args[0])})"

def _print_sinh(self, expr):
assert len(expr.args) == 1
return f"std::sinh({self._print(expr.args[0])})"

def _print_asin(self, expr):
assert len(expr.args) == 1
return f"std::asin({self._print(expr.args[0])})"

def _print_tan(self, expr):
assert len(expr.args) == 1
return f"std::tan({self._print(expr.args[0])})"

def _print_tanh(self, expr):
assert len(expr.args) == 1
return f"std::tanh({self._print(expr.args[0])})"

def _print_atan(self, expr):
assert len(expr.args) == 1
return f"std::atan({self._print(expr.args[0])})"

def _print_Round(self, expr):
assert len(expr.args) == 1
return f"std::lrint({self._print(expr.args[0])})"
Expand Down
36 changes: 36 additions & 0 deletions torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,42 @@ def _print_Abs(self, expr):
assert len(expr.args) == 1
return f"tl.abs({self._print(expr.args[0])})"

def _print_cos(self, expr):
assert len(expr.args) == 1
return f"tl.math.cos(({self._print(expr.args[0])}).to(tl.float32))"

def _print_cosh(self, expr):
assert len(expr.args) == 1
return f"tl.math.cosh(({self._print(expr.args[0])}).to(tl.float32))"

def _print_acos(self, expr):
assert len(expr.args) == 1
return f"tl.math.acos(({self._print(expr.args[0])}).to(tl.float32))"

def _print_sin(self, expr):
assert len(expr.args) == 1
return f"tl.math.sin(({self._print(expr.args[0])}).to(tl.float32))"

def _print_sinh(self, expr):
assert len(expr.args) == 1
return f"tl.math.sinh(({self._print(expr.args[0])}).to(tl.float32))"

def _print_asin(self, expr):
assert len(expr.args) == 1
return f"tl.math.asin(({self._print(expr.args[0])}).to(tl.float32))"

def _print_tan(self, expr):
assert len(expr.args) == 1
return f"tl.math.tan(({self._print(expr.args[0])}).to(tl.float32))"

def _print_tanh(self, expr):
assert len(expr.args) == 1
return f"tl.math.tanh(({self._print(expr.args[0])}).to(tl.float32))"

def _print_atan(self, expr):
assert len(expr.args) == 1
return f"tl.math.atan(({self._print(expr.args[0])}).to(tl.float32))"

def _print_FloorDiv(self, expr):
if expr.is_integer:
return super()._print_FloorDiv(expr)
Expand Down
2 changes: 1 addition & 1 deletion torch/_prims_common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
torch.sym_int,
torch.sym_max,
torch.sym_min,
torch.sym_sqrt,
torch._sym_sqrt, # type: ignore[attr-defined]
torch.sym_ite,
torch.Tensor.dim,
torch.Tensor.ndim.__get__, # type: ignore[attr-defined]
Expand Down
Loading

0 comments on commit 7005a4b

Please sign in to comment.