Revert typo fix in python due to auto-formatter changing too much
This commit is contained in:
@ -37,17 +37,13 @@ StringIO (or io if in Python 3.x)
|
||||
"""
|
||||
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import argparse
|
||||
import time
|
||||
import pickle
|
||||
|
||||
import os, numpy as np, argparse, time, pickle
|
||||
from scipy.special import logsumexp
|
||||
from mpi4py import MPI
|
||||
|
||||
from tqdm import tqdm
|
||||
import gzip
|
||||
import bz2
|
||||
import gzip, bz2
|
||||
try:
|
||||
# python-2
|
||||
from StringIO import StringIO as IOBuffer
|
||||
@ -56,11 +52,12 @@ except ImportError:
|
||||
from io import BytesIO as IOBuffer
|
||||
|
||||
|
||||
|
||||
#### INITIALIZE MPI ####
|
||||
# (note that all output on screen will be printed only on the ROOT proc)
|
||||
ROOT = 0
|
||||
comm = MPI.COMM_WORLD
|
||||
me = comm.rank # my proc id
|
||||
me = comm.rank # my proc id
|
||||
nproc = comm.size
|
||||
|
||||
|
||||
@ -80,8 +77,7 @@ def _get_nearest_temp(temps, query_temp):
|
||||
out_temp: nearest temp from the list
|
||||
"""
|
||||
|
||||
if isinstance(temps, list):
|
||||
temps = np.array(temps)
|
||||
if isinstance(temps, list): temps = np.array(temps)
|
||||
return temps[np.argmin(np.abs(temps-query_temp))]
|
||||
|
||||
|
||||
@ -99,10 +95,10 @@ def readwrite(trajfn, mode):
|
||||
|
||||
if trajfn.endswith(".gz"):
|
||||
of = gzip.open(trajfn, mode)
|
||||
# return gzip.GzipFile(trajfn, mode)
|
||||
#return gzip.GzipFile(trajfn, mode)
|
||||
elif trajfn.endswith(".bz2"):
|
||||
of = bz2.open(trajfn, mode)
|
||||
# return bz2.BZ2File(trajfn, mode)
|
||||
#return bz2.BZ2File(trajfn, mode)
|
||||
else:
|
||||
of = open(trajfn, mode)
|
||||
return of
|
||||
@ -127,8 +123,8 @@ def get_replica_frames(logfn, temps, nswap, writefreq):
|
||||
"""
|
||||
|
||||
n_rep = len(temps)
|
||||
swap_history = np.loadtxt(logfn, skiprows=3)
|
||||
master_frametuple_dict = dict((n, []) for n in range(n_rep))
|
||||
swap_history = np.loadtxt(logfn, skiprows = 3)
|
||||
master_frametuple_dict = dict( (n, []) for n in range(n_rep) )
|
||||
|
||||
# walk through the replicas
|
||||
print("Getting frames from all replicas at temperature:")
|
||||
@ -140,15 +136,15 @@ def get_replica_frames(logfn, temps, nswap, writefreq):
|
||||
if writefreq <= nswap:
|
||||
for ii, i in enumerate(rep_inds[:-1]):
|
||||
start = int(ii * nswap / writefreq)
|
||||
stop = int((ii+1) * nswap / writefreq)
|
||||
[master_frametuple_dict[n].append((i, x))
|
||||
for x in range(start, stop)]
|
||||
stop = int( (ii+1) * nswap / writefreq)
|
||||
[master_frametuple_dict[n].append( (i,x) ) \
|
||||
for x in range(start, stop)]
|
||||
|
||||
# case-2: when temps. are swapped faster than dumping frames
|
||||
else:
|
||||
nskip = int(writefreq / nswap)
|
||||
[master_frametuple_dict[n].append((i, ii))
|
||||
for ii, i in enumerate(rep_inds[0::nskip])]
|
||||
[master_frametuple_dict[n].append( (i,ii) ) \
|
||||
for ii, i in enumerate(rep_inds[0::nskip])]
|
||||
|
||||
return master_frametuple_dict
|
||||
|
||||
@ -165,12 +161,11 @@ def get_byte_index(rep_inds, byteindfns, intrajfns):
|
||||
"""
|
||||
for n in rep_inds:
|
||||
# check if the byte indices for this traj has already been computed
|
||||
if os.path.isfile(byteindfns[n]):
|
||||
continue
|
||||
if os.path.isfile(byteindfns[n]): continue
|
||||
|
||||
# extract bytes
|
||||
fobj = readwrite(intrajfns[n], "rb")
|
||||
byteinds = [[0, 0]]
|
||||
byteinds = [ [0,0] ]
|
||||
|
||||
# place file pointer at first line
|
||||
nframe = 0
|
||||
@ -180,37 +175,33 @@ def get_byte_index(rep_inds, byteindfns, intrajfns):
|
||||
# status printed only for replica read on root proc
|
||||
# this assumes that each proc takes roughly the same time
|
||||
if me == ROOT:
|
||||
pb = tqdm(desc="Reading replicas", leave=True,
|
||||
position=ROOT + 2*me,
|
||||
unit="B/replica", unit_scale=True,
|
||||
unit_divisor=1024)
|
||||
pb = tqdm(desc = "Reading replicas", leave = True,
|
||||
position = ROOT + 2*me,
|
||||
unit = "B/replica", unit_scale = True,
|
||||
unit_divisor = 1024)
|
||||
|
||||
# start crawling through the bytes
|
||||
while True:
|
||||
next_line = fobj.readline()
|
||||
if len(next_line) == 0:
|
||||
break
|
||||
if len(next_line) == 0: break
|
||||
# this will only work with lammpstrj traj format.
|
||||
# this condition essentially checks periodic recurrences
|
||||
# of the token TIMESTEP. Each time it is found,
|
||||
# we have crawled through a frame (snapshot)
|
||||
if next_line == first_line:
|
||||
nframe += 1
|
||||
byteinds.append([nframe, cur_pos])
|
||||
if me == ROOT:
|
||||
pb.update()
|
||||
byteinds.append( [nframe, cur_pos] )
|
||||
if me == ROOT: pb.update()
|
||||
cur_pos = fobj.tell()
|
||||
if me == ROOT:
|
||||
pb.update(0)
|
||||
if me == ROOT:
|
||||
pb.close()
|
||||
if me == ROOT: pb.update(0)
|
||||
if me == ROOT: pb.close()
|
||||
|
||||
# take care of the EOF
|
||||
cur_pos = fobj.tell()
|
||||
byteinds.append([nframe+1, cur_pos]) # dummy index for the EOF
|
||||
byteinds.append( [nframe+1, cur_pos] ) # dummy index for the EOF
|
||||
|
||||
# write to file
|
||||
np.savetxt(byteindfns[n], np.array(byteinds), fmt="%d")
|
||||
np.savetxt(byteindfns[n], np.array(byteinds), fmt = "%d")
|
||||
|
||||
# close the trajfile object
|
||||
fobj.close()
|
||||
@ -256,15 +247,15 @@ def write_reordered_traj(temp_inds, byte_inds, outtemps, temps,
|
||||
of = readwrite(outtrajfns[n], "wb")
|
||||
|
||||
# get frames
|
||||
abs_temp_ind = np.argmin(abs(temps - outtemps[n]))
|
||||
abs_temp_ind = np.argmin( abs(temps - outtemps[n]) )
|
||||
frametuple = frametuple_dict[abs_temp_ind][-nframes:]
|
||||
|
||||
# write frames to buffer
|
||||
if me == ROOT:
|
||||
pb = tqdm(frametuple,
|
||||
desc=("Buffering trajectories for writing"),
|
||||
leave=True, position=ROOT + 2*me,
|
||||
unit='frame/replica', unit_scale=True)
|
||||
desc = ("Buffering trajectories for writing"),
|
||||
leave = True, position = ROOT + 2*me,
|
||||
unit = 'frame/replica', unit_scale = True)
|
||||
|
||||
iterable = pb
|
||||
else:
|
||||
@ -272,23 +263,20 @@ def write_reordered_traj(temp_inds, byte_inds, outtemps, temps,
|
||||
|
||||
for i, (rep, frame) in enumerate(iterable):
|
||||
infobj = infobjs[rep]
|
||||
start_ptr = int(byte_inds[rep][frame, 1])
|
||||
stop_ptr = int(byte_inds[rep][frame+1, 1])
|
||||
start_ptr = int(byte_inds[rep][frame,1])
|
||||
stop_ptr = int(byte_inds[rep][frame+1,1])
|
||||
byte_len = stop_ptr - start_ptr
|
||||
infobj.seek(start_ptr)
|
||||
buf.write(infobj.read(byte_len))
|
||||
if me == ROOT:
|
||||
pb.close()
|
||||
if me == ROOT: pb.close()
|
||||
|
||||
# write buffer to disk
|
||||
if me == ROOT:
|
||||
print("Writing buffer to file")
|
||||
if me == ROOT: print("Writing buffer to file")
|
||||
of.write(buf.getvalue())
|
||||
of.close()
|
||||
buf.close()
|
||||
|
||||
for i in infobjs:
|
||||
i.close()
|
||||
for i in infobjs: i.close()
|
||||
|
||||
return
|
||||
|
||||
@ -337,13 +325,13 @@ def get_canonical_logw(enefn, frametuple_dict, temps, nprod, writefreq,
|
||||
pip install --user pymbar
|
||||
sudo pip install pymbar
|
||||
|
||||
To install the dev. version directly from GitHub, use:
|
||||
To install the dev. version directly from github, use:
|
||||
pip install pip install git+https://github.com/choderalab/pymbar.git
|
||||
""")
|
||||
|
||||
u_rn = np.loadtxt(enefn)
|
||||
ntemps = u_rn.shape[0] # number of temps.
|
||||
nframes = int(nprod / writefreq) # number of frames at each temp.
|
||||
ntemps = u_rn.shape[0] # number of temps.
|
||||
nframes = int(nprod / writefreq) # number of frames at each temp.
|
||||
|
||||
# reorder the temps
|
||||
u_kn = np.zeros([ntemps, nframes], float)
|
||||
@ -353,90 +341,91 @@ def get_canonical_logw(enefn, frametuple_dict, temps, nprod, writefreq,
|
||||
u_kn[k, i] = u_rn[rep, frame]
|
||||
|
||||
# prep input for pymbar
|
||||
# 1) array of frames at each temp.
|
||||
#1) array of frames at each temp.
|
||||
nframes_k = nframes * np.ones(ntemps, np.uint8)
|
||||
|
||||
# 2) inverse temps. for chosen energy scale
|
||||
#2) inverse temps. for chosen energy scale
|
||||
beta_k = 1.0 / (kB * temps)
|
||||
|
||||
# 3) get reduced energies (*ONLY FOR THE CANONICAL ENSEMBLE*)
|
||||
#3) get reduced energies (*ONLY FOR THE CANONICAL ENSEMBLE*)
|
||||
u_kln = np.zeros([ntemps, ntemps, nframes], float)
|
||||
for k in range(ntemps):
|
||||
u_kln[k] = np.outer(beta_k, u_kn[k])
|
||||
|
||||
# run pymbar and extract the free energies
|
||||
print("\nRunning pymbar...")
|
||||
mbar = pymbar.mbar.MBAR(u_kln, nframes_k, verbose=True)
|
||||
f_k = mbar.f_k # (1 x k array)
|
||||
mbar = pymbar.mbar.MBAR(u_kln, nframes_k, verbose = True)
|
||||
f_k = mbar.f_k # (1 x k array)
|
||||
|
||||
# calculate the log-weights
|
||||
print("\nExtracting log-weights...")
|
||||
log_nframes = np.log(nframes)
|
||||
logw = dict((k, np.zeros([ntemps, nframes], float)) for k in range(ntemps))
|
||||
logw = dict( (k, np.zeros([ntemps, nframes], float)) for k in range(ntemps) )
|
||||
# get log-weights to reweight to this temp.
|
||||
for k in range(ntemps):
|
||||
for n in range(nframes):
|
||||
num = -beta_k[k] * u_kn[k, n]
|
||||
denom = f_k - beta_k[k] * u_kn[k, n]
|
||||
num = -beta_k[k] * u_kn[k,n]
|
||||
denom = f_k - beta_k[k] * u_kn[k,n]
|
||||
for l in range(ntemps):
|
||||
logw[l][k, n] = num - logsumexp(denom) - log_nframes
|
||||
logw[l][k,n] = num - logsumexp(denom) - log_nframes
|
||||
|
||||
return logw
|
||||
|
||||
|
||||
|
||||
#### MAIN WORKFLOW ####
|
||||
if __name__ == "__main__":
|
||||
# accept user inputs
|
||||
parser = argparse.ArgumentParser(description=__doc__,
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter)
|
||||
parser = argparse.ArgumentParser(description = __doc__,
|
||||
formatter_class = argparse.RawDescriptionHelpFormatter)
|
||||
|
||||
parser.add_argument("prefix",
|
||||
help="Prefix of REMD LAMMPS trajectories.\
|
||||
help = "Prefix of REMD LAMMPS trajectories.\
|
||||
Supply full path. Trajectories assumed to be named as \
|
||||
<prefix>.%%d.lammpstrj. \
|
||||
Can be in compressed (.gz or .bz2) format. \
|
||||
This is a required argument")
|
||||
|
||||
parser.add_argument("-logfn", "--logfn", default="log.lammps",
|
||||
help="LAMMPS log file that contains swap history \
|
||||
parser.add_argument("-logfn", "--logfn", default = "log.lammps",
|
||||
help = "LAMMPS log file that contains swap history \
|
||||
of temperatures among replicas. \
|
||||
Default = 'lammps.log'")
|
||||
|
||||
parser.add_argument("-tfn", "--tempfn", default="temps.txt",
|
||||
help="ascii file (readable by numpy.loadtxt) with \
|
||||
parser.add_argument("-tfn", "--tempfn", default = "temps.txt",
|
||||
help = "ascii file (readable by numpy.loadtxt) with \
|
||||
the temperatures used in the REMD simulation.")
|
||||
|
||||
parser.add_argument("-ns", "--nswap", type=int,
|
||||
help="Swap frequency used in LAMMPS temper command")
|
||||
parser.add_argument("-ns", "--nswap", type = int,
|
||||
help = "Swap frequency used in LAMMPS temper command")
|
||||
|
||||
parser.add_argument("-nw", "--nwrite", type=int, default=1,
|
||||
help="Trajectory writing frequency used \
|
||||
parser.add_argument("-nw", "--nwrite", type = int, default = 1,
|
||||
help = "Trajectory writing frequency used \
|
||||
in LAMMPS dump command")
|
||||
|
||||
parser.add_argument("-np", "--nprod", type=int, default=0,
|
||||
help="Number of timesteps to save in the reordered\
|
||||
parser.add_argument("-np", "--nprod", type = int, default = 0,
|
||||
help = "Number of timesteps to save in the reordered\
|
||||
trajectories.\
|
||||
This should be in units of the LAMMPS timestep")
|
||||
|
||||
parser.add_argument("-logw", "--logw", action='store_true',
|
||||
help="Supplying this flag \
|
||||
parser.add_argument("-logw", "--logw", action = 'store_true',
|
||||
help = "Supplying this flag \
|
||||
calculates *canonical* (NVT ensemble) log weights")
|
||||
|
||||
parser.add_argument("-e", "--enefn",
|
||||
help="File that has n_replica x n_frames array\
|
||||
help = "File that has n_replica x n_frames array\
|
||||
of total potential energies")
|
||||
|
||||
parser.add_argument("-kB", "--boltzmann_const",
|
||||
type=float, default=0.001987,
|
||||
help="Boltzmann constant in appropriate units. \
|
||||
type = float, default = 0.001987,
|
||||
help = "Boltzmann constant in appropriate units. \
|
||||
Default is kcal/mol")
|
||||
|
||||
parser.add_argument("-ot", "--out_temps", nargs='+', type=np.float64,
|
||||
help="Reorder trajectories at these temperatures.\n \
|
||||
parser.add_argument("-ot", "--out_temps", nargs = '+', type = np.float64,
|
||||
help = "Reorder trajectories at these temperatures.\n \
|
||||
Default is all temperatures used in the simulation")
|
||||
|
||||
parser.add_argument("-od", "--outdir", default=".",
|
||||
help="All output will be saved to this directory")
|
||||
parser.add_argument("-od", "--outdir", default = ".",
|
||||
help = "All output will be saved to this directory")
|
||||
|
||||
# parse inputs
|
||||
args = parser.parse_args()
|
||||
@ -449,16 +438,14 @@ if __name__ == "__main__":
|
||||
nprod = args.nprod
|
||||
|
||||
enefn = args.enefn
|
||||
if not enefn is None:
|
||||
enefn = os.path.abspath(enefn)
|
||||
if not enefn is None: enefn = os.path.abspath(enefn)
|
||||
get_logw = args.logw
|
||||
kB = args.boltzmann_const
|
||||
|
||||
out_temps = args.out_temps
|
||||
outdir = os.path.abspath(args.outdir)
|
||||
if not os.path.isdir(outdir):
|
||||
if me == ROOT:
|
||||
os.mkdir(outdir)
|
||||
if me == ROOT: os.mkdir(outdir)
|
||||
|
||||
# check that all input files are present (only on the ROOT proc)
|
||||
if me == ROOT:
|
||||
@ -478,8 +465,7 @@ if __name__ == "__main__":
|
||||
for i in range(ntemps):
|
||||
this_intrajfn = intrajfns[i]
|
||||
x = this_intrajfn + ".gz"
|
||||
if os.path.isfile(this_intrajfn):
|
||||
continue
|
||||
if os.path.isfile(this_intrajfn): continue
|
||||
elif os.path.isfile(this_intrajfn + ".gz"):
|
||||
intrajfns[i] = this_intrajfn + ".gz"
|
||||
elif os.path.isfile(this_intrajfn + ".bz2"):
|
||||
@ -490,41 +476,42 @@ if __name__ == "__main__":
|
||||
|
||||
# set output filenames
|
||||
outprefix = os.path.join(outdir, traj_prefix.split('/')[-1])
|
||||
outtrajfns = ["%s.%3.2f.lammpstrj.gz" %
|
||||
(outprefix, _get_nearest_temp(temps, t))
|
||||
outtrajfns = ["%s.%3.2f.lammpstrj.gz" % \
|
||||
(outprefix, _get_nearest_temp(temps, t)) \
|
||||
for t in out_temps]
|
||||
byteindfns = [os.path.join(outdir, ".byteind_%d.gz" % k)
|
||||
byteindfns = [os.path.join(outdir, ".byteind_%d.gz" % k) \
|
||||
for k in range(ntemps)]
|
||||
frametuplefn = outprefix + '.frametuple.pickle'
|
||||
if get_logw:
|
||||
logwfn = outprefix + ".logw.pickle"
|
||||
|
||||
|
||||
# get a list of all frames at a particular temp visited by each replica
|
||||
# this is fast so run only on ROOT proc.
|
||||
master_frametuple_dict = {}
|
||||
if me == ROOT:
|
||||
master_frametuple_dict = get_replica_frames(logfn=logfn,
|
||||
temps=temps,
|
||||
nswap=nswap,
|
||||
writefreq=writefreq)
|
||||
master_frametuple_dict = get_replica_frames(logfn = logfn,
|
||||
temps = temps,
|
||||
nswap = nswap,
|
||||
writefreq = writefreq)
|
||||
# save to a pickle from the ROOT proc
|
||||
with open(frametuplefn, 'wb') as of:
|
||||
pickle.dump(master_frametuple_dict, of)
|
||||
|
||||
# broadcast to all procs
|
||||
master_frametuple_dict = comm.bcast(master_frametuple_dict, root=ROOT)
|
||||
master_frametuple_dict = comm.bcast(master_frametuple_dict, root = ROOT)
|
||||
|
||||
# define a chunk of replicas to process on each proc
|
||||
CHUNKSIZE_1 = int(ntemps/nproc)
|
||||
if me < nproc - 1:
|
||||
my_rep_inds = range((me*CHUNKSIZE_1), (me+1)*CHUNKSIZE_1)
|
||||
my_rep_inds = range( (me*CHUNKSIZE_1), (me+1)*CHUNKSIZE_1 )
|
||||
else:
|
||||
my_rep_inds = range((me*CHUNKSIZE_1), ntemps)
|
||||
my_rep_inds = range( (me*CHUNKSIZE_1), ntemps )
|
||||
|
||||
# get byte indices from replica (un-ordered) trajs. in parallel
|
||||
get_byte_index(rep_inds=my_rep_inds,
|
||||
byteindfns=byteindfns,
|
||||
intrajfns=intrajfns)
|
||||
get_byte_index(rep_inds = my_rep_inds,
|
||||
byteindfns = byteindfns,
|
||||
intrajfns = intrajfns)
|
||||
|
||||
# block until all procs have finished
|
||||
comm.barrier()
|
||||
@ -533,7 +520,7 @@ if __name__ == "__main__":
|
||||
infobjs = [readwrite(i, "rb") for i in intrajfns]
|
||||
|
||||
# open all byteindex files
|
||||
byte_inds = dict((i, np.loadtxt(fn)) for i, fn in enumerate(byteindfns))
|
||||
byte_inds = dict( (i, np.loadtxt(fn)) for i, fn in enumerate(byteindfns) )
|
||||
|
||||
# define a chunk of output trajs. to process for each proc.
|
||||
# # of reordered trajs. to write may be less than the total # of replicas
|
||||
@ -549,38 +536,38 @@ if __name__ == "__main__":
|
||||
else:
|
||||
nproc_active = nproc
|
||||
if me < nproc_active-1:
|
||||
my_temp_inds = range((me*CHUNKSIZE_2), (me+1)*CHUNKSIZE_1)
|
||||
my_temp_inds = range( (me*CHUNKSIZE_2), (me+1)*CHUNKSIZE_1 )
|
||||
else:
|
||||
my_temp_inds = range((me*CHUNKSIZE_2), n_out_temps)
|
||||
my_temp_inds = range( (me*CHUNKSIZE_2), n_out_temps)
|
||||
|
||||
# retire the excess procs
|
||||
# dont' forget to close any open file objects
|
||||
if me >= nproc_active:
|
||||
for fobj in infobjs:
|
||||
fobj.close()
|
||||
for fobj in infobjs: fobj.close()
|
||||
exit()
|
||||
|
||||
# write reordered trajectories to disk from active procs in parallel
|
||||
write_reordered_traj(temp_inds=my_temp_inds,
|
||||
byte_inds=byte_inds,
|
||||
outtemps=out_temps, temps=temps,
|
||||
frametuple_dict=master_frametuple_dict,
|
||||
nprod=nprod, writefreq=writefreq,
|
||||
outtrajfns=outtrajfns,
|
||||
infobjs=infobjs)
|
||||
write_reordered_traj(temp_inds = my_temp_inds,
|
||||
byte_inds = byte_inds,
|
||||
outtemps = out_temps, temps = temps,
|
||||
frametuple_dict = master_frametuple_dict,
|
||||
nprod = nprod, writefreq = writefreq,
|
||||
outtrajfns = outtrajfns,
|
||||
infobjs = infobjs)
|
||||
|
||||
# calculate canonical log-weights if requested
|
||||
# usually this is very fast so retire all but the ROOT proc
|
||||
if not get_logw:
|
||||
exit()
|
||||
if not me == ROOT:
|
||||
exit()
|
||||
if not get_logw: exit()
|
||||
if not me == ROOT: exit()
|
||||
|
||||
logw = get_canonical_logw(enefn = enefn, temps = temps,
|
||||
frametuple_dict = master_frametuple_dict,
|
||||
nprod = nprod, writefreq = writefreq,
|
||||
kB = kB)
|
||||
|
||||
logw = get_canonical_logw(enefn=enefn, temps=temps,
|
||||
frametuple_dict=master_frametuple_dict,
|
||||
nprod=nprod, writefreq=writefreq,
|
||||
kB=kB)
|
||||
|
||||
# save the logweights to a pickle
|
||||
with open(logwfn, 'wb') as of:
|
||||
pickle.dump(logw, of)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user