From 2c65df1bc2efd9c39aae3a3ceeca06fecf25b698 Mon Sep 17 00:00:00 2001 From: Tim Bernhard Date: Tue, 10 Nov 2020 16:29:02 +0100 Subject: [PATCH] Revert typo fix in python due to auto-formatter changing too much --- tools/replica/reorder_remd_traj.py | 231 ++++++++++++++--------------- 1 file changed, 109 insertions(+), 122 deletions(-) diff --git a/tools/replica/reorder_remd_traj.py b/tools/replica/reorder_remd_traj.py index 6eee4770ab..5033ae1e53 100644 --- a/tools/replica/reorder_remd_traj.py +++ b/tools/replica/reorder_remd_traj.py @@ -37,17 +37,13 @@ StringIO (or io if in Python 3.x) """ -import os -import numpy as np -import argparse -import time -import pickle + +import os, numpy as np, argparse, time, pickle from scipy.special import logsumexp from mpi4py import MPI from tqdm import tqdm -import gzip -import bz2 +import gzip, bz2 try: # python-2 from StringIO import StringIO as IOBuffer @@ -56,11 +52,12 @@ except ImportError: from io import BytesIO as IOBuffer + #### INITIALIZE MPI #### # (note that all output on screen will be printed only on the ROOT proc) ROOT = 0 comm = MPI.COMM_WORLD -me = comm.rank # my proc id +me = comm.rank # my proc id nproc = comm.size @@ -80,8 +77,7 @@ def _get_nearest_temp(temps, query_temp): out_temp: nearest temp from the list """ - if isinstance(temps, list): - temps = np.array(temps) + if isinstance(temps, list): temps = np.array(temps) return temps[np.argmin(np.abs(temps-query_temp))] @@ -99,10 +95,10 @@ def readwrite(trajfn, mode): if trajfn.endswith(".gz"): of = gzip.open(trajfn, mode) - # return gzip.GzipFile(trajfn, mode) + #return gzip.GzipFile(trajfn, mode) elif trajfn.endswith(".bz2"): of = bz2.open(trajfn, mode) - # return bz2.BZ2File(trajfn, mode) + #return bz2.BZ2File(trajfn, mode) else: of = open(trajfn, mode) return of @@ -127,8 +123,8 @@ def get_replica_frames(logfn, temps, nswap, writefreq): """ n_rep = len(temps) - swap_history = np.loadtxt(logfn, skiprows=3) - master_frametuple_dict = dict((n, []) for n in range(n_rep)) + swap_history = np.loadtxt(logfn, skiprows = 3) + master_frametuple_dict = dict( (n, []) for n in range(n_rep) ) # walk through the replicas print("Getting frames from all replicas at temperature:") @@ -140,15 +136,15 @@ def get_replica_frames(logfn, temps, nswap, writefreq): if writefreq <= nswap: for ii, i in enumerate(rep_inds[:-1]): start = int(ii * nswap / writefreq) - stop = int((ii+1) * nswap / writefreq) - [master_frametuple_dict[n].append((i, x)) - for x in range(start, stop)] + stop = int( (ii+1) * nswap / writefreq) + [master_frametuple_dict[n].append( (i,x) ) \ + for x in range(start, stop)] # case-2: when temps. are swapped faster than dumping frames else: nskip = int(writefreq / nswap) - [master_frametuple_dict[n].append((i, ii)) - for ii, i in enumerate(rep_inds[0::nskip])] + [master_frametuple_dict[n].append( (i,ii) ) \ + for ii, i in enumerate(rep_inds[0::nskip])] return master_frametuple_dict @@ -165,12 +161,11 @@ def get_byte_index(rep_inds, byteindfns, intrajfns): """ for n in rep_inds: # check if the byte indices for this traj has already been computed - if os.path.isfile(byteindfns[n]): - continue + if os.path.isfile(byteindfns[n]): continue # extract bytes fobj = readwrite(intrajfns[n], "rb") - byteinds = [[0, 0]] + byteinds = [ [0,0] ] # place file pointer at first line nframe = 0 @@ -180,37 +175,33 @@ def get_byte_index(rep_inds, byteindfns, intrajfns): # status printed only for replica read on root proc # this assumes that each proc takes roughly the same time if me == ROOT: - pb = tqdm(desc="Reading replicas", leave=True, - position=ROOT + 2*me, - unit="B/replica", unit_scale=True, - unit_divisor=1024) + pb = tqdm(desc = "Reading replicas", leave = True, + position = ROOT + 2*me, + unit = "B/replica", unit_scale = True, + unit_divisor = 1024) # start crawling through the bytes while True: next_line = fobj.readline() - if len(next_line) == 0: - break + if len(next_line) == 0: break # this will only work with lammpstrj traj format. # this condition essentially checks periodic recurrences # of the token TIMESTEP. Each time it is found, # we have crawled through a frame (snapshot) if next_line == first_line: nframe += 1 - byteinds.append([nframe, cur_pos]) - if me == ROOT: - pb.update() + byteinds.append( [nframe, cur_pos] ) + if me == ROOT: pb.update() cur_pos = fobj.tell() - if me == ROOT: - pb.update(0) - if me == ROOT: - pb.close() + if me == ROOT: pb.update(0) + if me == ROOT: pb.close() # take care of the EOF cur_pos = fobj.tell() - byteinds.append([nframe+1, cur_pos]) # dummy index for the EOF + byteinds.append( [nframe+1, cur_pos] ) # dummy index for the EOF # write to file - np.savetxt(byteindfns[n], np.array(byteinds), fmt="%d") + np.savetxt(byteindfns[n], np.array(byteinds), fmt = "%d") # close the trajfile object fobj.close() @@ -256,15 +247,15 @@ def write_reordered_traj(temp_inds, byte_inds, outtemps, temps, of = readwrite(outtrajfns[n], "wb") # get frames - abs_temp_ind = np.argmin(abs(temps - outtemps[n])) + abs_temp_ind = np.argmin( abs(temps - outtemps[n]) ) frametuple = frametuple_dict[abs_temp_ind][-nframes:] # write frames to buffer if me == ROOT: pb = tqdm(frametuple, - desc=("Buffering trajectories for writing"), - leave=True, position=ROOT + 2*me, - unit='frame/replica', unit_scale=True) + desc = ("Buffering trajectories for writing"), + leave = True, position = ROOT + 2*me, + unit = 'frame/replica', unit_scale = True) iterable = pb else: @@ -272,23 +263,20 @@ def write_reordered_traj(temp_inds, byte_inds, outtemps, temps, for i, (rep, frame) in enumerate(iterable): infobj = infobjs[rep] - start_ptr = int(byte_inds[rep][frame, 1]) - stop_ptr = int(byte_inds[rep][frame+1, 1]) + start_ptr = int(byte_inds[rep][frame,1]) + stop_ptr = int(byte_inds[rep][frame+1,1]) byte_len = stop_ptr - start_ptr infobj.seek(start_ptr) buf.write(infobj.read(byte_len)) - if me == ROOT: - pb.close() + if me == ROOT: pb.close() # write buffer to disk - if me == ROOT: - print("Writing buffer to file") + if me == ROOT: print("Writing buffer to file") of.write(buf.getvalue()) of.close() buf.close() - for i in infobjs: - i.close() + for i in infobjs: i.close() return @@ -337,13 +325,13 @@ def get_canonical_logw(enefn, frametuple_dict, temps, nprod, writefreq, pip install --user pymbar sudo pip install pymbar - To install the dev. version directly from GitHub, use: + To install the dev. version directly from github, use: pip install pip install git+https://github.com/choderalab/pymbar.git """) u_rn = np.loadtxt(enefn) - ntemps = u_rn.shape[0] # number of temps. - nframes = int(nprod / writefreq) # number of frames at each temp. + ntemps = u_rn.shape[0] # number of temps. + nframes = int(nprod / writefreq) # number of frames at each temp. # reorder the temps u_kn = np.zeros([ntemps, nframes], float) @@ -353,90 +341,91 @@ def get_canonical_logw(enefn, frametuple_dict, temps, nprod, writefreq, u_kn[k, i] = u_rn[rep, frame] # prep input for pymbar - # 1) array of frames at each temp. + #1) array of frames at each temp. nframes_k = nframes * np.ones(ntemps, np.uint8) - # 2) inverse temps. for chosen energy scale + #2) inverse temps. for chosen energy scale beta_k = 1.0 / (kB * temps) - # 3) get reduced energies (*ONLY FOR THE CANONICAL ENSEMBLE*) + #3) get reduced energies (*ONLY FOR THE CANONICAL ENSEMBLE*) u_kln = np.zeros([ntemps, ntemps, nframes], float) for k in range(ntemps): u_kln[k] = np.outer(beta_k, u_kn[k]) # run pymbar and extract the free energies print("\nRunning pymbar...") - mbar = pymbar.mbar.MBAR(u_kln, nframes_k, verbose=True) - f_k = mbar.f_k # (1 x k array) + mbar = pymbar.mbar.MBAR(u_kln, nframes_k, verbose = True) + f_k = mbar.f_k # (1 x k array) # calculate the log-weights print("\nExtracting log-weights...") log_nframes = np.log(nframes) - logw = dict((k, np.zeros([ntemps, nframes], float)) for k in range(ntemps)) + logw = dict( (k, np.zeros([ntemps, nframes], float)) for k in range(ntemps) ) # get log-weights to reweight to this temp. for k in range(ntemps): for n in range(nframes): - num = -beta_k[k] * u_kn[k, n] - denom = f_k - beta_k[k] * u_kn[k, n] + num = -beta_k[k] * u_kn[k,n] + denom = f_k - beta_k[k] * u_kn[k,n] for l in range(ntemps): - logw[l][k, n] = num - logsumexp(denom) - log_nframes + logw[l][k,n] = num - logsumexp(denom) - log_nframes return logw + #### MAIN WORKFLOW #### if __name__ == "__main__": # accept user inputs - parser = argparse.ArgumentParser(description=__doc__, - formatter_class=argparse.RawDescriptionHelpFormatter) + parser = argparse.ArgumentParser(description = __doc__, + formatter_class = argparse.RawDescriptionHelpFormatter) parser.add_argument("prefix", - help="Prefix of REMD LAMMPS trajectories.\ + help = "Prefix of REMD LAMMPS trajectories.\ Supply full path. Trajectories assumed to be named as \ .%%d.lammpstrj. \ Can be in compressed (.gz or .bz2) format. \ This is a required argument") - parser.add_argument("-logfn", "--logfn", default="log.lammps", - help="LAMMPS log file that contains swap history \ + parser.add_argument("-logfn", "--logfn", default = "log.lammps", + help = "LAMMPS log file that contains swap history \ of temperatures among replicas. \ Default = 'lammps.log'") - parser.add_argument("-tfn", "--tempfn", default="temps.txt", - help="ascii file (readable by numpy.loadtxt) with \ + parser.add_argument("-tfn", "--tempfn", default = "temps.txt", + help = "ascii file (readable by numpy.loadtxt) with \ the temperatures used in the REMD simulation.") - parser.add_argument("-ns", "--nswap", type=int, - help="Swap frequency used in LAMMPS temper command") + parser.add_argument("-ns", "--nswap", type = int, + help = "Swap frequency used in LAMMPS temper command") - parser.add_argument("-nw", "--nwrite", type=int, default=1, - help="Trajectory writing frequency used \ + parser.add_argument("-nw", "--nwrite", type = int, default = 1, + help = "Trajectory writing frequency used \ in LAMMPS dump command") - parser.add_argument("-np", "--nprod", type=int, default=0, - help="Number of timesteps to save in the reordered\ + parser.add_argument("-np", "--nprod", type = int, default = 0, + help = "Number of timesteps to save in the reordered\ trajectories.\ This should be in units of the LAMMPS timestep") - parser.add_argument("-logw", "--logw", action='store_true', - help="Supplying this flag \ + parser.add_argument("-logw", "--logw", action = 'store_true', + help = "Supplying this flag \ calculates *canonical* (NVT ensemble) log weights") parser.add_argument("-e", "--enefn", - help="File that has n_replica x n_frames array\ + help = "File that has n_replica x n_frames array\ of total potential energies") parser.add_argument("-kB", "--boltzmann_const", - type=float, default=0.001987, - help="Boltzmann constant in appropriate units. \ + type = float, default = 0.001987, + help = "Boltzmann constant in appropriate units. \ Default is kcal/mol") - parser.add_argument("-ot", "--out_temps", nargs='+', type=np.float64, - help="Reorder trajectories at these temperatures.\n \ + parser.add_argument("-ot", "--out_temps", nargs = '+', type = np.float64, + help = "Reorder trajectories at these temperatures.\n \ Default is all temperatures used in the simulation") - parser.add_argument("-od", "--outdir", default=".", - help="All output will be saved to this directory") + parser.add_argument("-od", "--outdir", default = ".", + help = "All output will be saved to this directory") # parse inputs args = parser.parse_args() @@ -449,16 +438,14 @@ if __name__ == "__main__": nprod = args.nprod enefn = args.enefn - if not enefn is None: - enefn = os.path.abspath(enefn) + if not enefn is None: enefn = os.path.abspath(enefn) get_logw = args.logw kB = args.boltzmann_const out_temps = args.out_temps outdir = os.path.abspath(args.outdir) if not os.path.isdir(outdir): - if me == ROOT: - os.mkdir(outdir) + if me == ROOT: os.mkdir(outdir) # check that all input files are present (only on the ROOT proc) if me == ROOT: @@ -478,8 +465,7 @@ if __name__ == "__main__": for i in range(ntemps): this_intrajfn = intrajfns[i] x = this_intrajfn + ".gz" - if os.path.isfile(this_intrajfn): - continue + if os.path.isfile(this_intrajfn): continue elif os.path.isfile(this_intrajfn + ".gz"): intrajfns[i] = this_intrajfn + ".gz" elif os.path.isfile(this_intrajfn + ".bz2"): @@ -490,41 +476,42 @@ if __name__ == "__main__": # set output filenames outprefix = os.path.join(outdir, traj_prefix.split('/')[-1]) - outtrajfns = ["%s.%3.2f.lammpstrj.gz" % - (outprefix, _get_nearest_temp(temps, t)) + outtrajfns = ["%s.%3.2f.lammpstrj.gz" % \ + (outprefix, _get_nearest_temp(temps, t)) \ for t in out_temps] - byteindfns = [os.path.join(outdir, ".byteind_%d.gz" % k) + byteindfns = [os.path.join(outdir, ".byteind_%d.gz" % k) \ for k in range(ntemps)] frametuplefn = outprefix + '.frametuple.pickle' if get_logw: logwfn = outprefix + ".logw.pickle" + # get a list of all frames at a particular temp visited by each replica # this is fast so run only on ROOT proc. master_frametuple_dict = {} if me == ROOT: - master_frametuple_dict = get_replica_frames(logfn=logfn, - temps=temps, - nswap=nswap, - writefreq=writefreq) + master_frametuple_dict = get_replica_frames(logfn = logfn, + temps = temps, + nswap = nswap, + writefreq = writefreq) # save to a pickle from the ROOT proc with open(frametuplefn, 'wb') as of: pickle.dump(master_frametuple_dict, of) # broadcast to all procs - master_frametuple_dict = comm.bcast(master_frametuple_dict, root=ROOT) + master_frametuple_dict = comm.bcast(master_frametuple_dict, root = ROOT) # define a chunk of replicas to process on each proc CHUNKSIZE_1 = int(ntemps/nproc) if me < nproc - 1: - my_rep_inds = range((me*CHUNKSIZE_1), (me+1)*CHUNKSIZE_1) + my_rep_inds = range( (me*CHUNKSIZE_1), (me+1)*CHUNKSIZE_1 ) else: - my_rep_inds = range((me*CHUNKSIZE_1), ntemps) + my_rep_inds = range( (me*CHUNKSIZE_1), ntemps ) # get byte indices from replica (un-ordered) trajs. in parallel - get_byte_index(rep_inds=my_rep_inds, - byteindfns=byteindfns, - intrajfns=intrajfns) + get_byte_index(rep_inds = my_rep_inds, + byteindfns = byteindfns, + intrajfns = intrajfns) # block until all procs have finished comm.barrier() @@ -533,7 +520,7 @@ if __name__ == "__main__": infobjs = [readwrite(i, "rb") for i in intrajfns] # open all byteindex files - byte_inds = dict((i, np.loadtxt(fn)) for i, fn in enumerate(byteindfns)) + byte_inds = dict( (i, np.loadtxt(fn)) for i, fn in enumerate(byteindfns) ) # define a chunk of output trajs. to process for each proc. # # of reordered trajs. to write may be less than the total # of replicas @@ -549,38 +536,38 @@ if __name__ == "__main__": else: nproc_active = nproc if me < nproc_active-1: - my_temp_inds = range((me*CHUNKSIZE_2), (me+1)*CHUNKSIZE_1) + my_temp_inds = range( (me*CHUNKSIZE_2), (me+1)*CHUNKSIZE_1 ) else: - my_temp_inds = range((me*CHUNKSIZE_2), n_out_temps) + my_temp_inds = range( (me*CHUNKSIZE_2), n_out_temps) # retire the excess procs # dont' forget to close any open file objects if me >= nproc_active: - for fobj in infobjs: - fobj.close() + for fobj in infobjs: fobj.close() exit() # write reordered trajectories to disk from active procs in parallel - write_reordered_traj(temp_inds=my_temp_inds, - byte_inds=byte_inds, - outtemps=out_temps, temps=temps, - frametuple_dict=master_frametuple_dict, - nprod=nprod, writefreq=writefreq, - outtrajfns=outtrajfns, - infobjs=infobjs) + write_reordered_traj(temp_inds = my_temp_inds, + byte_inds = byte_inds, + outtemps = out_temps, temps = temps, + frametuple_dict = master_frametuple_dict, + nprod = nprod, writefreq = writefreq, + outtrajfns = outtrajfns, + infobjs = infobjs) # calculate canonical log-weights if requested # usually this is very fast so retire all but the ROOT proc - if not get_logw: - exit() - if not me == ROOT: - exit() + if not get_logw: exit() + if not me == ROOT: exit() + + logw = get_canonical_logw(enefn = enefn, temps = temps, + frametuple_dict = master_frametuple_dict, + nprod = nprod, writefreq = writefreq, + kB = kB) - logw = get_canonical_logw(enefn=enefn, temps=temps, - frametuple_dict=master_frametuple_dict, - nprod=nprod, writefreq=writefreq, - kB=kB) # save the logweights to a pickle with open(logwfn, 'wb') as of: pickle.dump(logw, of) + +