Source code for meld.comm

#
# Copyright 2015 by Justin MacCallum, Alberto Perez, Ken Dill
# All rights reserved
#

"""
Module to handle MPI communication
"""


import contextlib
import logging
import os
import platform
import signal
import sys
import threading
import time
from collections import defaultdict, namedtuple
from typing import Any, Dict, List, NamedTuple, Optional, Sequence, TypeVar

import numpy as np  # type: ignore

from meld import interfaces, util

logger = logging.getLogger(__name__)


# setup exception handling to abort when there is unhandled exception
sys_excepthook = sys.excepthook


def _mpi_excepthook(type, value, traceback):
    sys_excepthook(type, value, traceback)
    rank = _get_mpi_comm_world().rank + 1
    size = _get_mpi_comm_world().size
    node_name = f"{rank}/{size}"
    logger.critical(f"MPI node {node_name} raised exception.")
    sys.stdout.flush()
    sys.stderr.flush()
    _get_mpi_comm_world().Abort(1)


sys.excepthook = _mpi_excepthook


[docs]class MPICommunicator(interfaces.ICommunicator): """ Class to handle communications between leader and workers using MPI. Note: creating an MPI communicator will not actually initialize MPI. To do that, call :meth:`initialize`. """ _mpi_comm: Any
[docs] def __init__(self, n_atoms: int, n_replicas: int, timeout: int = 600): """ Initialize an MPICommunicator Args: n_atoms: number of atoms n_replicas: number of replicas timeout: maximum time to wait before aborting """ # We're not using n_atoms, but if we switch # to more efficient buffer-based # MPI routines, we'll need it. self._n_atoms = n_atoms self._n_replicas = n_replicas self._timeout = timeout self._timeout_message = f"Call to {{:s}} did not complete in {timeout} seconds"
def __getstate__(self) -> Dict[str, Any]: # don't pickle _mpi_comm return dict((k, v) for (k, v) in self.__dict__.items() if not k == "_mpi_comm") def __setstate__(self, state: Dict[str, Any]) -> None: # set _mpi_comm to None self.__dict__ = state
[docs] def initialize(self) -> None: """ Initialize and start MPI """ self._mpi_comm = _get_mpi_comm_world() self._my_rank = self._mpi_comm.Get_rank() self._n_workers = self._mpi_comm.Get_size()
[docs] def is_leader(self) -> bool: """ Is this the leader node? Returns: :const:`True` if we are the leader, otherwise :const:`False` """ if self._my_rank == 0: return True else: return False
[docs] @util.log_timing(logger) def barrier(self) -> None: """ Wait until all workers reach this point """ with _timeout( self._timeout, RuntimeError(self._timeout_message.format("barrier")) ): self._mpi_comm.barrier()
[docs] @util.log_timing(logger) def distribute_alphas_to_workers(self, all_alphas: List[float]) -> List[float]: """ Distribute alphas to workers Args: all_alphas: the alpha values to be distributed Returns: the block of alpha values for the leader """ alpha_blocks = self._to_blocks(all_alphas) with _timeout( self._timeout, RuntimeError(self._timeout_message.format("broadcast_alphas_to_workers")), ): return self._mpi_comm.scatter(alpha_blocks, root=0)
[docs] @util.log_timing(logger) def receive_alphas_from_leader(self) -> List[float]: """ Receive a block of alphas from leader. Returns: the block of alpha values for this worker """ with _timeout( self._timeout, RuntimeError(self._timeout_message.format("receive_alphas_from_leader")), ): return self._mpi_comm.scatter(None, root=0)
[docs] @util.log_timing(logger) def distribute_states_to_workers( self, all_states: Sequence[interfaces.IState] ) -> List[interfaces.IState]: """ Distribute a block of states to each worker. Args: all_states: states to be distributed Returns: the block of states to run on the leader node """ # Divide the states into blocks state_blocks = self._to_blocks(all_states) with _timeout( self._timeout, RuntimeError(self._timeout_message.format("broadcast_states_to_workers")), ): return self._mpi_comm.scatter(state_blocks, root=0)
[docs] @util.log_timing(logger) def receive_states_from_leader(self) -> List[interfaces.IState]: """ Get the block of states to run for this step Returns: the block of states to run for this step """ with _timeout( self._timeout, RuntimeError(self._timeout_message.format("receive_state_from_leader")), ): return self._mpi_comm.scatter(None, root=0)
[docs] @util.log_timing(logger) def gather_states_from_workers( self, state_on_leader: List[interfaces.IState] ) -> List[interfaces.IState]: """ Receive states from all workers Args: states_on_leader: the block of states on the leader after simulating Returns: A list of states, one from each replica. """ with _timeout( self._timeout, RuntimeError(self._timeout_message.format("gather_states_from_workers")), ): blocks = self._mpi_comm.gather(state_on_leader, root=0) return self._from_blocks(blocks)
[docs] @util.log_timing(logger) def send_states_to_leader(self, block: Sequence[interfaces.IState]) -> None: """ Send a block of states to the leader Args: block: block of states to send to the leader. """ with _timeout( self._timeout, RuntimeError(self._timeout_message.format("send_states_to_leader")), ): self._mpi_comm.gather(block, root=0)
[docs] @util.log_timing(logger) def broadcast_all_states_to_workers( self, states: Sequence[interfaces.IState] ) -> None: """ Broadcast all states to all workers. Args: states: a list of states """ with _timeout( self._timeout, RuntimeError( self._timeout_message.format( "broadcast_states_for_energy_calc_to_workers" ) ), ): self._mpi_comm.bcast(states, root=0)
[docs] @util.log_timing(logger) def receive_all_states_from_leader(self) -> Sequence[interfaces.IState]: """ Receive all states from leader. Returns: a list of states to calculate the energy of """ with _timeout( self._timeout, RuntimeError( self._timeout_message.format( "receive_states_for_energy_calc_from_leader" ) ), ): return self._mpi_comm.bcast(None, root=0)
[docs] @util.log_timing(logger) def gather_energies_from_workers( self, energies_on_leader: np.ndarray ) -> np.ndarray: """ Receive energies from each worker. Args: energies_on_leader: the energies from the leader Returns: a square matrix of every state on every replica to be used for replica exchange Note: Each row of the output matrix represents a different Hamiltonian. Each column represents a different state. Each worker will compute multiple rows of the output matrix. """ with _timeout( self._timeout, RuntimeError(self._timeout_message.format("gather_energies_from_workers")), ): energies = self._mpi_comm.gather(energies_on_leader, root=0) return np.concatenate(energies, axis=0)
[docs] @util.log_timing(logger) def send_energies_to_leader(self, energies: np.ndarray) -> None: """ Send a block of energies to the leader. Args: energies: block of energies to send to the leader Note: Each row represents a different Hamiltonian. Each column represents a different state. """ with _timeout( self._timeout, RuntimeError(self._timeout_message.format("send_energies_to_leader")), ): self._mpi_comm.gather(energies, root=0)
[docs] @util.log_timing(logger) def negotiate_device_id(self) -> int: """ Negotiate CUDA device id Returns: the cuda device id to use """ with _timeout( self._timeout, RuntimeError(self._timeout_message.format("negotiate_device_id")), ): hostname = platform.node() try: env_visible_devices = os.environ["CUDA_VISIBLE_DEVICES"] logger.info("%s found cuda devices: %s", hostname, env_visible_devices) visible_devices: Optional[List[int]] = [ int(dev) for dev in env_visible_devices.split(",") ] if not visible_devices: raise RuntimeError("No cuda devices available") else: if len(visible_devices) == 1: logger.info( "negotiate_device_id: visible_devices contains a single device: %d", visible_devices[0], ) device_id = 0 logger.info("hostname: %s, device_id: %d", hostname, device_id) return device_id except KeyError: logger.info("%s CUDA_VISIBLE_DEVICES is not set.", hostname) visible_devices = None hosts = self._mpi_comm.gather(HostInfo(hostname, visible_devices), root=0) # the leader computes the device ids if self._my_rank == 0: if hosts[0].devices is None: # if CUDA_VISIBLE_DEVICES isn't set on the leader, we assume it # isn't set for any node logger.info("CUDA_VISIBLE_DEVICES is not set.") logger.info("Assuming each mpi process has access") logger.info("to a CUDA device, where the device") logger.info("numbering starts from 0.") # create an empty default dict to count hosts host_counts: Dict[str, int] = defaultdict(int) # list of device ids # this assumes that available devices for each node # are numbered starting from 0 device_ids = [] for host in hosts: assert host.devices is None device_ids.append(host_counts[host.host_name]) host_counts[host.host_name] += 1 else: # CUDA_VISIBLE_DEVICES is set on the leader, so we # assume it is set for all nodes logger.info("CUDA_VISIBLE_DEVICES is set.") # create a dict to hold the device ids available on each host available_devices: Dict[str, List[int]] = {} # store the available devices on each node for host in hosts: if host.host_name in available_devices: if host.devices != available_devices[host.host_name]: raise RuntimeError("GPU devices for host do not match") else: available_devices[host.host_name] = host.devices # CUDA numbers the devices contiguously, starting from zero. # For example, if `CUDA_VISIBLE_DEVICES=2,4,5`, we would # access these as ids 0, 1, 2. available_devices = { host_name: list(range(len(devices))) for host_name, devices in available_devices.items() } # device ids for each node device_ids = [] for host in hosts: try: # pop off the first device_id for this host name device_ids.append(available_devices[host.host_name].pop(0)) except IndexError: logger.error("More mpi processes than GPUs") raise RuntimeError("More mpi process than GPUs") # receive device id from leader else: device_ids = [] # do the communication device_id = self._mpi_comm.scatter( device_ids if device_ids else None, root=0 ) logger.info("hostname: %s, device_id: %d", hostname, device_id) return device_id
@property def n_replicas(self) -> int: """number of replicas""" return self._n_replicas @property def n_atoms(self) -> int: """number of atoms""" return self._n_atoms @property def n_workers(self) -> int: """number of workers""" return self._n_workers @property def rank(self) -> int: """rank of this worker""" return self._my_rank X = TypeVar("X") def _to_blocks(self, items: Sequence[X]) -> List[List[X]]: items = list(items) if len(items) % self.n_workers: raise ValueError("number of items must be divisible by n_workers") group_size = len(items) // self.n_workers return [items[i : i + group_size] for i in range(0, len(items), group_size)] def _from_blocks(self, blocks: Sequence[List[X]]) -> List[X]: blocks = list(blocks) return [item for sublist in blocks for item in sublist]
def _get_mpi_comm_world(): """ Helper function to return the comm_world. """ try: from mpi4py import MPI # type: ignore except ImportError: print() print("****") print("Error importing mpi4py.") print() print("Meld depends on mpi4py, but does not automatically install it") print( "as a dependency. See https://github.com/maccallumlab/meld/blob/master/README.md" ) print("for details.") print("****") print() raise return MPI.COMM_WORLD # namedtuple to hold results for negotiate id
[docs]class HostInfo(NamedTuple): host_name: str devices: Optional[List[int]]
# Adapted from interrupting cow # https://bitbucket.org/evzijst/interruptingcow # # Original license below # # The MIT License (MIT) # # Copyright (c) 2012 Erik van Zijst <erik.van.zijst@gmail.com> # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in # all copies or substantial portions of the Software. # # THE USE OR OTHER DEALINGS IN THE SOFTWARE. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. class _StateException(Exception): pass class _Quota: def __init__(self, seconds): if seconds <= 0: raise ValueError("Invalid timeout: %s" % seconds) else: self._timeleft = seconds self._depth = 0 self._starttime = None def __str__(self): return "<Quota remaining=%s>" % self.remaining() def _start(self): if self._depth == 0: self._starttime = time.time() self._depth += 1 def _stop(self): if self._depth == 1: self._timeleft = self.remaining() self._starttime = None self._depth -= 1 def running(self): return self._depth > 0 def remaining(self): if self.running(): return max(self._timeleft - (time.time() - self._starttime), 0) else: return max(self._timeleft, 0) def _bootstrap(): Timer = namedtuple("Timer", "expiration exception") timers = [] def handler(*args): exception = timers.pop().exception if timers: timeleft = timers[-1].expiration - time.time() if timeleft > 0: signal.setitimer(signal.ITIMER_REAL, timeleft) else: handler(*args) raise exception def set_sighandler(): current = signal.getsignal(signal.SIGALRM) if current == signal.SIG_DFL: signal.signal(signal.SIGALRM, handler) elif current != handler: raise _StateException( "Your process alarm handler is already in " "use! Interruptingcow cannot be used in " "programs that use SIGALRM." ) def timeout(seconds, exception): if threading.currentThread().name != "MainThread": raise _StateException( "Interruptingcow can only be used from the " "MainThread." ) if isinstance(seconds, _Quota): quota = seconds else: quota = _Quota(float(seconds)) set_sighandler() seconds = quota.remaining() depth = len(timers) parenttimeleft = signal.getitimer(signal.ITIMER_REAL)[0] if not timers or parenttimeleft > seconds: try: quota._start() timers.append(Timer(time.time() + seconds, exception)) if seconds > 0: signal.setitimer(signal.ITIMER_REAL, seconds) yield else: handler() finally: quota._stop() if len(timers) > depth: # cancel our timer signal.setitimer(signal.ITIMER_REAL, 0) timers.pop() if timers: # reinstall the parent timer parenttimeleft = timers[-1].expiration - time.time() if parenttimeleft > 0: signal.setitimer(signal.ITIMER_REAL, parenttimeleft) else: # the parent timer has expired, trigger the handler handler() else: # not enough time left on the parent timer try: quota._start() yield finally: quota._stop() @contextlib.contextmanager def timeout_context_manager(seconds, exception): t = timeout(seconds, exception) next(t) yield return timeout_context_manager _timeout = _bootstrap()