-
Notifications
You must be signed in to change notification settings - Fork 22.5k
/
cudagraph_trees.py
2480 lines (2030 loc) · 97 KB
/
cudagraph_trees.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
CUDA graph trees are a safety abstraction over CUDAGraphs, similar to make_graph_callables,
which share the same memory pool. Sharing a memory pool is an extremely
important optimization when chaining multiple CUDA graphs together, as it
prevents you from needing to copy intermediate tensors from one graph to the
next, and reduces overall memory usage by allowing dead memory from the first
pool to be reused in the second.
The standard graph/make_graph_callables support sharing memory pool, but
with a lot of caveats. CUDA graph trees remove these restrictions:
* Previously, if you recorded graphs A, B, you had to replay A, B in that
order. With CUDA graph trees, after replaying A, you can change your
mind and record/replay a different graph B'; we will support efficient
execution of both A, B and A, B', using only max(mem(A, B), mem(A, B')). In
other words: we support arbitrary trees of CUDA graph operations, not just
sequences (this is why this feature is called CUDA graph trees.)
* Previously, if you executed graph A, some non-CUDA graph code, and then
graph B, after executing graph B, it was not safe to retain any references
to intermediates produced by A. With CUDA graph trees, we track if any
outputs of graph A are still live by the time graph B is run, and make
sure graph B doesn't clobber there memory when reusing the CUDA graphs
pool. You'll get a separate recording of B depending on what tensors
stay live or dead.
CUDA graph trees are flexible enough to be used in Dynamo across graph breaks,
which is their primary use case.
The ability to switch from replay to record is fairly nontrivial: remember that
when you replay a CUDA graph, you only replay CUDA operations; no CPU side state
is updated. In particular, the CPU-side book-keeping for the allocator is not
reconstructed. However, to record a new child CUDA graph, we must restore this
book-keeping. This is what checkpoint pool state is used for.
"""
from __future__ import annotations
import contextlib
import dataclasses
import functools
import gc
import itertools
import operator
import sys
import threading
import traceback
import warnings
import weakref
from collections import defaultdict
from enum import auto, Enum
from typing import (
Any,
Callable,
cast,
ContextManager,
Dict,
Generator,
Iterator,
List,
Optional,
Sequence,
Set,
Tuple,
Type,
TYPE_CHECKING,
TypeVar,
Union,
)
import torch.fx
from torch import Tensor
from torch._dynamo.mutation_guard import GenerationTracker
from torch._dynamo.utils import counters, preserve_rng_state
from torch._inductor.compile_fx import (
align_inputs_from_check_idxs,
copy_misaligned_inputs,
get_expanded_dims,
get_input_idxs_to_check,
index_expanded_dims,
remove_unaligned_input_idxs,
static_input,
)
from torch._inductor.cudagraph_utils import (
check_for_mutation,
CheckInvariantStatus,
FunctionID,
log_cudagraph_skip_and_bump_counter,
log_data_ptr_mismatch,
maybe_warning_due_to_dynamic_shape,
ModelType,
OutputType,
PlaceholderInfo,
WrappedFunction,
)
from torch.multiprocessing.reductions import StorageWeakRef
from torch.storage import UntypedStorage
from torch.utils import _pytree as pytree
from torch.utils.weak import TensorWeakRef
if TYPE_CHECKING:
from torch._inductor.utils import InputType
from torch.types import _bool
StorageWeakRefPointer = int
StorageDataPtr = int
NBytes = int
S = TypeVar("S", bound="StorageWeakRefWrapper")
if torch.backends.cuda.is_built():
from torch._C import (
_cuda_CUDAAllocator_AllocatorState as AllocatorState,
_set_cached_tensors_enabled as _set_cached_tensors_enabled,
)
else:
class AllocatorState: # type: ignore[no-redef]
pass
def _set_cached_tensors_enabled(enabled: _bool) -> None:
pass
log = torch._logging.getArtifactLogger(__name__, "cudagraphs")
from . import config
@dataclasses.dataclass(frozen=True)
class GraphID:
"Unique counter of a cuda graph recording"
id: int
def clear_cublass_cache() -> None:
"""
Cublas keeps a persistent workspace allocation for running matmuls. This poses a problem for
doing warmup within a CUDAGraph private pool because we do not want persistent allocations from
one one run to the next. When we begin a new run of a cudagraphs path (generation), all tensors
from the previous generation are freed. This frees them the memory pool, but not elsewhere.
A tensor in the cublas workspace would continue to be in use the workspace but would also get allocated
in the next run. The memory would be in use in two places.
To solve this, we clear cublas caches before and after warming up or recording. If a workspace is required
it will be allocated to the cudagraph private pool and accounted for in the allocator for the duration of the
program. There is no overhead to this on replay since cudagraphs removes allocation overhead.
"""
torch._C._cuda_clearCublasWorkspaces()
@contextlib.contextmanager
def clear_cublas_manager() -> Generator[None, None, None]:
"Context manager around clearing cublas caches that will clear on enter and exit"
clear_cublass_cache()
try:
yield
finally:
clear_cublass_cache()
@contextlib.contextmanager
def disable_conv_cache_emptying() -> Generator[None, None, None]:
prev = torch._C._cuda_get_conv_benchmark_empty_cache()
torch._C._cudnn_set_conv_benchmark_empty_cache(False)
try:
yield
finally:
torch._C._cudnn_set_conv_benchmark_empty_cache(prev)
@contextlib.contextmanager
def enable_history_recording() -> Generator[None, None, None]:
"Turns on history recording in the CUDA Caching Allocator"
enabled = torch._C._cuda_isHistoryEnabled()
try:
if not enabled:
torch.cuda.memory._record_memory_history()
yield
finally:
if not enabled:
torch.cuda.memory._record_memory_history(None)
def get_history_recording() -> ContextManager[None]:
# TODO - remove, prevents cleanup
if not config.triton.cudagraph_trees_history_recording:
return contextlib.nullcontext()
return enable_history_recording()
class TreeManagerContainer:
"""
Manages the lifetime of the tree manager. Like `PrivatePool` in cuda caching allocator,
the tree and its corresponding memory pool should be kept alive as long as any outstanding
graph or tensor which is an output of a graph remains alive.
There is a single tree manager container per device.
The lifecycle of a tree_manager is:
- Is constructed, no graph, no fns, no tensors
- Tree manager is fetched, resulting in tree manager being allocated
- We generate a bunch of functions, calling add_strong_reference
- These functions die, calling finalize_reference
- When all the functions die, we finalize_tree_manager.
TODO: in the future, we would like to do the following once storage weak refs land
- We look for all the live storages and add references to THOSE
- We count as storages die
- All the storages are dead, we deallocate the tree manager
"""
def __init__(self, device_index: int) -> None:
# This class keeps a strong reference to tree_manager,
# but upon all other strong references to the tree_manager will reset it to None.
# We need a strong reference so that we can still access its attributes upon cleanup.
self.tree_manager: Optional[CUDAGraphTreeManager] = None
# Number of outstanding references to the current tree manager
self.live_cudagraphify_fns = 0
self.device_index = device_index
# Following two objects are only set in the case that Tensor outputs outlive
# the cudagraphify_fns. Reference to the Graph is needed to keep the private pool from
# deallocation.
self.live_storages_count = 0
self.graph: Optional[torch.cuda.CUDAGraph] = None
self.lock = threading.Lock()
def _finalize_tensor(self) -> None:
with self.lock:
self.live_storages_count -= 1
if self.live_storages_count == 0:
self.graph = None
# manager was used again after existing cleanup,
# we shouldnt set it to None
if self.live_cudagraphify_fns == 0:
self.tree_manager = None
def finalize_cudagraphify_fn(self) -> None:
with self.lock:
self.live_cudagraphify_fns -= 1
if self.live_cudagraphify_fns == 0:
self._finalize_tree_manager()
def _finalize_tree_manager(self) -> None:
assert self.lock.locked()
self.tree_manager = None
# TODO - when issue #91395 is landed, we can set a weakref on
# storages and trigger a deallocation when all outputs of the
# cudagraph are dead.
# live_storages = list(
# tree_manager.live_cudagraph_pool_storages_in_curr_execution()
# )
# # Maintain reference to graph to keep tensors alive
# assert len(tree_manager.roots) > 0, "expected at least one use"
# root = next(tree_manager.get_roots())
# self.graph = root.graph
# seen_storages = set()
# for stor in live_storages:
# if stor in seen_storages:
# continue
# seen_storages.add(stor)
# self.live_storages_count += 1
# . weakref.finalize(stor, self._finalize_tensor)
def add_strong_reference(self, fn: Callable[..., Any]) -> None:
with self.lock:
self.live_cudagraphify_fns += 1
weakref.finalize(fn, self.finalize_cudagraphify_fn)
def get_tree_manager(self) -> CUDAGraphTreeManager:
with self.lock:
if self.tree_manager is None:
self.tree_manager = CUDAGraphTreeManager(self.device_index)
return self.tree_manager
local = threading.local()
# one tree manager per device
local.tree_manager_containers = {}
local.tree_manager_locks = defaultdict(threading.Lock)
# only incremented by user call of mark_step_begin
class MarkStepBox:
mark_step_counter = 0
# We need to register this as an object that will be copied over as TLS when new
# threads are created in autograd
torch._C._stash_obj_in_tls("tree_manager_containers", local.tree_manager_containers)
torch._C._stash_obj_in_tls("tree_manager_locks", local.tree_manager_locks)
def mark_step_begin() -> None:
"Indicates that a new iteration of inference or training is about to begin."
# iterate down to distinguish from GenerationTracking counter
MarkStepBox.mark_step_counter -= 1
def reset_cudagraph_trees() -> None:
"Clear all cudagraph trees"
# see shutdown below for why this is necessary
container_dict = get_obj(local, "tree_manager_containers")
locks_dict = get_obj(local, "tree_manager_locks")
for device, lock in locks_dict.items():
with lock:
container = container_dict.get(device)
if not container or not container.tree_manager:
continue
container.tree_manager.shutdown()
_set_cached_tensors_enabled(False)
container_dict.clear()
MarkStepBox.mark_step_counter = 0
def get_obj(local: Any, attr_name: str) -> Any:
if hasattr(local, attr_name):
return getattr(local, attr_name)
else:
assert torch._C._is_key_in_tls(attr_name)
return torch._C._get_obj_in_tls(attr_name)
def get_container(device_index: int) -> TreeManagerContainer:
container_dict = get_obj(local, "tree_manager_containers")
lock = get_obj(local, "tree_manager_locks")[device_index]
with lock:
if device_index not in container_dict:
container_dict[device_index] = TreeManagerContainer(device_index)
return container_dict[device_index]
def get_manager(
device_index: int, create_if_none_exists: bool = True
) -> Optional[CUDAGraphTreeManager]:
if create_if_none_exists:
return get_container(device_index).get_tree_manager()
return get_container(device_index).tree_manager
def cudagraphify_impl(
model: ModelType,
inputs: List[InputType],
static_input_idxs: Sequence[int],
*args: Any,
**kwargs: Any,
) -> ModelType:
fn_cache: Dict[Tuple[int, ...], Callable[..., Any]] = {}
# Detect int inputs: we need to index on these
int_key = [i for i, v in enumerate(inputs) if isinstance(v, int)]
get_ints: Any = operator.itemgetter(*int_key) if int_key else lambda _: None
has_warn = False
del inputs
def deferred_cudagraphify(inputs: List[InputType]) -> OutputType:
nonlocal has_warn
int_key = get_ints(inputs)
fn = fn_cache.get(int_key)
if fn is not None:
return fn(inputs)
if int_key is None:
log.info("recording cudagraph tree for graph without symints")
else:
log.info("recording cudagraph tree for symint key %s", int_key)
if not has_warn:
has_warn = maybe_warning_due_to_dynamic_shape(fn_cache, int_key)
# first get indices we need to check to align, then update our static inputs,
# and finally copy
check_input_idxs = get_input_idxs_to_check(inputs, static_input_idxs)
new_static_input_idxs = remove_unaligned_input_idxs(inputs, static_input_idxs)
copy_misaligned_inputs(inputs, check_input_idxs)
fn, out = cudagraphify(model, inputs, new_static_input_idxs, *args, **kwargs)
fn = align_inputs_from_check_idxs(fn, inputs_to_check=check_input_idxs)
fn_cache[int_key] = fn
return out
return deferred_cudagraphify
def cudagraphify(
model: ModelType,
inputs: List[InputType],
static_input_idxs: Sequence[int] = (),
*,
device_index: int,
is_backward: bool,
is_inference: bool,
stack_traces: Optional[StackTraces] = None,
constants: Tuple[torch.Tensor, ...] = (),
placeholders: Tuple[PlaceholderInfo, ...] = (),
mutated_input_idxs: Tuple[int, ...] = (),
) -> Tuple[ModelType, OutputType]:
manager = get_container(device_index).get_tree_manager()
assert not (is_backward and is_inference)
mode = (
CompilationMode.BACKWARD
if is_backward
else (CompilationMode.INFERENCE if is_inference else CompilationMode.FORWARD)
)
return manager.add_function(
model,
inputs,
static_input_idxs,
stack_traces,
mode,
constants,
placeholders,
mutated_input_idxs,
)
class StorageWeakRefWrapper:
"""
Wrapper around a storage weak ref. Will deallocate it upon expiration if invoked.
"""
__slots__ = ["ref", "_data_ptr", "extra_ref_check"]
storage_ref: Optional[StorageWeakRef]
def __init__(
self,
inp: Union[Tensor, UntypedStorage],
extra_ref_check: Optional[Callable[[], bool]] = None,
) -> None:
"""
extra_ref_check is an additional check we need to run to check if the
weak ref has expired. in checking storage use count we assume extra_ref_check
will hold an additional reference to the storage.
"""
if isinstance(inp, Tensor):
stor = inp.untyped_storage()
else:
assert isinstance(inp, UntypedStorage)
stor = inp
self.ref = StorageWeakRef(stor)
self._data_ptr = stor.data_ptr()
self.extra_ref_check = extra_ref_check
@classmethod
def from_weakref_and_data_ptr(
cls: Type[S],
cdata: Any,
data_ptr: int,
extra_ref_check: Optional[Callable[[], bool]] = None,
) -> StorageWeakRefWrapper:
instance = cls.__new__(cls)
instance._data_ptr = data_ptr
instance.ref = StorageWeakRef.from_weakref(cdata)
instance.extra_ref_check = extra_ref_check
return instance
def __call__(self) -> Optional[StorageWeakRefPointer]:
if self.expired():
return None
return self.ref.cdata
def swap_weakref(self, cdata: Any) -> None:
self.ref.__del__()
self.ref.cdata = cdata
def data_ptr(self) -> int:
"NB: returns the data ptr even if the storage has expired"
return self._data_ptr
def remove_extra_reference(self) -> None:
self.extra_ref_check = None
def expired(self) -> bool:
if self.extra_ref_check is not None and not self.extra_ref_check():
return False
# if extra_ref_check is not None we expect an additional reference
stor_count = torch._C._storage_Use_Count(self.ref.cdata)
return (stor_count - (self.extra_ref_check is not None)) == 0
def __repr__(self) -> str:
if self.ref is None or self.ref.expired():
return f"StorageWeakRefWrapper to {self.data_ptr()}; dead"
else:
return f"StorageWeakRefWrapper to {self.data_ptr()}; alive"
def is_live(weak_ref: Optional[StorageWeakRefWrapper]) -> bool:
return maybe_deref(weak_ref) is not None
def maybe_deref(
weak_ref: Optional[StorageWeakRefWrapper],
) -> Optional[Tuple[StorageWeakRefPointer, int]]:
if weak_ref is None:
return None
r = weak_ref()
if r is None:
return None
# NB: r.data_ptr() does not necessarily equal weak_ref.data_ptr()
return r, weak_ref.data_ptr()
@contextlib.contextmanager
def _use_cuda_memory_pool_manager(
device: int, mem_pool: Tuple[int, int], stream: torch.cuda.Stream
) -> Generator[None, None, None]:
"""
Context manager to use cuda graph pool for new allocations. If you use this manager
all cudagraph tensors in use should be reflected in the allocator or they will be overwritten.
existing_graph should already have been used in a capture, and the mem_pool must already exist,
because this manager will not preserve a reference to the pool which keeps it alive.
"""
torch.cuda.synchronize()
stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(stream), torch.device(device):
torch._C._cuda_beginAllocateCurrentStreamToPool(device, mem_pool)
try:
yield
finally:
torch._C._cuda_endAllocateCurrentStreamToPool(device, mem_pool)
torch._C._cuda_releasePool(device, mem_pool)
torch.cuda.current_stream().wait_stream(stream)
def map_to_ref(t: Optional[Tensor]) -> Optional[StorageWeakRefWrapper]:
if not isinstance(t, torch.Tensor):
assert t is None
return None
return StorageWeakRefWrapper(t)
# A path index of (depth, offset) indices into a graph that is `depth`` number of nodes from the root
# at graph output offset
PathOutputIndex = Tuple[int, int]
# For each node in the path, for each output, is the output alive
PathLiveness = List[List[bool]]
StackTraces = List[Optional[str]]
class CUDAWarmupNode:
"""
Simplified Wrapper around A CUDA Model that wraps outputs in storage refs and exposes
apis to get the live storages in the current chain of warmup.
A CUDAWarmupNode may have either CUDAGraphNode or CUDAWarmupNode as a parent, but may only have
CUDAWarmupNode as children, because we cannot record or execute with tensors which do not have stable
memory addresses.
CUDAWarmupNode and CUDAGraphNode have a number of differences that make it easier to use separate classes.
- Much of the CUDAGraphNode logic & initialization is based on the tensor properties of first recording. In the
first instance of warmup, these are not finalized yet.
- All Inputs to the RecordedFunction must be copied over to the cuda graph memory pool, this is unnecessary in warmup.
- CUDAWarmup is only used once and so does not need to optimize as much bookkeeping. It is much simpler.
NB: this class and CUDAGraphNode need to expose `path_live_weakrefs`, `all_outputs_are_dead`, and
`self.outputs_weakrefs`, `stack_traces`, and `tensor_weakrefs` for compatibility.
"""
def __init__(
self,
wrapped_function: WrappedFunction,
parent: Optional[Union[CUDAGraphNode, CUDAWarmupNode]],
cuda_graphs_pool: Tuple[int, int],
existing_cuda_graph: Optional[torch.cuda.CUDAGraph],
device_index: int,
stack_traces: Optional[StackTraces],
stream: torch.cuda.Stream,
already_warm: bool,
id: GraphID,
) -> None:
self.wrapped_function = wrapped_function
self.parent: Optional[Union[CUDAGraphNode, CUDAWarmupNode]] = parent
self.cuda_graphs_pool = cuda_graphs_pool
self.outputs_weakrefs: List[Optional[StorageWeakRefWrapper]] = []
self.tensor_weakrefs: List[Optional[TensorWeakRef]] = []
self.existing_cuda_graph = existing_cuda_graph
self.has_run = False
self.device_index = device_index
self.stack_traces = stack_traces
self.stream = stream
self.already_warm = already_warm
self.id = id
def run(self, new_inputs: Any) -> OutputType:
assert not self.has_run, "Wrapped function should never be run twice"
# See: output_is_alias_of_persistent_static_inputs below. We should only be returning freshly created
# storages in path_live_weakrefs.
existing_path_data_ptrs = {
t.data_ptr() for t in self.path_live_weakrefs() if t()
}
def get_non_cudagraph_inps() -> List[weakref.ReferenceType[UntypedStorage]]:
non_cudagraph_inps = []
for t in itertools.chain(new_inputs, self.wrapped_function.constants):
if (
isinstance(t, torch.Tensor)
and t.untyped_storage().data_ptr() not in existing_path_data_ptrs
):
non_cudagraph_inps.append(weakref.ref(t.untyped_storage()))
return non_cudagraph_inps
non_cudagraph_inps_storages = get_non_cudagraph_inps()
if config.triton.slow_path_cudagraph_asserts and not self.already_warm:
refs = list(self.path_live_weakrefs())
check_memory_pool(self.device_index, self.cuda_graphs_pool, refs)
with torch.cuda.device(
self.device_index
), disable_conv_cache_emptying(), clear_cublas_manager(), _use_cuda_memory_pool_manager(
self.device_index, self.cuda_graphs_pool, self.stream
), get_history_recording():
out = self.wrapped_function.model(new_inputs)
# We need to know which outputs are allocated within the cudagraph pool
# so that we can deallocate them at the beginning of the next cudagraph step,
# and set their access to error.
# We use a weakref to the inputs storage, in case a block which was previously
# allocated to the general caching allocator pool gets reallocated to a private pool.
non_cudagraph_inps_storage_ptrs = set()
for storage in non_cudagraph_inps_storages:
s = storage()
if s is not None:
non_cudagraph_inps_storage_ptrs.add(s._cdata)
assert len(new_inputs) == 0
# sdpa returns cpu tensors when not recording cuda graph
def add_ref(o: Any) -> bool:
return (
isinstance(o, torch.Tensor)
and o.is_cuda
and o.untyped_storage()._cdata not in non_cudagraph_inps_storage_ptrs
and o.untyped_storage().data_ptr() != 0
)
self.outputs_weakrefs.extend(
[map_to_ref(o) if add_ref(o) else None for o in out]
)
self.tensor_weakrefs.extend(
[TensorWeakRef(o) if add_ref(o) else None for o in out]
)
if config.triton.slow_path_cudagraph_asserts and not self.already_warm:
out_refs = list(self.path_live_weakrefs())
check_memory_pool(self.device_index, self.cuda_graphs_pool, out_refs)
return out
@property
def _path_from_root(
self,
) -> Generator[Union[CUDAGraphNode, CUDAWarmupNode], None, None]:
nodes = []
node: Union[CUDAGraphNode, CUDAWarmupNode] = self
while node:
nodes.append(node)
node = node.parent # type: ignore[assignment]
yield from reversed(nodes)
def path_live_weakrefs(self) -> Iterator[StorageWeakRefWrapper]:
"Returns all live storages weakrefs that created by nodes in this path"
for node in self._path_from_root:
for output in node.outputs_weakrefs:
if is_live(output):
yield output # type: ignore[misc]
def all_outputs_are_dead(self) -> bool:
return not list(self.path_live_weakrefs())
def _is_cuda_graph_recorded_tensor(self, t: torch.Tensor) -> bool:
for storage_weak_ref in self.path_live_weakrefs():
if t.untyped_storage().data_ptr() == storage_weak_ref.data_ptr():
return True
return False
# Aliases for List that say what the indices denote
InputList = List # input indexes
OutputList = List # output indexes
LevelList = List # levels (distance from root of tree)
class OutputAliasInfo:
pass
class _UnaliasedStorage(OutputAliasInfo):
"Singleton to mark that the graph output constructs a new alias or is None"
UnaliasedStorage = _UnaliasedStorage()
class AliasesPriorGraphOutput(OutputAliasInfo):
"Marks that the graph output aliases an output of a prior graph"
__slots__ = ["index"]
index: PathOutputIndex
def __init__(self, index: PathOutputIndex) -> None:
assert isinstance(index, tuple)
self.index = index
class AliasesNewOutput(OutputAliasInfo):
"Marks that the graph output aliases an index in the new, returned outputs"
__slots__ = ["index"]
index: int
def __init__(self, index: int) -> None:
assert isinstance(index, int)
self.index = index
class CUDAGraphNode:
"""
A single recording of a function into a CUDA Graph. Recordings of CUDA Graphs share a single memory pool
and are structured into a tree, where there is a single recording that can precede it (parent) and multiple
subsequent recordings that may follow (children). A node will have no parent if it is the first recording
in a tree; i.e., when it is first recorded, there are no live tensors from a previous recording which
would force a dependency.
On first recording, all of the live tensors in the current CUDA Graph Node path will be
reflected in the corresponding private pool. On subsequent executions, the caching allocator
is unaffected when the graph is replayed.
In order to support recording a subsequent cuda graph recording after execution of this graph,
we checkpoint the state of the memory pool so that it may later be resumed.
WrappedFunction should have already been warmed up prior to invocation.
See [setCheckpointPoolState] for further explanation, as well as
https://user-images.githubusercontent.com/13564/222815509-374f3400-f83d-4f7d-8fa6-4a092b3250bb.png
"""
def __init__(
self,
wrapped_function: WrappedFunction,
id: GraphID,
parent: Optional[CUDAGraphNode],
inputs: List[InputType],
cuda_graphs_pool: Tuple[int, int],
device_index: int,
stack_traces: Optional[StackTraces],
stream: torch.cuda.Stream,
) -> None:
assert isinstance(inputs, (list, tuple))
self.wrapped_function = wrapped_function
self.id = id
self.device = device_index
self.stack_traces = stack_traces
self.stream = stream
# Enable re-record a cudagraph when static tensor address changed.
# if not we should error when it changed.
self.rerecord_if_static_inputs_change = (
torch._dynamo.config.inline_inbuilt_nn_modules
or torch._inductor.config.triton.cudagraph_support_input_mutation
)
# if this is a root parent will be None. use weakref to prevent reference cycle
self._parent = weakref.ref(parent) if parent is not None else None
# reference to the shared memory pool for the entire cuda graphs tree
self.cuda_graphs_pool = cuda_graphs_pool
# A single wrapped function may be recorded multiple times if memory patterns or
# invariants change from one execution to the next
self.children: Dict[FunctionID, List[CUDAGraphNode]] = defaultdict(list)
# StorageWeakRef maintains whether the Storage C++ object remains allocated,
# not whether the corresponding memory has been deallocated. In order
# to use them to track memory deallocations we must maintain a single StorageWeakRef
# for all Storages that reference that memory (even if we are constructing Storages
# that do not have a deallocator function). We maintain one single storage_cache
# as we execute any tree path. When we retrieve a storage from the cache we
# check that it is still alive, and we hash based on observed recording data ptr
# and storage cdata.
# we preserve a single reference to executed outputs that is then referenced
# in children to avoid children having to chase parent pointers in the hot path
# DO NOT reassign output_weakrefs, only call `clear()`
# Path is a series of nodes from root to the current node
self.outputs_weakrefs: OutputList[Optional[StorageWeakRefWrapper]] = []
self.path_weakrefs: LevelList[OutputList[Optional[StorageWeakRefWrapper]]] = [
node.outputs_weakrefs for node in self._path_from_root
]
self.path_stacktraces: LevelList[Optional[StackTraces]] = [
node.stack_traces for node in self._path_from_root
]
self.tensor_weakrefs: OutputList[Optional[TensorWeakRef]] = []
# tensors which are outputs of previous graphs in the tree
self.cudagraph_managed_idxs: List[int] = [
idx
for idx, t in enumerate(inputs)
if isinstance(t, torch.Tensor) and self._is_cuda_graph_recorded_tensor(t)
]
self.static_input_idxs: List[int] = list(
set(wrapped_function.static_input_idxs) | set(self.cudagraph_managed_idxs)
)
self.non_static_input_idx: LevelList[int] = [
i for i in range(len(inputs)) if i not in self.static_input_idxs
]
counters["inductor"]["cudagraph_recorded_non_static_inputs"] += len(
self.non_static_input_idx
)
self.non_managed_static_input_idxs: LevelList[int] = [
i
for i in wrapped_function.static_input_idxs
if i not in self.cudagraph_managed_idxs
]
def maybe_get_static_data_ptr(
idx: int,
inputs: List[InputType],
static_input_idxs: List[int],
) -> Optional[int]:
inp = inputs[idx]
if isinstance(inp, torch.Tensor) and idx in static_input_idxs:
return inp.data_ptr()
return None
self.static_input_data_ptrs: InputList[Optional[int]] = [
maybe_get_static_data_ptr(i, inputs, self.static_input_idxs)
for i in range(len(inputs))
]
# When we checkpoint, and free generations, we will be manually freeing the outputs
# of CUDAGraphNodes. We should not be freeing parameters, not do we need to account for
# their liveness (they are static), so we need to compute which outputs are aliases of
# parameters. Some static inputs are saved tensors from the forward that die in the backward.
# Their locations are static but lifetimes are not. We only include the persistent static
# data ptrs below because the non persistent data ptrs may be outputs of this record and
# fresh allocations.
# precompute expanded dims to avoid computing in the hot path
self.expanded_dims: List[List[int]] = [
get_expanded_dims(x)
if isinstance(x, torch.Tensor) and idx not in self.static_input_idxs
else []
for idx, x in enumerate(inputs)
]
# For each node in path, which outputs were observed to be live
# before invoking graph recording, and after graph recording
self.recorded_liveness_before_graph: LevelList[OutputList[bool]] = []
self.recorded_liveness_after_graph: LevelList[OutputList[bool]] = []
# List of Tuples of (depth, output_index) that index into node at depth
# number of nodes from root and output_index of outputs. Will index into
# path_weakrefs.
self.expected_dead_indices_before_graph: List[PathOutputIndex] = []
self.expected_dead_indices_after_graph: List[PathOutputIndex] = []
# all live indices after graph recording
self.live_indices_after_graph: List[PathOutputIndex] = []
if self.parent is not None:
previous_liveness = self.parent.recorded_liveness_after_graph
curr_liveness = self._get_liveness(self.path_weakrefs)
different_indices = self._get_different_indices(
previous_liveness, curr_liveness
)
self.recorded_liveness_before_graph = curr_liveness
self.expected_dead_indices_before_graph = different_indices
recording_inputs = self._allocate_and_copy_recording_inputs(inputs)
# recording inputs will copy over memory, so we can free non recording inputs
inputs.clear()
del inputs
# graph used for recording model invocation
self.graph: Optional[torch.cuda.CUDAGraph] = torch.cuda.CUDAGraph()
# we allocate non-static inputs within the same memory pool as the CUDAGraph
# which we will record the model with. For memory efficiency, it is important
# to reclaim the input memory when the inputs are no longer live. To accomplish this,
# we reconstruct tensors at the correct data pointers of our inputs which are
# non owning and do not prevent deallocation. On subsequent executions, input values
# will be copied over to these tensors.
self.reconstructed_inputs: List[InputType] = [
self._reconstruct_from_tensor_metadata(self._tensor_metadata(x))
if isinstance(x, torch.Tensor)
else x
for x in recording_inputs
]
# DO THE RECORDING!!!
# We record the CUDA graph in the constructor of CUDAGraphNode, which
# gives you what the CPU side compute of the function would do. We
# don't throw the recording outputs away: their memory is
# correctly accounted for in the CUDAGraphs caching allocator. This
# means on the very FIRST run of the CUDA graph node, we can directly
# do more recording, because we have a valid caching allocator state.
# NB: This relies on run() being called immediately after the
# constructor, otherwise this optimization would not be valid.
# initialized below in _record
self.checkpointed_caching_state: Optional[AllocatorState] = None
# Output Storage Alias information, can be:
# - A new, unaliased storage, or the output is None
# - An alias of an output of a prior graph
# - An alias of an output already created in the reconstructed outputs
# This is None if the output in question is an int
self.output_storage_alias: OutputList[Optional[OutputAliasInfo]] = []
# is the output Storage unaliased in subsequent outputs, of all subsequent paths
# if it is, we cached the output tensor and adjust storage liveness tracking to also
# check if the output tensor does not have an additional python reference.
# If a descendent node discovers it has an alias of a prior output, then the output
# will no longer be cached in the ancestor.
# The large majority of tensors are unaliased, and preserving aliased output tensors would add
# significant additional complexity with marginal gains
# The cached tensor outputs are added on the first execution, and cleared whenever we need
# to do subsequent recording
self.unaliased_in_all_paths: OutputList[bool] = []
self.cached_tensor_outputs: OutputList[Optional[Tensor]] = []
# if an output aliases a static, persistent input then the corresponding Tensor will
# be set here. These are different than cached tensors, because they are tensors that
# are aliases of parameters that are always live.
self.static_output_tensors: OutputList[Optional[Tensor]] = []
# Cleared after recording
self.recording_outputs: Optional[OutputType] = self._record(
wrapped_function.model, recording_inputs
)
self.outputs_metadata: OutputList[Union[Dict[str, Any], int, None]] = []
# As with inputs, we do not want to keep the outputs permanently alive because that would prevent
# their memory being reclaimed in subsequent cuda graph recordings. We record the tensor metadata
# needed to reconstruct instead.
assert self.recording_outputs is not None
for out in self.recording_outputs:
if isinstance(out, torch.Tensor):
self.outputs_metadata.append(
self._tensor_metadata(out, ignore_storage_offset=False)
)
else:
assert isinstance(out, (int, type(None))), type(out)
self.outputs_metadata.append(out)
self.graph.replay()
def _copy_inputs_and_remove_from_src(
self, dsts: List[InputType], srcs: List[InputType]
) -> None:
dst_tensors = []
src_tensors = []
for idx in self.non_static_input_idx:
if not isinstance(srcs[idx], torch.Tensor):
continue
expanded_dims = self.expanded_dims[idx]
dst_tensors.append(index_expanded_dims(dsts[idx], expanded_dims)) # type: ignore[arg-type]
src_tensors.append(index_expanded_dims(srcs[idx], expanded_dims)) # type: ignore[arg-type]