Conversation
Contributor
Reviewer's GuideThis pull request introduces two major HookingContextFactory-based features—pruning and adapters—alongside enhancements to the caching context, a relaxation of the hook return-type check, and accompanying tests and minor cleanups. Entity relationship diagram for Adapters moduleerDiagram
HOOKED_MODULE {
adapters ModuleDict
}
ADAPTERS {
adapters Dict
cache_callback Callable
relative bool
directions List
}
HOOKED_MODULE ||--o{ ADAPTERS : uses
Entity relationship diagram for Pruning moduleerDiagram
PRUNING_CONTEXT {
_old_weights TensorDict
}
PRUNING {
importance_callback Callable
amount_to_prune float
modules_to_prune Dict
skip_modules Callable
relative bool
relative_path str
}
PRUNING_CONTEXT ||--o{ PRUNING : used_by
Class diagram for new Pruning and Adapters featuresclassDiagram
class HookingContextFactory {
}
class PruningContext {
- _old_weights
+ __enter__()
+ __exit__()
}
class Pruning {
- _importance_callback
- _amount_to_prune
- _modules_to_prune
- _skip_modules
- _relative
- _relative_path
+ _prepare_module()
+ default_skip()
}
class HookedModule {
}
class HookedModuleWithAdapters {
- adapters
}
class Adapters {
- _adapters
- _cache_callback
- _relative
- _directions
+ _hook_module()
}
class HookingContextWithCache {
- _cache
- _clear_cache
+ cache
+ clear()
+ __enter__()
}
HookingContextFactory <|-- Pruning
HookingContext <|-- PruningContext
HookingContextFactory <|-- Adapters
HookedModule <|-- HookedModuleWithAdapters
HookingContextWithCache <|-- Adapters
File-Level Changes
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
Contributor
There was a problem hiding this comment.
Hey there - I've reviewed your changes and they look great!
Prompt for AI Agents
Please address the comments from this code review:
## Individual Comments
### Comment 1
<location> `src/tdhook/weights/adapters.py:42-50` </location>
<code_context>
+ adapter_input = cache_proxy.resolve()
+ else:
+ adapter_input = kwargs.pop(DIRECTION_TO_RETURN[kwargs["direction"]])
+ return adapter(adapter_input, **kwargs)
+
+ return callback
</code_context>
<issue_to_address>
**suggestion (bug_risk):** Passing all kwargs to the adapter may cause unexpected keyword argument errors.
If the adapter's forward method doesn't accept arbitrary kwargs, this could cause runtime errors. Filter kwargs to only those accepted by the adapter, or clearly document the required adapter signature.
```suggestion
import inspect
def callback(**kwargs):
nonlocal adapter, cache_proxy
if cache_proxy is not None:
adapter_input = cache_proxy.resolve()
else:
adapter_input = kwargs.pop(DIRECTION_TO_RETURN[kwargs["direction"]])
# Filter kwargs to only those accepted by the adapter
adapter_params = inspect.signature(adapter).parameters
filtered_kwargs = {k: v for k, v in kwargs.items() if k in adapter_params}
return adapter(adapter_input, **filtered_kwargs)
return callback
```
</issue_to_address>
### Comment 2
<location> `tests/weights/test_pruning.py:79-88` </location>
<code_context>
+ torch.tensor(0.9),
+ )
+
+ def test_no_pruning(self):
+ model = nn.Linear(10, 10)
+ inp = torch.randn(10)
+ original_output = model(inp)
+
+ pruning = Pruning(importance_callback=_importance_cb, amount_to_prune=0)
+ ctx = pruning.prepare(model)
+ with ctx as hooked:
+ output = hooked(inp)
+ assert torch.allclose(output, original_output)
+
+ def test_skip_importance_cb(self):
</code_context>
<issue_to_address>
**suggestion (testing):** Add a test for negative pruning amounts.
Please add a test case for negative amount_to_prune to verify that the code correctly raises an error or handles the input as intended.
Suggested implementation:
```python
def test_negative_pruning_amount(self):
model = nn.Linear(10, 10)
with pytest.raises(ValueError):
Pruning(importance_callback=_importance_cb, amount_to_prune=-0.1)
def test_skip_importance_cb(self):
model = nn.Linear(10, 10)
original_state = (model.weight.clone(), model.bias.clone())
```
Make sure that `pytest` is imported at the top of the file if it is not already:
```python
import pytest
```
Also, ensure that the `Pruning` class raises a `ValueError` for negative `amount_to_prune`. If it does not, you will need to add this validation in the `Pruning` class implementation.
</issue_to_address>
### Comment 3
<location> `tests/weights/test_pruning.py:90-98` </location>
<code_context>
+ output = hooked(inp)
+ assert torch.allclose(output, original_output)
+
+ def test_skip_importance_cb(self):
+ model = nn.Linear(10, 10)
+ original_state = (model.weight.clone(), model.bias.clone())
+
+ pruning = Pruning(importance_callback=_importance_cb_skip_weight, amount_to_prune=0.5)
+ ctx = pruning.prepare(model)
+ with ctx:
+ assert torch.allclose(model.weight, original_state[0])
+ assert not torch.allclose(model.bias, original_state[1])
</code_context>
<issue_to_address>
**suggestion (testing):** Consider testing skip_modules logic.
Please add a test for skip_modules by passing a custom function and confirming that the intended modules are excluded from pruning.
</issue_to_address>
### Comment 4
<location> `tests/weights/test_adapters.py:30-28` </location>
<code_context>
+ restored_out = default_test_model(data["input"])
+ assert torch.allclose(baseline_out, restored_out)
+
+ def test_adapter_crosslayer(self, default_test_model):
+ data = TensorDict({"input": torch.randn(4, 10)}, batch_size=4)
+ baseline_out = default_test_model(data["input"]).detach().clone()
+
+ adapters = {"linear2": (_DoubleAdapter(), "linear1", "linear2")}
+ ctx_factory = Adapters(adapters=adapters)
+
+ with ctx_factory.prepare(default_test_model) as hooked:
+ patched_data = hooked(data.clone())
+ patched_out = patched_data["output"]
+ assert not torch.allclose(baseline_out, patched_out)
+
+ restored_out = default_test_model(data["input"])
+ assert torch.allclose(baseline_out, restored_out)
</code_context>
<issue_to_address>
**suggestion (testing):** Consider adding tests for multiple adapters and directionality.
Expanding tests to cover multiple adapters, various directions, and edge cases like mismatched input/output shapes will improve test coverage and reliability.
</issue_to_address>
### Comment 5
<location> `tests/weights/test_pruning.py:15-17` </location>
<code_context>
def _importance_cb_skip_weight(parameter, parameter_name, **_):
if parameter_name == "weight":
return None
return parameter
</code_context>
<issue_to_address>
**suggestion (code-quality):** We've found these issues:
- Lift code into else after jump in control flow ([`reintroduce-else`](https://docs.sourcery.ai/Reference/Default-Rules/refactorings/reintroduce-else/))
- Replace if statement with if expression ([`assign-if-exp`](https://docs.sourcery.ai/Reference/Default-Rules/refactorings/assign-if-exp/))
```suggestion
return None if parameter_name == "weight" else parameter
```
</issue_to_address>
### Comment 6
<location> `tests/weights/test_pruning.py:21-23` </location>
<code_context>
def _importance_cb_skip_bias(parameter, parameter_name, **_):
if parameter_name == "bias":
return None
return parameter
</code_context>
<issue_to_address>
**suggestion (code-quality):** We've found these issues:
- Lift code into else after jump in control flow ([`reintroduce-else`](https://docs.sourcery.ai/Reference/Default-Rules/refactorings/reintroduce-else/))
- Replace if statement with if expression ([`assign-if-exp`](https://docs.sourcery.ai/Reference/Default-Rules/refactorings/assign-if-exp/))
```suggestion
return None if parameter_name == "bias" else parameter
```
</issue_to_address>Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #9 +/- ##
==========================================
+ Coverage 94.87% 95.81% +0.94%
==========================================
Files 29 30 +1
Lines 1815 1913 +98
==========================================
+ Hits 1722 1833 +111
+ Misses 93 80 -13 ☔ View full report in Codecov by Sentry. |
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Key insights about the PR.
Linked Issues
Summary by Sourcery
Implement new Pruning and Adapters hook factories for weight manipulation, improve hook callback handling to accept None responses, extend caching context with cache clearing, and remove obsolete SAE module
New Features:
Enhancements:
Tests:
Chores: