Source code for meld.runner.transform.restraints.meld.tracker

import collections
import logging
from typing import DefaultDict, Dict, List, Optional, Set, Tuple, Union

import numpy as np  # typing: ignore

from meld import interfaces
from meld.system import density, mapping, param_sampling, restraints, scalers

logger = logging.getLogger(__name__)


[docs]class RestraintTracker: """ A data structure to keep track of restraints, groups, and collections. For restraints, we keep track of the dependence on scalers, ramps, positioners, and peak mappings. We only update a restraint when those dependencies have changed. For groups and collections, we keep track of which ones depend on parameter sampling. Only those ones are updated each step. """ param_manager: param_sampling.ParameterManager peak_mapper: mapping.PeakMapManager rdc_restraints: List[restraints.RdcRestraint] distance_restraints: List[restraints.DistanceRestraint] hyperbolic_distance_restraints: List[restraints.HyperbolicDistanceRestraint] torsion_restraints: List[restraints.TorsionRestraint] dist_prof_restraints: List[restraints.DistProfileRestraint] torsion_profile_restraints: List[restraints.TorsProfileRestraint] gmm_restraints: List[restraints.GMMDistanceRestraint] groups_with_dep: List[Tuple[restraints.RestraintGroup, int]] collections_with_dep: List[Tuple[restraints.SelectivelyActiveCollection, int]] scaler_map: DefaultDict[restraints.RestraintScaler, List[Tuple[str, int]]] ramp_map: DefaultDict[restraints.TimeRamp, List[Tuple[str, int]]] positioner_map: DefaultDict[restraints.Positioner, List[Tuple[str, int]]] peak_mapping_map: DefaultDict[int, List[Tuple[str, int]]] scaler_values: Dict[restraints.RestraintScaler, float] ramp_values: Dict[restraints.TimeRamp, float] positioner_values: Dict[restraints.Positioner, float] peak_mapping_values: Optional[np.ndarray] need_update: Set[Tuple[str, int]] densities: List[density.DensityMap] density_restraints: List[restraints.DensityRestraint] scaler_density_map: DefaultDict[ restraints.BlurScaler, List[Tuple[int, density.DensityMap]] ] scaler_density_values: Dict[restraints.BlurScaler, float]
[docs] def __init__( self, param_manager: param_sampling.ParameterManager, peak_mapper: mapping.PeakMapManager, ): self.param_manager = param_manager self.peak_mapper = peak_mapper # These hold lists of meld restraints in the order that they were added # to the system. self.distance_restraints = [] self.rdc_restraints = [] self.hyperbolic_distance_restraints = [] self.torsion_restraints = [] self.dist_prof_restraints = [] self.torsion_profile_restraints = [] self.gmm_restraints = [] self.densities = [] self.density_restraints = [] self.groups_with_dep = [] self.collections_with_dep = [] # These map from scalers, ramps, etc to the restraints that depend on them. self.scaler_map = collections.defaultdict(list) self.ramp_map = collections.defaultdict(list) self.positioner_map = collections.defaultdict(list) self.peak_mapping_map = collections.defaultdict(list) self.scaler_density_map = collections.defaultdict(list) # These maintain the previous values for these quantities self.scaler_values = {} self.ramp_values = {} self.positioner_values = {} self.peak_mapping_values = None self.scaler_density_values = {} # We maintain a set of restraints that need to be updated. self.need_update = set()
def update(self, alpha: float, timestep: int, state: interfaces.IState): self._update_scalers(alpha) self._update_ramps(timestep) self._update_positioners(alpha) self._update_peak_mappings(state) def get_and_reset_need_update(self) -> Set[Tuple[str, int]]: need_update = self.need_update self.need_update = set() return need_update def density_to_update(self, alpha): to_update = [] for scaler in self.scaler_density_values: old_value = self.scaler_density_values[scaler] new_value = scaler(alpha) if new_value != old_value or alpha == 0.0: to_update.extend(self.scaler_density_map[scaler]) self.scaler_density_values[scaler] = new_value return to_update def _update_scalers(self, alpha: float): for scaler in self.scaler_values: old_value = self.scaler_values[scaler] new_value = scaler(alpha) if new_value != old_value: for category, index in self.scaler_map[scaler]: self.need_update.add((category, index)) self.scaler_values[scaler] = new_value def _update_ramps(self, timestep: int): for ramp in self.ramp_values: old_value = self.ramp_values[ramp] new_value = ramp(timestep) if new_value != old_value: for category, index in self.ramp_map[ramp]: self.need_update.add((category, index)) self.ramp_values[ramp] = new_value def _update_positioners(self, alpha: float): for positioner in self.positioner_values: old_value = self.positioner_values[positioner] new_value = positioner(alpha) if new_value != old_value: for category, index in self.positioner_map[positioner]: self.need_update.add((category, index)) self.positioner_values[positioner] = new_value def _update_peak_mappings(self, state: interfaces.IState): changes = np.argwhere(self.peak_mapping_values != state.mappings) for global_peak_index in changes: global_peak_index = global_peak_index[0] for category, index in self.peak_mapping_map[global_peak_index]: self.need_update.add((category, index)) if self.peak_mapping_values is not None: self.peak_mapping_values[global_peak_index] = state.mappings[ global_peak_index ] def add_rdc_restraint( self, rest: restraints.RdcRestraint, alpha: float, timestep: int, state: interfaces.IState, ): assert isinstance(rest, restraints.RdcRestraint) self.rdc_restraints.append(rest) index = len(self.rdc_restraints) - 1 self.need_update.add(("rdc", index)) self._add_scaler_dependency(rest.scaler, "rdc", index, alpha) self._add_ramp_dependency(rest.ramp, "rdc", index, timestep) self._add_peak_mapping_dependency(rest.atom_index_1, "rdc", index, state) self._add_peak_mapping_dependency(rest.atom_index_2, "rdc", index, state) def add_density(self, index: int, density: density.DensityMap, alpha: float): self.densities.append(density) self._add_scaler_density_dependency(index, density, alpha) def add_density_restraint( self, rest: restraints.DensityRestraint, alpha: float, timestep: int, state: interfaces.IState, ): assert isinstance(rest, restraints.DensityRestraint) self.density_restraints.append(rest) index = len(self.density_restraints) - 1 self.need_update.add(("density", index)) self._add_scaler_dependency(rest.scaler, "density", index, alpha) self._add_ramp_dependency(rest.ramp, "density", index, timestep) def add_distance_restraint( self, rest: restraints.DistanceRestraint, alpha: float, timestep: int, state: interfaces.IState, ): assert isinstance(rest, restraints.DistanceRestraint) self.distance_restraints.append(rest) index = len(self.distance_restraints) - 1 self.need_update.add(("distance", index)) self._add_scaler_dependency(rest.scaler, "distance", index, alpha) self._add_ramp_dependency(rest.ramp, "distance", index, timestep) self._add_positioner_dependency(rest.r1, "distance", index, alpha) self._add_positioner_dependency(rest.r2, "distance", index, alpha) self._add_positioner_dependency(rest.r3, "distance", index, alpha) self._add_positioner_dependency(rest.r4, "distance", index, alpha) self._add_peak_mapping_dependency(rest.atom_index_1, "distance", index, state) self._add_peak_mapping_dependency(rest.atom_index_2, "distance", index, state) def add_hyperbolic_distance_restraint( self, rest: restraints.HyperbolicDistanceRestraint, alpha: float, timestep: int, state: interfaces.IState, ): assert isinstance(rest, restraints.HyperbolicDistanceRestraint) self.hyperbolic_distance_restraints.append(rest) index = len(self.hyperbolic_distance_restraints) - 1 self.need_update.add(("hyperbolic_distance", index)) self._add_scaler_dependency(rest.scaler, "hyperbolic_distance", index, alpha) self._add_ramp_dependency(rest.ramp, "hyperbolic_distance", index, timestep) def add_torsion_restraint( self, rest: restraints.TorsionRestraint, alpha: float, timestep: int, state: interfaces.IState, ): assert isinstance(rest, restraints.TorsionRestraint) self.torsion_restraints.append(rest) index = len(self.torsion_restraints) - 1 self.need_update.add(("torsion", index)) self._add_scaler_dependency(rest.scaler, "torsion", index, alpha) self._add_ramp_dependency(rest.ramp, "torsion", index, timestep) def add_distance_profile_restraint( self, rest: restraints.DistProfileRestraint, alpha: float, timestep: int, state: interfaces.IState, ): assert isinstance(rest, restraints.DistProfileRestraint) self.dist_prof_restraints.append(rest) index = len(self.dist_prof_restraints) - 1 self.need_update.add(("dist_profile", index)) self._add_scaler_dependency(rest.scaler, "dist_profile", index, alpha) self._add_ramp_dependency(rest.ramp, "dist_profile", index, timestep) def add_torsion_profile_restraint( self, rest: restraints.TorsProfileRestraint, alpha: float, timestep: int, state: interfaces.IState, ): assert isinstance(rest, restraints.TorsProfileRestraint) self.torsion_profile_restraints.append(rest) index = len(self.torsion_profile_restraints) - 1 self.need_update.add(("tors_profile", index)) self._add_scaler_dependency(rest.scaler, "tors_profile", index, alpha) self._add_ramp_dependency(rest.ramp, "tors_profile", index, timestep) def add_gmm_distance_restraint( self, rest: restraints.GMMDistanceRestraint, alpha: float, timestep: int, state: interfaces.IState, ): assert isinstance(rest, restraints.GMMDistanceRestraint) self.gmm_restraints.append(rest) index = len(self.gmm_restraints) - 1 self.need_update.add(("gmm", index)) self._add_scaler_dependency(rest.scaler, "gmm", index, alpha) self._add_ramp_dependency(rest.ramp, "gmm", index, timestep) def _add_scaler_dependency( self, scaler: restraints.RestraintScaler, category: str, index: int, alpha: float, ): if not isinstance(scaler, restraints.ConstantScaler): self.scaler_map[scaler].append((category, index)) if scaler not in self.scaler_values: self.scaler_values[scaler] = scaler(alpha) else: assert scaler(alpha) == self.scaler_values[scaler] def _add_ramp_dependency( self, ramp: restraints.TimeRamp, category: str, index: int, timestep: int ): if not isinstance(ramp, restraints.ConstantRamp): self.ramp_map[ramp].append((category, index)) if ramp not in self.ramp_values: self.ramp_values[ramp] = ramp(timestep) else: assert ramp(timestep) == self.ramp_values[ramp] def _add_positioner_dependency( self, positioner: restraints.Positioner, category: str, index: int, alpha: float ): if not isinstance(positioner, restraints.ConstantPositioner): self.positioner_map[positioner].append((category, index)) if positioner not in self.positioner_values: self.positioner_values[positioner] = positioner(alpha) else: assert positioner(alpha) == self.positioner_values[positioner] def _add_peak_mapping_dependency( self, peak_mapping: Union[int, mapping.PeakMapping], category: str, index: int, state: interfaces.IState, ): if isinstance(peak_mapping, mapping.PeakMapping): global_peak_index = self.peak_mapper.get_index(peak_mapping) self.peak_mapping_map[global_peak_index].append((category, index)) if self.peak_mapping_values is None: self.peak_mapping_values = state.mappings.copy() else: assert (state.mappings == self.peak_mapping_values).all() def _add_scaler_density_dependency( self, index: int, density: density.DensityMap, alpha: float ): if not isinstance(density.blur_scaler, scalers.ConstantBlurScaler): self.scaler_density_map[density.blur_scaler].append((index, density)) if density.blur_scaler not in self.scaler_density_values: self.scaler_density_values[density.blur_scaler] = density.blur_scaler( alpha ) else: assert ( density.blur_scaler(alpha) == self.scaler_values[density.blur_scaler] )