From 3a1423dc489f99f0af7d9734cec99059e1861e10 Mon Sep 17 00:00:00 2001 From: Giacomo Fiorin Date: Thu, 2 Jun 2022 11:24:04 -0400 Subject: [PATCH] Update Colvars to 2022-05-24 and copy of Lepton library One bugfix for the Colvars library in the ABF method, and update of the copy of the Lepton library as per the OpenMM repository. List of relevant PR. - 483 Update Lepton via patching procedure https://github.com/Colvars/colvars/pull/483 (@giacomofiorin) - 482 Fix integer overflow in log_gradient_finite_diff and gradient_finite_diff https://github.com/Colvars/colvars/pull/482 (@HanatoK) --- doc/src/PDF/colvars-refman-lammps.pdf | Bin 1490159 -> 1490159 bytes lib/colvars/Makefile.common | 22 +- lib/colvars/Makefile.lepton.deps | 22 +- lib/colvars/colvargrid.h | 20 +- lib/colvars/colvarproxy_tcl.cpp | 2 - lib/colvars/colvars_version.h | 2 +- .../include/lepton/CompiledExpression.h | 22 +- .../include/lepton/CompiledVectorExpression.h | 145 +++ .../lepton/include/lepton/CustomFunction.h | 2 +- .../lepton/include/lepton/ExpressionProgram.h | 2 +- .../include/lepton/ExpressionTreeNode.h | 8 +- lib/colvars/lepton/include/lepton/Operation.h | 2 +- .../lepton/include/lepton/ParsedExpression.h | 20 +- lib/colvars/lepton/src/CompiledExpression.cpp | 582 +++++++++-- .../lepton/src/CompiledVectorExpression.cpp | 933 ++++++++++++++++++ lib/colvars/lepton/src/ExpressionTreeNode.cpp | 48 +- lib/colvars/lepton/src/MSVC_erfc.h | 4 +- lib/colvars/lepton/src/Operation.cpp | 128 ++- lib/colvars/lepton/src/ParsedExpression.cpp | 211 ++-- lib/colvars/lepton/src/Parser.cpp | 2 +- 20 files changed, 1942 insertions(+), 235 deletions(-) create mode 100644 lib/colvars/lepton/include/lepton/CompiledVectorExpression.h create mode 100644 lib/colvars/lepton/src/CompiledVectorExpression.cpp diff --git a/doc/src/PDF/colvars-refman-lammps.pdf b/doc/src/PDF/colvars-refman-lammps.pdf index 71eaee867a76a108e0a8f58333d6a2ecfb513e8d..c87ccbad75a3c8a7e0ed46c69af5b0be057fb44f 100644 GIT binary patch delta 144 zcmaF=GwS`%sD>8C7N!>F7M2#)7Pc1l7LFFq7OocV7M>Q~7QPn#7J(MQ7NHj57LgXw zEn<7;Fd3Rm-#8C7N!>F7M2#)7Pc1l7LFFq7OocV7M>Q~7QPn#7J(MQ7NHj57LgXw zEn<7;FqxQ7-#G; S+1XCPhLDo&(aXf9hyef|Fe(=S diff --git a/lib/colvars/Makefile.common b/lib/colvars/Makefile.common index a920c24958..31a93652ae 100644 --- a/lib/colvars/Makefile.common +++ b/lib/colvars/Makefile.common @@ -62,10 +62,13 @@ COLVARS_SRCS = \ colvar_neuralnetworkcompute.cpp LEPTON_SRCS = \ - lepton/src/CompiledExpression.cpp lepton/src/ExpressionTreeNode.cpp \ - lepton/src/ParsedExpression.cpp lepton/src/ExpressionProgram.cpp \ - lepton/src/Operation.cpp lepton/src/Parser.cpp - + lepton/src/CompiledExpression.cpp \ + lepton/src/CompiledVectorExpression.cpp \ + lepton/src/ExpressionProgram.cpp \ + lepton/src/ExpressionTreeNode.cpp \ + lepton/src/Operation.cpp \ + lepton/src/ParsedExpression.cpp \ + lepton/src/Parser.cpp # Allow to selectively turn off Lepton ifeq ($(COLVARS_LEPTON),no) @@ -93,4 +96,13 @@ Makefile.deps: $(COLVARS_SRCS) done include Makefile.deps -include Makefile.lepton.deps # Hand-generated + +Makefile.lepton.deps: $(LEPTON_SRCS) + @echo > $@ + @for src in $^ ; do \ + obj=`basename $$src .cpp`.o ; \ + $(CXX) $(CXXFLAGS) -MM $(LEPTON_INCFLAGS) \ + -MT '$$(COLVARS_OBJ_DIR)'$$obj $$src >> $@ ; \ + done + +include Makefile.lepton.deps diff --git a/lib/colvars/Makefile.lepton.deps b/lib/colvars/Makefile.lepton.deps index 93c3912384..4546339de6 100644 --- a/lib/colvars/Makefile.lepton.deps +++ b/lib/colvars/Makefile.lepton.deps @@ -1,36 +1,46 @@ -lepton/src/CompiledExpression.o: lepton/src/CompiledExpression.cpp \ + +$(COLVARS_OBJ_DIR)CompiledExpression.o: lepton/src/CompiledExpression.cpp \ lepton/include/lepton/CompiledExpression.h \ lepton/include/lepton/ExpressionTreeNode.h \ lepton/include/lepton/windowsIncludes.h \ lepton/include/lepton/Operation.h lepton/include/lepton/CustomFunction.h \ lepton/include/lepton/Exception.h \ lepton/include/lepton/ParsedExpression.h -lepton/src/ExpressionProgram.o: lepton/src/ExpressionProgram.cpp \ +$(COLVARS_OBJ_DIR)CompiledVectorExpression.o: \ + lepton/src/CompiledVectorExpression.cpp \ + lepton/include/lepton/CompiledVectorExpression.h \ + lepton/include/lepton/ExpressionTreeNode.h \ + lepton/include/lepton/windowsIncludes.h \ + lepton/include/lepton/Operation.h lepton/include/lepton/CustomFunction.h \ + lepton/include/lepton/Exception.h \ + lepton/include/lepton/ParsedExpression.h +$(COLVARS_OBJ_DIR)ExpressionProgram.o: lepton/src/ExpressionProgram.cpp \ lepton/include/lepton/ExpressionProgram.h \ lepton/include/lepton/ExpressionTreeNode.h \ lepton/include/lepton/windowsIncludes.h \ lepton/include/lepton/Operation.h lepton/include/lepton/CustomFunction.h \ lepton/include/lepton/Exception.h \ lepton/include/lepton/ParsedExpression.h -lepton/src/ExpressionTreeNode.o: lepton/src/ExpressionTreeNode.cpp \ +$(COLVARS_OBJ_DIR)ExpressionTreeNode.o: lepton/src/ExpressionTreeNode.cpp \ lepton/include/lepton/ExpressionTreeNode.h \ lepton/include/lepton/windowsIncludes.h \ lepton/include/lepton/Exception.h lepton/include/lepton/Operation.h \ lepton/include/lepton/CustomFunction.h lepton/include/lepton/Exception.h -lepton/src/Operation.o: lepton/src/Operation.cpp \ +$(COLVARS_OBJ_DIR)Operation.o: lepton/src/Operation.cpp \ lepton/include/lepton/Operation.h \ lepton/include/lepton/windowsIncludes.h \ lepton/include/lepton/CustomFunction.h lepton/include/lepton/Exception.h \ lepton/include/lepton/ExpressionTreeNode.h lepton/src/MSVC_erfc.h -lepton/src/ParsedExpression.o: lepton/src/ParsedExpression.cpp \ +$(COLVARS_OBJ_DIR)ParsedExpression.o: lepton/src/ParsedExpression.cpp \ lepton/include/lepton/ParsedExpression.h \ lepton/include/lepton/ExpressionTreeNode.h \ lepton/include/lepton/windowsIncludes.h \ lepton/include/lepton/CompiledExpression.h \ + lepton/include/lepton/CompiledVectorExpression.h \ lepton/include/lepton/ExpressionProgram.h \ lepton/include/lepton/Operation.h lepton/include/lepton/CustomFunction.h \ lepton/include/lepton/Exception.h -lepton/src/Parser.o: lepton/src/Parser.cpp \ +$(COLVARS_OBJ_DIR)Parser.o: lepton/src/Parser.cpp \ lepton/include/lepton/Parser.h lepton/include/lepton/windowsIncludes.h \ lepton/include/lepton/CustomFunction.h lepton/include/lepton/Exception.h \ lepton/include/lepton/ExpressionTreeNode.h \ diff --git a/lib/colvars/colvargrid.h b/lib/colvars/colvargrid.h index e0b2ec7f03..f34c5eccab 100644 --- a/lib/colvars/colvargrid.h +++ b/lib/colvars/colvargrid.h @@ -1275,7 +1275,7 @@ public: inline cvm::real log_gradient_finite_diff(const std::vector &ix0, int n = 0) { - int A0, A1, A2; + cvm::real A0, A1, A2; std::vector ix = ix0; // TODO this can be rewritten more concisely with wrap_edge() @@ -1288,7 +1288,7 @@ public: if (A0 * A1 == 0) { return 0.; // can't handle empty bins } else { - return (cvm::logn((cvm::real)A1) - cvm::logn((cvm::real)A0)) + return (cvm::logn(A1) - cvm::logn(A0)) / (widths[n] * 2.); } } else if (ix[n] > 0 && ix[n] < nx[n]-1) { // not an edge @@ -1300,7 +1300,7 @@ public: if (A0 * A1 == 0) { return 0.; // can't handle empty bins } else { - return (cvm::logn((cvm::real)A1) - cvm::logn((cvm::real)A0)) + return (cvm::logn(A1) - cvm::logn(A0)) / (widths[n] * 2.); } } else { @@ -1313,8 +1313,8 @@ public: if (A0 * A1 * A2 == 0) { return 0.; // can't handle empty bins } else { - return (-1.5 * cvm::logn((cvm::real)A0) + 2. * cvm::logn((cvm::real)A1) - - 0.5 * cvm::logn((cvm::real)A2)) * increment / widths[n]; + return (-1.5 * cvm::logn(A0) + 2. * cvm::logn(A1) + - 0.5 * cvm::logn(A2)) * increment / widths[n]; } } } @@ -1324,7 +1324,7 @@ public: inline cvm::real gradient_finite_diff(const std::vector &ix0, int n = 0) { - int A0, A1, A2; + cvm::real A0, A1, A2; std::vector ix = ix0; // FIXME this can be rewritten more concisely with wrap_edge() @@ -1337,7 +1337,7 @@ public: if (A0 * A1 == 0) { return 0.; // can't handle empty bins } else { - return cvm::real(A1 - A0) / (widths[n] * 2.); + return (A1 - A0) / (widths[n] * 2.); } } else if (ix[n] > 0 && ix[n] < nx[n]-1) { // not an edge ix[n]--; @@ -1348,7 +1348,7 @@ public: if (A0 * A1 == 0) { return 0.; // can't handle empty bins } else { - return cvm::real(A1 - A0) / (widths[n] * 2.); + return (A1 - A0) / (widths[n] * 2.); } } else { // edge: use 2nd order derivative @@ -1357,8 +1357,8 @@ public: A0 = value(ix); ix[n] += increment; A1 = value(ix); ix[n] += increment; A2 = value(ix); - return (-1.5 * cvm::real(A0) + 2. * cvm::real(A1) - - 0.5 * cvm::real(A2)) * increment / widths[n]; + return (-1.5 * A0 + 2. * A1 + - 0.5 * A2) * increment / widths[n]; } } }; diff --git a/lib/colvars/colvarproxy_tcl.cpp b/lib/colvars/colvarproxy_tcl.cpp index 33bdc9dc38..700492f0e7 100644 --- a/lib/colvars/colvarproxy_tcl.cpp +++ b/lib/colvars/colvarproxy_tcl.cpp @@ -22,9 +22,7 @@ colvarproxy_tcl::colvarproxy_tcl() { -#ifdef COLVARS_TCL tcl_interp_ = NULL; -#endif } diff --git a/lib/colvars/colvars_version.h b/lib/colvars/colvars_version.h index 2a1d449ab5..d2a48f8af7 100644 --- a/lib/colvars/colvars_version.h +++ b/lib/colvars/colvars_version.h @@ -1,3 +1,3 @@ #ifndef COLVARS_VERSION -#define COLVARS_VERSION "2022-05-09" +#define COLVARS_VERSION "2022-05-24" #endif diff --git a/lib/colvars/lepton/include/lepton/CompiledExpression.h b/lib/colvars/lepton/include/lepton/CompiledExpression.h index c7e393e93b..84ec2eb410 100644 --- a/lib/colvars/lepton/include/lepton/CompiledExpression.h +++ b/lib/colvars/lepton/include/lepton/CompiledExpression.h @@ -9,7 +9,7 @@ * Biological Structures at Stanford, funded under the NIH Roadmap for * * Medical Research, grant U54 GM072970. See https://simtk.org. * * * - * Portions copyright (c) 2013-2019 Stanford University and the Authors. * + * Portions copyright (c) 2013-2022 Stanford University and the Authors. * * Authors: Peter Eastman * * Contributors: * * * @@ -40,7 +40,11 @@ #include #include #ifdef LEPTON_USE_JIT - #include "asmjit.h" +#if defined(__ARM__) || defined(__ARM64__) +#include "asmjit/a64.h" +#else +#include "asmjit/x86.h" +#endif #endif namespace Lepton { @@ -52,9 +56,9 @@ class ParsedExpression; * A CompiledExpression is a highly optimized representation of an expression for cases when you want to evaluate * it many times as quickly as possible. You should treat it as an opaque object; none of the internal representation * is visible. - * + * * A CompiledExpression is created by calling createCompiledExpression() on a ParsedExpression. - * + * * WARNING: CompiledExpression is NOT thread safe. You should never access a CompiledExpression from two threads at * the same time. */ @@ -101,9 +105,15 @@ private: std::map dummyVariables; double (*jitCode)(); #ifdef LEPTON_USE_JIT + void findPowerGroups(std::vector >& groups, std::vector >& groupPowers, std::vector& stepGroup); void generateJitCode(); - void generateSingleArgCall(asmjit::X86Compiler& c, asmjit::X86Xmm& dest, asmjit::X86Xmm& arg, double (*function)(double)); - void generateTwoArgCall(asmjit::X86Compiler& c, asmjit::X86Xmm& dest, asmjit::X86Xmm& arg1, asmjit::X86Xmm& arg2, double (*function)(double, double)); +#if defined(__ARM__) || defined(__ARM64__) + void generateSingleArgCall(asmjit::a64::Compiler& c, asmjit::arm::Vec& dest, asmjit::arm::Vec& arg, double (*function)(double)); + void generateTwoArgCall(asmjit::a64::Compiler& c, asmjit::arm::Vec& dest, asmjit::arm::Vec& arg1, asmjit::arm::Vec& arg2, double (*function)(double, double)); +#else + void generateSingleArgCall(asmjit::x86::Compiler& c, asmjit::x86::Xmm& dest, asmjit::x86::Xmm& arg, double (*function)(double)); + void generateTwoArgCall(asmjit::x86::Compiler& c, asmjit::x86::Xmm& dest, asmjit::x86::Xmm& arg1, asmjit::x86::Xmm& arg2, double (*function)(double, double)); +#endif std::vector constants; asmjit::JitRuntime runtime; #endif diff --git a/lib/colvars/lepton/include/lepton/CompiledVectorExpression.h b/lib/colvars/lepton/include/lepton/CompiledVectorExpression.h new file mode 100644 index 0000000000..a9dd936750 --- /dev/null +++ b/lib/colvars/lepton/include/lepton/CompiledVectorExpression.h @@ -0,0 +1,145 @@ +#ifndef LEPTON_VECTOR_EXPRESSION_H_ +#define LEPTON_VECTOR_EXPRESSION_H_ + +/* -------------------------------------------------------------------------- * + * Lepton * + * -------------------------------------------------------------------------- * + * This is part of the Lepton expression parser originating from * + * Simbios, the NIH National Center for Physics-Based Simulation of * + * Biological Structures at Stanford, funded under the NIH Roadmap for * + * Medical Research, grant U54 GM072970. See https://simtk.org. * + * * + * Portions copyright (c) 2013-2022 Stanford University and the Authors. * + * Authors: Peter Eastman * + * Contributors: * + * * + * Permission is hereby granted, free of charge, to any person obtaining a * + * copy of this software and associated documentation files (the "Software"), * + * to deal in the Software without restriction, including without limitation * + * the rights to use, copy, modify, merge, publish, distribute, sublicense, * + * and/or sell copies of the Software, and to permit persons to whom the * + * Software is furnished to do so, subject to the following conditions: * + * * + * The above copyright notice and this permission notice shall be included in * + * all copies or substantial portions of the Software. * + * * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL * + * THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, * + * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR * + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE * + * USE OR OTHER DEALINGS IN THE SOFTWARE. * + * -------------------------------------------------------------------------- */ + +#include "ExpressionTreeNode.h" +#include "windowsIncludes.h" +#include +#include +#include +#include +#include +#include +#ifdef LEPTON_USE_JIT +#if defined(__ARM__) || defined(__ARM64__) +#include "asmjit/a64.h" +#else +#include "asmjit/x86.h" +#endif +#endif + +namespace Lepton { + +class Operation; +class ParsedExpression; + +/** + * A CompiledVectorExpression is a highly optimized representation of an expression for cases when you want to evaluate + * it many times as quickly as possible. It is similar to CompiledExpression, with the extra feature that it uses the CPU's + * vector unit (AVX on x86, NEON on ARM) to evaluate the expression for multiple sets of arguments at once. It also differs + * from CompiledExpression and ParsedExpression in using single precision rather than double precision to evaluate the expression. + * You should treat it as an opaque object; none of the internal representation is visible. + * + * A CompiledVectorExpression is created by calling createCompiledVectorExpression() on a ParsedExpression. When you create + * it, you must specify the width of the vectors on which to compute the expression. The allowed widths depend on the type of + * CPU it is running on. 4 is always allowed, and 8 is allowed on x86 processors with AVX. Call getAllowedWidths() to query + * the allowed values. + * + * WARNING: CompiledVectorExpression is NOT thread safe. You should never access a CompiledVectorExpression from two threads at + * the same time. + */ + +class LEPTON_EXPORT CompiledVectorExpression { +public: + CompiledVectorExpression(); + CompiledVectorExpression(const CompiledVectorExpression& expression); + ~CompiledVectorExpression(); + CompiledVectorExpression& operator=(const CompiledVectorExpression& expression); + /** + * Get the width of the vectors on which the expression is computed. + */ + int getWidth() const; + /** + * Get the names of all variables used by this expression. + */ + const std::set& getVariables() const; + /** + * Get a pointer to the memory location where the value of a particular variable is stored. This can be used + * to set the value of the variable before calling evaluate(). + * + * @param name the name of the variable to query + * @return a pointer to N floating point values, where N is the vector width + */ + float* getVariablePointer(const std::string& name); + /** + * You can optionally specify the memory locations from which the values of variables should be read. + * This is useful, for example, when several expressions all use the same variable. You can then set + * the value of that variable in one place, and it will be seen by all of them. The location should + * be a pointer to N floating point values, where N is the vector width. + */ + void setVariableLocations(std::map& variableLocations); + /** + * Evaluate the expression. The values of all variables should have been set before calling this. + * + * @return a pointer to N floating point values, where N is the vector width + */ + const float* evaluate() const; + /** + * Get the list of vector widths that are supported on the current processor. + */ + static const std::vector& getAllowedWidths(); +private: + friend class ParsedExpression; + CompiledVectorExpression(const ParsedExpression& expression, int width); + void compileExpression(const ExpressionTreeNode& node, std::vector >& temps, int& workspaceSize); + int findTempIndex(const ExpressionTreeNode& node, std::vector >& temps); + int width; + std::map variablePointers; + std::vector > variablesToCopy; + std::vector > arguments; + std::vector target; + std::vector operation; + std::map variableIndices; + std::set variableNames; + mutable std::vector workspace; + mutable std::vector argValues; + std::map dummyVariables; + void (*jitCode)(); +#ifdef LEPTON_USE_JIT + void findPowerGroups(std::vector >& groups, std::vector >& groupPowers, std::vector& stepGroup); + void generateJitCode(); +#if defined(__ARM__) || defined(__ARM64__) + void generateSingleArgCall(asmjit::a64::Compiler& c, asmjit::arm::Vec& dest, asmjit::arm::Vec& arg, float (*function)(float)); + void generateTwoArgCall(asmjit::a64::Compiler& c, asmjit::arm::Vec& dest, asmjit::arm::Vec& arg1, asmjit::arm::Vec& arg2, float (*function)(float, float)); +#else + void generateSingleArgCall(asmjit::x86::Compiler& c, asmjit::x86::Ymm& dest, asmjit::x86::Ymm& arg, float (*function)(float)); + void generateTwoArgCall(asmjit::x86::Compiler& c, asmjit::x86::Ymm& dest, asmjit::x86::Ymm& arg1, asmjit::x86::Ymm& arg2, float (*function)(float, float)); +#endif + std::vector constants; + asmjit::JitRuntime runtime; +#endif +}; + +} // namespace Lepton + +#endif /*LEPTON_VECTOR_EXPRESSION_H_*/ diff --git a/lib/colvars/lepton/include/lepton/CustomFunction.h b/lib/colvars/lepton/include/lepton/CustomFunction.h index fbb0ddd52a..7b6a2b6834 100644 --- a/lib/colvars/lepton/include/lepton/CustomFunction.h +++ b/lib/colvars/lepton/include/lepton/CustomFunction.h @@ -83,7 +83,7 @@ class LEPTON_EXPORT PlaceholderFunction : public CustomFunction { public: /** * Create a Placeholder function. - * + * * @param numArgs the number of arguments the function expects */ PlaceholderFunction(int numArgs) : numArgs(numArgs) { diff --git a/lib/colvars/lepton/include/lepton/ExpressionProgram.h b/lib/colvars/lepton/include/lepton/ExpressionProgram.h index a49a9094d0..e989906288 100644 --- a/lib/colvars/lepton/include/lepton/ExpressionProgram.h +++ b/lib/colvars/lepton/include/lepton/ExpressionProgram.h @@ -67,7 +67,7 @@ public: const Operation& getOperation(int index) const; /** * Change an Operation in this program. - * + * * The Operation must have been allocated on the heap with the "new" operator. * The ExpressionProgram assumes ownership of it and will delete it when it * is no longer needed. diff --git a/lib/colvars/lepton/include/lepton/ExpressionTreeNode.h b/lib/colvars/lepton/include/lepton/ExpressionTreeNode.h index bf3a9a0902..dde26103cb 100644 --- a/lib/colvars/lepton/include/lepton/ExpressionTreeNode.h +++ b/lib/colvars/lepton/include/lepton/ExpressionTreeNode.h @@ -9,7 +9,7 @@ * Biological Structures at Stanford, funded under the NIH Roadmap for * * Medical Research, grant U54 GM072970. See https://simtk.org. * * * - * Portions copyright (c) 2009 Stanford University and the Authors. * + * Portions copyright (c) 2009-2021 Stanford University and the Authors. * * Authors: Peter Eastman * * Contributors: * * * @@ -39,6 +39,7 @@ namespace Lepton { class Operation; +class ParsedExpression; /** * This class represents a node in the abstract syntax tree representation of an expression. @@ -82,11 +83,13 @@ public: */ ExpressionTreeNode(Operation* operation); ExpressionTreeNode(const ExpressionTreeNode& node); + ExpressionTreeNode(ExpressionTreeNode&& node); ExpressionTreeNode(); ~ExpressionTreeNode(); bool operator==(const ExpressionTreeNode& node) const; bool operator!=(const ExpressionTreeNode& node) const; ExpressionTreeNode& operator=(const ExpressionTreeNode& node); + ExpressionTreeNode& operator=(ExpressionTreeNode&& node); /** * Get the Operation performed by this node. */ @@ -96,8 +99,11 @@ public: */ const std::vector& getChildren() const; private: + friend class ParsedExpression; + void assignTags(std::vector& examples) const; Operation* operation; std::vector children; + mutable int tag; }; } // namespace Lepton diff --git a/lib/colvars/lepton/include/lepton/Operation.h b/lib/colvars/lepton/include/lepton/Operation.h index 1ddde0b8c0..4b8969cd59 100644 --- a/lib/colvars/lepton/include/lepton/Operation.h +++ b/lib/colvars/lepton/include/lepton/Operation.h @@ -1017,7 +1017,7 @@ public: double evaluate(double* args, const std::map& variables) const { if (isIntPower) { // Integer powers can be computed much more quickly by repeated multiplication. - + int exponent = intValue; double base = args[0]; if (exponent < 0) { diff --git a/lib/colvars/lepton/include/lepton/ParsedExpression.h b/lib/colvars/lepton/include/lepton/ParsedExpression.h index d88b3d5829..6c6526e525 100644 --- a/lib/colvars/lepton/include/lepton/ParsedExpression.h +++ b/lib/colvars/lepton/include/lepton/ParsedExpression.h @@ -9,7 +9,7 @@ * Biological Structures at Stanford, funded under the NIH Roadmap for * * Medical Research, grant U54 GM072970. See https://simtk.org. * * * - * Portions copyright (c) 2009=2013 Stanford University and the Authors. * + * Portions copyright (c) 2009-2022 Stanford University and the Authors. * * Authors: Peter Eastman * * Contributors: * * * @@ -41,6 +41,7 @@ namespace Lepton { class CompiledExpression; class ExpressionProgram; +class CompiledVectorExpression; /** * This class represents the result of parsing an expression. It provides methods for working with the @@ -102,6 +103,16 @@ public: * Create a CompiledExpression that represents the same calculation as this expression. */ CompiledExpression createCompiledExpression() const; + /** + * Create a CompiledVectorExpression that allows the expression to be evaluated efficiently + * using the CPU's vector unit. + * + * @param width the width of the vectors to evaluate it on. The allowed values + * depend on the CPU. 4 is always allowed, and 8 is allowed on + * x86 processors with AVX. Call CompiledVectorExpression::getAllowedWidths() + * to query the allowed widths on the current processor. + */ + CompiledVectorExpression createCompiledVectorExpression(int width) const; /** * Create a new ParsedExpression which is identical to this one, except that the names of some * variables have been changed. @@ -113,9 +124,10 @@ public: private: static double evaluate(const ExpressionTreeNode& node, const std::map& variables); static ExpressionTreeNode preevaluateVariables(const ExpressionTreeNode& node, const std::map& variables); - static ExpressionTreeNode precalculateConstantSubexpressions(const ExpressionTreeNode& node); - static ExpressionTreeNode substituteSimplerExpression(const ExpressionTreeNode& node); - static ExpressionTreeNode differentiate(const ExpressionTreeNode& node, const std::string& variable); + static ExpressionTreeNode precalculateConstantSubexpressions(const ExpressionTreeNode& node, std::map& nodeCache); + static ExpressionTreeNode substituteSimplerExpression(const ExpressionTreeNode& node, std::map& nodeCache); + static ExpressionTreeNode differentiate(const ExpressionTreeNode& node, const std::string& variable, std::map& nodeCache); + static bool isConstant(const ExpressionTreeNode& node); static double getConstantValue(const ExpressionTreeNode& node); static ExpressionTreeNode renameNodeVariables(const ExpressionTreeNode& node, const std::map& replacements); ExpressionTreeNode rootNode; diff --git a/lib/colvars/lepton/src/CompiledExpression.cpp b/lib/colvars/lepton/src/CompiledExpression.cpp index 1ad348b47d..8a0239b04f 100644 --- a/lib/colvars/lepton/src/CompiledExpression.cpp +++ b/lib/colvars/lepton/src/CompiledExpression.cpp @@ -6,7 +6,7 @@ * Biological Structures at Stanford, funded under the NIH Roadmap for * * Medical Research, grant U54 GM072970. See https://simtk.org. * * * - * Portions copyright (c) 2013-2019 Stanford University and the Authors. * + * Portions copyright (c) 2013-2022 Stanford University and the Authors. * * Authors: Peter Eastman * * Contributors: * * * @@ -84,17 +84,17 @@ CompiledExpression& CompiledExpression::operator=(const CompiledExpression& expr void CompiledExpression::compileExpression(const ExpressionTreeNode& node, vector >& temps) { if (findTempIndex(node, temps) != -1) return; // We have already processed a node identical to this one. - + // Process the child nodes. - + vector args; for (int i = 0; i < node.getChildren().size(); i++) { compileExpression(node.getChildren()[i], temps); args.push_back(findTempIndex(node.getChildren()[i], temps)); } - + // Process this node. - + if (node.getOperation().getId() == Operation::VARIABLE) { variableIndices[node.getOperation().getName()] = (int) workspace.size(); variableNames.insert(node.getOperation().getName()); @@ -108,7 +108,7 @@ void CompiledExpression::compileExpression(const ExpressionTreeNode& node, vecto arguments[stepIndex].push_back(0); // The value won't actually be used. We just need something there. else { // If the arguments are sequential, we can just pass a pointer to the first one. - + bool sequential = true; for (int i = 1; i < args.size(); i++) if (args[i] != args[i-1]+1) @@ -148,30 +148,28 @@ void CompiledExpression::setVariableLocations(map& variableLoca variablePointers = variableLocations; #ifdef LEPTON_USE_JIT // Rebuild the JIT code. - + if (workspace.size() > 0) generateJitCode(); -#else +#endif // Make a list of all variables we will need to copy before evaluating the expression. - + variablesToCopy.clear(); for (map::const_iterator iter = variableIndices.begin(); iter != variableIndices.end(); ++iter) { map::iterator pointer = variablePointers.find(iter->first); if (pointer != variablePointers.end()) variablesToCopy.push_back(make_pair(&workspace[iter->second], pointer->second)); } -#endif } double CompiledExpression::evaluate() const { -#ifdef LEPTON_USE_JIT - return jitCode(); -#else + if (jitCode) + return jitCode(); for (int i = 0; i < variablesToCopy.size(); i++) *variablesToCopy[i].first = *variablesToCopy[i].second; // Loop over the operations and evaluate each one. - + for (int step = 0; step < operation.size(); step++) { const vector& args = arguments[step]; if (args.size() == 1) @@ -183,7 +181,6 @@ double CompiledExpression::evaluate() const { } } return workspace[workspace.size()-1]; -#endif } #ifdef LEPTON_USE_JIT @@ -192,32 +189,78 @@ static double evaluateOperation(Operation* op, double* args) { return op->evaluate(args, dummyVariables); } +void CompiledExpression::findPowerGroups(vector >& groups, vector >& groupPowers, vector& stepGroup) { + // Identify every step that raises an argument to an integer power. + + vector stepPower(operation.size(), 0); + vector stepArg(operation.size(), -1); + for (int step = 0; step < operation.size(); step++) { + Operation& op = *operation[step]; + int power = 0; + if (op.getId() == Operation::SQUARE) + power = 2; + else if (op.getId() == Operation::CUBE) + power = 3; + else if (op.getId() == Operation::POWER_CONSTANT) { + double realPower = dynamic_cast(&op)->getValue(); + if (realPower == (int) realPower) + power = (int) realPower; + } + if (power != 0) { + stepPower[step] = power; + stepArg[step] = arguments[step][0]; + } + } + + // Find groups that operate on the same argument and whose powers have the same sign. + + stepGroup.resize(operation.size(), -1); + for (int i = 0; i < operation.size(); i++) { + if (stepGroup[i] != -1) + continue; + vector group, power; + for (int j = i; j < operation.size(); j++) { + if (stepArg[i] == stepArg[j] && stepPower[i]*stepPower[j] > 0) { + stepGroup[j] = groups.size(); + group.push_back(j); + power.push_back(stepPower[j]); + } + } + groups.push_back(group); + groupPowers.push_back(power); + } +} + +#if defined(__ARM__) || defined(__ARM64__) void CompiledExpression::generateJitCode() { CodeHolder code; - code.init(runtime.getCodeInfo()); - X86Compiler c(&code); - c.addFunc(FuncSignature0()); - vector workspaceVar(workspace.size()); + code.init(runtime.environment()); + a64::Compiler c(&code); + c.addFunc(FuncSignatureT()); + vector workspaceVar(workspace.size()); for (int i = 0; i < (int) workspaceVar.size(); i++) - workspaceVar[i] = c.newXmmSd(); - X86Gp argsPointer = c.newIntPtr(); - c.mov(argsPointer, imm_ptr(&argValues[0])); - + workspaceVar[i] = c.newVecD(); + arm::Gp argsPointer = c.newIntPtr(); + c.mov(argsPointer, imm(&argValues[0])); + vector > groups, groupPowers; + vector stepGroup; + findPowerGroups(groups, groupPowers, stepGroup); + // Load the arguments into variables. - + for (set::const_iterator iter = variableNames.begin(); iter != variableNames.end(); ++iter) { map::iterator index = variableIndices.find(*iter); - X86Gp variablePointer = c.newIntPtr(); - c.mov(variablePointer, imm_ptr(&getVariableReference(index->first))); - c.movsd(workspaceVar[index->second], x86::ptr(variablePointer, 0, 0)); + arm::Gp variablePointer = c.newIntPtr(); + c.mov(variablePointer, imm(&getVariableReference(index->first))); + c.ldr(workspaceVar[index->second], arm::ptr(variablePointer, 0)); } // Make a list of all constants that will be needed for evaluation. - + vector operationConstantIndex(operation.size(), -1); for (int step = 0; step < (int) operation.size(); step++) { // Find the constant value (if any) used by this operation. - + Operation& op = *operation[step]; double value; if (op.getId() == Operation::CONSTANT) @@ -232,11 +275,17 @@ void CompiledExpression::generateJitCode() { value = 1.0; else if (op.getId() == Operation::DELTA) value = 1.0; + else if (op.getId() == Operation::POWER_CONSTANT) { + if (stepGroup[step] == -1) + value = dynamic_cast(op).getValue(); + else + value = 1.0; + } else continue; - + // See if we already have a variable for this constant. - + for (int i = 0; i < (int) constants.size(); i++) if (value == constants[i]) { operationConstantIndex[step] = i; @@ -247,62 +296,101 @@ void CompiledExpression::generateJitCode() { constants.push_back(value); } } - + // Load constants into variables. - - vector constantVar(constants.size()); + + vector constantVar(constants.size()); if (constants.size() > 0) { - X86Gp constantsPointer = c.newIntPtr(); - c.mov(constantsPointer, imm_ptr(&constants[0])); + arm::Gp constantsPointer = c.newIntPtr(); + c.mov(constantsPointer, imm(&constants[0])); for (int i = 0; i < (int) constants.size(); i++) { - constantVar[i] = c.newXmmSd(); - c.movsd(constantVar[i], x86::ptr(constantsPointer, 8*i, 0)); + constantVar[i] = c.newVecD(); + c.ldr(constantVar[i], arm::ptr(constantsPointer, 8*i)); } } // Evaluate the operations. + vector hasComputedPower(operation.size(), false); for (int step = 0; step < (int) operation.size(); step++) { + if (hasComputedPower[step]) + continue; + + // When one or more steps involve raising the same argument to multiple integer + // powers, we can compute them all together for efficiency. + + if (stepGroup[step] != -1) { + vector& group = groups[stepGroup[step]]; + vector& powers = groupPowers[stepGroup[step]]; + arm::Vec multiplier = c.newVecD(); + if (powers[0] > 0) + c.fmov(multiplier, workspaceVar[arguments[step][0]]); + else { + c.fdiv(multiplier, constantVar[operationConstantIndex[step]], workspaceVar[arguments[step][0]]); + for (int i = 0; i < powers.size(); i++) + powers[i] = -powers[i]; + } + vector hasAssigned(group.size(), false); + bool done = false; + while (!done) { + done = true; + for (int i = 0; i < group.size(); i++) { + if (powers[i]%2 == 1) { + if (!hasAssigned[i]) + c.fmov(workspaceVar[target[group[i]]], multiplier); + else + c.fmul(workspaceVar[target[group[i]]], workspaceVar[target[group[i]]], multiplier); + hasAssigned[i] = true; + } + powers[i] >>= 1; + if (powers[i] != 0) + done = false; + } + if (!done) + c.fmul(multiplier, multiplier, multiplier); + } + for (int step : group) + hasComputedPower[step] = true; + continue; + } + + // Evaluate the step. + Operation& op = *operation[step]; vector args = arguments[step]; if (args.size() == 1) { // One or more sequential arguments. Fill out the list. - + for (int i = 1; i < op.getNumArguments(); i++) args.push_back(args[0]+i); } - + // Generate instructions to execute this operation. - + switch (op.getId()) { case Operation::CONSTANT: - c.movsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + c.fmov(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); break; case Operation::ADD: - c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); - c.addsd(workspaceVar[target[step]], workspaceVar[args[1]]); + c.fadd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); break; case Operation::SUBTRACT: - c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); - c.subsd(workspaceVar[target[step]], workspaceVar[args[1]]); + c.fsub(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); break; case Operation::MULTIPLY: - c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); - c.mulsd(workspaceVar[target[step]], workspaceVar[args[1]]); + c.fmul(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); break; case Operation::DIVIDE: - c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); - c.divsd(workspaceVar[target[step]], workspaceVar[args[1]]); + c.fdiv(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); break; case Operation::POWER: generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], pow); break; case Operation::NEGATE: - c.xorps(workspaceVar[target[step]], workspaceVar[target[step]]); - c.subsd(workspaceVar[target[step]], workspaceVar[args[0]]); + c.fneg(workspaceVar[target[step]], workspaceVar[args[0]]); break; case Operation::SQRT: - c.sqrtsd(workspaceVar[target[step]], workspaceVar[args[0]]); + c.fsqrt(workspaceVar[target[step]], workspaceVar[args[0]]); break; case Operation::EXP: generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], exp); @@ -341,56 +429,63 @@ void CompiledExpression::generateJitCode() { generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tanh); break; case Operation::STEP: - c.xorps(workspaceVar[target[step]], workspaceVar[target[step]]); - c.cmpsd(workspaceVar[target[step]], workspaceVar[args[0]], imm(18)); // Comparison mode is _CMP_LE_OQ = 18 - c.andps(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + c.cmge(workspaceVar[target[step]], workspaceVar[args[0]], imm(0)); + c.and_(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); break; case Operation::DELTA: - c.xorps(workspaceVar[target[step]], workspaceVar[target[step]]); - c.cmpsd(workspaceVar[target[step]], workspaceVar[args[0]], imm(16)); // Comparison mode is _CMP_EQ_OS = 16 - c.andps(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + c.cmeq(workspaceVar[target[step]], workspaceVar[args[0]], imm(0)); + c.and_(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); break; case Operation::SQUARE: - c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); - c.mulsd(workspaceVar[target[step]], workspaceVar[args[0]]); + c.fmul(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]]); break; case Operation::CUBE: - c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); - c.mulsd(workspaceVar[target[step]], workspaceVar[args[0]]); - c.mulsd(workspaceVar[target[step]], workspaceVar[args[0]]); + c.fmul(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]]); + c.fmul(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]]); break; case Operation::RECIPROCAL: - c.movsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); - c.divsd(workspaceVar[target[step]], workspaceVar[args[0]]); + c.fdiv(workspaceVar[target[step]], constantVar[operationConstantIndex[step]], workspaceVar[args[0]]); break; case Operation::ADD_CONSTANT: - c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); - c.addsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + c.fadd(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]); break; case Operation::MULTIPLY_CONSTANT: - c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); - c.mulsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + c.fmul(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]); + break; + case Operation::POWER_CONSTANT: + generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]], pow); + break; + case Operation::MIN: + c.fmin(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::MAX: + c.fmax(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); break; case Operation::ABS: - generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], fabs); + c.fabs(workspaceVar[target[step]], workspaceVar[args[0]]); break; case Operation::FLOOR: - generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], floor); + c.frintm(workspaceVar[target[step]], workspaceVar[args[0]]); break; case Operation::CEIL: - generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], ceil); + c.frintp(workspaceVar[target[step]], workspaceVar[args[0]]); + break; + case Operation::SELECT: + c.fcmeq(workspaceVar[target[step]], workspaceVar[args[0]], imm(0)); + c.bsl(workspaceVar[target[step]], workspaceVar[args[2]], workspaceVar[args[1]]); break; default: // Just invoke evaluateOperation(). - + for (int i = 0; i < (int) args.size(); i++) - c.movsd(x86::ptr(argsPointer, 8*i, 0), workspaceVar[args[i]]); - X86Gp fn = c.newIntPtr(); - c.mov(fn, imm_ptr((void*) evaluateOperation)); - CCFuncCall* call = c.call(fn, FuncSignature2()); - call->setArg(0, imm_ptr(&op)); - call->setArg(1, imm_ptr(&argValues[0])); - call->setRet(0, workspaceVar[target[step]]); + c.str(workspaceVar[args[i]], arm::ptr(argsPointer, 8*i)); + arm::Gp fn = c.newIntPtr(); + c.mov(fn, imm((void*) evaluateOperation)); + InvokeNode* invoke; + c.invoke(&invoke, fn, FuncSignatureT()); + invoke->setArg(0, imm(&op)); + invoke->setArg(1, imm(&argValues[0])); + invoke->setRet(0, workspaceVar[target[step]]); } } c.ret(workspaceVar[workspace.size()-1]); @@ -399,20 +494,319 @@ void CompiledExpression::generateJitCode() { runtime.add(&jitCode, &code); } -void CompiledExpression::generateSingleArgCall(X86Compiler& c, X86Xmm& dest, X86Xmm& arg, double (*function)(double)) { - X86Gp fn = c.newIntPtr(); - c.mov(fn, imm_ptr((void*) function)); - CCFuncCall* call = c.call(fn, FuncSignature1()); - call->setArg(0, arg); - call->setRet(0, dest); +void CompiledExpression::generateSingleArgCall(a64::Compiler& c, arm::Vec& dest, arm::Vec& arg, double (*function)(double)) { + arm::Gp fn = c.newIntPtr(); + c.mov(fn, imm((void*) function)); + InvokeNode* invoke; + c.invoke(&invoke, fn, FuncSignatureT()); + invoke->setArg(0, arg); + invoke->setRet(0, dest); } -void CompiledExpression::generateTwoArgCall(X86Compiler& c, X86Xmm& dest, X86Xmm& arg1, X86Xmm& arg2, double (*function)(double, double)) { - X86Gp fn = c.newIntPtr(); - c.mov(fn, imm_ptr((void*) function)); - CCFuncCall* call = c.call(fn, FuncSignature2()); - call->setArg(0, arg1); - call->setArg(1, arg2); - call->setRet(0, dest); +void CompiledExpression::generateTwoArgCall(a64::Compiler& c, arm::Vec& dest, arm::Vec& arg1, arm::Vec& arg2, double (*function)(double, double)) { + arm::Gp fn = c.newIntPtr(); + c.mov(fn, imm((void*) function)); + InvokeNode* invoke; + c.invoke(&invoke, fn, FuncSignatureT()); + invoke->setArg(0, arg1); + invoke->setArg(1, arg2); + invoke->setRet(0, dest); +} +#else +void CompiledExpression::generateJitCode() { + const CpuInfo& cpu = CpuInfo::host(); + if (!cpu.hasFeature(CpuFeatures::X86::kAVX)) + return; + CodeHolder code; + code.init(runtime.environment()); + x86::Compiler c(&code); + FuncNode* funcNode = c.addFunc(FuncSignatureT()); + funcNode->frame().setAvxEnabled(); + vector workspaceVar(workspace.size()); + for (int i = 0; i < (int) workspaceVar.size(); i++) + workspaceVar[i] = c.newXmmSd(); + x86::Gp argsPointer = c.newIntPtr(); + c.mov(argsPointer, imm(&argValues[0])); + vector > groups, groupPowers; + vector stepGroup; + findPowerGroups(groups, groupPowers, stepGroup); + + // Load the arguments into variables. + + x86::Gp variablePointer = c.newIntPtr(); + for (set::const_iterator iter = variableNames.begin(); iter != variableNames.end(); ++iter) { + map::iterator index = variableIndices.find(*iter); + c.mov(variablePointer, imm(&getVariableReference(index->first))); + c.vmovsd(workspaceVar[index->second], x86::ptr(variablePointer, 0, 0)); + } + + // Make a list of all constants that will be needed for evaluation. + + vector operationConstantIndex(operation.size(), -1); + for (int step = 0; step < (int) operation.size(); step++) { + // Find the constant value (if any) used by this operation. + + Operation& op = *operation[step]; + double value; + if (op.getId() == Operation::CONSTANT) + value = dynamic_cast(op).getValue(); + else if (op.getId() == Operation::ADD_CONSTANT) + value = dynamic_cast(op).getValue(); + else if (op.getId() == Operation::MULTIPLY_CONSTANT) + value = dynamic_cast(op).getValue(); + else if (op.getId() == Operation::RECIPROCAL) + value = 1.0; + else if (op.getId() == Operation::STEP) + value = 1.0; + else if (op.getId() == Operation::DELTA) + value = 1.0; + else if (op.getId() == Operation::ABS) { + long long mask = 0x7FFFFFFFFFFFFFFF; + value = *reinterpret_cast(&mask); + } + else if (op.getId() == Operation::POWER_CONSTANT) { + if (stepGroup[step] == -1) + value = dynamic_cast(op).getValue(); + else + value = 1.0; + } + else + continue; + + // See if we already have a variable for this constant. + + for (int i = 0; i < (int) constants.size(); i++) + if (value == constants[i]) { + operationConstantIndex[step] = i; + break; + } + if (operationConstantIndex[step] == -1) { + operationConstantIndex[step] = constants.size(); + constants.push_back(value); + } + } + + // Load constants into variables. + + vector constantVar(constants.size()); + if (constants.size() > 0) { + x86::Gp constantsPointer = c.newIntPtr(); + c.mov(constantsPointer, imm(&constants[0])); + for (int i = 0; i < (int) constants.size(); i++) { + constantVar[i] = c.newXmmSd(); + c.vmovsd(constantVar[i], x86::ptr(constantsPointer, 8*i, 0)); + } + } + + // Evaluate the operations. + + vector hasComputedPower(operation.size(), false); + for (int step = 0; step < (int) operation.size(); step++) { + if (hasComputedPower[step]) + continue; + + // When one or more steps involve raising the same argument to multiple integer + // powers, we can compute them all together for efficiency. + + if (stepGroup[step] != -1) { + vector& group = groups[stepGroup[step]]; + vector& powers = groupPowers[stepGroup[step]]; + x86::Xmm multiplier = c.newXmmSd(); + if (powers[0] > 0) + c.vmovsd(multiplier, workspaceVar[arguments[step][0]], workspaceVar[arguments[step][0]]); + else { + c.vdivsd(multiplier, constantVar[operationConstantIndex[step]], workspaceVar[arguments[step][0]]); + for (int i = 0; i < powers.size(); i++) + powers[i] = -powers[i]; + } + vector hasAssigned(group.size(), false); + bool done = false; + while (!done) { + done = true; + for (int i = 0; i < group.size(); i++) { + if (powers[i]%2 == 1) { + if (!hasAssigned[i]) + c.vmovsd(workspaceVar[target[group[i]]], multiplier, multiplier); + else + c.vmulsd(workspaceVar[target[group[i]]], workspaceVar[target[group[i]]], multiplier); + hasAssigned[i] = true; + } + powers[i] >>= 1; + if (powers[i] != 0) + done = false; + } + if (!done) + c.vmulsd(multiplier, multiplier, multiplier); + } + for (int step : group) + hasComputedPower[step] = true; + continue; + } + + // Evaluate the step. + + Operation& op = *operation[step]; + vector args = arguments[step]; + if (args.size() == 1) { + // One or more sequential arguments. Fill out the list. + + for (int i = 1; i < op.getNumArguments(); i++) + args.push_back(args[0]+i); + } + + // Generate instructions to execute this operation. + + switch (op.getId()) { + case Operation::CONSTANT: + c.vmovsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]], constantVar[operationConstantIndex[step]]); + break; + case Operation::ADD: + c.vaddsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::SUBTRACT: + c.vsubsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::MULTIPLY: + c.vmulsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::DIVIDE: + c.vdivsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::POWER: + generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], pow); + break; + case Operation::NEGATE: + c.vxorps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[target[step]]); + c.vsubsd(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]]); + break; + case Operation::SQRT: + c.vsqrtsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]]); + break; + case Operation::EXP: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], exp); + break; + case Operation::LOG: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], log); + break; + case Operation::SIN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], sin); + break; + case Operation::COS: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], cos); + break; + case Operation::TAN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tan); + break; + case Operation::ASIN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], asin); + break; + case Operation::ACOS: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], acos); + break; + case Operation::ATAN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], atan); + break; + case Operation::ATAN2: + generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], atan2); + break; + case Operation::SINH: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], sinh); + break; + case Operation::COSH: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], cosh); + break; + case Operation::TANH: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tanh); + break; + case Operation::STEP: + c.vxorps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[target[step]]); + c.vcmpsd(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]], imm(18)); // Comparison mode is _CMP_LE_OQ = 18 + c.vandps(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + break; + case Operation::DELTA: + c.vxorps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[target[step]]); + c.vcmpsd(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]], imm(16)); // Comparison mode is _CMP_EQ_OS = 16 + c.vandps(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + break; + case Operation::SQUARE: + c.vmulsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]]); + break; + case Operation::CUBE: + c.vmulsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]]); + c.vmulsd(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]]); + break; + case Operation::RECIPROCAL: + c.vdivsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]], workspaceVar[args[0]]); + break; + case Operation::ADD_CONSTANT: + c.vaddsd(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]); + break; + case Operation::MULTIPLY_CONSTANT: + c.vmulsd(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]); + break; + case Operation::POWER_CONSTANT: + generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]], pow); + break; + case Operation::MIN: + c.vminsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::MAX: + c.vmaxsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::ABS: + c.vandpd(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]); + break; + case Operation::FLOOR: + c.vroundsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]], imm(1)); + break; + case Operation::CEIL: + c.vroundsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]], imm(2)); + break; + case Operation::SELECT: + { + x86::Xmm mask = c.newXmmSd(); + c.vxorps(mask, mask, mask); + c.vcmpsd(mask, mask, workspaceVar[args[0]], imm(0)); // Comparison mode is _CMP_EQ_OQ = 0 + c.vblendvps(workspaceVar[target[step]], workspaceVar[args[1]], workspaceVar[args[2]], mask); + break; + } + default: + // Just invoke evaluateOperation(). + + for (int i = 0; i < (int) args.size(); i++) + c.vmovsd(x86::ptr(argsPointer, 8*i, 0), workspaceVar[args[i]]); + x86::Gp fn = c.newIntPtr(); + c.mov(fn, imm((void*) evaluateOperation)); + InvokeNode* invoke; + c.invoke(&invoke, fn, FuncSignatureT()); + invoke->setArg(0, imm(&op)); + invoke->setArg(1, imm(&argValues[0])); + invoke->setRet(0, workspaceVar[target[step]]); + } + } + c.ret(workspaceVar[workspace.size()-1]); + c.endFunc(); + c.finalize(); + runtime.add(&jitCode, &code); +} + +void CompiledExpression::generateSingleArgCall(x86::Compiler& c, x86::Xmm& dest, x86::Xmm& arg, double (*function)(double)) { + x86::Gp fn = c.newIntPtr(); + c.mov(fn, imm((void*) function)); + InvokeNode* invoke; + c.invoke(&invoke, fn, FuncSignatureT()); + invoke->setArg(0, arg); + invoke->setRet(0, dest); +} + +void CompiledExpression::generateTwoArgCall(x86::Compiler& c, x86::Xmm& dest, x86::Xmm& arg1, x86::Xmm& arg2, double (*function)(double, double)) { + x86::Gp fn = c.newIntPtr(); + c.mov(fn, imm((void*) function)); + InvokeNode* invoke; + c.invoke(&invoke, fn, FuncSignatureT()); + invoke->setArg(0, arg1); + invoke->setArg(1, arg2); + invoke->setRet(0, dest); } #endif +#endif diff --git a/lib/colvars/lepton/src/CompiledVectorExpression.cpp b/lib/colvars/lepton/src/CompiledVectorExpression.cpp new file mode 100644 index 0000000000..7c01a986bb --- /dev/null +++ b/lib/colvars/lepton/src/CompiledVectorExpression.cpp @@ -0,0 +1,933 @@ +/* -------------------------------------------------------------------------- * + * Lepton * + * -------------------------------------------------------------------------- * + * This is part of the Lepton expression parser originating from * + * Simbios, the NIH National Center for Physics-Based Simulation of * + * Biological Structures at Stanford, funded under the NIH Roadmap for * + * Medical Research, grant U54 GM072970. See https://simtk.org. * + * * + * Portions copyright (c) 2013-2022 Stanford University and the Authors. * + * Authors: Peter Eastman * + * Contributors: * + * * + * Permission is hereby granted, free of charge, to any person obtaining a * + * copy of this software and associated documentation files (the "Software"), * + * to deal in the Software without restriction, including without limitation * + * the rights to use, copy, modify, merge, publish, distribute, sublicense, * + * and/or sell copies of the Software, and to permit persons to whom the * + * Software is furnished to do so, subject to the following conditions: * + * * + * The above copyright notice and this permission notice shall be included in * + * all copies or substantial portions of the Software. * + * * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL * + * THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, * + * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR * + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE * + * USE OR OTHER DEALINGS IN THE SOFTWARE. * + * -------------------------------------------------------------------------- */ + +#include "lepton/CompiledVectorExpression.h" +#include "lepton/Operation.h" +#include "lepton/ParsedExpression.h" +#include +#include + +using namespace Lepton; +using namespace std; +#ifdef LEPTON_USE_JIT +using namespace asmjit; +#endif + +CompiledVectorExpression::CompiledVectorExpression() : jitCode(NULL) { +} + +CompiledVectorExpression::CompiledVectorExpression(const ParsedExpression& expression, int width) : jitCode(NULL), width(width) { + const vector allowedWidths = getAllowedWidths(); + if (find(allowedWidths.begin(), allowedWidths.end(), width) == allowedWidths.end()) + throw Exception("Unsupported width for vector expression: "+to_string(width)); + ParsedExpression expr = expression.optimize(); // Just in case it wasn't already optimized. + vector > temps; + int workspaceSize = 0; + compileExpression(expr.getRootNode(), temps, workspaceSize); + workspace.resize(workspaceSize*width); + int maxArguments = 1; + for (int i = 0; i < (int) operation.size(); i++) + if (operation[i]->getNumArguments() > maxArguments) + maxArguments = operation[i]->getNumArguments(); + argValues.resize(maxArguments); +#ifdef LEPTON_USE_JIT + generateJitCode(); +#endif +} + +CompiledVectorExpression::~CompiledVectorExpression() { + for (int i = 0; i < (int) operation.size(); i++) + if (operation[i] != NULL) + delete operation[i]; +} + +CompiledVectorExpression::CompiledVectorExpression(const CompiledVectorExpression& expression) : jitCode(NULL) { + *this = expression; +} + +CompiledVectorExpression& CompiledVectorExpression::operator=(const CompiledVectorExpression& expression) { + arguments = expression.arguments; + width = expression.width; + target = expression.target; + variableIndices = expression.variableIndices; + variableNames = expression.variableNames; + workspace.resize(expression.workspace.size()); + argValues.resize(expression.argValues.size()); + operation.resize(expression.operation.size()); + for (int i = 0; i < (int) operation.size(); i++) + operation[i] = expression.operation[i]->clone(); + setVariableLocations(variablePointers); + return *this; +} + +const vector& CompiledVectorExpression::getAllowedWidths() { + static vector widths; + if (widths.size() == 0) { + widths.push_back(4); +#ifdef LEPTON_USE_JIT + const CpuInfo& cpu = CpuInfo::host(); + if (cpu.hasFeature(CpuFeatures::X86::kAVX)) + widths.push_back(8); +#endif + } + return widths; +} + +void CompiledVectorExpression::compileExpression(const ExpressionTreeNode& node, vector >& temps, int& workspaceSize) { + if (findTempIndex(node, temps) != -1) + return; // We have already processed a node identical to this one. + + // Process the child nodes. + + vector args; + for (int i = 0; i < node.getChildren().size(); i++) { + compileExpression(node.getChildren()[i], temps, workspaceSize); + args.push_back(findTempIndex(node.getChildren()[i], temps)); + } + + // Process this node. + + if (node.getOperation().getId() == Operation::VARIABLE) { + variableIndices[node.getOperation().getName()] = workspaceSize; + variableNames.insert(node.getOperation().getName()); + } + else { + int stepIndex = (int) arguments.size(); + arguments.push_back(vector()); + target.push_back(workspaceSize); + operation.push_back(node.getOperation().clone()); + if (args.size() == 0) + arguments[stepIndex].push_back(0); // The value won't actually be used. We just need something there. + else { + // If the arguments are sequential, we can just pass a pointer to the first one. + + bool sequential = true; + for (int i = 1; i < args.size(); i++) + if (args[i] != args[i - 1] + 1) + sequential = false; + if (sequential) + arguments[stepIndex].push_back(args[0]); + else + arguments[stepIndex] = args; + } + } + temps.push_back(make_pair(node, workspaceSize)); + workspaceSize++; +} + +int CompiledVectorExpression::findTempIndex(const ExpressionTreeNode& node, vector >& temps) { + for (int i = 0; i < (int) temps.size(); i++) + if (temps[i].first == node) + return i; + return -1; +} + +int CompiledVectorExpression::getWidth() const { + return width; +} + +const set& CompiledVectorExpression::getVariables() const { + return variableNames; +} + +float* CompiledVectorExpression::getVariablePointer(const string& name) { + map::iterator pointer = variablePointers.find(name); + if (pointer != variablePointers.end()) + return pointer->second; + map::iterator index = variableIndices.find(name); + if (index == variableIndices.end()) + throw Exception("getVariableReference: Unknown variable '" + name + "'"); + return &workspace[index->second*width]; +} + +void CompiledVectorExpression::setVariableLocations(map& variableLocations) { + variablePointers = variableLocations; +#ifdef LEPTON_USE_JIT + // Rebuild the JIT code. + + if (workspace.size() > 0) + generateJitCode(); +#endif + // Make a list of all variables we will need to copy before evaluating the expression. + + variablesToCopy.clear(); + for (map::const_iterator iter = variableIndices.begin(); iter != variableIndices.end(); ++iter) { + map::iterator pointer = variablePointers.find(iter->first); + if (pointer != variablePointers.end()) + variablesToCopy.push_back(make_pair(&workspace[iter->second*width], pointer->second)); + } +} + +const float* CompiledVectorExpression::evaluate() const { + if (jitCode) { + jitCode(); + return &workspace[workspace.size()-width]; + } + for (int i = 0; i < variablesToCopy.size(); i++) + for (int j = 0; j < width; j++) + variablesToCopy[i].first[j] = variablesToCopy[i].second[j]; + + // Loop over the operations and evaluate each one. + + for (int step = 0; step < operation.size(); step++) { + const vector& args = arguments[step]; + if (args.size() == 1) { + for (int j = 0; j < width; j++) { + for (int i = 0; i < operation[step]->getNumArguments(); i++) + argValues[i] = workspace[(args[0]+i)*width+j]; + workspace[target[step]*width+j] = operation[step]->evaluate(&argValues[0], dummyVariables); + } + } else { + for (int j = 0; j < width; j++) { + for (int i = 0; i < args.size(); i++) + argValues[i] = workspace[args[i]*width+j]; + workspace[target[step]*width+j] = operation[step]->evaluate(&argValues[0], dummyVariables); + } + } + } + return &workspace[workspace.size()-width]; +} + +#ifdef LEPTON_USE_JIT + +static double evaluateOperation(Operation* op, double* args) { + static map dummyVariables; + return op->evaluate(args, dummyVariables); +} + +void CompiledVectorExpression::findPowerGroups(vector >& groups, vector >& groupPowers, vector& stepGroup) { + // Identify every step that raises an argument to an integer power. + + vector stepPower(operation.size(), 0); + vector stepArg(operation.size(), -1); + for (int step = 0; step < operation.size(); step++) { + Operation& op = *operation[step]; + int power = 0; + if (op.getId() == Operation::SQUARE) + power = 2; + else if (op.getId() == Operation::CUBE) + power = 3; + else if (op.getId() == Operation::POWER_CONSTANT) { + double realPower = dynamic_cast (&op)->getValue(); + if (realPower == (int) realPower) + power = (int) realPower; + } + if (power != 0) { + stepPower[step] = power; + stepArg[step] = arguments[step][0]; + } + } + + // Find groups that operate on the same argument and whose powers have the same sign. + + stepGroup.resize(operation.size(), -1); + for (int i = 0; i < operation.size(); i++) { + if (stepGroup[i] != -1) + continue; + vector group, power; + for (int j = i; j < operation.size(); j++) { + if (stepArg[i] == stepArg[j] && stepPower[i] * stepPower[j] > 0) { + stepGroup[j] = groups.size(); + group.push_back(j); + power.push_back(stepPower[j]); + } + } + groups.push_back(group); + groupPowers.push_back(power); + } +} + +#if defined(__ARM__) || defined(__ARM64__) + +void CompiledVectorExpression::generateJitCode() { + CodeHolder code; + code.init(runtime.environment()); + a64::Compiler c(&code); + c.addFunc(FuncSignatureT()); + vector workspaceVar(workspace.size()/width); + for (int i = 0; i < (int) workspaceVar.size(); i++) + workspaceVar[i] = c.newVecQ(); + arm::Gp argsPointer = c.newIntPtr(); + c.mov(argsPointer, imm(&argValues[0])); + vector > groups, groupPowers; + vector stepGroup; + findPowerGroups(groups, groupPowers, stepGroup); + + // Load the arguments into variables. + + arm::Gp variablePointer = c.newIntPtr(); + for (set::const_iterator iter = variableNames.begin(); iter != variableNames.end(); ++iter) { + map::iterator index = variableIndices.find(*iter); + c.mov(variablePointer, imm(getVariablePointer(index->first))); + c.ldr(workspaceVar[index->second].s4(), arm::ptr(variablePointer, 0)); + } + + // Make a list of all constants that will be needed for evaluation. + + vector operationConstantIndex(operation.size(), -1); + for (int step = 0; step < (int) operation.size(); step++) { + // Find the constant value (if any) used by this operation. + + Operation& op = *operation[step]; + float value; + if (op.getId() == Operation::CONSTANT) + value = dynamic_cast (op).getValue(); + else if (op.getId() == Operation::ADD_CONSTANT) + value = dynamic_cast (op).getValue(); + else if (op.getId() == Operation::MULTIPLY_CONSTANT) + value = dynamic_cast (op).getValue(); + else if (op.getId() == Operation::RECIPROCAL) + value = 1.0; + else if (op.getId() == Operation::STEP) + value = 1.0; + else if (op.getId() == Operation::DELTA) + value = 1.0; + else if (op.getId() == Operation::POWER_CONSTANT) { + if (stepGroup[step] == -1) + value = dynamic_cast (op).getValue(); + else + value = 1.0; + } else + continue; + + // See if we already have a variable for this constant. + + for (int i = 0; i < (int) constants.size(); i++) + if (value == constants[i]) { + operationConstantIndex[step] = i; + break; + } + if (operationConstantIndex[step] == -1) { + operationConstantIndex[step] = constants.size(); + constants.push_back(value); + } + } + + // Load constants into variables. + + vector constantVar(constants.size()); + if (constants.size() > 0) { + arm::Gp constantsPointer = c.newIntPtr(); + for (int i = 0; i < (int) constants.size(); i++) { + c.mov(constantsPointer, imm(&constants[i])); + constantVar[i] = c.newVecQ(); + c.ld1r(constantVar[i].s4(), arm::ptr(constantsPointer)); + } + } + + // Evaluate the operations. + + vector hasComputedPower(operation.size(), false); + arm::Vec argReg = c.newVecS(); + arm::Vec doubleArgReg = c.newVecD(); + arm::Vec doubleResultReg = c.newVecD(); + for (int step = 0; step < (int) operation.size(); step++) { + if (hasComputedPower[step]) + continue; + + // When one or more steps involve raising the same argument to multiple integer + // powers, we can compute them all together for efficiency. + + if (stepGroup[step] != -1) { + vector& group = groups[stepGroup[step]]; + vector& powers = groupPowers[stepGroup[step]]; + arm::Vec multiplier = c.newVecQ(); + if (powers[0] > 0) + c.mov(multiplier.s4(), workspaceVar[arguments[step][0]].s4()); + else { + c.fdiv(multiplier.s4(), constantVar[operationConstantIndex[step]].s4(), workspaceVar[arguments[step][0]].s4()); + for (int i = 0; i < powers.size(); i++) + powers[i] = -powers[i]; + } + vector hasAssigned(group.size(), false); + bool done = false; + while (!done) { + done = true; + for (int i = 0; i < group.size(); i++) { + if (powers[i] % 2 == 1) { + if (!hasAssigned[i]) + c.mov(workspaceVar[target[group[i]]].s4(), multiplier.s4()); + else + c.fmul(workspaceVar[target[group[i]]].s4(), workspaceVar[target[group[i]]].s4(), multiplier.s4()); + hasAssigned[i] = true; + } + powers[i] >>= 1; + if (powers[i] != 0) + done = false; + } + if (!done) + c.fmul(multiplier.s4(), multiplier.s4(), multiplier.s4()); + } + for (int step : group) + hasComputedPower[step] = true; + continue; + } + + // Evaluate the step. + + Operation& op = *operation[step]; + vector args = arguments[step]; + if (args.size() == 1) { + // One or more sequential arguments. Fill out the list. + + for (int i = 1; i < op.getNumArguments(); i++) + args.push_back(args[0] + i); + } + + // Generate instructions to execute this operation. + + switch (op.getId()) { + case Operation::CONSTANT: + c.mov(workspaceVar[target[step]].s4(), constantVar[operationConstantIndex[step]].s4()); + break; + case Operation::ADD: + c.fadd(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[1]].s4()); + break; + case Operation::SUBTRACT: + c.fsub(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[1]].s4()); + break; + case Operation::MULTIPLY: + c.fmul(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[1]].s4()); + break; + case Operation::DIVIDE: + c.fdiv(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[1]].s4()); + break; + case Operation::POWER: + generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], powf); + break; + case Operation::NEGATE: + c.fneg(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4()); + break; + case Operation::SQRT: + c.fsqrt(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4()); + break; + case Operation::EXP: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], expf); + break; + case Operation::LOG: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], logf); + break; + case Operation::SIN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], sinf); + break; + case Operation::COS: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], cosf); + break; + case Operation::TAN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tanf); + break; + case Operation::ASIN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], asinf); + break; + case Operation::ACOS: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], acosf); + break; + case Operation::ATAN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], atanf); + break; + case Operation::ATAN2: + generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], atan2f); + break; + case Operation::SINH: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], sinhf); + break; + case Operation::COSH: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], coshf); + break; + case Operation::TANH: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tanhf); + break; + case Operation::STEP: + c.cmge(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), imm(0)); + c.and_(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + break; + case Operation::DELTA: + c.cmeq(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), imm(0)); + c.and_(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + break; + case Operation::SQUARE: + c.fmul(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[0]].s4()); + break; + case Operation::CUBE: + c.fmul(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[0]].s4()); + c.fmul(workspaceVar[target[step]].s4(), workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4()); + break; + case Operation::RECIPROCAL: + c.fdiv(workspaceVar[target[step]].s4(), constantVar[operationConstantIndex[step]].s4(), workspaceVar[args[0]].s4()); + break; + case Operation::ADD_CONSTANT: + c.fadd(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), constantVar[operationConstantIndex[step]].s4()); + break; + case Operation::MULTIPLY_CONSTANT: + c.fmul(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), constantVar[operationConstantIndex[step]].s4()); + break; + case Operation::POWER_CONSTANT: + generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]], powf); + break; + case Operation::MIN: + c.fmin(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[1]].s4()); + break; + case Operation::MAX: + c.fmax(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[1]].s4()); + break; + case Operation::ABS: + c.fabs(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4()); + break; + case Operation::FLOOR: + c.frintm(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4()); + break; + case Operation::CEIL: + c.frintp(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4()); + break; + case Operation::SELECT: + c.fcmeq(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), imm(0)); + c.bsl(workspaceVar[target[step]], workspaceVar[args[2]], workspaceVar[args[1]]); + break; + default: + // Just invoke evaluateOperation(). + for (int element = 0; element < width; element++) { + for (int i = 0; i < (int) args.size(); i++) { + c.ins(argReg.s(0), workspaceVar[args[i]].s(element)); + c.fcvt(doubleArgReg, argReg); + c.str(doubleArgReg, arm::ptr(argsPointer, 8*i)); + } + arm::Gp fn = c.newIntPtr(); + c.mov(fn, imm((void*) evaluateOperation)); + InvokeNode* invoke; + c.invoke(&invoke, fn, FuncSignatureT()); + invoke->setArg(0, imm(&op)); + invoke->setArg(1, imm(&argValues[0])); + invoke->setRet(0, doubleResultReg); + c.fcvt(argReg, doubleResultReg); + c.ins(workspaceVar[target[step]].s(element), argReg.s(0)); + } + } + } + arm::Gp resultPointer = c.newIntPtr(); + c.mov(resultPointer, imm(&workspace[workspace.size()-width])); + c.str(workspaceVar.back().s4(), arm::ptr(resultPointer, 0)); + c.endFunc(); + c.finalize(); + runtime.add(&jitCode, &code); +} + +void CompiledVectorExpression::generateSingleArgCall(a64::Compiler& c, arm::Vec& dest, arm::Vec& arg, float (*function)(float)) { + arm::Gp fn = c.newIntPtr(); + c.mov(fn, imm((void*) function)); + arm::Vec a = c.newVecS(); + arm::Vec d = c.newVecS(); + for (int element = 0; element < width; element++) { + c.ins(a.s(0), arg.s(element)); + InvokeNode* invoke; + c.invoke(&invoke, fn, FuncSignatureT()); + invoke->setArg(0, a); + invoke->setRet(0, d); + c.ins(dest.s(element), d.s(0)); + } +} + +void CompiledVectorExpression::generateTwoArgCall(a64::Compiler& c, arm::Vec& dest, arm::Vec& arg1, arm::Vec& arg2, float (*function)(float, float)) { + arm::Gp fn = c.newIntPtr(); + c.mov(fn, imm((void*) function)); + arm::Vec a1 = c.newVecS(); + arm::Vec a2 = c.newVecS(); + arm::Vec d = c.newVecS(); + for (int element = 0; element < width; element++) { + c.ins(a1.s(0), arg1.s(element)); + c.ins(a2.s(0), arg2.s(element)); + InvokeNode* invoke; + c.invoke(&invoke, fn, FuncSignatureT()); + invoke->setArg(0, a1); + invoke->setArg(1, a2); + invoke->setRet(0, d); + c.ins(dest.s(element), d.s(0)); + } +} +#else + +void CompiledVectorExpression::generateJitCode() { + const CpuInfo& cpu = CpuInfo::host(); + if (!cpu.hasFeature(CpuFeatures::X86::kAVX)) + return; + CodeHolder code; + code.init(runtime.environment()); + x86::Compiler c(&code); + FuncNode* funcNode = c.addFunc(FuncSignatureT()); + funcNode->frame().setAvxEnabled(); + vector workspaceVar(workspace.size()/width); + for (int i = 0; i < (int) workspaceVar.size(); i++) + workspaceVar[i] = c.newYmmPs(); + x86::Gp argsPointer = c.newIntPtr(); + c.mov(argsPointer, imm(&argValues[0])); + vector > groups, groupPowers; + vector stepGroup; + findPowerGroups(groups, groupPowers, stepGroup); + + // Load the arguments into variables. + + for (set::const_iterator iter = variableNames.begin(); iter != variableNames.end(); ++iter) { + map::iterator index = variableIndices.find(*iter); + x86::Gp variablePointer = c.newIntPtr(); + c.mov(variablePointer, imm(getVariablePointer(index->first))); + if (width == 4) + c.vmovdqu(workspaceVar[index->second].xmm(), x86::ptr(variablePointer, 0, 0)); + else + c.vmovdqu(workspaceVar[index->second], x86::ptr(variablePointer, 0, 0)); + } + + // Make a list of all constants that will be needed for evaluation. + + vector operationConstantIndex(operation.size(), -1); + for (int step = 0; step < (int) operation.size(); step++) { + // Find the constant value (if any) used by this operation. + + Operation& op = *operation[step]; + double value; + if (op.getId() == Operation::CONSTANT) + value = dynamic_cast (op).getValue(); + else if (op.getId() == Operation::ADD_CONSTANT) + value = dynamic_cast (op).getValue(); + else if (op.getId() == Operation::MULTIPLY_CONSTANT) + value = dynamic_cast (op).getValue(); + else if (op.getId() == Operation::RECIPROCAL) + value = 1.0; + else if (op.getId() == Operation::STEP) + value = 1.0; + else if (op.getId() == Operation::DELTA) + value = 1.0; + else if (op.getId() == Operation::ABS) { + int mask = 0x7FFFFFFF; + value = *reinterpret_cast(&mask); + } + else if (op.getId() == Operation::POWER_CONSTANT) { + if (stepGroup[step] == -1) + value = dynamic_cast (op).getValue(); + else + value = 1.0; + } else + continue; + + // See if we already have a variable for this constant. + + for (int i = 0; i < (int) constants.size(); i++) + if (value == constants[i]) { + operationConstantIndex[step] = i; + break; + } + if (operationConstantIndex[step] == -1) { + operationConstantIndex[step] = constants.size(); + constants.push_back(value); + } + } + + // Load constants into variables. + + vector constantVar(constants.size()); + if (constants.size() > 0) { + x86::Gp constantsPointer = c.newIntPtr(); + c.mov(constantsPointer, imm(&constants[0])); + for (int i = 0; i < (int) constants.size(); i++) { + constantVar[i] = c.newYmmPs(); + c.vbroadcastss(constantVar[i], x86::ptr(constantsPointer, 4*i, 0)); + } + } + + // Evaluate the operations. + + vector hasComputedPower(operation.size(), false); + x86::Ymm argReg = c.newYmm(); + x86::Ymm doubleArgReg = c.newYmm(); + x86::Ymm doubleResultReg = c.newYmm(); + for (int step = 0; step < (int) operation.size(); step++) { + if (hasComputedPower[step]) + continue; + + // When one or more steps involve raising the same argument to multiple integer + // powers, we can compute them all together for efficiency. + + if (stepGroup[step] != -1) { + vector& group = groups[stepGroup[step]]; + vector& powers = groupPowers[stepGroup[step]]; + x86::Ymm multiplier = c.newYmmPs(); + if (powers[0] > 0) + c.vmovdqu(multiplier, workspaceVar[arguments[step][0]]); + else { + c.vdivps(multiplier, constantVar[operationConstantIndex[step]], workspaceVar[arguments[step][0]]); + for (int i = 0; i < powers.size(); i++) + powers[i] = -powers[i]; + } + vector hasAssigned(group.size(), false); + bool done = false; + while (!done) { + done = true; + for (int i = 0; i < group.size(); i++) { + if (powers[i] % 2 == 1) { + if (!hasAssigned[i]) + c.vmovdqu(workspaceVar[target[group[i]]], multiplier); + else + c.vmulps(workspaceVar[target[group[i]]], workspaceVar[target[group[i]]], multiplier); + hasAssigned[i] = true; + } + powers[i] >>= 1; + if (powers[i] != 0) + done = false; + } + if (!done) + c.vmulps(multiplier, multiplier, multiplier); + } + for (int step : group) + hasComputedPower[step] = true; + continue; + } + + // Evaluate the step. + + Operation& op = *operation[step]; + vector args = arguments[step]; + if (args.size() == 1) { + // One or more sequential arguments. Fill out the list. + + for (int i = 1; i < op.getNumArguments(); i++) + args.push_back(args[0] + i); + } + + // Generate instructions to execute this operation. + + switch (op.getId()) { + case Operation::CONSTANT: + c.vmovdqu(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + break; + case Operation::ADD: + c.vaddps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::SUBTRACT: + c.vsubps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::MULTIPLY: + c.vmulps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::DIVIDE: + c.vdivps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::POWER: + generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], powf); + break; + case Operation::NEGATE: + c.vxorps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[target[step]]); + c.vsubps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]]); + break; + case Operation::SQRT: + c.vsqrtps(workspaceVar[target[step]], workspaceVar[args[0]]); + break; + case Operation::EXP: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], expf); + break; + case Operation::LOG: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], logf); + break; + case Operation::SIN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], sinf); + break; + case Operation::COS: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], cosf); + break; + case Operation::TAN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tanf); + break; + case Operation::ASIN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], asinf); + break; + case Operation::ACOS: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], acosf); + break; + case Operation::ATAN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], atanf); + break; + case Operation::ATAN2: + generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], atan2f); + break; + case Operation::SINH: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], sinhf); + break; + case Operation::COSH: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], coshf); + break; + case Operation::TANH: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tanhf); + break; + case Operation::STEP: + c.vxorps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[target[step]]); + c.vcmpps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]], imm(18)); // Comparison mode is _CMP_LE_OQ = 18 + c.vandps(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + break; + case Operation::DELTA: + c.vxorps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[target[step]]); + c.vcmpps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]], imm(16)); // Comparison mode is _CMP_EQ_OQ = 0 + c.vandps(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + break; + case Operation::SQUARE: + c.vmulps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]]); + break; + case Operation::CUBE: + c.vmulps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]]); + c.vmulps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]]); + break; + case Operation::RECIPROCAL: + c.vdivps(workspaceVar[target[step]], constantVar[operationConstantIndex[step]], workspaceVar[args[0]]); + break; + case Operation::ADD_CONSTANT: + c.vaddps(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]); + break; + case Operation::MULTIPLY_CONSTANT: + c.vmulps(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]); + break; + case Operation::POWER_CONSTANT: + generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]], powf); + break; + case Operation::MIN: + c.vminps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::MAX: + c.vmaxps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::ABS: + c.vandps(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]); + break; + case Operation::FLOOR: + c.vroundps(workspaceVar[target[step]], workspaceVar[args[0]], imm(1)); + break; + case Operation::CEIL: + c.vroundps(workspaceVar[target[step]], workspaceVar[args[0]], imm(2)); + break; + case Operation::SELECT: + { + x86::Ymm mask = c.newYmmPs(); + c.vxorps(mask, mask, mask); + c.vcmpps(mask, mask, workspaceVar[args[0]], imm(0)); // Comparison mode is _CMP_EQ_OQ = 0 + c.vblendvps(workspaceVar[target[step]], workspaceVar[args[1]], workspaceVar[args[2]], mask); + break; + } + default: + // Just invoke evaluateOperation(). + + for (int element = 0; element < width; element++) { + for (int i = 0; i < (int) args.size(); i++) { + if (element < 4) + c.vshufps(argReg, workspaceVar[args[i]], workspaceVar[args[i]], imm(element)); + else { + c.vperm2f128(argReg, workspaceVar[args[i]], workspaceVar[args[i]], imm(1)); + c.vshufps(argReg, argReg, argReg, imm(element-4)); + } + c.vcvtss2sd(doubleArgReg.xmm(), doubleArgReg.xmm(), argReg.xmm()); + c.vmovsd(x86::ptr(argsPointer, 8*i, 0), doubleArgReg.xmm()); + } + x86::Gp fn = c.newIntPtr(); + c.mov(fn, imm((void*) evaluateOperation)); + InvokeNode* invoke; + c.invoke(&invoke, fn, FuncSignatureT()); + invoke->setArg(0, imm(&op)); + invoke->setArg(1, imm(&argValues[0])); + invoke->setRet(0, doubleResultReg); + c.vcvtsd2ss(argReg.xmm(), argReg.xmm(), doubleResultReg.xmm()); + if (element > 3) + c.vperm2f128(argReg, argReg, argReg, imm(0)); + if (element != 0) + c.vshufps(argReg, argReg, argReg, imm(0)); + c.vblendps(workspaceVar[target[step]], workspaceVar[target[step]], argReg, 1<()); + invoke->setArg(0, a); + invoke->setRet(0, d); + if (element > 3) + c.vperm2f128(d, d, d, imm(0)); + if (element != 0) + c.vshufps(d, d, d, imm(0)); + c.vblendps(dest, dest, d, 1<()); + invoke->setArg(0, a1); + invoke->setArg(1, a2); + invoke->setRet(0, d); + if (element > 3) + c.vperm2f128(d, d, d, imm(0)); + if (element != 0) + c.vshufps(d, d, d, imm(0)); + c.vblendps(dest, dest, d, 1< using namespace Lepton; using namespace std; @@ -62,6 +63,11 @@ ExpressionTreeNode::ExpressionTreeNode(Operation* operation) : operation(operati ExpressionTreeNode::ExpressionTreeNode(const ExpressionTreeNode& node) : operation(node.operation == NULL ? NULL : node.operation->clone()), children(node.getChildren()) { } +ExpressionTreeNode::ExpressionTreeNode(ExpressionTreeNode&& node) : operation(node.operation), children(move(node.children)) { + node.operation = NULL; + node.children.clear(); +} + ExpressionTreeNode::ExpressionTreeNode() : operation(NULL) { } @@ -98,6 +104,16 @@ ExpressionTreeNode& ExpressionTreeNode::operator=(const ExpressionTreeNode& node return *this; } +ExpressionTreeNode& ExpressionTreeNode::operator=(ExpressionTreeNode&& node) { + if (operation != NULL) + delete operation; + operation = node.operation; + children = move(node.children); + node.operation = NULL; + node.children.clear(); + return *this; +} + const Operation& ExpressionTreeNode::getOperation() const { return *operation; } @@ -105,3 +121,33 @@ const Operation& ExpressionTreeNode::getOperation() const { const vector& ExpressionTreeNode::getChildren() const { return children; } + +void ExpressionTreeNode::assignTags(vector& examples) const { + // Assign tag values to all nodes in a tree, such that two nodes have the same + // tag if and only if they (and all their children) are equal. This is used to + // optimize other operations. + + int numTags = examples.size(); + for (const ExpressionTreeNode& child : getChildren()) + child.assignTags(examples); + if (numTags == examples.size()) { + // All the children matched existing tags, so possibly this node does too. + + for (int i = 0; i < examples.size(); i++) { + const ExpressionTreeNode& example = *examples[i]; + bool matches = (getChildren().size() == example.getChildren().size() && getOperation() == example.getOperation()); + for (int j = 0; matches && j < getChildren().size(); j++) + if (getChildren()[j].tag != example.getChildren()[j].tag) + matches = false; + if (matches) { + tag = i; + return; + } + } + } + + // This node does not match any previous node, so assign a new tag. + + tag = examples.size(); + examples.push_back(this); +} diff --git a/lib/colvars/lepton/src/MSVC_erfc.h b/lib/colvars/lepton/src/MSVC_erfc.h index b1cd87a289..2c6b619e89 100644 --- a/lib/colvars/lepton/src/MSVC_erfc.h +++ b/lib/colvars/lepton/src/MSVC_erfc.h @@ -3,7 +3,7 @@ /* * Up to version 11 (VC++ 2012), Microsoft does not support the - * standard C99 erf() and erfc() functions so we have to fake them here. + * standard C99 erf() and erfc() functions so we have to fake them here. * These were added in version 12 (VC++ 2013), which sets _MSC_VER=1800 * (VC11 has _MSC_VER=1700). */ @@ -15,7 +15,7 @@ #endif #if defined(_MSC_VER) -#if _MSC_VER <= 1700 // 1700 is VC11, 1800 is VC12 +#if _MSC_VER <= 1700 // 1700 is VC11, 1800 is VC12 /*************************** * erf.cpp * author: Steve Strand diff --git a/lib/colvars/lepton/src/Operation.cpp b/lib/colvars/lepton/src/Operation.cpp index 78741c4814..b5a958b2f7 100644 --- a/lib/colvars/lepton/src/Operation.cpp +++ b/lib/colvars/lepton/src/Operation.cpp @@ -7,7 +7,7 @@ * Biological Structures at Stanford, funded under the NIH Roadmap for * * Medical Research, grant U54 GM072970. See https://simtk.org. * * * - * Portions copyright (c) 2009-2019 Stanford University and the Authors. * + * Portions copyright (c) 2009-2021 Stanford University and the Authors. * * Authors: Peter Eastman * * Contributors: * * * @@ -37,6 +37,12 @@ using namespace Lepton; using namespace std; +static bool isZero(const ExpressionTreeNode& node) { + if (node.getOperation().getId() != Operation::CONSTANT) + return false; + return dynamic_cast(node.getOperation()).getValue() == 0.0; +} + double Operation::Erf::evaluate(double* args, const map& variables) const { return erf(args[0]); } @@ -58,35 +64,71 @@ ExpressionTreeNode Operation::Variable::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { if (function->getNumArguments() == 0) return ExpressionTreeNode(new Operation::Constant(0.0)); - ExpressionTreeNode result = ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Custom(*this, 0), children), childDerivs[0]); - for (int i = 1; i < getNumArguments(); i++) { - result = ExpressionTreeNode(new Operation::Add(), - result, - ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Custom(*this, i), children), childDerivs[i])); + ExpressionTreeNode result; + bool foundTerm = false; + for (int i = 0; i < getNumArguments(); i++) { + if (!isZero(childDerivs[i])) { + if (foundTerm) + result = ExpressionTreeNode(new Operation::Add(), + result, + ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Custom(*this, i), children), childDerivs[i])); + else { + result = ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Custom(*this, i), children), childDerivs[i]); + foundTerm = true; + } + } } - return result; + if (foundTerm) + return result; + return ExpressionTreeNode(new Operation::Constant(0.0)); } ExpressionTreeNode Operation::Add::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) + return childDerivs[1]; + if (isZero(childDerivs[1])) + return childDerivs[0]; return ExpressionTreeNode(new Operation::Add(), childDerivs[0], childDerivs[1]); } ExpressionTreeNode Operation::Subtract::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) { + if (isZero(childDerivs[1])) + return ExpressionTreeNode(new Operation::Constant(0.0)); + return ExpressionTreeNode(new Operation::Negate(), childDerivs[1]); + } + if (isZero(childDerivs[1])) + return childDerivs[0]; return ExpressionTreeNode(new Operation::Subtract(), childDerivs[0], childDerivs[1]); } ExpressionTreeNode Operation::Multiply::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) { + if (isZero(childDerivs[1])) + return ExpressionTreeNode(new Operation::Constant(0.0)); + return ExpressionTreeNode(new Operation::Multiply(), children[0], childDerivs[1]); + } + if (isZero(childDerivs[1])) + return ExpressionTreeNode(new Operation::Multiply(), children[1], childDerivs[0]); return ExpressionTreeNode(new Operation::Add(), ExpressionTreeNode(new Operation::Multiply(), children[0], childDerivs[1]), ExpressionTreeNode(new Operation::Multiply(), children[1], childDerivs[0])); } ExpressionTreeNode Operation::Divide::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { - return ExpressionTreeNode(new Operation::Divide(), - ExpressionTreeNode(new Operation::Subtract(), + ExpressionTreeNode subexp; + if (isZero(childDerivs[0])) { + if (isZero(childDerivs[1])) + return ExpressionTreeNode(new Operation::Constant(0.0)); + subexp = ExpressionTreeNode(new Operation::Negate(), ExpressionTreeNode(new Operation::Multiply(), children[0], childDerivs[1])); + } + else if (isZero(childDerivs[1])) + subexp = ExpressionTreeNode(new Operation::Multiply(), children[1], childDerivs[0]); + else + subexp = ExpressionTreeNode(new Operation::Subtract(), ExpressionTreeNode(new Operation::Multiply(), children[1], childDerivs[0]), - ExpressionTreeNode(new Operation::Multiply(), children[0], childDerivs[1])), - ExpressionTreeNode(new Operation::Square(), children[1])); + ExpressionTreeNode(new Operation::Multiply(), children[0], childDerivs[1])); + return ExpressionTreeNode(new Operation::Divide(), subexp, ExpressionTreeNode(new Operation::Square(), children[1])); } ExpressionTreeNode Operation::Power::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { @@ -105,10 +147,14 @@ ExpressionTreeNode Operation::Power::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Negate(), childDerivs[0]); } ExpressionTreeNode Operation::Sqrt::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::MultiplyConstant(0.5), ExpressionTreeNode(new Operation::Reciprocal(), @@ -117,24 +163,32 @@ ExpressionTreeNode Operation::Sqrt::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Exp(), children[0]), childDerivs[0]); } ExpressionTreeNode Operation::Log::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Reciprocal(), children[0]), childDerivs[0]); } ExpressionTreeNode Operation::Sin::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Cos(), children[0]), childDerivs[0]); } ExpressionTreeNode Operation::Cos::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Negate(), ExpressionTreeNode(new Operation::Sin(), children[0])), @@ -142,6 +196,8 @@ ExpressionTreeNode Operation::Cos::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Sec(), children[0]), @@ -150,6 +206,8 @@ ExpressionTreeNode Operation::Sec::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Negate(), ExpressionTreeNode(new Operation::Multiply(), @@ -159,6 +217,8 @@ ExpressionTreeNode Operation::Csc::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Square(), ExpressionTreeNode(new Operation::Sec(), children[0])), @@ -166,6 +226,8 @@ ExpressionTreeNode Operation::Tan::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Negate(), ExpressionTreeNode(new Operation::Square(), @@ -174,6 +236,8 @@ ExpressionTreeNode Operation::Cot::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Reciprocal(), ExpressionTreeNode(new Operation::Sqrt(), @@ -184,6 +248,8 @@ ExpressionTreeNode Operation::Asin::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Negate(), ExpressionTreeNode(new Operation::Reciprocal(), @@ -195,6 +261,8 @@ ExpressionTreeNode Operation::Acos::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Reciprocal(), ExpressionTreeNode(new Operation::AddConstant(1.0), @@ -213,6 +281,8 @@ ExpressionTreeNode Operation::Atan2::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Cosh(), children[0]), @@ -220,6 +290,8 @@ ExpressionTreeNode Operation::Sinh::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Sinh(), children[0]), @@ -227,6 +299,8 @@ ExpressionTreeNode Operation::Cosh::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Subtract(), ExpressionTreeNode(new Operation::Constant(1.0)), @@ -236,6 +310,8 @@ ExpressionTreeNode Operation::Tanh::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Constant(2.0/sqrt(M_PI))), @@ -246,6 +322,8 @@ ExpressionTreeNode Operation::Erf::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Constant(-2.0/sqrt(M_PI))), @@ -264,6 +342,8 @@ ExpressionTreeNode Operation::Delta::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::MultiplyConstant(2.0), children[0]), @@ -271,6 +351,8 @@ ExpressionTreeNode Operation::Square::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::MultiplyConstant(3.0), ExpressionTreeNode(new Operation::Square(), children[0])), @@ -278,6 +360,8 @@ ExpressionTreeNode Operation::Cube::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Negate(), ExpressionTreeNode(new Operation::Reciprocal(), @@ -290,11 +374,15 @@ ExpressionTreeNode Operation::AddConstant::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::MultiplyConstant(value), childDerivs[0]); } ExpressionTreeNode Operation::PowerConstant::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::MultiplyConstant(value), ExpressionTreeNode(new Operation::PowerConstant(value-1), @@ -305,22 +393,18 @@ ExpressionTreeNode Operation::PowerConstant::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { ExpressionTreeNode step(new Operation::Step(), ExpressionTreeNode(new Operation::Subtract(), children[0], children[1])); - return ExpressionTreeNode(new Operation::Subtract(), - ExpressionTreeNode(new Operation::Multiply(), childDerivs[1], step), - ExpressionTreeNode(new Operation::Multiply(), childDerivs[0], - ExpressionTreeNode(new Operation::AddConstant(-1), step))); + return ExpressionTreeNode(new Operation::Select(), {step, childDerivs[1], childDerivs[0]}); } ExpressionTreeNode Operation::Max::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { ExpressionTreeNode step(new Operation::Step(), ExpressionTreeNode(new Operation::Subtract(), children[0], children[1])); - return ExpressionTreeNode(new Operation::Subtract(), - ExpressionTreeNode(new Operation::Multiply(), childDerivs[0], step), - ExpressionTreeNode(new Operation::Multiply(), childDerivs[1], - ExpressionTreeNode(new Operation::AddConstant(-1), step))); + return ExpressionTreeNode(new Operation::Select(), {step, childDerivs[0], childDerivs[1]}); } ExpressionTreeNode Operation::Abs::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); ExpressionTreeNode step(new Operation::Step(), children[0]); return ExpressionTreeNode(new Operation::Multiply(), childDerivs[0], @@ -337,9 +421,5 @@ ExpressionTreeNode Operation::Ceil::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { - vector derivChildren; - derivChildren.push_back(children[0]); - derivChildren.push_back(childDerivs[1]); - derivChildren.push_back(childDerivs[2]); - return ExpressionTreeNode(new Operation::Select(), derivChildren); + return ExpressionTreeNode(new Operation::Select(), {children[0], childDerivs[1], childDerivs[2]}); } diff --git a/lib/colvars/lepton/src/ParsedExpression.cpp b/lib/colvars/lepton/src/ParsedExpression.cpp index fd3b091d3c..f3f18fccd2 100644 --- a/lib/colvars/lepton/src/ParsedExpression.cpp +++ b/lib/colvars/lepton/src/ParsedExpression.cpp @@ -6,7 +6,7 @@ * Biological Structures at Stanford, funded under the NIH Roadmap for * * Medical Research, grant U54 GM072970. See https://simtk.org. * * * - * Portions copyright (c) 2009 Stanford University and the Authors. * + * Portions copyright (c) 2009-2022 Stanford University and the Authors. * * Authors: Peter Eastman * * Contributors: * * * @@ -31,6 +31,7 @@ #include "lepton/ParsedExpression.h" #include "lepton/CompiledExpression.h" +#include "lepton/CompiledVectorExpression.h" #include "lepton/ExpressionProgram.h" #include "lepton/Operation.h" #include @@ -68,9 +69,16 @@ double ParsedExpression::evaluate(const ExpressionTreeNode& node, const map examples; + result.assignTags(examples); + map nodeCache; + result = precalculateConstantSubexpressions(result, nodeCache); while (true) { - ExpressionTreeNode simplified = substituteSimplerExpression(result); + examples.clear(); + result.assignTags(examples); + nodeCache.clear(); + ExpressionTreeNode simplified = substituteSimplerExpression(result, nodeCache); if (simplified == result) break; result = simplified; @@ -80,9 +88,15 @@ ParsedExpression ParsedExpression::optimize() const { ParsedExpression ParsedExpression::optimize(const map& variables) const { ExpressionTreeNode result = preevaluateVariables(getRootNode(), variables); - result = precalculateConstantSubexpressions(result); + vector examples; + result.assignTags(examples); + map nodeCache; + result = precalculateConstantSubexpressions(result, nodeCache); while (true) { - ExpressionTreeNode simplified = substituteSimplerExpression(result); + examples.clear(); + result.assignTags(examples); + nodeCache.clear(); + ExpressionTreeNode simplified = substituteSimplerExpression(result, nodeCache); if (simplified == result) break; result = simplified; @@ -104,36 +118,67 @@ ExpressionTreeNode ParsedExpression::preevaluateVariables(const ExpressionTreeNo return ExpressionTreeNode(node.getOperation().clone(), children); } -ExpressionTreeNode ParsedExpression::precalculateConstantSubexpressions(const ExpressionTreeNode& node) { +ExpressionTreeNode ParsedExpression::precalculateConstantSubexpressions(const ExpressionTreeNode& node, map& nodeCache) { + auto cached = nodeCache.find(node.tag); + if (cached != nodeCache.end()) + return cached->second; vector children(node.getChildren().size()); for (int i = 0; i < (int) children.size(); i++) - children[i] = precalculateConstantSubexpressions(node.getChildren()[i]); + children[i] = precalculateConstantSubexpressions(node.getChildren()[i], nodeCache); ExpressionTreeNode result = ExpressionTreeNode(node.getOperation().clone(), children); - if (node.getOperation().getId() == Operation::VARIABLE || node.getOperation().getId() == Operation::CUSTOM) + if (node.getOperation().getId() == Operation::VARIABLE || node.getOperation().getId() == Operation::CUSTOM) { + nodeCache[node.tag] = result; return result; + } for (int i = 0; i < (int) children.size(); i++) - if (children[i].getOperation().getId() != Operation::CONSTANT) + if (children[i].getOperation().getId() != Operation::CONSTANT) { + nodeCache[node.tag] = result; return result; - return ExpressionTreeNode(new Operation::Constant(evaluate(result, map()))); + } + result = ExpressionTreeNode(new Operation::Constant(evaluate(result, map()))); + nodeCache[node.tag] = result; + return result; } -ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const ExpressionTreeNode& node) { +ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const ExpressionTreeNode& node, map& nodeCache) { vector children(node.getChildren().size()); - for (int i = 0; i < (int) children.size(); i++) - children[i] = substituteSimplerExpression(node.getChildren()[i]); + for (int i = 0; i < (int) children.size(); i++) { + const ExpressionTreeNode& child = node.getChildren()[i]; + auto cached = nodeCache.find(child.tag); + if (cached == nodeCache.end()) { + children[i] = substituteSimplerExpression(child, nodeCache); + nodeCache[child.tag] = children[i]; + } + else + children[i] = cached->second; + } + + // Collect some info on constant expressions in children + bool first_const = children.size() > 0 && isConstant(children[0]); // is first child constant? + bool second_const = children.size() > 1 && isConstant(children[1]); ; // is second child constant? + double first, second; // if yes, value of first and second child + if (first_const) + first = getConstantValue(children[0]); + if (second_const) + second = getConstantValue(children[1]); + switch (node.getOperation().getId()) { case Operation::ADD: { - double first = getConstantValue(children[0]); - double second = getConstantValue(children[1]); - if (first == 0.0) // Add 0 - return children[1]; - if (second == 0.0) // Add 0 - return children[0]; - if (first == first) // Add a constant - return ExpressionTreeNode(new Operation::AddConstant(first), children[1]); - if (second == second) // Add a constant - return ExpressionTreeNode(new Operation::AddConstant(second), children[0]); + if (first_const) { + if (first == 0.0) { // Add 0 + return children[1]; + } else { // Add a constant + return ExpressionTreeNode(new Operation::AddConstant(first), children[1]); + } + } + if (second_const) { + if (second == 0.0) { // Add 0 + return children[0]; + } else { // Add a constant + return ExpressionTreeNode(new Operation::AddConstant(second), children[0]); + } + } if (children[1].getOperation().getId() == Operation::NEGATE) // a+(-b) = a-b return ExpressionTreeNode(new Operation::Subtract(), children[0], children[1].getChildren()[0]); if (children[0].getOperation().getId() == Operation::NEGATE) // (-a)+b = b-a @@ -144,34 +189,35 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio { if (children[0] == children[1]) return ExpressionTreeNode(new Operation::Constant(0.0)); // Subtracting anything from itself is 0 - double first = getConstantValue(children[0]); - if (first == 0.0) // Subtract from 0 - return ExpressionTreeNode(new Operation::Negate(), children[1]); - double second = getConstantValue(children[1]); - if (second == 0.0) // Subtract 0 - return children[0]; - if (second == second) // Subtract a constant - return ExpressionTreeNode(new Operation::AddConstant(-second), children[0]); + if (first_const) { + if (first == 0.0) // Subtract from 0 + return ExpressionTreeNode(new Operation::Negate(), children[1]); + } + if (second_const) { + if (second == 0.0) { // Subtract 0 + return children[0]; + } else { // Subtract a constant + return ExpressionTreeNode(new Operation::AddConstant(-second), children[0]); + } + } if (children[1].getOperation().getId() == Operation::NEGATE) // a-(-b) = a+b return ExpressionTreeNode(new Operation::Add(), children[0], children[1].getChildren()[0]); break; } case Operation::MULTIPLY: - { - double first = getConstantValue(children[0]); - double second = getConstantValue(children[1]); - if (first == 0.0 || second == 0.0) // Multiply by 0 + { + if ((first_const && first == 0.0) || (second_const && second == 0.0)) // Multiply by 0 return ExpressionTreeNode(new Operation::Constant(0.0)); - if (first == 1.0) // Multiply by 1 + if (first_const && first == 1.0) // Multiply by 1 return children[1]; - if (second == 1.0) // Multiply by 1 + if (second_const && second == 1.0) // Multiply by 1 return children[0]; - if (children[0].getOperation().getId() == Operation::CONSTANT) { // Multiply by a constant + if (first_const) { // Multiply by a constant if (children[1].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine two multiplies into a single one return ExpressionTreeNode(new Operation::MultiplyConstant(first*dynamic_cast(&children[1].getOperation())->getValue()), children[1].getChildren()[0]); return ExpressionTreeNode(new Operation::MultiplyConstant(first), children[1]); } - if (children[1].getOperation().getId() == Operation::CONSTANT) { // Multiply by a constant + if (second_const) { // Multiply by a constant if (children[0].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine two multiplies into a single one return ExpressionTreeNode(new Operation::MultiplyConstant(second*dynamic_cast(&children[0].getOperation())->getValue()), children[0].getChildren()[0]); return ExpressionTreeNode(new Operation::MultiplyConstant(second), children[0]); @@ -202,18 +248,16 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio { if (children[0] == children[1]) return ExpressionTreeNode(new Operation::Constant(1.0)); // Dividing anything from itself is 0 - double numerator = getConstantValue(children[0]); - if (numerator == 0.0) // 0 divided by something + if (first_const && first == 0.0) // 0 divided by something return ExpressionTreeNode(new Operation::Constant(0.0)); - if (numerator == 1.0) // 1 divided by something + if (first_const && first == 1.0) // 1 divided by something return ExpressionTreeNode(new Operation::Reciprocal(), children[1]); - double denominator = getConstantValue(children[1]); - if (denominator == 1.0) // Divide by 1 + if (second_const && second == 1.0) // Divide by 1 return children[0]; - if (children[1].getOperation().getId() == Operation::CONSTANT) { + if (second_const) { if (children[0].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine a multiply and a divide into one multiply - return ExpressionTreeNode(new Operation::MultiplyConstant(dynamic_cast(&children[0].getOperation())->getValue()/denominator), children[0].getChildren()[0]); - return ExpressionTreeNode(new Operation::MultiplyConstant(1.0/denominator), children[0]); // Replace a divide with a multiply + return ExpressionTreeNode(new Operation::MultiplyConstant(dynamic_cast(&children[0].getOperation())->getValue()/second), children[0].getChildren()[0]); + return ExpressionTreeNode(new Operation::MultiplyConstant(1.0/second), children[0]); // Replace a divide with a multiply } if (children[0].getOperation().getId() == Operation::NEGATE && children[1].getOperation().getId() == Operation::NEGATE) // The two negations cancel return ExpressionTreeNode(new Operation::Divide(), children[0].getChildren()[0], children[1].getChildren()[0]); @@ -229,34 +273,34 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio } case Operation::POWER: { - double base = getConstantValue(children[0]); - if (base == 0.0) // 0 to any power is 0 + if (first_const && first == 0.0) // 0 to any power is 0 return ExpressionTreeNode(new Operation::Constant(0.0)); - if (base == 1.0) // 1 to any power is 1 + if (first_const && first == 1.0) // 1 to any power is 1 return ExpressionTreeNode(new Operation::Constant(1.0)); - double exponent = getConstantValue(children[1]); - if (exponent == 0.0) // x^0 = 1 - return ExpressionTreeNode(new Operation::Constant(1.0)); - if (exponent == 1.0) // x^1 = x - return children[0]; - if (exponent == -1.0) // x^-1 = recip(x) - return ExpressionTreeNode(new Operation::Reciprocal(), children[0]); - if (exponent == 2.0) // x^2 = square(x) - return ExpressionTreeNode(new Operation::Square(), children[0]); - if (exponent == 3.0) // x^3 = cube(x) - return ExpressionTreeNode(new Operation::Cube(), children[0]); - if (exponent == 0.5) // x^0.5 = sqrt(x) - return ExpressionTreeNode(new Operation::Sqrt(), children[0]); - if (exponent == exponent) // Constant power - return ExpressionTreeNode(new Operation::PowerConstant(exponent), children[0]); + if (second_const) { // Constant exponent + if (second == 0.0) // x^0 = 1 + return ExpressionTreeNode(new Operation::Constant(1.0)); + if (second == 1.0) // x^1 = x + return children[0]; + if (second == -1.0) // x^-1 = recip(x) + return ExpressionTreeNode(new Operation::Reciprocal(), children[0]); + if (second == 2.0) // x^2 = square(x) + return ExpressionTreeNode(new Operation::Square(), children[0]); + if (second == 3.0) // x^3 = cube(x) + return ExpressionTreeNode(new Operation::Cube(), children[0]); + if (second == 0.5) // x^0.5 = sqrt(x) + return ExpressionTreeNode(new Operation::Sqrt(), children[0]); + // Constant power + return ExpressionTreeNode(new Operation::PowerConstant(second), children[0]); + } break; } case Operation::NEGATE: { if (children[0].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine a multiply and a negate into a single multiply return ExpressionTreeNode(new Operation::MultiplyConstant(-dynamic_cast(&children[0].getOperation())->getValue()), children[0].getChildren()[0]); - if (children[0].getOperation().getId() == Operation::CONSTANT) // Negate a constant - return ExpressionTreeNode(new Operation::Constant(-getConstantValue(children[0]))); + if (first_const) // Negate a constant + return ExpressionTreeNode(new Operation::Constant(-first)); if (children[0].getOperation().getId() == Operation::NEGATE) // The two negations cancel return children[0].getChildren()[0]; break; @@ -265,7 +309,7 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio { if (children[0].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine two multiplies into a single one return ExpressionTreeNode(new Operation::MultiplyConstant(dynamic_cast(&node.getOperation())->getValue()*dynamic_cast(&children[0].getOperation())->getValue()), children[0].getChildren()[0]); - if (children[0].getOperation().getId() == Operation::CONSTANT) // Multiply two constants + if (first_const) // Multiply two constants return ExpressionTreeNode(new Operation::Constant(dynamic_cast(&node.getOperation())->getValue()*getConstantValue(children[0]))); if (children[0].getOperation().getId() == Operation::NEGATE) // Combine a multiply and a negate into a single multiply return ExpressionTreeNode(new Operation::MultiplyConstant(-dynamic_cast(&node.getOperation())->getValue()), children[0].getChildren()[0]); @@ -293,20 +337,33 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio } ParsedExpression ParsedExpression::differentiate(const string& variable) const { - return differentiate(getRootNode(), variable); + vector examples; + getRootNode().assignTags(examples); + map nodeCache; + return differentiate(getRootNode(), variable, nodeCache); } -ExpressionTreeNode ParsedExpression::differentiate(const ExpressionTreeNode& node, const string& variable) { +ExpressionTreeNode ParsedExpression::differentiate(const ExpressionTreeNode& node, const string& variable, map& nodeCache) { + auto cached = nodeCache.find(node.tag); + if (cached != nodeCache.end()) + return cached->second; vector childDerivs(node.getChildren().size()); for (int i = 0; i < (int) childDerivs.size(); i++) - childDerivs[i] = differentiate(node.getChildren()[i], variable); - return node.getOperation().differentiate(node.getChildren(),childDerivs, variable); + childDerivs[i] = differentiate(node.getChildren()[i], variable, nodeCache); + ExpressionTreeNode result = node.getOperation().differentiate(node.getChildren(), childDerivs, variable); + nodeCache[node.tag] = result; + return result; +} + +bool ParsedExpression::isConstant(const ExpressionTreeNode& node) { + return (node.getOperation().getId() == Operation::CONSTANT); } double ParsedExpression::getConstantValue(const ExpressionTreeNode& node) { - if (node.getOperation().getId() == Operation::CONSTANT) - return dynamic_cast(node.getOperation()).getValue(); - return numeric_limits::quiet_NaN(); + if (node.getOperation().getId() != Operation::CONSTANT) { + throw Exception("getConstantValue called on a non-constant ExpressionNode"); + } + return dynamic_cast(node.getOperation()).getValue(); } ExpressionProgram ParsedExpression::createProgram() const { @@ -317,6 +374,10 @@ CompiledExpression ParsedExpression::createCompiledExpression() const { return CompiledExpression(*this); } +CompiledVectorExpression ParsedExpression::createCompiledVectorExpression(int width) const { + return CompiledVectorExpression(*this, width); +} + ParsedExpression ParsedExpression::renameVariables(const map& replacements) const { return ParsedExpression(renameNodeVariables(getRootNode(), replacements)); } diff --git a/lib/colvars/lepton/src/Parser.cpp b/lib/colvars/lepton/src/Parser.cpp index e284add258..47ebac464a 100644 --- a/lib/colvars/lepton/src/Parser.cpp +++ b/lib/colvars/lepton/src/Parser.cpp @@ -66,7 +66,7 @@ private: string Parser::trim(const string& expression) { // Remove leading and trailing spaces. - + int start, end; for (start = 0; start < (int) expression.size() && isspace(expression[start]); start++) ;