from meld import interfaces
from meld.system import restraints
from meld.system import param_sampling
from meld.system import mapping
import collections
from typing import List, Tuple, Optional, Union, Set, Dict, DefaultDict
[docs]class RestraintTracker:
"""
A data structure to keep track of restraints, groups, and collections.
For restraints, we keep track of the dependence on scalers,
ramps, positioners, and peak mappings. We only update a
restraint when those dependencies have changed.
For groups and collections, we keep track of which ones
depend on parameter sampling. Only those ones are updated
each step.
"""
param_manager: param_sampling.ParameterManager
peak_mapper: mapping.PeakMapManager
distance_restraints: List[restraints.DistanceRestraint]
hyperbolic_distance_restraints: List[restraints.HyperbolicDistanceRestraint]
torsion_restraints: List[restraints.TorsionRestraint]
dist_prof_restraints: List[restraints.DistProfileRestraint]
torsion_profile_restraints: List[restraints.TorsProfileRestraint]
gmm_restraints: List[restraints.GMMDistanceRestraint]
groups_with_dep: List[Tuple[restraints.RestraintGroup, int]]
collections_with_dep: List[Tuple[restraints.SelectivelyActiveCollection, int]]
scaler_map: DefaultDict[restraints.RestraintScaler, List[Tuple[str, int]]]
ramp_map: DefaultDict[restraints.TimeRamp, List[Tuple[str, int]]]
positioner_map: DefaultDict[restraints.Positioner, List[Tuple[str, int]]]
peak_mapping_map: DefaultDict[mapping.PeakMapping, List[Tuple[str, int]]]
scaler_values: Dict[restraints.RestraintScaler, float]
ramp_values: Dict[restraints.TimeRamp, float]
positioner_values: Dict[restraints.Positioner, float]
peak_mapping_values: Dict[mapping.PeakMapping, int]
need_update: Set[Tuple[str, int]]
[docs] def __init__(
self,
param_manager: param_sampling.ParameterManager,
peak_mapper: mapping.PeakMapManager,
):
self.param_manager = param_manager
self.peak_mapper = peak_mapper
# These hold lists of meld restraints in the order that they were added
# to the system.
self.distance_restraints = []
self.hyperbolic_distance_restraints = []
self.torsion_restraints = []
self.dist_prof_restraints = []
self.torsion_profile_restraints = []
self.gmm_restraints = []
self.groups_with_dep = []
self.collections_with_dep = []
# These map from scalers, ramps, etc to the restraints that depend on them.
self.scaler_map = collections.defaultdict(list)
self.ramp_map = collections.defaultdict(list)
self.positioner_map = collections.defaultdict(list)
self.peak_mapping_map = collections.defaultdict(list)
# These maintain the previous values for these quantities
self.scaler_values = {}
self.ramp_values = {}
self.positioner_values = {}
self.peak_mapping_values = {}
# We maintain a set of restraints that need to be updated.
self.need_update = set()
def update(self, alpha: float, timestep: int, state: interfaces.IState):
self._update_scalers(alpha)
self._update_ramps(timestep)
self._update_positioners(alpha)
self._update_peak_mappings(state)
def get_and_reset_need_update(self) -> Set[Tuple[str, int]]:
need_update = self.need_update
self.need_update = set()
return need_update
def _update_scalers(self, alpha: float):
for scaler in self.scaler_values:
old_value = self.scaler_values[scaler]
new_value = scaler(alpha)
if new_value != old_value:
for category, index in self.scaler_map[scaler]:
self.need_update.add((category, index))
self.scaler_values[scaler] = new_value
def _update_ramps(self, timestep: int):
for ramp in self.ramp_values:
old_value = self.ramp_values[ramp]
new_value = ramp(timestep)
if new_value != old_value:
for category, index in self.ramp_map[ramp]:
self.need_update.add((category, index))
self.ramp_values[ramp] = new_value
def _update_positioners(self, alpha: float):
for positioner in self.positioner_values:
old_value = self.positioner_values[positioner]
new_value = positioner(alpha)
if new_value != old_value:
for category, index in self.positioner_map[positioner]:
self.need_update.add((category, index))
self.positioner_values[positioner] = new_value
def _update_peak_mappings(self, state: interfaces.IState):
for peak_mapping in self.peak_mapping_values:
old_value = self.peak_mapping_values[peak_mapping]
new_value = self.peak_mapper.extract_value(peak_mapping, state.mappings)
if isinstance(new_value, mapping.NotMapped):
new_value = -1
if new_value != old_value:
for category, index in self.peak_mapping_map[peak_mapping]:
self.need_update.add((category, index))
self.peak_mapping_values[peak_mapping] = new_value
def add_distance_restraint(
self,
rest: restraints.DistanceRestraint,
alpha: float,
timestep: int,
state: interfaces.IState,
):
assert isinstance(rest, restraints.DistanceRestraint)
self.distance_restraints.append(rest)
index = len(self.distance_restraints) - 1
self.need_update.add(("distance", index))
self._add_scaler_dependency(rest.scaler, "distance", index, alpha)
self._add_ramp_dependency(rest.ramp, "distance", index, timestep)
self._add_positioner_dependency(rest.r1, "distance", index, alpha)
self._add_positioner_dependency(rest.r2, "distance", index, alpha)
self._add_positioner_dependency(rest.r3, "distance", index, alpha)
self._add_positioner_dependency(rest.r4, "distance", index, alpha)
self._add_peak_mapping_dependency(rest.atom_index_1, "distance", index, state)
self._add_peak_mapping_dependency(rest.atom_index_2, "distance", index, state)
def add_hyperbolic_distance_restraint(
self,
rest: restraints.HyperbolicDistanceRestraint,
alpha: float,
timestep: int,
state: interfaces.IState,
):
assert isinstance(rest, restraints.HyperbolicDistanceRestraint)
self.hyperbolic_distance_restraints.append(rest)
index = len(self.hyperbolic_distance_restraints) - 1
self.need_update.add(("hyperbolic_distance", index))
self._add_scaler_dependency(rest.scaler, "hyperbolic_distance", index, alpha)
self._add_ramp_dependency(rest.ramp, "hyperbolic_distance", index, timestep)
def add_torsion_restraint(
self,
rest: restraints.TorsionRestraint,
alpha: float,
timestep: int,
state: interfaces.IState,
):
assert isinstance(rest, restraints.TorsionRestraint)
self.torsion_restraints.append(rest)
index = len(self.torsion_restraints) - 1
self.need_update.add(("torsion", index))
self._add_scaler_dependency(rest.scaler, "torsion", index, alpha)
self._add_ramp_dependency(rest.ramp, "torsion", index, timestep)
def add_distance_profile_restraint(
self,
rest: restraints.DistProfileRestraint,
alpha: float,
timestep: int,
state: interfaces.IState,
):
assert isinstance(rest, restraints.DistProfileRestraint)
self.dist_prof_restraints.append(rest)
index = len(self.dist_prof_restraints) - 1
self.need_update.add(("dist_profile", index))
self._add_scaler_dependency(rest.scaler, "dist_profile", index, alpha)
self._add_ramp_dependency(rest.ramp, "dist_profile", index, timestep)
def add_torsion_profile_restraint(
self,
rest: restraints.TorsProfileRestraint,
alpha: float,
timestep: int,
state: interfaces.IState,
):
assert isinstance(rest, restraints.TorsProfileRestraint)
self.torsion_profile_restraints.append(rest)
index = len(self.torsion_profile_restraints) - 1
self.need_update.add(("tors_profile", index))
self._add_scaler_dependency(rest.scaler, "tors_profile", index, alpha)
self._add_ramp_dependency(rest.ramp, "tors_profile", index, timestep)
def add_gmm_distance_restraint(
self,
rest: restraints.GMMDistanceRestraint,
alpha: float,
timestep: int,
state: interfaces.IState,
):
assert isinstance(rest, restraints.GMMDistanceRestraint)
self.gmm_restraints.append(rest)
index = len(self.gmm_restraints) - 1
self.need_update.add(("gmm", index))
self._add_scaler_dependency(rest.scaler, "gmm", index, alpha)
self._add_ramp_dependency(rest.ramp, "gmm", index, timestep)
def _add_scaler_dependency(
self,
scaler: restraints.RestraintScaler,
category: str,
index: int,
alpha: float,
):
if not isinstance(scaler, restraints.ConstantScaler):
self.scaler_map[scaler].append((category, index))
if scaler not in self.scaler_values:
self.scaler_values[scaler] = scaler(alpha)
else:
assert scaler(alpha) == self.scaler_values[scaler]
def _add_ramp_dependency(
self, ramp: restraints.TimeRamp, category: str, index: int, timestep: int
):
if not isinstance(ramp, restraints.ConstantRamp):
self.ramp_map[ramp].append((category, index))
if ramp not in self.ramp_values:
self.ramp_values[ramp] = ramp(timestep)
else:
assert ramp(timestep) == self.ramp_values[ramp]
def _add_positioner_dependency(
self, positioner: restraints.Positioner, category: str, index: int, alpha: float
):
if not isinstance(positioner, restraints.ConstantPositioner):
self.positioner_map[positioner].append((category, index))
if positioner not in self.positioner_values:
self.positioner_values[positioner] = positioner(alpha)
else:
assert positioner(alpha) == self.positioner_values[positioner]
def _add_peak_mapping_dependency(
self,
peak_mapping: Union[int, mapping.PeakMapping],
category: str,
index: int,
state: interfaces.IState,
):
if not isinstance(peak_mapping, int):
self.peak_mapping_map[peak_mapping].append((category, index))
if peak_mapping not in self.peak_mapping_values:
new_value = self.peak_mapper.extract_value(peak_mapping, state.mappings)
if isinstance(new_value, mapping.NotMapped):
new_value = -1
self.peak_mapping_values[peak_mapping] = new_value
else:
value = self.peak_mapper.extract_value(peak_mapping, state.mappings)
assert value == self.peak_mapping_values[peak_mapping]