Source code for meld.runner.openmm_runner

#
# Copyright 2015 by Justin MacCallum, Alberto Perez, Ken Dill
# All rights reserved
#

import logging
import math
import random
from tkinter import E
from typing import Dict, List, Optional

import numpy as np  # type: ignore
import openmm as mm  # type: ignore
from openmm import app  # type: ignore
from openmm import unit as u  # type: ignore

from meld import interfaces
from meld.runner import transform
from meld.system import options, restraints
from meld.system.state import SystemState
from meld.util import log_timing
from meld.vault import ENERGY_GROUPS

logger = logging.getLogger(__name__)


GAS_CONSTANT = 8.314e-3


[docs]class OpenMMRunner(interfaces.IRunner): _always_on_restraints: List[restraints.Restraint] _selectable_collections: List[restraints.SelectivelyActiveCollection] _options: options.RunOptions _simulation: app.Simulation _omm_system: mm.System _topology: app.Topology _integrator: mm.LangevinIntegrator _barostat: mm.MonteCarloBarostat _timestep: int _initialized: bool _alpha: float _temperature: float _transformers: List[transform.TransformerBase] _extra_bonds: List[interfaces.ExtraBondParam] _extra_restricted_angles: List[interfaces.ExtraAngleParam] _extra_torsions: List[interfaces.ExtraTorsParam]
[docs] def __init__( self, meld_system: interfaces.ISystem, options: options.RunOptions, communicator: Optional[interfaces.ICommunicator] = None, platform: Optional[str] = None, ): self._omm_system = meld_system.omm_system self._topology = meld_system.topology self._integrator = meld_system.integrator self._barostat = meld_system.barostat self._solvation = meld_system.solvation self.builder_info = meld_system.builder_info # Default to CUDA platform platform = platform if platform else "CUDA" self.platform = platform if communicator: # Only need to figure out device id for CUDA if platform == "CUDA": self._device_id = communicator.negotiate_device_id() self._rank: Optional[int] = communicator.rank else: self._device_id = 0 self._rank = None if meld_system.temperature_scaler is None: raise RuntimeError("system does not have temparture_scaler set") else: self.temperature_scaler = meld_system.temperature_scaler self._always_on_restraints = meld_system.restraints.always_active self._selectable_collections = ( meld_system.restraints.selectively_active_collections ) self._options = options self._timestep = 0 self._initialized = False self._alpha = 0.0 self._transformers: List[transform.TransformerBase] = [] self._extra_bonds = meld_system.extra_bonds self._extra_restricted_angles = meld_system.extra_restricted_angles self._extra_torsions = meld_system.extra_torsions self._parameter_manager = meld_system.param_sampler self._mapper = meld_system.mapper self._density = meld_system.density
def prepare_for_timestep( self, state: interfaces.IState, alpha: float, timestep: int ): self._alpha = alpha self._timestep = timestep assert self.temperature_scaler is not None self._temperature = self.temperature_scaler(alpha) self._initialize_simulation(state) @log_timing(logger) def minimize_then_run(self, state: interfaces.IState) -> interfaces.IState: return self._run(state, minimize=True) @log_timing(logger) def run(self, state: interfaces.IState) -> interfaces.IState: return self._run(state, minimize=False) def get_energy(self, state: interfaces.IState) -> float: # update all of the transformers self._transformers_update(state) # set the coordinates coordinates = u.Quantity(state.positions, u.nanometer) self._simulation.context.setPositions(coordinates) # set the box vectors if self._solvation == "explicit": box_vector = state.box_vector self._simulation.context.setPeriodicBoxVectors( [box_vector[0], 0.0, 0.0], [0.0, box_vector[1], 0.0], [0.0, 0.0, box_vector[2]], ) # set the rdc alignments self._set_alignments(state) # get the energy snapshot = self._simulation.context.getState(getEnergy=True) e_potential = snapshot.getPotentialEnergy() e_potential = ( e_potential.value_in_unit(u.kilojoule / u.mole) / GAS_CONSTANT / self._temperature ) # get the log_prior for parameters being sampled log_prior = self._parameter_manager.log_prior(state.parameters, self._alpha) return e_potential - log_prior def get_group_energies(self, state: interfaces.IState) -> np.ndarray: # update all of the transformers self._transformers_update(state) # set the coordinates coordinates = u.Quantity(state.positions, u.nanometer) self._simulation.context.setPositions(coordinates) # set the box vectors if self._solvation == "explicit": box_vector = state.box_vector self._simulation.context.setPeriodicBoxVectors( [box_vector[0], 0.0, 0.0], [0.0, box_vector[1], 0.0], [0.0, 0.0, box_vector[2]], ) # set the rdc alignments self._set_alignments(state) group_energies = np.zeros(ENERGY_GROUPS) for i in range(ENERGY_GROUPS - 1): snapshot = self._simulation.context.getState(getEnergy=True, groups={i}) e_potential = snapshot.getPotentialEnergy() e_potential = ( e_potential.value_in_unit(u.kilojoule / u.mole) / GAS_CONSTANT / self._temperature ) group_energies[i] = e_potential log_prior = self._parameter_manager.log_prior(state.parameters, self._alpha) group_energies[-1] = -log_prior return group_energies def _get_forces(self, state: interfaces.IState) -> np.ndarray: # update all of the transformers self._transformers_update(state) # set the coordinates coordinates = u.Quantity(state.positions, u.nanometer) self._simulation.context.setPositions(coordinates) # set the box vectors if self._solvation == "explicit": box_vector = state.box_vector self._simulation.context.setPeriodicBoxVectors( [box_vector[0], 0.0, 0.0], [0.0, box_vector[1], 0.0], [0.0, 0.0, box_vector[2]], ) # set the rdc alignments self._set_alignments(state) # get the forces snapshot = self._simulation.context.getState(getForces=True) forces = snapshot.getForces(asNumpy=True).value_in_unit( u.kilojoule / u.mole / u.nanometer ) return forces def _get_max_force_norm(self, state: interfaces.IState) -> float: forces = self._get_forces(state) return np.max(np.linalg.norm(forces, axis=1)) def _initialize_simulation(self, state: interfaces.IState) -> None: if self._initialized: # update temperature and pressure if self.builder_info.get("has_alignments", False): self._simulation.integrator.setGlobalVariableByName( "kT", self._temperature * GAS_CONSTANT ) else: self._integrator.setTemperature(self._temperature) if self._barostat: self._simulation.context.setParameter( self._barostat.Temperature(), self._temperature ) # update all of the system transformers self._transformers_update(state) else: # we need to set the whole thing from scratch self._initialized = True _add_extras( self._omm_system, self._extra_bonds, self._extra_restricted_angles, self._extra_torsions, ) # setup the transformers self._transformers_setup() if len(self._always_on_restraints) > 0: print("Not all always on restraints were handled.") for remaining_always_on in self._always_on_restraints: print("\t", remaining_always_on) raise RuntimeError("Not all always on restraints were handled.") if len(self._selectable_collections) > 0: print("Not all selectable restraints were handled.") for remaining_selectable in self._selectable_collections: print("\t", remaining_selectable) raise RuntimeError("Not all selectable restraints were handled.") self._omm_system = self._transformers_add_interactions( state, self._omm_system, self._topology ) self._transformers_finalize(state, self._omm_system, self._topology) # setup the platform, CUDA by default and Reference for testing properties: Dict[str, str] if self.platform == "Reference": logger.info("Using Reference platform.") platform = mm.Platform.getPlatformByName("Reference") properties = {} elif self.platform == "CPU": logger.info("Using CPU platform.") platform = mm.Platform.getPlatformByName("CPU") properties = {} elif self.platform == "CUDA": logger.info("Using CUDA platform.") platform = mm.Platform.getPlatformByName("CUDA") # The plugin currently requires that we use nvcc, as # nvrtc is not able to compile code that uses the cub # library, which we use in the plugin. # We can force the use of nvcc by setting CudaCompiler. # We set it to the default value, which will reflect the # OPENMM_CUDA_COMPILER environmnet variable if set. compiler = platform.getPropertyDefaultValue("CudaCompiler") logger.debug(f"Using CUDA compiler {compiler}.") properties = { "CudaDeviceIndex": str(self._device_id), "CudaPrecision": "mixed", "CudaCompiler": compiler, } else: raise RuntimeError(f"Unknown platform {self.platform}.") # forcegroups = self._forcegroupify(sys) # create the simulation object self._simulation = _create_openmm_simulation( self._topology, self._omm_system, self._integrator, platform, properties ) # forcegroups=self._forcegroupify(sys) # self._simulation = _create_openmm_simulation( # prmtop.topology, sys, self._integrator, platform, properties # ) self._transformers_update(state) def _forcegroupify(self, system): forcegroups = {} for i in range(system.getNumForces()): # logger.info(f"{i}th force \n") force = system.getForce(i) force.setForceGroup(i) forcegroups[force] = i return forcegroups def _transformers_setup(self) -> None: trans_types = [ transform.ConfinementRestraintTransformer, transform.CartesianRestraintTransformer, transform.YZCartesianTransformer, transform.COMRestraintTransformer, transform.AbsoluteCOMRestraintTransformer, transform.MeldRestraintTransformer, transform.REST2Transformer, ] for tt in trans_types: trans = tt( self._parameter_manager, self._mapper, self._density, self.builder_info, self._options, self._always_on_restraints, self._selectable_collections, ) self._transformers.append(trans) def _transformers_add_interactions( self, state: interfaces.IState, sys, topol ) -> mm.System: for t in self._transformers: sys = t.add_interactions(state, sys, topol) return sys def _transformers_finalize(self, state: interfaces.IState, sys, topol) -> None: for t in self._transformers: t.finalize(state, sys, topol) def _transformers_update(self, state: interfaces.IState) -> None: for t in self._transformers: t.update(state, self._simulation, self._alpha, self._timestep) def _run_min_mc(self, state: interfaces.IState) -> interfaces.IState: if self._options.min_mc is not None: logger.info("Running MCMC before minimization.") logger.info(f"Starting energy {self.get_energy(state):.3f}") logger.info( f"Starting maximum force norm {self._get_max_force_norm(state):.3f}" ) state.energy = self.get_energy(state) state = self._options.min_mc.update(state, self) logger.info(f"Ending energy {self.get_energy(state):.3f}") logger.info( f"Ending maximum force norm {self._get_max_force_norm(state):.3f}" ) return state def _run_mc(self, state: interfaces.IState) -> interfaces.IState: if self._options.run_mc is not None: logger.info("Running MCMC.") logger.debug(f"Starting energy {self.get_energy(state):.3f}") state.energy = self.get_energy(state) state = self._options.run_mc.update(state, self) logger.debug(f"Ending energy {self.get_energy(state):.3f}") return state def _run(self, state: interfaces.IState, minimize: bool) -> interfaces.IState: # update the transformers to account for sampled parameters # stored in the state self._transformers_update(state) assert abs(state.alpha - self._alpha) < 1e-6 # Run Monte Carlo position updates if minimize: state = self._run_min_mc(state) else: state = self._run_mc(state) # Run Monte Carlo parameter updates state = self._run_param_mc(state) # Run Monte Carlo mapper updates state = self._run_mapper_mc(state) coordinates = u.Quantity(state.positions, u.nanometer) velocities = u.Quantity(state.velocities, u.nanometer / u.picosecond) box_vectors = u.Quantity(state.box_vector, u.nanometer) # set the positions self._simulation.context.setPositions(coordinates) # if explicit solvent, then set the box vectors if self._solvation == "explicit": self._simulation.context.setPeriodicBoxVectors( [box_vectors[0].value_in_unit(u.nanometer), 0.0, 0.0], [0.0, box_vectors[1].value_in_unit(u.nanometer), 0.0], [0.0, 0.0, box_vectors[2].value_in_unit(u.nanometer)], ) # set the rdc alignments self._set_alignments(state) # run energy minimization if minimize: logger.info("Running minimization.") pre_state = self._simulation.context.getState( getForces=True, getEnergy=True ) pre_energy = ( pre_state.getPotentialEnergy().value_in_unit(u.kilojoule_per_mole) / GAS_CONSTANT / self._temperature ) pre_forces = pre_state.getForces().value_in_unit( u.kilojoule_per_mole / u.nanometer ) pre_norm = np.max(np.linalg.norm(pre_forces, axis=1)) logger.info(f"Starting energy {pre_energy:.3f}.") logger.info(f"Starting maximum force norm {pre_norm:.3f}.") self._simulation.minimizeEnergy(maxIterations=self._options.minimize_steps) post_state = self._simulation.context.getState( getForces=True, getEnergy=True ) post_energy = ( post_state.getPotentialEnergy().value_in_unit(u.kilojoule_per_mole) / GAS_CONSTANT / self._temperature ) post_forces = post_state.getForces().value_in_unit( u.kilojoule_per_mole / u.nanometer ) post_norm = np.linalg.norm(post_forces, axis=1) post_max = np.max(post_norm) post_index = np.argmax(post_norm) logger.info(f"Ending energy {post_energy:.3f}.") logger.info( f"Ending maximum force norm {post_max:.3f} on particle {post_index}." ) # set the velocities # check to see if velocities initialized to zero if np.all(velocities._value == 0.0): logger.info( "All velocities are zero, this is likely because input files do not contain velocity info. Generating velocities from Maxwell-Boltzmann distribution" ) self._simulation.context.setVelocitiesToTemperature(self._temperature) else: self._simulation.context.setVelocities(velocities) # run timesteps self._simulation.step(self._options.timesteps) # extract coords, vels, energy and strip units if self._solvation == "implicit": snapshot = self._simulation.context.getState( getPositions=True, getVelocities=True, getEnergy=True ) elif self._solvation == "explicit": snapshot = self._simulation.context.getState( getPositions=True, getVelocities=True, getEnergy=True, enforcePeriodicBox=True, ) coordinates = snapshot.getPositions(asNumpy=True).value_in_unit(u.nanometer) velocities = snapshot.getVelocities(asNumpy=True).value_in_unit( u.nanometer / u.picosecond ) _check_for_nan(coordinates, velocities, self._rank) # if explicit solvent, the recover the box vectors if self._solvation == "explicit": box_vector = snapshot.getPeriodicBoxVectors().value_in_unit(u.nanometer) box_vector = np.array( (box_vector[0][0], box_vector[1][1], box_vector[2][2]) ) # just store zeros for implicit solvent else: box_vector = np.zeros(3) # get the energy e_potential = ( snapshot.getPotentialEnergy().value_in_unit(u.kilojoule / u.mole) / GAS_CONSTANT / self._temperature ) # store in state state.positions = coordinates state.velocities = velocities state.energy = e_potential state.box_vector = box_vector state.rdc_alignments = self._gather_alignments() return state def _gather_alignments(self): values = [] if self.builder_info.get("has_alignments", False): for i in range(self.builder_info["num_alignments"]): for j in range(5): a = self._simulation.context.getParameter(f"rdc_{i}_s{j + 1}") values.append(a) values = np.array(values, dtype=np.float64) return values def _set_alignments(self, state): if self.builder_info.get("has_alignments", False): alignments = state.rdc_alignments.reshape(-1, 5) for i in range(alignments.shape[0]): for j in range(5): self._simulation.context.setParameter( f"rdc_{i}_s{j + 1}", alignments[i, j] ) def _run_param_mc(self, state): if not self._parameter_manager.has_parameters(): return state if self._options.param_mcmc_steps is None: raise RuntimeError( "There are sampled parameters, but param_mcmc_steps is not set." ) energy = self.get_energy(state) for _ in range(self._options.param_mcmc_steps): trial_params = self._parameter_manager.sample(state.parameters) if not self._parameter_manager.is_valid(trial_params): accept = False else: trial_state = SystemState( state.positions, state.velocities, state.alpha, state.energy, state.group_energies, state.box_vector, trial_params, state.mappings, ) trial_energy = self.get_energy(trial_state) delta = trial_energy - energy if delta < 0: accept = True else: if random.random() < math.exp(-delta): accept = True else: accept = False if accept: state = trial_state energy = trial_energy # Update transfomers in case we rejected the # last MCMC move if not accept: self._transformers_update(state) return state def _run_mapper_mc(self, state): if not self._mapper.has_mappers(): return state if self._options.mapper_mcmc_steps is None: raise RuntimeError( "There are mapped atom groups, but mapper_mcmc_steps is not set." ) energy = self.get_energy(state) accept = False for _ in range(self._options.mapper_mcmc_steps): trial_mappings = self._mapper.sample(state.mappings) trial_state = SystemState( state.positions, state.velocities, state.alpha, state.energy, state.group_energies, state.box_vector, state.parameters, trial_mappings, ) trial_energy = self.get_energy(trial_state) delta = trial_energy - energy if delta < 0: accept = True else: if random.random() < math.exp(-delta): accept = True else: accept = False if accept: state = trial_state energy = trial_energy # Update transfomers in case we rejected the # last MCMC move if not accept: self._transformers_update(state) return state
def _check_for_nan( coordinates: np.ndarray, velocities: np.ndarray, rank: Optional[int] ) -> None: output_rank = 0 if rank is None else rank if np.isnan(coordinates).any(): raise RuntimeError("Coordinates for rank {} contain NaN", output_rank) if np.isnan(velocities).any(): raise RuntimeError("Velocities for rank {} contain NaN", output_rank) def _create_openmm_simulation(topology, system, integrator, platform, properties): return app.Simulation(topology, system, integrator, platform, properties) def _add_extras(system, bonds, restricted_angles, torsions): # add the extra bonds if bonds: f = [f for f in system.getForces() if isinstance(f, mm.HarmonicBondForce)][0] for bond in bonds: f.addBond(bond.i, bond.j, bond.length, bond.force_constant) # add the extra restricted_angles if restricted_angles: # create the new force for restricted angles f = mm.CustomAngleForce( "0.5 * k_ra * (theta - theta0_ra)^2 / sin(theta * 3.1459 / 180)" ) f.addPerAngleParameter("k_ra") f.addPerAngleParameter("theta0_ra") for angle in restricted_angles: f.addAngle( angle.i, angle.j, angle.k, (angle.force_constant, angle.angle), ) system.addForce(f) # add the extra torsions if torsions: f = [f for f in system.getForces() if isinstance(f, mm.PeriodicTorsionForce)][0] for tors in torsions: f.addTorsion( tors.i, tors.j, tors.k, tors.l, tors.multiplicity, tors.phase, tors.energy, )