Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Major performance degradation when multiple metrics/losses #20388

Open
EtayLivne opened this issue Nov 3, 2024 · 0 comments
Open

Major performance degradation when multiple metrics/losses #20388

EtayLivne opened this issue Nov 3, 2024 · 0 comments
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.4.x

Comments

@EtayLivne
Copy link
Contributor

Bug description

The issue
The same exact model will train considerably slower if the last layer is interpreted as multiple outputs rather than a single output. A benchmark I've built takes a model with about 72 million weights across several linear layers, where the last layer has 4 weight. I compare the case where these weights are considered a single output, with one metric and one loss, vs cases where I have two losses + metrics, and four losses + metrics. I've found what appears to be an increasing performance degradation with additional metrics and losses, accumulating to as much as 40-50 % in some runs. The same does not appear to replicate in equivalent training loops I've set up with vanilla Pytorch.

Given that each step in the training loop has hundreds of millions of calculations (forward + backward pass), I find a decrease of tens of percentages in performance just to support another loss + metric to be quite a hefty price to pay, and at least some of the issue appears to be something that cropped up in Lightning and isn't there in vanilla torch.

The benchmark
The lightning benchmark is set up to run with barebones=True, on a machine with 2 CPUs and one GPU (container on an A100). The data is a set of in-memory random vectors with random labels. I've written the models and metrics as "flat" as I can, not using any nested structures to hold them (no lists/dicts, no nn.Sequential, etc.). The training loop is of 10_000 steps, in a single training epoch and not including a validation epoch. I use seed_everything to make the experiment deterministic across runs, in both the vanilla torch and lightning case. In the vanilla case I tried to make sure all computations happen on-device.

(I've also tried changing the order in which I call each training loop, thinking maybe there's some global system state the changes from call to call, but the results replicate regardless of order).

Further info

I did this benchmark after encountering performance degradation in my work. I used the lightning profiler to investigate and found that the entirety of the difference in performance between having one head and two is that there were twice as many function calls associated with metrics and losses. I've disabled profiling in the benchmark in favor of the barebones approach to reduce noise as much as possible, so the results are pretty blackboxy, but this can definitely be a direction to investigate.

What version are you seeing the problem on?

v2.4

How to reproduce the bug

https://github.com/EtayLivne/pl-multihead-benchmark

Follow the (short and simple) instructions in the READ.me of the above repo.

Error messages and logs

# Error messages and logs here please

Environment

Current environment
  • CUDA:
    • GPU:
      • NVIDIA L40S
    • available: True
    • version: 12.1
  • Lightning:
    • lightning: 2.4.0
    • lightning-utilities: 0.11.8
    • pytorch-lightning: 2.4.0
    • sagemaker-pytorch-training: 2.8.1
    • torch: 2.5.1
    • torchaudio: 2.5.1
    • torchmetrics: 1.5.1
    • torchvision: 0.15.2a0
  • Packages:
    • aiobotocore: 2.15.1
    • aiohappyeyeballs: 2.4.3
    • aiohttp: 3.10.10
    • aioitertools: 0.12.0
    • aiosignal: 1.3.1
    • angie-data-reader: 0.7.2
    • angie-shuffle-service: 0.15.1
    • antlr4-python3-runtime: 4.9.3
    • anyio: 4.6.2.post1
    • async-timeout: 4.0.3
    • attrs: 23.2.0
    • autocommand: 2.2.2
    • backports.tarfile: 1.2.0
    • bcrypt: 4.2.0
    • beartype: 0.19.0
    • blinker: 1.8.2
    • boto3: 1.35.23
    • botocore: 1.35.23
    • brotli: 1.1.0
    • cachetools: 5.5.0
    • certifi: 2024.8.30
    • cffi: 1.17.1
    • charset-normalizer: 3.4.0
    • click: 8.1.7
    • cloud-logging: 0.5.44
    • cloud-storage-utils: 1.0.53
    • cloudpickle: 2.2.1
    • colorama: 0.4.6
    • comet-ml: 3.47.1
    • configobj: 5.0.9
    • contextlib2: 21.6.0
    • copier: 8.1.0
    • cryptography: 43.0.3
    • decorator: 5.1.1
    • dill: 0.3.9
    • dl-optimizer: 24.11022.0
    • dl-optimizer-common: 24.11022.0
    • docker: 7.1.0
    • dulwich: 0.22.3
    • dunamai: 1.22.0
    • durationpy: 0.9
    • everett: 3.1.0
    • exceptiongroup: 1.2.2
    • faker: 30.8.1
    • filelock: 3.16.1
    • fpdf: 1.7.2
    • frozenlist: 1.5.0
    • fsspec: 2024.10.0
    • funcy: 2.0
    • future: 1.0.0
    • getdaft: 0.3.0.dev0
    • gevent: 24.10.3
    • gmpy2: 2.1.5
    • google-api-core: 2.22.0
    • google-api-python-client: 2.149.0
    • google-auth: 2.35.0
    • google-auth-httplib2: 0.2.0
    • google-cloud-container: 2.53.0
    • google-cloud-core: 2.4.1
    • google-cloud-storage: 2.18.2
    • google-crc32c: 1.1.2
    • google-pasta: 0.2.0
    • google-resumable-media: 2.7.2
    • googleapis-common-protos: 1.65.0
    • greenlet: 3.1.1
    • grpcio: 1.62.2
    • grpcio-status: 1.62.2
    • h11: 0.14.0
    • h2: 4.1.0
    • hera: 5.13.1
    • hpack: 4.0.0
    • httpcore: 1.0.6
    • httplib2: 0.22.0
    • httpx: 0.27.2
    • hydra-core: 1.3.2
    • hyperframe: 6.0.1
    • hyperopt: 0.2.7
    • idna: 3.10
    • importlib-metadata: 6.10.0
    • importlib-resources: 6.4.5
    • inflect: 7.3.1
    • iniconfig: 2.0.0
    • inotify-simple: 1.2.1
    • jaraco.collections: 5.1.0
    • jaraco.context: 5.3.0
    • jaraco.functools: 4.0.1
    • jaraco.text: 3.12.1
    • jinja2: 3.1.4
    • jinja2-ansible-filters: 1.3.2
    • jmespath: 1.0.1
    • jsonschema: 4.23.0
    • jsonschema-specifications: 2024.10.1
    • kubernetes: 31.0.0
    • lightning: 2.4.0
    • lightning-utilities: 0.11.8
    • lxml: 4.9.3
    • lz4: 4.3.3
    • markdown-it-py: 3.0.0
    • markupsafe: 3.0.2
    • mdurl: 0.1.2
    • me-auth-client: 0.29.0
    • menta3: 0.25.3
    • more-itertools: 10.3.0
    • mpmath: 1.3.0
    • msal: 1.31.0
    • multidict: 6.1.0
    • multiprocess: 0.70.17
    • networkx: 3.4.2
    • nssd-errors: 1.6.3
    • numpy: 1.26.4
    • oauthlib: 3.2.2
    • omegaconf: 2.3.0
    • opencv-python: 4.10.0.84
    • packaging: 24.1
    • pandas: 2.2.3
    • paramiko: 3.5.0
    • pathos: 0.3.3
    • pathspec: 0.12.1
    • pillow: 9.4.0
    • pip: 24.3.1
    • pkgutil-resolve-name: 1.3.10
    • platformdirs: 4.3.6
    • pluggy: 1.5.0
    • plumbum: 1.9.0
    • pox: 0.3.5
    • ppft: 1.7.6.9
    • prompt-toolkit: 3.0.36
    • propcache: 0.2.0
    • proto-plus: 1.25.0
    • protobuf: 3.20.3
    • psutil: 6.1.0
    • py4j: 0.10.9.7
    • pyarrow: 16.1.0
    • pyasn1: 0.6.1
    • pyasn1-modules: 0.4.1
    • pycparser: 2.22
    • pydantic: 1.10.17
    • pydot: 3.0.1
    • pygments: 2.18.0
    • pyjwt: 2.9.0
    • pynacl: 1.5.0
    • pynamodb: 5.5.1
    • pyopenssl: 24.2.1
    • pyparsing: 3.2.0
    • pysocks: 1.7.1
    • pytest: 7.4.4
    • python-box: 6.1.0
    • python-dateutil: 2.9.0
    • python-json-logger: 2.0.7
    • pytorch-lightning: 2.4.0
    • pytz: 2024.1
    • pyu2f: 0.1.5
    • pyyaml: 6.0.2
    • pyyaml-include: 1.3
    • questionary: 2.0.1
    • referencing: 0.35.1
    • requests: 2.32.3
    • requests-oauthlib: 2.0.0
    • requests-toolbelt: 1.0.0
    • retrying: 1.3.4
    • rich: 13.9.3
    • rpds-py: 0.20.0
    • rsa: 4.9
    • ruamel.yaml: 0.18.6
    • ruamel.yaml.clib: 0.2.8
    • s3fs: 2024.10.0
    • s3path: 0.5.8
    • s3transfer: 0.10.3
    • sagemaker: 2.210.0
    • sagemaker-pytorch-training: 2.8.1
    • sagemaker-training: 4.8.1
    • schema: 0.7.7
    • scipy: 1.14.1
    • semantic-version: 2.10.0
    • sentry-sdk: 2.16.0
    • setuptools: 75.1.0
    • six: 1.16.0
    • smart-open: 7.0.5
    • smdebug-rulesconfig: 1.0.1
    • sniffio: 1.3.1
    • sympy: 1.13.3
    • tabulate: 0.9.0
    • tblib: 2.0.0
    • tenacity: 9.0.0
    • termcolor: 2.5.0
    • toml: 0.10.2
    • tomli: 2.0.2
    • torch: 2.5.1
    • torchaudio: 2.5.1
    • torchmetrics: 1.5.1
    • torchvision: 0.15.2a0
    • tqdm: 4.66.6
    • triton: 3.1.0
    • typeguard: 4.3.0
    • typing-extensions: 4.12.2
    • tzdata: 2024.2
    • uritemplate: 4.1.1
    • urllib3: 1.26.19
    • wcwidth: 0.2.13
    • websocket-client: 1.8.0
    • werkzeug: 3.0.6
    • wheel: 0.44.0
    • wrapt: 1.16.0
    • wurlitzer: 3.1.1
    • yarl: 1.16.0
    • zipp: 3.20.2
    • zope.event: 5.0
    • zope.interface: 7.1.1
  • System:

More info

No response

@EtayLivne EtayLivne added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Nov 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.4.x
Projects
None yet
Development

No branches or pull requests

1 participant