Source code for thermo.gpumd.preproc

from .common import __check_list, __check_range
from numpy import prod

__author__ = "Alexander Gabourie"
__email__ = "gabourie@stanford.edu"


#########################################
# Structure preprocessing
#########################################

def __get_group(split, pos, direction):
    """
    Gets the group that an atom belongs to based on its position. Only works in
    one direction as it is used for NEMD.

    Args:
        split (list(float)):
            List of boundaries. First element should be lower boundary of
            sim. box in specified direction and the last the upper.

        position (float):
            Position of the atom

        direction (str):
            Which direction the split will work

    Returns:
        int: Group of atom

    """
    if direction == 'x':
        d = pos[0]
    elif direction == 'y':
        d = pos[1]
    else:
        d = pos[2]
    errmsg = 'Out of bounds error: {}'.format(d)
    for i, val in enumerate(split[:-1]):
        if i == 0 and d < val:
            print(errmsg)
            return -1
        if val <= d < split[i + 1]:
            return i
    print(errmsg)
    return -1


def __init_index(index, info, num_atoms):
    """
    Initializes the index key for the info dict.

    Args:
        index (int):
            Index of atom in the Atoms object.

        info (dict):
            Dictionary that stores the velocity, and groups.

        num_atoms (int):
            Number of atoms in the Atoms object.

    Returns:
        int: Index of atom in the Atoms object.

    """
    if index == num_atoms - 1:
        index = -1
    if index not in info:
        info[index] = dict()
    return index


def __handle_end(info, num_atoms):
    """
    Duplicates the index -1 entry for key that's num_atoms-1. Works in-place.

    Args:
        info (dict):
            Dictionary that stores the velocity, and groups.

        num_atoms (int):
            Number of atoms in the Atoms object.

    """
    info[num_atoms - 1] = info[-1]


[docs]def add_group_by_position(split, atoms, direction): """ Assigns groups to all atoms based on its position. Only works in one direction as it is used for NEMD. Returns a bookkeeping parameter, but atoms will be udated in-place. Args: split (list(float)): List of boundaries. First element should be lower boundary of sim. box in specified direction and the last the upper. atoms (ase.Atoms): Atoms to group direction (str): Which direction the split will work. Returns: int: A list of number of atoms in each group. """ info = atoms.info counts = [0] * (len(split) - 1) num_atoms = len(atoms) for index, atom in enumerate(atoms): index = __init_index(index, info, num_atoms) i = __get_group(split, atom.position, direction) if 'groups' in info[index]: info[index]['groups'].append(i) else: info[index]['groups'] = [i] counts[i] += 1 __handle_end(info, num_atoms) atoms.info = info return counts
[docs]def add_group_by_type(atoms, types): """ Assigns groups to all atoms based on atom types. Returns a bookkeeping parameter, but atoms will be udated in-place. Args: atoms (ase.Atoms): Atoms to group types (dict): Dictionary with types for keys and group as a value. Only one group allowed per atom. Assumed groups are integers starting at 0 and increasing in steps of 1. Ex. range(0,10). Returns: int: A list of number of atoms in each group. """ # atom symbol checking all_symbols = list(types) # check that symbol set matches symbol set of atoms if set(atoms.get_chemical_symbols()) - set(all_symbols): raise ValueError('Group symbols do not match atoms symbols.') if not len(set(all_symbols)) == len(all_symbols): raise ValueError('Group not assigned to all atom types.') num_groups = len(set([types[sym] for sym in set(all_symbols)])) num_atoms = len(atoms) info = atoms.info counts = [0] * num_groups for index, atom in enumerate(atoms): index = __init_index(index, info, num_atoms) group = types[atom.symbol] counts[group] += 1 if 'groups' in info[index]: info[index]['groups'].append(group) else: info[index]['groups'] = [group] __handle_end(info, num_atoms) atoms.info = info return counts
[docs]def set_velocities(atoms, custom=None): """ Sets the 'velocity' part of the atoms to be used in GPUMD. Custom velocities must be provided. They must also be in the units of eV^(1/2) amu^(-1/2). Args: atoms (ase.Atoms): Atoms to assign velocities to. custom (list(list)): list of len(atoms) with each element made from a 3-element list for [vx, vy, vz] """ if not custom: raise ValueError("No velocities provided.") num_atoms = len(atoms) info = atoms.info if not len(custom) == num_atoms: return ValueError('Incorrect number of velocities for number of atoms.') for index, (atom, velocity) in enumerate(zip(atoms, custom)): if not len(velocity) == 3: return ValueError('Three components of velocity not provided.') index = __init_index(index, info, num_atoms) info[index]['velocity'] = velocity __handle_end(info, num_atoms) atoms.info = info
def __init_index2(index, info): # TODO merge this with other __init_index function if index not in info.keys(): info[index] = dict()
[docs]def add_basis(atoms, index=None, mapping=None): """ Assigns a basis index for each atom in atoms. Updates atoms. Args: atoms (ase.Atoms): Atoms to assign basis to. index (list(int)): Atom indices of those in the unit cell. Order is important. mapping (list(int)): Mapping of all atoms to the relevant basis positions """ n = atoms.get_global_number_of_atoms() info = atoms.info info['unitcell'] = list() if index: if (mapping is None) or (len(mapping) != n): raise ValueError("Full atom mapping required if index is provided.") for idx in index: info['unitcell'].append(idx) for idx in range(n): __init_index2(idx, info) info[idx]['basis'] = mapping[idx] else: for idx in range(n): info['unitcell'].append(idx) # if no index provided, assume atoms is unit cell __init_index2(idx, info) info[idx]['basis'] = idx
[docs]def repeat(atoms, rep): """ A wrapper of ase.Atoms.repeat that is aware of GPUMD's basis information. Args: atoms (ase.Atoms): Atoms to assign velocities to. rep (int | list(3 ints)): List of three positive integers or a single integer """ rep = __check_list(rep, varname='rep', dtype=int) replen = len(rep) if replen == 1: rep = rep*3 elif not replen == 3: raise ValueError("rep must be a sequence of 1 or 3 integers.") __check_range(rep, 2**64) supercell = atoms.repeat(rep) sinfo = supercell.info ainfo = atoms.info n = atoms.get_global_number_of_atoms() for i in range(1, prod(rep, dtype=int)): for j in range(n): sinfo[i*n+j] = {'basis': ainfo[j]['basis']} return supercell