Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
0a77f05
add neuroevolution main test file
liyc5929 Dec 18, 2024
5821a15
Add `SupervisedLearningProblem` neuroevolution module and basic singl…
liyc5929 Dec 20, 2024
69b5bc7
Add `ParamsAndVector` class as well as the corresponding testing codes.
liyc5929 Dec 24, 2024
fb16a3d
Update workflow settings and `ParamsAndVector` batched operators.
liyc5929 Dec 25, 2024
90bded7
Add unit test for `ParamsAndVector` module
liyc5929 Dec 25, 2024
b974241
Fix code of `std_workflow`
liyc5929 Dec 25, 2024
35a0044
Update inner state settings for `SupervisedLearningProblem`.
liyc5929 Dec 26, 2024
666cd21
Update `SupervisedLearningProblem` process framework and test.
liyc5929 Dec 26, 2024
2759f85
main_neuroevolution_yanchenli.py
liyc5929 Dec 26, 2024
2d4b289
Update `SupervisedLearningProblem` processing code.
liyc5929 Dec 26, 2024
3c24574
Add raw executable neuroevolution module support and test.
liyc5929 Dec 30, 2024
91503b2
Merge branch from `origin/main` and upload `PSO` algorithm setting.
liyc5929 Dec 30, 2024
a05eea2
Add hash setting for global data loader; Update neuroevolution demo.
liyc5929 Dec 31, 2024
e214360
Add while-loop implementation for `SupervisedLearningProblem`.
liyc5929 Jan 2, 2025
27a31ad
Fix original neuroevolution process of `SupervisedLearningProblem`.
liyc5929 Jan 2, 2025
fac2507
Update while-loop test for `SupervisedLearningProblem`.
liyc5929 Jan 3, 2025
707570c
allow TraceCond to be used with non-pure member methods
sses7757 Jan 2, 2025
ce76fef
rename wrappers for better debugging
sses7757 Jan 3, 2025
9597974
Add unit test for `SupervisedLearningProblem`
liyc5929 Jan 3, 2025
6df3403
Add single state model forward and criterion for single evaluation of…
liyc5929 Jan 3, 2025
fb181db
Add single-run and population-based neuroevolution testing framework.
liyc5929 Jan 3, 2025
6e5a193
Add model params keys to model state keys map for `SupervisedLearning…
liyc5929 Jan 3, 2025
10c6153
Add code refactoring for `SupervisedLearningProblem` and its tests.
liyc5929 Jan 3, 2025
cd3a665
Add single-run neuroevolution test for `SupervisedLearningProblem`.
liyc5929 Jan 4, 2025
e4df99c
fix ModuleBase setup to reset JIT-ed module; fix EvalMonitor
sses7757 Jan 6, 2025
e97cf7d
Add data pre-loading and neuroevolution process rectification for `Su…
liyc5929 Jan 7, 2025
4d184b2
fix use_state and JIT
sses7757 Jan 3, 2025
81b7ba4
adding cache support for TraceCond and TraceWhile
sses7757 Jan 3, 2025
5012ccc
adding TraceSwitch
sses7757 Jan 3, 2025
b2e92cf
fix JIT and control flow in-place operations
sses7757 Jan 3, 2025
b2a1e93
improve JIT script
sses7757 Jan 4, 2025
05d54d0
Merge remote-tracking branch 'origin/evoxtorch-main' into evoxtorch-d…
liyc5929 Jan 7, 2025
90f7d95
update this branch with evoxtorch-main
sses7757 Jan 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -167,4 +167,5 @@ cython_debug/
# Test
tests
evox
**/_future/*
**/_future/*
data
11 changes: 9 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,15 @@ python ./unit_test/algorithms/pso_variants/test_pso.py
python ./unit_test/algorithms/pso_variants/test_sl_pso_gs.py
python ./unit_test/algorithms/pso_variants/test_sl_pso_us.py

python ./unit_test/core/test_jit_util.py
python ./unit_test/core/test_module.py
python ./unit_test/core/test_jit_util.py
python ./unit_test/core/test_module.py

python ./unit_test/problems/test_hpo_wrapper.py
python ./unit_test/problems/test_supervised_learning.py

python ./unit_test/utils/test_jit_fix.py
python ./unit_test/utils/test_parameters_and_vector.py
python ./unit_test/utils/test_while.py

python ./unit_test/workflows/test_std_workflow.py
```
Expand Down
1 change: 1 addition & 0 deletions src/core/_vmap_fix.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def batched_random_like(rand_func: Callable, like_tensor: torch.Tensor, **kwargs
return batch_rand_values



_original_rand = torch.rand
_original_randn = torch.randn
_original_randint = torch.randint
Expand Down
13 changes: 10 additions & 3 deletions src/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.train(False)
self.__static_names__ = []
self._hash_id_ = None

def eval(self):
assert False, "`ModuleBase.eval()` shall never be invoked to prevent ambiguity."
Expand All @@ -166,6 +167,9 @@ def setup(self, *args, **kwargs):
The static initialization can still be written in the `__init__` while the mutable initialization cannot.
Therefore, multiple calls of `setup` for multiple initializations are possible.
"""
if hasattr(self, _WRAPPING_MODULE_NAME):
wrapper: _WrapClassBase = object.__getattribute__(self, _WRAPPING_MODULE_NAME)
wrapper.__jit_module__ = None
return self

def load_state_dict(self, state_dict: Mapping[str, torch.Tensor], copy: bool = False, **kwargs):
Expand Down Expand Up @@ -263,6 +267,11 @@ def to(self, *args, **kwargs) -> "ModuleBase":
self.__setattr_inner__(k, val)
return self

def __hash__(self):
if self._hash_id_ is None:
self._hash_id_ = super().__hash__()
return self._hash_id_

def __getattribute__(self, name):
if not tracing_or_using_state() or name == _WRAPPING_MODULE_NAME or _is_magic(name):
return super(nn.Module, self).__getattribute__(name)
Expand Down Expand Up @@ -494,9 +503,7 @@ def __repr__(self) -> str:
)

def __hash__(self) -> int:
return object.__hash__(
self.__inner_module__ if self.__jit_module__ is None else self.__jit_module__
)
return object.__hash__(self.__inner_module__)

def __format__(self, format_spec: str) -> str:
return object.__format__(
Expand Down
Empty file added src/problems/__init__.py
Empty file.
1 change: 1 addition & 0 deletions src/problems/neuroevolution/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .supervised_learning import SupervisedLearningProblem
269 changes: 269 additions & 0 deletions src/problems/neuroevolution/supervised_learning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
import copy
import types
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from typing import Dict, Tuple, Iterable, Iterator

from ...core import Problem, jit_class, use_state, vmap, jit
from ...core.module import assign_load_state_dict


__supervised_data__: Dict[int, Dict[str, DataLoader | Iterable | Iterator | Tuple]] = None

# cSpell:words vmapped


@jit_class
class SupervisedLearningProblem(Problem):
"""The supervised learning problem to test a model's parameters or a batch of parameters with given data and criterion."""

def __init__(
self,
model: nn.Module,
data_loader: DataLoader,
criterion: nn.Module,
pop_size: int | None = None,
device: torch.device | None = None,
):
"""Initialize the `SupervisedLearningProblem`.

Args:
model (`nn.Module`): The neural network model whose parameters need to be evaluated.
data_loader (`DataLoader`): The data loader providing the dataset for evaluation.
criterion (`nn.Module`): The loss function used to evaluate the parameters' performance.
pop_size (`int`, optional): The size of the population (batch size of the parameters) to be evaluated. Defaults to None for single-run mode.
device (`torch.device`, optional): The device to run the computations on. Defaults to the current default device.

Raises:
`RuntimeError`: If the data loader contains no items.
"""
super().__init__()
device = torch.get_default_device() if device is None else device
pop_size = 1 if pop_size is None else pop_size

# Global data loader info registration
global __supervised_data__
if __supervised_data__ is None:
__supervised_data__ = {}
instance_id = hash(self)
if instance_id not in __supervised_data__.keys():
__supervised_data__[instance_id] = {
"data_loader_ref": data_loader,
"data_loader_iter": None,
"data_next_cache": None,
}
try:
dummy_inputs, dummy_labels = next(iter(data_loader))
except StopIteration:
raise RuntimeError(
f"The `data_loader` of `{self.__class__.__name__}` must contain at least one item."
)
dummy_inputs: torch.Tensor = dummy_inputs.to(device=device)
dummy_labels: torch.Tensor = dummy_labels.to(device=device)

# Model initialization
inference_model = copy.deepcopy(model)
inference_model = inference_model.to(device=device)
for _, value in inference_model.named_parameters():
value.requires_grad = False
inference_model.load_state_dict = types.MethodType(assign_load_state_dict, inference_model)

# JITed and vmapped model state forward initialization
state_forward = use_state(lambda: inference_model.forward)
model_init_state = state_forward.init_state(clone=False)
self._jit_state_forward, (_, dummy_single_logits) = jit(
state_forward,
trace=True,
lazy=False,
example_inputs=(model_init_state, dummy_inputs),
return_dummy_output=True,
)
vmap_state_forward = vmap(state_forward, in_dims=(0, None))
vmap_model_init_state = vmap_state_forward.init_state(pop_size)
self._jit_vmap_state_forward, (_, dummy_vmap_logits) = jit(
vmap_state_forward,
trace=True,
lazy=False,
example_inputs=(vmap_model_init_state, dummy_inputs),
return_dummy_output=True,
)

# Building map from model parameters key to model state key
model_params = dict(inference_model.named_parameters())
self.param_to_state_key_map: Dict[str, str] = {
params_key: state_key
for state_key, state_value in model_init_state.items()
for params_key, params_value in model_params.items()
if torch.equal(state_value, params_value)
}

# JITed and vmapped state criterion initialization
state_criterion = use_state(lambda: criterion.forward)
criterion_init_state = state_criterion.init_state()
self._jit_state_criterion = jit(
state_criterion,
trace=True,
lazy=False,
example_inputs=(
criterion_init_state,
dummy_single_logits,
dummy_labels,
),
)
vmap_state_criterion = vmap(state_criterion, in_dims=(0, 0, None))
vmap_criterion_init_state = vmap_state_criterion.init_state(pop_size)
self._jit_vmap_state_criterion = jit(
vmap_state_criterion,
trace=True,
lazy=False,
example_inputs=(
vmap_criterion_init_state,
dummy_vmap_logits,
dummy_labels,
),
)
self.vmap_criterion_init_state = vmap_criterion_init_state

# Model parameters and buffers registration
self._model_buffers = {
key: value
for key, value in model_init_state.items()
if key not in self.param_to_state_key_map
}
sample_param_key, sample_state_key = next(iter(self.param_to_state_key_map.items()))
self._sample_param_key = sample_param_key
self._sample_param_ndim = model_init_state[sample_state_key].ndim

# Other member variables registration
self.criterion_init_state = criterion_init_state

def __del__(self):
global __supervised_data__
__supervised_data__.pop(self._hash_id_, None)
super().__del__()

@torch.jit.ignore
def _data_loader_reset(self) -> None:
global __supervised_data__
data_info = __supervised_data__[self._hash_id_]
data_info["data_loader_iter"] = iter(data_info["data_loader_ref"])
try:
data_info["data_next_cache"] = next(data_info["data_loader_iter"])
except StopIteration:
data_info["data_next_cache"] = None

@torch.jit.ignore
def _data_loader_next(self) -> Tuple[torch.Tensor, torch.Tensor]:
global __supervised_data__
data_info = __supervised_data__[self._hash_id_]
next_data = data_info["data_next_cache"]
try:
data_info["data_next_cache"] = next(data_info["data_loader_iter"])
except StopIteration:
data_info["data_next_cache"] = None
return next_data

@torch.jit.ignore
def _data_loader_has_next(self) -> bool:
global __supervised_data__
return __supervised_data__[self._hash_id_]["data_next_cache"] is not None

def _vmap_evaluate(
self,
pop_params: Dict[str, nn.Parameter],
num_map: int,
device: torch.device,
):
# Initialize model and criterion states
model_buffers = { # expand dimensions for model buffers
key: value.unsqueeze(0).expand([num_map] + list(value.shape))
for key, value in self._model_buffers.items()
}
state_params = {self.param_to_state_key_map[key]: value for key, value in pop_params.items()}
model_state = model_buffers
model_state.update(state_params)
criterion_state = {key: value.clone() for key, value in self.vmap_criterion_init_state.items()}

total_result = torch.zeros(num_map, device=device)
total_inputs = 0
self._data_loader_reset()
while self._data_loader_has_next():
inputs, labels = self._data_loader_next()
inputs = inputs.to(device=device, non_blocking=True)
labels = labels.to(device=device, non_blocking=True)

model_state, logits = self._jit_vmap_state_forward(
model_state,
inputs,
)
criterion_state, result = self._jit_vmap_state_criterion(
criterion_state,
logits,
labels,
)
total_result += result * inputs.size(0)
total_inputs += inputs.size(0)
pop_fitness = total_result / total_inputs
return pop_fitness

def _single_evaluate(
self,
params: Dict[str, nn.Parameter],
device: torch.device,
):
# Initialize model and criterion states
model_buffers = {key: value.clone() for key, value in self._model_buffers.items()}
params = {self.param_to_state_key_map[key]: value.squeeze(0) for key, value in params.items()}
model_state = model_buffers
model_state.update(params)
criterion_state = {key: value.clone() for key, value in self.criterion_init_state.items()}

# Calculate population fitness
total_result = torch.tensor(0.0, device=device)
total_inputs = 0
self._data_loader_reset()
while self._data_loader_has_next():
inputs, labels = self._data_loader_next()
inputs = inputs.to(device=device, non_blocking=True)
labels = labels.to(device=device, non_blocking=True)

model_state, logits = self._jit_state_forward(
model_state,
inputs,
)
criterion_state, result = self._jit_state_criterion(
criterion_state,
logits,
labels,
)
total_result += result * inputs.size(0)
total_inputs += inputs.size(0)
fitness = total_result / total_inputs
return fitness.unsqueeze(0)

def evaluate(self, pop_params: Dict[str, nn.Parameter]) -> torch.Tensor:
"""Evaluate the fitness of a population (batch) of model parameters.

Args:
pop_params (`Dict[str, nn.Parameter]`): A dictionary of parameters where each key is a parameter name and each value is a tensor of shape (batch_size, *param_shape) representing the batched parameters of batched models.

Returns:
A tensor of shape (batch_size,) containing the fitness of each sample in the population.
"""
pop_params_value = pop_params[self._sample_param_key]
assert (
pop_params_value.ndim == self._sample_param_ndim + 1
), f"Expected exactly one batch dimension, got {pop_params_value.ndim - self._sample_param_ndim}"
if pop_params_value.size(0) != 1:
pop_fitness = self._vmap_evaluate(
pop_params,
pop_params_value.size(0),
pop_params_value.device,
)
else:
pop_fitness = self._single_evaluate(
pop_params,
pop_params_value.device,
)
return pop_fitness
5 changes: 3 additions & 2 deletions src/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
__all__ = ["switch", "clamp", "clip", "maximum", "minimum", "TracingWhile", "TracingCond", "TracingSwitch"]
__all__ = ["switch", "clamp", "clip", "maximum", "minimum", "TracingWhile", "TracingCond", "TracingSwitch", "ParamsAndVector"]

from .jit_fix_operator import switch, clamp, clip, maximum, minimum
from .control_flow import TracingWhile, TracingCond, TracingSwitch
from .parameters_and_vector import ParamsAndVector


################### NOTICE ###################
Expand All @@ -12,4 +13,4 @@
# 4. Python's while loops cannot be vector-mapped directly, please use the function in this module instead.
# 5. DO NOT directly use `torch.jit.script` to JIT `torch.vmap` functions. You may get unexpected results without any warning.
#
################# END NOTICE #################
################# END NOTICE #################
2 changes: 1 addition & 1 deletion src/utils/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1473,4 +1473,4 @@ def vmap_switch(self, branch_idx: torch.Tensor, *x: torch.Tensor):
for fn in state_branch_fns:
fn.set_state(state_out)
# return
return final_output
return final_output
Loading