Source code for meld.runner.transform.restraints.meld.meld

#
# All rights reserved
#

"""
This module implements transformers that add meld restraints
"""

import logging
from typing import List, Tuple, Union

import numpy as np  # type: ignore
import openmm as mm  # type: ignore
from meldplugin import MeldForce  # type: ignore
from openmm import app  # type: ignore

from meld import interfaces
from meld.runner import transform
from meld.runner.transform.restraints.meld.tracker import RestraintTracker
from meld.runner.transform.restraints.util import _delete_from_always_active
from meld.system import density, mapping, options, param_sampling, restraints

logger = logging.getLogger(__name__)

FORCE_GROUP = 1


[docs]class MeldRestraintTransformer(transform.TransformerBase): """ Transformer to handle MELD restraints """ force: MeldForce
[docs] def __init__( self, param_manager: param_sampling.ParameterManager, mapper: mapping.PeakMapManager, density_manager: density.DensityManager, builder_info: dict, options: options.RunOptions, always_active_restraints: List[restraints.Restraint], selectively_active_restraints: List[restraints.SelectivelyActiveCollection], ) -> None: self.param_manager = param_manager self.mapper = mapper self.density_manager = density_manager self.builder_info = builder_info # Track indices of restraints, groups, and collections so that we can # update them. self.tracker = RestraintTracker(param_manager, mapper) self.always_on = [ r for r in always_active_restraints if isinstance(r, restraints.SelectableRestraint) ] _delete_from_always_active(self.always_on, always_active_restraints) # Gather all of the selectively active restraints. self.selective_on = [r for r in selectively_active_restraints] for r in self.selective_on: selectively_active_restraints.remove(r) if self.always_on or self.selective_on: self.active = True else: self.active = False
[docs] def add_interactions( self, state: interfaces.IState, system: mm.System, topology: app.Topology ) -> mm.System: if self.active: n_alignments = self.builder_info.get("num_alignments", 0) rdc_scale_factor = self.builder_info.get("alignment_scale_factor", 1e-4) meld_force = MeldForce(n_alignments, rdc_scale_factor) # If we have any density maps, add them now for index, density in enumerate(self.density_manager.densities): self.tracker.add_density(index, density, 0) blurred = _compute_density_potential( density.density_data, density.blur_scaler(0) ) # ,origin=False) # TODO What do do outside of grid? # TODO fix numpy typemaps meld_force.addGridPotential( blurred, density.origin[0], density.origin[1], density.origin[2], density.voxel_size[0], density.voxel_size[1], density.voxel_size[2], density.nx, density.ny, density.nz, index, ) # Add all of the always-on restraints if self.always_on: group_list = [] for rest in self.always_on: rest_index = self._add_meld_restraint(rest, meld_force, 0, 0, state) # Each restraint goes in its own group. # This group does not depend on parameter sampling, # so we will never need to update it group_index = meld_force.addGroup([rest_index], 1) group_list.append(group_index) # All of the always-on restraints go in a single collection # This collection does not depend on parameter sampling, # so we will never need to update it. meld_force.addCollection(group_list, len(group_list)) # Add the selectively active restraints for coll in self.selective_on: group_indices = [] for group in coll.groups: restraint_indices = [] for rest in group.restraints: rest_index = self._add_meld_restraint( rest, meld_force, 0, 0, state ) restraint_indices.append(rest_index) # Create the group in the meldplugin group_num_active = self._handle_num_active(group.num_active, state) group_index = meld_force.addGroup( restraint_indices, group_num_active ) group_indices.append(group_index) # If the group depends on parameter sampling, add it to the tracker # so that it can be updated. if isinstance(group.num_active, param_sampling.Parameter): self.tracker.groups_with_dep.append((group, group_index)) # Create the collection in the meldplugin coll_num_active = self._handle_num_active(coll.num_active, state) coll_index = meld_force.addCollection(group_indices, coll_num_active) # If the collection depends on parameter sampling, add it to the tracker # so that it can be updated. if isinstance(coll.num_active, param_sampling.Parameter): self.tracker.collections_with_dep.append((coll, coll_index)) meld_force.setForceGroup(FORCE_GROUP) system.addForce(meld_force) self.force = meld_force return system
[docs] def update( self, state: interfaces.IState, simulation: app.Simulation, alpha: float, timestep: int, ) -> None: if self.active: self._update_densities(alpha) self._update_restraints(alpha, timestep, state) self._update_groups_collections(state) self.force.updateParametersInContext(simulation.context)
def _update_densities(self, alpha): to_update = self.tracker.density_to_update(alpha) for index, density in to_update: blur = density.blur_scaler(alpha) blurred = _compute_density_potential(density.density_data, alpha) self.force.modifyGridPotential( index, blurred, density.origin[0], density.origin[1], density.origin[2], density.voxel_size[0], density.voxel_size[1], density.voxel_size[2], density.nx, density.ny, density.nz, ) def _update_groups_collections( self, state: interfaces.IState, ) -> None: for coll, index in self.tracker.collections_with_dep: num_active = self._handle_num_active(coll.num_active, state) self.force.modifyCollectionNumActive(index, num_active) for group, index in self.tracker.groups_with_dep: num_active = self._handle_num_active(group.num_active, state) self.force.modifyGroupNumActive(index, num_active) def _update_restraints( self, alpha: float, timestep: int, state: interfaces.IState ) -> None: # Get the list of restraints to update self.tracker.update(alpha, timestep, state) to_update = self.tracker.get_and_reset_need_update() for category, index in to_update: if category == "rdc": rdc_rest = self.tracker.rdc_restraints[index] scale = rdc_rest.scaler(alpha) * rdc_rest.ramp(timestep) j, k = self._handle_mapping( [rdc_rest.atom_index_1, rdc_rest.atom_index_2], state ) self.force.modifyRDCRestraint( index, j, k, rdc_rest.alignment_index, rdc_rest.kappa, rdc_rest.d_obs, rdc_rest.tolerance, rdc_rest.quadratic_cut, rdc_rest.force_const * scale, ) elif category == "distance": dist_rest = self.tracker.distance_restraints[index] scale = dist_rest.scaler(alpha) * dist_rest.ramp(timestep) j, k = self._handle_mapping( [dist_rest.atom_index_1, dist_rest.atom_index_2], state ) self.force.modifyDistanceRestraint( index, j, k, dist_rest.r1(alpha), dist_rest.r2(alpha), dist_rest.r3(alpha), dist_rest.r4(alpha), dist_rest.k * scale, ) elif category == "hyperbolic_distance": hyper_rest = self.tracker.hyperbolic_distance_restraints[index] scale = hyper_rest.scaler(alpha) * hyper_rest.ramp(timestep) self.force.modifyHyperbolicDistanceRestraint( index, hyper_rest.atom_index_1, hyper_rest.atom_index_2, hyper_rest.r1, hyper_rest.r2, hyper_rest.r3, hyper_rest.r4, hyper_rest.k * scale, hyper_rest.asymptote * scale, ) elif category == "torsion": tors_rest = self.tracker.torsion_restraints[index] scale = tors_rest.scaler(alpha) * tors_rest.ramp(timestep) self.force.modifyTorsionRestraint( index, tors_rest.atom_index_1, tors_rest.atom_index_2, tors_rest.atom_index_3, tors_rest.atom_index_4, tors_rest.phi, tors_rest.delta_phi, tors_rest.k * scale, ) elif category == "dist_profile": dist_prof_rest = self.tracker.dist_prof_restraints[index] scale = dist_prof_rest.scaler(alpha) * dist_prof_rest.ramp(timestep) self.force.modifyDistProfileRestraint( index, dist_prof_rest.atom_index_1, dist_prof_rest.atom_index_2, dist_prof_rest.r_min, dist_prof_rest.r_max, dist_prof_rest.n_bins, dist_prof_rest.spline_params[:, 0], dist_prof_rest.spline_params[:, 1], dist_prof_rest.spline_params[:, 2], dist_prof_rest.spline_params[:, 3], dist_prof_rest.scale_factor * scale, ) elif category == "tors_profile": tors_prof_rest = self.tracker.torsion_profile_restraints[index] scale = tors_prof_rest.scaler(alpha) * tors_prof_rest.ramp(timestep) self.force.modifyTorsProfileRestraint( index, tors_prof_rest.atom_index_1, tors_prof_rest.atom_index_2, tors_prof_rest.atom_index_3, tors_prof_rest.atom_index_4, tors_prof_rest.atom_index_5, tors_prof_rest.atom_index_6, tors_prof_rest.atom_index_7, tors_prof_rest.atom_index_8, tors_prof_rest.n_bins, tors_prof_rest.spline_params[:, 0], tors_prof_rest.spline_params[:, 1], tors_prof_rest.spline_params[:, 2], tors_prof_rest.spline_params[:, 3], tors_prof_rest.spline_params[:, 4], tors_prof_rest.spline_params[:, 5], tors_prof_rest.spline_params[:, 6], tors_prof_rest.spline_params[:, 7], tors_prof_rest.spline_params[:, 8], tors_prof_rest.spline_params[:, 9], tors_prof_rest.spline_params[:, 10], tors_prof_rest.spline_params[:, 11], tors_prof_rest.spline_params[:, 12], tors_prof_rest.spline_params[:, 13], tors_prof_rest.spline_params[:, 14], tors_prof_rest.spline_params[:, 15], tors_prof_rest.scale_factor * scale, ) elif category == "gmm": gmm_rest = self.tracker.gmm_restraints[index] scale = gmm_rest.scaler(alpha) * gmm_rest.ramp(timestep) nd = gmm_rest.n_distances nc = gmm_rest.n_components w = gmm_rest.weights m = list(gmm_rest.means.flatten()) d, o = _setup_precisions(gmm_rest.precisions, nd, nc) self.force.modifyGMMRestraint( index, nd, nc, scale, gmm_rest.atoms, w, m, d, o ) elif category == "density": density_rest = self.tracker.density_restraints[index] self.force.modifyGridPotentialRestraint( index, density_rest.atom_index, _compute_density_potential(density_rest.mu, alpha), np.linspace( density_rest.map_origin[0], density_rest.map_origin[0] + (density_rest.map_dimension[0] - 1) * density_rest.map_gridLength[0], int(density_rest.map_dimension[0]), ), np.linspace( density_rest.map_origin[1], density_rest.map_origin[1] + (density_rest.map_dimension[1] - 1) * density_rest.map_gridLength[1], int(density_rest.map_dimension[1]), ), np.linspace( density_rest.map_origin[2], density_rest.map_origin[2] + (density_rest.map_dimension[2] - 1) * density_rest.map_gridLength[2], int(density_rest.map_dimension[2]), ), ) else: raise RuntimeError(f"Unknown restraint category {category}") def _handle_num_active(self, value, state): if isinstance(value, param_sampling.Parameter): return int(self.param_manager.extract_value(value, state.parameters)) else: return value def _handle_mapping( self, values: List[Union[int, mapping.PeakMapping]], state: interfaces.IState ) -> List[int]: indices: List[int] = [] for value in values: if isinstance(value, mapping.PeakMapping): index = self.mapper.extract_value(value, state.mappings) else: index = value indices.append(index) # If any of the indices is un-mapped, we set them # # all to -1. if any(x == -1 for x in indices): indices = [-1 for _ in values] return indices def _add_meld_restraint( self, rest, meld_force: MeldForce, alpha: float, timestep: int, state: interfaces.IState, ) -> int: scale = rest.scaler(alpha) * rest.ramp(timestep) if isinstance(rest, restraints.RdcRestraint): i, j = self._handle_mapping([rest.atom_index_1, rest.atom_index_2], state) rest_index = meld_force.addRDCRestraint( i, j, rest.alignment_index, rest.kappa, rest.d_obs, rest.tolerance, rest.quadratic_cut, rest.force_const * scale, ) self.tracker.add_rdc_restraint(rest, alpha, timestep, state) elif isinstance(rest, restraints.DistanceRestraint): i, j = self._handle_mapping([rest.atom_index_1, rest.atom_index_2], state) rest_index = meld_force.addDistanceRestraint( i, j, rest.r1(alpha), rest.r2(alpha), rest.r3(alpha), rest.r4(alpha), rest.k * scale, ) self.tracker.add_distance_restraint(rest, alpha, timestep, state) elif isinstance(rest, restraints.HyperbolicDistanceRestraint): rest_index = meld_force.addHyperbolicDistanceRestraint( rest.atom_index_1, rest.atom_index_2, rest.r1, rest.r2, rest.r3, rest.r4, rest.k * scale, rest.asymptote * scale, ) self.tracker.add_hyperbolic_distance_restraint(rest, alpha, timestep, state) elif isinstance(rest, restraints.TorsionRestraint): rest_index = meld_force.addTorsionRestraint( rest.atom_index_1, rest.atom_index_2, rest.atom_index_3, rest.atom_index_4, rest.phi, rest.delta_phi, rest.k * scale, ) self.tracker.add_torsion_restraint(rest, alpha, timestep, state) elif isinstance(rest, restraints.DistProfileRestraint): rest_index = meld_force.addDistProfileRestraint( rest.atom_index_1, rest.atom_index_2, rest.r_min, rest.r_max, rest.n_bins, rest.spline_params[:, 0], rest.spline_params[:, 1], rest.spline_params[:, 2], rest.spline_params[:, 3], rest.scale_factor * scale, ) self.tracker.add_distance_profile_restraint(rest, alpha, timestep, state) elif isinstance(rest, restraints.TorsProfileRestraint): rest_index = meld_force.addTorsProfileRestraint( rest.atom_index_1, rest.atom_index_2, rest.atom_index_3, rest.atom_index_4, rest.atom_index_5, rest.atom_index_6, rest.atom_index_7, rest.atom_index_8, rest.n_bins, rest.spline_params[:, 0], rest.spline_params[:, 1], rest.spline_params[:, 2], rest.spline_params[:, 3], rest.spline_params[:, 4], rest.spline_params[:, 5], rest.spline_params[:, 6], rest.spline_params[:, 7], rest.spline_params[:, 8], rest.spline_params[:, 9], rest.spline_params[:, 10], rest.spline_params[:, 11], rest.spline_params[:, 12], rest.spline_params[:, 13], rest.spline_params[:, 14], rest.spline_params[:, 15], rest.scale_factor * scale, ) self.tracker.add_torsion_profile_restraint(rest, alpha, timestep, state) elif isinstance(rest, restraints.GMMDistanceRestraint): nd = rest.n_distances nc = rest.n_components w = rest.weights m = list(rest.means.flatten()) d, o = _setup_precisions(rest.precisions, nd, nc) rest_index = meld_force.addGMMRestraint( nd, nc, scale, rest.atoms, w, m, d, o ) self.tracker.add_gmm_distance_restraint(rest, alpha, timestep, state) elif isinstance(rest, restraints.DensityRestraint): rest_index = meld_force.addGridPotentialRestraint( rest.atom_index, _compute_density_potential(rest.mu, alpha), np.linspace( rest.map_origin[0], rest.map_origin[0] + (rest.map_dimension[0] - 1) * rest.map_gridLength[0], int(rest.map_dimension[0]), ), np.linspace( rest.map_origin[1], rest.map_origin[1] + (rest.map_dimension[1] - 1) * rest.map_gridLength[1], int(rest.map_dimension[1]), ), np.linspace( rest.map_origin[2], rest.map_origin[2] + (rest.map_dimension[2] - 1) * rest.map_gridLength[2], int(rest.map_dimension[2]), ), ) self.tracker.add_density_restraint(rest, alpha, timestep, state) else: raise RuntimeError(f"Do not know how to handle restraint {rest}") return rest_index
def _setup_precisions( precisions: np.ndarray, n_distances: int, n_conditions: int ) -> Tuple[List[float], List[float]]: # The normalization of our GMMs will blow up # due to division by zero if the precisions # are zero, so we clamp this to a very # small value. diags = [] for i in range(n_conditions): for j in range(n_distances): diags.append(precisions[i, j, j]) off_diags = [] for i in range(n_conditions): for j in range(n_distances): for k in range(j + 1, n_distances): off_diags.append(precisions[i, j, k]) return diags, off_diags def _compute_density_potential(mu, alpha): replica_num = int(alpha * (mu.shape[0] - 1)) potential = mu[replica_num].astype(np.float64) return potential