Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SOTA 2-bit quants - part 2 #4856

Merged
merged 11 commits into from
Jan 11, 2024
Prev Previous commit
Next Next commit
iq2_xs: faster AVX2 dit product
21.4 t/s for TG-128, 59.2 t/s for PP-512.
The latter is 2x compared to the previous version.
  • Loading branch information
Kawrakow committed Jan 10, 2024
commit 8299b03a99fe88220720eed3ecdeea8202ca7dcd
43 changes: 30 additions & 13 deletions ggml-quants.c
Original file line number Diff line number Diff line change
Expand Up @@ -7605,36 +7605,53 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest

#elif defined(__AVX2__)

const __m128i m4 = _mm_set1_epi8(0xf);
const __m128i m1 = _mm_set1_epi8(1);
const __m128i m511 = _mm_set1_epi16(511);
const __m128i m127 = _mm_set1_epi16(127);

const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;

uint64_t aux64;

// somewhat hacky, but gives a significant boost in performance
__m128i aux_gindex, aux_sindex;
const uint16_t * gindex = (const uint16_t *)&aux_gindex;
const uint16_t * sindex = (const uint16_t *)&aux_sindex;

__m256 accumf = _mm256_setzero_ps();
for (int i = 0; i < nb; ++i) {
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
const uint16_t * restrict q2 = x[i].qs;
const uint8_t * restrict sc = x[i].scales;
const int8_t * restrict q8 = y[i].qs;

memcpy(&aux64, x[i].scales, 8);
__m128i stmp = _mm_set1_epi64x(aux64);
stmp = _mm_unpacklo_epi8(_mm_and_si128(stmp, m4), _mm_and_si128(_mm_srli_epi16(stmp, 4), m4));
const __m128i scales = _mm_add_epi8(_mm_slli_epi16(stmp, 1), m1);

__m256i sumi1 = _mm256_setzero_si256();
__m256i sumi2 = _mm256_setzero_si256();
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[q2[3] & 511], iq2xs_grid[q2[2] & 511], iq2xs_grid[q2[1] & 511], iq2xs_grid[q2[0] & 511]);
const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[q2[7] & 511], iq2xs_grid[q2[6] & 511], iq2xs_grid[q2[5] & 511], iq2xs_grid[q2[4] & 511]);
const __m256i s2_1 = _mm256_set_epi64x(signs64[q2[3] >> 9], signs64[q2[2] >> 9], signs64[q2[1] >> 9], signs64[q2[0] >> 9]);
const __m256i s2_2 = _mm256_set_epi64x(signs64[q2[7] >> 9], signs64[q2[6] >> 9], signs64[q2[5] >> 9], signs64[q2[4] >> 9]);
const __m128i q2_data = _mm_loadu_si128((const __m128i*)q2); q2 += 8;
aux_gindex = _mm_and_si128(q2_data, m511);
aux_sindex = _mm_and_si128(_mm_srli_epi16(q2_data, 9), m127);
const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[gindex[3]], iq2xs_grid[gindex[2]], iq2xs_grid[gindex[1]], iq2xs_grid[gindex[0]]);
const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[gindex[7]], iq2xs_grid[gindex[6]], iq2xs_grid[gindex[5]], iq2xs_grid[gindex[4]]);
const __m256i s2_1 = _mm256_set_epi64x(signs64[sindex[3]], signs64[sindex[2]], signs64[sindex[1]], signs64[sindex[0]]);
const __m256i s2_2 = _mm256_set_epi64x(signs64[sindex[7]], signs64[sindex[6]], signs64[sindex[5]], signs64[sindex[4]]);
const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1);
const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2);
const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);

const uint16_t ls1 = 2*(sc[0] & 0xf) + 1, ls2 = 2*(sc[0] >> 4) + 1;
const uint16_t ls3 = 2*(sc[1] & 0xf) + 1, ls4 = 2*(sc[1] >> 4) + 1;
const __m256i p1 = _mm256_madd_epi16(dot1, MM256_SET_M128I(_mm_set1_epi16(ls2), _mm_set1_epi16(ls1)));
const __m256i p2 = _mm256_madd_epi16(dot2, MM256_SET_M128I(_mm_set1_epi16(ls4), _mm_set1_epi16(ls3)));
sumi1 = _mm256_add_epi32(sumi1, p1);
sumi2 = _mm256_add_epi32(sumi2, p2);
q2 += 8;
sc += 2;
const __m256i sc1 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0)));
const __m256i sc2 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1)));

sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot1, sc1));
sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot2, sc2));
}

accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
Expand Down
Loading