Skip to content

Commit

Permalink
Merge branch 'tk2' of github.com:/HazyResearch/ThunderKittens into tk2
Browse files Browse the repository at this point in the history
  • Loading branch information
benjaminfspector committed Oct 27, 2024
2 parents e137133 + fc62669 commit 9626be2
Show file tree
Hide file tree
Showing 9 changed files with 593 additions and 979 deletions.
379 changes: 140 additions & 239 deletions tests/python/attention/implementations.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import torch
from functools import partial
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
import os
import time
import thunderkittens as tk

try:
Expand Down Expand Up @@ -44,247 +41,151 @@ def get_attention_inputs(b, h, n, dv, dt=torch.bfloat16, pad_multiple=0, ):
return q, k, v, dO


def pytorch_test(dt, b, h, n, dv, causal, is_forwards, verbose=True, **kwargs):
q, k, v, dO = get_attention_inputs(b, h, n, dv, dt)

start_events = [torch.cuda.Event(enable_timing=True) for _ in range(1)]
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(1)]

try:
torch.cuda.synchronize()
start_events[0].record()

QK = torch.matmul(q, k.transpose(-2, -1))
QK /= (q.size(-1) ** 0.5)
if causal:
mask = torch.triu(torch.ones(QK.size(-2), QK.size(-1)), 1).to(torch.bool).to(QK.device)
QK.masked_fill_(mask, float('-inf'))
QK = torch.nn.functional.softmax(QK, dim=-1)
y = torch.matmul(QK, v)

end_events[0].record()
torch.cuda.synchronize()
tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)][0]

except Exception as e:
if verbose:
print(f"Error: {e}")
tot = -1
y = None

if is_forwards:
return y, tot

try:
torch.cuda.synchronize()
start_events[0].record()

y.backward(dO)

q_grad = q.grad
k_grad = k.grad
v_grad = v.grad

q_grad = q_grad.to(torch.bfloat16)
k_grad = k_grad.to(torch.bfloat16)
v_grad = v_grad.to(torch.bfloat16)
y = y.to(torch.bfloat16)

end_events[0].record()
torch.cuda.synchronize()
tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)][0]

except Exception as e:
if verbose:
print(f"Error: {e}")
tot = -1
q_grad = None
k_grad = None
v_grad = None

return (q_grad, k_grad, v_grad), tot


def fa2_test(dt, b, h, n, dv, causal, is_forwards, verbose=True, **kwargs):
q, k, v, dO = get_attention_inputs(b, h, n, dv, dt)

start_events = [torch.cuda.Event(enable_timing=True) for _ in range(1)]
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(1)]
try:
torch.cuda.synchronize()
start_events[0].record()

y = torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=causal)

end_events[0].record()
torch.cuda.synchronize()
tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)][0]
except Exception as e:
if verbose:
print(f"Error: {e}")
tot = -1
y = None

if is_forwards:
return y, tot

start_events = [torch.cuda.Event(enable_timing=True) for _ in range(1)]
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(1)]
try:
torch.cuda.synchronize()
start_events[0].record()

y.backward(dO)

q_grad = q.grad
k_grad = k.grad
v_grad = v.grad

end_events[0].record()
torch.cuda.synchronize()
tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)][0]
except Exception as e:
if verbose:
print(f"Error: {e}")
tot = -1
q_grad = None
k_grad = None
v_grad = None
return (q_grad, k_grad, v_grad), tot


def tk_test(dt, b, h, n, dv, causal, is_forwards, verbose=True, **kwargs):
q, k, v, dO = get_attention_inputs(b, h, n, dv, dt) #, pad_multiple=192)

start_events = [torch.cuda.Event(enable_timing=True) for _ in range(1)]
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(1)]

try:
torch.cuda.synchronize()
start_events[0].record()

y, l_vec = tk.mha_forward(q, k, v, causal)

end_events[0].record()
torch.cuda.synchronize()
tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)][0]
assert not np.isnan(y.detach().float().cpu()).any(), "NaN values detected in output 'o'"
assert not np.isinf(y.detach().float().cpu()).any(), "Inf values detected in output 'o'"
except Exception as e:
if verbose:
print(f"Error: {e}")
tot = -1
y = None

if is_forwards:
return y, tot

start_events = [torch.cuda.Event(enable_timing=True) for _ in range(1)]
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(1)]
try:
torch.cuda.synchronize()
start_events[0].record()

q_grad, k_grad, v_grad = tk.mha_backward(q, k, v, y, l_vec, dO, causal)

end_events[0].record()
torch.cuda.synchronize()
tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)][0]
assert not np.isnan(q_grad.float().cpu()).any(), "NaN values detected in output 'q_grad'"
assert not np.isinf(q_grad.float().cpu()).any(), "Inf values detected in output 'q_grad'"
assert not np.isnan(k_grad.float().cpu()).any(), "NaN values detected in output 'k_grad'"
assert not np.isinf(k_grad.float().cpu()).any(), "Inf values detected in output 'k_grad'"
assert not np.isnan(v_grad.float().cpu()).any(), "NaN values detected in output 'v_grad'"
assert not np.isinf(v_grad.float().cpu()).any(), "Inf values detected in output 'v_grad'"

except Exception as e:
if verbose:
print(f"Error: {e}")
tot = -1
q_grad = None
k_grad = None
v_grad = None

return (q_grad, k_grad, v_grad), tot


def fa3_test(dt, b, h, n, dv, causal, is_forwards, verbose=True, **kwargs):
q, k, v, dO = get_attention_inputs(b, h, n, dv, dt)
q = q.transpose(1,2)
k = k.transpose(1,2)
v = v.transpose(1,2)
dO = dO.transpose(1,2)

start_events = [torch.cuda.Event(enable_timing=True) for _ in range(1)]
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(1)]

try:
torch.cuda.synchronize()
start_events[0].record()

y, lse = flash_attn_func3(q, k, v, causal=causal)

end_events[0].record()
torch.cuda.synchronize()
tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)][0]
assert not np.isnan(y.detach().float().cpu()).any(), "NaN values detected in output 'o'"
assert not np.isinf(y.detach().float().cpu()).any(), "Inf values detected in output 'o'"
def attention_test(dt, b, h, n, dv, causal, is_forwards, method_str, num_iters=10, verbose=True, **kwargs):

for stage in ['warmup', 'timed']:

start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_iters)]
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_iters)]

for i in range(num_iters):
try:
q, k, v, dO = get_attention_inputs(b, h, n, dv, dt)

if method_str == "pytorch":
torch.cuda.synchronize()
start_events[i].record()
QK = torch.matmul(q, k.transpose(-2, -1))
QK /= (q.size(-1) ** 0.5)
if causal:
mask = torch.triu(torch.ones(QK.size(-2), QK.size(-1)), 1).to(torch.bool).to(QK.device)
QK.masked_fill_(mask, float('-inf'))
QK = torch.nn.functional.softmax(QK, dim=-1)
y = torch.matmul(QK, v)
end_events[i].record()
torch.cuda.synchronize()

elif method_str == "fa2":
torch.cuda.synchronize()
start_events[i].record()
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=causal)
end_events[i].record()
torch.cuda.synchronize()

elif method_str == 'tk':
torch.cuda.synchronize()
start_events[i].record()
y, l_vec = tk.mha_forward(q, k, v, causal)
end_events[i].record()
torch.cuda.synchronize()

elif method_str == "fa3":
q = q.transpose(1,2)
k = k.transpose(1,2)
v = v.transpose(1,2)
dO = dO.transpose(1,2)
torch.cuda.synchronize()
start_events[i].record()
y, lse = flash_attn_func3(q, k, v, causal=causal)
end_events[i].record()
torch.cuda.synchronize()

else:
assert 0, f"Unknown method: {method_str}"

outputs = ( y )

except Exception as e:
if verbose:
print(f"Error: {e}")
return None, -1

if is_forwards and stage == 'timed':
continue

try:

if method_str == "pytorch":
torch.cuda.synchronize()
start_events[i].record()
y.backward(dO)
q_grad = q.grad.to(torch.bfloat16)
k_grad = k.grad.to(torch.bfloat16)
v_grad = v.grad.to(torch.bfloat16)
y = y.to(torch.bfloat16)
end_events[i].record()
torch.cuda.synchronize()

elif method_str == "fa2":
torch.cuda.synchronize()
start_events[i].record()
y.backward(dO)
q_grad = q.grad
k_grad = k.grad
v_grad = v.grad
end_events[i].record()
torch.cuda.synchronize()

elif method_str == 'tk':
torch.cuda.synchronize()
start_events[i].record()
q_grad, k_grad, v_grad = tk.mha_backward(q, k, v, y, l_vec, dO, causal)
end_events[i].record()
torch.cuda.synchronize()

elif method_str == "fa3":
torch.cuda.synchronize()
start_events[i].record()
y.backward(dO)
q_grad = q.grad
k_grad = k.grad
v_grad = v.grad
end_events[i].record()
torch.cuda.synchronize()

else:
assert 0, f"Unknown method: {method_str}"

outputs = (q_grad, k_grad, v_grad)

except Exception as e:
if verbose:
print(f"Error: {e}")
return (None, None, None), -1

torch.cuda.empty_cache()

except Exception as e:
if verbose:
print(f"Error: {e}")
tot = -1
y = None
tot = sum([s.elapsed_time(e) for s, e in zip(start_events, end_events)])/num_iters
# for output in outputs:
# try:
# assert not np.isnan(output.detach().float().cpu()).any(), "NaN values detected in output 'o'"
# assert not np.isinf(output.detach().float().cpu()).any(), "Inf values detected in output 'o'"
# except:
# breakpoint()
return outputs, tot

if is_forwards:
return y, tot

try:
torch.cuda.synchronize()
start_events[0].record()

y.backward(dO)

q_grad = q.grad
k_grad = k.grad
v_grad = v.grad

end_events[0].record()
torch.cuda.synchronize()
tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)][0]
except:
if verbose:
print(f"Error: {sys.exc_info()[0]}")
tot = -1
q_grad = None
k_grad = None
v_grad = None

return (q_grad, k_grad, v_grad), tot



from functools import partial
IMPLEMENTATIONS = {
"fwd_PT_c=t": partial(pytorch_test, causal=True, is_forwards=True),
"fwd_FA2_c=t": partial(fa2_test, causal=True, is_forwards=True),
"fwd_FA3_c=t": partial(fa3_test, causal=True, is_forwards=True),
"fwd_TK_c=t": partial(tk_test, causal=True, is_forwards=True),
"fwd_PT_c=f": partial(pytorch_test, causal=False, is_forwards=True),
"fwd_FA2_c=f": partial(fa2_test, causal=False, is_forwards=True),
"fwd_FA3_c=f": partial(fa3_test, causal=False, is_forwards=True),
"fwd_TK_c=f": partial(tk_test, causal=False, is_forwards=True),
"attn_fwd_PT_c=t": partial(attention_test, causal=True, is_forwards=True, method_str="pytorch"),
"attn_fwd_FA2_c=t": partial(attention_test, causal=True, is_forwards=True, method_str="fa2"),
"attn_fwd_FA3_c=t": partial(attention_test, causal=True, is_forwards=True, method_str="fa3"),
"attn_fwd_TK_c=t": partial(attention_test, causal=True, is_forwards=True, method_str="tk"),
"attn_fwd_PT_c=f": partial(attention_test, causal=False, is_forwards=True, method_str="pytorch"),
"attn_fwd_FA2_c=f": partial(attention_test, causal=False, is_forwards=True, method_str="fa2"),
"attn_fwd_FA3_c=f": partial(attention_test, causal=False, is_forwards=True, method_str="fa3"),
"attn_fwd_TK_c=f": partial(attention_test, causal=False, is_forwards=True, method_str="tk"),
}

IMPLEMENTATIONS = {
"bwd_PT_c=t": partial(pytorch_test, causal=True, is_forwards=False),
"bwd_FA2_c=t": partial(fa2_test, causal=True, is_forwards=False),
"bwd_FA3_c=t": partial(fa3_test, causal=True, is_forwards=False),
"bwd_TK_c=t": partial(tk_test, causal=True, is_forwards=False),
"bwd_PT_c=f": partial(pytorch_test, causal=False, is_forwards=False),
"bwd_FA2_c=f": partial(fa2_test, causal=False, is_forwards=False),
"bwd_FA3_c=f": partial(fa3_test, causal=False, is_forwards=False),
"bwd_TK_c=f": partial(tk_test, causal=False, is_forwards=False),
IMPLEMENTATIONS_BWD = {
"attn_bwd_PT_c=t": partial(attention_test, causal=True, is_forwards=False, method_str="pytorch"),
"attn_bwd_FA2_c=t": partial(attention_test, causal=True, is_forwards=False, method_str="fa2"),
"attn_bwd_FA3_c=t": partial(attention_test, causal=True, is_forwards=False, method_str="fa3"),
"attn_bwd_TK_c=t": partial(attention_test, causal=True, is_forwards=False, method_str="tk"),
"attn_bwd_PT_c=f": partial(attention_test, causal=False, is_forwards=False, method_str="pytorch"),
"attn_bwd_FA2_c=f": partial(attention_test, causal=False, is_forwards=False, method_str="fa2"),
"attn_bwd_FA3_c=f": partial(attention_test, causal=False, is_forwards=False, method_str="fa3"),
"attn_bwd_TK_c=f": partial(attention_test, causal=False, is_forwards=False, method_str="tk"),
}

NAME = "ATTENTION"

Loading

0 comments on commit 9626be2

Please sign in to comment.