Skip to content

Conversation

will-cromar
Copy link
Collaborator

@will-cromar will-cromar commented May 15, 2024

See #6751

  • This implementation is intentionally minimal to start with. The main improvement compared to sync is that exceptions are handled sanely.
  • Update README example to use xla.step. Remove ParallelLoader because it mostly does not make a difference for MP, and we should keep our starting point as simple as possible.

@will-cromar will-cromar changed the title [WIP] Add xla.step context manager Add xla.step context manager May 16, 2024
@will-cromar will-cromar requested a review from JackCaoG May 16, 2024 20:09
@will-cromar will-cromar marked this pull request as ready for review May 16, 2024 20:09

# Create a DataLoader
dataset = TensorDataset(input_data, target_data)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does DataLoader don't take device as argument?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. In normal PyTorch, you have to move the data with tensor.to: https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#training-on-gpu



@contextlib.contextmanager
def step():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the reason I find step can be a bit confusing is that we don't call mark_step upon entering the step.

with xla.step():
  y = x + z

y += 1

step as a context kind of suggest execution will only cover what happened in side the context manger but that's actually not the case.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. This should either print a warning if there are pending operations, or just mark_step twice. What do you think is better?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's try mark_step twice and benchmark it with one of the examples on resneto50 with fakedata.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to hold off on modifying the examples until we're running tests on them. Here's my patch:

--- a/examples/train_resnet_base.py
+++ b/examples/train_resnet_base.py
@@ -45,15 +45,16 @@ class TrainResNetBase():
     self.model.train()
     loader = itertools.islice(loader, self.num_steps)
     for step, (data, target) in enumerate(loader):
-      self.optimizer.zero_grad()
-      output = self.model(data)
-      loss = self.loss_fn(output, target)
-      loss.backward()
-      self.run_optimizer()
+      with torch_xla.step():
+        self.optimizer.zero_grad()
+        output = self.model(data)
+        loss = self.loss_fn(output, target)
+        loss.backward()
+        self.run_optimizer()
+
       tracker.add(self.batch_size)
       if step % 10 == 0:
-        xm.add_step_closure(
-            self._train_update, args=(step, loss, tracker, epoch))
+        self._train_update(step, loss, tracker, epoch)

Before:

epoch: 1, step: 290, loss: 6.608619213104248, rate: 1747.0911849087843
epoch: 1, step: 290, loss: 6.606635570526123, rate: 1747.0763868012214
epoch: 1, step: 290, loss: 6.618781566619873, rate: 1747.2648104487325
epoch: 1, step: 290, loss: 6.605813980102539, rate: 1746.9924093597208

After:

epoch: 1, step: 290, loss: 6.603261947631836, rate: 1752.4689284654187
epoch: 1, step: 290, loss: 6.607376575469971, rate: 1752.4377415557715
epoch: 1, step: 290, loss: 6.611710071563721, rate: 1752.2556378789855
epoch: 1, step: 290, loss: 6.638012886047363, rate: 1752.400066823619

@will-cromar will-cromar merged commit 3c59087 into master May 17, 2024
zpcore pushed a commit that referenced this pull request May 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants