Update FixColvars to expand usage of fix_modify commands

See https://github.com/Colvars/colvars/pull/418

Also moving inthash code to a separate file to simplify future refactoring
This commit is contained in:
Giacomo Fiorin
2024-08-06 01:05:51 +02:00
parent 72a0992054
commit 278accd9ea
8 changed files with 779 additions and 566 deletions

View File

@ -23,26 +23,33 @@
/* ----------------------------------------------------------------------
Contributing author: Axel Kohlmeyer (Temple U)
Currently maintained by: Giacomo Fiorin (NIH)
------------------------------------------------------------------------- */
#include "fix_colvars.h"
#include "inthash.h"
#include "atom.h"
#include "citeme.h"
#include "comm.h"
#include "domain.h"
#include "error.h"
#include "input.h"
#include "memory.h"
#include "modify.h"
#include "output.h"
#include "respa.h"
#include "universe.h"
#include "update.h"
#include <cstring>
#include <iostream>
#include <vector>
#include "colvarproxy_lammps.h"
#include "colvarmodule.h"
#include "colvarproxy.h"
#include "colvarproxy_lammps.h"
#include "colvarscript.h"
#include "colvars_memstream.h"
/* struct for packed data communication of coordinates and forces. */
@ -58,197 +65,11 @@ inline std::ostream & operator<< (std::ostream &out, const LAMMPS_NS::commdata &
return out;
}
/* re-usable integer hash table code with static linkage. */
/** hash table top level data structure */
typedef struct inthash_t {
struct inthash_node_t **bucket; /* array of hash nodes */
int size; /* size of the array */
int entries; /* number of entries in table */
int downshift; /* shift cound, used in hash function */
int mask; /* used to select bits for hashing */
} inthash_t;
/** hash table node data structure */
typedef struct inthash_node_t {
int data; /* data in hash node */
int key; /* key for hash lookup */
struct inthash_node_t *next; /* next node in hash chain */
} inthash_node_t;
#define HASH_FAIL -1
#define HASH_LIMIT 0.5
/* initialize new hash table */
static void inthash_init(inthash_t *tptr, int buckets);
/* lookup entry in hash table */
static int inthash_lookup(void *tptr, int key);
/* insert an entry into hash table. */
static int inthash_insert(inthash_t *tptr, int key, int data);
/* delete the hash table */
static void inthash_destroy(inthash_t *tptr);
/************************************************************************
* integer hash code:
************************************************************************/
/* inthash() - Hash function returns a hash number for a given key.
* tptr: Pointer to a hash table, key: The key to create a hash number for */
static int inthash(const inthash_t *tptr, int key) {
int hashvalue;
hashvalue = (((key*1103515249)>>tptr->downshift) & tptr->mask);
if (hashvalue < 0) {
hashvalue = 0;
}
return hashvalue;
}
/*
* rebuild_table_int() - Create new hash table when old one fills up.
*
* tptr: Pointer to a hash table
*/
static void rebuild_table_int(inthash_t *tptr) {
inthash_node_t **old_bucket, *old_hash, *tmp;
int old_size, h, i;
old_bucket=tptr->bucket;
old_size=tptr->size;
/* create a new table and rehash old buckets */
inthash_init(tptr, old_size<<1);
for (i=0; i<old_size; i++) {
old_hash=old_bucket[i];
while (old_hash) {
tmp=old_hash;
old_hash=old_hash->next;
h=inthash(tptr, tmp->key);
tmp->next=tptr->bucket[h];
tptr->bucket[h]=tmp;
tptr->entries++;
} /* while */
} /* for */
/* free memory used by old table */
free(old_bucket);
}
/*
* inthash_init() - Initialize a new hash table.
*
* tptr: Pointer to the hash table to initialize
* buckets: The number of initial buckets to create
*/
void inthash_init(inthash_t *tptr, int buckets) {
/* make sure we allocate something */
if (buckets==0)
buckets=16;
/* initialize the table */
tptr->entries=0;
tptr->size=2;
tptr->mask=1;
tptr->downshift=29;
/* ensure buckets is a power of 2 */
while (tptr->size<buckets) {
tptr->size<<=1;
tptr->mask=(tptr->mask<<1)+1;
tptr->downshift--;
} /* while */
/* allocate memory for table */
tptr->bucket=(inthash_node_t **) calloc(tptr->size, sizeof(inthash_node_t *));
}
/*
* inthash_lookup() - Lookup an entry in the hash table and return a pointer to
* it or HASH_FAIL if it wasn't found.
*
* tptr: Pointer to the hash table
* key: The key to lookup
*/
int inthash_lookup(void *ptr, int key) {
const inthash_t *tptr = (const inthash_t *) ptr;
int h;
inthash_node_t *node;
/* find the entry in the hash table */
h=inthash(tptr, key);
for (node=tptr->bucket[h]; node!=nullptr; node=node->next) {
if (node->key == key)
break;
}
/* return the entry if it exists, or HASH_FAIL */
return(node ? node->data : HASH_FAIL);
}
/*
* inthash_insert() - Insert an entry into the hash table. If the entry already
* exists return a pointer to it, otherwise return HASH_FAIL.
*
* tptr: A pointer to the hash table
* key: The key to insert into the hash table
* data: A pointer to the data to insert into the hash table
*/
int inthash_insert(inthash_t *tptr, int key, int data) {
int tmp;
inthash_node_t *node;
int h;
/* check to see if the entry exists */
if ((tmp=inthash_lookup(tptr, key)) != HASH_FAIL)
return(tmp);
/* expand the table if needed */
while (tptr->entries>=HASH_LIMIT*tptr->size)
rebuild_table_int(tptr);
/* insert the new entry */
h=inthash(tptr, key);
node=(struct inthash_node_t *) malloc(sizeof(inthash_node_t));
node->data=data;
node->key=key;
node->next=tptr->bucket[h];
tptr->bucket[h]=node;
tptr->entries++;
return HASH_FAIL;
}
/*
* inthash_destroy() - Delete the entire table, and all remaining entries.
*
*/
void inthash_destroy(inthash_t *tptr) {
inthash_node_t *node, *last;
int i;
for (i=0; i<tptr->size; i++) {
node = tptr->bucket[i];
while (node != nullptr) {
last = node;
node = node->next;
free(last);
}
}
/* free the entire array of buckets */
if (tptr->bucket != nullptr) {
free(tptr->bucket);
memset(tptr, 0, sizeof(inthash_t));
}
}
/***************************************************************/
using namespace LAMMPS_NS;
using namespace FixConst;
using namespace IntHash_NS;
// initialize static class members
int FixColvars::instances=0;
@ -287,72 +108,137 @@ FixColvars::FixColvars(LAMMPS *lmp, int narg, char **arg) :
me = comm->me;
root2root = MPI_COMM_NULL;
proxy = nullptr;
if (strcmp(arg[3], "none") == 0) {
conf_file = nullptr;
} else {
conf_file = utils::strdup(arg[3]);
}
conf_file = utils::strdup(arg[3]);
rng_seed = 1966;
unwrap_flag = 1;
inp_name = nullptr;
out_name = nullptr;
tmp_name = nullptr;
/* parse optional arguments */
int iarg = 4;
while (iarg < narg) {
// we have keyword/value pairs. check if value is missing
if (iarg+1 == narg)
error->all(FLERR,"Missing argument to keyword");
if (0 == strcmp(arg[iarg], "input")) {
inp_name = utils::strdup(arg[iarg+1]);
} else if (0 == strcmp(arg[iarg], "output")) {
out_name = utils::strdup(arg[iarg+1]);
} else if (0 == strcmp(arg[iarg], "seed")) {
rng_seed = utils::inumeric(FLERR,arg[iarg+1],false,lmp);
} else if (0 == strcmp(arg[iarg], "unwrap")) {
unwrap_flag = utils::logical(FLERR,arg[iarg+1],false,lmp);
} else if (0 == strcmp(arg[iarg], "tstat")) {
tmp_name = utils::strdup(arg[iarg+1]);
} else {
error->all(FLERR,"Unknown fix colvars parameter");
}
++iarg; ++iarg;
}
if (!out_name) out_name = utils::strdup("out");
tfix_name = nullptr;
/* initialize various state variables. */
tstat_fix = nullptr;
energy = 0.0;
nlevels_respa = 0;
init_flag = 0;
num_coords = 0;
comm_buf = nullptr;
taglist = nullptr;
force_buf = nullptr;
proxy = nullptr;
idmap = nullptr;
script_args[0] = reinterpret_cast<unsigned char *>(strdup("fix_modify"));
parse_fix_arguments(narg, arg, true);
if (!out_name) out_name = utils::strdup("out");
if (me == 0) {
#ifdef LAMMPS_BIGBIG
utils::logmesg(lmp, "colvars: Warning: cannot handle atom ids > 2147483647\n");
#endif
proxy = new colvarproxy_lammps(lmp);
proxy->init();
proxy->set_random_seed(rng_seed);
proxy->set_target_temperature(t_target);
if (conf_file) {
proxy->add_config("configfile", conf_file);
}
}
/* storage required to communicate a single coordinate or force. */
size_one = sizeof(struct commdata);
}
/*********************************
* Clean up on deleting the fix. *
*********************************/
int FixColvars::parse_fix_arguments(int narg, char **arg, bool fix_constructor)
{
int const iarg_start = fix_constructor ? 4 : 0;
int iarg = iarg_start;
while (iarg < narg) {
bool is_fix_keyword = false;
if (0 == strcmp(arg[iarg], "input")) {
inp_name = utils::strdup(arg[iarg+1]);
// input prefix is set in FixColvars::setup()
is_fix_keyword = true;
} else if (0 == strcmp(arg[iarg], "output")) {
out_name = utils::strdup(arg[iarg+1]);
// output prefix is set in FixColvars::setup()
is_fix_keyword = true;
} else if (0 == strcmp(arg[iarg], "seed")) {
rng_seed = utils::inumeric(FLERR, arg[iarg+1], false, lmp);
is_fix_keyword = true;
} else if (0 == strcmp(arg[iarg], "unwrap")) {
unwrap_flag = utils::logical(FLERR, arg[iarg+1], false, lmp);
is_fix_keyword = true;
} else if (0 == strcmp(arg[iarg], "tstat")) {
tfix_name = utils::strdup(arg[iarg+1]);
if (me == 0) set_thermostat_temperature();
is_fix_keyword = true;
}
if (is_fix_keyword) {
// Valid LAMMPS fix keyword: raise error if it has no argument
if (iarg + 1 == narg) {
if (fix_constructor) {
error->all(FLERR, ("Missing argument to keyword \""+
std::string(arg[iarg]) +"\""));
} else {
// Error code consistent with Fix::modify_param()
return 0;
}
}
} else {
if (fix_constructor) {
error->all(FLERR, "Unrecognized fix colvars argument: please note that "
"Colvars script commands are not allowed until after the "
"fix is created");
} else {
if (iarg > iarg_start) {
error->all(FLERR,
"Unrecognized fix colvars argument: please note that "
"you cannot combine fix colvars keywords and Colvars "
"script commands in the same line");
} else {
// Return negative error code to try the Colvars script commands
return -1;
}
}
}
iarg += 2;
}
return iarg;
}
FixColvars::~FixColvars()
{
delete[] conf_file;
delete[] inp_name;
delete[] out_name;
delete[] tmp_name;
delete[] tfix_name;
memory->sfree(comm_buf);
if (proxy) {
delete proxy;
inthash_t *hashtable = (inthash_t *)idmap;
inthash_destroy(hashtable);
delete hashtable;
}
if (idmap) {
inthash_destroy(idmap);
delete idmap;
}
if (root2root != MPI_COMM_NULL)
@ -374,31 +260,20 @@ int FixColvars::setmask()
return mask;
}
/* ---------------------------------------------------------------------- */
// initial checks for colvars run.
void FixColvars::init()
{
if (atom->tag_enable == 0)
error->all(FLERR,"Cannot use fix colvars without atom IDs");
error->all(FLERR, "Cannot use fix colvars without atom IDs");
if (atom->map_style == Atom::MAP_NONE)
error->all(FLERR,"Fix colvars requires an atom map, see atom_modify");
error->all(FLERR, "Fix colvars requires an atom map, see atom_modify");
if ((me == 0) && (update->whichflag == 2))
error->warning(FLERR,"Using fix colvars with minimization");
error->warning(FLERR, "Using fix colvars with minimization");
if (utils::strmatch(update->integrate_style,"^respa"))
if (utils::strmatch(update->integrate_style, "^respa"))
nlevels_respa = ((Respa *) update->integrate)->nlevels;
}
/* ---------------------------------------------------------------------- */
void FixColvars::one_time_init()
{
int i,tmp;
if (init_flag) return;
init_flag = 1;
@ -406,103 +281,188 @@ void FixColvars::one_time_init()
if (universe->nworlds > 1) {
// create inter root communicator
int color = 1;
if (me == 0) color = 0;
MPI_Comm_split(universe->uworld,color,universe->iworld,&root2root);
if (me == 0) {
color = 0;
}
MPI_Comm_split(universe->uworld, color, universe->iworld, &root2root);
if (me == 0) {
proxy->set_replicas_communicator(root2root);
}
}
}
// create and initialize the colvars proxy
void FixColvars::set_thermostat_temperature()
{
if (me == 0) {
utils::logmesg(lmp,"colvars: Creating proxy instance\n");
#ifdef LAMMPS_BIGBIG
utils::logmesg(lmp,"colvars: cannot handle atom ids > 2147483647\n");
#endif
if (inp_name) {
if (strcmp(inp_name,"NULL") == 0) {
delete[] inp_name;
inp_name = nullptr;
if (tfix_name) {
if (strcmp(tfix_name, "NULL") != 0) {
Fix *tstat_fix = modify->get_fix_by_id(tfix_name);
if (!tstat_fix) {
error->one(FLERR, "Could not find thermostat fix ID {}", tfix_name);
}
int tmp = 0;
double *tt = reinterpret_cast<double *>(tstat_fix->extract("t_target",
tmp));
if (tt) {
t_target = *tt;
} else {
error->one(FLERR, "Fix ID {} is not a thermostat fix", tfix_name);
}
}
}
// try to determine thermostat target temperature
double t_target = 0.0;
if (tmp_name) {
if (strcmp(tmp_name,"NULL") == 0) {
tstat_fix = nullptr;
} else {
tstat_fix = modify->get_fix_by_id(tmp_name);
if (!tstat_fix) error->one(FLERR, "Could not find thermostat fix ID {}", tmp_name);
double *tt = (double*) tstat_fix->extract("t_target", tmp);
if (tt) t_target = *tt;
else error->one(FLERR, "Fix ID {} is not a thermostat fix", tmp_name);
}
}
proxy = new colvarproxy_lammps(lmp,inp_name,out_name,rng_seed,t_target,root2root);
proxy->init();
proxy->add_config("configfile", conf_file);
proxy->parse_module_config();
num_coords = (proxy->modify_atom_positions()->size());
}
}
// send the list of all colvar atom IDs to all nodes.
// also initialize and build hashtable on master.
/* ---------------------------------------------------------------------- */
MPI_Bcast(&num_coords, 1, MPI_INT, 0, world);
memory->create(taglist,num_coords,"colvars:taglist");
memory->create(force_buf,3*num_coords,"colvars:force_buf");
void FixColvars::init_taglist()
{
int new_taglist_size = -1;
if (me == 0) {
// Number of atoms requested by Colvars
num_coords = static_cast<int>(proxy->modify_atom_positions()->size());
if (proxy->modified_atom_list()) {
new_taglist_size = num_coords;
proxy->reset_modified_atom_list();
} else {
new_taglist_size = -1;
}
}
// Broadcast number of colvar atoms; negative means no updates
MPI_Bcast(&new_taglist_size, 1, MPI_INT, 0, world);
if (new_taglist_size < 0) {
return;
}
num_coords = new_taglist_size;
if (taglist) {
memory->destroy(taglist);
memory->destroy(force_buf);
}
memory->create(taglist, num_coords, "colvars:taglist");
memory->create(force_buf, 3*num_coords, "colvars:force_buf");
if (me == 0) {
// Initialize and build hashtable on MPI rank 0
std::vector<int> const &tl = *(proxy->get_atom_ids());
inthash_t *hashtable=new inthash_t;
inthash_init(hashtable, num_coords);
idmap = (void *)hashtable;
for (i=0; i < num_coords; ++i) {
if (idmap) {
delete idmap;
idmap = nullptr;
}
idmap = new inthash_t;
inthash_init(idmap, num_coords);
for (int i = 0; i < num_coords; ++i) {
taglist[i] = tl[i];
inthash_insert(hashtable, tl[i], i);
inthash_insert(idmap, tl[i], i);
}
}
// Broadcast colvar atom ID list
MPI_Bcast(taglist, num_coords, MPI_LMP_TAGINT, 0, world);
}
/* ---------------------------------------------------------------------- */
int FixColvars::modify_param(int narg, char **arg)
{
if (strcmp(arg[0],"configfile") == 0) {
if (narg < 2) error->all(FLERR,"Illegal fix_modify command");
if (me == 0) {
if (! proxy)
error->one(FLERR,"Cannot use fix_modify before initialization");
return proxy->add_config_file(arg[1]) == COLVARS_OK ? 2 : 0;
}
return 2;
} else if (strcmp(arg[0],"config") == 0) {
if (narg < 2) error->all(FLERR,"Illegal fix_modify command");
if (me == 0) {
if (! proxy)
error->one(FLERR,"Cannot use fix_modify before initialization");
std::string const conf(arg[1]);
return proxy->add_config_string(conf) == COLVARS_OK ? 2 : 0;
}
return 2;
} else if (strcmp(arg[0],"load") == 0) {
if (narg < 2) error->all(FLERR,"Illegal fix_modify command");
if (me == 0) {
if (! proxy)
error->one(FLERR,"Cannot use fix_modify before initialization");
return proxy->read_state_file(arg[1]) == COLVARS_OK ? 2 : 0;
}
if (narg > 100) {
error->one(FLERR, "Too many arguments for fix_modify command");
return 2;
}
// Parse arguments to fix colvars
int return_code = parse_fix_arguments(narg, arg, false);
if (return_code >= 0) {
// A fix colvars argument was detected, return directly
return return_code;
}
// Any unknown arguments will go through the Colvars scripting interface
if (me == 0) {
int error_code = COLVARSCRIPT_OK;
colvarscript *script = proxy->script;
script->set_cmdline_main_cmd("fix_modify " + std::string(id));
for (int i = 0; i < narg; i++) {
// Substitute LAMMPS variables
// See https://github.com/lammps/lammps/commit/f9be11ac8ab460edff3709d66734d3fc2cd806dd
char *new_arg = arg[i];
int ncopy = strlen(new_arg) + 1;
char *copy = utils::strdup(new_arg);
char *work = new char[ncopy];
int nwork = ncopy;
lmp->input->substitute(copy,work,ncopy,nwork,0);
delete[] work;
new_arg = copy;
script_args[i+1] = reinterpret_cast<unsigned char *>(new_arg);
}
// Run the command through Colvars
error_code |= script->run(narg+1, script_args);
std::string const result = proxy->get_error_msgs() + script->str_result();
if (result.size()) {
std::istringstream is(result);
std::string line;
while (std::getline(is, line)) {
if (lmp->screen) fprintf(lmp->screen, "%s\n", line.c_str());
if (lmp->logfile) fprintf(lmp->logfile, "%s\n", line.c_str());
}
}
return (error_code == COLVARSCRIPT_OK) ? narg : 0;
} else {
// Return without error, don't block Fix::modify_params()
return narg;
}
return 0;
}
/* ---------------------------------------------------------------------- */
void FixColvars::setup_io()
{
if (me == 0) {
proxy->set_input_prefix(std::string(inp_name ? inp_name : ""));
if (proxy->input_prefix().size() > 0) {
proxy->log("Will read input state from file \""+
proxy->input_prefix()+".colvars.state\"");
}
proxy->set_output_prefix(std::string(out_name ? out_name : ""));
// Try to extract a restart prefix from a potential restart command
LAMMPS_NS::Output *outp = lmp->output;
if ((outp->restart_every_single > 0) &&
(outp->restart1 != nullptr)) {
proxy->set_default_restart_frequency(outp->restart_every_single);
proxy->set_restart_output_prefix(std::string(outp->restart1));
} else if ((outp->restart_every_double > 0) &&
(outp->restart2a != nullptr)) {
proxy->set_default_restart_frequency(outp->restart_every_double);
proxy->set_restart_output_prefix(std::string(outp->restart2a));
}
}
}
void FixColvars::setup(int vflag)
{
@ -514,7 +474,12 @@ void FixColvars::setup(int vflag)
MPI_Status status;
MPI_Request request;
one_time_init();
if (me == 0) {
setup_io();
proxy->parse_module_config();
}
init_taglist();
// determine size of comm buffer
nme=0;
@ -675,19 +640,6 @@ void FixColvars::post_force(int /*vflag*/)
if (me == 0) {
if (proxy->want_exit())
error->one(FLERR,"Run aborted on request from colvars module.\n");
if (!tstat_fix) {
proxy->set_target_temperature(0.0);
} else {
int tmp;
// get thermostat target temperature from corresponding fix,
// if the fix supports extraction.
double *tt = (double *) tstat_fix->extract("t_target", tmp);
if (tt)
proxy->set_target_temperature(*tt);
else
proxy->set_target_temperature(0.0);
}
}
const tagint * const tag = atom->tag;
@ -935,13 +887,17 @@ void FixColvars::end_of_step()
void FixColvars::write_restart(FILE *fp)
{
if (me == 0) {
std::string rest_text;
proxy->serialize_status(rest_text);
// TODO call write_output_files()
const char *cvm_state = rest_text.c_str();
int len = strlen(cvm_state) + 1; // need to include terminating null byte.
fwrite(&len,sizeof(int),1,fp);
fwrite(cvm_state,1,len,fp);
cvm::memory_stream ms;
if (proxy->colvars->write_state(ms)) {
int len_cv_state = ms.length();
// Will write the buffer's length twice, so that the fix can read it later, too
int len = len_cv_state + sizeof(int);
fwrite(&len, sizeof(int), 1, fp);
fwrite(&len, sizeof(int), 1, fp);
fwrite(ms.output_buffer(), 1, len_cv_state, fp);
} else {
error->all(FLERR, "Failed to write Colvars state to binary file");
}
}
}
@ -949,11 +905,13 @@ void FixColvars::write_restart(FILE *fp)
void FixColvars::restart(char *buf)
{
one_time_init();
if (me == 0) {
std::string rest_text(buf);
proxy->deserialize_status(rest_text);
// Read the buffer's length, then load it into Colvars starting right past that location
int length = *(reinterpret_cast<int *>(buf));
unsigned char *colvars_state_buffer = reinterpret_cast<unsigned char *>(buf + sizeof(int));
if (proxy->colvars->set_input_state_buffer(length, colvars_state_buffer) != COLVARS_OK) {
error->all(FLERR, "Failed to set the Colvars input state from string buffer");
}
}
}