Source code for meld.system.mapping

#
# Copyright 2015 by Justin MacCallum, Alberto Perez, Ken Dill
# All rights reserved
#

"""
Module to handle sampling the mappings of peaks to atom indices
"""

import random
from typing import Dict, List, NamedTuple, Tuple

import numpy as np  # type: ignore

from meld.system import indexing


[docs]class PeakMapping(NamedTuple): """ A mapping from a peak to an atom """ map_name: str peak_id: int atom_name: str
[docs]class PeakMapper: name: str n_peaks: int n_active: int atom_names: List[str] atom_groups: List[Dict[str, int]] _frozen: bool
[docs] def __init__(self, name: str, n_peaks: int, n_active: int, atom_names: List[str]): if n_peaks <= 0: raise ValueError("n_peaks must be > 0") self.name = name self.n_peaks = n_peaks self.n_active = n_active assert n_peaks > 0 assert n_active > 0 assert n_active <= n_peaks self.atom_names = atom_names self.atom_groups = [] self.frozen = False
def add_atom_group(self, **kwargs: indexing.AtomIndex): if self.frozen: raise RuntimeError( "Cannot add an atom group after get_initial_state or extract_value have been called." ) for name in self.atom_names: if not name in kwargs: raise KeyError(f"Expected argument {name} not given.") for name in kwargs: if not name in self.atom_names: raise KeyError(f"Unexpected argument {name}.") for name, value in kwargs.items(): if not isinstance(value, indexing.AtomIndex): raise ValueError( f"Values should be AtomIndex, but got {type(value)} for {name}." ) self.atom_groups.append({k: int(v) for k, v in kwargs.items()}) def get_mapping(self, peak_id: int, atom_name: str) -> PeakMapping: if peak_id < 0: raise KeyError("peak_id must be >= 0.") if peak_id >= self.n_peaks: raise KeyError(f"peak_id must be <= {self.n_peaks - 1}.") if atom_name not in self.atom_names: raise KeyError(f"atom_name={atom_name} not in {self.atom_names}.") return PeakMapping(map_name=self.name, peak_id=peak_id, atom_name=atom_name) def get_initial_state(self) -> np.ndarray: # Freeze so we can't add more atom_groups self._frozen = True if self.n_active > self.n_atom_groups: raise ValueError("n_active must be <= n_atom_groups") state = -np.ones(self.n_peaks, dtype=int) state[: self.n_active] = np.arange(self.n_active) return state def sample(self, state: np.ndarray) -> np.ndarray: r = random.random() # We don't need to do peak reassignment, because all groups # will always be assigned to a peak. if self.n_active == self.n_atom_groups: if r < 0.1: return self._sample_peak_swap(state) else: return self._sample_neighbour_swap(state) # We have some groups that are unassigned, so we need to # include the peak reassignment step. else: if r < 0.1: return self._sample_peak_swap(state) elif r < 0.2: return self._sample_peak_reassign(state) else: return self._sample_neighbour_swap(state) def _sample_peak_swap(self, state: np.ndarray) -> np.ndarray: trial_state = state.copy() # Sample a pair of peaks to swap. i, j = random.sample(range(self.n_peaks), k=2) # Swap them group_i = trial_state[i] group_j = trial_state[j] trial_state[i] = group_j trial_state[j] = group_i return trial_state def _sample_neighbour_swap(self, state: np.ndarray) -> np.ndarray: trial_state = state.copy() # Choose two neighbouring residues i = random.randrange(self.n_atom_groups - 1) j = i + 1 # Identify the corresponding peaks peaks_i = np.argwhere(trial_state == i) peaks_j = np.argwhere(trial_state == j) peak_i = None if len(peaks_i) == 0 else peaks_i[0] peak_j = None if len(peaks_j) == 0 else peaks_j[0] # Neither residue is assigned to a peak, so we don't do anything if (peak_i is None) and (peak_j is None): pass # One of the residues is assigned but the other is not. elif peak_i is None: trial_state[peak_j] = i elif peak_j is None: trial_state[peak_i] = j # Both residues are assigned, so we swap them. else: trial_state[peak_i] = j trial_state[peak_j] = i return trial_state def _sample_peak_reassign(self, state: np.ndarray) -> np.ndarray: trial_state = state.copy() # Choose an assigned peak assigned_peaks = [peak[0] for peak in np.argwhere(trial_state != -1)] peak = random.choice(assigned_peaks) # Choose an unassigned atom group atom_groups = set(range(self.n_atom_groups)) assigned_groups = set(trial_state) unassigned_groups = list(atom_groups - assigned_groups) # Raise an error if we have no unassigned groups, as we shouldn't # be calling this function in that case. if not unassigned_groups: raise RuntimeError( "There are no unassigned groups, so _sample_peak_reassign shouldn't be called." ) group = random.choice(unassigned_groups) # Swap trial_state[peak] = group return trial_state @property def n_atom_groups(self) -> int: return len(self.atom_groups)
[docs]class PeakMapManager: mappers: Dict[str, PeakMapper] _name_to_range: Dict[str, Tuple[int, int]]
[docs] def __init__(self): self.mappers = {} self._name_to_range = None
def add_map( self, name: str, n_peaks: int, n_active: int, atom_names: List[str] ) -> PeakMapper: # don't allow duplicates if name in self.mappers: raise ValueError(f"Trying to insert duplicate entry for {name}.") mapper = PeakMapper(name, n_peaks, n_active, atom_names) self.mappers[name] = mapper return mapper def get_initial_state(self) -> np.ndarray: if self._name_to_range is None: self._setup_name_to_range() # If we don't have any mappers, just return an empty array. if not self.mappers: return np.array([], dtype=int) # Loop through our mappers in the order they were added and get the # initial state. states = [mapper.get_initial_state() for mapper in self.mappers.values()] # Concatenate them together return np.hstack(states) def extract_value(self, mapping: PeakMapping, state: np.ndarray) -> int: if self._name_to_range is None: self._setup_name_to_range() range_ = self._name_to_range[mapping.map_name] mapper = self.mappers[mapping.map_name] mapper.frozen = True if mapping.map_name != mapper.name: raise KeyError(f"Map name {mapping.map_name} does not match {mapper.name}.") peak_id = mapping.peak_id if peak_id < 0: raise KeyError("peak_id must be >= 0.") if peak_id >= mapper.n_peaks: raise KeyError(f"peak_id must be < {mapper.n_peaks}") group_index = state[mapping.peak_id + range_[0]] if group_index == -1: return -1 else: return mapper.atom_groups[group_index][mapping.atom_name] def sample(self, state: np.ndarray) -> np.ndarray: if self._name_to_range is None: self._setup_name_to_range() sub_states = [] for name in self.mappers: range_ = self._name_to_range[name] sub_state = state[range_[0] : range_[1]] sub_states.append(sub_state) trial_sub_samples = [] perturbed = random.randrange(0, len(sub_states)) for i, (mapper, sub_state) in enumerate(zip(self.mappers.values(), sub_states)): if i == perturbed: trial_sub_sample = mapper.sample(sub_state) trial_sub_samples.append(trial_sub_sample) else: trial_sub_samples.append(sub_state) return np.hstack(trial_sub_samples) def get_index(self, mapping: PeakMapping) -> int: if self._name_to_range is None: self._setup_name_to_range() range_ = self._name_to_range[mapping.map_name] mapper = self.mappers[mapping.map_name] mapper.frozen = True if mapping.map_name != mapper.name: raise KeyError(f"Map name {mapping.map_name} does not match {mapper.name}.") peak_id = mapping.peak_id return peak_id + range_[0] def has_mappers(self) -> bool: if self.mappers: return True else: return False def _setup_name_to_range(self): start = 0 self._name_to_range = {} for name in self.mappers: length = self.mappers[name].get_initial_state().shape[0] self._name_to_range[name] = (start, start + length) start += length