diff --git a/doc/src/PDF/colvars-refman-lammps.pdf b/doc/src/PDF/colvars-refman-lammps.pdf index 71eaee867a..c87ccbad75 100644 Binary files a/doc/src/PDF/colvars-refman-lammps.pdf and b/doc/src/PDF/colvars-refman-lammps.pdf differ diff --git a/lib/colvars/Makefile.common b/lib/colvars/Makefile.common index a920c24958..31a93652ae 100644 --- a/lib/colvars/Makefile.common +++ b/lib/colvars/Makefile.common @@ -62,10 +62,13 @@ COLVARS_SRCS = \ colvar_neuralnetworkcompute.cpp LEPTON_SRCS = \ - lepton/src/CompiledExpression.cpp lepton/src/ExpressionTreeNode.cpp \ - lepton/src/ParsedExpression.cpp lepton/src/ExpressionProgram.cpp \ - lepton/src/Operation.cpp lepton/src/Parser.cpp - + lepton/src/CompiledExpression.cpp \ + lepton/src/CompiledVectorExpression.cpp \ + lepton/src/ExpressionProgram.cpp \ + lepton/src/ExpressionTreeNode.cpp \ + lepton/src/Operation.cpp \ + lepton/src/ParsedExpression.cpp \ + lepton/src/Parser.cpp # Allow to selectively turn off Lepton ifeq ($(COLVARS_LEPTON),no) @@ -93,4 +96,13 @@ Makefile.deps: $(COLVARS_SRCS) done include Makefile.deps -include Makefile.lepton.deps # Hand-generated + +Makefile.lepton.deps: $(LEPTON_SRCS) + @echo > $@ + @for src in $^ ; do \ + obj=`basename $$src .cpp`.o ; \ + $(CXX) $(CXXFLAGS) -MM $(LEPTON_INCFLAGS) \ + -MT '$$(COLVARS_OBJ_DIR)'$$obj $$src >> $@ ; \ + done + +include Makefile.lepton.deps diff --git a/lib/colvars/Makefile.lepton.deps b/lib/colvars/Makefile.lepton.deps index 93c3912384..4546339de6 100644 --- a/lib/colvars/Makefile.lepton.deps +++ b/lib/colvars/Makefile.lepton.deps @@ -1,36 +1,46 @@ -lepton/src/CompiledExpression.o: lepton/src/CompiledExpression.cpp \ + +$(COLVARS_OBJ_DIR)CompiledExpression.o: lepton/src/CompiledExpression.cpp \ lepton/include/lepton/CompiledExpression.h \ lepton/include/lepton/ExpressionTreeNode.h \ lepton/include/lepton/windowsIncludes.h \ lepton/include/lepton/Operation.h lepton/include/lepton/CustomFunction.h \ lepton/include/lepton/Exception.h \ lepton/include/lepton/ParsedExpression.h -lepton/src/ExpressionProgram.o: lepton/src/ExpressionProgram.cpp \ +$(COLVARS_OBJ_DIR)CompiledVectorExpression.o: \ + lepton/src/CompiledVectorExpression.cpp \ + lepton/include/lepton/CompiledVectorExpression.h \ + lepton/include/lepton/ExpressionTreeNode.h \ + lepton/include/lepton/windowsIncludes.h \ + lepton/include/lepton/Operation.h lepton/include/lepton/CustomFunction.h \ + lepton/include/lepton/Exception.h \ + lepton/include/lepton/ParsedExpression.h +$(COLVARS_OBJ_DIR)ExpressionProgram.o: lepton/src/ExpressionProgram.cpp \ lepton/include/lepton/ExpressionProgram.h \ lepton/include/lepton/ExpressionTreeNode.h \ lepton/include/lepton/windowsIncludes.h \ lepton/include/lepton/Operation.h lepton/include/lepton/CustomFunction.h \ lepton/include/lepton/Exception.h \ lepton/include/lepton/ParsedExpression.h -lepton/src/ExpressionTreeNode.o: lepton/src/ExpressionTreeNode.cpp \ +$(COLVARS_OBJ_DIR)ExpressionTreeNode.o: lepton/src/ExpressionTreeNode.cpp \ lepton/include/lepton/ExpressionTreeNode.h \ lepton/include/lepton/windowsIncludes.h \ lepton/include/lepton/Exception.h lepton/include/lepton/Operation.h \ lepton/include/lepton/CustomFunction.h lepton/include/lepton/Exception.h -lepton/src/Operation.o: lepton/src/Operation.cpp \ +$(COLVARS_OBJ_DIR)Operation.o: lepton/src/Operation.cpp \ lepton/include/lepton/Operation.h \ lepton/include/lepton/windowsIncludes.h \ lepton/include/lepton/CustomFunction.h lepton/include/lepton/Exception.h \ lepton/include/lepton/ExpressionTreeNode.h lepton/src/MSVC_erfc.h -lepton/src/ParsedExpression.o: lepton/src/ParsedExpression.cpp \ +$(COLVARS_OBJ_DIR)ParsedExpression.o: lepton/src/ParsedExpression.cpp \ lepton/include/lepton/ParsedExpression.h \ lepton/include/lepton/ExpressionTreeNode.h \ lepton/include/lepton/windowsIncludes.h \ lepton/include/lepton/CompiledExpression.h \ + lepton/include/lepton/CompiledVectorExpression.h \ lepton/include/lepton/ExpressionProgram.h \ lepton/include/lepton/Operation.h lepton/include/lepton/CustomFunction.h \ lepton/include/lepton/Exception.h -lepton/src/Parser.o: lepton/src/Parser.cpp \ +$(COLVARS_OBJ_DIR)Parser.o: lepton/src/Parser.cpp \ lepton/include/lepton/Parser.h lepton/include/lepton/windowsIncludes.h \ lepton/include/lepton/CustomFunction.h lepton/include/lepton/Exception.h \ lepton/include/lepton/ExpressionTreeNode.h \ diff --git a/lib/colvars/colvargrid.h b/lib/colvars/colvargrid.h index e0b2ec7f03..f34c5eccab 100644 --- a/lib/colvars/colvargrid.h +++ b/lib/colvars/colvargrid.h @@ -1275,7 +1275,7 @@ public: inline cvm::real log_gradient_finite_diff(const std::vector &ix0, int n = 0) { - int A0, A1, A2; + cvm::real A0, A1, A2; std::vector ix = ix0; // TODO this can be rewritten more concisely with wrap_edge() @@ -1288,7 +1288,7 @@ public: if (A0 * A1 == 0) { return 0.; // can't handle empty bins } else { - return (cvm::logn((cvm::real)A1) - cvm::logn((cvm::real)A0)) + return (cvm::logn(A1) - cvm::logn(A0)) / (widths[n] * 2.); } } else if (ix[n] > 0 && ix[n] < nx[n]-1) { // not an edge @@ -1300,7 +1300,7 @@ public: if (A0 * A1 == 0) { return 0.; // can't handle empty bins } else { - return (cvm::logn((cvm::real)A1) - cvm::logn((cvm::real)A0)) + return (cvm::logn(A1) - cvm::logn(A0)) / (widths[n] * 2.); } } else { @@ -1313,8 +1313,8 @@ public: if (A0 * A1 * A2 == 0) { return 0.; // can't handle empty bins } else { - return (-1.5 * cvm::logn((cvm::real)A0) + 2. * cvm::logn((cvm::real)A1) - - 0.5 * cvm::logn((cvm::real)A2)) * increment / widths[n]; + return (-1.5 * cvm::logn(A0) + 2. * cvm::logn(A1) + - 0.5 * cvm::logn(A2)) * increment / widths[n]; } } } @@ -1324,7 +1324,7 @@ public: inline cvm::real gradient_finite_diff(const std::vector &ix0, int n = 0) { - int A0, A1, A2; + cvm::real A0, A1, A2; std::vector ix = ix0; // FIXME this can be rewritten more concisely with wrap_edge() @@ -1337,7 +1337,7 @@ public: if (A0 * A1 == 0) { return 0.; // can't handle empty bins } else { - return cvm::real(A1 - A0) / (widths[n] * 2.); + return (A1 - A0) / (widths[n] * 2.); } } else if (ix[n] > 0 && ix[n] < nx[n]-1) { // not an edge ix[n]--; @@ -1348,7 +1348,7 @@ public: if (A0 * A1 == 0) { return 0.; // can't handle empty bins } else { - return cvm::real(A1 - A0) / (widths[n] * 2.); + return (A1 - A0) / (widths[n] * 2.); } } else { // edge: use 2nd order derivative @@ -1357,8 +1357,8 @@ public: A0 = value(ix); ix[n] += increment; A1 = value(ix); ix[n] += increment; A2 = value(ix); - return (-1.5 * cvm::real(A0) + 2. * cvm::real(A1) - - 0.5 * cvm::real(A2)) * increment / widths[n]; + return (-1.5 * A0 + 2. * A1 + - 0.5 * A2) * increment / widths[n]; } } }; diff --git a/lib/colvars/colvarproxy_tcl.cpp b/lib/colvars/colvarproxy_tcl.cpp index 33bdc9dc38..700492f0e7 100644 --- a/lib/colvars/colvarproxy_tcl.cpp +++ b/lib/colvars/colvarproxy_tcl.cpp @@ -22,9 +22,7 @@ colvarproxy_tcl::colvarproxy_tcl() { -#ifdef COLVARS_TCL tcl_interp_ = NULL; -#endif } diff --git a/lib/colvars/colvars_version.h b/lib/colvars/colvars_version.h index 2a1d449ab5..d2a48f8af7 100644 --- a/lib/colvars/colvars_version.h +++ b/lib/colvars/colvars_version.h @@ -1,3 +1,3 @@ #ifndef COLVARS_VERSION -#define COLVARS_VERSION "2022-05-09" +#define COLVARS_VERSION "2022-05-24" #endif diff --git a/lib/colvars/lepton/include/lepton/CompiledExpression.h b/lib/colvars/lepton/include/lepton/CompiledExpression.h index c7e393e93b..84ec2eb410 100644 --- a/lib/colvars/lepton/include/lepton/CompiledExpression.h +++ b/lib/colvars/lepton/include/lepton/CompiledExpression.h @@ -9,7 +9,7 @@ * Biological Structures at Stanford, funded under the NIH Roadmap for * * Medical Research, grant U54 GM072970. See https://simtk.org. * * * - * Portions copyright (c) 2013-2019 Stanford University and the Authors. * + * Portions copyright (c) 2013-2022 Stanford University and the Authors. * * Authors: Peter Eastman * * Contributors: * * * @@ -40,7 +40,11 @@ #include #include #ifdef LEPTON_USE_JIT - #include "asmjit.h" +#if defined(__ARM__) || defined(__ARM64__) +#include "asmjit/a64.h" +#else +#include "asmjit/x86.h" +#endif #endif namespace Lepton { @@ -52,9 +56,9 @@ class ParsedExpression; * A CompiledExpression is a highly optimized representation of an expression for cases when you want to evaluate * it many times as quickly as possible. You should treat it as an opaque object; none of the internal representation * is visible. - * + * * A CompiledExpression is created by calling createCompiledExpression() on a ParsedExpression. - * + * * WARNING: CompiledExpression is NOT thread safe. You should never access a CompiledExpression from two threads at * the same time. */ @@ -101,9 +105,15 @@ private: std::map dummyVariables; double (*jitCode)(); #ifdef LEPTON_USE_JIT + void findPowerGroups(std::vector >& groups, std::vector >& groupPowers, std::vector& stepGroup); void generateJitCode(); - void generateSingleArgCall(asmjit::X86Compiler& c, asmjit::X86Xmm& dest, asmjit::X86Xmm& arg, double (*function)(double)); - void generateTwoArgCall(asmjit::X86Compiler& c, asmjit::X86Xmm& dest, asmjit::X86Xmm& arg1, asmjit::X86Xmm& arg2, double (*function)(double, double)); +#if defined(__ARM__) || defined(__ARM64__) + void generateSingleArgCall(asmjit::a64::Compiler& c, asmjit::arm::Vec& dest, asmjit::arm::Vec& arg, double (*function)(double)); + void generateTwoArgCall(asmjit::a64::Compiler& c, asmjit::arm::Vec& dest, asmjit::arm::Vec& arg1, asmjit::arm::Vec& arg2, double (*function)(double, double)); +#else + void generateSingleArgCall(asmjit::x86::Compiler& c, asmjit::x86::Xmm& dest, asmjit::x86::Xmm& arg, double (*function)(double)); + void generateTwoArgCall(asmjit::x86::Compiler& c, asmjit::x86::Xmm& dest, asmjit::x86::Xmm& arg1, asmjit::x86::Xmm& arg2, double (*function)(double, double)); +#endif std::vector constants; asmjit::JitRuntime runtime; #endif diff --git a/lib/colvars/lepton/include/lepton/CompiledVectorExpression.h b/lib/colvars/lepton/include/lepton/CompiledVectorExpression.h new file mode 100644 index 0000000000..a9dd936750 --- /dev/null +++ b/lib/colvars/lepton/include/lepton/CompiledVectorExpression.h @@ -0,0 +1,145 @@ +#ifndef LEPTON_VECTOR_EXPRESSION_H_ +#define LEPTON_VECTOR_EXPRESSION_H_ + +/* -------------------------------------------------------------------------- * + * Lepton * + * -------------------------------------------------------------------------- * + * This is part of the Lepton expression parser originating from * + * Simbios, the NIH National Center for Physics-Based Simulation of * + * Biological Structures at Stanford, funded under the NIH Roadmap for * + * Medical Research, grant U54 GM072970. See https://simtk.org. * + * * + * Portions copyright (c) 2013-2022 Stanford University and the Authors. * + * Authors: Peter Eastman * + * Contributors: * + * * + * Permission is hereby granted, free of charge, to any person obtaining a * + * copy of this software and associated documentation files (the "Software"), * + * to deal in the Software without restriction, including without limitation * + * the rights to use, copy, modify, merge, publish, distribute, sublicense, * + * and/or sell copies of the Software, and to permit persons to whom the * + * Software is furnished to do so, subject to the following conditions: * + * * + * The above copyright notice and this permission notice shall be included in * + * all copies or substantial portions of the Software. * + * * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL * + * THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, * + * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR * + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE * + * USE OR OTHER DEALINGS IN THE SOFTWARE. * + * -------------------------------------------------------------------------- */ + +#include "ExpressionTreeNode.h" +#include "windowsIncludes.h" +#include +#include +#include +#include +#include +#include +#ifdef LEPTON_USE_JIT +#if defined(__ARM__) || defined(__ARM64__) +#include "asmjit/a64.h" +#else +#include "asmjit/x86.h" +#endif +#endif + +namespace Lepton { + +class Operation; +class ParsedExpression; + +/** + * A CompiledVectorExpression is a highly optimized representation of an expression for cases when you want to evaluate + * it many times as quickly as possible. It is similar to CompiledExpression, with the extra feature that it uses the CPU's + * vector unit (AVX on x86, NEON on ARM) to evaluate the expression for multiple sets of arguments at once. It also differs + * from CompiledExpression and ParsedExpression in using single precision rather than double precision to evaluate the expression. + * You should treat it as an opaque object; none of the internal representation is visible. + * + * A CompiledVectorExpression is created by calling createCompiledVectorExpression() on a ParsedExpression. When you create + * it, you must specify the width of the vectors on which to compute the expression. The allowed widths depend on the type of + * CPU it is running on. 4 is always allowed, and 8 is allowed on x86 processors with AVX. Call getAllowedWidths() to query + * the allowed values. + * + * WARNING: CompiledVectorExpression is NOT thread safe. You should never access a CompiledVectorExpression from two threads at + * the same time. + */ + +class LEPTON_EXPORT CompiledVectorExpression { +public: + CompiledVectorExpression(); + CompiledVectorExpression(const CompiledVectorExpression& expression); + ~CompiledVectorExpression(); + CompiledVectorExpression& operator=(const CompiledVectorExpression& expression); + /** + * Get the width of the vectors on which the expression is computed. + */ + int getWidth() const; + /** + * Get the names of all variables used by this expression. + */ + const std::set& getVariables() const; + /** + * Get a pointer to the memory location where the value of a particular variable is stored. This can be used + * to set the value of the variable before calling evaluate(). + * + * @param name the name of the variable to query + * @return a pointer to N floating point values, where N is the vector width + */ + float* getVariablePointer(const std::string& name); + /** + * You can optionally specify the memory locations from which the values of variables should be read. + * This is useful, for example, when several expressions all use the same variable. You can then set + * the value of that variable in one place, and it will be seen by all of them. The location should + * be a pointer to N floating point values, where N is the vector width. + */ + void setVariableLocations(std::map& variableLocations); + /** + * Evaluate the expression. The values of all variables should have been set before calling this. + * + * @return a pointer to N floating point values, where N is the vector width + */ + const float* evaluate() const; + /** + * Get the list of vector widths that are supported on the current processor. + */ + static const std::vector& getAllowedWidths(); +private: + friend class ParsedExpression; + CompiledVectorExpression(const ParsedExpression& expression, int width); + void compileExpression(const ExpressionTreeNode& node, std::vector >& temps, int& workspaceSize); + int findTempIndex(const ExpressionTreeNode& node, std::vector >& temps); + int width; + std::map variablePointers; + std::vector > variablesToCopy; + std::vector > arguments; + std::vector target; + std::vector operation; + std::map variableIndices; + std::set variableNames; + mutable std::vector workspace; + mutable std::vector argValues; + std::map dummyVariables; + void (*jitCode)(); +#ifdef LEPTON_USE_JIT + void findPowerGroups(std::vector >& groups, std::vector >& groupPowers, std::vector& stepGroup); + void generateJitCode(); +#if defined(__ARM__) || defined(__ARM64__) + void generateSingleArgCall(asmjit::a64::Compiler& c, asmjit::arm::Vec& dest, asmjit::arm::Vec& arg, float (*function)(float)); + void generateTwoArgCall(asmjit::a64::Compiler& c, asmjit::arm::Vec& dest, asmjit::arm::Vec& arg1, asmjit::arm::Vec& arg2, float (*function)(float, float)); +#else + void generateSingleArgCall(asmjit::x86::Compiler& c, asmjit::x86::Ymm& dest, asmjit::x86::Ymm& arg, float (*function)(float)); + void generateTwoArgCall(asmjit::x86::Compiler& c, asmjit::x86::Ymm& dest, asmjit::x86::Ymm& arg1, asmjit::x86::Ymm& arg2, float (*function)(float, float)); +#endif + std::vector constants; + asmjit::JitRuntime runtime; +#endif +}; + +} // namespace Lepton + +#endif /*LEPTON_VECTOR_EXPRESSION_H_*/ diff --git a/lib/colvars/lepton/include/lepton/CustomFunction.h b/lib/colvars/lepton/include/lepton/CustomFunction.h index fbb0ddd52a..7b6a2b6834 100644 --- a/lib/colvars/lepton/include/lepton/CustomFunction.h +++ b/lib/colvars/lepton/include/lepton/CustomFunction.h @@ -83,7 +83,7 @@ class LEPTON_EXPORT PlaceholderFunction : public CustomFunction { public: /** * Create a Placeholder function. - * + * * @param numArgs the number of arguments the function expects */ PlaceholderFunction(int numArgs) : numArgs(numArgs) { diff --git a/lib/colvars/lepton/include/lepton/ExpressionProgram.h b/lib/colvars/lepton/include/lepton/ExpressionProgram.h index a49a9094d0..e989906288 100644 --- a/lib/colvars/lepton/include/lepton/ExpressionProgram.h +++ b/lib/colvars/lepton/include/lepton/ExpressionProgram.h @@ -67,7 +67,7 @@ public: const Operation& getOperation(int index) const; /** * Change an Operation in this program. - * + * * The Operation must have been allocated on the heap with the "new" operator. * The ExpressionProgram assumes ownership of it and will delete it when it * is no longer needed. diff --git a/lib/colvars/lepton/include/lepton/ExpressionTreeNode.h b/lib/colvars/lepton/include/lepton/ExpressionTreeNode.h index bf3a9a0902..dde26103cb 100644 --- a/lib/colvars/lepton/include/lepton/ExpressionTreeNode.h +++ b/lib/colvars/lepton/include/lepton/ExpressionTreeNode.h @@ -9,7 +9,7 @@ * Biological Structures at Stanford, funded under the NIH Roadmap for * * Medical Research, grant U54 GM072970. See https://simtk.org. * * * - * Portions copyright (c) 2009 Stanford University and the Authors. * + * Portions copyright (c) 2009-2021 Stanford University and the Authors. * * Authors: Peter Eastman * * Contributors: * * * @@ -39,6 +39,7 @@ namespace Lepton { class Operation; +class ParsedExpression; /** * This class represents a node in the abstract syntax tree representation of an expression. @@ -82,11 +83,13 @@ public: */ ExpressionTreeNode(Operation* operation); ExpressionTreeNode(const ExpressionTreeNode& node); + ExpressionTreeNode(ExpressionTreeNode&& node); ExpressionTreeNode(); ~ExpressionTreeNode(); bool operator==(const ExpressionTreeNode& node) const; bool operator!=(const ExpressionTreeNode& node) const; ExpressionTreeNode& operator=(const ExpressionTreeNode& node); + ExpressionTreeNode& operator=(ExpressionTreeNode&& node); /** * Get the Operation performed by this node. */ @@ -96,8 +99,11 @@ public: */ const std::vector& getChildren() const; private: + friend class ParsedExpression; + void assignTags(std::vector& examples) const; Operation* operation; std::vector children; + mutable int tag; }; } // namespace Lepton diff --git a/lib/colvars/lepton/include/lepton/Operation.h b/lib/colvars/lepton/include/lepton/Operation.h index 1ddde0b8c0..4b8969cd59 100644 --- a/lib/colvars/lepton/include/lepton/Operation.h +++ b/lib/colvars/lepton/include/lepton/Operation.h @@ -1017,7 +1017,7 @@ public: double evaluate(double* args, const std::map& variables) const { if (isIntPower) { // Integer powers can be computed much more quickly by repeated multiplication. - + int exponent = intValue; double base = args[0]; if (exponent < 0) { diff --git a/lib/colvars/lepton/include/lepton/ParsedExpression.h b/lib/colvars/lepton/include/lepton/ParsedExpression.h index d88b3d5829..6c6526e525 100644 --- a/lib/colvars/lepton/include/lepton/ParsedExpression.h +++ b/lib/colvars/lepton/include/lepton/ParsedExpression.h @@ -9,7 +9,7 @@ * Biological Structures at Stanford, funded under the NIH Roadmap for * * Medical Research, grant U54 GM072970. See https://simtk.org. * * * - * Portions copyright (c) 2009=2013 Stanford University and the Authors. * + * Portions copyright (c) 2009-2022 Stanford University and the Authors. * * Authors: Peter Eastman * * Contributors: * * * @@ -41,6 +41,7 @@ namespace Lepton { class CompiledExpression; class ExpressionProgram; +class CompiledVectorExpression; /** * This class represents the result of parsing an expression. It provides methods for working with the @@ -102,6 +103,16 @@ public: * Create a CompiledExpression that represents the same calculation as this expression. */ CompiledExpression createCompiledExpression() const; + /** + * Create a CompiledVectorExpression that allows the expression to be evaluated efficiently + * using the CPU's vector unit. + * + * @param width the width of the vectors to evaluate it on. The allowed values + * depend on the CPU. 4 is always allowed, and 8 is allowed on + * x86 processors with AVX. Call CompiledVectorExpression::getAllowedWidths() + * to query the allowed widths on the current processor. + */ + CompiledVectorExpression createCompiledVectorExpression(int width) const; /** * Create a new ParsedExpression which is identical to this one, except that the names of some * variables have been changed. @@ -113,9 +124,10 @@ public: private: static double evaluate(const ExpressionTreeNode& node, const std::map& variables); static ExpressionTreeNode preevaluateVariables(const ExpressionTreeNode& node, const std::map& variables); - static ExpressionTreeNode precalculateConstantSubexpressions(const ExpressionTreeNode& node); - static ExpressionTreeNode substituteSimplerExpression(const ExpressionTreeNode& node); - static ExpressionTreeNode differentiate(const ExpressionTreeNode& node, const std::string& variable); + static ExpressionTreeNode precalculateConstantSubexpressions(const ExpressionTreeNode& node, std::map& nodeCache); + static ExpressionTreeNode substituteSimplerExpression(const ExpressionTreeNode& node, std::map& nodeCache); + static ExpressionTreeNode differentiate(const ExpressionTreeNode& node, const std::string& variable, std::map& nodeCache); + static bool isConstant(const ExpressionTreeNode& node); static double getConstantValue(const ExpressionTreeNode& node); static ExpressionTreeNode renameNodeVariables(const ExpressionTreeNode& node, const std::map& replacements); ExpressionTreeNode rootNode; diff --git a/lib/colvars/lepton/src/CompiledExpression.cpp b/lib/colvars/lepton/src/CompiledExpression.cpp index 1ad348b47d..8a0239b04f 100644 --- a/lib/colvars/lepton/src/CompiledExpression.cpp +++ b/lib/colvars/lepton/src/CompiledExpression.cpp @@ -6,7 +6,7 @@ * Biological Structures at Stanford, funded under the NIH Roadmap for * * Medical Research, grant U54 GM072970. See https://simtk.org. * * * - * Portions copyright (c) 2013-2019 Stanford University and the Authors. * + * Portions copyright (c) 2013-2022 Stanford University and the Authors. * * Authors: Peter Eastman * * Contributors: * * * @@ -84,17 +84,17 @@ CompiledExpression& CompiledExpression::operator=(const CompiledExpression& expr void CompiledExpression::compileExpression(const ExpressionTreeNode& node, vector >& temps) { if (findTempIndex(node, temps) != -1) return; // We have already processed a node identical to this one. - + // Process the child nodes. - + vector args; for (int i = 0; i < node.getChildren().size(); i++) { compileExpression(node.getChildren()[i], temps); args.push_back(findTempIndex(node.getChildren()[i], temps)); } - + // Process this node. - + if (node.getOperation().getId() == Operation::VARIABLE) { variableIndices[node.getOperation().getName()] = (int) workspace.size(); variableNames.insert(node.getOperation().getName()); @@ -108,7 +108,7 @@ void CompiledExpression::compileExpression(const ExpressionTreeNode& node, vecto arguments[stepIndex].push_back(0); // The value won't actually be used. We just need something there. else { // If the arguments are sequential, we can just pass a pointer to the first one. - + bool sequential = true; for (int i = 1; i < args.size(); i++) if (args[i] != args[i-1]+1) @@ -148,30 +148,28 @@ void CompiledExpression::setVariableLocations(map& variableLoca variablePointers = variableLocations; #ifdef LEPTON_USE_JIT // Rebuild the JIT code. - + if (workspace.size() > 0) generateJitCode(); -#else +#endif // Make a list of all variables we will need to copy before evaluating the expression. - + variablesToCopy.clear(); for (map::const_iterator iter = variableIndices.begin(); iter != variableIndices.end(); ++iter) { map::iterator pointer = variablePointers.find(iter->first); if (pointer != variablePointers.end()) variablesToCopy.push_back(make_pair(&workspace[iter->second], pointer->second)); } -#endif } double CompiledExpression::evaluate() const { -#ifdef LEPTON_USE_JIT - return jitCode(); -#else + if (jitCode) + return jitCode(); for (int i = 0; i < variablesToCopy.size(); i++) *variablesToCopy[i].first = *variablesToCopy[i].second; // Loop over the operations and evaluate each one. - + for (int step = 0; step < operation.size(); step++) { const vector& args = arguments[step]; if (args.size() == 1) @@ -183,7 +181,6 @@ double CompiledExpression::evaluate() const { } } return workspace[workspace.size()-1]; -#endif } #ifdef LEPTON_USE_JIT @@ -192,32 +189,78 @@ static double evaluateOperation(Operation* op, double* args) { return op->evaluate(args, dummyVariables); } +void CompiledExpression::findPowerGroups(vector >& groups, vector >& groupPowers, vector& stepGroup) { + // Identify every step that raises an argument to an integer power. + + vector stepPower(operation.size(), 0); + vector stepArg(operation.size(), -1); + for (int step = 0; step < operation.size(); step++) { + Operation& op = *operation[step]; + int power = 0; + if (op.getId() == Operation::SQUARE) + power = 2; + else if (op.getId() == Operation::CUBE) + power = 3; + else if (op.getId() == Operation::POWER_CONSTANT) { + double realPower = dynamic_cast(&op)->getValue(); + if (realPower == (int) realPower) + power = (int) realPower; + } + if (power != 0) { + stepPower[step] = power; + stepArg[step] = arguments[step][0]; + } + } + + // Find groups that operate on the same argument and whose powers have the same sign. + + stepGroup.resize(operation.size(), -1); + for (int i = 0; i < operation.size(); i++) { + if (stepGroup[i] != -1) + continue; + vector group, power; + for (int j = i; j < operation.size(); j++) { + if (stepArg[i] == stepArg[j] && stepPower[i]*stepPower[j] > 0) { + stepGroup[j] = groups.size(); + group.push_back(j); + power.push_back(stepPower[j]); + } + } + groups.push_back(group); + groupPowers.push_back(power); + } +} + +#if defined(__ARM__) || defined(__ARM64__) void CompiledExpression::generateJitCode() { CodeHolder code; - code.init(runtime.getCodeInfo()); - X86Compiler c(&code); - c.addFunc(FuncSignature0()); - vector workspaceVar(workspace.size()); + code.init(runtime.environment()); + a64::Compiler c(&code); + c.addFunc(FuncSignatureT()); + vector workspaceVar(workspace.size()); for (int i = 0; i < (int) workspaceVar.size(); i++) - workspaceVar[i] = c.newXmmSd(); - X86Gp argsPointer = c.newIntPtr(); - c.mov(argsPointer, imm_ptr(&argValues[0])); - + workspaceVar[i] = c.newVecD(); + arm::Gp argsPointer = c.newIntPtr(); + c.mov(argsPointer, imm(&argValues[0])); + vector > groups, groupPowers; + vector stepGroup; + findPowerGroups(groups, groupPowers, stepGroup); + // Load the arguments into variables. - + for (set::const_iterator iter = variableNames.begin(); iter != variableNames.end(); ++iter) { map::iterator index = variableIndices.find(*iter); - X86Gp variablePointer = c.newIntPtr(); - c.mov(variablePointer, imm_ptr(&getVariableReference(index->first))); - c.movsd(workspaceVar[index->second], x86::ptr(variablePointer, 0, 0)); + arm::Gp variablePointer = c.newIntPtr(); + c.mov(variablePointer, imm(&getVariableReference(index->first))); + c.ldr(workspaceVar[index->second], arm::ptr(variablePointer, 0)); } // Make a list of all constants that will be needed for evaluation. - + vector operationConstantIndex(operation.size(), -1); for (int step = 0; step < (int) operation.size(); step++) { // Find the constant value (if any) used by this operation. - + Operation& op = *operation[step]; double value; if (op.getId() == Operation::CONSTANT) @@ -232,11 +275,17 @@ void CompiledExpression::generateJitCode() { value = 1.0; else if (op.getId() == Operation::DELTA) value = 1.0; + else if (op.getId() == Operation::POWER_CONSTANT) { + if (stepGroup[step] == -1) + value = dynamic_cast(op).getValue(); + else + value = 1.0; + } else continue; - + // See if we already have a variable for this constant. - + for (int i = 0; i < (int) constants.size(); i++) if (value == constants[i]) { operationConstantIndex[step] = i; @@ -247,62 +296,101 @@ void CompiledExpression::generateJitCode() { constants.push_back(value); } } - + // Load constants into variables. - - vector constantVar(constants.size()); + + vector constantVar(constants.size()); if (constants.size() > 0) { - X86Gp constantsPointer = c.newIntPtr(); - c.mov(constantsPointer, imm_ptr(&constants[0])); + arm::Gp constantsPointer = c.newIntPtr(); + c.mov(constantsPointer, imm(&constants[0])); for (int i = 0; i < (int) constants.size(); i++) { - constantVar[i] = c.newXmmSd(); - c.movsd(constantVar[i], x86::ptr(constantsPointer, 8*i, 0)); + constantVar[i] = c.newVecD(); + c.ldr(constantVar[i], arm::ptr(constantsPointer, 8*i)); } } // Evaluate the operations. + vector hasComputedPower(operation.size(), false); for (int step = 0; step < (int) operation.size(); step++) { + if (hasComputedPower[step]) + continue; + + // When one or more steps involve raising the same argument to multiple integer + // powers, we can compute them all together for efficiency. + + if (stepGroup[step] != -1) { + vector& group = groups[stepGroup[step]]; + vector& powers = groupPowers[stepGroup[step]]; + arm::Vec multiplier = c.newVecD(); + if (powers[0] > 0) + c.fmov(multiplier, workspaceVar[arguments[step][0]]); + else { + c.fdiv(multiplier, constantVar[operationConstantIndex[step]], workspaceVar[arguments[step][0]]); + for (int i = 0; i < powers.size(); i++) + powers[i] = -powers[i]; + } + vector hasAssigned(group.size(), false); + bool done = false; + while (!done) { + done = true; + for (int i = 0; i < group.size(); i++) { + if (powers[i]%2 == 1) { + if (!hasAssigned[i]) + c.fmov(workspaceVar[target[group[i]]], multiplier); + else + c.fmul(workspaceVar[target[group[i]]], workspaceVar[target[group[i]]], multiplier); + hasAssigned[i] = true; + } + powers[i] >>= 1; + if (powers[i] != 0) + done = false; + } + if (!done) + c.fmul(multiplier, multiplier, multiplier); + } + for (int step : group) + hasComputedPower[step] = true; + continue; + } + + // Evaluate the step. + Operation& op = *operation[step]; vector args = arguments[step]; if (args.size() == 1) { // One or more sequential arguments. Fill out the list. - + for (int i = 1; i < op.getNumArguments(); i++) args.push_back(args[0]+i); } - + // Generate instructions to execute this operation. - + switch (op.getId()) { case Operation::CONSTANT: - c.movsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + c.fmov(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); break; case Operation::ADD: - c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); - c.addsd(workspaceVar[target[step]], workspaceVar[args[1]]); + c.fadd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); break; case Operation::SUBTRACT: - c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); - c.subsd(workspaceVar[target[step]], workspaceVar[args[1]]); + c.fsub(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); break; case Operation::MULTIPLY: - c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); - c.mulsd(workspaceVar[target[step]], workspaceVar[args[1]]); + c.fmul(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); break; case Operation::DIVIDE: - c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); - c.divsd(workspaceVar[target[step]], workspaceVar[args[1]]); + c.fdiv(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); break; case Operation::POWER: generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], pow); break; case Operation::NEGATE: - c.xorps(workspaceVar[target[step]], workspaceVar[target[step]]); - c.subsd(workspaceVar[target[step]], workspaceVar[args[0]]); + c.fneg(workspaceVar[target[step]], workspaceVar[args[0]]); break; case Operation::SQRT: - c.sqrtsd(workspaceVar[target[step]], workspaceVar[args[0]]); + c.fsqrt(workspaceVar[target[step]], workspaceVar[args[0]]); break; case Operation::EXP: generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], exp); @@ -341,56 +429,63 @@ void CompiledExpression::generateJitCode() { generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tanh); break; case Operation::STEP: - c.xorps(workspaceVar[target[step]], workspaceVar[target[step]]); - c.cmpsd(workspaceVar[target[step]], workspaceVar[args[0]], imm(18)); // Comparison mode is _CMP_LE_OQ = 18 - c.andps(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + c.cmge(workspaceVar[target[step]], workspaceVar[args[0]], imm(0)); + c.and_(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); break; case Operation::DELTA: - c.xorps(workspaceVar[target[step]], workspaceVar[target[step]]); - c.cmpsd(workspaceVar[target[step]], workspaceVar[args[0]], imm(16)); // Comparison mode is _CMP_EQ_OS = 16 - c.andps(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + c.cmeq(workspaceVar[target[step]], workspaceVar[args[0]], imm(0)); + c.and_(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); break; case Operation::SQUARE: - c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); - c.mulsd(workspaceVar[target[step]], workspaceVar[args[0]]); + c.fmul(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]]); break; case Operation::CUBE: - c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); - c.mulsd(workspaceVar[target[step]], workspaceVar[args[0]]); - c.mulsd(workspaceVar[target[step]], workspaceVar[args[0]]); + c.fmul(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]]); + c.fmul(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]]); break; case Operation::RECIPROCAL: - c.movsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); - c.divsd(workspaceVar[target[step]], workspaceVar[args[0]]); + c.fdiv(workspaceVar[target[step]], constantVar[operationConstantIndex[step]], workspaceVar[args[0]]); break; case Operation::ADD_CONSTANT: - c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); - c.addsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + c.fadd(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]); break; case Operation::MULTIPLY_CONSTANT: - c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); - c.mulsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + c.fmul(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]); + break; + case Operation::POWER_CONSTANT: + generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]], pow); + break; + case Operation::MIN: + c.fmin(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::MAX: + c.fmax(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); break; case Operation::ABS: - generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], fabs); + c.fabs(workspaceVar[target[step]], workspaceVar[args[0]]); break; case Operation::FLOOR: - generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], floor); + c.frintm(workspaceVar[target[step]], workspaceVar[args[0]]); break; case Operation::CEIL: - generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], ceil); + c.frintp(workspaceVar[target[step]], workspaceVar[args[0]]); + break; + case Operation::SELECT: + c.fcmeq(workspaceVar[target[step]], workspaceVar[args[0]], imm(0)); + c.bsl(workspaceVar[target[step]], workspaceVar[args[2]], workspaceVar[args[1]]); break; default: // Just invoke evaluateOperation(). - + for (int i = 0; i < (int) args.size(); i++) - c.movsd(x86::ptr(argsPointer, 8*i, 0), workspaceVar[args[i]]); - X86Gp fn = c.newIntPtr(); - c.mov(fn, imm_ptr((void*) evaluateOperation)); - CCFuncCall* call = c.call(fn, FuncSignature2()); - call->setArg(0, imm_ptr(&op)); - call->setArg(1, imm_ptr(&argValues[0])); - call->setRet(0, workspaceVar[target[step]]); + c.str(workspaceVar[args[i]], arm::ptr(argsPointer, 8*i)); + arm::Gp fn = c.newIntPtr(); + c.mov(fn, imm((void*) evaluateOperation)); + InvokeNode* invoke; + c.invoke(&invoke, fn, FuncSignatureT()); + invoke->setArg(0, imm(&op)); + invoke->setArg(1, imm(&argValues[0])); + invoke->setRet(0, workspaceVar[target[step]]); } } c.ret(workspaceVar[workspace.size()-1]); @@ -399,20 +494,319 @@ void CompiledExpression::generateJitCode() { runtime.add(&jitCode, &code); } -void CompiledExpression::generateSingleArgCall(X86Compiler& c, X86Xmm& dest, X86Xmm& arg, double (*function)(double)) { - X86Gp fn = c.newIntPtr(); - c.mov(fn, imm_ptr((void*) function)); - CCFuncCall* call = c.call(fn, FuncSignature1()); - call->setArg(0, arg); - call->setRet(0, dest); +void CompiledExpression::generateSingleArgCall(a64::Compiler& c, arm::Vec& dest, arm::Vec& arg, double (*function)(double)) { + arm::Gp fn = c.newIntPtr(); + c.mov(fn, imm((void*) function)); + InvokeNode* invoke; + c.invoke(&invoke, fn, FuncSignatureT()); + invoke->setArg(0, arg); + invoke->setRet(0, dest); } -void CompiledExpression::generateTwoArgCall(X86Compiler& c, X86Xmm& dest, X86Xmm& arg1, X86Xmm& arg2, double (*function)(double, double)) { - X86Gp fn = c.newIntPtr(); - c.mov(fn, imm_ptr((void*) function)); - CCFuncCall* call = c.call(fn, FuncSignature2()); - call->setArg(0, arg1); - call->setArg(1, arg2); - call->setRet(0, dest); +void CompiledExpression::generateTwoArgCall(a64::Compiler& c, arm::Vec& dest, arm::Vec& arg1, arm::Vec& arg2, double (*function)(double, double)) { + arm::Gp fn = c.newIntPtr(); + c.mov(fn, imm((void*) function)); + InvokeNode* invoke; + c.invoke(&invoke, fn, FuncSignatureT()); + invoke->setArg(0, arg1); + invoke->setArg(1, arg2); + invoke->setRet(0, dest); +} +#else +void CompiledExpression::generateJitCode() { + const CpuInfo& cpu = CpuInfo::host(); + if (!cpu.hasFeature(CpuFeatures::X86::kAVX)) + return; + CodeHolder code; + code.init(runtime.environment()); + x86::Compiler c(&code); + FuncNode* funcNode = c.addFunc(FuncSignatureT()); + funcNode->frame().setAvxEnabled(); + vector workspaceVar(workspace.size()); + for (int i = 0; i < (int) workspaceVar.size(); i++) + workspaceVar[i] = c.newXmmSd(); + x86::Gp argsPointer = c.newIntPtr(); + c.mov(argsPointer, imm(&argValues[0])); + vector > groups, groupPowers; + vector stepGroup; + findPowerGroups(groups, groupPowers, stepGroup); + + // Load the arguments into variables. + + x86::Gp variablePointer = c.newIntPtr(); + for (set::const_iterator iter = variableNames.begin(); iter != variableNames.end(); ++iter) { + map::iterator index = variableIndices.find(*iter); + c.mov(variablePointer, imm(&getVariableReference(index->first))); + c.vmovsd(workspaceVar[index->second], x86::ptr(variablePointer, 0, 0)); + } + + // Make a list of all constants that will be needed for evaluation. + + vector operationConstantIndex(operation.size(), -1); + for (int step = 0; step < (int) operation.size(); step++) { + // Find the constant value (if any) used by this operation. + + Operation& op = *operation[step]; + double value; + if (op.getId() == Operation::CONSTANT) + value = dynamic_cast(op).getValue(); + else if (op.getId() == Operation::ADD_CONSTANT) + value = dynamic_cast(op).getValue(); + else if (op.getId() == Operation::MULTIPLY_CONSTANT) + value = dynamic_cast(op).getValue(); + else if (op.getId() == Operation::RECIPROCAL) + value = 1.0; + else if (op.getId() == Operation::STEP) + value = 1.0; + else if (op.getId() == Operation::DELTA) + value = 1.0; + else if (op.getId() == Operation::ABS) { + long long mask = 0x7FFFFFFFFFFFFFFF; + value = *reinterpret_cast(&mask); + } + else if (op.getId() == Operation::POWER_CONSTANT) { + if (stepGroup[step] == -1) + value = dynamic_cast(op).getValue(); + else + value = 1.0; + } + else + continue; + + // See if we already have a variable for this constant. + + for (int i = 0; i < (int) constants.size(); i++) + if (value == constants[i]) { + operationConstantIndex[step] = i; + break; + } + if (operationConstantIndex[step] == -1) { + operationConstantIndex[step] = constants.size(); + constants.push_back(value); + } + } + + // Load constants into variables. + + vector constantVar(constants.size()); + if (constants.size() > 0) { + x86::Gp constantsPointer = c.newIntPtr(); + c.mov(constantsPointer, imm(&constants[0])); + for (int i = 0; i < (int) constants.size(); i++) { + constantVar[i] = c.newXmmSd(); + c.vmovsd(constantVar[i], x86::ptr(constantsPointer, 8*i, 0)); + } + } + + // Evaluate the operations. + + vector hasComputedPower(operation.size(), false); + for (int step = 0; step < (int) operation.size(); step++) { + if (hasComputedPower[step]) + continue; + + // When one or more steps involve raising the same argument to multiple integer + // powers, we can compute them all together for efficiency. + + if (stepGroup[step] != -1) { + vector& group = groups[stepGroup[step]]; + vector& powers = groupPowers[stepGroup[step]]; + x86::Xmm multiplier = c.newXmmSd(); + if (powers[0] > 0) + c.vmovsd(multiplier, workspaceVar[arguments[step][0]], workspaceVar[arguments[step][0]]); + else { + c.vdivsd(multiplier, constantVar[operationConstantIndex[step]], workspaceVar[arguments[step][0]]); + for (int i = 0; i < powers.size(); i++) + powers[i] = -powers[i]; + } + vector hasAssigned(group.size(), false); + bool done = false; + while (!done) { + done = true; + for (int i = 0; i < group.size(); i++) { + if (powers[i]%2 == 1) { + if (!hasAssigned[i]) + c.vmovsd(workspaceVar[target[group[i]]], multiplier, multiplier); + else + c.vmulsd(workspaceVar[target[group[i]]], workspaceVar[target[group[i]]], multiplier); + hasAssigned[i] = true; + } + powers[i] >>= 1; + if (powers[i] != 0) + done = false; + } + if (!done) + c.vmulsd(multiplier, multiplier, multiplier); + } + for (int step : group) + hasComputedPower[step] = true; + continue; + } + + // Evaluate the step. + + Operation& op = *operation[step]; + vector args = arguments[step]; + if (args.size() == 1) { + // One or more sequential arguments. Fill out the list. + + for (int i = 1; i < op.getNumArguments(); i++) + args.push_back(args[0]+i); + } + + // Generate instructions to execute this operation. + + switch (op.getId()) { + case Operation::CONSTANT: + c.vmovsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]], constantVar[operationConstantIndex[step]]); + break; + case Operation::ADD: + c.vaddsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::SUBTRACT: + c.vsubsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::MULTIPLY: + c.vmulsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::DIVIDE: + c.vdivsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::POWER: + generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], pow); + break; + case Operation::NEGATE: + c.vxorps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[target[step]]); + c.vsubsd(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]]); + break; + case Operation::SQRT: + c.vsqrtsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]]); + break; + case Operation::EXP: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], exp); + break; + case Operation::LOG: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], log); + break; + case Operation::SIN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], sin); + break; + case Operation::COS: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], cos); + break; + case Operation::TAN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tan); + break; + case Operation::ASIN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], asin); + break; + case Operation::ACOS: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], acos); + break; + case Operation::ATAN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], atan); + break; + case Operation::ATAN2: + generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], atan2); + break; + case Operation::SINH: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], sinh); + break; + case Operation::COSH: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], cosh); + break; + case Operation::TANH: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tanh); + break; + case Operation::STEP: + c.vxorps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[target[step]]); + c.vcmpsd(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]], imm(18)); // Comparison mode is _CMP_LE_OQ = 18 + c.vandps(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + break; + case Operation::DELTA: + c.vxorps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[target[step]]); + c.vcmpsd(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]], imm(16)); // Comparison mode is _CMP_EQ_OS = 16 + c.vandps(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + break; + case Operation::SQUARE: + c.vmulsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]]); + break; + case Operation::CUBE: + c.vmulsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]]); + c.vmulsd(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]]); + break; + case Operation::RECIPROCAL: + c.vdivsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]], workspaceVar[args[0]]); + break; + case Operation::ADD_CONSTANT: + c.vaddsd(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]); + break; + case Operation::MULTIPLY_CONSTANT: + c.vmulsd(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]); + break; + case Operation::POWER_CONSTANT: + generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]], pow); + break; + case Operation::MIN: + c.vminsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::MAX: + c.vmaxsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::ABS: + c.vandpd(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]); + break; + case Operation::FLOOR: + c.vroundsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]], imm(1)); + break; + case Operation::CEIL: + c.vroundsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]], imm(2)); + break; + case Operation::SELECT: + { + x86::Xmm mask = c.newXmmSd(); + c.vxorps(mask, mask, mask); + c.vcmpsd(mask, mask, workspaceVar[args[0]], imm(0)); // Comparison mode is _CMP_EQ_OQ = 0 + c.vblendvps(workspaceVar[target[step]], workspaceVar[args[1]], workspaceVar[args[2]], mask); + break; + } + default: + // Just invoke evaluateOperation(). + + for (int i = 0; i < (int) args.size(); i++) + c.vmovsd(x86::ptr(argsPointer, 8*i, 0), workspaceVar[args[i]]); + x86::Gp fn = c.newIntPtr(); + c.mov(fn, imm((void*) evaluateOperation)); + InvokeNode* invoke; + c.invoke(&invoke, fn, FuncSignatureT()); + invoke->setArg(0, imm(&op)); + invoke->setArg(1, imm(&argValues[0])); + invoke->setRet(0, workspaceVar[target[step]]); + } + } + c.ret(workspaceVar[workspace.size()-1]); + c.endFunc(); + c.finalize(); + runtime.add(&jitCode, &code); +} + +void CompiledExpression::generateSingleArgCall(x86::Compiler& c, x86::Xmm& dest, x86::Xmm& arg, double (*function)(double)) { + x86::Gp fn = c.newIntPtr(); + c.mov(fn, imm((void*) function)); + InvokeNode* invoke; + c.invoke(&invoke, fn, FuncSignatureT()); + invoke->setArg(0, arg); + invoke->setRet(0, dest); +} + +void CompiledExpression::generateTwoArgCall(x86::Compiler& c, x86::Xmm& dest, x86::Xmm& arg1, x86::Xmm& arg2, double (*function)(double, double)) { + x86::Gp fn = c.newIntPtr(); + c.mov(fn, imm((void*) function)); + InvokeNode* invoke; + c.invoke(&invoke, fn, FuncSignatureT()); + invoke->setArg(0, arg1); + invoke->setArg(1, arg2); + invoke->setRet(0, dest); } #endif +#endif diff --git a/lib/colvars/lepton/src/CompiledVectorExpression.cpp b/lib/colvars/lepton/src/CompiledVectorExpression.cpp new file mode 100644 index 0000000000..7c01a986bb --- /dev/null +++ b/lib/colvars/lepton/src/CompiledVectorExpression.cpp @@ -0,0 +1,933 @@ +/* -------------------------------------------------------------------------- * + * Lepton * + * -------------------------------------------------------------------------- * + * This is part of the Lepton expression parser originating from * + * Simbios, the NIH National Center for Physics-Based Simulation of * + * Biological Structures at Stanford, funded under the NIH Roadmap for * + * Medical Research, grant U54 GM072970. See https://simtk.org. * + * * + * Portions copyright (c) 2013-2022 Stanford University and the Authors. * + * Authors: Peter Eastman * + * Contributors: * + * * + * Permission is hereby granted, free of charge, to any person obtaining a * + * copy of this software and associated documentation files (the "Software"), * + * to deal in the Software without restriction, including without limitation * + * the rights to use, copy, modify, merge, publish, distribute, sublicense, * + * and/or sell copies of the Software, and to permit persons to whom the * + * Software is furnished to do so, subject to the following conditions: * + * * + * The above copyright notice and this permission notice shall be included in * + * all copies or substantial portions of the Software. * + * * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL * + * THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, * + * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR * + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE * + * USE OR OTHER DEALINGS IN THE SOFTWARE. * + * -------------------------------------------------------------------------- */ + +#include "lepton/CompiledVectorExpression.h" +#include "lepton/Operation.h" +#include "lepton/ParsedExpression.h" +#include +#include + +using namespace Lepton; +using namespace std; +#ifdef LEPTON_USE_JIT +using namespace asmjit; +#endif + +CompiledVectorExpression::CompiledVectorExpression() : jitCode(NULL) { +} + +CompiledVectorExpression::CompiledVectorExpression(const ParsedExpression& expression, int width) : jitCode(NULL), width(width) { + const vector allowedWidths = getAllowedWidths(); + if (find(allowedWidths.begin(), allowedWidths.end(), width) == allowedWidths.end()) + throw Exception("Unsupported width for vector expression: "+to_string(width)); + ParsedExpression expr = expression.optimize(); // Just in case it wasn't already optimized. + vector > temps; + int workspaceSize = 0; + compileExpression(expr.getRootNode(), temps, workspaceSize); + workspace.resize(workspaceSize*width); + int maxArguments = 1; + for (int i = 0; i < (int) operation.size(); i++) + if (operation[i]->getNumArguments() > maxArguments) + maxArguments = operation[i]->getNumArguments(); + argValues.resize(maxArguments); +#ifdef LEPTON_USE_JIT + generateJitCode(); +#endif +} + +CompiledVectorExpression::~CompiledVectorExpression() { + for (int i = 0; i < (int) operation.size(); i++) + if (operation[i] != NULL) + delete operation[i]; +} + +CompiledVectorExpression::CompiledVectorExpression(const CompiledVectorExpression& expression) : jitCode(NULL) { + *this = expression; +} + +CompiledVectorExpression& CompiledVectorExpression::operator=(const CompiledVectorExpression& expression) { + arguments = expression.arguments; + width = expression.width; + target = expression.target; + variableIndices = expression.variableIndices; + variableNames = expression.variableNames; + workspace.resize(expression.workspace.size()); + argValues.resize(expression.argValues.size()); + operation.resize(expression.operation.size()); + for (int i = 0; i < (int) operation.size(); i++) + operation[i] = expression.operation[i]->clone(); + setVariableLocations(variablePointers); + return *this; +} + +const vector& CompiledVectorExpression::getAllowedWidths() { + static vector widths; + if (widths.size() == 0) { + widths.push_back(4); +#ifdef LEPTON_USE_JIT + const CpuInfo& cpu = CpuInfo::host(); + if (cpu.hasFeature(CpuFeatures::X86::kAVX)) + widths.push_back(8); +#endif + } + return widths; +} + +void CompiledVectorExpression::compileExpression(const ExpressionTreeNode& node, vector >& temps, int& workspaceSize) { + if (findTempIndex(node, temps) != -1) + return; // We have already processed a node identical to this one. + + // Process the child nodes. + + vector args; + for (int i = 0; i < node.getChildren().size(); i++) { + compileExpression(node.getChildren()[i], temps, workspaceSize); + args.push_back(findTempIndex(node.getChildren()[i], temps)); + } + + // Process this node. + + if (node.getOperation().getId() == Operation::VARIABLE) { + variableIndices[node.getOperation().getName()] = workspaceSize; + variableNames.insert(node.getOperation().getName()); + } + else { + int stepIndex = (int) arguments.size(); + arguments.push_back(vector()); + target.push_back(workspaceSize); + operation.push_back(node.getOperation().clone()); + if (args.size() == 0) + arguments[stepIndex].push_back(0); // The value won't actually be used. We just need something there. + else { + // If the arguments are sequential, we can just pass a pointer to the first one. + + bool sequential = true; + for (int i = 1; i < args.size(); i++) + if (args[i] != args[i - 1] + 1) + sequential = false; + if (sequential) + arguments[stepIndex].push_back(args[0]); + else + arguments[stepIndex] = args; + } + } + temps.push_back(make_pair(node, workspaceSize)); + workspaceSize++; +} + +int CompiledVectorExpression::findTempIndex(const ExpressionTreeNode& node, vector >& temps) { + for (int i = 0; i < (int) temps.size(); i++) + if (temps[i].first == node) + return i; + return -1; +} + +int CompiledVectorExpression::getWidth() const { + return width; +} + +const set& CompiledVectorExpression::getVariables() const { + return variableNames; +} + +float* CompiledVectorExpression::getVariablePointer(const string& name) { + map::iterator pointer = variablePointers.find(name); + if (pointer != variablePointers.end()) + return pointer->second; + map::iterator index = variableIndices.find(name); + if (index == variableIndices.end()) + throw Exception("getVariableReference: Unknown variable '" + name + "'"); + return &workspace[index->second*width]; +} + +void CompiledVectorExpression::setVariableLocations(map& variableLocations) { + variablePointers = variableLocations; +#ifdef LEPTON_USE_JIT + // Rebuild the JIT code. + + if (workspace.size() > 0) + generateJitCode(); +#endif + // Make a list of all variables we will need to copy before evaluating the expression. + + variablesToCopy.clear(); + for (map::const_iterator iter = variableIndices.begin(); iter != variableIndices.end(); ++iter) { + map::iterator pointer = variablePointers.find(iter->first); + if (pointer != variablePointers.end()) + variablesToCopy.push_back(make_pair(&workspace[iter->second*width], pointer->second)); + } +} + +const float* CompiledVectorExpression::evaluate() const { + if (jitCode) { + jitCode(); + return &workspace[workspace.size()-width]; + } + for (int i = 0; i < variablesToCopy.size(); i++) + for (int j = 0; j < width; j++) + variablesToCopy[i].first[j] = variablesToCopy[i].second[j]; + + // Loop over the operations and evaluate each one. + + for (int step = 0; step < operation.size(); step++) { + const vector& args = arguments[step]; + if (args.size() == 1) { + for (int j = 0; j < width; j++) { + for (int i = 0; i < operation[step]->getNumArguments(); i++) + argValues[i] = workspace[(args[0]+i)*width+j]; + workspace[target[step]*width+j] = operation[step]->evaluate(&argValues[0], dummyVariables); + } + } else { + for (int j = 0; j < width; j++) { + for (int i = 0; i < args.size(); i++) + argValues[i] = workspace[args[i]*width+j]; + workspace[target[step]*width+j] = operation[step]->evaluate(&argValues[0], dummyVariables); + } + } + } + return &workspace[workspace.size()-width]; +} + +#ifdef LEPTON_USE_JIT + +static double evaluateOperation(Operation* op, double* args) { + static map dummyVariables; + return op->evaluate(args, dummyVariables); +} + +void CompiledVectorExpression::findPowerGroups(vector >& groups, vector >& groupPowers, vector& stepGroup) { + // Identify every step that raises an argument to an integer power. + + vector stepPower(operation.size(), 0); + vector stepArg(operation.size(), -1); + for (int step = 0; step < operation.size(); step++) { + Operation& op = *operation[step]; + int power = 0; + if (op.getId() == Operation::SQUARE) + power = 2; + else if (op.getId() == Operation::CUBE) + power = 3; + else if (op.getId() == Operation::POWER_CONSTANT) { + double realPower = dynamic_cast (&op)->getValue(); + if (realPower == (int) realPower) + power = (int) realPower; + } + if (power != 0) { + stepPower[step] = power; + stepArg[step] = arguments[step][0]; + } + } + + // Find groups that operate on the same argument and whose powers have the same sign. + + stepGroup.resize(operation.size(), -1); + for (int i = 0; i < operation.size(); i++) { + if (stepGroup[i] != -1) + continue; + vector group, power; + for (int j = i; j < operation.size(); j++) { + if (stepArg[i] == stepArg[j] && stepPower[i] * stepPower[j] > 0) { + stepGroup[j] = groups.size(); + group.push_back(j); + power.push_back(stepPower[j]); + } + } + groups.push_back(group); + groupPowers.push_back(power); + } +} + +#if defined(__ARM__) || defined(__ARM64__) + +void CompiledVectorExpression::generateJitCode() { + CodeHolder code; + code.init(runtime.environment()); + a64::Compiler c(&code); + c.addFunc(FuncSignatureT()); + vector workspaceVar(workspace.size()/width); + for (int i = 0; i < (int) workspaceVar.size(); i++) + workspaceVar[i] = c.newVecQ(); + arm::Gp argsPointer = c.newIntPtr(); + c.mov(argsPointer, imm(&argValues[0])); + vector > groups, groupPowers; + vector stepGroup; + findPowerGroups(groups, groupPowers, stepGroup); + + // Load the arguments into variables. + + arm::Gp variablePointer = c.newIntPtr(); + for (set::const_iterator iter = variableNames.begin(); iter != variableNames.end(); ++iter) { + map::iterator index = variableIndices.find(*iter); + c.mov(variablePointer, imm(getVariablePointer(index->first))); + c.ldr(workspaceVar[index->second].s4(), arm::ptr(variablePointer, 0)); + } + + // Make a list of all constants that will be needed for evaluation. + + vector operationConstantIndex(operation.size(), -1); + for (int step = 0; step < (int) operation.size(); step++) { + // Find the constant value (if any) used by this operation. + + Operation& op = *operation[step]; + float value; + if (op.getId() == Operation::CONSTANT) + value = dynamic_cast (op).getValue(); + else if (op.getId() == Operation::ADD_CONSTANT) + value = dynamic_cast (op).getValue(); + else if (op.getId() == Operation::MULTIPLY_CONSTANT) + value = dynamic_cast (op).getValue(); + else if (op.getId() == Operation::RECIPROCAL) + value = 1.0; + else if (op.getId() == Operation::STEP) + value = 1.0; + else if (op.getId() == Operation::DELTA) + value = 1.0; + else if (op.getId() == Operation::POWER_CONSTANT) { + if (stepGroup[step] == -1) + value = dynamic_cast (op).getValue(); + else + value = 1.0; + } else + continue; + + // See if we already have a variable for this constant. + + for (int i = 0; i < (int) constants.size(); i++) + if (value == constants[i]) { + operationConstantIndex[step] = i; + break; + } + if (operationConstantIndex[step] == -1) { + operationConstantIndex[step] = constants.size(); + constants.push_back(value); + } + } + + // Load constants into variables. + + vector constantVar(constants.size()); + if (constants.size() > 0) { + arm::Gp constantsPointer = c.newIntPtr(); + for (int i = 0; i < (int) constants.size(); i++) { + c.mov(constantsPointer, imm(&constants[i])); + constantVar[i] = c.newVecQ(); + c.ld1r(constantVar[i].s4(), arm::ptr(constantsPointer)); + } + } + + // Evaluate the operations. + + vector hasComputedPower(operation.size(), false); + arm::Vec argReg = c.newVecS(); + arm::Vec doubleArgReg = c.newVecD(); + arm::Vec doubleResultReg = c.newVecD(); + for (int step = 0; step < (int) operation.size(); step++) { + if (hasComputedPower[step]) + continue; + + // When one or more steps involve raising the same argument to multiple integer + // powers, we can compute them all together for efficiency. + + if (stepGroup[step] != -1) { + vector& group = groups[stepGroup[step]]; + vector& powers = groupPowers[stepGroup[step]]; + arm::Vec multiplier = c.newVecQ(); + if (powers[0] > 0) + c.mov(multiplier.s4(), workspaceVar[arguments[step][0]].s4()); + else { + c.fdiv(multiplier.s4(), constantVar[operationConstantIndex[step]].s4(), workspaceVar[arguments[step][0]].s4()); + for (int i = 0; i < powers.size(); i++) + powers[i] = -powers[i]; + } + vector hasAssigned(group.size(), false); + bool done = false; + while (!done) { + done = true; + for (int i = 0; i < group.size(); i++) { + if (powers[i] % 2 == 1) { + if (!hasAssigned[i]) + c.mov(workspaceVar[target[group[i]]].s4(), multiplier.s4()); + else + c.fmul(workspaceVar[target[group[i]]].s4(), workspaceVar[target[group[i]]].s4(), multiplier.s4()); + hasAssigned[i] = true; + } + powers[i] >>= 1; + if (powers[i] != 0) + done = false; + } + if (!done) + c.fmul(multiplier.s4(), multiplier.s4(), multiplier.s4()); + } + for (int step : group) + hasComputedPower[step] = true; + continue; + } + + // Evaluate the step. + + Operation& op = *operation[step]; + vector args = arguments[step]; + if (args.size() == 1) { + // One or more sequential arguments. Fill out the list. + + for (int i = 1; i < op.getNumArguments(); i++) + args.push_back(args[0] + i); + } + + // Generate instructions to execute this operation. + + switch (op.getId()) { + case Operation::CONSTANT: + c.mov(workspaceVar[target[step]].s4(), constantVar[operationConstantIndex[step]].s4()); + break; + case Operation::ADD: + c.fadd(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[1]].s4()); + break; + case Operation::SUBTRACT: + c.fsub(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[1]].s4()); + break; + case Operation::MULTIPLY: + c.fmul(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[1]].s4()); + break; + case Operation::DIVIDE: + c.fdiv(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[1]].s4()); + break; + case Operation::POWER: + generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], powf); + break; + case Operation::NEGATE: + c.fneg(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4()); + break; + case Operation::SQRT: + c.fsqrt(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4()); + break; + case Operation::EXP: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], expf); + break; + case Operation::LOG: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], logf); + break; + case Operation::SIN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], sinf); + break; + case Operation::COS: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], cosf); + break; + case Operation::TAN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tanf); + break; + case Operation::ASIN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], asinf); + break; + case Operation::ACOS: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], acosf); + break; + case Operation::ATAN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], atanf); + break; + case Operation::ATAN2: + generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], atan2f); + break; + case Operation::SINH: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], sinhf); + break; + case Operation::COSH: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], coshf); + break; + case Operation::TANH: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tanhf); + break; + case Operation::STEP: + c.cmge(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), imm(0)); + c.and_(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + break; + case Operation::DELTA: + c.cmeq(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), imm(0)); + c.and_(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + break; + case Operation::SQUARE: + c.fmul(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[0]].s4()); + break; + case Operation::CUBE: + c.fmul(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[0]].s4()); + c.fmul(workspaceVar[target[step]].s4(), workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4()); + break; + case Operation::RECIPROCAL: + c.fdiv(workspaceVar[target[step]].s4(), constantVar[operationConstantIndex[step]].s4(), workspaceVar[args[0]].s4()); + break; + case Operation::ADD_CONSTANT: + c.fadd(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), constantVar[operationConstantIndex[step]].s4()); + break; + case Operation::MULTIPLY_CONSTANT: + c.fmul(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), constantVar[operationConstantIndex[step]].s4()); + break; + case Operation::POWER_CONSTANT: + generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]], powf); + break; + case Operation::MIN: + c.fmin(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[1]].s4()); + break; + case Operation::MAX: + c.fmax(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[1]].s4()); + break; + case Operation::ABS: + c.fabs(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4()); + break; + case Operation::FLOOR: + c.frintm(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4()); + break; + case Operation::CEIL: + c.frintp(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4()); + break; + case Operation::SELECT: + c.fcmeq(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), imm(0)); + c.bsl(workspaceVar[target[step]], workspaceVar[args[2]], workspaceVar[args[1]]); + break; + default: + // Just invoke evaluateOperation(). + for (int element = 0; element < width; element++) { + for (int i = 0; i < (int) args.size(); i++) { + c.ins(argReg.s(0), workspaceVar[args[i]].s(element)); + c.fcvt(doubleArgReg, argReg); + c.str(doubleArgReg, arm::ptr(argsPointer, 8*i)); + } + arm::Gp fn = c.newIntPtr(); + c.mov(fn, imm((void*) evaluateOperation)); + InvokeNode* invoke; + c.invoke(&invoke, fn, FuncSignatureT()); + invoke->setArg(0, imm(&op)); + invoke->setArg(1, imm(&argValues[0])); + invoke->setRet(0, doubleResultReg); + c.fcvt(argReg, doubleResultReg); + c.ins(workspaceVar[target[step]].s(element), argReg.s(0)); + } + } + } + arm::Gp resultPointer = c.newIntPtr(); + c.mov(resultPointer, imm(&workspace[workspace.size()-width])); + c.str(workspaceVar.back().s4(), arm::ptr(resultPointer, 0)); + c.endFunc(); + c.finalize(); + runtime.add(&jitCode, &code); +} + +void CompiledVectorExpression::generateSingleArgCall(a64::Compiler& c, arm::Vec& dest, arm::Vec& arg, float (*function)(float)) { + arm::Gp fn = c.newIntPtr(); + c.mov(fn, imm((void*) function)); + arm::Vec a = c.newVecS(); + arm::Vec d = c.newVecS(); + for (int element = 0; element < width; element++) { + c.ins(a.s(0), arg.s(element)); + InvokeNode* invoke; + c.invoke(&invoke, fn, FuncSignatureT()); + invoke->setArg(0, a); + invoke->setRet(0, d); + c.ins(dest.s(element), d.s(0)); + } +} + +void CompiledVectorExpression::generateTwoArgCall(a64::Compiler& c, arm::Vec& dest, arm::Vec& arg1, arm::Vec& arg2, float (*function)(float, float)) { + arm::Gp fn = c.newIntPtr(); + c.mov(fn, imm((void*) function)); + arm::Vec a1 = c.newVecS(); + arm::Vec a2 = c.newVecS(); + arm::Vec d = c.newVecS(); + for (int element = 0; element < width; element++) { + c.ins(a1.s(0), arg1.s(element)); + c.ins(a2.s(0), arg2.s(element)); + InvokeNode* invoke; + c.invoke(&invoke, fn, FuncSignatureT()); + invoke->setArg(0, a1); + invoke->setArg(1, a2); + invoke->setRet(0, d); + c.ins(dest.s(element), d.s(0)); + } +} +#else + +void CompiledVectorExpression::generateJitCode() { + const CpuInfo& cpu = CpuInfo::host(); + if (!cpu.hasFeature(CpuFeatures::X86::kAVX)) + return; + CodeHolder code; + code.init(runtime.environment()); + x86::Compiler c(&code); + FuncNode* funcNode = c.addFunc(FuncSignatureT()); + funcNode->frame().setAvxEnabled(); + vector workspaceVar(workspace.size()/width); + for (int i = 0; i < (int) workspaceVar.size(); i++) + workspaceVar[i] = c.newYmmPs(); + x86::Gp argsPointer = c.newIntPtr(); + c.mov(argsPointer, imm(&argValues[0])); + vector > groups, groupPowers; + vector stepGroup; + findPowerGroups(groups, groupPowers, stepGroup); + + // Load the arguments into variables. + + for (set::const_iterator iter = variableNames.begin(); iter != variableNames.end(); ++iter) { + map::iterator index = variableIndices.find(*iter); + x86::Gp variablePointer = c.newIntPtr(); + c.mov(variablePointer, imm(getVariablePointer(index->first))); + if (width == 4) + c.vmovdqu(workspaceVar[index->second].xmm(), x86::ptr(variablePointer, 0, 0)); + else + c.vmovdqu(workspaceVar[index->second], x86::ptr(variablePointer, 0, 0)); + } + + // Make a list of all constants that will be needed for evaluation. + + vector operationConstantIndex(operation.size(), -1); + for (int step = 0; step < (int) operation.size(); step++) { + // Find the constant value (if any) used by this operation. + + Operation& op = *operation[step]; + double value; + if (op.getId() == Operation::CONSTANT) + value = dynamic_cast (op).getValue(); + else if (op.getId() == Operation::ADD_CONSTANT) + value = dynamic_cast (op).getValue(); + else if (op.getId() == Operation::MULTIPLY_CONSTANT) + value = dynamic_cast (op).getValue(); + else if (op.getId() == Operation::RECIPROCAL) + value = 1.0; + else if (op.getId() == Operation::STEP) + value = 1.0; + else if (op.getId() == Operation::DELTA) + value = 1.0; + else if (op.getId() == Operation::ABS) { + int mask = 0x7FFFFFFF; + value = *reinterpret_cast(&mask); + } + else if (op.getId() == Operation::POWER_CONSTANT) { + if (stepGroup[step] == -1) + value = dynamic_cast (op).getValue(); + else + value = 1.0; + } else + continue; + + // See if we already have a variable for this constant. + + for (int i = 0; i < (int) constants.size(); i++) + if (value == constants[i]) { + operationConstantIndex[step] = i; + break; + } + if (operationConstantIndex[step] == -1) { + operationConstantIndex[step] = constants.size(); + constants.push_back(value); + } + } + + // Load constants into variables. + + vector constantVar(constants.size()); + if (constants.size() > 0) { + x86::Gp constantsPointer = c.newIntPtr(); + c.mov(constantsPointer, imm(&constants[0])); + for (int i = 0; i < (int) constants.size(); i++) { + constantVar[i] = c.newYmmPs(); + c.vbroadcastss(constantVar[i], x86::ptr(constantsPointer, 4*i, 0)); + } + } + + // Evaluate the operations. + + vector hasComputedPower(operation.size(), false); + x86::Ymm argReg = c.newYmm(); + x86::Ymm doubleArgReg = c.newYmm(); + x86::Ymm doubleResultReg = c.newYmm(); + for (int step = 0; step < (int) operation.size(); step++) { + if (hasComputedPower[step]) + continue; + + // When one or more steps involve raising the same argument to multiple integer + // powers, we can compute them all together for efficiency. + + if (stepGroup[step] != -1) { + vector& group = groups[stepGroup[step]]; + vector& powers = groupPowers[stepGroup[step]]; + x86::Ymm multiplier = c.newYmmPs(); + if (powers[0] > 0) + c.vmovdqu(multiplier, workspaceVar[arguments[step][0]]); + else { + c.vdivps(multiplier, constantVar[operationConstantIndex[step]], workspaceVar[arguments[step][0]]); + for (int i = 0; i < powers.size(); i++) + powers[i] = -powers[i]; + } + vector hasAssigned(group.size(), false); + bool done = false; + while (!done) { + done = true; + for (int i = 0; i < group.size(); i++) { + if (powers[i] % 2 == 1) { + if (!hasAssigned[i]) + c.vmovdqu(workspaceVar[target[group[i]]], multiplier); + else + c.vmulps(workspaceVar[target[group[i]]], workspaceVar[target[group[i]]], multiplier); + hasAssigned[i] = true; + } + powers[i] >>= 1; + if (powers[i] != 0) + done = false; + } + if (!done) + c.vmulps(multiplier, multiplier, multiplier); + } + for (int step : group) + hasComputedPower[step] = true; + continue; + } + + // Evaluate the step. + + Operation& op = *operation[step]; + vector args = arguments[step]; + if (args.size() == 1) { + // One or more sequential arguments. Fill out the list. + + for (int i = 1; i < op.getNumArguments(); i++) + args.push_back(args[0] + i); + } + + // Generate instructions to execute this operation. + + switch (op.getId()) { + case Operation::CONSTANT: + c.vmovdqu(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + break; + case Operation::ADD: + c.vaddps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::SUBTRACT: + c.vsubps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::MULTIPLY: + c.vmulps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::DIVIDE: + c.vdivps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::POWER: + generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], powf); + break; + case Operation::NEGATE: + c.vxorps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[target[step]]); + c.vsubps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]]); + break; + case Operation::SQRT: + c.vsqrtps(workspaceVar[target[step]], workspaceVar[args[0]]); + break; + case Operation::EXP: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], expf); + break; + case Operation::LOG: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], logf); + break; + case Operation::SIN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], sinf); + break; + case Operation::COS: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], cosf); + break; + case Operation::TAN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tanf); + break; + case Operation::ASIN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], asinf); + break; + case Operation::ACOS: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], acosf); + break; + case Operation::ATAN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], atanf); + break; + case Operation::ATAN2: + generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], atan2f); + break; + case Operation::SINH: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], sinhf); + break; + case Operation::COSH: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], coshf); + break; + case Operation::TANH: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tanhf); + break; + case Operation::STEP: + c.vxorps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[target[step]]); + c.vcmpps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]], imm(18)); // Comparison mode is _CMP_LE_OQ = 18 + c.vandps(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + break; + case Operation::DELTA: + c.vxorps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[target[step]]); + c.vcmpps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]], imm(16)); // Comparison mode is _CMP_EQ_OQ = 0 + c.vandps(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + break; + case Operation::SQUARE: + c.vmulps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]]); + break; + case Operation::CUBE: + c.vmulps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]]); + c.vmulps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]]); + break; + case Operation::RECIPROCAL: + c.vdivps(workspaceVar[target[step]], constantVar[operationConstantIndex[step]], workspaceVar[args[0]]); + break; + case Operation::ADD_CONSTANT: + c.vaddps(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]); + break; + case Operation::MULTIPLY_CONSTANT: + c.vmulps(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]); + break; + case Operation::POWER_CONSTANT: + generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]], powf); + break; + case Operation::MIN: + c.vminps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::MAX: + c.vmaxps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::ABS: + c.vandps(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]); + break; + case Operation::FLOOR: + c.vroundps(workspaceVar[target[step]], workspaceVar[args[0]], imm(1)); + break; + case Operation::CEIL: + c.vroundps(workspaceVar[target[step]], workspaceVar[args[0]], imm(2)); + break; + case Operation::SELECT: + { + x86::Ymm mask = c.newYmmPs(); + c.vxorps(mask, mask, mask); + c.vcmpps(mask, mask, workspaceVar[args[0]], imm(0)); // Comparison mode is _CMP_EQ_OQ = 0 + c.vblendvps(workspaceVar[target[step]], workspaceVar[args[1]], workspaceVar[args[2]], mask); + break; + } + default: + // Just invoke evaluateOperation(). + + for (int element = 0; element < width; element++) { + for (int i = 0; i < (int) args.size(); i++) { + if (element < 4) + c.vshufps(argReg, workspaceVar[args[i]], workspaceVar[args[i]], imm(element)); + else { + c.vperm2f128(argReg, workspaceVar[args[i]], workspaceVar[args[i]], imm(1)); + c.vshufps(argReg, argReg, argReg, imm(element-4)); + } + c.vcvtss2sd(doubleArgReg.xmm(), doubleArgReg.xmm(), argReg.xmm()); + c.vmovsd(x86::ptr(argsPointer, 8*i, 0), doubleArgReg.xmm()); + } + x86::Gp fn = c.newIntPtr(); + c.mov(fn, imm((void*) evaluateOperation)); + InvokeNode* invoke; + c.invoke(&invoke, fn, FuncSignatureT()); + invoke->setArg(0, imm(&op)); + invoke->setArg(1, imm(&argValues[0])); + invoke->setRet(0, doubleResultReg); + c.vcvtsd2ss(argReg.xmm(), argReg.xmm(), doubleResultReg.xmm()); + if (element > 3) + c.vperm2f128(argReg, argReg, argReg, imm(0)); + if (element != 0) + c.vshufps(argReg, argReg, argReg, imm(0)); + c.vblendps(workspaceVar[target[step]], workspaceVar[target[step]], argReg, 1<()); + invoke->setArg(0, a); + invoke->setRet(0, d); + if (element > 3) + c.vperm2f128(d, d, d, imm(0)); + if (element != 0) + c.vshufps(d, d, d, imm(0)); + c.vblendps(dest, dest, d, 1<()); + invoke->setArg(0, a1); + invoke->setArg(1, a2); + invoke->setRet(0, d); + if (element > 3) + c.vperm2f128(d, d, d, imm(0)); + if (element != 0) + c.vshufps(d, d, d, imm(0)); + c.vblendps(dest, dest, d, 1< using namespace Lepton; using namespace std; @@ -62,6 +63,11 @@ ExpressionTreeNode::ExpressionTreeNode(Operation* operation) : operation(operati ExpressionTreeNode::ExpressionTreeNode(const ExpressionTreeNode& node) : operation(node.operation == NULL ? NULL : node.operation->clone()), children(node.getChildren()) { } +ExpressionTreeNode::ExpressionTreeNode(ExpressionTreeNode&& node) : operation(node.operation), children(move(node.children)) { + node.operation = NULL; + node.children.clear(); +} + ExpressionTreeNode::ExpressionTreeNode() : operation(NULL) { } @@ -98,6 +104,16 @@ ExpressionTreeNode& ExpressionTreeNode::operator=(const ExpressionTreeNode& node return *this; } +ExpressionTreeNode& ExpressionTreeNode::operator=(ExpressionTreeNode&& node) { + if (operation != NULL) + delete operation; + operation = node.operation; + children = move(node.children); + node.operation = NULL; + node.children.clear(); + return *this; +} + const Operation& ExpressionTreeNode::getOperation() const { return *operation; } @@ -105,3 +121,33 @@ const Operation& ExpressionTreeNode::getOperation() const { const vector& ExpressionTreeNode::getChildren() const { return children; } + +void ExpressionTreeNode::assignTags(vector& examples) const { + // Assign tag values to all nodes in a tree, such that two nodes have the same + // tag if and only if they (and all their children) are equal. This is used to + // optimize other operations. + + int numTags = examples.size(); + for (const ExpressionTreeNode& child : getChildren()) + child.assignTags(examples); + if (numTags == examples.size()) { + // All the children matched existing tags, so possibly this node does too. + + for (int i = 0; i < examples.size(); i++) { + const ExpressionTreeNode& example = *examples[i]; + bool matches = (getChildren().size() == example.getChildren().size() && getOperation() == example.getOperation()); + for (int j = 0; matches && j < getChildren().size(); j++) + if (getChildren()[j].tag != example.getChildren()[j].tag) + matches = false; + if (matches) { + tag = i; + return; + } + } + } + + // This node does not match any previous node, so assign a new tag. + + tag = examples.size(); + examples.push_back(this); +} diff --git a/lib/colvars/lepton/src/MSVC_erfc.h b/lib/colvars/lepton/src/MSVC_erfc.h index b1cd87a289..2c6b619e89 100644 --- a/lib/colvars/lepton/src/MSVC_erfc.h +++ b/lib/colvars/lepton/src/MSVC_erfc.h @@ -3,7 +3,7 @@ /* * Up to version 11 (VC++ 2012), Microsoft does not support the - * standard C99 erf() and erfc() functions so we have to fake them here. + * standard C99 erf() and erfc() functions so we have to fake them here. * These were added in version 12 (VC++ 2013), which sets _MSC_VER=1800 * (VC11 has _MSC_VER=1700). */ @@ -15,7 +15,7 @@ #endif #if defined(_MSC_VER) -#if _MSC_VER <= 1700 // 1700 is VC11, 1800 is VC12 +#if _MSC_VER <= 1700 // 1700 is VC11, 1800 is VC12 /*************************** * erf.cpp * author: Steve Strand diff --git a/lib/colvars/lepton/src/Operation.cpp b/lib/colvars/lepton/src/Operation.cpp index 78741c4814..b5a958b2f7 100644 --- a/lib/colvars/lepton/src/Operation.cpp +++ b/lib/colvars/lepton/src/Operation.cpp @@ -7,7 +7,7 @@ * Biological Structures at Stanford, funded under the NIH Roadmap for * * Medical Research, grant U54 GM072970. See https://simtk.org. * * * - * Portions copyright (c) 2009-2019 Stanford University and the Authors. * + * Portions copyright (c) 2009-2021 Stanford University and the Authors. * * Authors: Peter Eastman * * Contributors: * * * @@ -37,6 +37,12 @@ using namespace Lepton; using namespace std; +static bool isZero(const ExpressionTreeNode& node) { + if (node.getOperation().getId() != Operation::CONSTANT) + return false; + return dynamic_cast(node.getOperation()).getValue() == 0.0; +} + double Operation::Erf::evaluate(double* args, const map& variables) const { return erf(args[0]); } @@ -58,35 +64,71 @@ ExpressionTreeNode Operation::Variable::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { if (function->getNumArguments() == 0) return ExpressionTreeNode(new Operation::Constant(0.0)); - ExpressionTreeNode result = ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Custom(*this, 0), children), childDerivs[0]); - for (int i = 1; i < getNumArguments(); i++) { - result = ExpressionTreeNode(new Operation::Add(), - result, - ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Custom(*this, i), children), childDerivs[i])); + ExpressionTreeNode result; + bool foundTerm = false; + for (int i = 0; i < getNumArguments(); i++) { + if (!isZero(childDerivs[i])) { + if (foundTerm) + result = ExpressionTreeNode(new Operation::Add(), + result, + ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Custom(*this, i), children), childDerivs[i])); + else { + result = ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Custom(*this, i), children), childDerivs[i]); + foundTerm = true; + } + } } - return result; + if (foundTerm) + return result; + return ExpressionTreeNode(new Operation::Constant(0.0)); } ExpressionTreeNode Operation::Add::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) + return childDerivs[1]; + if (isZero(childDerivs[1])) + return childDerivs[0]; return ExpressionTreeNode(new Operation::Add(), childDerivs[0], childDerivs[1]); } ExpressionTreeNode Operation::Subtract::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) { + if (isZero(childDerivs[1])) + return ExpressionTreeNode(new Operation::Constant(0.0)); + return ExpressionTreeNode(new Operation::Negate(), childDerivs[1]); + } + if (isZero(childDerivs[1])) + return childDerivs[0]; return ExpressionTreeNode(new Operation::Subtract(), childDerivs[0], childDerivs[1]); } ExpressionTreeNode Operation::Multiply::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) { + if (isZero(childDerivs[1])) + return ExpressionTreeNode(new Operation::Constant(0.0)); + return ExpressionTreeNode(new Operation::Multiply(), children[0], childDerivs[1]); + } + if (isZero(childDerivs[1])) + return ExpressionTreeNode(new Operation::Multiply(), children[1], childDerivs[0]); return ExpressionTreeNode(new Operation::Add(), ExpressionTreeNode(new Operation::Multiply(), children[0], childDerivs[1]), ExpressionTreeNode(new Operation::Multiply(), children[1], childDerivs[0])); } ExpressionTreeNode Operation::Divide::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { - return ExpressionTreeNode(new Operation::Divide(), - ExpressionTreeNode(new Operation::Subtract(), + ExpressionTreeNode subexp; + if (isZero(childDerivs[0])) { + if (isZero(childDerivs[1])) + return ExpressionTreeNode(new Operation::Constant(0.0)); + subexp = ExpressionTreeNode(new Operation::Negate(), ExpressionTreeNode(new Operation::Multiply(), children[0], childDerivs[1])); + } + else if (isZero(childDerivs[1])) + subexp = ExpressionTreeNode(new Operation::Multiply(), children[1], childDerivs[0]); + else + subexp = ExpressionTreeNode(new Operation::Subtract(), ExpressionTreeNode(new Operation::Multiply(), children[1], childDerivs[0]), - ExpressionTreeNode(new Operation::Multiply(), children[0], childDerivs[1])), - ExpressionTreeNode(new Operation::Square(), children[1])); + ExpressionTreeNode(new Operation::Multiply(), children[0], childDerivs[1])); + return ExpressionTreeNode(new Operation::Divide(), subexp, ExpressionTreeNode(new Operation::Square(), children[1])); } ExpressionTreeNode Operation::Power::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { @@ -105,10 +147,14 @@ ExpressionTreeNode Operation::Power::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Negate(), childDerivs[0]); } ExpressionTreeNode Operation::Sqrt::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::MultiplyConstant(0.5), ExpressionTreeNode(new Operation::Reciprocal(), @@ -117,24 +163,32 @@ ExpressionTreeNode Operation::Sqrt::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Exp(), children[0]), childDerivs[0]); } ExpressionTreeNode Operation::Log::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Reciprocal(), children[0]), childDerivs[0]); } ExpressionTreeNode Operation::Sin::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Cos(), children[0]), childDerivs[0]); } ExpressionTreeNode Operation::Cos::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Negate(), ExpressionTreeNode(new Operation::Sin(), children[0])), @@ -142,6 +196,8 @@ ExpressionTreeNode Operation::Cos::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Sec(), children[0]), @@ -150,6 +206,8 @@ ExpressionTreeNode Operation::Sec::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Negate(), ExpressionTreeNode(new Operation::Multiply(), @@ -159,6 +217,8 @@ ExpressionTreeNode Operation::Csc::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Square(), ExpressionTreeNode(new Operation::Sec(), children[0])), @@ -166,6 +226,8 @@ ExpressionTreeNode Operation::Tan::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Negate(), ExpressionTreeNode(new Operation::Square(), @@ -174,6 +236,8 @@ ExpressionTreeNode Operation::Cot::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Reciprocal(), ExpressionTreeNode(new Operation::Sqrt(), @@ -184,6 +248,8 @@ ExpressionTreeNode Operation::Asin::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Negate(), ExpressionTreeNode(new Operation::Reciprocal(), @@ -195,6 +261,8 @@ ExpressionTreeNode Operation::Acos::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Reciprocal(), ExpressionTreeNode(new Operation::AddConstant(1.0), @@ -213,6 +281,8 @@ ExpressionTreeNode Operation::Atan2::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Cosh(), children[0]), @@ -220,6 +290,8 @@ ExpressionTreeNode Operation::Sinh::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Sinh(), children[0]), @@ -227,6 +299,8 @@ ExpressionTreeNode Operation::Cosh::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Subtract(), ExpressionTreeNode(new Operation::Constant(1.0)), @@ -236,6 +310,8 @@ ExpressionTreeNode Operation::Tanh::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Constant(2.0/sqrt(M_PI))), @@ -246,6 +322,8 @@ ExpressionTreeNode Operation::Erf::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Constant(-2.0/sqrt(M_PI))), @@ -264,6 +342,8 @@ ExpressionTreeNode Operation::Delta::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::MultiplyConstant(2.0), children[0]), @@ -271,6 +351,8 @@ ExpressionTreeNode Operation::Square::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::MultiplyConstant(3.0), ExpressionTreeNode(new Operation::Square(), children[0])), @@ -278,6 +360,8 @@ ExpressionTreeNode Operation::Cube::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Negate(), ExpressionTreeNode(new Operation::Reciprocal(), @@ -290,11 +374,15 @@ ExpressionTreeNode Operation::AddConstant::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::MultiplyConstant(value), childDerivs[0]); } ExpressionTreeNode Operation::PowerConstant::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::MultiplyConstant(value), ExpressionTreeNode(new Operation::PowerConstant(value-1), @@ -305,22 +393,18 @@ ExpressionTreeNode Operation::PowerConstant::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { ExpressionTreeNode step(new Operation::Step(), ExpressionTreeNode(new Operation::Subtract(), children[0], children[1])); - return ExpressionTreeNode(new Operation::Subtract(), - ExpressionTreeNode(new Operation::Multiply(), childDerivs[1], step), - ExpressionTreeNode(new Operation::Multiply(), childDerivs[0], - ExpressionTreeNode(new Operation::AddConstant(-1), step))); + return ExpressionTreeNode(new Operation::Select(), {step, childDerivs[1], childDerivs[0]}); } ExpressionTreeNode Operation::Max::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { ExpressionTreeNode step(new Operation::Step(), ExpressionTreeNode(new Operation::Subtract(), children[0], children[1])); - return ExpressionTreeNode(new Operation::Subtract(), - ExpressionTreeNode(new Operation::Multiply(), childDerivs[0], step), - ExpressionTreeNode(new Operation::Multiply(), childDerivs[1], - ExpressionTreeNode(new Operation::AddConstant(-1), step))); + return ExpressionTreeNode(new Operation::Select(), {step, childDerivs[0], childDerivs[1]}); } ExpressionTreeNode Operation::Abs::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); ExpressionTreeNode step(new Operation::Step(), children[0]); return ExpressionTreeNode(new Operation::Multiply(), childDerivs[0], @@ -337,9 +421,5 @@ ExpressionTreeNode Operation::Ceil::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { - vector derivChildren; - derivChildren.push_back(children[0]); - derivChildren.push_back(childDerivs[1]); - derivChildren.push_back(childDerivs[2]); - return ExpressionTreeNode(new Operation::Select(), derivChildren); + return ExpressionTreeNode(new Operation::Select(), {children[0], childDerivs[1], childDerivs[2]}); } diff --git a/lib/colvars/lepton/src/ParsedExpression.cpp b/lib/colvars/lepton/src/ParsedExpression.cpp index fd3b091d3c..f3f18fccd2 100644 --- a/lib/colvars/lepton/src/ParsedExpression.cpp +++ b/lib/colvars/lepton/src/ParsedExpression.cpp @@ -6,7 +6,7 @@ * Biological Structures at Stanford, funded under the NIH Roadmap for * * Medical Research, grant U54 GM072970. See https://simtk.org. * * * - * Portions copyright (c) 2009 Stanford University and the Authors. * + * Portions copyright (c) 2009-2022 Stanford University and the Authors. * * Authors: Peter Eastman * * Contributors: * * * @@ -31,6 +31,7 @@ #include "lepton/ParsedExpression.h" #include "lepton/CompiledExpression.h" +#include "lepton/CompiledVectorExpression.h" #include "lepton/ExpressionProgram.h" #include "lepton/Operation.h" #include @@ -68,9 +69,16 @@ double ParsedExpression::evaluate(const ExpressionTreeNode& node, const map examples; + result.assignTags(examples); + map nodeCache; + result = precalculateConstantSubexpressions(result, nodeCache); while (true) { - ExpressionTreeNode simplified = substituteSimplerExpression(result); + examples.clear(); + result.assignTags(examples); + nodeCache.clear(); + ExpressionTreeNode simplified = substituteSimplerExpression(result, nodeCache); if (simplified == result) break; result = simplified; @@ -80,9 +88,15 @@ ParsedExpression ParsedExpression::optimize() const { ParsedExpression ParsedExpression::optimize(const map& variables) const { ExpressionTreeNode result = preevaluateVariables(getRootNode(), variables); - result = precalculateConstantSubexpressions(result); + vector examples; + result.assignTags(examples); + map nodeCache; + result = precalculateConstantSubexpressions(result, nodeCache); while (true) { - ExpressionTreeNode simplified = substituteSimplerExpression(result); + examples.clear(); + result.assignTags(examples); + nodeCache.clear(); + ExpressionTreeNode simplified = substituteSimplerExpression(result, nodeCache); if (simplified == result) break; result = simplified; @@ -104,36 +118,67 @@ ExpressionTreeNode ParsedExpression::preevaluateVariables(const ExpressionTreeNo return ExpressionTreeNode(node.getOperation().clone(), children); } -ExpressionTreeNode ParsedExpression::precalculateConstantSubexpressions(const ExpressionTreeNode& node) { +ExpressionTreeNode ParsedExpression::precalculateConstantSubexpressions(const ExpressionTreeNode& node, map& nodeCache) { + auto cached = nodeCache.find(node.tag); + if (cached != nodeCache.end()) + return cached->second; vector children(node.getChildren().size()); for (int i = 0; i < (int) children.size(); i++) - children[i] = precalculateConstantSubexpressions(node.getChildren()[i]); + children[i] = precalculateConstantSubexpressions(node.getChildren()[i], nodeCache); ExpressionTreeNode result = ExpressionTreeNode(node.getOperation().clone(), children); - if (node.getOperation().getId() == Operation::VARIABLE || node.getOperation().getId() == Operation::CUSTOM) + if (node.getOperation().getId() == Operation::VARIABLE || node.getOperation().getId() == Operation::CUSTOM) { + nodeCache[node.tag] = result; return result; + } for (int i = 0; i < (int) children.size(); i++) - if (children[i].getOperation().getId() != Operation::CONSTANT) + if (children[i].getOperation().getId() != Operation::CONSTANT) { + nodeCache[node.tag] = result; return result; - return ExpressionTreeNode(new Operation::Constant(evaluate(result, map()))); + } + result = ExpressionTreeNode(new Operation::Constant(evaluate(result, map()))); + nodeCache[node.tag] = result; + return result; } -ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const ExpressionTreeNode& node) { +ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const ExpressionTreeNode& node, map& nodeCache) { vector children(node.getChildren().size()); - for (int i = 0; i < (int) children.size(); i++) - children[i] = substituteSimplerExpression(node.getChildren()[i]); + for (int i = 0; i < (int) children.size(); i++) { + const ExpressionTreeNode& child = node.getChildren()[i]; + auto cached = nodeCache.find(child.tag); + if (cached == nodeCache.end()) { + children[i] = substituteSimplerExpression(child, nodeCache); + nodeCache[child.tag] = children[i]; + } + else + children[i] = cached->second; + } + + // Collect some info on constant expressions in children + bool first_const = children.size() > 0 && isConstant(children[0]); // is first child constant? + bool second_const = children.size() > 1 && isConstant(children[1]); ; // is second child constant? + double first, second; // if yes, value of first and second child + if (first_const) + first = getConstantValue(children[0]); + if (second_const) + second = getConstantValue(children[1]); + switch (node.getOperation().getId()) { case Operation::ADD: { - double first = getConstantValue(children[0]); - double second = getConstantValue(children[1]); - if (first == 0.0) // Add 0 - return children[1]; - if (second == 0.0) // Add 0 - return children[0]; - if (first == first) // Add a constant - return ExpressionTreeNode(new Operation::AddConstant(first), children[1]); - if (second == second) // Add a constant - return ExpressionTreeNode(new Operation::AddConstant(second), children[0]); + if (first_const) { + if (first == 0.0) { // Add 0 + return children[1]; + } else { // Add a constant + return ExpressionTreeNode(new Operation::AddConstant(first), children[1]); + } + } + if (second_const) { + if (second == 0.0) { // Add 0 + return children[0]; + } else { // Add a constant + return ExpressionTreeNode(new Operation::AddConstant(second), children[0]); + } + } if (children[1].getOperation().getId() == Operation::NEGATE) // a+(-b) = a-b return ExpressionTreeNode(new Operation::Subtract(), children[0], children[1].getChildren()[0]); if (children[0].getOperation().getId() == Operation::NEGATE) // (-a)+b = b-a @@ -144,34 +189,35 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio { if (children[0] == children[1]) return ExpressionTreeNode(new Operation::Constant(0.0)); // Subtracting anything from itself is 0 - double first = getConstantValue(children[0]); - if (first == 0.0) // Subtract from 0 - return ExpressionTreeNode(new Operation::Negate(), children[1]); - double second = getConstantValue(children[1]); - if (second == 0.0) // Subtract 0 - return children[0]; - if (second == second) // Subtract a constant - return ExpressionTreeNode(new Operation::AddConstant(-second), children[0]); + if (first_const) { + if (first == 0.0) // Subtract from 0 + return ExpressionTreeNode(new Operation::Negate(), children[1]); + } + if (second_const) { + if (second == 0.0) { // Subtract 0 + return children[0]; + } else { // Subtract a constant + return ExpressionTreeNode(new Operation::AddConstant(-second), children[0]); + } + } if (children[1].getOperation().getId() == Operation::NEGATE) // a-(-b) = a+b return ExpressionTreeNode(new Operation::Add(), children[0], children[1].getChildren()[0]); break; } case Operation::MULTIPLY: - { - double first = getConstantValue(children[0]); - double second = getConstantValue(children[1]); - if (first == 0.0 || second == 0.0) // Multiply by 0 + { + if ((first_const && first == 0.0) || (second_const && second == 0.0)) // Multiply by 0 return ExpressionTreeNode(new Operation::Constant(0.0)); - if (first == 1.0) // Multiply by 1 + if (first_const && first == 1.0) // Multiply by 1 return children[1]; - if (second == 1.0) // Multiply by 1 + if (second_const && second == 1.0) // Multiply by 1 return children[0]; - if (children[0].getOperation().getId() == Operation::CONSTANT) { // Multiply by a constant + if (first_const) { // Multiply by a constant if (children[1].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine two multiplies into a single one return ExpressionTreeNode(new Operation::MultiplyConstant(first*dynamic_cast(&children[1].getOperation())->getValue()), children[1].getChildren()[0]); return ExpressionTreeNode(new Operation::MultiplyConstant(first), children[1]); } - if (children[1].getOperation().getId() == Operation::CONSTANT) { // Multiply by a constant + if (second_const) { // Multiply by a constant if (children[0].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine two multiplies into a single one return ExpressionTreeNode(new Operation::MultiplyConstant(second*dynamic_cast(&children[0].getOperation())->getValue()), children[0].getChildren()[0]); return ExpressionTreeNode(new Operation::MultiplyConstant(second), children[0]); @@ -202,18 +248,16 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio { if (children[0] == children[1]) return ExpressionTreeNode(new Operation::Constant(1.0)); // Dividing anything from itself is 0 - double numerator = getConstantValue(children[0]); - if (numerator == 0.0) // 0 divided by something + if (first_const && first == 0.0) // 0 divided by something return ExpressionTreeNode(new Operation::Constant(0.0)); - if (numerator == 1.0) // 1 divided by something + if (first_const && first == 1.0) // 1 divided by something return ExpressionTreeNode(new Operation::Reciprocal(), children[1]); - double denominator = getConstantValue(children[1]); - if (denominator == 1.0) // Divide by 1 + if (second_const && second == 1.0) // Divide by 1 return children[0]; - if (children[1].getOperation().getId() == Operation::CONSTANT) { + if (second_const) { if (children[0].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine a multiply and a divide into one multiply - return ExpressionTreeNode(new Operation::MultiplyConstant(dynamic_cast(&children[0].getOperation())->getValue()/denominator), children[0].getChildren()[0]); - return ExpressionTreeNode(new Operation::MultiplyConstant(1.0/denominator), children[0]); // Replace a divide with a multiply + return ExpressionTreeNode(new Operation::MultiplyConstant(dynamic_cast(&children[0].getOperation())->getValue()/second), children[0].getChildren()[0]); + return ExpressionTreeNode(new Operation::MultiplyConstant(1.0/second), children[0]); // Replace a divide with a multiply } if (children[0].getOperation().getId() == Operation::NEGATE && children[1].getOperation().getId() == Operation::NEGATE) // The two negations cancel return ExpressionTreeNode(new Operation::Divide(), children[0].getChildren()[0], children[1].getChildren()[0]); @@ -229,34 +273,34 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio } case Operation::POWER: { - double base = getConstantValue(children[0]); - if (base == 0.0) // 0 to any power is 0 + if (first_const && first == 0.0) // 0 to any power is 0 return ExpressionTreeNode(new Operation::Constant(0.0)); - if (base == 1.0) // 1 to any power is 1 + if (first_const && first == 1.0) // 1 to any power is 1 return ExpressionTreeNode(new Operation::Constant(1.0)); - double exponent = getConstantValue(children[1]); - if (exponent == 0.0) // x^0 = 1 - return ExpressionTreeNode(new Operation::Constant(1.0)); - if (exponent == 1.0) // x^1 = x - return children[0]; - if (exponent == -1.0) // x^-1 = recip(x) - return ExpressionTreeNode(new Operation::Reciprocal(), children[0]); - if (exponent == 2.0) // x^2 = square(x) - return ExpressionTreeNode(new Operation::Square(), children[0]); - if (exponent == 3.0) // x^3 = cube(x) - return ExpressionTreeNode(new Operation::Cube(), children[0]); - if (exponent == 0.5) // x^0.5 = sqrt(x) - return ExpressionTreeNode(new Operation::Sqrt(), children[0]); - if (exponent == exponent) // Constant power - return ExpressionTreeNode(new Operation::PowerConstant(exponent), children[0]); + if (second_const) { // Constant exponent + if (second == 0.0) // x^0 = 1 + return ExpressionTreeNode(new Operation::Constant(1.0)); + if (second == 1.0) // x^1 = x + return children[0]; + if (second == -1.0) // x^-1 = recip(x) + return ExpressionTreeNode(new Operation::Reciprocal(), children[0]); + if (second == 2.0) // x^2 = square(x) + return ExpressionTreeNode(new Operation::Square(), children[0]); + if (second == 3.0) // x^3 = cube(x) + return ExpressionTreeNode(new Operation::Cube(), children[0]); + if (second == 0.5) // x^0.5 = sqrt(x) + return ExpressionTreeNode(new Operation::Sqrt(), children[0]); + // Constant power + return ExpressionTreeNode(new Operation::PowerConstant(second), children[0]); + } break; } case Operation::NEGATE: { if (children[0].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine a multiply and a negate into a single multiply return ExpressionTreeNode(new Operation::MultiplyConstant(-dynamic_cast(&children[0].getOperation())->getValue()), children[0].getChildren()[0]); - if (children[0].getOperation().getId() == Operation::CONSTANT) // Negate a constant - return ExpressionTreeNode(new Operation::Constant(-getConstantValue(children[0]))); + if (first_const) // Negate a constant + return ExpressionTreeNode(new Operation::Constant(-first)); if (children[0].getOperation().getId() == Operation::NEGATE) // The two negations cancel return children[0].getChildren()[0]; break; @@ -265,7 +309,7 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio { if (children[0].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine two multiplies into a single one return ExpressionTreeNode(new Operation::MultiplyConstant(dynamic_cast(&node.getOperation())->getValue()*dynamic_cast(&children[0].getOperation())->getValue()), children[0].getChildren()[0]); - if (children[0].getOperation().getId() == Operation::CONSTANT) // Multiply two constants + if (first_const) // Multiply two constants return ExpressionTreeNode(new Operation::Constant(dynamic_cast(&node.getOperation())->getValue()*getConstantValue(children[0]))); if (children[0].getOperation().getId() == Operation::NEGATE) // Combine a multiply and a negate into a single multiply return ExpressionTreeNode(new Operation::MultiplyConstant(-dynamic_cast(&node.getOperation())->getValue()), children[0].getChildren()[0]); @@ -293,20 +337,33 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio } ParsedExpression ParsedExpression::differentiate(const string& variable) const { - return differentiate(getRootNode(), variable); + vector examples; + getRootNode().assignTags(examples); + map nodeCache; + return differentiate(getRootNode(), variable, nodeCache); } -ExpressionTreeNode ParsedExpression::differentiate(const ExpressionTreeNode& node, const string& variable) { +ExpressionTreeNode ParsedExpression::differentiate(const ExpressionTreeNode& node, const string& variable, map& nodeCache) { + auto cached = nodeCache.find(node.tag); + if (cached != nodeCache.end()) + return cached->second; vector childDerivs(node.getChildren().size()); for (int i = 0; i < (int) childDerivs.size(); i++) - childDerivs[i] = differentiate(node.getChildren()[i], variable); - return node.getOperation().differentiate(node.getChildren(),childDerivs, variable); + childDerivs[i] = differentiate(node.getChildren()[i], variable, nodeCache); + ExpressionTreeNode result = node.getOperation().differentiate(node.getChildren(), childDerivs, variable); + nodeCache[node.tag] = result; + return result; +} + +bool ParsedExpression::isConstant(const ExpressionTreeNode& node) { + return (node.getOperation().getId() == Operation::CONSTANT); } double ParsedExpression::getConstantValue(const ExpressionTreeNode& node) { - if (node.getOperation().getId() == Operation::CONSTANT) - return dynamic_cast(node.getOperation()).getValue(); - return numeric_limits::quiet_NaN(); + if (node.getOperation().getId() != Operation::CONSTANT) { + throw Exception("getConstantValue called on a non-constant ExpressionNode"); + } + return dynamic_cast(node.getOperation()).getValue(); } ExpressionProgram ParsedExpression::createProgram() const { @@ -317,6 +374,10 @@ CompiledExpression ParsedExpression::createCompiledExpression() const { return CompiledExpression(*this); } +CompiledVectorExpression ParsedExpression::createCompiledVectorExpression(int width) const { + return CompiledVectorExpression(*this, width); +} + ParsedExpression ParsedExpression::renameVariables(const map& replacements) const { return ParsedExpression(renameNodeVariables(getRootNode(), replacements)); } diff --git a/lib/colvars/lepton/src/Parser.cpp b/lib/colvars/lepton/src/Parser.cpp index e284add258..47ebac464a 100644 --- a/lib/colvars/lepton/src/Parser.cpp +++ b/lib/colvars/lepton/src/Parser.cpp @@ -66,7 +66,7 @@ private: string Parser::trim(const string& expression) { // Remove leading and trailing spaces. - + int start, end; for (start = 0; start < (int) expression.size() && isspace(expression[start]); start++) ;