diff --git a/lib/lepton/include/lepton/CompiledExpression.h b/lib/lepton/include/lepton/CompiledExpression.h index 8ead5ce96f..6c940e081c 100644 --- a/lib/lepton/include/lepton/CompiledExpression.h +++ b/lib/lepton/include/lepton/CompiledExpression.h @@ -9,7 +9,7 @@ * Biological Structures at Stanford, funded under the NIH Roadmap for * * Medical Research, grant U54 GM072970. See https://simtk.org. * * * - * Portions copyright (c) 2013-2019 Stanford University and the Authors. * + * Portions copyright (c) 2013-2022 Stanford University and the Authors. * * Authors: Peter Eastman * * Contributors: * * * @@ -40,7 +40,11 @@ #include #include #ifdef LEPTON_USE_JIT - #include "asmjit.h" +#if defined(__ARM__) || defined(__ARM64__) +#include "asmjit/a64.h" +#else +#include "asmjit/x86.h" +#endif #endif namespace LMP_Lepton { @@ -101,9 +105,15 @@ private: std::map dummyVariables; double (*jitCode)(); #ifdef LEPTON_USE_JIT + void findPowerGroups(std::vector >& groups, std::vector >& groupPowers, std::vector& stepGroup); void generateJitCode(); - void generateSingleArgCall(asmjit::X86Compiler& c, asmjit::X86Xmm& dest, asmjit::X86Xmm& arg, double (*function)(double)); - void generateTwoArgCall(asmjit::X86Compiler& c, asmjit::X86Xmm& dest, asmjit::X86Xmm& arg1, asmjit::X86Xmm& arg2, double (*function)(double, double)); +#if defined(__ARM__) || defined(__ARM64__) + void generateSingleArgCall(asmjit::a64::Compiler& c, asmjit::arm::Vec& dest, asmjit::arm::Vec& arg, double (*function)(double)); + void generateTwoArgCall(asmjit::a64::Compiler& c, asmjit::arm::Vec& dest, asmjit::arm::Vec& arg1, asmjit::arm::Vec& arg2, double (*function)(double, double)); +#else + void generateSingleArgCall(asmjit::x86::Compiler& c, asmjit::x86::Xmm& dest, asmjit::x86::Xmm& arg, double (*function)(double)); + void generateTwoArgCall(asmjit::x86::Compiler& c, asmjit::x86::Xmm& dest, asmjit::x86::Xmm& arg1, asmjit::x86::Xmm& arg2, double (*function)(double, double)); +#endif std::vector constants; asmjit::JitRuntime runtime; #endif diff --git a/lib/lepton/include/lepton/CompiledVectorExpression.h b/lib/lepton/include/lepton/CompiledVectorExpression.h new file mode 100644 index 0000000000..e097e3eae1 --- /dev/null +++ b/lib/lepton/include/lepton/CompiledVectorExpression.h @@ -0,0 +1,145 @@ +#ifndef LEPTON_VECTOR_EXPRESSION_H_ +#define LEPTON_VECTOR_EXPRESSION_H_ + +/* -------------------------------------------------------------------------- * + * Lepton * + * -------------------------------------------------------------------------- * + * This is part of the Lepton expression parser originating from * + * Simbios, the NIH National Center for Physics-Based Simulation of * + * Biological Structures at Stanford, funded under the NIH Roadmap for * + * Medical Research, grant U54 GM072970. See https://simtk.org. * + * * + * Portions copyright (c) 2013-2022 Stanford University and the Authors. * + * Authors: Peter Eastman * + * Contributors: * + * * + * Permission is hereby granted, free of charge, to any person obtaining a * + * copy of this software and associated documentation files (the "Software"), * + * to deal in the Software without restriction, including without limitation * + * the rights to use, copy, modify, merge, publish, distribute, sublicense, * + * and/or sell copies of the Software, and to permit persons to whom the * + * Software is furnished to do so, subject to the following conditions: * + * * + * The above copyright notice and this permission notice shall be included in * + * all copies or substantial portions of the Software. * + * * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL * + * THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, * + * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR * + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE * + * USE OR OTHER DEALINGS IN THE SOFTWARE. * + * -------------------------------------------------------------------------- */ + +#include "ExpressionTreeNode.h" +#include "windowsIncludes.h" +#include +#include +#include +#include +#include +#include +#ifdef LEPTON_USE_JIT +#if defined(__ARM__) || defined(__ARM64__) +#include "asmjit/a64.h" +#else +#include "asmjit/x86.h" +#endif +#endif + +namespace LMP_Lepton { + +class Operation; +class ParsedExpression; + +/** + * A CompiledVectorExpression is a highly optimized representation of an expression for cases when you want to evaluate + * it many times as quickly as possible. It is similar to CompiledExpression, with the extra feature that it uses the CPU's + * vector unit (AVX on x86, NEON on ARM) to evaluate the expression for multiple sets of arguments at once. It also differs + * from CompiledExpression and ParsedExpression in using single precision rather than double precision to evaluate the expression. + * You should treat it as an opaque object; none of the internal representation is visible. + * + * A CompiledVectorExpression is created by calling createCompiledVectorExpression() on a ParsedExpression. When you create + * it, you must specify the width of the vectors on which to compute the expression. The allowed widths depend on the type of + * CPU it is running on. 4 is always allowed, and 8 is allowed on x86 processors with AVX. Call getAllowedWidths() to query + * the allowed values. + * + * WARNING: CompiledVectorExpression is NOT thread safe. You should never access a CompiledVectorExpression from two threads at + * the same time. + */ + +class LEPTON_EXPORT CompiledVectorExpression { +public: + CompiledVectorExpression(); + CompiledVectorExpression(const CompiledVectorExpression& expression); + ~CompiledVectorExpression(); + CompiledVectorExpression& operator=(const CompiledVectorExpression& expression); + /** + * Get the width of the vectors on which the expression is computed. + */ + int getWidth() const; + /** + * Get the names of all variables used by this expression. + */ + const std::set& getVariables() const; + /** + * Get a pointer to the memory location where the value of a particular variable is stored. This can be used + * to set the value of the variable before calling evaluate(). + * + * @param name the name of the variable to query + * @return a pointer to N floating point values, where N is the vector width + */ + float* getVariablePointer(const std::string& name); + /** + * You can optionally specify the memory locations from which the values of variables should be read. + * This is useful, for example, when several expressions all use the same variable. You can then set + * the value of that variable in one place, and it will be seen by all of them. The location should + * be a pointer to N floating point values, where N is the vector width. + */ + void setVariableLocations(std::map& variableLocations); + /** + * Evaluate the expression. The values of all variables should have been set before calling this. + * + * @return a pointer to N floating point values, where N is the vector width + */ + const float* evaluate() const; + /** + * Get the list of vector widths that are supported on the current processor. + */ + static const std::vector& getAllowedWidths(); +private: + friend class ParsedExpression; + CompiledVectorExpression(const ParsedExpression& expression, int width); + void compileExpression(const ExpressionTreeNode& node, std::vector >& temps, int& workspaceSize); + int findTempIndex(const ExpressionTreeNode& node, std::vector >& temps); + int width; + std::map variablePointers; + std::vector > variablesToCopy; + std::vector > arguments; + std::vector target; + std::vector operation; + std::map variableIndices; + std::set variableNames; + mutable std::vector workspace; + mutable std::vector argValues; + std::map dummyVariables; + void (*jitCode)(); +#ifdef LEPTON_USE_JIT + void findPowerGroups(std::vector >& groups, std::vector >& groupPowers, std::vector& stepGroup); + void generateJitCode(); +#if defined(__ARM__) || defined(__ARM64__) + void generateSingleArgCall(asmjit::a64::Compiler& c, asmjit::arm::Vec& dest, asmjit::arm::Vec& arg, float (*function)(float)); + void generateTwoArgCall(asmjit::a64::Compiler& c, asmjit::arm::Vec& dest, asmjit::arm::Vec& arg1, asmjit::arm::Vec& arg2, float (*function)(float, float)); +#else + void generateSingleArgCall(asmjit::x86::Compiler& c, asmjit::x86::Ymm& dest, asmjit::x86::Ymm& arg, float (*function)(float)); + void generateTwoArgCall(asmjit::x86::Compiler& c, asmjit::x86::Ymm& dest, asmjit::x86::Ymm& arg1, asmjit::x86::Ymm& arg2, float (*function)(float, float)); +#endif + std::vector constants; + asmjit::JitRuntime runtime; +#endif +}; + +} // namespace LMP_Lepton + +#endif /*LEPTON_VECTOR_EXPRESSION_H_*/ diff --git a/lib/lepton/include/lepton/ExpressionTreeNode.h b/lib/lepton/include/lepton/ExpressionTreeNode.h index 514cc008a9..eba791fbaa 100644 --- a/lib/lepton/include/lepton/ExpressionTreeNode.h +++ b/lib/lepton/include/lepton/ExpressionTreeNode.h @@ -9,7 +9,7 @@ * Biological Structures at Stanford, funded under the NIH Roadmap for * * Medical Research, grant U54 GM072970. See https://simtk.org. * * * - * Portions copyright (c) 2009 Stanford University and the Authors. * + * Portions copyright (c) 2009-2021 Stanford University and the Authors. * * Authors: Peter Eastman * * Contributors: * * * @@ -39,6 +39,7 @@ namespace LMP_Lepton { class Operation; +class ParsedExpression; /** * This class represents a node in the abstract syntax tree representation of an expression. @@ -82,11 +83,13 @@ public: */ ExpressionTreeNode(Operation* operation); ExpressionTreeNode(const ExpressionTreeNode& node); + ExpressionTreeNode(ExpressionTreeNode&& node); ExpressionTreeNode(); ~ExpressionTreeNode(); bool operator==(const ExpressionTreeNode& node) const; bool operator!=(const ExpressionTreeNode& node) const; ExpressionTreeNode& operator=(const ExpressionTreeNode& node); + ExpressionTreeNode& operator=(ExpressionTreeNode&& node); /** * Get the Operation performed by this node. */ @@ -96,8 +99,11 @@ public: */ const std::vector& getChildren() const; private: + friend class ParsedExpression; + void assignTags(std::vector& examples) const; Operation* operation; std::vector children; + mutable int tag; }; } // namespace LMP_Lepton diff --git a/lib/lepton/include/lepton/ParsedExpression.h b/lib/lepton/include/lepton/ParsedExpression.h index 586acb4d2c..05081f677c 100644 --- a/lib/lepton/include/lepton/ParsedExpression.h +++ b/lib/lepton/include/lepton/ParsedExpression.h @@ -9,7 +9,7 @@ * Biological Structures at Stanford, funded under the NIH Roadmap for * * Medical Research, grant U54 GM072970. See https://simtk.org. * * * - * Portions copyright (c) 2009=2013 Stanford University and the Authors. * + * Portions copyright (c) 2009-2022 Stanford University and the Authors. * * Authors: Peter Eastman * * Contributors: * * * @@ -41,6 +41,7 @@ namespace LMP_Lepton { class CompiledExpression; class ExpressionProgram; +class CompiledVectorExpression; /** * This class represents the result of parsing an expression. It provides methods for working with the @@ -102,6 +103,16 @@ public: * Create a CompiledExpression that represents the same calculation as this expression. */ CompiledExpression createCompiledExpression() const; + /** + * Create a CompiledVectorExpression that allows the expression to be evaluated efficiently + * using the CPU's vector unit. + * + * @param width the width of the vectors to evaluate it on. The allowed values + * depend on the CPU. 4 is always allowed, and 8 is allowed on + * x86 processors with AVX. Call CompiledVectorExpression::getAllowedWidths() + * to query the allowed widths on the current processor. + */ + CompiledVectorExpression createCompiledVectorExpression(int width) const; /** * Create a new ParsedExpression which is identical to this one, except that the names of some * variables have been changed. @@ -113,9 +124,9 @@ public: private: static double evaluate(const ExpressionTreeNode& node, const std::map& variables); static ExpressionTreeNode preevaluateVariables(const ExpressionTreeNode& node, const std::map& variables); - static ExpressionTreeNode precalculateConstantSubexpressions(const ExpressionTreeNode& node); - static ExpressionTreeNode substituteSimplerExpression(const ExpressionTreeNode& node); - static ExpressionTreeNode differentiate(const ExpressionTreeNode& node, const std::string& variable); + static ExpressionTreeNode precalculateConstantSubexpressions(const ExpressionTreeNode& node, std::map& nodeCache); + static ExpressionTreeNode substituteSimplerExpression(const ExpressionTreeNode& node, std::map& nodeCache); + static ExpressionTreeNode differentiate(const ExpressionTreeNode& node, const std::string& variable, std::map& nodeCache); static bool isConstant(const ExpressionTreeNode& node); static double getConstantValue(const ExpressionTreeNode& node); static ExpressionTreeNode renameNodeVariables(const ExpressionTreeNode& node, const std::map& replacements); diff --git a/lib/lepton/src/CompiledExpression.cpp b/lib/lepton/src/CompiledExpression.cpp index c6c1543ce4..b85c3a08f7 100644 --- a/lib/lepton/src/CompiledExpression.cpp +++ b/lib/lepton/src/CompiledExpression.cpp @@ -6,7 +6,7 @@ * Biological Structures at Stanford, funded under the NIH Roadmap for * * Medical Research, grant U54 GM072970. See https://simtk.org. * * * - * Portions copyright (c) 2013-2019 Stanford University and the Authors. * + * Portions copyright (c) 2013-2022 Stanford University and the Authors. * * Authors: Peter Eastman * * Contributors: * * * @@ -151,7 +151,7 @@ void CompiledExpression::setVariableLocations(map& variableLoca if (workspace.size() > 0) generateJitCode(); -#else +#endif // Make a list of all variables we will need to copy before evaluating the expression. variablesToCopy.clear(); @@ -160,13 +160,11 @@ void CompiledExpression::setVariableLocations(map& variableLoca if (pointer != variablePointers.end()) variablesToCopy.push_back(make_pair(&workspace[iter->second], pointer->second)); } -#endif } double CompiledExpression::evaluate() const { -#ifdef LEPTON_USE_JIT - return jitCode(); -#else + if (jitCode) + return jitCode(); for (int i = 0; i < (int)variablesToCopy.size(); i++) *variablesToCopy[i].first = *variablesToCopy[i].second; @@ -183,7 +181,6 @@ double CompiledExpression::evaluate() const { } } return workspace[workspace.size()-1]; -#endif } #ifdef LEPTON_USE_JIT @@ -192,24 +189,70 @@ static double evaluateOperation(Operation* op, double* args) { return op->evaluate(args, dummyVariables); } +void CompiledExpression::findPowerGroups(vector >& groups, vector >& groupPowers, vector& stepGroup) { + // Identify every step that raises an argument to an integer power. + + vector stepPower(operation.size(), 0); + vector stepArg(operation.size(), -1); + for (int step = 0; step < (int)operation.size(); step++) { + Operation& op = *operation[step]; + int power = 0; + if (op.getId() == Operation::SQUARE) + power = 2; + else if (op.getId() == Operation::CUBE) + power = 3; + else if (op.getId() == Operation::POWER_CONSTANT) { + double realPower = dynamic_cast(&op)->getValue(); + if (realPower == (int) realPower) + power = (int) realPower; + } + if (power != 0) { + stepPower[step] = power; + stepArg[step] = arguments[step][0]; + } + } + + // Find groups that operate on the same argument and whose powers have the same sign. + + stepGroup.resize(operation.size(), -1); + for (int i = 0; i < (int)operation.size(); i++) { + if (stepGroup[i] != -1) + continue; + vector group, power; + for (int j = i; j < (int)operation.size(); j++) { + if (stepArg[i] == stepArg[j] && stepPower[i]*stepPower[j] > 0) { + stepGroup[j] = groups.size(); + group.push_back(j); + power.push_back(stepPower[j]); + } + } + groups.push_back(group); + groupPowers.push_back(power); + } +} + +#if defined(__ARM__) || defined(__ARM64__) void CompiledExpression::generateJitCode() { CodeHolder code; - code.init(runtime.getCodeInfo()); - X86Compiler c(&code); - c.addFunc(FuncSignature0()); - vector workspaceVar(workspace.size()); + code.init(runtime.environment()); + a64::Compiler c(&code); + c.addFunc(FuncSignatureT()); + vector workspaceVar(workspace.size()); for (int i = 0; i < (int) workspaceVar.size(); i++) - workspaceVar[i] = c.newXmmSd(); - X86Gp argsPointer = c.newIntPtr(); - c.mov(argsPointer, imm_ptr(&argValues[0])); + workspaceVar[i] = c.newVecD(); + arm::Gp argsPointer = c.newIntPtr(); + c.mov(argsPointer, imm(&argValues[0])); + vector > groups, groupPowers; + vector stepGroup; + findPowerGroups(groups, groupPowers, stepGroup); // Load the arguments into variables. for (set::const_iterator iter = variableNames.begin(); iter != variableNames.end(); ++iter) { map::iterator index = variableIndices.find(*iter); - X86Gp variablePointer = c.newIntPtr(); - c.mov(variablePointer, imm_ptr(&getVariableReference(index->first))); - c.movsd(workspaceVar[index->second], x86::ptr(variablePointer, 0, 0)); + arm::Gp variablePointer = c.newIntPtr(); + c.mov(variablePointer, imm(&getVariableReference(index->first))); + c.ldr(workspaceVar[index->second], arm::ptr(variablePointer, 0)); } // Make a list of all constants that will be needed for evaluation. @@ -232,6 +275,12 @@ void CompiledExpression::generateJitCode() { value = 1.0; else if (op.getId() == Operation::DELTA) value = 1.0; + else if (op.getId() == Operation::POWER_CONSTANT) { + if (stepGroup[step] == -1) + value = dynamic_cast(op).getValue(); + else + value = 1.0; + } else continue; @@ -250,19 +299,63 @@ void CompiledExpression::generateJitCode() { // Load constants into variables. - vector constantVar(constants.size()); + vector constantVar(constants.size()); if (constants.size() > 0) { - X86Gp constantsPointer = c.newIntPtr(); - c.mov(constantsPointer, imm_ptr(&constants[0])); + arm::Gp constantsPointer = c.newIntPtr(); + c.mov(constantsPointer, imm(&constants[0])); for (int i = 0; i < (int) constants.size(); i++) { - constantVar[i] = c.newXmmSd(); - c.movsd(constantVar[i], x86::ptr(constantsPointer, 8*i, 0)); + constantVar[i] = c.newVecD(); + c.ldr(constantVar[i], arm::ptr(constantsPointer, 8*i)); } } // Evaluate the operations. + vector hasComputedPower(operation.size(), false); for (int step = 0; step < (int) operation.size(); step++) { + if (hasComputedPower[step]) + continue; + + // When one or more steps involve raising the same argument to multiple integer + // powers, we can compute them all together for efficiency. + + if (stepGroup[step] != -1) { + vector& group = groups[stepGroup[step]]; + vector& powers = groupPowers[stepGroup[step]]; + arm::Vec multiplier = c.newVecD(); + if (powers[0] > 0) + c.fmov(multiplier, workspaceVar[arguments[step][0]]); + else { + c.fdiv(multiplier, constantVar[operationConstantIndex[step]], workspaceVar[arguments[step][0]]); + for (int i = 0; i < powers.size(); i++) + powers[i] = -powers[i]; + } + vector hasAssigned(group.size(), false); + bool done = false; + while (!done) { + done = true; + for (int i = 0; i < group.size(); i++) { + if (powers[i]%2 == 1) { + if (!hasAssigned[i]) + c.fmov(workspaceVar[target[group[i]]], multiplier); + else + c.fmul(workspaceVar[target[group[i]]], workspaceVar[target[group[i]]], multiplier); + hasAssigned[i] = true; + } + powers[i] >>= 1; + if (powers[i] != 0) + done = false; + } + if (!done) + c.fmul(multiplier, multiplier, multiplier); + } + for (int step : group) + hasComputedPower[step] = true; + continue; + } + + // Evaluate the step. + Operation& op = *operation[step]; vector args = arguments[step]; if (args.size() == 1) { @@ -276,33 +369,28 @@ void CompiledExpression::generateJitCode() { switch (op.getId()) { case Operation::CONSTANT: - c.movsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + c.fmov(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); break; case Operation::ADD: - c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); - c.addsd(workspaceVar[target[step]], workspaceVar[args[1]]); + c.fadd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); break; case Operation::SUBTRACT: - c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); - c.subsd(workspaceVar[target[step]], workspaceVar[args[1]]); + c.fsub(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); break; case Operation::MULTIPLY: - c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); - c.mulsd(workspaceVar[target[step]], workspaceVar[args[1]]); + c.fmul(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); break; case Operation::DIVIDE: - c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); - c.divsd(workspaceVar[target[step]], workspaceVar[args[1]]); + c.fdiv(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); break; case Operation::POWER: generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], pow); break; case Operation::NEGATE: - c.xorps(workspaceVar[target[step]], workspaceVar[target[step]]); - c.subsd(workspaceVar[target[step]], workspaceVar[args[0]]); + c.fneg(workspaceVar[target[step]], workspaceVar[args[0]]); break; case Operation::SQRT: - c.sqrtsd(workspaceVar[target[step]], workspaceVar[args[0]]); + c.fsqrt(workspaceVar[target[step]], workspaceVar[args[0]]); break; case Operation::EXP: generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], exp); @@ -341,56 +429,63 @@ void CompiledExpression::generateJitCode() { generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tanh); break; case Operation::STEP: - c.xorps(workspaceVar[target[step]], workspaceVar[target[step]]); - c.cmpsd(workspaceVar[target[step]], workspaceVar[args[0]], imm(18)); // Comparison mode is _CMP_LE_OQ = 18 - c.andps(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + c.cmge(workspaceVar[target[step]], workspaceVar[args[0]], imm(0)); + c.and_(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); break; case Operation::DELTA: - c.xorps(workspaceVar[target[step]], workspaceVar[target[step]]); - c.cmpsd(workspaceVar[target[step]], workspaceVar[args[0]], imm(16)); // Comparison mode is _CMP_EQ_OS = 16 - c.andps(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + c.cmeq(workspaceVar[target[step]], workspaceVar[args[0]], imm(0)); + c.and_(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); break; case Operation::SQUARE: - c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); - c.mulsd(workspaceVar[target[step]], workspaceVar[args[0]]); + c.fmul(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]]); break; case Operation::CUBE: - c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); - c.mulsd(workspaceVar[target[step]], workspaceVar[args[0]]); - c.mulsd(workspaceVar[target[step]], workspaceVar[args[0]]); + c.fmul(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]]); + c.fmul(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]]); break; case Operation::RECIPROCAL: - c.movsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); - c.divsd(workspaceVar[target[step]], workspaceVar[args[0]]); + c.fdiv(workspaceVar[target[step]], constantVar[operationConstantIndex[step]], workspaceVar[args[0]]); break; case Operation::ADD_CONSTANT: - c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); - c.addsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + c.fadd(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]); break; case Operation::MULTIPLY_CONSTANT: - c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); - c.mulsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + c.fmul(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]); + break; + case Operation::POWER_CONSTANT: + generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]], pow); + break; + case Operation::MIN: + c.fmin(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::MAX: + c.fmax(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); break; case Operation::ABS: - generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], fabs); + c.fabs(workspaceVar[target[step]], workspaceVar[args[0]]); break; case Operation::FLOOR: - generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], floor); + c.frintm(workspaceVar[target[step]], workspaceVar[args[0]]); break; case Operation::CEIL: - generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], ceil); + c.frintp(workspaceVar[target[step]], workspaceVar[args[0]]); + break; + case Operation::SELECT: + c.fcmeq(workspaceVar[target[step]], workspaceVar[args[0]], imm(0)); + c.bsl(workspaceVar[target[step]], workspaceVar[args[2]], workspaceVar[args[1]]); break; default: // Just invoke evaluateOperation(). for (int i = 0; i < (int) args.size(); i++) - c.movsd(x86::ptr(argsPointer, 8*i, 0), workspaceVar[args[i]]); - X86Gp fn = c.newIntPtr(); - c.mov(fn, imm_ptr((void*) evaluateOperation)); - CCFuncCall* call = c.call(fn, FuncSignature2()); - call->setArg(0, imm_ptr(&op)); - call->setArg(1, imm_ptr(&argValues[0])); - call->setRet(0, workspaceVar[target[step]]); + c.str(workspaceVar[args[i]], arm::ptr(argsPointer, 8*i)); + arm::Gp fn = c.newIntPtr(); + c.mov(fn, imm((void*) evaluateOperation)); + InvokeNode* invoke; + c.invoke(&invoke, fn, FuncSignatureT()); + invoke->setArg(0, imm(&op)); + invoke->setArg(1, imm(&argValues[0])); + invoke->setRet(0, workspaceVar[target[step]]); } } c.ret(workspaceVar[workspace.size()-1]); @@ -399,20 +494,319 @@ void CompiledExpression::generateJitCode() { runtime.add(&jitCode, &code); } -void CompiledExpression::generateSingleArgCall(X86Compiler& c, X86Xmm& dest, X86Xmm& arg, double (*function)(double)) { - X86Gp fn = c.newIntPtr(); - c.mov(fn, imm_ptr((void*) function)); - CCFuncCall* call = c.call(fn, FuncSignature1()); - call->setArg(0, arg); - call->setRet(0, dest); +void CompiledExpression::generateSingleArgCall(a64::Compiler& c, arm::Vec& dest, arm::Vec& arg, double (*function)(double)) { + arm::Gp fn = c.newIntPtr(); + c.mov(fn, imm((void*) function)); + InvokeNode* invoke; + c.invoke(&invoke, fn, FuncSignatureT()); + invoke->setArg(0, arg); + invoke->setRet(0, dest); } -void CompiledExpression::generateTwoArgCall(X86Compiler& c, X86Xmm& dest, X86Xmm& arg1, X86Xmm& arg2, double (*function)(double, double)) { - X86Gp fn = c.newIntPtr(); - c.mov(fn, imm_ptr((void*) function)); - CCFuncCall* call = c.call(fn, FuncSignature2()); - call->setArg(0, arg1); - call->setArg(1, arg2); - call->setRet(0, dest); +void CompiledExpression::generateTwoArgCall(a64::Compiler& c, arm::Vec& dest, arm::Vec& arg1, arm::Vec& arg2, double (*function)(double, double)) { + arm::Gp fn = c.newIntPtr(); + c.mov(fn, imm((void*) function)); + InvokeNode* invoke; + c.invoke(&invoke, fn, FuncSignatureT()); + invoke->setArg(0, arg1); + invoke->setArg(1, arg2); + invoke->setRet(0, dest); +} +#else +void CompiledExpression::generateJitCode() { + const CpuInfo& cpu = CpuInfo::host(); + if (!cpu.hasFeature(CpuFeatures::X86::kAVX)) + return; + CodeHolder code; + code.init(runtime.environment()); + x86::Compiler c(&code); + FuncNode* funcNode = c.addFunc(FuncSignatureT()); + funcNode->frame().setAvxEnabled(); + vector workspaceVar(workspace.size()); + for (int i = 0; i < (int) workspaceVar.size(); i++) + workspaceVar[i] = c.newXmmSd(); + x86::Gp argsPointer = c.newIntPtr(); + c.mov(argsPointer, imm(&argValues[0])); + vector > groups, groupPowers; + vector stepGroup; + findPowerGroups(groups, groupPowers, stepGroup); + + // Load the arguments into variables. + + x86::Gp variablePointer = c.newIntPtr(); + for (set::const_iterator iter = variableNames.begin(); iter != variableNames.end(); ++iter) { + map::iterator index = variableIndices.find(*iter); + c.mov(variablePointer, imm(&getVariableReference(index->first))); + c.vmovsd(workspaceVar[index->second], x86::ptr(variablePointer, 0, 0)); + } + + // Make a list of all constants that will be needed for evaluation. + + vector operationConstantIndex(operation.size(), -1); + for (int step = 0; step < (int) operation.size(); step++) { + // Find the constant value (if any) used by this operation. + + Operation& op = *operation[step]; + double value; + if (op.getId() == Operation::CONSTANT) + value = dynamic_cast(op).getValue(); + else if (op.getId() == Operation::ADD_CONSTANT) + value = dynamic_cast(op).getValue(); + else if (op.getId() == Operation::MULTIPLY_CONSTANT) + value = dynamic_cast(op).getValue(); + else if (op.getId() == Operation::RECIPROCAL) + value = 1.0; + else if (op.getId() == Operation::STEP) + value = 1.0; + else if (op.getId() == Operation::DELTA) + value = 1.0; + else if (op.getId() == Operation::ABS) { + long long mask = 0x7FFFFFFFFFFFFFFF; + value = *reinterpret_cast(&mask); + } + else if (op.getId() == Operation::POWER_CONSTANT) { + if (stepGroup[step] == -1) + value = dynamic_cast(op).getValue(); + else + value = 1.0; + } + else + continue; + + // See if we already have a variable for this constant. + + for (int i = 0; i < (int) constants.size(); i++) + if (value == constants[i]) { + operationConstantIndex[step] = i; + break; + } + if (operationConstantIndex[step] == -1) { + operationConstantIndex[step] = constants.size(); + constants.push_back(value); + } + } + + // Load constants into variables. + + vector constantVar(constants.size()); + if (constants.size() > 0) { + x86::Gp constantsPointer = c.newIntPtr(); + c.mov(constantsPointer, imm(&constants[0])); + for (int i = 0; i < (int) constants.size(); i++) { + constantVar[i] = c.newXmmSd(); + c.vmovsd(constantVar[i], x86::ptr(constantsPointer, 8*i, 0)); + } + } + + // Evaluate the operations. + + vector hasComputedPower(operation.size(), false); + for (int step = 0; step < (int) operation.size(); step++) { + if (hasComputedPower[step]) + continue; + + // When one or more steps involve raising the same argument to multiple integer + // powers, we can compute them all together for efficiency. + + if (stepGroup[step] != -1) { + vector& group = groups[stepGroup[step]]; + vector& powers = groupPowers[stepGroup[step]]; + x86::Xmm multiplier = c.newXmmSd(); + if (powers[0] > 0) + c.vmovsd(multiplier, workspaceVar[arguments[step][0]], workspaceVar[arguments[step][0]]); + else { + c.vdivsd(multiplier, constantVar[operationConstantIndex[step]], workspaceVar[arguments[step][0]]); + for (int i = 0; i < (int)powers.size(); i++) + powers[i] = -powers[i]; + } + vector hasAssigned(group.size(), false); + bool done = false; + while (!done) { + done = true; + for (int i = 0; i < (int)group.size(); i++) { + if (powers[i]%2 == 1) { + if (!hasAssigned[i]) + c.vmovsd(workspaceVar[target[group[i]]], multiplier, multiplier); + else + c.vmulsd(workspaceVar[target[group[i]]], workspaceVar[target[group[i]]], multiplier); + hasAssigned[i] = true; + } + powers[i] >>= 1; + if (powers[i] != 0) + done = false; + } + if (!done) + c.vmulsd(multiplier, multiplier, multiplier); + } + for (int step : group) + hasComputedPower[step] = true; + continue; + } + + // Evaluate the step. + + Operation& op = *operation[step]; + vector args = arguments[step]; + if (args.size() == 1) { + // One or more sequential arguments. Fill out the list. + + for (int i = 1; i < op.getNumArguments(); i++) + args.push_back(args[0]+i); + } + + // Generate instructions to execute this operation. + + switch (op.getId()) { + case Operation::CONSTANT: + c.vmovsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]], constantVar[operationConstantIndex[step]]); + break; + case Operation::ADD: + c.vaddsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::SUBTRACT: + c.vsubsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::MULTIPLY: + c.vmulsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::DIVIDE: + c.vdivsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::POWER: + generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], pow); + break; + case Operation::NEGATE: + c.vxorps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[target[step]]); + c.vsubsd(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]]); + break; + case Operation::SQRT: + c.vsqrtsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]]); + break; + case Operation::EXP: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], exp); + break; + case Operation::LOG: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], log); + break; + case Operation::SIN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], sin); + break; + case Operation::COS: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], cos); + break; + case Operation::TAN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tan); + break; + case Operation::ASIN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], asin); + break; + case Operation::ACOS: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], acos); + break; + case Operation::ATAN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], atan); + break; + case Operation::ATAN2: + generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], atan2); + break; + case Operation::SINH: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], sinh); + break; + case Operation::COSH: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], cosh); + break; + case Operation::TANH: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tanh); + break; + case Operation::STEP: + c.vxorps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[target[step]]); + c.vcmpsd(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]], imm(18)); // Comparison mode is _CMP_LE_OQ = 18 + c.vandps(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + break; + case Operation::DELTA: + c.vxorps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[target[step]]); + c.vcmpsd(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]], imm(16)); // Comparison mode is _CMP_EQ_OS = 16 + c.vandps(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + break; + case Operation::SQUARE: + c.vmulsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]]); + break; + case Operation::CUBE: + c.vmulsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]]); + c.vmulsd(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]]); + break; + case Operation::RECIPROCAL: + c.vdivsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]], workspaceVar[args[0]]); + break; + case Operation::ADD_CONSTANT: + c.vaddsd(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]); + break; + case Operation::MULTIPLY_CONSTANT: + c.vmulsd(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]); + break; + case Operation::POWER_CONSTANT: + generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]], pow); + break; + case Operation::MIN: + c.vminsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::MAX: + c.vmaxsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::ABS: + c.vandpd(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]); + break; + case Operation::FLOOR: + c.vroundsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]], imm(1)); + break; + case Operation::CEIL: + c.vroundsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]], imm(2)); + break; + case Operation::SELECT: + { + x86::Xmm mask = c.newXmmSd(); + c.vxorps(mask, mask, mask); + c.vcmpsd(mask, mask, workspaceVar[args[0]], imm(0)); // Comparison mode is _CMP_EQ_OQ = 0 + c.vblendvps(workspaceVar[target[step]], workspaceVar[args[1]], workspaceVar[args[2]], mask); + break; + } + default: + // Just invoke evaluateOperation(). + + for (int i = 0; i < (int) args.size(); i++) + c.vmovsd(x86::ptr(argsPointer, 8*i, 0), workspaceVar[args[i]]); + x86::Gp fn = c.newIntPtr(); + c.mov(fn, imm((void*) evaluateOperation)); + InvokeNode* invoke; + c.invoke(&invoke, fn, FuncSignatureT()); + invoke->setArg(0, imm(&op)); + invoke->setArg(1, imm(&argValues[0])); + invoke->setRet(0, workspaceVar[target[step]]); + } + } + c.ret(workspaceVar[workspace.size()-1]); + c.endFunc(); + c.finalize(); + runtime.add(&jitCode, &code); +} + +void CompiledExpression::generateSingleArgCall(x86::Compiler& c, x86::Xmm& dest, x86::Xmm& arg, double (*function)(double)) { + x86::Gp fn = c.newIntPtr(); + c.mov(fn, imm((void*) function)); + InvokeNode* invoke; + c.invoke(&invoke, fn, FuncSignatureT()); + invoke->setArg(0, arg); + invoke->setRet(0, dest); +} + +void CompiledExpression::generateTwoArgCall(x86::Compiler& c, x86::Xmm& dest, x86::Xmm& arg1, x86::Xmm& arg2, double (*function)(double, double)) { + x86::Gp fn = c.newIntPtr(); + c.mov(fn, imm((void*) function)); + InvokeNode* invoke; + c.invoke(&invoke, fn, FuncSignatureT()); + invoke->setArg(0, arg1); + invoke->setArg(1, arg2); + invoke->setRet(0, dest); } #endif +#endif diff --git a/lib/lepton/src/CompiledVectorExpression.cpp b/lib/lepton/src/CompiledVectorExpression.cpp new file mode 100644 index 0000000000..7e4dfcad9c --- /dev/null +++ b/lib/lepton/src/CompiledVectorExpression.cpp @@ -0,0 +1,933 @@ +/* -------------------------------------------------------------------------- * + * Lepton * + * -------------------------------------------------------------------------- * + * This is part of the Lepton expression parser originating from * + * Simbios, the NIH National Center for Physics-Based Simulation of * + * Biological Structures at Stanford, funded under the NIH Roadmap for * + * Medical Research, grant U54 GM072970. See https://simtk.org. * + * * + * Portions copyright (c) 2013-2022 Stanford University and the Authors. * + * Authors: Peter Eastman * + * Contributors: * + * * + * Permission is hereby granted, free of charge, to any person obtaining a * + * copy of this software and associated documentation files (the "Software"), * + * to deal in the Software without restriction, including without limitation * + * the rights to use, copy, modify, merge, publish, distribute, sublicense, * + * and/or sell copies of the Software, and to permit persons to whom the * + * Software is furnished to do so, subject to the following conditions: * + * * + * The above copyright notice and this permission notice shall be included in * + * all copies or substantial portions of the Software. * + * * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL * + * THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, * + * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR * + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE * + * USE OR OTHER DEALINGS IN THE SOFTWARE. * + * -------------------------------------------------------------------------- */ + +#include "lepton/CompiledVectorExpression.h" +#include "lepton/Operation.h" +#include "lepton/ParsedExpression.h" +#include +#include + +using namespace LMP_Lepton; +using namespace std; +#ifdef LEPTON_USE_JIT +using namespace asmjit; +#endif + +CompiledVectorExpression::CompiledVectorExpression() : jitCode(NULL) { +} + +CompiledVectorExpression::CompiledVectorExpression(const ParsedExpression& expression, int width) : width(width), jitCode(NULL) { + const vector allowedWidths = getAllowedWidths(); + if (find(allowedWidths.begin(), allowedWidths.end(), width) == allowedWidths.end()) + throw Exception("Unsupported width for vector expression: "+to_string(width)); + ParsedExpression expr = expression.optimize(); // Just in case it wasn't already optimized. + vector > temps; + int workspaceSize = 0; + compileExpression(expr.getRootNode(), temps, workspaceSize); + workspace.resize(workspaceSize*width); + int maxArguments = 1; + for (int i = 0; i < (int) operation.size(); i++) + if (operation[i]->getNumArguments() > maxArguments) + maxArguments = operation[i]->getNumArguments(); + argValues.resize(maxArguments); +#ifdef LEPTON_USE_JIT + generateJitCode(); +#endif +} + +CompiledVectorExpression::~CompiledVectorExpression() { + for (int i = 0; i < (int) operation.size(); i++) + if (operation[i] != NULL) + delete operation[i]; +} + +CompiledVectorExpression::CompiledVectorExpression(const CompiledVectorExpression& expression) : jitCode(NULL) { + *this = expression; +} + +CompiledVectorExpression& CompiledVectorExpression::operator=(const CompiledVectorExpression& expression) { + arguments = expression.arguments; + width = expression.width; + target = expression.target; + variableIndices = expression.variableIndices; + variableNames = expression.variableNames; + workspace.resize(expression.workspace.size()); + argValues.resize(expression.argValues.size()); + operation.resize(expression.operation.size()); + for (int i = 0; i < (int) operation.size(); i++) + operation[i] = expression.operation[i]->clone(); + setVariableLocations(variablePointers); + return *this; +} + +const vector& CompiledVectorExpression::getAllowedWidths() { + static vector widths; + if (widths.size() == 0) { + widths.push_back(4); +#ifdef LEPTON_USE_JIT + const CpuInfo& cpu = CpuInfo::host(); + if (cpu.hasFeature(CpuFeatures::X86::kAVX)) + widths.push_back(8); +#endif + } + return widths; +} + +void CompiledVectorExpression::compileExpression(const ExpressionTreeNode& node, vector >& temps, int& workspaceSize) { + if (findTempIndex(node, temps) != -1) + return; // We have already processed a node identical to this one. + + // Process the child nodes. + + vector args; + for (int i = 0; i < (int)node.getChildren().size(); i++) { + compileExpression(node.getChildren()[i], temps, workspaceSize); + args.push_back(findTempIndex(node.getChildren()[i], temps)); + } + + // Process this node. + + if (node.getOperation().getId() == Operation::VARIABLE) { + variableIndices[node.getOperation().getName()] = workspaceSize; + variableNames.insert(node.getOperation().getName()); + } + else { + int stepIndex = (int) arguments.size(); + arguments.push_back(vector()); + target.push_back(workspaceSize); + operation.push_back(node.getOperation().clone()); + if (args.size() == 0) + arguments[stepIndex].push_back(0); // The value won't actually be used. We just need something there. + else { + // If the arguments are sequential, we can just pass a pointer to the first one. + + bool sequential = true; + for (int i = 1; i < (int)args.size(); i++) + if (args[i] != args[i - 1] + 1) + sequential = false; + if (sequential) + arguments[stepIndex].push_back(args[0]); + else + arguments[stepIndex] = args; + } + } + temps.push_back(make_pair(node, workspaceSize)); + workspaceSize++; +} + +int CompiledVectorExpression::findTempIndex(const ExpressionTreeNode& node, vector >& temps) { + for (int i = 0; i < (int) temps.size(); i++) + if (temps[i].first == node) + return i; + return -1; +} + +int CompiledVectorExpression::getWidth() const { + return width; +} + +const set& CompiledVectorExpression::getVariables() const { + return variableNames; +} + +float* CompiledVectorExpression::getVariablePointer(const string& name) { + map::iterator pointer = variablePointers.find(name); + if (pointer != variablePointers.end()) + return pointer->second; + map::iterator index = variableIndices.find(name); + if (index == variableIndices.end()) + throw Exception("getVariableReference: Unknown variable '" + name + "'"); + return &workspace[index->second*width]; +} + +void CompiledVectorExpression::setVariableLocations(map& variableLocations) { + variablePointers = variableLocations; +#ifdef LEPTON_USE_JIT + // Rebuild the JIT code. + + if (workspace.size() > 0) + generateJitCode(); +#endif + // Make a list of all variables we will need to copy before evaluating the expression. + + variablesToCopy.clear(); + for (map::const_iterator iter = variableIndices.begin(); iter != variableIndices.end(); ++iter) { + map::iterator pointer = variablePointers.find(iter->first); + if (pointer != variablePointers.end()) + variablesToCopy.push_back(make_pair(&workspace[iter->second*width], pointer->second)); + } +} + +const float* CompiledVectorExpression::evaluate() const { + if (jitCode) { + jitCode(); + return &workspace[workspace.size()-width]; + } + for (int i = 0; i < (int)variablesToCopy.size(); i++) + for (int j = 0; j < width; j++) + variablesToCopy[i].first[j] = variablesToCopy[i].second[j]; + + // Loop over the operations and evaluate each one. + + for (int step = 0; step < (int)operation.size(); step++) { + const vector& args = arguments[step]; + if (args.size() == 1) { + for (int j = 0; j < width; j++) { + for (int i = 0; i < operation[step]->getNumArguments(); i++) + argValues[i] = workspace[(args[0]+i)*width+j]; + workspace[target[step]*width+j] = operation[step]->evaluate(&argValues[0], dummyVariables); + } + } else { + for (int j = 0; j < width; j++) { + for (int i = 0; i < (int)args.size(); i++) + argValues[i] = workspace[args[i]*width+j]; + workspace[target[step]*width+j] = operation[step]->evaluate(&argValues[0], dummyVariables); + } + } + } + return &workspace[workspace.size()-width]; +} + +#ifdef LEPTON_USE_JIT + +static double evaluateOperation(Operation* op, double* args) { + static map dummyVariables; + return op->evaluate(args, dummyVariables); +} + +void CompiledVectorExpression::findPowerGroups(vector >& groups, vector >& groupPowers, vector& stepGroup) { + // Identify every step that raises an argument to an integer power. + + vector stepPower(operation.size(), 0); + vector stepArg(operation.size(), -1); + for (int step = 0; step < (int)operation.size(); step++) { + Operation& op = *operation[step]; + int power = 0; + if (op.getId() == Operation::SQUARE) + power = 2; + else if (op.getId() == Operation::CUBE) + power = 3; + else if (op.getId() == Operation::POWER_CONSTANT) { + double realPower = dynamic_cast (&op)->getValue(); + if (realPower == (int) realPower) + power = (int) realPower; + } + if (power != 0) { + stepPower[step] = power; + stepArg[step] = arguments[step][0]; + } + } + + // Find groups that operate on the same argument and whose powers have the same sign. + + stepGroup.resize(operation.size(), -1); + for (int i = 0; i < (int)operation.size(); i++) { + if (stepGroup[i] != -1) + continue; + vector group, power; + for (int j = i; j < (int)operation.size(); j++) { + if (stepArg[i] == stepArg[j] && stepPower[i] * stepPower[j] > 0) { + stepGroup[j] = groups.size(); + group.push_back(j); + power.push_back(stepPower[j]); + } + } + groups.push_back(group); + groupPowers.push_back(power); + } +} + +#if defined(__ARM__) || defined(__ARM64__) + +void CompiledVectorExpression::generateJitCode() { + CodeHolder code; + code.init(runtime.environment()); + a64::Compiler c(&code); + c.addFunc(FuncSignatureT()); + vector workspaceVar(workspace.size()/width); + for (int i = 0; i < (int) workspaceVar.size(); i++) + workspaceVar[i] = c.newVecQ(); + arm::Gp argsPointer = c.newIntPtr(); + c.mov(argsPointer, imm(&argValues[0])); + vector > groups, groupPowers; + vector stepGroup; + findPowerGroups(groups, groupPowers, stepGroup); + + // Load the arguments into variables. + + arm::Gp variablePointer = c.newIntPtr(); + for (set::const_iterator iter = variableNames.begin(); iter != variableNames.end(); ++iter) { + map::iterator index = variableIndices.find(*iter); + c.mov(variablePointer, imm(getVariablePointer(index->first))); + c.ldr(workspaceVar[index->second].s4(), arm::ptr(variablePointer, 0)); + } + + // Make a list of all constants that will be needed for evaluation. + + vector operationConstantIndex(operation.size(), -1); + for (int step = 0; step < (int) operation.size(); step++) { + // Find the constant value (if any) used by this operation. + + Operation& op = *operation[step]; + float value; + if (op.getId() == Operation::CONSTANT) + value = dynamic_cast (op).getValue(); + else if (op.getId() == Operation::ADD_CONSTANT) + value = dynamic_cast (op).getValue(); + else if (op.getId() == Operation::MULTIPLY_CONSTANT) + value = dynamic_cast (op).getValue(); + else if (op.getId() == Operation::RECIPROCAL) + value = 1.0; + else if (op.getId() == Operation::STEP) + value = 1.0; + else if (op.getId() == Operation::DELTA) + value = 1.0; + else if (op.getId() == Operation::POWER_CONSTANT) { + if (stepGroup[step] == -1) + value = dynamic_cast (op).getValue(); + else + value = 1.0; + } else + continue; + + // See if we already have a variable for this constant. + + for (int i = 0; i < (int) constants.size(); i++) + if (value == constants[i]) { + operationConstantIndex[step] = i; + break; + } + if (operationConstantIndex[step] == -1) { + operationConstantIndex[step] = constants.size(); + constants.push_back(value); + } + } + + // Load constants into variables. + + vector constantVar(constants.size()); + if (constants.size() > 0) { + arm::Gp constantsPointer = c.newIntPtr(); + for (int i = 0; i < (int) constants.size(); i++) { + c.mov(constantsPointer, imm(&constants[i])); + constantVar[i] = c.newVecQ(); + c.ld1r(constantVar[i].s4(), arm::ptr(constantsPointer)); + } + } + + // Evaluate the operations. + + vector hasComputedPower(operation.size(), false); + arm::Vec argReg = c.newVecS(); + arm::Vec doubleArgReg = c.newVecD(); + arm::Vec doubleResultReg = c.newVecD(); + for (int step = 0; step < (int) operation.size(); step++) { + if (hasComputedPower[step]) + continue; + + // When one or more steps involve raising the same argument to multiple integer + // powers, we can compute them all together for efficiency. + + if (stepGroup[step] != -1) { + vector& group = groups[stepGroup[step]]; + vector& powers = groupPowers[stepGroup[step]]; + arm::Vec multiplier = c.newVecQ(); + if (powers[0] > 0) + c.mov(multiplier.s4(), workspaceVar[arguments[step][0]].s4()); + else { + c.fdiv(multiplier.s4(), constantVar[operationConstantIndex[step]].s4(), workspaceVar[arguments[step][0]].s4()); + for (int i = 0; i < powers.size(); i++) + powers[i] = -powers[i]; + } + vector hasAssigned(group.size(), false); + bool done = false; + while (!done) { + done = true; + for (int i = 0; i < group.size(); i++) { + if (powers[i] % 2 == 1) { + if (!hasAssigned[i]) + c.mov(workspaceVar[target[group[i]]].s4(), multiplier.s4()); + else + c.fmul(workspaceVar[target[group[i]]].s4(), workspaceVar[target[group[i]]].s4(), multiplier.s4()); + hasAssigned[i] = true; + } + powers[i] >>= 1; + if (powers[i] != 0) + done = false; + } + if (!done) + c.fmul(multiplier.s4(), multiplier.s4(), multiplier.s4()); + } + for (int step : group) + hasComputedPower[step] = true; + continue; + } + + // Evaluate the step. + + Operation& op = *operation[step]; + vector args = arguments[step]; + if (args.size() == 1) { + // One or more sequential arguments. Fill out the list. + + for (int i = 1; i < op.getNumArguments(); i++) + args.push_back(args[0] + i); + } + + // Generate instructions to execute this operation. + + switch (op.getId()) { + case Operation::CONSTANT: + c.mov(workspaceVar[target[step]].s4(), constantVar[operationConstantIndex[step]].s4()); + break; + case Operation::ADD: + c.fadd(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[1]].s4()); + break; + case Operation::SUBTRACT: + c.fsub(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[1]].s4()); + break; + case Operation::MULTIPLY: + c.fmul(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[1]].s4()); + break; + case Operation::DIVIDE: + c.fdiv(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[1]].s4()); + break; + case Operation::POWER: + generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], powf); + break; + case Operation::NEGATE: + c.fneg(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4()); + break; + case Operation::SQRT: + c.fsqrt(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4()); + break; + case Operation::EXP: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], expf); + break; + case Operation::LOG: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], logf); + break; + case Operation::SIN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], sinf); + break; + case Operation::COS: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], cosf); + break; + case Operation::TAN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tanf); + break; + case Operation::ASIN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], asinf); + break; + case Operation::ACOS: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], acosf); + break; + case Operation::ATAN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], atanf); + break; + case Operation::ATAN2: + generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], atan2f); + break; + case Operation::SINH: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], sinhf); + break; + case Operation::COSH: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], coshf); + break; + case Operation::TANH: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tanhf); + break; + case Operation::STEP: + c.cmge(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), imm(0)); + c.and_(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + break; + case Operation::DELTA: + c.cmeq(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), imm(0)); + c.and_(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + break; + case Operation::SQUARE: + c.fmul(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[0]].s4()); + break; + case Operation::CUBE: + c.fmul(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[0]].s4()); + c.fmul(workspaceVar[target[step]].s4(), workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4()); + break; + case Operation::RECIPROCAL: + c.fdiv(workspaceVar[target[step]].s4(), constantVar[operationConstantIndex[step]].s4(), workspaceVar[args[0]].s4()); + break; + case Operation::ADD_CONSTANT: + c.fadd(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), constantVar[operationConstantIndex[step]].s4()); + break; + case Operation::MULTIPLY_CONSTANT: + c.fmul(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), constantVar[operationConstantIndex[step]].s4()); + break; + case Operation::POWER_CONSTANT: + generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]], powf); + break; + case Operation::MIN: + c.fmin(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[1]].s4()); + break; + case Operation::MAX: + c.fmax(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[1]].s4()); + break; + case Operation::ABS: + c.fabs(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4()); + break; + case Operation::FLOOR: + c.frintm(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4()); + break; + case Operation::CEIL: + c.frintp(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4()); + break; + case Operation::SELECT: + c.fcmeq(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), imm(0)); + c.bsl(workspaceVar[target[step]], workspaceVar[args[2]], workspaceVar[args[1]]); + break; + default: + // Just invoke evaluateOperation(). + for (int element = 0; element < width; element++) { + for (int i = 0; i < (int) args.size(); i++) { + c.ins(argReg.s(0), workspaceVar[args[i]].s(element)); + c.fcvt(doubleArgReg, argReg); + c.str(doubleArgReg, arm::ptr(argsPointer, 8*i)); + } + arm::Gp fn = c.newIntPtr(); + c.mov(fn, imm((void*) evaluateOperation)); + InvokeNode* invoke; + c.invoke(&invoke, fn, FuncSignatureT()); + invoke->setArg(0, imm(&op)); + invoke->setArg(1, imm(&argValues[0])); + invoke->setRet(0, doubleResultReg); + c.fcvt(argReg, doubleResultReg); + c.ins(workspaceVar[target[step]].s(element), argReg.s(0)); + } + } + } + arm::Gp resultPointer = c.newIntPtr(); + c.mov(resultPointer, imm(&workspace[workspace.size()-width])); + c.str(workspaceVar.back().s4(), arm::ptr(resultPointer, 0)); + c.endFunc(); + c.finalize(); + runtime.add(&jitCode, &code); +} + +void CompiledVectorExpression::generateSingleArgCall(a64::Compiler& c, arm::Vec& dest, arm::Vec& arg, float (*function)(float)) { + arm::Gp fn = c.newIntPtr(); + c.mov(fn, imm((void*) function)); + arm::Vec a = c.newVecS(); + arm::Vec d = c.newVecS(); + for (int element = 0; element < width; element++) { + c.ins(a.s(0), arg.s(element)); + InvokeNode* invoke; + c.invoke(&invoke, fn, FuncSignatureT()); + invoke->setArg(0, a); + invoke->setRet(0, d); + c.ins(dest.s(element), d.s(0)); + } +} + +void CompiledVectorExpression::generateTwoArgCall(a64::Compiler& c, arm::Vec& dest, arm::Vec& arg1, arm::Vec& arg2, float (*function)(float, float)) { + arm::Gp fn = c.newIntPtr(); + c.mov(fn, imm((void*) function)); + arm::Vec a1 = c.newVecS(); + arm::Vec a2 = c.newVecS(); + arm::Vec d = c.newVecS(); + for (int element = 0; element < width; element++) { + c.ins(a1.s(0), arg1.s(element)); + c.ins(a2.s(0), arg2.s(element)); + InvokeNode* invoke; + c.invoke(&invoke, fn, FuncSignatureT()); + invoke->setArg(0, a1); + invoke->setArg(1, a2); + invoke->setRet(0, d); + c.ins(dest.s(element), d.s(0)); + } +} +#else + +void CompiledVectorExpression::generateJitCode() { + const CpuInfo& cpu = CpuInfo::host(); + if (!cpu.hasFeature(CpuFeatures::X86::kAVX)) + return; + CodeHolder code; + code.init(runtime.environment()); + x86::Compiler c(&code); + FuncNode* funcNode = c.addFunc(FuncSignatureT()); + funcNode->frame().setAvxEnabled(); + vector workspaceVar(workspace.size()/width); + for (int i = 0; i < (int) workspaceVar.size(); i++) + workspaceVar[i] = c.newYmmPs(); + x86::Gp argsPointer = c.newIntPtr(); + c.mov(argsPointer, imm(&argValues[0])); + vector > groups, groupPowers; + vector stepGroup; + findPowerGroups(groups, groupPowers, stepGroup); + + // Load the arguments into variables. + + for (set::const_iterator iter = variableNames.begin(); iter != variableNames.end(); ++iter) { + map::iterator index = variableIndices.find(*iter); + x86::Gp variablePointer = c.newIntPtr(); + c.mov(variablePointer, imm(getVariablePointer(index->first))); + if (width == 4) + c.vmovdqu(workspaceVar[index->second].xmm(), x86::ptr(variablePointer, 0, 0)); + else + c.vmovdqu(workspaceVar[index->second], x86::ptr(variablePointer, 0, 0)); + } + + // Make a list of all constants that will be needed for evaluation. + + vector operationConstantIndex(operation.size(), -1); + for (int step = 0; step < (int) operation.size(); step++) { + // Find the constant value (if any) used by this operation. + + Operation& op = *operation[step]; + double value; + if (op.getId() == Operation::CONSTANT) + value = dynamic_cast (op).getValue(); + else if (op.getId() == Operation::ADD_CONSTANT) + value = dynamic_cast (op).getValue(); + else if (op.getId() == Operation::MULTIPLY_CONSTANT) + value = dynamic_cast (op).getValue(); + else if (op.getId() == Operation::RECIPROCAL) + value = 1.0; + else if (op.getId() == Operation::STEP) + value = 1.0; + else if (op.getId() == Operation::DELTA) + value = 1.0; + else if (op.getId() == Operation::ABS) { + int mask = 0x7FFFFFFF; + value = *reinterpret_cast(&mask); + } + else if (op.getId() == Operation::POWER_CONSTANT) { + if (stepGroup[step] == -1) + value = dynamic_cast (op).getValue(); + else + value = 1.0; + } else + continue; + + // See if we already have a variable for this constant. + + for (int i = 0; i < (int) constants.size(); i++) + if (value == constants[i]) { + operationConstantIndex[step] = i; + break; + } + if (operationConstantIndex[step] == -1) { + operationConstantIndex[step] = constants.size(); + constants.push_back(value); + } + } + + // Load constants into variables. + + vector constantVar(constants.size()); + if (constants.size() > 0) { + x86::Gp constantsPointer = c.newIntPtr(); + c.mov(constantsPointer, imm(&constants[0])); + for (int i = 0; i < (int) constants.size(); i++) { + constantVar[i] = c.newYmmPs(); + c.vbroadcastss(constantVar[i], x86::ptr(constantsPointer, 4*i, 0)); + } + } + + // Evaluate the operations. + + vector hasComputedPower(operation.size(), false); + x86::Ymm argReg = c.newYmm(); + x86::Ymm doubleArgReg = c.newYmm(); + x86::Ymm doubleResultReg = c.newYmm(); + for (int step = 0; step < (int) operation.size(); step++) { + if (hasComputedPower[step]) + continue; + + // When one or more steps involve raising the same argument to multiple integer + // powers, we can compute them all together for efficiency. + + if (stepGroup[step] != -1) { + vector& group = groups[stepGroup[step]]; + vector& powers = groupPowers[stepGroup[step]]; + x86::Ymm multiplier = c.newYmmPs(); + if (powers[0] > 0) + c.vmovdqu(multiplier, workspaceVar[arguments[step][0]]); + else { + c.vdivps(multiplier, constantVar[operationConstantIndex[step]], workspaceVar[arguments[step][0]]); + for (int i = 0; i < (int)powers.size(); i++) + powers[i] = -powers[i]; + } + vector hasAssigned(group.size(), false); + bool done = false; + while (!done) { + done = true; + for (int i = 0; i < (int)group.size(); i++) { + if (powers[i] % 2 == 1) { + if (!hasAssigned[i]) + c.vmovdqu(workspaceVar[target[group[i]]], multiplier); + else + c.vmulps(workspaceVar[target[group[i]]], workspaceVar[target[group[i]]], multiplier); + hasAssigned[i] = true; + } + powers[i] >>= 1; + if (powers[i] != 0) + done = false; + } + if (!done) + c.vmulps(multiplier, multiplier, multiplier); + } + for (int step : group) + hasComputedPower[step] = true; + continue; + } + + // Evaluate the step. + + Operation& op = *operation[step]; + vector args = arguments[step]; + if (args.size() == 1) { + // One or more sequential arguments. Fill out the list. + + for (int i = 1; i < op.getNumArguments(); i++) + args.push_back(args[0] + i); + } + + // Generate instructions to execute this operation. + + switch (op.getId()) { + case Operation::CONSTANT: + c.vmovdqu(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + break; + case Operation::ADD: + c.vaddps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::SUBTRACT: + c.vsubps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::MULTIPLY: + c.vmulps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::DIVIDE: + c.vdivps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::POWER: + generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], powf); + break; + case Operation::NEGATE: + c.vxorps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[target[step]]); + c.vsubps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]]); + break; + case Operation::SQRT: + c.vsqrtps(workspaceVar[target[step]], workspaceVar[args[0]]); + break; + case Operation::EXP: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], expf); + break; + case Operation::LOG: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], logf); + break; + case Operation::SIN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], sinf); + break; + case Operation::COS: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], cosf); + break; + case Operation::TAN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tanf); + break; + case Operation::ASIN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], asinf); + break; + case Operation::ACOS: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], acosf); + break; + case Operation::ATAN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], atanf); + break; + case Operation::ATAN2: + generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], atan2f); + break; + case Operation::SINH: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], sinhf); + break; + case Operation::COSH: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], coshf); + break; + case Operation::TANH: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tanhf); + break; + case Operation::STEP: + c.vxorps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[target[step]]); + c.vcmpps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]], imm(18)); // Comparison mode is _CMP_LE_OQ = 18 + c.vandps(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + break; + case Operation::DELTA: + c.vxorps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[target[step]]); + c.vcmpps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]], imm(16)); // Comparison mode is _CMP_EQ_OQ = 0 + c.vandps(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + break; + case Operation::SQUARE: + c.vmulps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]]); + break; + case Operation::CUBE: + c.vmulps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]]); + c.vmulps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]]); + break; + case Operation::RECIPROCAL: + c.vdivps(workspaceVar[target[step]], constantVar[operationConstantIndex[step]], workspaceVar[args[0]]); + break; + case Operation::ADD_CONSTANT: + c.vaddps(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]); + break; + case Operation::MULTIPLY_CONSTANT: + c.vmulps(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]); + break; + case Operation::POWER_CONSTANT: + generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]], powf); + break; + case Operation::MIN: + c.vminps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::MAX: + c.vmaxps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::ABS: + c.vandps(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]); + break; + case Operation::FLOOR: + c.vroundps(workspaceVar[target[step]], workspaceVar[args[0]], imm(1)); + break; + case Operation::CEIL: + c.vroundps(workspaceVar[target[step]], workspaceVar[args[0]], imm(2)); + break; + case Operation::SELECT: + { + x86::Ymm mask = c.newYmmPs(); + c.vxorps(mask, mask, mask); + c.vcmpps(mask, mask, workspaceVar[args[0]], imm(0)); // Comparison mode is _CMP_EQ_OQ = 0 + c.vblendvps(workspaceVar[target[step]], workspaceVar[args[1]], workspaceVar[args[2]], mask); + break; + } + default: + // Just invoke evaluateOperation(). + + for (int element = 0; element < width; element++) { + for (int i = 0; i < (int) args.size(); i++) { + if (element < 4) + c.vshufps(argReg, workspaceVar[args[i]], workspaceVar[args[i]], imm(element)); + else { + c.vperm2f128(argReg, workspaceVar[args[i]], workspaceVar[args[i]], imm(1)); + c.vshufps(argReg, argReg, argReg, imm(element-4)); + } + c.vcvtss2sd(doubleArgReg.xmm(), doubleArgReg.xmm(), argReg.xmm()); + c.vmovsd(x86::ptr(argsPointer, 8*i, 0), doubleArgReg.xmm()); + } + x86::Gp fn = c.newIntPtr(); + c.mov(fn, imm((void*) evaluateOperation)); + InvokeNode* invoke; + c.invoke(&invoke, fn, FuncSignatureT()); + invoke->setArg(0, imm(&op)); + invoke->setArg(1, imm(&argValues[0])); + invoke->setRet(0, doubleResultReg); + c.vcvtsd2ss(argReg.xmm(), argReg.xmm(), doubleResultReg.xmm()); + if (element > 3) + c.vperm2f128(argReg, argReg, argReg, imm(0)); + if (element != 0) + c.vshufps(argReg, argReg, argReg, imm(0)); + c.vblendps(workspaceVar[target[step]], workspaceVar[target[step]], argReg, 1<()); + invoke->setArg(0, a); + invoke->setRet(0, d); + if (element > 3) + c.vperm2f128(d, d, d, imm(0)); + if (element != 0) + c.vshufps(d, d, d, imm(0)); + c.vblendps(dest, dest, d, 1<()); + invoke->setArg(0, a1); + invoke->setArg(1, a2); + invoke->setRet(0, d); + if (element > 3) + c.vperm2f128(d, d, d, imm(0)); + if (element != 0) + c.vshufps(d, d, d, imm(0)); + c.vblendps(dest, dest, d, 1< using namespace LMP_Lepton; using namespace std; @@ -62,6 +63,11 @@ ExpressionTreeNode::ExpressionTreeNode(Operation* operation) : operation(operati ExpressionTreeNode::ExpressionTreeNode(const ExpressionTreeNode& node) : operation(node.operation == NULL ? NULL : node.operation->clone()), children(node.getChildren()) { } +ExpressionTreeNode::ExpressionTreeNode(ExpressionTreeNode&& node) : operation(node.operation), children(move(node.children)) { + node.operation = NULL; + node.children.clear(); +} + ExpressionTreeNode::ExpressionTreeNode() : operation(NULL) { } @@ -98,6 +104,16 @@ ExpressionTreeNode& ExpressionTreeNode::operator=(const ExpressionTreeNode& node return *this; } +ExpressionTreeNode& ExpressionTreeNode::operator=(ExpressionTreeNode&& node) { + if (operation != NULL) + delete operation; + operation = node.operation; + children = move(node.children); + node.operation = NULL; + node.children.clear(); + return *this; +} + const Operation& ExpressionTreeNode::getOperation() const { return *operation; } @@ -105,3 +121,33 @@ const Operation& ExpressionTreeNode::getOperation() const { const vector& ExpressionTreeNode::getChildren() const { return children; } + +void ExpressionTreeNode::assignTags(vector& examples) const { + // Assign tag values to all nodes in a tree, such that two nodes have the same + // tag if and only if they (and all their children) are equal. This is used to + // optimize other operations. + + int numTags = examples.size(); + for (const ExpressionTreeNode& child : getChildren()) + child.assignTags(examples); + if (numTags == (int)examples.size()) { + // All the children matched existing tags, so possibly this node does too. + + for (int i = 0; i < (int)examples.size(); i++) { + const ExpressionTreeNode& example = *examples[i]; + bool matches = (getChildren().size() == example.getChildren().size() && getOperation() == example.getOperation()); + for (int j = 0; matches && j < (int)getChildren().size(); j++) + if (getChildren()[j].tag != example.getChildren()[j].tag) + matches = false; + if (matches) { + tag = i; + return; + } + } + } + + // This node does not match any previous node, so assign a new tag. + + tag = examples.size(); + examples.push_back(this); +} diff --git a/lib/lepton/src/Operation.cpp b/lib/lepton/src/Operation.cpp index bec5686a74..08deff8584 100644 --- a/lib/lepton/src/Operation.cpp +++ b/lib/lepton/src/Operation.cpp @@ -7,7 +7,7 @@ * Biological Structures at Stanford, funded under the NIH Roadmap for * * Medical Research, grant U54 GM072970. See https://simtk.org. * * * - * Portions copyright (c) 2009-2019 Stanford University and the Authors. * + * Portions copyright (c) 2009-2021 Stanford University and the Authors. * * Authors: Peter Eastman * * Contributors: * * * @@ -37,7 +37,13 @@ using namespace LMP_Lepton; using namespace std; -double Operation::Erf::evaluate(double* args, const map& ) const { +static bool isZero(const ExpressionTreeNode& node) { + if (node.getOperation().getId() != Operation::CONSTANT) + return false; + return dynamic_cast(node.getOperation()).getValue() == 0.0; +} + +double Operation::Erf::evaluate(double* args, const map&) const { return erf(args[0]); } @@ -58,35 +64,71 @@ ExpressionTreeNode Operation::Variable::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { if (function->getNumArguments() == 0) return ExpressionTreeNode(new Operation::Constant(0.0)); - ExpressionTreeNode result = ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Custom(*this, 0), children), childDerivs[0]); - for (int i = 1; i < getNumArguments(); i++) { - result = ExpressionTreeNode(new Operation::Add(), - result, - ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Custom(*this, i), children), childDerivs[i])); + ExpressionTreeNode result; + bool foundTerm = false; + for (int i = 0; i < getNumArguments(); i++) { + if (!isZero(childDerivs[i])) { + if (foundTerm) + result = ExpressionTreeNode(new Operation::Add(), + result, + ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Custom(*this, i), children), childDerivs[i])); + else { + result = ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Custom(*this, i), children), childDerivs[i]); + foundTerm = true; + } + } } - return result; + if (foundTerm) + return result; + return ExpressionTreeNode(new Operation::Constant(0.0)); } ExpressionTreeNode Operation::Add::differentiate(const std::vector& , const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) + return childDerivs[1]; + if (isZero(childDerivs[1])) + return childDerivs[0]; return ExpressionTreeNode(new Operation::Add(), childDerivs[0], childDerivs[1]); } ExpressionTreeNode Operation::Subtract::differentiate(const std::vector& , const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) { + if (isZero(childDerivs[1])) + return ExpressionTreeNode(new Operation::Constant(0.0)); + return ExpressionTreeNode(new Operation::Negate(), childDerivs[1]); + } + if (isZero(childDerivs[1])) + return childDerivs[0]; return ExpressionTreeNode(new Operation::Subtract(), childDerivs[0], childDerivs[1]); } ExpressionTreeNode Operation::Multiply::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) { + if (isZero(childDerivs[1])) + return ExpressionTreeNode(new Operation::Constant(0.0)); + return ExpressionTreeNode(new Operation::Multiply(), children[0], childDerivs[1]); + } + if (isZero(childDerivs[1])) + return ExpressionTreeNode(new Operation::Multiply(), children[1], childDerivs[0]); return ExpressionTreeNode(new Operation::Add(), ExpressionTreeNode(new Operation::Multiply(), children[0], childDerivs[1]), ExpressionTreeNode(new Operation::Multiply(), children[1], childDerivs[0])); } ExpressionTreeNode Operation::Divide::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { - return ExpressionTreeNode(new Operation::Divide(), - ExpressionTreeNode(new Operation::Subtract(), + ExpressionTreeNode subexp; + if (isZero(childDerivs[0])) { + if (isZero(childDerivs[1])) + return ExpressionTreeNode(new Operation::Constant(0.0)); + subexp = ExpressionTreeNode(new Operation::Negate(), ExpressionTreeNode(new Operation::Multiply(), children[0], childDerivs[1])); + } + else if (isZero(childDerivs[1])) + subexp = ExpressionTreeNode(new Operation::Multiply(), children[1], childDerivs[0]); + else + subexp = ExpressionTreeNode(new Operation::Subtract(), ExpressionTreeNode(new Operation::Multiply(), children[1], childDerivs[0]), - ExpressionTreeNode(new Operation::Multiply(), children[0], childDerivs[1])), - ExpressionTreeNode(new Operation::Square(), children[1])); + ExpressionTreeNode(new Operation::Multiply(), children[0], childDerivs[1])); + return ExpressionTreeNode(new Operation::Divide(), subexp, ExpressionTreeNode(new Operation::Square(), children[1])); } ExpressionTreeNode Operation::Power::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { @@ -105,10 +147,14 @@ ExpressionTreeNode Operation::Power::differentiate(const std::vector& , const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Negate(), childDerivs[0]); } ExpressionTreeNode Operation::Sqrt::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::MultiplyConstant(0.5), ExpressionTreeNode(new Operation::Reciprocal(), @@ -117,24 +163,32 @@ ExpressionTreeNode Operation::Sqrt::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Exp(), children[0]), childDerivs[0]); } ExpressionTreeNode Operation::Log::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Reciprocal(), children[0]), childDerivs[0]); } ExpressionTreeNode Operation::Sin::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Cos(), children[0]), childDerivs[0]); } ExpressionTreeNode Operation::Cos::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Negate(), ExpressionTreeNode(new Operation::Sin(), children[0])), @@ -142,6 +196,8 @@ ExpressionTreeNode Operation::Cos::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Sec(), children[0]), @@ -150,6 +206,8 @@ ExpressionTreeNode Operation::Sec::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Negate(), ExpressionTreeNode(new Operation::Multiply(), @@ -159,6 +217,8 @@ ExpressionTreeNode Operation::Csc::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Square(), ExpressionTreeNode(new Operation::Sec(), children[0])), @@ -166,6 +226,8 @@ ExpressionTreeNode Operation::Tan::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Negate(), ExpressionTreeNode(new Operation::Square(), @@ -174,6 +236,8 @@ ExpressionTreeNode Operation::Cot::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Reciprocal(), ExpressionTreeNode(new Operation::Sqrt(), @@ -184,6 +248,8 @@ ExpressionTreeNode Operation::Asin::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Negate(), ExpressionTreeNode(new Operation::Reciprocal(), @@ -195,6 +261,8 @@ ExpressionTreeNode Operation::Acos::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Reciprocal(), ExpressionTreeNode(new Operation::AddConstant(1.0), @@ -213,6 +281,8 @@ ExpressionTreeNode Operation::Atan2::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Cosh(), children[0]), @@ -220,6 +290,8 @@ ExpressionTreeNode Operation::Sinh::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Sinh(), children[0]), @@ -227,6 +299,8 @@ ExpressionTreeNode Operation::Cosh::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Subtract(), ExpressionTreeNode(new Operation::Constant(1.0)), @@ -236,6 +310,8 @@ ExpressionTreeNode Operation::Tanh::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Constant(2.0/sqrt(M_PI))), @@ -246,6 +322,8 @@ ExpressionTreeNode Operation::Erf::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Constant(-2.0/sqrt(M_PI))), @@ -264,6 +342,8 @@ ExpressionTreeNode Operation::Delta::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::MultiplyConstant(2.0), children[0]), @@ -271,6 +351,8 @@ ExpressionTreeNode Operation::Square::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::MultiplyConstant(3.0), ExpressionTreeNode(new Operation::Square(), children[0])), @@ -278,6 +360,8 @@ ExpressionTreeNode Operation::Cube::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Negate(), ExpressionTreeNode(new Operation::Reciprocal(), @@ -290,11 +374,15 @@ ExpressionTreeNode Operation::AddConstant::differentiate(const std::vector& , const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::MultiplyConstant(value), childDerivs[0]); } ExpressionTreeNode Operation::PowerConstant::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::MultiplyConstant(value), ExpressionTreeNode(new Operation::PowerConstant(value-1), @@ -305,22 +393,18 @@ ExpressionTreeNode Operation::PowerConstant::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { ExpressionTreeNode step(new Operation::Step(), ExpressionTreeNode(new Operation::Subtract(), children[0], children[1])); - return ExpressionTreeNode(new Operation::Subtract(), - ExpressionTreeNode(new Operation::Multiply(), childDerivs[1], step), - ExpressionTreeNode(new Operation::Multiply(), childDerivs[0], - ExpressionTreeNode(new Operation::AddConstant(-1), step))); + return ExpressionTreeNode(new Operation::Select(), {step, childDerivs[1], childDerivs[0]}); } ExpressionTreeNode Operation::Max::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { ExpressionTreeNode step(new Operation::Step(), ExpressionTreeNode(new Operation::Subtract(), children[0], children[1])); - return ExpressionTreeNode(new Operation::Subtract(), - ExpressionTreeNode(new Operation::Multiply(), childDerivs[0], step), - ExpressionTreeNode(new Operation::Multiply(), childDerivs[1], - ExpressionTreeNode(new Operation::AddConstant(-1), step))); + return ExpressionTreeNode(new Operation::Select(), {step, childDerivs[0], childDerivs[1]}); } ExpressionTreeNode Operation::Abs::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); ExpressionTreeNode step(new Operation::Step(), children[0]); return ExpressionTreeNode(new Operation::Multiply(), childDerivs[0], @@ -337,9 +421,5 @@ ExpressionTreeNode Operation::Ceil::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { - vector derivChildren; - derivChildren.push_back(children[0]); - derivChildren.push_back(childDerivs[1]); - derivChildren.push_back(childDerivs[2]); - return ExpressionTreeNode(new Operation::Select(), derivChildren); + return ExpressionTreeNode(new Operation::Select(), {children[0], childDerivs[1], childDerivs[2]}); } diff --git a/lib/lepton/src/ParsedExpression.cpp b/lib/lepton/src/ParsedExpression.cpp index 1417551011..a6f41ae354 100644 --- a/lib/lepton/src/ParsedExpression.cpp +++ b/lib/lepton/src/ParsedExpression.cpp @@ -6,7 +6,7 @@ * Biological Structures at Stanford, funded under the NIH Roadmap for * * Medical Research, grant U54 GM072970. See https://simtk.org. * * * - * Portions copyright (c) 2009 Stanford University and the Authors. * + * Portions copyright (c) 2009-2022 Stanford University and the Authors. * * Authors: Peter Eastman * * Contributors: * * * @@ -31,6 +31,7 @@ #include "lepton/ParsedExpression.h" #include "lepton/CompiledExpression.h" +#include "lepton/CompiledVectorExpression.h" #include "lepton/ExpressionProgram.h" #include "lepton/Operation.h" #include @@ -68,9 +69,16 @@ double ParsedExpression::evaluate(const ExpressionTreeNode& node, const map examples; + result.assignTags(examples); + map nodeCache; + result = precalculateConstantSubexpressions(result, nodeCache); while (true) { - ExpressionTreeNode simplified = substituteSimplerExpression(result); + examples.clear(); + result.assignTags(examples); + nodeCache.clear(); + ExpressionTreeNode simplified = substituteSimplerExpression(result, nodeCache); if (simplified == result) break; result = simplified; @@ -80,9 +88,15 @@ ParsedExpression ParsedExpression::optimize() const { ParsedExpression ParsedExpression::optimize(const map& variables) const { ExpressionTreeNode result = preevaluateVariables(getRootNode(), variables); - result = precalculateConstantSubexpressions(result); + vector examples; + result.assignTags(examples); + map nodeCache; + result = precalculateConstantSubexpressions(result, nodeCache); while (true) { - ExpressionTreeNode simplified = substituteSimplerExpression(result); + examples.clear(); + result.assignTags(examples); + nodeCache.clear(); + ExpressionTreeNode simplified = substituteSimplerExpression(result, nodeCache); if (simplified == result) break; result = simplified; @@ -104,27 +118,44 @@ ExpressionTreeNode ParsedExpression::preevaluateVariables(const ExpressionTreeNo return ExpressionTreeNode(node.getOperation().clone(), children); } -ExpressionTreeNode ParsedExpression::precalculateConstantSubexpressions(const ExpressionTreeNode& node) { +ExpressionTreeNode ParsedExpression::precalculateConstantSubexpressions(const ExpressionTreeNode& node, map& nodeCache) { + auto cached = nodeCache.find(node.tag); + if (cached != nodeCache.end()) + return cached->second; vector children(node.getChildren().size()); for (int i = 0; i < (int) children.size(); i++) - children[i] = precalculateConstantSubexpressions(node.getChildren()[i]); + children[i] = precalculateConstantSubexpressions(node.getChildren()[i], nodeCache); ExpressionTreeNode result = ExpressionTreeNode(node.getOperation().clone(), children); - if (node.getOperation().getId() == Operation::VARIABLE || node.getOperation().getId() == Operation::CUSTOM) + if (node.getOperation().getId() == Operation::VARIABLE || node.getOperation().getId() == Operation::CUSTOM) { + nodeCache[node.tag] = result; return result; + } for (int i = 0; i < (int) children.size(); i++) - if (children[i].getOperation().getId() != Operation::CONSTANT) + if (children[i].getOperation().getId() != Operation::CONSTANT) { + nodeCache[node.tag] = result; return result; - return ExpressionTreeNode(new Operation::Constant(evaluate(result, map()))); + } + result = ExpressionTreeNode(new Operation::Constant(evaluate(result, map()))); + nodeCache[node.tag] = result; + return result; } -ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const ExpressionTreeNode& node) { +ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const ExpressionTreeNode& node, map& nodeCache) { vector children(node.getChildren().size()); - for (int i = 0; i < (int) children.size(); i++) - children[i] = substituteSimplerExpression(node.getChildren()[i]); + for (int i = 0; i < (int) children.size(); i++) { + const ExpressionTreeNode& child = node.getChildren()[i]; + auto cached = nodeCache.find(child.tag); + if (cached == nodeCache.end()) { + children[i] = substituteSimplerExpression(child, nodeCache); + nodeCache[child.tag] = children[i]; + } + else + children[i] = cached->second; + } // Collect some info on constant expressions in children bool first_const = children.size() > 0 && isConstant(children[0]); // is first child constant? - bool second_const = children.size() > 1 && isConstant(children[1]); ; // is second child constant? + bool second_const = children.size() > 1 && isConstant(children[1]); // is second child constant? double first, second; // if yes, value of first and second child if (first_const) first = getConstantValue(children[0]); @@ -296,6 +327,12 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio return children[0].getChildren()[0]; break; } + case Operation::SELECT: + { + if (children[1] == children[2]) // Select between two identical values + return children[1]; + break; + } default: { // If operation ID is not one of the above, @@ -308,14 +345,22 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio } ParsedExpression ParsedExpression::differentiate(const string& variable) const { - return differentiate(getRootNode(), variable); + vector examples; + getRootNode().assignTags(examples); + map nodeCache; + return differentiate(getRootNode(), variable, nodeCache); } -ExpressionTreeNode ParsedExpression::differentiate(const ExpressionTreeNode& node, const string& variable) { +ExpressionTreeNode ParsedExpression::differentiate(const ExpressionTreeNode& node, const string& variable, map& nodeCache) { + auto cached = nodeCache.find(node.tag); + if (cached != nodeCache.end()) + return cached->second; vector childDerivs(node.getChildren().size()); for (int i = 0; i < (int) childDerivs.size(); i++) - childDerivs[i] = differentiate(node.getChildren()[i], variable); - return node.getOperation().differentiate(node.getChildren(),childDerivs, variable); + childDerivs[i] = differentiate(node.getChildren()[i], variable, nodeCache); + ExpressionTreeNode result = node.getOperation().differentiate(node.getChildren(), childDerivs, variable); + nodeCache[node.tag] = result; + return result; } bool ParsedExpression::isConstant(const ExpressionTreeNode& node) { @@ -337,6 +382,10 @@ CompiledExpression ParsedExpression::createCompiledExpression() const { return CompiledExpression(*this); } +CompiledVectorExpression ParsedExpression::createCompiledVectorExpression(int width) const { + return CompiledVectorExpression(*this, width); +} + ParsedExpression ParsedExpression::renameVariables(const map& replacements) const { return ParsedExpression(renameNodeVariables(getRootNode(), replacements)); }