Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[dynamo] Added dyn shapes support for math trigo ops: sin(h), cos(h),…
… tan(h) ... (pytorch#114866) Description: - Added dynamic shapes support for math trigo ops: sin(h), cos(h), tan(h) ... ```python import math import torch def func(x, a, b): c = 0 c = c + math.sqrt(a) c = c + math.cos(a) c = c + math.cosh(a) c = c + math.sin(a) c = c + math.sinh(a) c = c + math.tan(a) c = c + math.tanh(a) c = c + math.asin(b) c = c + math.acos(b) c = c + math.atan(a) y = x + c return y cfunc = torch.compile(func, dynamic=True, fullgraph=True) device = "cpu" # or "cuda" x = torch.tensor([0, 1, 2, 3], dtype=torch.float32, device=device) a = 12 b = 1 out = cfunc(x, a, b) expected = func(x, a, b) torch.testing.assert_close(out, expected) ``` and the graph `TORCH_LOGS=+graph_code python check_math_ops.py`: <details> <summary> graph code </summary> ``` [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] TRACED GRAPH [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] ===== __compiled_fn_0 ===== [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] <eval_with_key>.0 class GraphModule(torch.nn.Module): [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] def forward(self, L_a_ : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor): [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] l_a_ = L_a_ [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] l_x_ = L_x_ [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] # File: check_math_ops.py:57, code: c = c + math.sqrt(a) [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] sym_sqrt = torch.sym_sqrt(l_a_) [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] add = 0 + sym_sqrt; sym_sqrt = None [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] # File: check_math_ops.py:58, code: c = c + math.cos(a) [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] sym_cos = torch.sym_cos(l_a_) [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] add_1 = add + sym_cos; add = sym_cos = None [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] # File: check_math_ops.py:59, code: c = c + math.cosh(a) [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] sym_cosh = torch.sym_cosh(l_a_) [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] add_2 = add_1 + sym_cosh; add_1 = sym_cosh = None [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] # File: check_math_ops.py:60, code: c = c + math.sin(a) [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] sym_sin = torch.sym_sin(l_a_) [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] add_3 = add_2 + sym_sin; add_2 = sym_sin = None [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] # File: check_math_ops.py:61, code: c = c + math.sinh(a) [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] sym_sinh = torch.sym_sinh(l_a_) [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] add_4 = add_3 + sym_sinh; add_3 = sym_sinh = None [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] # File: check_math_ops.py:62, code: c = c + math.tan(a) [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] sym_tan = torch.sym_tan(l_a_) [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] add_5 = add_4 + sym_tan; add_4 = sym_tan = None [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] # File: check_math_ops.py:63, code: c = c + math.tanh(a) [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] sym_tanh = torch.sym_tanh(l_a_) [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] add_6 = add_5 + sym_tanh; add_5 = sym_tanh = None [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] # File: check_math_ops.py:64, code: c = c + math.asin(b) [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] add_7 = add_6 + 1.5707963267948966; add_6 = None [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] # File: check_math_ops.py:65, code: c = c + math.acos(b) [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] add_8 = add_7 + 0.0; add_7 = None [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] # File: check_math_ops.py:66, code: c = c + math.atan(a) [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] sym_atan = torch.sym_atan(l_a_); l_a_ = None [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] add_9 = add_8 + sym_atan; add_8 = sym_atan = None [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] # File: check_math_ops.py:67, code: y = x + c [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] y = l_x_ + add_9; l_x_ = add_9 = None [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] return (y,) [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] [2023-11-30 22:16:10,654] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] ``` </details> Generated code with `TORCH_LOGS=+output_code python check_math_ops.py`: <details> <summary> C++ code </summary> ``` [2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] cpp_fused_add_0 = async_compile.cpp(''' [2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] #include "/tmp/torchinductor_root/2l/c2ljzlm4sosod7u6lyrroqdba6hmfcyijrric6p4t3fhbcmw6osp.h" [2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] extern "C" void kernel(const float* in_ptr0, [2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] float* out_ptr0, [2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] const long ks0, [2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] const long ks1) [2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] { [2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] { [2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] #pragma GCC ivdep [2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] for(long x0=static_cast<long>(0L); x0<static_cast<long>(ks0); x0+=static_cast<long>(1L)) [2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] { [2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] auto tmp0 = in_ptr0[static_cast<long>(x0)]; [2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] auto tmp1 = c10::convert<float>(1.57079632679490 + (std::sqrt(ks1)) + (std::atan(ks1)) + (std::cos(ks1)) + (std::cosh(ks1)) + (std::sin(ks1)) + (std::sinh(ks1)) + (std::tan(ks1)) + (std::tanh(ks1))); [2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] auto tmp2 = decltype(tmp0)(tmp0 + tmp1); [2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] out_ptr0[static_cast<long>(x0)] = tmp2; [2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] } [2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] } [2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] } [2023-11-30 22:19:09,709] [0/0] torch._inductor.graph.__output_code: [DEBUG] ''') ``` </details> <details> <summary> Triton code </summary> ``` [2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] @pointwise( [2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] size_hints=[4], [2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] filename=__file__, [2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1), equal_to_1=(), i ds_of_folded_args=(), divisible_by_8=())]}, [2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_0', 'mutated_arg_names': []}, [2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] min_elem_per_thread=0 [2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] ) [2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] @triton.jit [2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] def triton_(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): [2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] xoffset = tl.program_id(0) * XBLOCK [2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] xindex = xoffset + tl.arange(0, XBLOCK)[:] [2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] xmask = xindex < xnumel [2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] x0 = xindex [2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] tmp0 = tl.load(in_ptr0 + (x0), xmask) [2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] tmp1 = 1.57079632679490 + (tl.math.sqrt(ks0.to(tl.float32))) + (tl.math.atan((ks0).to(tl.float32))) + (tl.math.cos((ks0).to(tl.float32))) + (tl.math.cosh((ks0).to(tl.float32))) + (tl.math.sin((ks0) .to(tl.float32))) + (tl.math.sinh((ks0).to(tl.float32))) + (tl.math.tan((ks0).to(tl.float32))) + (tl.math.tanh((ks0).to(tl.float32))) [2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] tmp2 = tmp1.to(tl.float32) [2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] tmp3 = tmp0 + tmp2 [2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] tl.store(out_ptr0 + (x0), tmp3, xmask) [2023-11-30 22:20:00,383] [0/0] torch._inductor.graph.__output_code: [DEBUG] ''') ``` </details> Pull Request resolved: pytorch#114866 Approved by: https://github.com/peterbell10
- Loading branch information