Skip to content

Conversation

@lhw414
Copy link

@lhw414 lhw414 commented Jul 26, 2024

  • Added backward var operation and kernel.
  • Added driver test and gtest for var.
  • New API is guarded by MIOPEN_BETA_API macro.

When comparing the newly developed miopen var kernel with ROCm, there's performance improvement for a specific range of input sizes.(1024 ~ 1024 * 1024 * 2)

Type Direction geomean
fp32(cont) bwd 3.24
fp16(cont) bwd 3.26
bfp16(cont) bwd 3.23
fp32(non-cont) bwd 2.85
fp16(non-cont) bwd 3.11
bfp16(non-cont) bwd 3.66

float32(contiguous)
op_name dtype size direction rocm_kernel_avg MIOpen_kernel_avg ROCm / MIOpen
var float32 [2048] bwd 25,088 11,786 2.13
var float32 [8192] bwd 26,960 9,991 2.70
var float32 [65536] bwd 35,584 9,600 3.71
var float32 [131072] bwd 31,664 11,484 2.76
var float32 [262144] bwd 32,608 14,613 2.23
var float32 [524288] bwd 34,544 19,022 1.82
var float32 [1048576] bwd 45,103 32,977 1.37
var float32 [80 40] bwd 52,063 10,115 5.15
var float32 [80 250] bwd 41,632 9,689 4.30
var float32 [30 300] bwd 34,784 9,831 3.54
var float32 [30 40] bwd 23,072 10,755 2.15
var float32 [100 200] bwd 47,520 9,689 4.90
var float32 [300 400] bwd 30,528 11,289 2.70
var float32 [500 600] bwd 32,912 14,240 2.31
var float32 [900 1000] bwd 57,839 30,240 1.91
var float32 [40 50 30] bwd 45,775 9,618 4.76
var float32 [30 50 10] bwd 35,040 10,009 3.50
var float32 [50 40 50] bwd 35,024 10,115 3.46
var float32 [40 60 10] bwd 36,480 9,653 3.78
var float32 [20 30 40] bwd 33,872 9,653 3.51
var float32 [50 30 10] bwd 27,072 9,867 2.74
var float32 [100 200 10] bwd 47,904 11,609 4.13
var float32 [100 200 50] bwd 54,911 32,835 1.67
var float32 [100 200 100] bwd 110,046 57,031 1.93
var float32 [50 10 5 20] bwd 42,047 9,813 4.28
var float32 [50 30 5 4] bwd 38,912 10,009 3.89
var float32 [5 100 10 3] bwd 36,192 10,080 3.59
var float32 [10 20 10 5] bwd 34,144 9,973 3.42
var float32 [20 30 5 5] bwd 64,976 9,778 6.65
var float32 [20 5 10 10] bwd 40,047 9,582 4.18
var float32 [30 10 15 10] bwd 31,472 9,457 3.33
var float32 [10 10 10 10] bwd 26,112 9,689 2.70
var float32 [100 10 10 10] bwd 49,984 9,955 5.02
var float32 [100 20 10 10] bwd 40,592 11,484 3.53
var float32 [100 100 10 10] bwd 130,271 32,906 3.96
var float32 [30 10 5 10 2] bwd 47,264 9,636 4.90
var float32 [30 5 20 5 4] bwd 39,583 10,080 3.93
var float32 [20 10 3 12 4] bwd 48,591 9,813 4.95
var float32 [20 5 10 5 10] bwd 47,487 9,760 4.87
var float32 [40 20 3 2 5] bwd 44,943 9,635 4.66
var float32 [15 3 5 20 12] bwd 56,847 9,493 5.99
var float32 [12 12 4 8 10] bwd 36,144 9,404 3.84
var float32 [5 5 5 10 10] bwd 37,952 9,600 3.95
var float32 [10 4 8 2 4] bwd 24,544 10,649 2.30
var float32 [100 10 10 10 10] bwd 54,159 32,533 1.66
var float32 [100 20 10 10 10] bwd 69,455 57,529 1.21

float16(contiguous)
var float16 [2048] bwd 24,160 11,751 2.06
var float16 [8192] bwd 27,152 9,369 2.90
var float16 [65536] bwd 35,104 9,671 3.63
var float16 [131072] bwd 32,176 11,218 2.87
var float16 [262144] bwd 32,208 13,973 2.31
var float16 [524288] bwd 36,032 19,200 1.88
var float16 [1048576] bwd 42,192 31,732 1.33
var float16 [80 40] bwd 53,487 9,938 5.38
var float16 [80 250] bwd 40,447 9,404 4.30
var float16 [30 300] bwd 34,864 9,458 3.69
var float16 [30 40] bwd 23,872 11,218 2.13
var float16 [100 200] bwd 47,423 9,422 5.03
var float16 [300 400] bwd 30,560 11,324 2.70
var float16 [500 600] bwd 33,648 14,951 2.25
var float16 [900 1000] bwd 57,696 29,226 1.97
var float16 [40 50 30] bwd 41,904 9,689 4.32
var float16 [30 50 10] bwd 34,752 9,956 3.49
var float16 [50 40 50] bwd 35,712 10,080 3.54
var float16 [40 60 10] bwd 38,703 9,386 4.12
var float16 [20 30 40] bwd 29,296 9,458 3.10
var float16 [50 30 10] bwd 26,544 9,635 2.75
var float16 [100 200 10] bwd 48,208 11,395 4.23
var float16 [100 200 50] bwd 53,535 31,555 1.70
var float16 [100 200 100] bwd 98,415 54,684 1.80
var float16 [50 10 5 20] bwd 41,711 9,493 4.39
var float16 [50 30 5 4] bwd 41,439 9,849 4.21
var float16 [5 100 10 3] bwd 37,040 9,689 3.82
var float16 [10 20 10 5] bwd 35,536 9,724 3.65
var float16 [20 30 5 5] bwd 61,920 9,511 6.51
var float16 [20 5 10 10] bwd 39,232 9,902 3.96
var float16 [30 10 15 10] bwd 32,464 9,315 3.49
var float16 [10 10 10 10] bwd 26,192 9,369 2.80
var float16 [100 10 10 10] bwd 48,832 9,707 5.03
var float16 [100 20 10 10] bwd 41,104 11,324 3.63
var float16 [100 100 10 10] bwd 128,559 31,662 4.06
var float16 [30 10 5 10 2] bwd 46,064 9,404 4.90
var float16 [30 5 20 5 4] bwd 39,152 9,742 4.02
var float16 [20 10 3 12 4] bwd 50,320 9,440 5.33
var float16 [20 5 10 5 10] bwd 46,880 9,618 4.87
var float16 [40 20 3 2 5] bwd 40,816 9,493 4.30
var float16 [15 3 5 20 12] bwd 54,543 9,280 5.88
var float16 [12 12 4 8 10] bwd 36,415 9,262 3.93
var float16 [5 5 5 10 10] bwd 37,600 9,387 4.01
var float16 [10 4 8 2 4] bwd 24,272 10,293 2.36
var float16 [100 10 10 10 10] bwd 52,368 31,306 1.67
var float16 [100 20 10 10 10] bwd 63,311 55,021 1.15

bfloat16(contiguous)
op_name dtype size direction rocm_kernel_avg MIOpen_kernel_avg ROCm / MIOpen
var bfloat16 [2048] bwd 19,568 11,396 1.72
var bfloat16 [8192] bwd 26,512 9,920 2.67
var bfloat16 [65536] bwd 35,120 9,671 3.63
var bfloat16 [131072] bwd 32,224 11,609 2.78
var bfloat16 [262144] bwd 31,440 14,293 2.20
var bfloat16 [524288] bwd 35,712 19,982 1.79
var bfloat16 [1048576] bwd 41,760 33,617 1.24
var bfloat16 [80 40] bwd 55,503 9,991 5.56
var bfloat16 [80 250] bwd 41,632 9,511 4.38
var bfloat16 [30 300] bwd 35,024 9,600 3.65
var bfloat16 [30 40] bwd 24,384 11,271 2.16
var bfloat16 [100 200] bwd 49,712 9,582 5.19
var bfloat16 [300 400] bwd 31,680 11,520 2.75
var bfloat16 [500 600] bwd 33,824 14,898 2.27
var bfloat16 [900 1000] bwd 58,319 30,613 1.91
var bfloat16 [40 50 30] bwd 44,159 9,636 4.58
var bfloat16 [30 50 10] bwd 34,624 10,169 3.40
var bfloat16 [50 40 50] bwd 35,295 10,151 3.48
var bfloat16 [40 60 10] bwd 38,624 9,742 3.96
var bfloat16 [20 30 40] bwd 30,272 9,635 3.14
var bfloat16 [50 30 10] bwd 26,624 9,742 2.73
var bfloat16 [100 200 10] bwd 48,608 11,769 4.13
var bfloat16 [100 200 50] bwd 54,128 33,386 1.62
var bfloat16 [100 200 100] bwd 100,527 58,275 1.73
var bfloat16 [50 10 5 20] bwd 44,144 9,475 4.66
var bfloat16 [50 30 5 4] bwd 42,208 10,009 4.22
var bfloat16 [5 100 10 3] bwd 36,240 9,778 3.71
var bfloat16 [10 20 10 5] bwd 36,480 9,902 3.68
var bfloat16 [20 30 5 5] bwd 61,023 9,671 6.31
var bfloat16 [20 5 10 10] bwd 40,512 9,760 4.15
var bfloat16 [30 10 15 10] bwd 32,560 9,404 3.46
var bfloat16 [10 10 10 10] bwd 26,080 9,636 2.71
var bfloat16 [100 10 10 10] bwd 50,751 9,866 5.14
var bfloat16 [100 20 10 10] bwd 43,295 11,431 3.79
var bfloat16 [100 100 10 10] bwd 128,303 33,528 3.83
var bfloat16 [30 10 5 10 2] bwd 49,040 9,600 5.11
var bfloat16 [30 5 20 5 4] bwd 42,032 9,831 4.28
var bfloat16 [20 10 3 12 4] bwd 51,455 9,760 5.27
var bfloat16 [20 5 10 5 10] bwd 45,903 9,724 4.72
var bfloat16 [40 20 3 2 5] bwd 41,008 9,476 4.33
var bfloat16 [15 3 5 20 12] bwd 55,775 9,404 5.93
var bfloat16 [12 12 4 8 10] bwd 36,624 9,351 3.92
var bfloat16 [5 5 5 10 10] bwd 38,960 9,493 4.10
var bfloat16 [10 4 8 2 4] bwd 24,432 10,436 2.34
var bfloat16 [100 10 10 10 10] bwd 55,456 33,226 1.67
var bfloat16 [100 20 10 10 10] bwd 68,383 58,275 1.17

float32(noncontiguous)
op_name dtype size direction rocm_kernel_avg MIOpen_kernel_avg ROCm/MIOpen
var float32 [2048] bwd 28,912 12,213 2.37
var float32 [8192] bwd 28,032 9,973 2.81
var float32 [65536] bwd 43,456 9,724 4.47
var float32 [131072] bwd 34,768 11,520 3.02
var float32 [262144] bwd 36,736 14,436 2.54
var float32 [524288] bwd 40,928 19,769 2.07
var float32 [1048576] bwd 55,664 33,102 1.68
var float32 [80 40] bwd 49,792 11,929 4.17
var float32 [80 250] bwd 41,312 11,467 3.60
var float32 [30 300] bwd 29,728 11,556 2.57
var float32 [30 40] bwd 34,352 12,427 2.76
var float32 [100 200] bwd 46,208 11,431 4.04
var float32 [300 400] bwd 30,640 15,680 1.95
var float32 [500 600] bwd 42,704 22,684 1.88
var float32 [900 1000] bwd 59,871 55,484 1.08
var float32 [40 50 30] bwd 42,992 12,089 3.56
var float32 [30 50 10] bwd 36,511 12,071 3.02
var float32 [50 40 50] bwd 36,928 14,204 2.60
var float32 [40 60 10] bwd 37,424 11,164 3.35
var float32 [20 30 40] bwd 31,904 11,556 2.76
var float32 [50 30 10] bwd 41,103 11,662 3.52
var float32 [100 200 10] bwd 47,408 15,947 2.97
var float32 [100 200 50] bwd 56,399 97,848 0.58
var float32 [50 10 5 20] bwd 40,800 11,644 3.50
var float32 [50 30 5 4] bwd 40,319 11,680 3.45
var float32 [5 100 10 3] bwd 37,712 11,076 3.40
var float32 [10 20 10 5] bwd 34,096 10,898 3.13
var float32 [20 30 5 5] bwd 63,727 10,898 5.85
var float32 [20 5 10 10] bwd 45,296 11,751 3.85
var float32 [30 10 15 10] bwd 32,144 10,898 2.95
var float32 [10 10 10 10] bwd 39,696 11,396 3.48
var float32 [100 10 10 10] bwd 47,727 12,356 3.86
var float32 [100 20 10 10] bwd 39,968 28,249 1.41
var float32 [100 100 10 10] bwd 131,646 82,861 1.59
var float32 [30 10 5 10 2] bwd 45,376 10,827 4.19
var float32 [30 5 20 5 4] bwd 40,607 11,040 3.68
var float32 [20 10 3 12 4] bwd 48,736 10,346 4.71
var float32 [20 5 10 5 10] bwd 43,952 10,738 4.09
var float32 [40 20 3 2 5] bwd 48,464 11,111 4.36
var float32 [15 3 5 20 12] bwd 52,752 11,111 4.75
var float32 [12 12 4 8 10] bwd 34,784 10,560 3.29
var float32 [5 5 5 10 10] bwd 42,720 11,040 3.87
var float32 [10 4 8 2 4] bwd 33,616 11,698 2.87
var float32 [100 10 10 10 10] bwd 79,618 43,576 1.83
var float32 [100 20 10 10 10] bwd 81,288 63,989 1.27

float16(noncontiguous)
op_name dtype size direction rocm_kernel_avg MIOpen_kernel_avg ROCm/MIOpen
var float16 [2048] bwd 30,272 11,929 2.54
var float16 [8192] bwd 28,704 9,475 3.03
var float16 [65536] bwd 32,880 9,582 3.43
var float16 [131072] bwd 34,544 11,413 3.03
var float16 [262144] bwd 37,776 14,258 2.65
var float16 [524288] bwd 48,000 19,715 2.43
var float16 [1048576] bwd 63,791 31,804 2.01
var float16 [80 40] bwd 59,519 11,858 5.02
var float16 [80 250] bwd 42,224 11,360 3.72
var float16 [30 300] bwd 29,920 11,413 2.62
var float16 [30 40] bwd 31,680 11,911 2.66
var float16 [100 200] bwd 49,231 11,235 4.38
var float16 [300 400] bwd 32,015 15,324 2.09
var float16 [500 600] bwd 44,831 22,507 1.99
var float16 [900 1000] bwd 75,199 54,737 1.37
var float16 [40 50 30] bwd 42,896 11,591 3.70
var float16 [30 50 10] bwd 35,087 11,022 3.18
var float16 [50 40 50] bwd 37,712 14,240 2.65
var float16 [40 60 10] bwd 39,120 10,898 3.59
var float16 [20 30 40] bwd 31,056 11,200 2.77
var float16 [50 30 10] bwd 55,072 11,058 4.98
var float16 [100 200 10] bwd 53,968 15,075 3.58
var float16 [100 200 50] bwd 64,160 85,315 0.75
var float16 [50 10 5 20] bwd 49,600 11,182 4.44
var float16 [50 30 5 4] bwd 40,576 11,182 3.63
var float16 [5 100 10 3] bwd 35,840 10,507 3.41
var float16 [10 20 10 5] bwd 33,008 10,204 3.23
var float16 [20 30 5 5] bwd 43,808 10,382 4.22
var float16 [20 5 10 10] bwd 50,704 11,342 4.47
var float16 [30 10 15 10] bwd 32,912 10,222 3.22
var float16 [10 10 10 10] bwd 46,864 10,506 4.46
var float16 [100 10 10 10] bwd 53,152 11,858 4.48
var float16 [100 20 10 10] bwd 39,600 23,360 1.70
var float16 [100 100 10 10] bwd 75,632 48,835 1.55
var float16 [30 10 5 10 2] bwd 45,072 10,418 4.33
var float16 [30 5 20 5 4] bwd 37,679 10,809 3.49
var float16 [20 10 3 12 4] bwd 56,991 10,187 5.59
var float16 [20 5 10 5 10] bwd 54,336 10,258 5.30
var float16 [40 20 3 2 5] bwd 61,152 10,364 5.90
var float16 [15 3 5 20 12] bwd 65,359 10,240 6.38
var float16 [12 12 4 8 10] bwd 36,768 9,867 3.73
var float16 [5 5 5 10 10] bwd 44,624 10,755 4.15
var float16 [10 4 8 2 4] bwd 33,232 10,880 3.05
var float16 [100 10 10 10 10] bwd 71,254 41,356 1.72
var float16 [100 20 10 10 10] bwd 79,803 59,832 1.33

bfloat16(noncontiguous)
op_name dtype size direction rocm_kernel_avg MIOpen_kernel_avg ROCm/MIOpen
var bfloat16 [2048] bwd 30,352 11,680 2.60
var bfloat16 [8192] bwd 32,160 9,760 3.30
var bfloat16 [65536] bwd 34,992 9,707 3.60
var bfloat16 [131072] bwd 37,776 11,893 3.18
var bfloat16 [262144] bwd 48,288 15,004 3.22
var bfloat16 [524288] bwd 65,087 20,462 3.18
var bfloat16 [1048576] bwd 101,455 33,457 3.03
var bfloat16 [80 40] bwd 44,896 12,426 3.61
var bfloat16 [80 250] bwd 45,936 11,360 4.04
var bfloat16 [30 300] bwd 32,416 11,911 2.72
var bfloat16 [30 40] bwd 34,064 12,978 2.62
var bfloat16 [100 200] bwd 53,424 11,378 4.70
var bfloat16 [300 400] bwd 35,520 15,360 2.31
var bfloat16 [500 600] bwd 57,023 22,542 2.53
var bfloat16 [900 1000] bwd 117,070 55,484 2.11
var bfloat16 [40 50 30] bwd 53,903 11,733 4.59
var bfloat16 [30 50 10] bwd 34,912 11,840 2.95
var bfloat16 [50 40 50] bwd 45,119 13,938 3.24
var bfloat16 [40 60 10] bwd 42,416 11,111 3.82
var bfloat16 [20 30 40] bwd 33,184 11,875 2.79
var bfloat16 [50 30 10] bwd 69,087 11,662 5.92
var bfloat16 [100 200 10] bwd 69,327 15,928 4.35
var bfloat16 [100 200 50] bwd 76,322 19,825 3.85
var bfloat16 [50 10 5 20] bwd 55,072 10,933 5.04
var bfloat16 [50 30 5 4] bwd 46,160 11,413 4.04
var bfloat16 [5 100 10 3] bwd 34,800 10,862 3.20
var bfloat16 [10 20 10 5] bwd 40,207 10,933 3.68
var bfloat16 [20 30 5 5] bwd 52,895 10,738 4.93
var bfloat16 [20 5 10 10] bwd 62,927 11,875 5.30
var bfloat16 [30 10 15 10] bwd 35,312 10,755 3.28
var bfloat16 [10 10 10 10] bwd 60,528 10,613 5.70
var bfloat16 [100 10 10 10] bwd 68,527 11,911 5.75
var bfloat16 [100 20 10 10] bwd 45,888 23,360 1.96
var bfloat16 [100 100 10 10] bwd 121,263 49,688 2.44
var bfloat16 [30 10 5 10 2] bwd 47,792 10,649 4.49
var bfloat16 [30 5 20 5 4] bwd 40,320 10,755 3.75
var bfloat16 [20 10 3 12 4] bwd 76,591 10,258 7.47
var bfloat16 [20 5 10 5 10] bwd 68,847 10,258 6.71
var bfloat16 [40 20 3 2 5] bwd 86,159 10,898 7.91
var bfloat16 [15 3 5 20 12] bwd 65,535 10,364 6.32
var bfloat16 [12 12 4 8 10] bwd 37,552 9,955 3.77
var bfloat16 [5 5 5 10 10] bwd 66,432 10,844 6.13
var bfloat16 [10 4 8 2 4] bwd 37,344 11,111 3.36
var bfloat16 [100 10 10 10 10] bwd 89,745 46,571 1.93
var bfloat16 [100 20 10 10 10] bwd 91,338 69,890 1.31

@et16kr
Copy link

et16kr commented Jul 29, 2024

/home/MIOpen/src/solver/var/../../kernels/tensor_utils.hpp:51:21: error: function 'GET_STRIDED_INDEX' defined in a header file; function definitions in header files can lead to ODR violations [misc-definitions-in-headers,-warnings-as-errors]
   51 | __device__ uint64_t GET_STRIDED_INDEX(const uint64_t indices[5], const uint64_t strides[5])
      |                     ^
/home/MIOpen/src/solver/var/../../kernels/tensor_utils.hpp:51:21: note: make as 'inline'
   51 | __device__ uint64_t GET_STRIDED_INDEX(const uint64_t indices[5], const uint64_t strides[5])
      |                     ^
      | inline
/home/MIOpen/src/solver/var/backward_var.cpp:44:56: error: parameter 'context' is unused [misc-unused-parameters,-warnings-as-errors]
   44 | bool VarBackward::IsApplicable(const ExecutionContext& context,
      |                                                        ^~~~~~~
      |                                                         /*context*/
/home/MIOpen/src/solver/var/backward_var.cpp:54:63: error: parameter 'context' is unused [misc-unused-parameters,-warnings-as-errors]
   54 | ConvSolution VarBackward::GetSolution(const ExecutionContext& context,
      |                                                               ^~~~~~~
      |                                                                /*context*/
make[3]: *** [src/CMakeFiles/tidy-target-MIOpen-solver_var_backward_var_cpp.dir/build.make:71: src/CMakeFiles/tidy-target-MIOpen-solver_var_backward_var_cpp] Error 1
make[2]: *** [CMakeFiles/Makefile2:9448: src/CMakeFiles/tidy-target-MIOpen-solver_var_backward_var_cpp.dir/all] Error 2

make analyze 확인 부탁 드립니다.

@lhw414
Copy link
Author

lhw414 commented Jul 30, 2024

/home/MIOpen/src/solver/var/../../kernels/tensor_utils.hpp:51:21: error: function 'GET_STRIDED_INDEX' defined in a header file; function definitions in header files can lead to ODR violations [misc-definitions-in-headers,-warnings-as-errors]
   51 | __device__ uint64_t GET_STRIDED_INDEX(const uint64_t indices[5], const uint64_t strides[5])
      |                     ^
/home/MIOpen/src/solver/var/../../kernels/tensor_utils.hpp:51:21: note: make as 'inline'
   51 | __device__ uint64_t GET_STRIDED_INDEX(const uint64_t indices[5], const uint64_t strides[5])
      |                     ^
      | inline
/home/MIOpen/src/solver/var/backward_var.cpp:44:56: error: parameter 'context' is unused [misc-unused-parameters,-warnings-as-errors]
   44 | bool VarBackward::IsApplicable(const ExecutionContext& context,
      |                                                        ^~~~~~~
      |                                                         /*context*/
/home/MIOpen/src/solver/var/backward_var.cpp:54:63: error: parameter 'context' is unused [misc-unused-parameters,-warnings-as-errors]
   54 | ConvSolution VarBackward::GetSolution(const ExecutionContext& context,
      |                                                               ^~~~~~~
      |                                                                /*context*/
make[3]: *** [src/CMakeFiles/tidy-target-MIOpen-solver_var_backward_var_cpp.dir/build.make:71: src/CMakeFiles/tidy-target-MIOpen-solver_var_backward_var_cpp] Error 1
make[2]: *** [CMakeFiles/Makefile2:9448: src/CMakeFiles/tidy-target-MIOpen-solver_var_backward_var_cpp.dir/all] Error 2

make analyze 확인 부탁 드립니다.

수정 완료했습니다.

@kyeonghwanryu
Copy link

성능 측정 요약에서 cont와 non-cont를 분리해 주세요.

@lhw414
Copy link
Author

lhw414 commented Aug 1, 2024

성능 측정 요약에서 cont와 non-cont를 분리해 주세요.

non-cont 커널 성능 측정 따로 완료하여 해당 테스트 결과 추가해두었습니다.

@kyeonghwanryu
Copy link

non-cont 통게에서 1d는 빼야겠네요. 그거 말고는 괜찮아 보입니다. 수고하셨습니다.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants