Skip to content

mx: small speedup with dim0 cast#1980

Merged
vkuzo merged 72 commits into
mainfrom
gh/vkuzo/86/head
Apr 1, 2025
Merged

mx: small speedup with dim0 cast#1980
vkuzo merged 72 commits into
mainfrom
gh/vkuzo/86/head

Conversation

@vkuzo

@vkuzo vkuzo commented Mar 28, 2025

Copy link
Copy Markdown
Contributor

Summary:

Removes the unnecessary cast to bfloat16 in the MX dim0 casting code.
This is a 2.6% speedup on 16k by 16k shape:
https://www.internalfb.com/phabricator/paste/view/P1769373804

Note: this PR also includes a couple of cleanups around e8m0 dtype and
NaN handling, I found them while coding this PR. Leaving them together
instead of
separate PR since they are all safe.

Test Plan:

(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$ python benchmarks/mx_formats/cast_bench.py --mode dim0_mx --M 16384 --K 16384
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.8.0a0+git25309a1
triton version: 3.3.0
mode: dim0_mx
time_us 152.90741052631583
mem_bw_gbps 5321.488168553876
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$ python benchmarks/mx_formats/cast_bench.py --mode dim0_mx --M 16384 --K 16384
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.8.0a0+git25309a1
triton version: 3.3.0
mode: dim0_mx
time_us 149.03950980392162
mem_bw_gbps 5459.5924065404415

Reviewers:

Subscribers:

Tasks:

Tags:

vkuzo added 30 commits March 21, 2025 06:59
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
vkuzo added 4 commits March 28, 2025 13:00
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 28, 2025
Summary:

Removes the unnecessary cast to bfloat16 in the MX dim0 casting code.
This is a 2.6% speedup on 16k by 16k shape:
https://www.internalfb.com/phabricator/paste/view/P1769373804

Note: this PR also includes a couple of cleanups around e8m0 dtype and
NaN handling, I found them while coding this PR. Leaving them together
instead of
separate PR since they are all safe.

Test Plan:

```bash
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$ python benchmarks/mx_formats/cast_bench.py --mode dim0_mx --M 16384 --K 16384
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.8.0a0+git25309a1
triton version: 3.3.0
mode: dim0_mx
time_us 152.90741052631583
mem_bw_gbps 5321.488168553876
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$ python benchmarks/mx_formats/cast_bench.py --mode dim0_mx --M 16384 --K 16384
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.8.0a0+git25309a1
triton version: 3.3.0
mode: dim0_mx
time_us 149.03950980392162
mem_bw_gbps 5459.5924065404415
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 47fb1df
ghstack-comment-id: 2762318741
Pull Request resolved: #1980
vkuzo added 5 commits March 28, 2025 13:02
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 28, 2025
Summary:

Removes the unnecessary cast to bfloat16 in the MX dim0 casting code.
This is a 2.6% speedup on 16k by 16k shape:
https://www.internalfb.com/phabricator/paste/view/P1769373804

Note: this PR also includes a couple of cleanups around e8m0 dtype and
NaN handling, I found them while coding this PR. Leaving them together
instead of
separate PR since they are all safe.

Test Plan:

```bash
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$ python benchmarks/mx_formats/cast_bench.py --mode dim0_mx --M 16384 --K 16384
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.8.0a0+git25309a1
triton version: 3.3.0
mode: dim0_mx
time_us 152.90741052631583
mem_bw_gbps 5321.488168553876
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$ python benchmarks/mx_formats/cast_bench.py --mode dim0_mx --M 16384 --K 16384
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.8.0a0+git25309a1
triton version: 3.3.0
mode: dim0_mx
time_us 149.03950980392162
mem_bw_gbps 5459.5924065404415
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 47fb1df
ghstack-comment-id: 2762318741
Pull Request resolved: #1980
vkuzo added 4 commits March 28, 2025 13:03
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 28, 2025
Summary:

Removes the unnecessary cast to bfloat16 in the MX dim0 casting code.
This is a 2.6% speedup on 16k by 16k shape:
https://www.internalfb.com/phabricator/paste/view/P1769373804

Note: this PR also includes a couple of cleanups around e8m0 dtype and
NaN handling, I found them while coding this PR. Leaving them together
instead of
separate PR since they are all safe.

Test Plan:

```bash
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$ python benchmarks/mx_formats/cast_bench.py --mode dim0_mx --M 16384 --K 16384
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.8.0a0+git25309a1
triton version: 3.3.0
mode: dim0_mx
time_us 152.90741052631583
mem_bw_gbps 5321.488168553876
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$ python benchmarks/mx_formats/cast_bench.py --mode dim0_mx --M 16384 --K 16384
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.8.0a0+git25309a1
triton version: 3.3.0
mode: dim0_mx
time_us 149.03950980392162
mem_bw_gbps 5459.5924065404415
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 47fb1df
ghstack-comment-id: 2762318741
Pull Request resolved: #1980
vkuzo added 3 commits March 28, 2025 13:03
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 28, 2025
Summary:

Removes the unnecessary cast to bfloat16 in the MX dim0 casting code.
This is a 2.6% speedup on 16k by 16k shape:
https://www.internalfb.com/phabricator/paste/view/P1769373804

Note: this PR also includes a couple of cleanups around e8m0 dtype and
NaN handling, I found them while coding this PR. Leaving them together
instead of
separate PR since they are all safe.

Test Plan:

```bash
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$ python benchmarks/mx_formats/cast_bench.py --mode dim0_mx --M 16384 --K 16384
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.8.0a0+git25309a1
triton version: 3.3.0
mode: dim0_mx
time_us 152.90741052631583
mem_bw_gbps 5321.488168553876
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$ python benchmarks/mx_formats/cast_bench.py --mode dim0_mx --M 16384 --K 16384
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.8.0a0+git25309a1
triton version: 3.3.0
mode: dim0_mx
time_us 149.03950980392162
mem_bw_gbps 5459.5924065404415
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 47fb1df
ghstack-comment-id: 2762318741
Pull Request resolved: #1980
vkuzo added a commit that referenced this pull request Apr 1, 2025
Summary:

Removes the unnecessary cast to bfloat16 in the MX dim0 casting code.
This is a 2.6% speedup on 16k by 16k shape:
https://www.internalfb.com/phabricator/paste/view/P1769373804

Note: this PR also includes a couple of cleanups around e8m0 dtype and
NaN handling, I found them while coding this PR. Leaving them together
instead of
separate PR since they are all safe.

Test Plan:

```bash
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$ python benchmarks/mx_formats/cast_bench.py --mode dim0_mx --M 16384 --K 16384
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.8.0a0+git25309a1
triton version: 3.3.0
mode: dim0_mx
time_us 152.90741052631583
mem_bw_gbps 5321.488168553876
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$ python benchmarks/mx_formats/cast_bench.py --mode dim0_mx --M 16384 --K 16384
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.8.0a0+git25309a1
triton version: 3.3.0
mode: dim0_mx
time_us 149.03950980392162
mem_bw_gbps 5459.5924065404415
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 47fb1df
ghstack-comment-id: 2762318741
Pull Request resolved: #1980
vkuzo added 2 commits April 1, 2025 09:40
[ghstack-poisoned]
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Apr 1, 2025
Summary:

Removes the unnecessary cast to bfloat16 in the MX dim0 casting code.
This is a 2.6% speedup on 16k by 16k shape:
https://www.internalfb.com/phabricator/paste/view/P1769373804

Note: this PR also includes a couple of cleanups around e8m0 dtype and
NaN handling, I found them while coding this PR. Leaving them together
instead of
separate PR since they are all safe.

Test Plan:

```bash
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$ python benchmarks/mx_formats/cast_bench.py --mode dim0_mx --M 16384 --K 16384
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.8.0a0+git25309a1
triton version: 3.3.0
mode: dim0_mx
time_us 152.90741052631583
mem_bw_gbps 5321.488168553876
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$ python benchmarks/mx_formats/cast_bench.py --mode dim0_mx --M 16384 --K 16384
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.8.0a0+git25309a1
triton version: 3.3.0
mode: dim0_mx
time_us 149.03950980392162
mem_bw_gbps 5459.5924065404415
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 47fb1df
ghstack-comment-id: 2762318741
Pull Request resolved: #1980
[ghstack-poisoned]
@vkuzo vkuzo changed the base branch from gh/vkuzo/85/head to main April 1, 2025 16:41
vkuzo added a commit that referenced this pull request Apr 1, 2025
Summary:

Removes the unnecessary cast to bfloat16 in the MX dim0 casting code.
This is a 2.6% speedup on 16k by 16k shape:
https://www.internalfb.com/phabricator/paste/view/P1769373804

Note: this PR also includes a couple of cleanups around e8m0 dtype and
NaN handling, I found them while coding this PR. Leaving them together
instead of
separate PR since they are all safe.

Test Plan:

```bash
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$ python benchmarks/mx_formats/cast_bench.py --mode dim0_mx --M 16384 --K 16384
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.8.0a0+git25309a1
triton version: 3.3.0
mode: dim0_mx
time_us 152.90741052631583
mem_bw_gbps 5321.488168553876
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$ python benchmarks/mx_formats/cast_bench.py --mode dim0_mx --M 16384 --K 16384
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.8.0a0+git25309a1
triton version: 3.3.0
mode: dim0_mx
time_us 149.03950980392162
mem_bw_gbps 5459.5924065404415
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 47fb1df
ghstack-comment-id: 2762318741
Pull Request resolved: #1980
@vkuzo vkuzo merged commit aafc1ba into main Apr 1, 2025
liangel-02 pushed a commit that referenced this pull request Aug 25, 2025
* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: performance Use this tag if this PR improves the performance of a feature

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants