Revert typo fix in python due to auto-formatter changing too much

This commit is contained in:
Tim Bernhard
2020-11-10 16:29:02 +01:00
parent 22e6d8283e
commit 2c65df1bc2

View File

@ -37,17 +37,13 @@ StringIO (or io if in Python 3.x)
"""
import os
import numpy as np
import argparse
import time
import pickle
import os, numpy as np, argparse, time, pickle
from scipy.special import logsumexp
from mpi4py import MPI
from tqdm import tqdm
import gzip
import bz2
import gzip, bz2
try:
# python-2
from StringIO import StringIO as IOBuffer
@ -56,11 +52,12 @@ except ImportError:
from io import BytesIO as IOBuffer
#### INITIALIZE MPI ####
# (note that all output on screen will be printed only on the ROOT proc)
ROOT = 0
comm = MPI.COMM_WORLD
me = comm.rank # my proc id
me = comm.rank # my proc id
nproc = comm.size
@ -80,8 +77,7 @@ def _get_nearest_temp(temps, query_temp):
out_temp: nearest temp from the list
"""
if isinstance(temps, list):
temps = np.array(temps)
if isinstance(temps, list): temps = np.array(temps)
return temps[np.argmin(np.abs(temps-query_temp))]
@ -99,10 +95,10 @@ def readwrite(trajfn, mode):
if trajfn.endswith(".gz"):
of = gzip.open(trajfn, mode)
# return gzip.GzipFile(trajfn, mode)
#return gzip.GzipFile(trajfn, mode)
elif trajfn.endswith(".bz2"):
of = bz2.open(trajfn, mode)
# return bz2.BZ2File(trajfn, mode)
#return bz2.BZ2File(trajfn, mode)
else:
of = open(trajfn, mode)
return of
@ -127,8 +123,8 @@ def get_replica_frames(logfn, temps, nswap, writefreq):
"""
n_rep = len(temps)
swap_history = np.loadtxt(logfn, skiprows=3)
master_frametuple_dict = dict((n, []) for n in range(n_rep))
swap_history = np.loadtxt(logfn, skiprows = 3)
master_frametuple_dict = dict( (n, []) for n in range(n_rep) )
# walk through the replicas
print("Getting frames from all replicas at temperature:")
@ -140,15 +136,15 @@ def get_replica_frames(logfn, temps, nswap, writefreq):
if writefreq <= nswap:
for ii, i in enumerate(rep_inds[:-1]):
start = int(ii * nswap / writefreq)
stop = int((ii+1) * nswap / writefreq)
[master_frametuple_dict[n].append((i, x))
for x in range(start, stop)]
stop = int( (ii+1) * nswap / writefreq)
[master_frametuple_dict[n].append( (i,x) ) \
for x in range(start, stop)]
# case-2: when temps. are swapped faster than dumping frames
else:
nskip = int(writefreq / nswap)
[master_frametuple_dict[n].append((i, ii))
for ii, i in enumerate(rep_inds[0::nskip])]
[master_frametuple_dict[n].append( (i,ii) ) \
for ii, i in enumerate(rep_inds[0::nskip])]
return master_frametuple_dict
@ -165,12 +161,11 @@ def get_byte_index(rep_inds, byteindfns, intrajfns):
"""
for n in rep_inds:
# check if the byte indices for this traj has already been computed
if os.path.isfile(byteindfns[n]):
continue
if os.path.isfile(byteindfns[n]): continue
# extract bytes
fobj = readwrite(intrajfns[n], "rb")
byteinds = [[0, 0]]
byteinds = [ [0,0] ]
# place file pointer at first line
nframe = 0
@ -180,37 +175,33 @@ def get_byte_index(rep_inds, byteindfns, intrajfns):
# status printed only for replica read on root proc
# this assumes that each proc takes roughly the same time
if me == ROOT:
pb = tqdm(desc="Reading replicas", leave=True,
position=ROOT + 2*me,
unit="B/replica", unit_scale=True,
unit_divisor=1024)
pb = tqdm(desc = "Reading replicas", leave = True,
position = ROOT + 2*me,
unit = "B/replica", unit_scale = True,
unit_divisor = 1024)
# start crawling through the bytes
while True:
next_line = fobj.readline()
if len(next_line) == 0:
break
if len(next_line) == 0: break
# this will only work with lammpstrj traj format.
# this condition essentially checks periodic recurrences
# of the token TIMESTEP. Each time it is found,
# we have crawled through a frame (snapshot)
if next_line == first_line:
nframe += 1
byteinds.append([nframe, cur_pos])
if me == ROOT:
pb.update()
byteinds.append( [nframe, cur_pos] )
if me == ROOT: pb.update()
cur_pos = fobj.tell()
if me == ROOT:
pb.update(0)
if me == ROOT:
pb.close()
if me == ROOT: pb.update(0)
if me == ROOT: pb.close()
# take care of the EOF
cur_pos = fobj.tell()
byteinds.append([nframe+1, cur_pos]) # dummy index for the EOF
byteinds.append( [nframe+1, cur_pos] ) # dummy index for the EOF
# write to file
np.savetxt(byteindfns[n], np.array(byteinds), fmt="%d")
np.savetxt(byteindfns[n], np.array(byteinds), fmt = "%d")
# close the trajfile object
fobj.close()
@ -256,15 +247,15 @@ def write_reordered_traj(temp_inds, byte_inds, outtemps, temps,
of = readwrite(outtrajfns[n], "wb")
# get frames
abs_temp_ind = np.argmin(abs(temps - outtemps[n]))
abs_temp_ind = np.argmin( abs(temps - outtemps[n]) )
frametuple = frametuple_dict[abs_temp_ind][-nframes:]
# write frames to buffer
if me == ROOT:
pb = tqdm(frametuple,
desc=("Buffering trajectories for writing"),
leave=True, position=ROOT + 2*me,
unit='frame/replica', unit_scale=True)
desc = ("Buffering trajectories for writing"),
leave = True, position = ROOT + 2*me,
unit = 'frame/replica', unit_scale = True)
iterable = pb
else:
@ -272,23 +263,20 @@ def write_reordered_traj(temp_inds, byte_inds, outtemps, temps,
for i, (rep, frame) in enumerate(iterable):
infobj = infobjs[rep]
start_ptr = int(byte_inds[rep][frame, 1])
stop_ptr = int(byte_inds[rep][frame+1, 1])
start_ptr = int(byte_inds[rep][frame,1])
stop_ptr = int(byte_inds[rep][frame+1,1])
byte_len = stop_ptr - start_ptr
infobj.seek(start_ptr)
buf.write(infobj.read(byte_len))
if me == ROOT:
pb.close()
if me == ROOT: pb.close()
# write buffer to disk
if me == ROOT:
print("Writing buffer to file")
if me == ROOT: print("Writing buffer to file")
of.write(buf.getvalue())
of.close()
buf.close()
for i in infobjs:
i.close()
for i in infobjs: i.close()
return
@ -337,13 +325,13 @@ def get_canonical_logw(enefn, frametuple_dict, temps, nprod, writefreq,
pip install --user pymbar
sudo pip install pymbar
To install the dev. version directly from GitHub, use:
To install the dev. version directly from github, use:
pip install pip install git+https://github.com/choderalab/pymbar.git
""")
u_rn = np.loadtxt(enefn)
ntemps = u_rn.shape[0] # number of temps.
nframes = int(nprod / writefreq) # number of frames at each temp.
ntemps = u_rn.shape[0] # number of temps.
nframes = int(nprod / writefreq) # number of frames at each temp.
# reorder the temps
u_kn = np.zeros([ntemps, nframes], float)
@ -353,90 +341,91 @@ def get_canonical_logw(enefn, frametuple_dict, temps, nprod, writefreq,
u_kn[k, i] = u_rn[rep, frame]
# prep input for pymbar
# 1) array of frames at each temp.
#1) array of frames at each temp.
nframes_k = nframes * np.ones(ntemps, np.uint8)
# 2) inverse temps. for chosen energy scale
#2) inverse temps. for chosen energy scale
beta_k = 1.0 / (kB * temps)
# 3) get reduced energies (*ONLY FOR THE CANONICAL ENSEMBLE*)
#3) get reduced energies (*ONLY FOR THE CANONICAL ENSEMBLE*)
u_kln = np.zeros([ntemps, ntemps, nframes], float)
for k in range(ntemps):
u_kln[k] = np.outer(beta_k, u_kn[k])
# run pymbar and extract the free energies
print("\nRunning pymbar...")
mbar = pymbar.mbar.MBAR(u_kln, nframes_k, verbose=True)
f_k = mbar.f_k # (1 x k array)
mbar = pymbar.mbar.MBAR(u_kln, nframes_k, verbose = True)
f_k = mbar.f_k # (1 x k array)
# calculate the log-weights
print("\nExtracting log-weights...")
log_nframes = np.log(nframes)
logw = dict((k, np.zeros([ntemps, nframes], float)) for k in range(ntemps))
logw = dict( (k, np.zeros([ntemps, nframes], float)) for k in range(ntemps) )
# get log-weights to reweight to this temp.
for k in range(ntemps):
for n in range(nframes):
num = -beta_k[k] * u_kn[k, n]
denom = f_k - beta_k[k] * u_kn[k, n]
num = -beta_k[k] * u_kn[k,n]
denom = f_k - beta_k[k] * u_kn[k,n]
for l in range(ntemps):
logw[l][k, n] = num - logsumexp(denom) - log_nframes
logw[l][k,n] = num - logsumexp(denom) - log_nframes
return logw
#### MAIN WORKFLOW ####
if __name__ == "__main__":
# accept user inputs
parser = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
parser = argparse.ArgumentParser(description = __doc__,
formatter_class = argparse.RawDescriptionHelpFormatter)
parser.add_argument("prefix",
help="Prefix of REMD LAMMPS trajectories.\
help = "Prefix of REMD LAMMPS trajectories.\
Supply full path. Trajectories assumed to be named as \
<prefix>.%%d.lammpstrj. \
Can be in compressed (.gz or .bz2) format. \
This is a required argument")
parser.add_argument("-logfn", "--logfn", default="log.lammps",
help="LAMMPS log file that contains swap history \
parser.add_argument("-logfn", "--logfn", default = "log.lammps",
help = "LAMMPS log file that contains swap history \
of temperatures among replicas. \
Default = 'lammps.log'")
parser.add_argument("-tfn", "--tempfn", default="temps.txt",
help="ascii file (readable by numpy.loadtxt) with \
parser.add_argument("-tfn", "--tempfn", default = "temps.txt",
help = "ascii file (readable by numpy.loadtxt) with \
the temperatures used in the REMD simulation.")
parser.add_argument("-ns", "--nswap", type=int,
help="Swap frequency used in LAMMPS temper command")
parser.add_argument("-ns", "--nswap", type = int,
help = "Swap frequency used in LAMMPS temper command")
parser.add_argument("-nw", "--nwrite", type=int, default=1,
help="Trajectory writing frequency used \
parser.add_argument("-nw", "--nwrite", type = int, default = 1,
help = "Trajectory writing frequency used \
in LAMMPS dump command")
parser.add_argument("-np", "--nprod", type=int, default=0,
help="Number of timesteps to save in the reordered\
parser.add_argument("-np", "--nprod", type = int, default = 0,
help = "Number of timesteps to save in the reordered\
trajectories.\
This should be in units of the LAMMPS timestep")
parser.add_argument("-logw", "--logw", action='store_true',
help="Supplying this flag \
parser.add_argument("-logw", "--logw", action = 'store_true',
help = "Supplying this flag \
calculates *canonical* (NVT ensemble) log weights")
parser.add_argument("-e", "--enefn",
help="File that has n_replica x n_frames array\
help = "File that has n_replica x n_frames array\
of total potential energies")
parser.add_argument("-kB", "--boltzmann_const",
type=float, default=0.001987,
help="Boltzmann constant in appropriate units. \
type = float, default = 0.001987,
help = "Boltzmann constant in appropriate units. \
Default is kcal/mol")
parser.add_argument("-ot", "--out_temps", nargs='+', type=np.float64,
help="Reorder trajectories at these temperatures.\n \
parser.add_argument("-ot", "--out_temps", nargs = '+', type = np.float64,
help = "Reorder trajectories at these temperatures.\n \
Default is all temperatures used in the simulation")
parser.add_argument("-od", "--outdir", default=".",
help="All output will be saved to this directory")
parser.add_argument("-od", "--outdir", default = ".",
help = "All output will be saved to this directory")
# parse inputs
args = parser.parse_args()
@ -449,16 +438,14 @@ if __name__ == "__main__":
nprod = args.nprod
enefn = args.enefn
if not enefn is None:
enefn = os.path.abspath(enefn)
if not enefn is None: enefn = os.path.abspath(enefn)
get_logw = args.logw
kB = args.boltzmann_const
out_temps = args.out_temps
outdir = os.path.abspath(args.outdir)
if not os.path.isdir(outdir):
if me == ROOT:
os.mkdir(outdir)
if me == ROOT: os.mkdir(outdir)
# check that all input files are present (only on the ROOT proc)
if me == ROOT:
@ -478,8 +465,7 @@ if __name__ == "__main__":
for i in range(ntemps):
this_intrajfn = intrajfns[i]
x = this_intrajfn + ".gz"
if os.path.isfile(this_intrajfn):
continue
if os.path.isfile(this_intrajfn): continue
elif os.path.isfile(this_intrajfn + ".gz"):
intrajfns[i] = this_intrajfn + ".gz"
elif os.path.isfile(this_intrajfn + ".bz2"):
@ -490,41 +476,42 @@ if __name__ == "__main__":
# set output filenames
outprefix = os.path.join(outdir, traj_prefix.split('/')[-1])
outtrajfns = ["%s.%3.2f.lammpstrj.gz" %
(outprefix, _get_nearest_temp(temps, t))
outtrajfns = ["%s.%3.2f.lammpstrj.gz" % \
(outprefix, _get_nearest_temp(temps, t)) \
for t in out_temps]
byteindfns = [os.path.join(outdir, ".byteind_%d.gz" % k)
byteindfns = [os.path.join(outdir, ".byteind_%d.gz" % k) \
for k in range(ntemps)]
frametuplefn = outprefix + '.frametuple.pickle'
if get_logw:
logwfn = outprefix + ".logw.pickle"
# get a list of all frames at a particular temp visited by each replica
# this is fast so run only on ROOT proc.
master_frametuple_dict = {}
if me == ROOT:
master_frametuple_dict = get_replica_frames(logfn=logfn,
temps=temps,
nswap=nswap,
writefreq=writefreq)
master_frametuple_dict = get_replica_frames(logfn = logfn,
temps = temps,
nswap = nswap,
writefreq = writefreq)
# save to a pickle from the ROOT proc
with open(frametuplefn, 'wb') as of:
pickle.dump(master_frametuple_dict, of)
# broadcast to all procs
master_frametuple_dict = comm.bcast(master_frametuple_dict, root=ROOT)
master_frametuple_dict = comm.bcast(master_frametuple_dict, root = ROOT)
# define a chunk of replicas to process on each proc
CHUNKSIZE_1 = int(ntemps/nproc)
if me < nproc - 1:
my_rep_inds = range((me*CHUNKSIZE_1), (me+1)*CHUNKSIZE_1)
my_rep_inds = range( (me*CHUNKSIZE_1), (me+1)*CHUNKSIZE_1 )
else:
my_rep_inds = range((me*CHUNKSIZE_1), ntemps)
my_rep_inds = range( (me*CHUNKSIZE_1), ntemps )
# get byte indices from replica (un-ordered) trajs. in parallel
get_byte_index(rep_inds=my_rep_inds,
byteindfns=byteindfns,
intrajfns=intrajfns)
get_byte_index(rep_inds = my_rep_inds,
byteindfns = byteindfns,
intrajfns = intrajfns)
# block until all procs have finished
comm.barrier()
@ -533,7 +520,7 @@ if __name__ == "__main__":
infobjs = [readwrite(i, "rb") for i in intrajfns]
# open all byteindex files
byte_inds = dict((i, np.loadtxt(fn)) for i, fn in enumerate(byteindfns))
byte_inds = dict( (i, np.loadtxt(fn)) for i, fn in enumerate(byteindfns) )
# define a chunk of output trajs. to process for each proc.
# # of reordered trajs. to write may be less than the total # of replicas
@ -549,38 +536,38 @@ if __name__ == "__main__":
else:
nproc_active = nproc
if me < nproc_active-1:
my_temp_inds = range((me*CHUNKSIZE_2), (me+1)*CHUNKSIZE_1)
my_temp_inds = range( (me*CHUNKSIZE_2), (me+1)*CHUNKSIZE_1 )
else:
my_temp_inds = range((me*CHUNKSIZE_2), n_out_temps)
my_temp_inds = range( (me*CHUNKSIZE_2), n_out_temps)
# retire the excess procs
# dont' forget to close any open file objects
if me >= nproc_active:
for fobj in infobjs:
fobj.close()
for fobj in infobjs: fobj.close()
exit()
# write reordered trajectories to disk from active procs in parallel
write_reordered_traj(temp_inds=my_temp_inds,
byte_inds=byte_inds,
outtemps=out_temps, temps=temps,
frametuple_dict=master_frametuple_dict,
nprod=nprod, writefreq=writefreq,
outtrajfns=outtrajfns,
infobjs=infobjs)
write_reordered_traj(temp_inds = my_temp_inds,
byte_inds = byte_inds,
outtemps = out_temps, temps = temps,
frametuple_dict = master_frametuple_dict,
nprod = nprod, writefreq = writefreq,
outtrajfns = outtrajfns,
infobjs = infobjs)
# calculate canonical log-weights if requested
# usually this is very fast so retire all but the ROOT proc
if not get_logw:
exit()
if not me == ROOT:
exit()
if not get_logw: exit()
if not me == ROOT: exit()
logw = get_canonical_logw(enefn = enefn, temps = temps,
frametuple_dict = master_frametuple_dict,
nprod = nprod, writefreq = writefreq,
kB = kB)
logw = get_canonical_logw(enefn=enefn, temps=temps,
frametuple_dict=master_frametuple_dict,
nprod=nprod, writefreq=writefreq,
kB=kB)
# save the logweights to a pickle
with open(logwfn, 'wb') as of:
pickle.dump(logw, of)