Source code for meld.runner.transform.restraints.meld.tracker

from meld import interfaces
from meld.system import restraints
from meld.system import param_sampling
from meld.system import mapping
import collections
from typing import List, Tuple, Optional, Union, Set, Dict, DefaultDict


[docs]class RestraintTracker: """ A data structure to keep track of restraints, groups, and collections. For restraints, we keep track of the dependence on scalers, ramps, positioners, and peak mappings. We only update a restraint when those dependencies have changed. For groups and collections, we keep track of which ones depend on parameter sampling. Only those ones are updated each step. """ param_manager: param_sampling.ParameterManager peak_mapper: mapping.PeakMapManager distance_restraints: List[restraints.DistanceRestraint] hyperbolic_distance_restraints: List[restraints.HyperbolicDistanceRestraint] torsion_restraints: List[restraints.TorsionRestraint] dist_prof_restraints: List[restraints.DistProfileRestraint] torsion_profile_restraints: List[restraints.TorsProfileRestraint] gmm_restraints: List[restraints.GMMDistanceRestraint] groups_with_dep: List[Tuple[restraints.RestraintGroup, int]] collections_with_dep: List[Tuple[restraints.SelectivelyActiveCollection, int]] scaler_map: DefaultDict[restraints.RestraintScaler, List[Tuple[str, int]]] ramp_map: DefaultDict[restraints.TimeRamp, List[Tuple[str, int]]] positioner_map: DefaultDict[restraints.Positioner, List[Tuple[str, int]]] peak_mapping_map: DefaultDict[mapping.PeakMapping, List[Tuple[str, int]]] scaler_values: Dict[restraints.RestraintScaler, float] ramp_values: Dict[restraints.TimeRamp, float] positioner_values: Dict[restraints.Positioner, float] peak_mapping_values: Dict[mapping.PeakMapping, int] need_update: Set[Tuple[str, int]]
[docs] def __init__( self, param_manager: param_sampling.ParameterManager, peak_mapper: mapping.PeakMapManager, ): self.param_manager = param_manager self.peak_mapper = peak_mapper # These hold lists of meld restraints in the order that they were added # to the system. self.distance_restraints = [] self.hyperbolic_distance_restraints = [] self.torsion_restraints = [] self.dist_prof_restraints = [] self.torsion_profile_restraints = [] self.gmm_restraints = [] self.groups_with_dep = [] self.collections_with_dep = [] # These map from scalers, ramps, etc to the restraints that depend on them. self.scaler_map = collections.defaultdict(list) self.ramp_map = collections.defaultdict(list) self.positioner_map = collections.defaultdict(list) self.peak_mapping_map = collections.defaultdict(list) # These maintain the previous values for these quantities self.scaler_values = {} self.ramp_values = {} self.positioner_values = {} self.peak_mapping_values = {} # We maintain a set of restraints that need to be updated. self.need_update = set()
def update(self, alpha: float, timestep: int, state: interfaces.IState): self._update_scalers(alpha) self._update_ramps(timestep) self._update_positioners(alpha) self._update_peak_mappings(state) def get_and_reset_need_update(self) -> Set[Tuple[str, int]]: need_update = self.need_update self.need_update = set() return need_update def _update_scalers(self, alpha: float): for scaler in self.scaler_values: old_value = self.scaler_values[scaler] new_value = scaler(alpha) if new_value != old_value: for category, index in self.scaler_map[scaler]: self.need_update.add((category, index)) self.scaler_values[scaler] = new_value def _update_ramps(self, timestep: int): for ramp in self.ramp_values: old_value = self.ramp_values[ramp] new_value = ramp(timestep) if new_value != old_value: for category, index in self.ramp_map[ramp]: self.need_update.add((category, index)) self.ramp_values[ramp] = new_value def _update_positioners(self, alpha: float): for positioner in self.positioner_values: old_value = self.positioner_values[positioner] new_value = positioner(alpha) if new_value != old_value: for category, index in self.positioner_map[positioner]: self.need_update.add((category, index)) self.positioner_values[positioner] = new_value def _update_peak_mappings(self, state: interfaces.IState): for peak_mapping in self.peak_mapping_values: old_value = self.peak_mapping_values[peak_mapping] new_value = self.peak_mapper.extract_value(peak_mapping, state.mappings) if isinstance(new_value, mapping.NotMapped): new_value = -1 if new_value != old_value: for category, index in self.peak_mapping_map[peak_mapping]: self.need_update.add((category, index)) self.peak_mapping_values[peak_mapping] = new_value def add_distance_restraint( self, rest: restraints.DistanceRestraint, alpha: float, timestep: int, state: interfaces.IState, ): assert isinstance(rest, restraints.DistanceRestraint) self.distance_restraints.append(rest) index = len(self.distance_restraints) - 1 self.need_update.add(("distance", index)) self._add_scaler_dependency(rest.scaler, "distance", index, alpha) self._add_ramp_dependency(rest.ramp, "distance", index, timestep) self._add_positioner_dependency(rest.r1, "distance", index, alpha) self._add_positioner_dependency(rest.r2, "distance", index, alpha) self._add_positioner_dependency(rest.r3, "distance", index, alpha) self._add_positioner_dependency(rest.r4, "distance", index, alpha) self._add_peak_mapping_dependency(rest.atom_index_1, "distance", index, state) self._add_peak_mapping_dependency(rest.atom_index_2, "distance", index, state) def add_hyperbolic_distance_restraint( self, rest: restraints.HyperbolicDistanceRestraint, alpha: float, timestep: int, state: interfaces.IState, ): assert isinstance(rest, restraints.HyperbolicDistanceRestraint) self.hyperbolic_distance_restraints.append(rest) index = len(self.hyperbolic_distance_restraints) - 1 self.need_update.add(("hyperbolic_distance", index)) self._add_scaler_dependency(rest.scaler, "hyperbolic_distance", index, alpha) self._add_ramp_dependency(rest.ramp, "hyperbolic_distance", index, timestep) def add_torsion_restraint( self, rest: restraints.TorsionRestraint, alpha: float, timestep: int, state: interfaces.IState, ): assert isinstance(rest, restraints.TorsionRestraint) self.torsion_restraints.append(rest) index = len(self.torsion_restraints) - 1 self.need_update.add(("torsion", index)) self._add_scaler_dependency(rest.scaler, "torsion", index, alpha) self._add_ramp_dependency(rest.ramp, "torsion", index, timestep) def add_distance_profile_restraint( self, rest: restraints.DistProfileRestraint, alpha: float, timestep: int, state: interfaces.IState, ): assert isinstance(rest, restraints.DistProfileRestraint) self.dist_prof_restraints.append(rest) index = len(self.dist_prof_restraints) - 1 self.need_update.add(("dist_profile", index)) self._add_scaler_dependency(rest.scaler, "dist_profile", index, alpha) self._add_ramp_dependency(rest.ramp, "dist_profile", index, timestep) def add_torsion_profile_restraint( self, rest: restraints.TorsProfileRestraint, alpha: float, timestep: int, state: interfaces.IState, ): assert isinstance(rest, restraints.TorsProfileRestraint) self.torsion_profile_restraints.append(rest) index = len(self.torsion_profile_restraints) - 1 self.need_update.add(("tors_profile", index)) self._add_scaler_dependency(rest.scaler, "tors_profile", index, alpha) self._add_ramp_dependency(rest.ramp, "tors_profile", index, timestep) def add_gmm_distance_restraint( self, rest: restraints.GMMDistanceRestraint, alpha: float, timestep: int, state: interfaces.IState, ): assert isinstance(rest, restraints.GMMDistanceRestraint) self.gmm_restraints.append(rest) index = len(self.gmm_restraints) - 1 self.need_update.add(("gmm", index)) self._add_scaler_dependency(rest.scaler, "gmm", index, alpha) self._add_ramp_dependency(rest.ramp, "gmm", index, timestep) def _add_scaler_dependency( self, scaler: restraints.RestraintScaler, category: str, index: int, alpha: float, ): if not isinstance(scaler, restraints.ConstantScaler): self.scaler_map[scaler].append((category, index)) if scaler not in self.scaler_values: self.scaler_values[scaler] = scaler(alpha) else: assert scaler(alpha) == self.scaler_values[scaler] def _add_ramp_dependency( self, ramp: restraints.TimeRamp, category: str, index: int, timestep: int ): if not isinstance(ramp, restraints.ConstantRamp): self.ramp_map[ramp].append((category, index)) if ramp not in self.ramp_values: self.ramp_values[ramp] = ramp(timestep) else: assert ramp(timestep) == self.ramp_values[ramp] def _add_positioner_dependency( self, positioner: restraints.Positioner, category: str, index: int, alpha: float ): if not isinstance(positioner, restraints.ConstantPositioner): self.positioner_map[positioner].append((category, index)) if positioner not in self.positioner_values: self.positioner_values[positioner] = positioner(alpha) else: assert positioner(alpha) == self.positioner_values[positioner] def _add_peak_mapping_dependency( self, peak_mapping: Union[int, mapping.PeakMapping], category: str, index: int, state: interfaces.IState, ): if not isinstance(peak_mapping, int): self.peak_mapping_map[peak_mapping].append((category, index)) if peak_mapping not in self.peak_mapping_values: new_value = self.peak_mapper.extract_value(peak_mapping, state.mappings) if isinstance(new_value, mapping.NotMapped): new_value = -1 self.peak_mapping_values[peak_mapping] = new_value else: value = self.peak_mapper.extract_value(peak_mapping, state.mappings) assert value == self.peak_mapping_values[peak_mapping]