Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

init tls grad_mode/local_dispatch_key set while fork new thread in #113246

Closed
wants to merge 5 commits into from

Conversation

zhuhaozhe
Copy link
Collaborator

@zhuhaozhe zhuhaozhe commented Nov 8, 2023

TorchDynamo will guard grad_mode and the local dispatch key set.

struct LocalState {
// TLS state that changes operators
c10::impl::LocalDispatchKeySet dispatch_modifier;
bool grad_mode_enabled;

While using ThroughputBenchmark, those tls state will not be init as same as the main thread status.

callers.emplace_back([&, thread_id]() {
// We use conditional variable as a barrier to make sure each thread
// performs required warmeup iterations before we start measuring
for (const auto j : c10::irange(config.num_warmup_iters)) {
(void)j;
runOnce(std::move(thread_inputs[thread_id][input_iters[thread_id]]));
++input_iters[thread_id];
}
{
std::unique_lock<std::mutex> lock(m);
++initialized;
worker_main_cv.notify_one();
// NOLINTNEXTLINE(bugprone-infinite-loop)
while (!start) {
main_worker_cv.wait(lock);
}
}
LOG(INFO) << "Starting forward thread " << thread_id;
while (num_attempted_iters.fetch_add(1) < config.num_iters) {
runOnce(std::move(thread_inputs[thread_id][input_iters[thread_id]]));
++input_iters[thread_id];
}
{
std::unique_lock<std::mutex> lock(m);
++finished;
worker_main_cv.notify_one();
LOG(INFO) << "Shutting down forward thread " << thread_id
<< ". Total number of finished threads: " << finished;
}
});

Run following scripts

import torch
linear = torch.nn.Linear(128, 128)
compiled = torch.compile(linear)
x = torch.rand(10, 128)
with torch.no_grad(), torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
    compiled(x)
    compiled(x)

from torch._dynamo import config
config.error_on_recompile = True
from torch.utils import ThroughputBenchmark
with torch.no_grad(), torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
    bench = ThroughputBenchmark(compiled)
    bench.add_input(x)
    stats = bench.benchmark(
        num_calling_threads=10,
        num_warmup_iters=100,
        num_iters=100,
    )
    print(stats)

will lead to 2 re-compile reasons:

triggered by the following guard failure(s): ___check_global_state()
triggered by the following guard failure(s): tensor 'x' dispatch key set mismatch.

This will trigger a re-compile in torchdynamo. But since ThroughputBenchmark is used for sharing weight within threads, the model should not be changed anymore while running the benchmark. So this PR is to init the tls state as same as main thread. Then we can use ThroughputBenchmark to run torchdynamo optimized models.

Stack from ghstack (oldest at bottom):

throughputbenchmark

Copy link

pytorch-bot bot commented Nov 8, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/113246

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit f7f34af with merge base 3ff4572 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

zhuhaozhe added a commit that referenced this pull request Nov 8, 2023
throughputbenchmark

ghstack-source-id: 2d5dfa163d852914ea845c4df42c58a8153856f4
Pull Request resolved: #113246
@zhuhaozhe zhuhaozhe requested a review from jgong5 November 8, 2023 07:24
zhuhaozhe added a commit that referenced this pull request Nov 9, 2023
throughputbenchmark

ghstack-source-id: 86fba3f6380ccbd9e7bd61fb807519c53441f1e7
Pull Request resolved: #113246
…hread in"

TorchDynamo will guard grad_mode and the local dispatch key set.
https://github.com/pytorch/pytorch/blob/3a429423fcf72430e7a36c79e263c877d7a4ef72/torch/csrc/dynamo/guards.cpp#L13-L16

While using ThroughputBenchmark, those tls state will not be init as same as the main thread status.
https://github.com/pytorch/pytorch/blob/3a429423fcf72430e7a36c79e263c877d7a4ef72/torch/csrc/utils/throughput_benchmark-inl.h#L64-L94

Run following scripts
```
import torch
linear = torch.nn.Linear(128, 128)
compiled = torch.compile(linear)
x = torch.rand(10, 128)
with torch.no_grad(), torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
    compiled(x)
    compiled(x)

from torch._dynamo import config
config.error_on_recompile = True
from torch.utils import ThroughputBenchmark
with torch.no_grad(), torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
    bench = ThroughputBenchmark(compiled)
    bench.add_input(x)
    stats = bench.benchmark(
        num_calling_threads=10,
        num_warmup_iters=100,
        num_iters=100,
    )
    print(stats)
```
will lead to 2 re-compile reasons:
```
triggered by the following guard failure(s): ___check_global_state()
triggered by the following guard failure(s): tensor 'x' dispatch key set mismatch.
```

This will trigger a re-compile in torchdynamo. But since `ThroughputBenchmark` is used for sharing weight within threads, the model should not be changed anymore while running the benchmark. So this PR is to init the tls state as same as main thread. Then we can use ` ThroughputBenchmark` to run torchdynamo optimized models.




throughputbenchmark

[ghstack-poisoned]
@zhuhaozhe
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 15, 2023
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: Approval needed from one of the following:
iseeyuan, chenyang78, lc0, atalman, ananthsub, ...

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@zhuhaozhe
Copy link
Collaborator Author

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

…hread in"

TorchDynamo will guard grad_mode and the local dispatch key set.
https://github.com/pytorch/pytorch/blob/3a429423fcf72430e7a36c79e263c877d7a4ef72/torch/csrc/dynamo/guards.cpp#L13-L16

While using ThroughputBenchmark, those tls state will not be init as same as the main thread status.
https://github.com/pytorch/pytorch/blob/3a429423fcf72430e7a36c79e263c877d7a4ef72/torch/csrc/utils/throughput_benchmark-inl.h#L64-L94

Run following scripts
```
import torch
linear = torch.nn.Linear(128, 128)
compiled = torch.compile(linear)
x = torch.rand(10, 128)
with torch.no_grad(), torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
    compiled(x)
    compiled(x)

from torch._dynamo import config
config.error_on_recompile = True
from torch.utils import ThroughputBenchmark
with torch.no_grad(), torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
    bench = ThroughputBenchmark(compiled)
    bench.add_input(x)
    stats = bench.benchmark(
        num_calling_threads=10,
        num_warmup_iters=100,
        num_iters=100,
    )
    print(stats)
```
will lead to 2 re-compile reasons:
```
triggered by the following guard failure(s): ___check_global_state()
triggered by the following guard failure(s): tensor 'x' dispatch key set mismatch.
```

This will trigger a re-compile in torchdynamo. But since `ThroughputBenchmark` is used for sharing weight within threads, the model should not be changed anymore while running the benchmark. So this PR is to init the tls state as same as main thread. Then we can use ` ThroughputBenchmark` to run torchdynamo optimized models.




throughputbenchmark

[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Successfully rebased gh/zhuhaozhe/2/orig onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via ghstack checkout https://github.com/pytorch/pytorch/pull/113246)

pytorchmergebot pushed a commit that referenced this pull request Nov 19, 2023
throughputbenchmark

ghstack-source-id: cf24d8bb677b5d408cd3cec5d99ce987602de4f1
Pull Request resolved: #113246
@zhuhaozhe
Copy link
Collaborator Author

Hi, @iseeyuan, @chenyang78 , @atalman May you help to review this PR, the mergebot shows you are the owner.

zhuhaozhe added a commit that referenced this pull request Nov 20, 2023
throughputbenchmark

ghstack-source-id: 1b0c97fb8939d397220f0541f57581facc3bc574
Pull Request resolved: #113246
…hread in"

TorchDynamo will guard grad_mode and the local dispatch key set.
https://github.com/pytorch/pytorch/blob/3a429423fcf72430e7a36c79e263c877d7a4ef72/torch/csrc/dynamo/guards.cpp#L13-L16

While using ThroughputBenchmark, those tls state will not be init as same as the main thread status.
https://github.com/pytorch/pytorch/blob/3a429423fcf72430e7a36c79e263c877d7a4ef72/torch/csrc/utils/throughput_benchmark-inl.h#L64-L94

Run following scripts
```
import torch
linear = torch.nn.Linear(128, 128)
compiled = torch.compile(linear)
x = torch.rand(10, 128)
with torch.no_grad(), torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
    compiled(x)
    compiled(x)

from torch._dynamo import config
config.error_on_recompile = True
from torch.utils import ThroughputBenchmark
with torch.no_grad(), torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
    bench = ThroughputBenchmark(compiled)
    bench.add_input(x)
    stats = bench.benchmark(
        num_calling_threads=10,
        num_warmup_iters=100,
        num_iters=100,
    )
    print(stats)
```
will lead to 2 re-compile reasons:
```
triggered by the following guard failure(s): ___check_global_state()
triggered by the following guard failure(s): tensor 'x' dispatch key set mismatch.
```

This will trigger a re-compile in torchdynamo. But since `ThroughputBenchmark` is used for sharing weight within threads, the model should not be changed anymore while running the benchmark. So this PR is to init the tls state as same as main thread. Then we can use ` ThroughputBenchmark` to run torchdynamo optimized models.




throughputbenchmark

[ghstack-poisoned]
@zhuhaozhe zhuhaozhe requested review from desertfire and removed request for ananthsub December 21, 2023 08:12
@zhuhaozhe
Copy link
Collaborator Author

Hi, @desertfire May you help to review this PR?

@zhuhaozhe
Copy link
Collaborator Author

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

…hread in"

TorchDynamo will guard grad_mode and the local dispatch key set.
https://github.com/pytorch/pytorch/blob/3a429423fcf72430e7a36c79e263c877d7a4ef72/torch/csrc/dynamo/guards.cpp#L13-L16

While using ThroughputBenchmark, those tls state will not be init as same as the main thread status.
https://github.com/pytorch/pytorch/blob/3a429423fcf72430e7a36c79e263c877d7a4ef72/torch/csrc/utils/throughput_benchmark-inl.h#L64-L94

Run following scripts
```
import torch
linear = torch.nn.Linear(128, 128)
compiled = torch.compile(linear)
x = torch.rand(10, 128)
with torch.no_grad(), torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
    compiled(x)
    compiled(x)

from torch._dynamo import config
config.error_on_recompile = True
from torch.utils import ThroughputBenchmark
with torch.no_grad(), torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
    bench = ThroughputBenchmark(compiled)
    bench.add_input(x)
    stats = bench.benchmark(
        num_calling_threads=10,
        num_warmup_iters=100,
        num_iters=100,
    )
    print(stats)
```
will lead to 2 re-compile reasons:
```
triggered by the following guard failure(s): ___check_global_state()
triggered by the following guard failure(s): tensor 'x' dispatch key set mismatch.
```

This will trigger a re-compile in torchdynamo. But since `ThroughputBenchmark` is used for sharing weight within threads, the model should not be changed anymore while running the benchmark. So this PR is to init the tls state as same as main thread. Then we can use ` ThroughputBenchmark` to run torchdynamo optimized models.




throughputbenchmark

[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Successfully rebased gh/zhuhaozhe/2/orig onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via ghstack checkout https://github.com/pytorch/pytorch/pull/113246)

pytorchmergebot pushed a commit that referenced this pull request Jan 3, 2024
throughputbenchmark

ghstack-source-id: 86bb4b94dbfe36c4fef326f595f16e1f427e34dc
Pull Request resolved: #113246
@zhuhaozhe
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

@zhuhaozhe
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/zhuhaozhe/2/head branch January 14, 2024 15:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged open source topic: not user facing topic category
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

5 participants