Source code for meld.system.restraints

#
# Copyright 2015 by Justin MacCallum, Alberto Perez, Ken Dill
# All rights reserved
#

"""
Implements all of the restraints available in MELD.

This file implements restraints and related classes for MELD.
Restraints are the primary way that "extra" forces are added
into MELD simulations.

There are several important concepts: `restraints`, `groups`,
`collections`, `scalers`, `ramps`, `positioners`, and the
`restraint manager`.

Restraints
----------
Restraints represent "extra" forces that can be added into a
MELD simulation. There are many different types of restraints.
Each restraint object has a variety of parameters that describe
the strength of the force, the atoms involved, and so on.

There are two main types of restraints, :class:`SelectableRestraint`
and :class:`NonSelectableRestraint`, which have substantially different
behavior.

:class:`NonSelectableRestraint` are "always on". They may be scaled by
scalers and ramps, but the force from each :class:`NonSelectableRestraint`
is independent of other restraints.

:class:`SelectableRestraint` have forces and energies that depend on
other :class:`SelectableRestraint`. They may be combined into
:class:`RestraintGroup` objects, which allows for the ``n_active`` lowest
energy restraints to be active at each timestep. The remaining
restraints are inactive and do not contribute their forces or
energy to the system for that timestep. This selectable nature
allows for the implmentation of very flexible restraint
strategies useful for a variety of problems in structural
biology [1]_, [2]_.

The standard way to create a restraint is using their
:meth:`RestraintManager.create_restraint` with the appropriate
restraint key:

>>> r = system.restraints.create_restraint(rest_key, params...)


Groups
------
:class:`SelectableRestraint` must be part of a :class:`RestraintGroup`. Each
timestep, the restraints are sorted by energy and the ``num_active``
restraints with the lowest energy are activated for the timestep,
while the rest are ignored. It is not possible to add a
:class:`NonSelectableRestraint` to a :class:`RestraintGroup`.

:class:`RestraintGroups` are created by:

>>> g = system.restraints.create_restraint(list_of_restraints, num_active)

Collections
-----------
There are two types of collection: always on, and selectively active.

Restraints that will always be active are added to a single always on
collection. The standard ways to do this are:

>>> system.restraints.add_as_always_active(restraint)
>>> system.restraints.add_as_always_active_list(list_of_restraints)

Restraints or groups of restraints that will be selected are added
to selectively active collections. A mix of bare
:class:`SelectableRestraint` or :class:`RestraintGroup` objects may be added.
When bare restraints are added, they are automatically placed into
a group containing only that with restraint with ``num_active=1``.
The standard way to create a restraint group is:

>>> system.restraints.add_selectively_active_collection(
        list_of_restraints_and_groups, num_active)

Scalers
-------
Each replica in a MELD simulation has a value ``alpha`` that
runs from ``0.0`` to ``1.0``, inclusive. The lowest replica always
has ``alpha=0``, while the highest has ``alpha=1``. The strength
of restraints can be scaled by specifying a Scaler that maps
alpha into a scaling of the force constant.

Scalers are created and added to a restraint by:

>>> scaler = system.restraints.create_scaler(scaler_key, params...)
>>> r = system.restraints.create_restraint(rest_key, scaler=scaler, params...)

Ramps
-----
Ramps are similar to Scalers, except that they map the step of
the simulation into a scaling of the force constant. They are
typically used to slowly turn on forces at the start of a simulation.

Ramps are created and added to a restraint by:

>>> ramp = system.restraints.create_scaler(ramp_key, params...)
>>> r = system.restraints.create_restraint(rest_key, ramp=ramp, params...)

Note:
   Despite the name, ramps are created with the :meth:`create_scaler` method.

Positioners
-----------
Positioners are used to control the position or distance in a restraint. They
function similar to Scalers, but rather than returning a value in ``[0, 1]``, they
return a value from a defined range.

Positioners are created and added to a restraint by:

>>> positioner = system.restraints.create_scaler(pos_key, params...)
>>> r = system.restraints.create_restraint(
        rest_key, param=positioner, params...)

Note:
   Despite the name, positioners are created with the ``create_scaler`` method.

Restraint Manager
-----------------
The :class:`System` object maintains a :class:`RestraintManager` object, which is the
primary means for interacting with restraints. Generally, restraints, groups,
scalers, etc are created through the :class:`RestraintManager`, rather than
by direct construction.

References
----------
.. [1] J.L. MacCallum, A. Perez, and K.A. Dill, Determining protein structures
       by combining semireliable data with atomistic physical models by Bayesian
       inference, PNAS, 2015, 112(22), pp.6985--6990.
.. [2] A. Perez, J.L. MacCallum, and K.A. Dill, Accelerating molecular simulations
       of proteins using Bayesian inference on weak information, PNAS, 2015,
       112(38), pp. 11846--11851.
"""


from __future__ import annotations

from typing import Dict, List, NamedTuple, Optional, Union

import numpy as np  # type: ignore
from openmm import unit as u  # type: ignore

from meld import interfaces
from meld.system import indexing, mapping, param_sampling
from meld.system.density import DensityMap
from meld.system.scalers import (
    BlurScaler,
    ConstantPositioner,
    ConstantRamp,
    ConstantScaler,
    Positioner,
    RestraintScaler,
    ScalerRegistry,
    TimeRamp,
)
from meld.util import strip_unit

STRENGTH_AT_ALPHA_MAX = 1e-3  # default strength of restraints at alpha=1.0


class _RestraintRegistry(type):
    """
    Metaclass that maintains a registry of restraint types.

    All classes that descend from Restraint inherit _RestraintRegistry as their
    metaclass. _RestraintRegistry will automatically maintain a map between
    the class attribute '_restraint_key_' and all restraint types.

    The function get_constructor_for_key is used to get the class for the
    corresponding key.
    """

    _restraint_registry: Dict[str, type] = {}

    def __init__(cls, name, bases, attrs):
        if name in ["Restraint", "SelectableRestraint", "NonSelectableRestraint"]:
            pass  # we don't register the base classes
        else:
            try:
                key = attrs["_restraint_key_"]
            except KeyError:
                raise RuntimeError(
                    f"Restraint type {name} subclasses Restraint, "
                    "but does not set _restraint_key_"
                )
            if key in _RestraintRegistry._restraint_registry:
                raise RuntimeError(
                    "Trying to register two different classes"
                    f"with _restraint_key_ = {key}."
                )
            _RestraintRegistry._restraint_registry[key] = cls

    @classmethod
    def get_constructor_for_key(self, key):
        """Get the constructor for the restraint type matching key."""
        try:
            return _RestraintRegistry._restraint_registry[key]
        except KeyError:
            raise RuntimeError(f'Unknown restraint type "{key}".')


[docs]class Restraint(metaclass=_RestraintRegistry): """Abstract class for all restraints.""" pass
[docs]class SelectableRestraint(Restraint): """Abstract class for selectable restraints.""" pass
[docs]class NonSelectableRestraint(Restraint): """Abstract class for non-selectable restraints.""" pass
[docs]class DistanceRestraint(SelectableRestraint): """ Restrain the distance between two groups """ _restraint_key_ = "distance" atom_index_1: Union[int, mapping.PeakMapping] atom_index_2: Union[int, mapping.PeakMapping]
[docs] def __init__( self, system: interfaces.ISystem, scaler: Optional[RestraintScaler], ramp: Optional[TimeRamp], atom1: Union[indexing.AtomIndex, mapping.PeakMapping], atom2: Union[indexing.AtomIndex, mapping.PeakMapping], r1: Union[u.Quantity, Positioner], r2: Union[u.Quantity, Positioner], r3: Union[u.Quantity, Positioner], r4: Union[u.Quantity, Positioner], k: u.Quantity, ): """ Initialize a DistanceRestraint The energy is zero between ``r2`` and ``r3``. It increases quadratically between ``r1`` and ``r2`` and between ``r3`` and ``r4``. The energy increases linearly below ``r1`` and above ``r4``. Args: system: system object that restraint belongs to scaler: a Scaler to vary the force constant with alpha. If ``None``, then a constant 1.0 scaler will be used. ramp: a time ramp to turn restraints on a beginning of simulation atom_1: index of atom 1 atom_2: index of atom 2 r1: distance r2: distance r3: distance r4: distance k: force constant """ if isinstance(atom1, mapping.PeakMapping): self.atom_index_1 = atom1 else: assert isinstance(atom1, indexing.AtomIndex) self.atom_index_1 = int(atom1) if isinstance(atom2, mapping.PeakMapping): self.atom_index_2 = atom2 else: assert isinstance(atom2, indexing.AtomIndex) self.atom_index_2 = int(atom2) if isinstance(r1, Positioner): self.r1 = r1 else: self.r1 = ConstantPositioner(r1) if isinstance(r2, Positioner): self.r2 = r2 else: self.r2 = ConstantPositioner(r2) if isinstance(r3, Positioner): self.r3 = r3 else: self.r3 = ConstantPositioner(r3) if isinstance(r4, Positioner): self.r4 = r4 else: self.r4 = ConstantPositioner(r4) self.k = strip_unit(k, u.kilojoule_per_mole / u.nanometer**2) self.scaler = ConstantScaler() if scaler is None else scaler self.ramp = ConstantRamp() if ramp is None else ramp self._check(system)
def _check(self, system): for alpha in [0, 0.2, 0.4, 0.6, 0.8, 1.0]: if ( self.r1(alpha) < 0 or self.r2(alpha) < 0 or self.r3(alpha) < 0 or self.r4(alpha) < 0 ): raise RuntimeError( "r1 to r4 must be > 0. r1={} r2={} r3={} r4={}.".format( self.r1(alpha), self.r2(alpha), self.r3(alpha), self.r4(alpha) ) ) if self.r2(alpha) < self.r1(alpha): raise RuntimeError( f"r2 must be >= r1. r1={self.r1(alpha)} r2={self.r2(alpha)}." ) if self.r3(alpha) < self.r2(alpha): raise RuntimeError( f"r3 must be >= r2. r2={self.r2(alpha)} r3={self.r3(alpha)}." ) if self.r4(alpha) < self.r3(alpha): raise RuntimeError( f"r4 must be >= r3. r3={self.r3(alpha)} r4={self.r4(alpha)}." ) if self.k < 0: raise RuntimeError(f"k must be >= 0. k={self.k}.")
[docs]class GMMDistanceRestraint(SelectableRestraint): """ Restrain multiple distances using Gaussian mixture models The energy has the form: E = w1 N1 exp(-0.5 (r-u1)^T P1 (r-u1)) + w2 N2 exp(-0.5 (r-u2)^T P2 (r-u2)) + ... where: w1, w2, ... are the weights N1, N2, ... are automatically calculated normalization factors r is the vector of distances for the atom pairs u1, u2, ... are the mean vectors for each component P1, P2, ... are the precision (inverse covariance) matrices for each component """ _restraint_key_ = "gmm"
[docs] def __init__( self, system: interfaces.ISystem, scaler: Optional[RestraintScaler], ramp: Optional[TimeRamp], n_distances: int, n_components: int, atoms: List[indexing.AtomIndex], weights: np.ndarray, means: np.ndarray, precisions: np.ndarray, ): """ Initialize a GMMDistanceRestraint Args: system: system object that restraint belongs to scaler: A Scaler to vary the force constant with alpha. If ``None``, then a constant 1.0 scaler will be used. ramp: a time ramp to turn restraints on a beginning of simulation n_distances: number of distances involved in GMM; max 32 n_components: number of mixture components; max 32 atoms: a lit of length `2 * n_distances` weights: the weights for the mixture components, shape(n_components) means : the means of each mixture component, shape(n_components, n_distances) precisions: the precision (i.e. inverse covariance) of each mixture component, shape(n_components, n_distances, n_distances) """ self.scaler = ConstantScaler() if scaler is None else scaler self.ramp = ConstantRamp() if ramp is None else ramp self.n_distances = n_distances self.n_components = n_components self.weights = weights self.means = means self.precisions = precisions self.atoms = None self._setup_atoms(atoms, system) self._check(system)
[docs] @classmethod def from_params( cls, system: interfaces.ISystem, scaler: Optional[RestraintScaler], ramp: Optional[TimeRamp], params: GMMParams, ) -> GMMDistanceRestraint: """ Create a GMMDistanceRestraint from a GMMParams object. Args: system: system object that restraint belongs to scaler: A Scaler to vary the force constant with alpha. If ``None``, then a constant 1.0 scaler will be used. ramp: a time ramp to turn restraints on a beginning of simulation params: object to build restraint from """ return cls( system, scaler, ramp, params.n_distances, params.n_components, params.atoms, params.weights, params.means, params.precisions, )
def _setup_atoms(self, pair_list, system): self.atoms = [] for index in pair_list: assert isinstance(index, indexing.AtomIndex) self.atoms.append(int(index)) def _check(self, system): if len(self.atoms) != 2 * self.n_distances: raise RuntimeError("len(atoms) must be 2*n_distances") if self.weights.shape[0] != self.n_components: raise RuntimeError("weights must have shape (n_components,)") if self.means.shape != (self.n_components, self.n_distances): raise RuntimeError("means must have shape (n_components, n_distances)") if self.precisions.shape != ( self.n_components, self.n_distances, self.n_distances, ): raise RuntimeError( "precisions must have shape (n_components, n_distances, n_distances)" ) for i in range(self.n_components): if not np.allclose(self.precisions[i, :, :], self.precisions[i, :, :].T): raise RuntimeError("precision matrix must be symmetric") for i in range(self.n_components): # Perform a Cholesky decomposition on each precision matrix. # This will fail if the matrix is not positive definite. try: np.linalg.cholesky(self.precisions[i, :, :]) except np.linalg.LinAlgError: raise RuntimeError("precision matrices must be positive definite")
[docs]class HyperbolicDistanceRestraint(SelectableRestraint): """ Hyperbolic distance restraint between two atoms """ _restraint_key_ = "hyperbolic"
[docs] def __init__( self, system: interfaces.ISystem, scaler: Optional[RestraintScaler], ramp: Optional[TimeRamp], atom1: indexing.AtomIndex, atom2: indexing.AtomIndex, r1: u.Quantity, r2: u.Quantity, r3: u.Quantity, r4: u.Quantity, k: u.Quantity, asymptote: u.Quantity, ): """ Initialize a HyperbolicDistanceRestraint There are five regions:: I: r < r1 II: r1 < r < r2 III: r2 < r < r3 IV: r3 < r < r4 V: r4 < r The energy is linear in region I, quadratic in II and IV, and zero in III. The energy is hyperbolic in region V, with an asymptotic value set by the parameter asymptote. The energy will be 1/3 of the asymptotic value at r=r4. The distance between r3 and r4 controls the steepness of the potential. Args: system: the system this restraint belongs to scaler: scale the force constant with alpha ramp: ramp up restraint over time atom1: first atom in bond atom2: second atom in bond r1: distance r2: distance r3: distance r4: distance asymptote: maximum energy in region V """ self.scaler = ConstantScaler() if scaler is None else scaler self.ramp = ConstantRamp() if ramp is None else ramp assert isinstance(atom1, indexing.AtomIndex) self.atom_index_1 = int(atom1) assert isinstance(atom2, indexing.AtomIndex) self.atom_index_2 = int(atom2) self.r1 = strip_unit(r1, u.nanometer) self.r2 = strip_unit(r2, u.nanometer) self.r3 = strip_unit(r3, u.nanometer) self.r4 = strip_unit(r4, u.nanometer) self.k = strip_unit(k, u.kilojoule_per_mole / u.nanometer**2) self.asymptote = strip_unit(asymptote, u.kilojoule_per_mole) self._check(system)
def _check(self, system): if self.r1 < 0 or self.r2 < 0 or self.r3 < 0 or self.r4 < 0: raise RuntimeError( "r1 to r4 must be > 0. r1={} r2={} r3={} r4={}.".format( self.r1, self.r2, self.r3, self.r4 ) ) if self.r2 < self.r1: raise RuntimeError(f"r2 must be >= r1. r1={self.r1} r2={self.r2}.") if self.r3 < self.r2: raise RuntimeError(f"r3 must be >= r2. r2={self.r2} r3={self.r3}.") if self.r4 <= self.r3: raise RuntimeError(f"r4 must be > r3. r3={self.r3} r4={self.r4}.") if self.k < 0: raise RuntimeError(f"k must be >= 0. k={self.k}.") if self.asymptote < 0: raise RuntimeError(f"asymptote must be >= 0. asymptote={self.asymptote}.")
[docs]class TorsionRestraint(SelectableRestraint): """ A Torsion restraint between four atoms """ _restraint_key_ = "torsion"
[docs] def __init__( self, system: interfaces.ISystem, scaler: Optional[RestraintScaler], ramp: Optional[TimeRamp], atom1: indexing.AtomIndex, atom2: indexing.AtomIndex, atom3: indexing.AtomIndex, atom4: indexing.AtomIndex, phi: u.Quantity, delta_phi: u.Quantity, k: u.Quantity, ): """ Initialize a TorsionRestraint Args: system: the system this restraint belongs to scaler: scale the force with alpha ramp: ramp up the force over time atom1: index of first atom atom2: index of second atom atom3: index of third atom atom4: index of fourth atom phi: equilibrium angle in degrees delta_phi: flat within delta_phi, degrees k: force constant in :math:`kJ/mol/deg^2` """ assert isinstance(atom1, indexing.AtomIndex) assert isinstance(atom2, indexing.AtomIndex) assert isinstance(atom3, indexing.AtomIndex) assert isinstance(atom3, indexing.AtomIndex) self.atom_index_1 = int(atom1) self.atom_index_2 = int(atom2) self.atom_index_3 = int(atom3) self.atom_index_4 = int(atom4) self.phi = strip_unit(phi, u.degree) self.delta_phi = strip_unit(delta_phi, u.degree) self.k = strip_unit(k, u.kilojoule_per_mole / u.degree**2) self.scaler = ConstantScaler() if scaler is None else scaler self.ramp = ConstantRamp() if ramp is None else ramp self._check()
def _check(self): if ( len( set( [ self.atom_index_1, self.atom_index_2, self.atom_index_3, self.atom_index_4, ] ) ) != 4 ): raise RuntimeError( "All four indices of a torsion restraint must be unique." ) if self.phi < -180 or self.phi > 180: raise RuntimeError(f"-180 <= phi <= 180. phi was {self.phi}.") if self.delta_phi < 0 or self.delta_phi > 180: raise RuntimeError(f"0 <= delta_phi < 180. delta_phi was {self.delta_phi}.") if self.k < 0: raise RuntimeError(f"k >= 0. k was {self.k}.")
[docs]class DistProfileRestraint(SelectableRestraint): """ A spline-based distance profile restraint between two atoms """ _restraint_key_ = "dist_prof"
[docs] def __init__( self, system: interfaces.ISystem, scaler: Optional[RestraintScaler], ramp: Optional[TimeRamp], atom1: indexing.AtomIndex, atom2: indexing.AtomIndex, r_min: u.Quantity, r_max: u.Quantity, n_bins: int, spline_params: np.ndarray, scale_factor: u.Quantity, ): """ Initialize a DistProfileRestraint Args: system: the system this restraint belongs to scaler: scale the force with alpha ramp; scale the force over time atom1: the first atom in the bond atom2: the second atom in the bond r_min: the minimum distance in the lookup table r_max: the maximum distance in the lookup table n_bins: the number of bins in the lookup table spline_params: the spline coefficient lookup table, shape(n_bins, 4) scale_factor: scale the energy """ self.scaler = ConstantScaler() if scaler is None else scaler self.ramp = ConstantRamp() if ramp is None else ramp assert isinstance(atom1, indexing.AtomIndex) assert isinstance(atom2, indexing.AtomIndex) self.atom_index_1 = int(atom1) self.atom_index_2 = int(atom2) self.r_min = strip_unit(r_min, u.nanometer) self.r_max = strip_unit(r_max, u.nanometer) self.n_bins = n_bins self.spline_params = spline_params self.scale_factor = strip_unit(scale_factor, u.kilojoule_per_mole) self._check()
def _check(self): assert self.r_min >= 0.0 assert self.r_max > self.r_min assert self.n_bins > 0 assert self.spline_params.shape[0] == self.n_bins assert self.spline_params.shape[1] == 4
[docs]class TorsProfileRestraint(SelectableRestraint): """ A spline-based restraint between two torsions over eight atoms """ _restraint_key_ = "tors_prof"
[docs] def __init__( self, system: interfaces.ISystem, scaler: Optional[RestraintScaler], ramp: Optional[TimeRamp], atom1: indexing.AtomIndex, atom2: indexing.AtomIndex, atom3: indexing.AtomIndex, atom4: indexing.AtomIndex, atom5: indexing.AtomIndex, atom6: indexing.AtomIndex, atom7: indexing.AtomIndex, atom8: indexing.AtomIndex, n_bins: int, spline_params: np.ndarray, scale_factor: u.Quantity, ): """ Initialize a TorsProfileRestraint Args: system: the system this restraint belongs to scaler: scale the force with alpha ramp: ramp the strength of the force over time atom1: first atom of first torsion atom2: second atom of first torsion atom3: third atom of first torsion atom4: fourth atom of first torsion atom5: first atom of second torsion atom6: second atom of second torsion atom7: third atom of second torsion atom8: fourth atom of second torsion n_bins: number of bins in lookup spline_params: the spline coefficient lookup table, shape(n_bins, 16) scale_factor: scale the energy """ self.scaler = ConstantScaler() if scaler is None else scaler self.ramp = ConstantRamp() if ramp is None else ramp assert isinstance(atom1, indexing.AtomIndex) assert isinstance(atom2, indexing.AtomIndex) assert isinstance(atom3, indexing.AtomIndex) assert isinstance(atom4, indexing.AtomIndex) assert isinstance(atom5, indexing.AtomIndex) assert isinstance(atom6, indexing.AtomIndex) assert isinstance(atom7, indexing.AtomIndex) assert isinstance(atom8, indexing.AtomIndex) self.atom_index_1 = int(atom1) self.atom_index_2 = int(atom2) self.atom_index_3 = int(atom3) self.atom_index_4 = int(atom4) self.atom_index_5 = int(atom5) self.atom_index_6 = int(atom6) self.atom_index_7 = int(atom7) self.atom_index_8 = int(atom8) self.n_bins = n_bins self.spline_params = spline_params self.scale_factor = strip_unit(scale_factor, u.kilojoule_per_mole) self._check()
def _check(self): assert self.n_bins > 0 n_params = self.n_bins * self.n_bins assert self.spline_params.shape[0] == n_params assert self.spline_params.shape[1] == 16
[docs]class RdcRestraint(SelectableRestraint): """ Residual Dipolar Coupling Restraint """ _restraint_key_ = "rdc" atom_index_1: Union[int, mapping.PeakMapping] atom_index_2: Union[int, mapping.PeakMapping]
[docs] def __init__( self, system: interfaces.ISystem, scaler: Optional[RestraintScaler], ramp: Optional[TimeRamp], atom1: Union[indexing.AtomIndex, mapping.PeakMapping], atom2: Union[indexing.AtomIndex, mapping.PeakMapping], kappa: u.Quantity, d_obs: u.Quantity, tolerance: u.Quantity, force_const: u.Quantity, quadratic_cut: u.Quantity, weight: float, alignment_index: int, ): """ Initialize an RdcRestraint Args: system: the system this restraint belongs to scaler: scale the force with alpha ramp: scale the force over time atom1: the first atom in the RDC atom2: the second atom in the RDC kappa: prefactor for RDC calculation in :math:`Hz nm^3` d_obs: observed dipolar coupling in :math:`Hz` tolerance: calculed couplings within tolerance (in :math:`Hz`) of d_obs will have zero energy and force force_const: force constant in :math:`kJ/mol/Hz^2` quadratic_cut: force constant becomes linear beyond this deviation in :math:`s^-1` weight: dimensionless weight to place on this restraint alignment_index: which alignment to use Note: Typical values for kappa are: - 1H - 1H: :math:`-360.3 Hz nm^3` - 13C - 1H: :math:`-90.6 Hz nm^3` - 15N - 1H: :math:`36.5 Hz nm^3` """ if isinstance(atom1, mapping.PeakMapping): self.atom_index_1 = atom1 else: assert isinstance(atom1, indexing.AtomIndex) self.atom_index_1 = int(atom1) if isinstance(atom2, mapping.PeakMapping): self.atom_index_2 = atom2 else: assert isinstance(atom2, indexing.AtomIndex) self.atom_index_2 = int(atom2) kappa = strip_unit(kappa, u.second**-1 * u.nanometer**3) d_obs = strip_unit(d_obs, u.second**-1) tolerance = strip_unit(tolerance, u.second**-1) force_const = strip_unit(force_const, u.kilojoule_per_mole * u.second**2) quadratic_cut = strip_unit(quadratic_cut, u.second**-1) self.alignment_index = alignment_index self.kappa = float(kappa) self.d_obs = float(d_obs) self.tolerance = float(tolerance) self.force_const = float(force_const) self.quadratic_cut = quadratic_cut self.weight = float(weight) self.scaler = ConstantScaler() if scaler is None else scaler self.ramp = ConstantRamp() if ramp is None else ramp self._check(system)
def _check(self, system): if self.atom_index_1 == self.atom_index_2: raise ValueError("atom1 and atom2 must be different") if self.tolerance < 0: raise ValueError("tolerance must be > 0") if self.force_const < 0: raise ValueError("force_constant must be > 0") if self.weight < 0: raise ValueError("weight must be > 0") if self.quadratic_cut <= 0: raise ValueError("quadratic_cut must be > 0")
[docs]class ConfinementRestraint(NonSelectableRestraint): """ Confinement restraint from origin Confines an atom to be within radius of the origin. These restraints are typically set to somewhat larger than the expected radius of gyration of the protein and help to keep the structures comapact even when the protein is unfolded. Typically used with a :class:`ConstantScaler`. """ _restraint_key_ = "confine"
[docs] def __init__( self, system: interfaces.ISystem, scaler: Optional[RestraintScaler], ramp: Optional[TimeRamp], atom_index: indexing.AtomIndex, radius: u.Quantity, force_const: u.Quantity, ): """ Initialize a ConfinementRestraint Args: system: the system that this restraint belongs to scaler: scale the force with alpha ramp: scale the force over time atom_index: the index of the restrained atom radius: the distance to confine to force_const: strength of confinement """ assert isinstance(atom_index, indexing.AtomIndex) self.atom_index = int(atom_index) self.radius = strip_unit(radius, u.nanometer) self.force_const = strip_unit( force_const, u.kilojoule_per_mole / u.nanometer**2 ) self.scaler = ConstantScaler() if scaler is None else scaler self.ramp = ConstantRamp() if ramp is None else ramp self._check(system)
def _check(self, system): if self.radius < 0: raise ValueError("radius must be > 0") if self.force_const < 0: raise ValueError("force_constant must be > 0")
[docs]class CartesianRestraint(NonSelectableRestraint): """Cartesian restraint on xyz coordinates""" _restraint_key_ = "cartesian"
[docs] def __init__( self, system: interfaces.ISystem, scaler: Optional[RestraintScaler], ramp: Optional[TimeRamp], atom_index: indexing.AtomIndex, x: u.Quantity, y: u.Quantity, z: u.Quantity, delta: u.Quantity, force_const: u.Quantity, ): """ Initialize a CartesianRestraint Args: system: the system this restraint belongs to scaler: scale the force with alpha ramp: scale the force over time atom_index: the atom to restrain x: equilibrium x-coordinate y: equilibrium y-coordinate z: equilibrium z-coordinate delta: energy is zero within delta force_const: force constant """ assert isinstance(atom_index, indexing.AtomIndex) self.atom_index = int(atom_index) self.x = strip_unit(x, u.nanometer) self.y = strip_unit(y, u.nanometer) self.z = strip_unit(z, u.nanometer) self.delta = strip_unit(delta, u.nanometer) self.force_const = strip_unit( force_const, u.kilojoule_per_mole / u.nanometer**2 ) self.scaler = ConstantScaler() if scaler is None else scaler self.ramp = ConstantRamp() if ramp is None else ramp self._check()
def _check(self): if self.delta < 0: raise ValueError("delta must be non-negative") if self.force_const < 0: raise ValueError("force_const must be non-negative")
[docs]class YZCartesianRestraint(NonSelectableRestraint): """ Cartesian restraint on yz coordinates only """ _restraint_key_ = "yzcartesian"
[docs] def __init__( self, system: interfaces.ISystem, scaler: Optional[RestraintScaler], ramp: Optional[TimeRamp], atom_index: indexing.AtomIndex, y: u.Quantity, z: u.Quantity, delta: u.Quantity, force_const: u.Quantity, ): """ Initialize a YZCartesianRestraint Args: system: the system this restraint belongs to scaler: scale the force with alpha ramp: scale the force over time atom_index: the atom to restrain x: equilibrium x-coordinate, in nm y: equilibrium y-coordinate, in nm delta: energy is zero within delta, in nm force_const: force constant in :math:`kJ/mol/nm^2` """ assert isinstance(atom_index, indexing.AtomIndex) self.atom_index = int(atom_index) self.y = strip_unit(y, u.nanometer) self.z = strip_unit(z, u.nanometer) self.delta = strip_unit(delta, u.nanometer) self.force_const = strip_unit( force_const, u.kilojoule_per_mole / u.nanometer**2 ) self.scaler = ConstantScaler() if scaler is None else scaler self.ramp = ConstantRamp() if ramp is None else ramp self._check()
def _check(self): if self.delta < 0: raise ValueError("delta must be non-negative") if self.force_const < 0: raise ValueError("force_const must be non-negative")
[docs]class AbsoluteCOMRestraint(NonSelectableRestraint): """ Restraint on the distance between a group and a point in space This class implements a restraint on the distance between the center of a group and a point in space. The weights used to calculate the center can be specified as ``weights``. If ``None``, then the masses of the atoms will be used. The ``dims`` parameter controls which dimensions are used to compute the distance. For example if ``dims='xyz'``, then the distance will be the normal distance in all three dimensions. If ``dims=x``, then only the x-component will be considered. Restraints are typically added using ``RestraintMangager.create_restraint`` with the ``'abs_com'`` key: >>> r = system.restraints.create_restraint('abs_com', scaler=scaler, ramp=ramp, group=group, weights=weights, dims=dims, force_const=force_const, position=position) """ _restraint_key_ = "abs_com"
[docs] def __init__( self, system: interfaces.ISystem, scaler: Optional[RestraintScaler], ramp: Optional[TimeRamp], group: List[indexing.AtomIndex], weights: np.ndarray, dims: str, force_const: u.Quantity, position: u.Quantity, ): """ Initialize an AbsoluteCOMRestraint Args: system: system object used for indexing scaler: scale the force with alpha ramp: scale the force over time group: atoms to restrain COM weights: Weights to use when calculating the COM. If ``None``, then the masses will be used. dims: combination of x, y, z that determines which dimensions are used when calculating the distance force_const: force constant in kJ/mol/nm^2 point: location in space to restrain to """ self.scaler: RestraintScaler = ConstantScaler() if scaler is None else scaler self.ramp: TimeRamp = ConstantRamp() if ramp is None else ramp self.dims = dims self._check_dims() self.force_const = strip_unit( force_const, u.kilojoule_per_mole / u.nanometer**2 ) if self.force_const < 0: raise ValueError("force_const cannot be negative") assert isinstance(position, u.Quantity) self.position = position.value_in_unit(u.nanometer) if len(self.position) != 3: raise ValueError("position should be an array of [x, y, z]") self.weights = weights self.indices = self._get_indices(group) self._check_weights()
def _check_weights(self): if self.weights is not None: if len(self.indices) != len(self.weights): raise ValueError("weights and group have different lengths") for w in self.weights: if w < 0: raise ValueError("weights must be > 0") def _check_dims(self): for c in self.dims: if c not in "xyz": raise ValueError(f'dims must be a combination of "xyz", found {c}') for c in "xyz": if self.dims.count(c) > 1: raise ValueError(f"{c} occurs more than once in dims") def _get_indices(self, group): indices = [] for g in group: assert isinstance(g, indexing.AtomIndex) indices.append(int(g)) return indices
[docs]class COMRestraint(NonSelectableRestraint): """ Restraint on the distance between two groups along selected axes This class implements a restraint on the distance between the center of two groups. The weights used to calculate the center can be specified as ``weights1`` and ``weights2``. If these are ``None``, then the masses of the atoms will be used. The ``dims`` parameter controls which dimensions are used to compute the distance. For example if ``dims='xyz'``, then the distance will be the normal distance in all three dimensions. If ``dims='x'``, then only the x-component will be considered. Restraints are typically added using ``RestraintMangager.create_restraint`` with the ``'com'`` key: >>> r = system.restraints.create_restraint('com', scaler, ramp=ramp, group1=group1, group2=group2, weights1=weights1, weights2=weights2, dims=dims, force_const=force_const, distance=distance) """ _restraint_key_ = "com"
[docs] def __init__( self, system: interfaces.ISystem, scaler: Optional[RestraintScaler], ramp: Optional[TimeRamp], group1: List[indexing.AtomIndex], group2: List[indexing.AtomIndex], weights1: List[float], weights2: List[float], dims: str, force_const: u.Quantity, distance: Union[u.Quantity, Positioner], ): """ Initialize a COMRestraint Args: system: the system this restraint belongs to scaler: scale the force with alpha ramp: scale the force over time group1: atoms in group1 group2: atoms in group2 weights1: Weights to use when calculating the COM. If ``None``, then the atom masses will be used. weights2: Weights to use when calculating the COM. If ``None``, then the atom masses will be used. dims: combination of x, y, z that determines which dimensions are used when calculating the distance force_const: force constant in kJ/mol/nm^2 distance: distance between groups """ # setup indices self.scaler = ConstantScaler() if scaler is None else scaler self.ramp = ConstantRamp() if ramp is None else ramp self.indices1 = self._get_indices(group1) self.indices2 = self._get_indices(group2) # setup the weights self.weights1 = weights1 if self.weights1 is not None: if len(self.indices1) != len(self.weights1): raise ValueError("len(indices1) != len(weights1)") for w in self.weights1: if w < 0: raise ValueError("weights1 must be > 0") self.weights2 = weights2 if weights2 is not None: if len(self.indices2) != len(weights2): raise ValueError("len(indices2) != len(weights2)") for w in self.weights2: if w < 0: raise ValueError("weights2 must be > 0") # setup the dimensions self.dims = dims self._check_dims() # setup the force constant and positioner self.force_const = strip_unit( force_const, u.kilojoule_per_mole / u.nanometer**2 ) if self.force_const < 0: raise ValueError("force constant cannot be negative") if isinstance(distance, Positioner): self.positioner = distance else: if strip_unit(distance, u.nanometer) < 0.0: raise ValueError("distance cannot be negative") self.positioner = ConstantPositioner(distance)
def _get_indices(self, group): indices = [] for g in group: assert isinstance(g, indexing.AtomIndex) indices.append(int(g)) return indices def _check_dims(self): # check for non 'xyz' for c in self.dims: if c not in "xyz": raise ValueError(f'dims must be a combination of "xyz", found {c}') for dim in "xyz": count = self.dims.count(dim) if count > 1: raise ValueError(f"{dim} occurs more than once in dims")
[docs]class DensityRestraint(SelectableRestraint): _restraint_key_ = "density"
[docs] def __init__( self, system: interfaces.ISystem, scaler: Optional[RestraintScaler], ramp: Optional[TimeRamp], atom: list, density: DensityMap, mu=None, ): self.atom_index = [int(i) for i in atom] self.scaler = ConstantScaler() if scaler is None else scaler self.ramp = ConstantRamp() if ramp is None else ramp self.mu = density.density_data self.map_origin = density.origin self.map_dimension = [density.nx, density.ny, density.nz] self.map_gridLength = density.voxel_size
[docs]class AlwaysActiveCollection: """ A collection of restraints that are always on """
[docs] def __init__(self): self._restraints = []
@property def restraints(self) -> List[Restraint]: return self._restraints
[docs] def add_restraint(self, restraint: Restraint): """ Add a restraint Args: restraint: restraint to add """ if not isinstance(restraint, Restraint): raise RuntimeError( f"Tried to add unknown restraint of type {str(type(restraint))}." ) self._restraints.append(restraint)
[docs]class SelectivelyActiveCollection: """ A collection of :class:`RestraintGroup` that are selectively active Each time step the ``num_active`` lowest energy groups will be active. """
[docs] def __init__( self, restraint_list: List[Union[RestraintGroup, SelectableRestraint]], num_active: int, ): """ Initialize a SelectivelyActiveCollection Args restraint_list: list of restraints to add to collection num_active: number active each time step Note: ``restraint_list`` can contain both :class:`RestraintGroup` and :class:`SelectableRestraint`. Any :class:`SelectableRestraints` will be put into a singleton :class:`RestraintGroup`. """ self._groups: List[RestraintGroup] = [] if not restraint_list: raise RuntimeError( "SelectivelyActiveCollection cannot have empty restraint list." ) for rest in restraint_list: self._add_restraint(rest) # Do error checking n_rest = len(self._groups) if isinstance(num_active, param_sampling.DiscreteParameter): if num_active.min < 0: raise RuntimeError("num_active must be >= 0.") if num_active.max > n_rest: raise RuntimeError(f"num active must be <= num_groups ({n_rest}).") else: if num_active < 0: raise RuntimeError("num_active must be >= 0.") if num_active > n_rest: raise RuntimeError(f"num active must be <= num_groups ({n_rest}).") self._num_active = num_active
@property def groups(self) -> List[RestraintGroup]: """ Number of groups in collection """ return self._groups @property def num_active(self) -> int: """ Number active in collection """ return self._num_active def _add_restraint(self, restraint): if isinstance(restraint, RestraintGroup): self._groups.append(restraint) elif not isinstance(restraint, SelectableRestraint): raise RuntimeError( f"Cannot add restraint of type {str(type(restraint))} to" "SelectivelyActiveCollection" ) else: group = RestraintGroup([restraint], 1) self._groups.append(group)
[docs]class RestraintGroup: """ A group of selectable restraints Each timestep the lowest ``num_active`` energy restraints will be active. """
[docs] def __init__(self, rest_list: List[SelectableRestraint], num_active: int): """ Initialize a RestraintGroup Args: rest_list: list of :class:`SelectableRestraint` in this group num_active: number active each timestep """ self._restraints: List[SelectableRestraint] = [] if not rest_list: raise RuntimeError("rest_list cannot be empty.") for rest in rest_list: self._add_restraint(rest) n_rest = len(self._restraints) if isinstance(num_active, param_sampling.DiscreteParameter): if num_active.min < 0: raise RuntimeError("num_active must be >= 0.") if num_active.max > n_rest: raise RuntimeError(f"num active must be <= num_restraints ({n_rest}).") else: if num_active < 0: raise RuntimeError("num_active must be >= 0.") if num_active > n_rest: raise RuntimeError(f"num_active must be <= n_rest ({n_rest}).") self._num_active = num_active
@property def restraints(self) -> List[SelectableRestraint]: """ Restraints in the group """ return self._restraints @property def num_active(self) -> int: """ Number of active restraints """ return self._num_active def _add_restraint(self, rest): if not isinstance(rest, SelectableRestraint): raise RuntimeError("Can only add SelectableRestraints to a RestraintGroup.") self._restraints.append(rest)
[docs]class RestraintManager: """ A class to manage restraints for a System """
[docs] def __init__(self, system: interfaces.ISystem): """ Initialize a RestraintManager Args: system: the System to manage restraints for """ self._system = system self._always_active = AlwaysActiveCollection() self._selective_collections: List[SelectivelyActiveCollection] = []
@property def always_active(self) -> List[Restraint]: """ Always active restraints """ return self._always_active.restraints @property def selectively_active_collections(self) -> List[SelectivelyActiveCollection]: """ Selectively active collections """ return self._selective_collections
[docs] def add_as_always_active( self, restraint: Union[NonSelectableRestraint, SelectableRestraint] ) -> None: """ Add a restraint as always active Args: restraint: the restraint to add """ self._always_active.add_restraint(restraint)
[docs] def add_as_always_active_list( self, restraint_list: List[Union[NonSelectableRestraint, SelectableRestraint]] ) -> None: """ Add a list of restraints as always active Args: restraint_list: the restraints to add """ for r in restraint_list: self.add_as_always_active(r)
[docs] def add_selectively_active_collection( self, rest_list: List[Union[RestraintGroup, SelectableRestraint]], num_active: int, ) -> None: """ Add a selectively active collection Args: rest_list: list of restraints or restraint groups to add num_active: number of active groups in collection """ self._selective_collections.append( SelectivelyActiveCollection(rest_list, num_active) )
[docs] def create_restraint( self, rest_type: str, scaler: Optional[RestraintScaler] = None, ramp: Optional[TimeRamp] = None, **kwargs, ) -> Restraint: r""" Create a restraint Args: rest_type: type of restraint to add scaler: scale the force with alpha ramp: scale the force over time \**kwargs: passed along to restraint creation functions """ if scaler is None: scaler = ConstantScaler() else: if not isinstance(scaler, RestraintScaler): raise ValueError( "scaler must be a subclass of RestraintScaler, " f"you tried to add a {type(scaler)}." ) if ramp is None: ramp = ConstantRamp() else: if not isinstance(ramp, TimeRamp): raise ValueError( "ramp must be a subclass of TimeRamp," f"you tried to add a {type(ramp)}." ) return _RestraintRegistry.get_constructor_for_key(rest_type)( self._system, scaler, ramp, **kwargs )
[docs] def create_restraint_group( self, rest_list: List[SelectableRestraint], num_active: int ) -> RestraintGroup: """ Create a restraint group Args: rest_list: restraints to include in group num_active: number of restraints active at each timestep Returns: the new restraint group """ return RestraintGroup(rest_list, num_active)
[docs] def create_scaler(self, scaler_type: str, **kwargs) -> RestraintScaler: r""" Create a restraint scaler Args: scaler_type: the type a scaler to create \**kwargs: passed along to the scaler creattion functions Returns: the new restraint scaler """ return ScalerRegistry.get_constructor_for_key(scaler_type)(**kwargs)
[docs]class GMMParams(NamedTuple): n_components: int n_distances: int atoms: List[indexing.AtomIndex] weights: np.ndarray means: np.ndarray precisions: np.ndarray