Skip to content

Commit

Permalink
[inductor] refine loop split logic
Browse files Browse the repository at this point in the history
ghstack-source-id: a0ffb42b1c0b2159b72f278aa4184ab75325cd03
Pull Request resolved: pytorch#128812
  • Loading branch information
zhuhaozhe committed Jul 24, 2024
1 parent 406f510 commit c3d519c
Show file tree
Hide file tree
Showing 4 changed files with 718 additions and 396 deletions.
4 changes: 2 additions & 2 deletions test/inductor/test_cpu_repro.py
Original file line number Diff line number Diff line change
Expand Up @@ -3544,13 +3544,13 @@ def test_non_contiguous_reduction_store(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(39, 1, kernel_size=(1, 17), stride=(2, 2))
self.conv = torch.nn.Conv2d(39, 1, kernel_size=(1, 20), stride=(2, 2))

def forward(self, x):
return self.conv(x.max(3).values)

m = M()
x = torch.randn(1, 39, 1, 18, 17)
x = torch.randn(1, 39, 1, 18, 20)
self.common(m, (x,))

def test_embedding_vec(self):
Expand Down
7 changes: 6 additions & 1 deletion torch/_inductor/codegen/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
DeferredLineBase,
generate_assert,
IndentedBuffer,
PlaceHolderLine,
sympy_dot,
sympy_subs,
unique,
Expand Down Expand Up @@ -1143,7 +1144,11 @@ def __call__(self):
V.kernel.inplaced_to_remove,
)
):
return self.line
if isinstance(self.line, PlaceHolderLine):
line = self.line()
else:
line = self.line
return line
return None

def _new_line(self, line):
Expand Down
Loading

0 comments on commit c3d519c

Please sign in to comment.