#
# Copyright 2015 by Justin MacCallum, Alberto Perez, Ken Dill
# All rights reserved
#
"""
Methods to read template coordinates from a system object and build elastic network restraints prior to simulation for a specific chain.
"""
from typing import List, Optional, Union
import numpy as np # type: ignore
from scipy.spatial.distance import pdist # type: ignore
from scipy.spatial.distance import squareform # type: ignore
from meld import unit as u
from meld.system import meld_system, restraints, scalers
from meld.util import strip_unit
[docs]def create_elastic_network_restraints(
system: meld_system.System,
chain_indices: Union[int, List],
cutoff: u.Quantity = 1.0 * u.nanometer,
restrained_atom: str = "CA",
residue_separation: int = 3,
k: u.Quantity = 2500 * u.kilojoule_per_mole / u.nanometer**2,
only_inter: bool = False,
scaler: Optional[scalers.RestraintScaler] = None,
write_file: Optional[str] = None,
):
"""
Create elastic network restraints between specified backbone atoms within a single chain
Args:
system: system object that is used for indexing
chain_indices: provide a single int or list of integer ids that you want restraints built between
if you only want intra chain restraints (i.e. only restraints within a chain) only specify a single index
and just re-use the function for as many chains as you want intra chain restraints.
if you want restraints between chains (e.g. restraints to keep a complex together), provide a list of indices
cutoff: distance cutoff. Any restrained_atoms below the cutoff will have elastic network restraints created between them
restrained_atom: atom name that restraints between residues will use. Currently only a single atom name is supported
residue_separation: minimum sequence separation between restraints.
k: force constant
only_inter: only build restraints between chains. do not build restraints within a chain. useful if elastic network restraints
have come from somewhere else (e.g. martinize for martini builder)
scaler: Optional specify a Scaler to vary the force constant with alpha.
If ``None``, then a constant 1.0 scaler will be used.
write_file: Optional parameter to output restraints to a text file.
Useful for comparing restraints generated by different functions
only_intra
"""
# Ensure that k and cutoff are specified in the correct units
k_val = strip_unit(k, u.kilojoule_per_mole / u.nanometer**2)
cutoff_val = strip_unit(cutoff, u.nanometer)
# Check to see if scaler is specified
scaler = scalers.ConstantScaler() if scaler is None else scaler
# If chain indices are a single integer convert to iterable
if isinstance(chain_indices, int):
chain_indices = [chain_indices]
# Get all residues from every chain to build restraints
chains = list(system.topology.chains())
restrained_chains = [chains[chain_index] for chain_index in chain_indices]
residues = []
for restrained_chain in restrained_chains:
for residue in restrained_chain.residues():
# If explicitly solvated we need to ignore non protein residues.
if residue.name not in ["HOH", "WAT", "W", "Na+", "Ca2+", "K+", "Cl-"]:
residues.append(residue)
# Extract atom indices for CA atoms
ca_indices = []
for residue in residues:
# Not using chainid because that converts to relative numbering where as openmm uses absolute numbering within chains
ca_index = system.index.atom(
residue.index,
restrained_atom,
expected_resname=residue.name,
)
ca_indices.append(ca_index)
# Get initial coordinates from the system
coordinates = system.template_coordinates
ca_coordinates = coordinates[ca_indices]
# Calculate pairwise distances
dists = pdist(ca_coordinates)
# Transform dists into a nres by nres distance matrix
dist_map = squareform(dists)
# Grab upper (or lower doesn't matter chose upper arbitrarily) diagonal because dist_map is symmetric
dist_map = np.triu(dist_map)
# Make sure the length of the dist_map is the same as the length of residues
# This will not be the case if restrained_atom isn't present in every residue
# In that case, this function won't work
if dist_map.shape[0] != len(residues):
raise ValueError(
"Mismatch between distance matrix and number of residues. Usually this means that restrained_atom is not present in every residue of the chain."
)
# Triu sets off diagonal elements to zero so we look for things above zero and below our cutoff
close_pairs = np.argwhere((dist_map > 0) & (dist_map < cutoff_val))
# Collect the precise distance of each close pair
close_distances = []
for pair in close_pairs:
close_distance = dist_map[pair[0]][pair[1]]
close_distances.append(close_distance)
# Create restraints
rests = []
for pair, dist in zip(close_pairs, close_distances):
i, j = pair[0], pair[1]
# Skip writing restraints if only_inter turned on and chain indexes are the same.
if only_inter and residues[i].chain.index == residues[j].chain.index:
continue
# Only create restraints if residues are separated by at least residue_separation
if abs(i - j) >= residue_separation:
rest = system.restraints.create_restraint(
"distance",
scaler=scaler,
atom1=system.index.atom(
residues[i].index,
restrained_atom,
expected_resname=residues[i].name,
),
atom2=system.index.atom(
residues[j].index,
restrained_atom,
expected_resname=residues[j].name,
),
r1=0.0 * u.nanometer,
r2=0.0 * u.nanometer,
r3=dist * u.nanometer,
r4=(dist + 0.2) * u.nanometer,
k=k_val * u.kilojoule_per_mole / u.nanometer**2,
)
rests.append(rest)
# If write_file is set - write the restraints to a separate file.
# Mostly a sanity check to compare to previous ways to generate restraints
if write_file:
with open(write_file, "w") as f:
for pair, dist in zip(close_pairs, close_distances):
i, j = pair[0], pair[1]
# Skip writing restraints if only_inter turned on and chain indexes are the same.
if only_inter and residues[i].chain.index == residues[j].chain.index:
continue
if abs(i - j) >= residue_separation:
f.write(
f"{residues[i].chain.index} {residues[i].index} {restrained_atom} {residues[i].name} {residues[j].chain.index} {residues[j].index} {restrained_atom} {residues[j].name} {dist} {k_val} \n"
)
return rests
[docs]def add_elastic_network_restraints(
system: meld_system.System,
rests: List[restraints.SelectableRestraint],
active_fraction: float = 1.0,
max_grp_len: int = 64,
):
"""
For performance reasons, we add restraints in groups of grp_len
Args:
system: system object that restraint belongs to
rests: list of SelectableRestraint restraints
active fraction: fraction of restraints that must remain activated
max_grp_len: length of each restraint group.
"""
collection: List[
Union[restraints.RestraintGroup, restraints.SelectableRestraint]
] = []
grp: List[restraints.SelectableRestraint] = []
for rest in rests:
# Check to see if the length of the group is equal to the max group length
if len(grp) == max_grp_len:
# If so add group to collection and empty group
g = system.restraints.create_restraint_group(grp, len(grp))
collection.append(g)
grp = []
# Append restraints to group until max_grp_len is reached
grp.append(rest)
# When out of the for loop there will be remaining restraints that have a total length less than max_grp_len
# So we add the remaining restraints to the collection
g = system.restraints.create_restraint_group(grp, len(grp))
collection.append(g)
# Add restraints to system
system.restraints.add_selectively_active_collection(
collection, int(len(collection) * active_fraction)
)