Skip to content

Conversation

@loreloc
Copy link
Member

@loreloc loreloc commented Dec 16, 2024

  • Added optimization fusing an outer product and a sum over the parameters (e.g., embedding weights) into an einsum. This is crucial to avoid OOM errors and reproduce the benchmark results in the SOS circuits paper.

Other changes:

  • Torch parameter pointers now hold the pointed torch tensor parameter as a submodule.
  • Removed TorchCircuit.device() property and the ._set_device() method for two reasons. First, modules are NOT stored on devices, but tensors are. Second, when the circuit is a sub-module in some other module, then its device was NOT updated when calling .to() to the parent module. This is because .to() is applied recursively to the tensors (buffers & parameters) but not modules. This broke the implementation of SOS circuits.
  • We now do not need to pass a dummy tensor to the circuit in order to evaluate the partition functions. That is, constant layers now only take the batch size as input, rather than some allocated dummy tensor.
  • Fixed a few pylint errors.

@loreloc loreloc added the enhancement New feature or request label Dec 16, 2024
@loreloc loreloc added this to the cirkit 0.2.0 (murmur) milestone Dec 16, 2024
@loreloc loreloc self-assigned this Dec 16, 2024
@loreloc loreloc merged commit 696a5c7 into main Dec 16, 2024
1 of 2 checks passed
@loreloc loreloc deleted the speed-up-sos branch December 16, 2024 12:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants