Skip to content

[Feature] New methods pruning and adapters#9

Merged
Xmaster6y merged 5 commits into
mainfrom
methods
Sep 24, 2025
Merged

[Feature] New methods pruning and adapters#9
Xmaster6y merged 5 commits into
mainfrom
methods

Conversation

@Xmaster6y

@Xmaster6y Xmaster6y commented Sep 24, 2025

Copy link
Copy Markdown
Owner

What does this PR do?

Key insights about the PR.

Linked Issues

  • Closes #?
  • #?

Summary by Sourcery

Implement new Pruning and Adapters hook factories for weight manipulation, improve hook callback handling to accept None responses, extend caching context with cache clearing, and remove obsolete SAE module

New Features:

  • Add Pruning hook factory to support global and structured pruning with importance callbacks
  • Add Adapters hook factory to inject adapter modules into forward hooks with optional cross-layer mapping

Enhancements:

  • Allow hook callbacks to return None without triggering a type mismatch error
  • Extend HookingContextWithCache with a clear_cache option to reset cache on context enter

Tests:

  • Add tests for pruning behavior, including global and module-specific pruning, skip logic, and weight restoration
  • Add tests for adapters to verify output modification and restoration, including cross-layer scenarios
  • Add test ensuring hooks with None-returning callbacks do not alter values

Chores:

  • Remove deprecated sae weight module and clean up outdated TODOs

@sourcery-ai

sourcery-ai Bot commented Sep 24, 2025

Copy link
Copy Markdown
Contributor

Reviewer's Guide

This pull request introduces two major HookingContextFactory-based features—pruning and adapters—alongside enhancements to the caching context, a relaxation of the hook return-type check, and accompanying tests and minor cleanups.

Entity relationship diagram for Adapters module

erDiagram
    HOOKED_MODULE {
        adapters ModuleDict
    }
    ADAPTERS {
        adapters Dict
        cache_callback Callable
        relative bool
        directions List
    }
    HOOKED_MODULE ||--o{ ADAPTERS : uses
Loading

Entity relationship diagram for Pruning module

erDiagram
    PRUNING_CONTEXT {
        _old_weights TensorDict
    }
    PRUNING {
        importance_callback Callable
        amount_to_prune float
        modules_to_prune Dict
        skip_modules Callable
        relative bool
        relative_path str
    }
    PRUNING_CONTEXT ||--o{ PRUNING : used_by
Loading

Class diagram for new Pruning and Adapters features

classDiagram
    class HookingContextFactory {
    }
    class PruningContext {
        - _old_weights
        + __enter__()
        + __exit__()
    }
    class Pruning {
        - _importance_callback
        - _amount_to_prune
        - _modules_to_prune
        - _skip_modules
        - _relative
        - _relative_path
        + _prepare_module()
        + default_skip()
    }
    class HookedModule {
    }
    class HookedModuleWithAdapters {
        - adapters
    }
    class Adapters {
        - _adapters
        - _cache_callback
        - _relative
        - _directions
        + _hook_module()
    }
    class HookingContextWithCache {
        - _cache
        - _clear_cache
        + cache
        + clear()
        + __enter__()
    }
    HookingContextFactory <|-- Pruning
    HookingContext <|-- PruningContext
    HookingContextFactory <|-- Adapters
    HookedModule <|-- HookedModuleWithAdapters
    HookingContextWithCache <|-- Adapters
Loading

File-Level Changes

Change Details Files
Introduce Pruning context and logic for global and structured weight pruning
  • Add PruningContext to backup and restore module weights
  • Implement Pruning class with global_unstructured and ln_structured flows
  • Support modules_to_prune dict, skip_modules callback, relative path resolution
  • Raise error when pruning amount is unspecified
  • Remove pruning masks after application
src/tdhook/weights/pruning.py
tests/weights/test_pruning.py
Implement Adapters context factory for injecting adapter modules via hooks
  • Define HookedModuleWithAdapters to store adapter modules
  • Use HookingContextWithCache to buffer intermediate values
  • Create callback_factory to route inputs through adapters
  • Support same-layer and cross-layer adapter wiring via MultiHookHandle
src/tdhook/weights/adapters.py
tests/weights/test_adapters.py
Extend HookingContextWithCache to support cache clearing on entry
  • Add clear_cache flag to constructor
  • Implement clear() method on context
  • Clear internal cache when entering context if clear_cache is True
src/tdhook/contexts.py
Relax hook return-type constraint to allow None values
  • Adjust hook logic to skip type-check when callback returns None
  • Add new test ensuring setting hooks with None callbacks leave values unchanged
src/tdhook/hooks.py
tests/test_hooks.py
Perform minor cleanup in weights init.py
  • Remove outdated TODO comments related to SAE and model diffing
src/tdhook/weights/__init__.py

Tips and commands

Interacting with Sourcery

  • Trigger a new review: Comment @sourcery-ai review on the pull request.
  • Continue discussions: Reply directly to Sourcery's review comments.
  • Generate a GitHub issue from a review comment: Ask Sourcery to create an
    issue from a review comment by replying to it. You can also reply to a
    review comment with @sourcery-ai issue to create an issue from it.
  • Generate a pull request title: Write @sourcery-ai anywhere in the pull
    request title to generate a title at any time. You can also comment
    @sourcery-ai title on the pull request to (re-)generate the title at any time.
  • Generate a pull request summary: Write @sourcery-ai summary anywhere in
    the pull request body to generate a PR summary at any time exactly where you
    want it. You can also comment @sourcery-ai summary on the pull request to
    (re-)generate the summary at any time.
  • Generate reviewer's guide: Comment @sourcery-ai guide on the pull
    request to (re-)generate the reviewer's guide at any time.
  • Resolve all Sourcery comments: Comment @sourcery-ai resolve on the
    pull request to resolve all Sourcery comments. Useful if you've already
    addressed all the comments and don't want to see them anymore.
  • Dismiss all Sourcery reviews: Comment @sourcery-ai dismiss on the pull
    request to dismiss all existing Sourcery reviews. Especially useful if you
    want to start fresh with a new review - don't forget to comment
    @sourcery-ai review to trigger a new review!

Customizing Your Experience

Access your dashboard to:

  • Enable or disable review features such as the Sourcery-generated pull request
    summary, the reviewer's guide, and others.
  • Change the review language.
  • Add, remove or edit custom review instructions.
  • Adjust other review settings.

Getting Help

@sourcery-ai sourcery-ai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey there - I've reviewed your changes and they look great!

Prompt for AI Agents
Please address the comments from this code review:

## Individual Comments

### Comment 1
<location> `src/tdhook/weights/adapters.py:42-50` </location>
<code_context>
+                    adapter_input = cache_proxy.resolve()
+                else:
+                    adapter_input = kwargs.pop(DIRECTION_TO_RETURN[kwargs["direction"]])
+                return adapter(adapter_input, **kwargs)
+
+            return callback
</code_context>

<issue_to_address>
**suggestion (bug_risk):** Passing all kwargs to the adapter may cause unexpected keyword argument errors.

If the adapter's forward method doesn't accept arbitrary kwargs, this could cause runtime errors. Filter kwargs to only those accepted by the adapter, or clearly document the required adapter signature.

```suggestion
            import inspect

            def callback(**kwargs):
                nonlocal adapter, cache_proxy
                if cache_proxy is not None:
                    adapter_input = cache_proxy.resolve()
                else:
                    adapter_input = kwargs.pop(DIRECTION_TO_RETURN[kwargs["direction"]])
                # Filter kwargs to only those accepted by the adapter
                adapter_params = inspect.signature(adapter).parameters
                filtered_kwargs = {k: v for k, v in kwargs.items() if k in adapter_params}
                return adapter(adapter_input, **filtered_kwargs)

            return callback
```
</issue_to_address>

### Comment 2
<location> `tests/weights/test_pruning.py:79-88` </location>
<code_context>
+                torch.tensor(0.9),
+            )
+
+    def test_no_pruning(self):
+        model = nn.Linear(10, 10)
+        inp = torch.randn(10)
+        original_output = model(inp)
+
+        pruning = Pruning(importance_callback=_importance_cb, amount_to_prune=0)
+        ctx = pruning.prepare(model)
+        with ctx as hooked:
+            output = hooked(inp)
+            assert torch.allclose(output, original_output)
+
+    def test_skip_importance_cb(self):
</code_context>

<issue_to_address>
**suggestion (testing):** Add a test for negative pruning amounts.

Please add a test case for negative amount_to_prune to verify that the code correctly raises an error or handles the input as intended.

Suggested implementation:

```python
    def test_negative_pruning_amount(self):
        model = nn.Linear(10, 10)
        with pytest.raises(ValueError):
            Pruning(importance_callback=_importance_cb, amount_to_prune=-0.1)

    def test_skip_importance_cb(self):
        model = nn.Linear(10, 10)
        original_state = (model.weight.clone(), model.bias.clone())

```

Make sure that `pytest` is imported at the top of the file if it is not already:
```python
import pytest
```
Also, ensure that the `Pruning` class raises a `ValueError` for negative `amount_to_prune`. If it does not, you will need to add this validation in the `Pruning` class implementation.
</issue_to_address>

### Comment 3
<location> `tests/weights/test_pruning.py:90-98` </location>
<code_context>
+            output = hooked(inp)
+            assert torch.allclose(output, original_output)
+
+    def test_skip_importance_cb(self):
+        model = nn.Linear(10, 10)
+        original_state = (model.weight.clone(), model.bias.clone())
+
+        pruning = Pruning(importance_callback=_importance_cb_skip_weight, amount_to_prune=0.5)
+        ctx = pruning.prepare(model)
+        with ctx:
+            assert torch.allclose(model.weight, original_state[0])
+            assert not torch.allclose(model.bias, original_state[1])
</code_context>

<issue_to_address>
**suggestion (testing):** Consider testing skip_modules logic.

Please add a test for skip_modules by passing a custom function and confirming that the intended modules are excluded from pruning.
</issue_to_address>

### Comment 4
<location> `tests/weights/test_adapters.py:30-28` </location>
<code_context>
+        restored_out = default_test_model(data["input"])
+        assert torch.allclose(baseline_out, restored_out)
+
+    def test_adapter_crosslayer(self, default_test_model):
+        data = TensorDict({"input": torch.randn(4, 10)}, batch_size=4)
+        baseline_out = default_test_model(data["input"]).detach().clone()
+
+        adapters = {"linear2": (_DoubleAdapter(), "linear1", "linear2")}
+        ctx_factory = Adapters(adapters=adapters)
+
+        with ctx_factory.prepare(default_test_model) as hooked:
+            patched_data = hooked(data.clone())
+            patched_out = patched_data["output"]
+            assert not torch.allclose(baseline_out, patched_out)
+
+        restored_out = default_test_model(data["input"])
+        assert torch.allclose(baseline_out, restored_out)
</code_context>

<issue_to_address>
**suggestion (testing):** Consider adding tests for multiple adapters and directionality.

Expanding tests to cover multiple adapters, various directions, and edge cases like mismatched input/output shapes will improve test coverage and reliability.
</issue_to_address>

### Comment 5
<location> `tests/weights/test_pruning.py:15-17` </location>
<code_context>
def _importance_cb_skip_weight(parameter, parameter_name, **_):
    if parameter_name == "weight":
        return None
    return parameter

</code_context>

<issue_to_address>
**suggestion (code-quality):** We've found these issues:

- Lift code into else after jump in control flow ([`reintroduce-else`](https://docs.sourcery.ai/Reference/Default-Rules/refactorings/reintroduce-else/))
- Replace if statement with if expression ([`assign-if-exp`](https://docs.sourcery.ai/Reference/Default-Rules/refactorings/assign-if-exp/))

```suggestion
    return None if parameter_name == "weight" else parameter
```
</issue_to_address>

### Comment 6
<location> `tests/weights/test_pruning.py:21-23` </location>
<code_context>
def _importance_cb_skip_bias(parameter, parameter_name, **_):
    if parameter_name == "bias":
        return None
    return parameter

</code_context>

<issue_to_address>
**suggestion (code-quality):** We've found these issues:

- Lift code into else after jump in control flow ([`reintroduce-else`](https://docs.sourcery.ai/Reference/Default-Rules/refactorings/reintroduce-else/))
- Replace if statement with if expression ([`assign-if-exp`](https://docs.sourcery.ai/Reference/Default-Rules/refactorings/assign-if-exp/))

```suggestion
    return None if parameter_name == "bias" else parameter
```
</issue_to_address>

Sourcery is free for open source - if you like our reviews please consider sharing them ✨
Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.

Comment thread src/tdhook/weights/adapters.py
Comment thread tests/weights/test_pruning.py
Comment thread tests/weights/test_adapters.py
Comment thread tests/weights/test_pruning.py Outdated
Comment thread tests/weights/test_pruning.py Outdated
@codecov

codecov Bot commented Sep 24, 2025

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 97.41379% with 3 lines in your changes missing coverage. Please review.
✅ Project coverage is 95.81%. Comparing base (7c93ebd) to head (bd4aaca).
⚠️ Report is 16 commits behind head on main.

Files with missing lines Patch % Lines
src/tdhook/contexts.py 75.00% 2 Missing ⚠️
src/tdhook/weights/pruning.py 98.48% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main       #9      +/-   ##
==========================================
+ Coverage   94.87%   95.81%   +0.94%     
==========================================
  Files          29       30       +1     
  Lines        1815     1913      +98     
==========================================
+ Hits         1722     1833     +111     
+ Misses         93       80      -13     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
@Xmaster6y Xmaster6y merged commit 7490a7a into main Sep 24, 2025
7 checks passed
@Xmaster6y Xmaster6y deleted the methods branch September 24, 2025 13:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant