Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions openmmtools/multistate/multistatereporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,9 @@ def __init__(self, storage, open_mode=None,
self._analysis_particle_indices = tuple(analysis_particle_indices)
if open_mode is not None:
self.open(open_mode)
# Flag to check whether to overwrite real time statistics file
self._overwrite_statistics = True
# TODO: Maybe we want to expose this flag to control ovrwriting/appending
# Flag to check whether to overwrite real time statistics file -- Defaults to append
self._overwrite_statistics = False

@property
def filepath(self):
Expand Down Expand Up @@ -266,6 +267,7 @@ def open(self, mode='r', convention='ReplicaExchange', netcdf_format='NETCDF4'):
self.close()

# Create directory if we want to write.
# TODO: We probably want to check here specifically for w when we want to write
if mode != 'r':
for storage_path in self._storage_paths:
# normpath() transform '' to '.' for makedirs().
Expand Down
6 changes: 6 additions & 0 deletions openmmtools/multistate/multistatesampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,12 @@ def create(self, thermodynamic_states: list, sampler_states, storage,
raise RuntimeError('Storage file {} already exists; cowardly '
'refusing to overwrite.'.format(self._reporter.filepath))

# Make sure online analysis interval is a multiples of the reporter's checkpoint interval
# this avoids having redundant iteration information in the real time yaml files
if self.online_analysis_interval % self._reporter.checkpoint_interval != 0:
raise ValueError(f"Online analysis interval: {self.online_analysis_interval}, must be a "
f"multiple of the checkpoint interval: {self._reporter.checkpoint_interval}")

# Make sure sampler_states is an iterable of SamplerStates.
if isinstance(sampler_states, states.SamplerState):
sampler_states = [sampler_states]
Expand Down
2 changes: 1 addition & 1 deletion openmmtools/tests/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def test_mcmc_move_context_cache_shallow_copy():
)
# Create temporary reporter storage file
with tempfile.NamedTemporaryFile() as storage:
reporter = multistate.MultiStateReporter(storage.name, checkpoint_interval=999999)
reporter = multistate.MultiStateReporter(storage.name, checkpoint_interval=200)
simulation.create(
thermodynamic_states=thermodynamic_states,
sampler_states=SamplerState(
Expand Down
15 changes: 10 additions & 5 deletions openmmtools/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ def teardown_class(cls):
def run(self, include_unsampled_states=False):
# Create and configure simulation object
move = mmtools.mcmc.MCDisplacementMove(displacement_sigma=1.0*unit.angstroms)
simulation = self.SAMPLER(mcmc_moves=move, number_of_iterations=self.N_ITERATIONS)
simulation = self.SAMPLER(mcmc_moves=move, number_of_iterations=self.N_ITERATIONS,
online_analysis_interval=self.N_ITERATIONS)

# Define file for temporary storage.
with temporary_directory() as tmp_dir:
Expand Down Expand Up @@ -587,6 +588,7 @@ class TestBaseMultistateSampler(object):

N_SAMPLERS = 3
N_STATES = 5
# TODO: Once we migrate to pytest SAMPLER and REPORTER should be fixtures!
SAMPLER = MultiStateSampler
REPORTER = MultiStateReporter

Expand Down Expand Up @@ -999,7 +1001,7 @@ def actual_stored_properties_check(self, additional_properties=None):
thermodynamic_states, sampler_states, unsampled_states = copy.deepcopy(self.alanine_test)

with self.temporary_storage_path() as storage_path:
sampler = self.SAMPLER(number_of_iterations=5)
sampler = self.SAMPLER(number_of_iterations=5, online_analysis_interval=1)
reporter = self.REPORTER(storage_path, checkpoint_interval=1)
self.call_sampler_create(sampler, reporter,
thermodynamic_states, sampler_states,
Expand Down Expand Up @@ -1451,7 +1453,8 @@ def test_online_analysis_works(self):
sampler = self.SAMPLER(mcmc_moves=move, number_of_iterations=n_iterations,
online_analysis_interval=online_interval,
online_analysis_minimum_iterations=3)
self.call_sampler_create(sampler, storage_path,
reporter = self.REPORTER(storage_path, checkpoint_interval=online_interval)
self.call_sampler_create(sampler, reporter,
thermodynamic_states, sampler_states,
unsampled_states)
# Run
Expand Down Expand Up @@ -1510,7 +1513,8 @@ def test_online_analysis_stops(self):
online_analysis_interval=online_interval,
online_analysis_minimum_iterations=0,
online_analysis_target_error=np.inf) # use infinite error to stop right away
self.call_sampler_create(sampler, storage_path,
reporter = self.REPORTER(storage_path, checkpoint_interval=online_interval)
self.call_sampler_create(sampler, reporter,
thermodynamic_states, sampler_states,
unsampled_states)
# Run
Expand Down Expand Up @@ -1570,7 +1574,8 @@ def test_real_time_analysis_yaml(self):
move = mmtools.mcmc.IntegratorMove(openmm.VerletIntegrator(1.0 * unit.femtosecond), n_steps=1)
sampler = self.SAMPLER(mcmc_moves=move, number_of_iterations=n_iterations,
online_analysis_interval=online_interval)
self.call_sampler_create(sampler, storage_path,
reporter = self.REPORTER(storage_path, checkpoint_interval=online_interval)
self.call_sampler_create(sampler, reporter,
thermodynamic_states, sampler_states,
unsampled_states)
# Run
Expand Down