Update Colvars to 2022-05-24 and copy of Lepton library

One bugfix for the Colvars library in the ABF method, and update of the copy
of the Lepton library as per the OpenMM repository.

List of relevant PR.

- 483 Update Lepton via patching procedure
  https://github.com/Colvars/colvars/pull/483 (@giacomofiorin)

- 482 Fix integer overflow in log_gradient_finite_diff and gradient_finite_diff
  https://github.com/Colvars/colvars/pull/482 (@HanatoK)
This commit is contained in:
Giacomo Fiorin
2022-06-02 11:24:04 -04:00
parent 30ae7fe66b
commit 3a1423dc48
20 changed files with 1942 additions and 235 deletions

View File

@ -62,10 +62,13 @@ COLVARS_SRCS = \
colvar_neuralnetworkcompute.cpp
LEPTON_SRCS = \
lepton/src/CompiledExpression.cpp lepton/src/ExpressionTreeNode.cpp \
lepton/src/ParsedExpression.cpp lepton/src/ExpressionProgram.cpp \
lepton/src/Operation.cpp lepton/src/Parser.cpp
lepton/src/CompiledExpression.cpp \
lepton/src/CompiledVectorExpression.cpp \
lepton/src/ExpressionProgram.cpp \
lepton/src/ExpressionTreeNode.cpp \
lepton/src/Operation.cpp \
lepton/src/ParsedExpression.cpp \
lepton/src/Parser.cpp
# Allow to selectively turn off Lepton
ifeq ($(COLVARS_LEPTON),no)
@ -93,4 +96,13 @@ Makefile.deps: $(COLVARS_SRCS)
done
include Makefile.deps
include Makefile.lepton.deps # Hand-generated
Makefile.lepton.deps: $(LEPTON_SRCS)
@echo > $@
@for src in $^ ; do \
obj=`basename $$src .cpp`.o ; \
$(CXX) $(CXXFLAGS) -MM $(LEPTON_INCFLAGS) \
-MT '$$(COLVARS_OBJ_DIR)'$$obj $$src >> $@ ; \
done
include Makefile.lepton.deps

View File

@ -1,36 +1,46 @@
lepton/src/CompiledExpression.o: lepton/src/CompiledExpression.cpp \
$(COLVARS_OBJ_DIR)CompiledExpression.o: lepton/src/CompiledExpression.cpp \
lepton/include/lepton/CompiledExpression.h \
lepton/include/lepton/ExpressionTreeNode.h \
lepton/include/lepton/windowsIncludes.h \
lepton/include/lepton/Operation.h lepton/include/lepton/CustomFunction.h \
lepton/include/lepton/Exception.h \
lepton/include/lepton/ParsedExpression.h
lepton/src/ExpressionProgram.o: lepton/src/ExpressionProgram.cpp \
$(COLVARS_OBJ_DIR)CompiledVectorExpression.o: \
lepton/src/CompiledVectorExpression.cpp \
lepton/include/lepton/CompiledVectorExpression.h \
lepton/include/lepton/ExpressionTreeNode.h \
lepton/include/lepton/windowsIncludes.h \
lepton/include/lepton/Operation.h lepton/include/lepton/CustomFunction.h \
lepton/include/lepton/Exception.h \
lepton/include/lepton/ParsedExpression.h
$(COLVARS_OBJ_DIR)ExpressionProgram.o: lepton/src/ExpressionProgram.cpp \
lepton/include/lepton/ExpressionProgram.h \
lepton/include/lepton/ExpressionTreeNode.h \
lepton/include/lepton/windowsIncludes.h \
lepton/include/lepton/Operation.h lepton/include/lepton/CustomFunction.h \
lepton/include/lepton/Exception.h \
lepton/include/lepton/ParsedExpression.h
lepton/src/ExpressionTreeNode.o: lepton/src/ExpressionTreeNode.cpp \
$(COLVARS_OBJ_DIR)ExpressionTreeNode.o: lepton/src/ExpressionTreeNode.cpp \
lepton/include/lepton/ExpressionTreeNode.h \
lepton/include/lepton/windowsIncludes.h \
lepton/include/lepton/Exception.h lepton/include/lepton/Operation.h \
lepton/include/lepton/CustomFunction.h lepton/include/lepton/Exception.h
lepton/src/Operation.o: lepton/src/Operation.cpp \
$(COLVARS_OBJ_DIR)Operation.o: lepton/src/Operation.cpp \
lepton/include/lepton/Operation.h \
lepton/include/lepton/windowsIncludes.h \
lepton/include/lepton/CustomFunction.h lepton/include/lepton/Exception.h \
lepton/include/lepton/ExpressionTreeNode.h lepton/src/MSVC_erfc.h
lepton/src/ParsedExpression.o: lepton/src/ParsedExpression.cpp \
$(COLVARS_OBJ_DIR)ParsedExpression.o: lepton/src/ParsedExpression.cpp \
lepton/include/lepton/ParsedExpression.h \
lepton/include/lepton/ExpressionTreeNode.h \
lepton/include/lepton/windowsIncludes.h \
lepton/include/lepton/CompiledExpression.h \
lepton/include/lepton/CompiledVectorExpression.h \
lepton/include/lepton/ExpressionProgram.h \
lepton/include/lepton/Operation.h lepton/include/lepton/CustomFunction.h \
lepton/include/lepton/Exception.h
lepton/src/Parser.o: lepton/src/Parser.cpp \
$(COLVARS_OBJ_DIR)Parser.o: lepton/src/Parser.cpp \
lepton/include/lepton/Parser.h lepton/include/lepton/windowsIncludes.h \
lepton/include/lepton/CustomFunction.h lepton/include/lepton/Exception.h \
lepton/include/lepton/ExpressionTreeNode.h \

View File

@ -1275,7 +1275,7 @@ public:
inline cvm::real log_gradient_finite_diff(const std::vector<int> &ix0,
int n = 0)
{
int A0, A1, A2;
cvm::real A0, A1, A2;
std::vector<int> ix = ix0;
// TODO this can be rewritten more concisely with wrap_edge()
@ -1288,7 +1288,7 @@ public:
if (A0 * A1 == 0) {
return 0.; // can't handle empty bins
} else {
return (cvm::logn((cvm::real)A1) - cvm::logn((cvm::real)A0))
return (cvm::logn(A1) - cvm::logn(A0))
/ (widths[n] * 2.);
}
} else if (ix[n] > 0 && ix[n] < nx[n]-1) { // not an edge
@ -1300,7 +1300,7 @@ public:
if (A0 * A1 == 0) {
return 0.; // can't handle empty bins
} else {
return (cvm::logn((cvm::real)A1) - cvm::logn((cvm::real)A0))
return (cvm::logn(A1) - cvm::logn(A0))
/ (widths[n] * 2.);
}
} else {
@ -1313,8 +1313,8 @@ public:
if (A0 * A1 * A2 == 0) {
return 0.; // can't handle empty bins
} else {
return (-1.5 * cvm::logn((cvm::real)A0) + 2. * cvm::logn((cvm::real)A1)
- 0.5 * cvm::logn((cvm::real)A2)) * increment / widths[n];
return (-1.5 * cvm::logn(A0) + 2. * cvm::logn(A1)
- 0.5 * cvm::logn(A2)) * increment / widths[n];
}
}
}
@ -1324,7 +1324,7 @@ public:
inline cvm::real gradient_finite_diff(const std::vector<int> &ix0,
int n = 0)
{
int A0, A1, A2;
cvm::real A0, A1, A2;
std::vector<int> ix = ix0;
// FIXME this can be rewritten more concisely with wrap_edge()
@ -1337,7 +1337,7 @@ public:
if (A0 * A1 == 0) {
return 0.; // can't handle empty bins
} else {
return cvm::real(A1 - A0) / (widths[n] * 2.);
return (A1 - A0) / (widths[n] * 2.);
}
} else if (ix[n] > 0 && ix[n] < nx[n]-1) { // not an edge
ix[n]--;
@ -1348,7 +1348,7 @@ public:
if (A0 * A1 == 0) {
return 0.; // can't handle empty bins
} else {
return cvm::real(A1 - A0) / (widths[n] * 2.);
return (A1 - A0) / (widths[n] * 2.);
}
} else {
// edge: use 2nd order derivative
@ -1357,8 +1357,8 @@ public:
A0 = value(ix);
ix[n] += increment; A1 = value(ix);
ix[n] += increment; A2 = value(ix);
return (-1.5 * cvm::real(A0) + 2. * cvm::real(A1)
- 0.5 * cvm::real(A2)) * increment / widths[n];
return (-1.5 * A0 + 2. * A1
- 0.5 * A2) * increment / widths[n];
}
}
};

View File

@ -22,9 +22,7 @@
colvarproxy_tcl::colvarproxy_tcl()
{
#ifdef COLVARS_TCL
tcl_interp_ = NULL;
#endif
}

View File

@ -1,3 +1,3 @@
#ifndef COLVARS_VERSION
#define COLVARS_VERSION "2022-05-09"
#define COLVARS_VERSION "2022-05-24"
#endif

View File

@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2013-2019 Stanford University and the Authors. *
* Portions copyright (c) 2013-2022 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
@ -40,7 +40,11 @@
#include <utility>
#include <vector>
#ifdef LEPTON_USE_JIT
#include "asmjit.h"
#if defined(__ARM__) || defined(__ARM64__)
#include "asmjit/a64.h"
#else
#include "asmjit/x86.h"
#endif
#endif
namespace Lepton {
@ -52,9 +56,9 @@ class ParsedExpression;
* A CompiledExpression is a highly optimized representation of an expression for cases when you want to evaluate
* it many times as quickly as possible. You should treat it as an opaque object; none of the internal representation
* is visible.
*
*
* A CompiledExpression is created by calling createCompiledExpression() on a ParsedExpression.
*
*
* WARNING: CompiledExpression is NOT thread safe. You should never access a CompiledExpression from two threads at
* the same time.
*/
@ -101,9 +105,15 @@ private:
std::map<std::string, double> dummyVariables;
double (*jitCode)();
#ifdef LEPTON_USE_JIT
void findPowerGroups(std::vector<std::vector<int> >& groups, std::vector<std::vector<int> >& groupPowers, std::vector<int>& stepGroup);
void generateJitCode();
void generateSingleArgCall(asmjit::X86Compiler& c, asmjit::X86Xmm& dest, asmjit::X86Xmm& arg, double (*function)(double));
void generateTwoArgCall(asmjit::X86Compiler& c, asmjit::X86Xmm& dest, asmjit::X86Xmm& arg1, asmjit::X86Xmm& arg2, double (*function)(double, double));
#if defined(__ARM__) || defined(__ARM64__)
void generateSingleArgCall(asmjit::a64::Compiler& c, asmjit::arm::Vec& dest, asmjit::arm::Vec& arg, double (*function)(double));
void generateTwoArgCall(asmjit::a64::Compiler& c, asmjit::arm::Vec& dest, asmjit::arm::Vec& arg1, asmjit::arm::Vec& arg2, double (*function)(double, double));
#else
void generateSingleArgCall(asmjit::x86::Compiler& c, asmjit::x86::Xmm& dest, asmjit::x86::Xmm& arg, double (*function)(double));
void generateTwoArgCall(asmjit::x86::Compiler& c, asmjit::x86::Xmm& dest, asmjit::x86::Xmm& arg1, asmjit::x86::Xmm& arg2, double (*function)(double, double));
#endif
std::vector<double> constants;
asmjit::JitRuntime runtime;
#endif

View File

@ -0,0 +1,145 @@
#ifndef LEPTON_VECTOR_EXPRESSION_H_
#define LEPTON_VECTOR_EXPRESSION_H_
/* -------------------------------------------------------------------------- *
* Lepton *
* -------------------------------------------------------------------------- *
* This is part of the Lepton expression parser originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2013-2022 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
* to deal in the Software without restriction, including without limitation *
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
* and/or sell copies of the Software, and to permit persons to whom the *
* Software is furnished to do so, subject to the following conditions: *
* *
* The above copyright notice and this permission notice shall be included in *
* all copies or substantial portions of the Software. *
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
#include "ExpressionTreeNode.h"
#include "windowsIncludes.h"
#include <array>
#include <map>
#include <set>
#include <string>
#include <utility>
#include <vector>
#ifdef LEPTON_USE_JIT
#if defined(__ARM__) || defined(__ARM64__)
#include "asmjit/a64.h"
#else
#include "asmjit/x86.h"
#endif
#endif
namespace Lepton {
class Operation;
class ParsedExpression;
/**
* A CompiledVectorExpression is a highly optimized representation of an expression for cases when you want to evaluate
* it many times as quickly as possible. It is similar to CompiledExpression, with the extra feature that it uses the CPU's
* vector unit (AVX on x86, NEON on ARM) to evaluate the expression for multiple sets of arguments at once. It also differs
* from CompiledExpression and ParsedExpression in using single precision rather than double precision to evaluate the expression.
* You should treat it as an opaque object; none of the internal representation is visible.
*
* A CompiledVectorExpression is created by calling createCompiledVectorExpression() on a ParsedExpression. When you create
* it, you must specify the width of the vectors on which to compute the expression. The allowed widths depend on the type of
* CPU it is running on. 4 is always allowed, and 8 is allowed on x86 processors with AVX. Call getAllowedWidths() to query
* the allowed values.
*
* WARNING: CompiledVectorExpression is NOT thread safe. You should never access a CompiledVectorExpression from two threads at
* the same time.
*/
class LEPTON_EXPORT CompiledVectorExpression {
public:
CompiledVectorExpression();
CompiledVectorExpression(const CompiledVectorExpression& expression);
~CompiledVectorExpression();
CompiledVectorExpression& operator=(const CompiledVectorExpression& expression);
/**
* Get the width of the vectors on which the expression is computed.
*/
int getWidth() const;
/**
* Get the names of all variables used by this expression.
*/
const std::set<std::string>& getVariables() const;
/**
* Get a pointer to the memory location where the value of a particular variable is stored. This can be used
* to set the value of the variable before calling evaluate().
*
* @param name the name of the variable to query
* @return a pointer to N floating point values, where N is the vector width
*/
float* getVariablePointer(const std::string& name);
/**
* You can optionally specify the memory locations from which the values of variables should be read.
* This is useful, for example, when several expressions all use the same variable. You can then set
* the value of that variable in one place, and it will be seen by all of them. The location should
* be a pointer to N floating point values, where N is the vector width.
*/
void setVariableLocations(std::map<std::string, float*>& variableLocations);
/**
* Evaluate the expression. The values of all variables should have been set before calling this.
*
* @return a pointer to N floating point values, where N is the vector width
*/
const float* evaluate() const;
/**
* Get the list of vector widths that are supported on the current processor.
*/
static const std::vector<int>& getAllowedWidths();
private:
friend class ParsedExpression;
CompiledVectorExpression(const ParsedExpression& expression, int width);
void compileExpression(const ExpressionTreeNode& node, std::vector<std::pair<ExpressionTreeNode, int> >& temps, int& workspaceSize);
int findTempIndex(const ExpressionTreeNode& node, std::vector<std::pair<ExpressionTreeNode, int> >& temps);
int width;
std::map<std::string, float*> variablePointers;
std::vector<std::pair<float*, float*> > variablesToCopy;
std::vector<std::vector<int> > arguments;
std::vector<int> target;
std::vector<Operation*> operation;
std::map<std::string, int> variableIndices;
std::set<std::string> variableNames;
mutable std::vector<float> workspace;
mutable std::vector<double> argValues;
std::map<std::string, double> dummyVariables;
void (*jitCode)();
#ifdef LEPTON_USE_JIT
void findPowerGroups(std::vector<std::vector<int> >& groups, std::vector<std::vector<int> >& groupPowers, std::vector<int>& stepGroup);
void generateJitCode();
#if defined(__ARM__) || defined(__ARM64__)
void generateSingleArgCall(asmjit::a64::Compiler& c, asmjit::arm::Vec& dest, asmjit::arm::Vec& arg, float (*function)(float));
void generateTwoArgCall(asmjit::a64::Compiler& c, asmjit::arm::Vec& dest, asmjit::arm::Vec& arg1, asmjit::arm::Vec& arg2, float (*function)(float, float));
#else
void generateSingleArgCall(asmjit::x86::Compiler& c, asmjit::x86::Ymm& dest, asmjit::x86::Ymm& arg, float (*function)(float));
void generateTwoArgCall(asmjit::x86::Compiler& c, asmjit::x86::Ymm& dest, asmjit::x86::Ymm& arg1, asmjit::x86::Ymm& arg2, float (*function)(float, float));
#endif
std::vector<float> constants;
asmjit::JitRuntime runtime;
#endif
};
} // namespace Lepton
#endif /*LEPTON_VECTOR_EXPRESSION_H_*/

View File

@ -83,7 +83,7 @@ class LEPTON_EXPORT PlaceholderFunction : public CustomFunction {
public:
/**
* Create a Placeholder function.
*
*
* @param numArgs the number of arguments the function expects
*/
PlaceholderFunction(int numArgs) : numArgs(numArgs) {

View File

@ -67,7 +67,7 @@ public:
const Operation& getOperation(int index) const;
/**
* Change an Operation in this program.
*
*
* The Operation must have been allocated on the heap with the "new" operator.
* The ExpressionProgram assumes ownership of it and will delete it when it
* is no longer needed.

View File

@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2009 Stanford University and the Authors. *
* Portions copyright (c) 2009-2021 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
@ -39,6 +39,7 @@
namespace Lepton {
class Operation;
class ParsedExpression;
/**
* This class represents a node in the abstract syntax tree representation of an expression.
@ -82,11 +83,13 @@ public:
*/
ExpressionTreeNode(Operation* operation);
ExpressionTreeNode(const ExpressionTreeNode& node);
ExpressionTreeNode(ExpressionTreeNode&& node);
ExpressionTreeNode();
~ExpressionTreeNode();
bool operator==(const ExpressionTreeNode& node) const;
bool operator!=(const ExpressionTreeNode& node) const;
ExpressionTreeNode& operator=(const ExpressionTreeNode& node);
ExpressionTreeNode& operator=(ExpressionTreeNode&& node);
/**
* Get the Operation performed by this node.
*/
@ -96,8 +99,11 @@ public:
*/
const std::vector<ExpressionTreeNode>& getChildren() const;
private:
friend class ParsedExpression;
void assignTags(std::vector<const ExpressionTreeNode*>& examples) const;
Operation* operation;
std::vector<ExpressionTreeNode> children;
mutable int tag;
};
} // namespace Lepton

View File

@ -1017,7 +1017,7 @@ public:
double evaluate(double* args, const std::map<std::string, double>& variables) const {
if (isIntPower) {
// Integer powers can be computed much more quickly by repeated multiplication.
int exponent = intValue;
double base = args[0];
if (exponent < 0) {

View File

@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2009=2013 Stanford University and the Authors. *
* Portions copyright (c) 2009-2022 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
@ -41,6 +41,7 @@ namespace Lepton {
class CompiledExpression;
class ExpressionProgram;
class CompiledVectorExpression;
/**
* This class represents the result of parsing an expression. It provides methods for working with the
@ -102,6 +103,16 @@ public:
* Create a CompiledExpression that represents the same calculation as this expression.
*/
CompiledExpression createCompiledExpression() const;
/**
* Create a CompiledVectorExpression that allows the expression to be evaluated efficiently
* using the CPU's vector unit.
*
* @param width the width of the vectors to evaluate it on. The allowed values
* depend on the CPU. 4 is always allowed, and 8 is allowed on
* x86 processors with AVX. Call CompiledVectorExpression::getAllowedWidths()
* to query the allowed widths on the current processor.
*/
CompiledVectorExpression createCompiledVectorExpression(int width) const;
/**
* Create a new ParsedExpression which is identical to this one, except that the names of some
* variables have been changed.
@ -113,9 +124,10 @@ public:
private:
static double evaluate(const ExpressionTreeNode& node, const std::map<std::string, double>& variables);
static ExpressionTreeNode preevaluateVariables(const ExpressionTreeNode& node, const std::map<std::string, double>& variables);
static ExpressionTreeNode precalculateConstantSubexpressions(const ExpressionTreeNode& node);
static ExpressionTreeNode substituteSimplerExpression(const ExpressionTreeNode& node);
static ExpressionTreeNode differentiate(const ExpressionTreeNode& node, const std::string& variable);
static ExpressionTreeNode precalculateConstantSubexpressions(const ExpressionTreeNode& node, std::map<int, ExpressionTreeNode>& nodeCache);
static ExpressionTreeNode substituteSimplerExpression(const ExpressionTreeNode& node, std::map<int, ExpressionTreeNode>& nodeCache);
static ExpressionTreeNode differentiate(const ExpressionTreeNode& node, const std::string& variable, std::map<int, ExpressionTreeNode>& nodeCache);
static bool isConstant(const ExpressionTreeNode& node);
static double getConstantValue(const ExpressionTreeNode& node);
static ExpressionTreeNode renameNodeVariables(const ExpressionTreeNode& node, const std::map<std::string, std::string>& replacements);
ExpressionTreeNode rootNode;

View File

@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2013-2019 Stanford University and the Authors. *
* Portions copyright (c) 2013-2022 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
@ -84,17 +84,17 @@ CompiledExpression& CompiledExpression::operator=(const CompiledExpression& expr
void CompiledExpression::compileExpression(const ExpressionTreeNode& node, vector<pair<ExpressionTreeNode, int> >& temps) {
if (findTempIndex(node, temps) != -1)
return; // We have already processed a node identical to this one.
// Process the child nodes.
vector<int> args;
for (int i = 0; i < node.getChildren().size(); i++) {
compileExpression(node.getChildren()[i], temps);
args.push_back(findTempIndex(node.getChildren()[i], temps));
}
// Process this node.
if (node.getOperation().getId() == Operation::VARIABLE) {
variableIndices[node.getOperation().getName()] = (int) workspace.size();
variableNames.insert(node.getOperation().getName());
@ -108,7 +108,7 @@ void CompiledExpression::compileExpression(const ExpressionTreeNode& node, vecto
arguments[stepIndex].push_back(0); // The value won't actually be used. We just need something there.
else {
// If the arguments are sequential, we can just pass a pointer to the first one.
bool sequential = true;
for (int i = 1; i < args.size(); i++)
if (args[i] != args[i-1]+1)
@ -148,30 +148,28 @@ void CompiledExpression::setVariableLocations(map<string, double*>& variableLoca
variablePointers = variableLocations;
#ifdef LEPTON_USE_JIT
// Rebuild the JIT code.
if (workspace.size() > 0)
generateJitCode();
#else
#endif
// Make a list of all variables we will need to copy before evaluating the expression.
variablesToCopy.clear();
for (map<string, int>::const_iterator iter = variableIndices.begin(); iter != variableIndices.end(); ++iter) {
map<string, double*>::iterator pointer = variablePointers.find(iter->first);
if (pointer != variablePointers.end())
variablesToCopy.push_back(make_pair(&workspace[iter->second], pointer->second));
}
#endif
}
double CompiledExpression::evaluate() const {
#ifdef LEPTON_USE_JIT
return jitCode();
#else
if (jitCode)
return jitCode();
for (int i = 0; i < variablesToCopy.size(); i++)
*variablesToCopy[i].first = *variablesToCopy[i].second;
// Loop over the operations and evaluate each one.
for (int step = 0; step < operation.size(); step++) {
const vector<int>& args = arguments[step];
if (args.size() == 1)
@ -183,7 +181,6 @@ double CompiledExpression::evaluate() const {
}
}
return workspace[workspace.size()-1];
#endif
}
#ifdef LEPTON_USE_JIT
@ -192,32 +189,78 @@ static double evaluateOperation(Operation* op, double* args) {
return op->evaluate(args, dummyVariables);
}
void CompiledExpression::findPowerGroups(vector<vector<int> >& groups, vector<vector<int> >& groupPowers, vector<int>& stepGroup) {
// Identify every step that raises an argument to an integer power.
vector<int> stepPower(operation.size(), 0);
vector<int> stepArg(operation.size(), -1);
for (int step = 0; step < operation.size(); step++) {
Operation& op = *operation[step];
int power = 0;
if (op.getId() == Operation::SQUARE)
power = 2;
else if (op.getId() == Operation::CUBE)
power = 3;
else if (op.getId() == Operation::POWER_CONSTANT) {
double realPower = dynamic_cast<const Operation::PowerConstant*>(&op)->getValue();
if (realPower == (int) realPower)
power = (int) realPower;
}
if (power != 0) {
stepPower[step] = power;
stepArg[step] = arguments[step][0];
}
}
// Find groups that operate on the same argument and whose powers have the same sign.
stepGroup.resize(operation.size(), -1);
for (int i = 0; i < operation.size(); i++) {
if (stepGroup[i] != -1)
continue;
vector<int> group, power;
for (int j = i; j < operation.size(); j++) {
if (stepArg[i] == stepArg[j] && stepPower[i]*stepPower[j] > 0) {
stepGroup[j] = groups.size();
group.push_back(j);
power.push_back(stepPower[j]);
}
}
groups.push_back(group);
groupPowers.push_back(power);
}
}
#if defined(__ARM__) || defined(__ARM64__)
void CompiledExpression::generateJitCode() {
CodeHolder code;
code.init(runtime.getCodeInfo());
X86Compiler c(&code);
c.addFunc(FuncSignature0<double>());
vector<X86Xmm> workspaceVar(workspace.size());
code.init(runtime.environment());
a64::Compiler c(&code);
c.addFunc(FuncSignatureT<double>());
vector<arm::Vec> workspaceVar(workspace.size());
for (int i = 0; i < (int) workspaceVar.size(); i++)
workspaceVar[i] = c.newXmmSd();
X86Gp argsPointer = c.newIntPtr();
c.mov(argsPointer, imm_ptr(&argValues[0]));
workspaceVar[i] = c.newVecD();
arm::Gp argsPointer = c.newIntPtr();
c.mov(argsPointer, imm(&argValues[0]));
vector<vector<int> > groups, groupPowers;
vector<int> stepGroup;
findPowerGroups(groups, groupPowers, stepGroup);
// Load the arguments into variables.
for (set<string>::const_iterator iter = variableNames.begin(); iter != variableNames.end(); ++iter) {
map<string, int>::iterator index = variableIndices.find(*iter);
X86Gp variablePointer = c.newIntPtr();
c.mov(variablePointer, imm_ptr(&getVariableReference(index->first)));
c.movsd(workspaceVar[index->second], x86::ptr(variablePointer, 0, 0));
arm::Gp variablePointer = c.newIntPtr();
c.mov(variablePointer, imm(&getVariableReference(index->first)));
c.ldr(workspaceVar[index->second], arm::ptr(variablePointer, 0));
}
// Make a list of all constants that will be needed for evaluation.
vector<int> operationConstantIndex(operation.size(), -1);
for (int step = 0; step < (int) operation.size(); step++) {
// Find the constant value (if any) used by this operation.
Operation& op = *operation[step];
double value;
if (op.getId() == Operation::CONSTANT)
@ -232,11 +275,17 @@ void CompiledExpression::generateJitCode() {
value = 1.0;
else if (op.getId() == Operation::DELTA)
value = 1.0;
else if (op.getId() == Operation::POWER_CONSTANT) {
if (stepGroup[step] == -1)
value = dynamic_cast<Operation::PowerConstant&>(op).getValue();
else
value = 1.0;
}
else
continue;
// See if we already have a variable for this constant.
for (int i = 0; i < (int) constants.size(); i++)
if (value == constants[i]) {
operationConstantIndex[step] = i;
@ -247,62 +296,101 @@ void CompiledExpression::generateJitCode() {
constants.push_back(value);
}
}
// Load constants into variables.
vector<X86Xmm> constantVar(constants.size());
vector<arm::Vec> constantVar(constants.size());
if (constants.size() > 0) {
X86Gp constantsPointer = c.newIntPtr();
c.mov(constantsPointer, imm_ptr(&constants[0]));
arm::Gp constantsPointer = c.newIntPtr();
c.mov(constantsPointer, imm(&constants[0]));
for (int i = 0; i < (int) constants.size(); i++) {
constantVar[i] = c.newXmmSd();
c.movsd(constantVar[i], x86::ptr(constantsPointer, 8*i, 0));
constantVar[i] = c.newVecD();
c.ldr(constantVar[i], arm::ptr(constantsPointer, 8*i));
}
}
// Evaluate the operations.
vector<bool> hasComputedPower(operation.size(), false);
for (int step = 0; step < (int) operation.size(); step++) {
if (hasComputedPower[step])
continue;
// When one or more steps involve raising the same argument to multiple integer
// powers, we can compute them all together for efficiency.
if (stepGroup[step] != -1) {
vector<int>& group = groups[stepGroup[step]];
vector<int>& powers = groupPowers[stepGroup[step]];
arm::Vec multiplier = c.newVecD();
if (powers[0] > 0)
c.fmov(multiplier, workspaceVar[arguments[step][0]]);
else {
c.fdiv(multiplier, constantVar[operationConstantIndex[step]], workspaceVar[arguments[step][0]]);
for (int i = 0; i < powers.size(); i++)
powers[i] = -powers[i];
}
vector<bool> hasAssigned(group.size(), false);
bool done = false;
while (!done) {
done = true;
for (int i = 0; i < group.size(); i++) {
if (powers[i]%2 == 1) {
if (!hasAssigned[i])
c.fmov(workspaceVar[target[group[i]]], multiplier);
else
c.fmul(workspaceVar[target[group[i]]], workspaceVar[target[group[i]]], multiplier);
hasAssigned[i] = true;
}
powers[i] >>= 1;
if (powers[i] != 0)
done = false;
}
if (!done)
c.fmul(multiplier, multiplier, multiplier);
}
for (int step : group)
hasComputedPower[step] = true;
continue;
}
// Evaluate the step.
Operation& op = *operation[step];
vector<int> args = arguments[step];
if (args.size() == 1) {
// One or more sequential arguments. Fill out the list.
for (int i = 1; i < op.getNumArguments(); i++)
args.push_back(args[0]+i);
}
// Generate instructions to execute this operation.
switch (op.getId()) {
case Operation::CONSTANT:
c.movsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
c.fmov(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
break;
case Operation::ADD:
c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
c.addsd(workspaceVar[target[step]], workspaceVar[args[1]]);
c.fadd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
break;
case Operation::SUBTRACT:
c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
c.subsd(workspaceVar[target[step]], workspaceVar[args[1]]);
c.fsub(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
break;
case Operation::MULTIPLY:
c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
c.mulsd(workspaceVar[target[step]], workspaceVar[args[1]]);
c.fmul(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
break;
case Operation::DIVIDE:
c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
c.divsd(workspaceVar[target[step]], workspaceVar[args[1]]);
c.fdiv(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
break;
case Operation::POWER:
generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], pow);
break;
case Operation::NEGATE:
c.xorps(workspaceVar[target[step]], workspaceVar[target[step]]);
c.subsd(workspaceVar[target[step]], workspaceVar[args[0]]);
c.fneg(workspaceVar[target[step]], workspaceVar[args[0]]);
break;
case Operation::SQRT:
c.sqrtsd(workspaceVar[target[step]], workspaceVar[args[0]]);
c.fsqrt(workspaceVar[target[step]], workspaceVar[args[0]]);
break;
case Operation::EXP:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], exp);
@ -341,56 +429,63 @@ void CompiledExpression::generateJitCode() {
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tanh);
break;
case Operation::STEP:
c.xorps(workspaceVar[target[step]], workspaceVar[target[step]]);
c.cmpsd(workspaceVar[target[step]], workspaceVar[args[0]], imm(18)); // Comparison mode is _CMP_LE_OQ = 18
c.andps(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
c.cmge(workspaceVar[target[step]], workspaceVar[args[0]], imm(0));
c.and_(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
break;
case Operation::DELTA:
c.xorps(workspaceVar[target[step]], workspaceVar[target[step]]);
c.cmpsd(workspaceVar[target[step]], workspaceVar[args[0]], imm(16)); // Comparison mode is _CMP_EQ_OS = 16
c.andps(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
c.cmeq(workspaceVar[target[step]], workspaceVar[args[0]], imm(0));
c.and_(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
break;
case Operation::SQUARE:
c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
c.mulsd(workspaceVar[target[step]], workspaceVar[args[0]]);
c.fmul(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]]);
break;
case Operation::CUBE:
c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
c.mulsd(workspaceVar[target[step]], workspaceVar[args[0]]);
c.mulsd(workspaceVar[target[step]], workspaceVar[args[0]]);
c.fmul(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]]);
c.fmul(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]]);
break;
case Operation::RECIPROCAL:
c.movsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
c.divsd(workspaceVar[target[step]], workspaceVar[args[0]]);
c.fdiv(workspaceVar[target[step]], constantVar[operationConstantIndex[step]], workspaceVar[args[0]]);
break;
case Operation::ADD_CONSTANT:
c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
c.addsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
c.fadd(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]);
break;
case Operation::MULTIPLY_CONSTANT:
c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
c.mulsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
c.fmul(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]);
break;
case Operation::POWER_CONSTANT:
generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]], pow);
break;
case Operation::MIN:
c.fmin(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
break;
case Operation::MAX:
c.fmax(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
break;
case Operation::ABS:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], fabs);
c.fabs(workspaceVar[target[step]], workspaceVar[args[0]]);
break;
case Operation::FLOOR:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], floor);
c.frintm(workspaceVar[target[step]], workspaceVar[args[0]]);
break;
case Operation::CEIL:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], ceil);
c.frintp(workspaceVar[target[step]], workspaceVar[args[0]]);
break;
case Operation::SELECT:
c.fcmeq(workspaceVar[target[step]], workspaceVar[args[0]], imm(0));
c.bsl(workspaceVar[target[step]], workspaceVar[args[2]], workspaceVar[args[1]]);
break;
default:
// Just invoke evaluateOperation().
for (int i = 0; i < (int) args.size(); i++)
c.movsd(x86::ptr(argsPointer, 8*i, 0), workspaceVar[args[i]]);
X86Gp fn = c.newIntPtr();
c.mov(fn, imm_ptr((void*) evaluateOperation));
CCFuncCall* call = c.call(fn, FuncSignature2<double, Operation*, double*>());
call->setArg(0, imm_ptr(&op));
call->setArg(1, imm_ptr(&argValues[0]));
call->setRet(0, workspaceVar[target[step]]);
c.str(workspaceVar[args[i]], arm::ptr(argsPointer, 8*i));
arm::Gp fn = c.newIntPtr();
c.mov(fn, imm((void*) evaluateOperation));
InvokeNode* invoke;
c.invoke(&invoke, fn, FuncSignatureT<double, Operation*, double*>());
invoke->setArg(0, imm(&op));
invoke->setArg(1, imm(&argValues[0]));
invoke->setRet(0, workspaceVar[target[step]]);
}
}
c.ret(workspaceVar[workspace.size()-1]);
@ -399,20 +494,319 @@ void CompiledExpression::generateJitCode() {
runtime.add(&jitCode, &code);
}
void CompiledExpression::generateSingleArgCall(X86Compiler& c, X86Xmm& dest, X86Xmm& arg, double (*function)(double)) {
X86Gp fn = c.newIntPtr();
c.mov(fn, imm_ptr((void*) function));
CCFuncCall* call = c.call(fn, FuncSignature1<double, double>());
call->setArg(0, arg);
call->setRet(0, dest);
void CompiledExpression::generateSingleArgCall(a64::Compiler& c, arm::Vec& dest, arm::Vec& arg, double (*function)(double)) {
arm::Gp fn = c.newIntPtr();
c.mov(fn, imm((void*) function));
InvokeNode* invoke;
c.invoke(&invoke, fn, FuncSignatureT<double, double>());
invoke->setArg(0, arg);
invoke->setRet(0, dest);
}
void CompiledExpression::generateTwoArgCall(X86Compiler& c, X86Xmm& dest, X86Xmm& arg1, X86Xmm& arg2, double (*function)(double, double)) {
X86Gp fn = c.newIntPtr();
c.mov(fn, imm_ptr((void*) function));
CCFuncCall* call = c.call(fn, FuncSignature2<double, double, double>());
call->setArg(0, arg1);
call->setArg(1, arg2);
call->setRet(0, dest);
void CompiledExpression::generateTwoArgCall(a64::Compiler& c, arm::Vec& dest, arm::Vec& arg1, arm::Vec& arg2, double (*function)(double, double)) {
arm::Gp fn = c.newIntPtr();
c.mov(fn, imm((void*) function));
InvokeNode* invoke;
c.invoke(&invoke, fn, FuncSignatureT<double, double, double>());
invoke->setArg(0, arg1);
invoke->setArg(1, arg2);
invoke->setRet(0, dest);
}
#else
void CompiledExpression::generateJitCode() {
const CpuInfo& cpu = CpuInfo::host();
if (!cpu.hasFeature(CpuFeatures::X86::kAVX))
return;
CodeHolder code;
code.init(runtime.environment());
x86::Compiler c(&code);
FuncNode* funcNode = c.addFunc(FuncSignatureT<double>());
funcNode->frame().setAvxEnabled();
vector<x86::Xmm> workspaceVar(workspace.size());
for (int i = 0; i < (int) workspaceVar.size(); i++)
workspaceVar[i] = c.newXmmSd();
x86::Gp argsPointer = c.newIntPtr();
c.mov(argsPointer, imm(&argValues[0]));
vector<vector<int> > groups, groupPowers;
vector<int> stepGroup;
findPowerGroups(groups, groupPowers, stepGroup);
// Load the arguments into variables.
x86::Gp variablePointer = c.newIntPtr();
for (set<string>::const_iterator iter = variableNames.begin(); iter != variableNames.end(); ++iter) {
map<string, int>::iterator index = variableIndices.find(*iter);
c.mov(variablePointer, imm(&getVariableReference(index->first)));
c.vmovsd(workspaceVar[index->second], x86::ptr(variablePointer, 0, 0));
}
// Make a list of all constants that will be needed for evaluation.
vector<int> operationConstantIndex(operation.size(), -1);
for (int step = 0; step < (int) operation.size(); step++) {
// Find the constant value (if any) used by this operation.
Operation& op = *operation[step];
double value;
if (op.getId() == Operation::CONSTANT)
value = dynamic_cast<Operation::Constant&>(op).getValue();
else if (op.getId() == Operation::ADD_CONSTANT)
value = dynamic_cast<Operation::AddConstant&>(op).getValue();
else if (op.getId() == Operation::MULTIPLY_CONSTANT)
value = dynamic_cast<Operation::MultiplyConstant&>(op).getValue();
else if (op.getId() == Operation::RECIPROCAL)
value = 1.0;
else if (op.getId() == Operation::STEP)
value = 1.0;
else if (op.getId() == Operation::DELTA)
value = 1.0;
else if (op.getId() == Operation::ABS) {
long long mask = 0x7FFFFFFFFFFFFFFF;
value = *reinterpret_cast<double*>(&mask);
}
else if (op.getId() == Operation::POWER_CONSTANT) {
if (stepGroup[step] == -1)
value = dynamic_cast<Operation::PowerConstant&>(op).getValue();
else
value = 1.0;
}
else
continue;
// See if we already have a variable for this constant.
for (int i = 0; i < (int) constants.size(); i++)
if (value == constants[i]) {
operationConstantIndex[step] = i;
break;
}
if (operationConstantIndex[step] == -1) {
operationConstantIndex[step] = constants.size();
constants.push_back(value);
}
}
// Load constants into variables.
vector<x86::Xmm> constantVar(constants.size());
if (constants.size() > 0) {
x86::Gp constantsPointer = c.newIntPtr();
c.mov(constantsPointer, imm(&constants[0]));
for (int i = 0; i < (int) constants.size(); i++) {
constantVar[i] = c.newXmmSd();
c.vmovsd(constantVar[i], x86::ptr(constantsPointer, 8*i, 0));
}
}
// Evaluate the operations.
vector<bool> hasComputedPower(operation.size(), false);
for (int step = 0; step < (int) operation.size(); step++) {
if (hasComputedPower[step])
continue;
// When one or more steps involve raising the same argument to multiple integer
// powers, we can compute them all together for efficiency.
if (stepGroup[step] != -1) {
vector<int>& group = groups[stepGroup[step]];
vector<int>& powers = groupPowers[stepGroup[step]];
x86::Xmm multiplier = c.newXmmSd();
if (powers[0] > 0)
c.vmovsd(multiplier, workspaceVar[arguments[step][0]], workspaceVar[arguments[step][0]]);
else {
c.vdivsd(multiplier, constantVar[operationConstantIndex[step]], workspaceVar[arguments[step][0]]);
for (int i = 0; i < powers.size(); i++)
powers[i] = -powers[i];
}
vector<bool> hasAssigned(group.size(), false);
bool done = false;
while (!done) {
done = true;
for (int i = 0; i < group.size(); i++) {
if (powers[i]%2 == 1) {
if (!hasAssigned[i])
c.vmovsd(workspaceVar[target[group[i]]], multiplier, multiplier);
else
c.vmulsd(workspaceVar[target[group[i]]], workspaceVar[target[group[i]]], multiplier);
hasAssigned[i] = true;
}
powers[i] >>= 1;
if (powers[i] != 0)
done = false;
}
if (!done)
c.vmulsd(multiplier, multiplier, multiplier);
}
for (int step : group)
hasComputedPower[step] = true;
continue;
}
// Evaluate the step.
Operation& op = *operation[step];
vector<int> args = arguments[step];
if (args.size() == 1) {
// One or more sequential arguments. Fill out the list.
for (int i = 1; i < op.getNumArguments(); i++)
args.push_back(args[0]+i);
}
// Generate instructions to execute this operation.
switch (op.getId()) {
case Operation::CONSTANT:
c.vmovsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]], constantVar[operationConstantIndex[step]]);
break;
case Operation::ADD:
c.vaddsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
break;
case Operation::SUBTRACT:
c.vsubsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
break;
case Operation::MULTIPLY:
c.vmulsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
break;
case Operation::DIVIDE:
c.vdivsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
break;
case Operation::POWER:
generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], pow);
break;
case Operation::NEGATE:
c.vxorps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[target[step]]);
c.vsubsd(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]]);
break;
case Operation::SQRT:
c.vsqrtsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]]);
break;
case Operation::EXP:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], exp);
break;
case Operation::LOG:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], log);
break;
case Operation::SIN:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], sin);
break;
case Operation::COS:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], cos);
break;
case Operation::TAN:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tan);
break;
case Operation::ASIN:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], asin);
break;
case Operation::ACOS:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], acos);
break;
case Operation::ATAN:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], atan);
break;
case Operation::ATAN2:
generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], atan2);
break;
case Operation::SINH:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], sinh);
break;
case Operation::COSH:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], cosh);
break;
case Operation::TANH:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tanh);
break;
case Operation::STEP:
c.vxorps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[target[step]]);
c.vcmpsd(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]], imm(18)); // Comparison mode is _CMP_LE_OQ = 18
c.vandps(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
break;
case Operation::DELTA:
c.vxorps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[target[step]]);
c.vcmpsd(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]], imm(16)); // Comparison mode is _CMP_EQ_OS = 16
c.vandps(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
break;
case Operation::SQUARE:
c.vmulsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]]);
break;
case Operation::CUBE:
c.vmulsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]]);
c.vmulsd(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]]);
break;
case Operation::RECIPROCAL:
c.vdivsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]], workspaceVar[args[0]]);
break;
case Operation::ADD_CONSTANT:
c.vaddsd(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]);
break;
case Operation::MULTIPLY_CONSTANT:
c.vmulsd(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]);
break;
case Operation::POWER_CONSTANT:
generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]], pow);
break;
case Operation::MIN:
c.vminsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
break;
case Operation::MAX:
c.vmaxsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
break;
case Operation::ABS:
c.vandpd(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]);
break;
case Operation::FLOOR:
c.vroundsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]], imm(1));
break;
case Operation::CEIL:
c.vroundsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]], imm(2));
break;
case Operation::SELECT:
{
x86::Xmm mask = c.newXmmSd();
c.vxorps(mask, mask, mask);
c.vcmpsd(mask, mask, workspaceVar[args[0]], imm(0)); // Comparison mode is _CMP_EQ_OQ = 0
c.vblendvps(workspaceVar[target[step]], workspaceVar[args[1]], workspaceVar[args[2]], mask);
break;
}
default:
// Just invoke evaluateOperation().
for (int i = 0; i < (int) args.size(); i++)
c.vmovsd(x86::ptr(argsPointer, 8*i, 0), workspaceVar[args[i]]);
x86::Gp fn = c.newIntPtr();
c.mov(fn, imm((void*) evaluateOperation));
InvokeNode* invoke;
c.invoke(&invoke, fn, FuncSignatureT<double, Operation*, double*>());
invoke->setArg(0, imm(&op));
invoke->setArg(1, imm(&argValues[0]));
invoke->setRet(0, workspaceVar[target[step]]);
}
}
c.ret(workspaceVar[workspace.size()-1]);
c.endFunc();
c.finalize();
runtime.add(&jitCode, &code);
}
void CompiledExpression::generateSingleArgCall(x86::Compiler& c, x86::Xmm& dest, x86::Xmm& arg, double (*function)(double)) {
x86::Gp fn = c.newIntPtr();
c.mov(fn, imm((void*) function));
InvokeNode* invoke;
c.invoke(&invoke, fn, FuncSignatureT<double, double>());
invoke->setArg(0, arg);
invoke->setRet(0, dest);
}
void CompiledExpression::generateTwoArgCall(x86::Compiler& c, x86::Xmm& dest, x86::Xmm& arg1, x86::Xmm& arg2, double (*function)(double, double)) {
x86::Gp fn = c.newIntPtr();
c.mov(fn, imm((void*) function));
InvokeNode* invoke;
c.invoke(&invoke, fn, FuncSignatureT<double, double, double>());
invoke->setArg(0, arg1);
invoke->setArg(1, arg2);
invoke->setRet(0, dest);
}
#endif
#endif

View File

@ -0,0 +1,933 @@
/* -------------------------------------------------------------------------- *
* Lepton *
* -------------------------------------------------------------------------- *
* This is part of the Lepton expression parser originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2013-2022 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
* to deal in the Software without restriction, including without limitation *
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
* and/or sell copies of the Software, and to permit persons to whom the *
* Software is furnished to do so, subject to the following conditions: *
* *
* The above copyright notice and this permission notice shall be included in *
* all copies or substantial portions of the Software. *
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
#include "lepton/CompiledVectorExpression.h"
#include "lepton/Operation.h"
#include "lepton/ParsedExpression.h"
#include <algorithm>
#include <utility>
using namespace Lepton;
using namespace std;
#ifdef LEPTON_USE_JIT
using namespace asmjit;
#endif
CompiledVectorExpression::CompiledVectorExpression() : jitCode(NULL) {
}
CompiledVectorExpression::CompiledVectorExpression(const ParsedExpression& expression, int width) : jitCode(NULL), width(width) {
const vector<int> allowedWidths = getAllowedWidths();
if (find(allowedWidths.begin(), allowedWidths.end(), width) == allowedWidths.end())
throw Exception("Unsupported width for vector expression: "+to_string(width));
ParsedExpression expr = expression.optimize(); // Just in case it wasn't already optimized.
vector<pair<ExpressionTreeNode, int> > temps;
int workspaceSize = 0;
compileExpression(expr.getRootNode(), temps, workspaceSize);
workspace.resize(workspaceSize*width);
int maxArguments = 1;
for (int i = 0; i < (int) operation.size(); i++)
if (operation[i]->getNumArguments() > maxArguments)
maxArguments = operation[i]->getNumArguments();
argValues.resize(maxArguments);
#ifdef LEPTON_USE_JIT
generateJitCode();
#endif
}
CompiledVectorExpression::~CompiledVectorExpression() {
for (int i = 0; i < (int) operation.size(); i++)
if (operation[i] != NULL)
delete operation[i];
}
CompiledVectorExpression::CompiledVectorExpression(const CompiledVectorExpression& expression) : jitCode(NULL) {
*this = expression;
}
CompiledVectorExpression& CompiledVectorExpression::operator=(const CompiledVectorExpression& expression) {
arguments = expression.arguments;
width = expression.width;
target = expression.target;
variableIndices = expression.variableIndices;
variableNames = expression.variableNames;
workspace.resize(expression.workspace.size());
argValues.resize(expression.argValues.size());
operation.resize(expression.operation.size());
for (int i = 0; i < (int) operation.size(); i++)
operation[i] = expression.operation[i]->clone();
setVariableLocations(variablePointers);
return *this;
}
const vector<int>& CompiledVectorExpression::getAllowedWidths() {
static vector<int> widths;
if (widths.size() == 0) {
widths.push_back(4);
#ifdef LEPTON_USE_JIT
const CpuInfo& cpu = CpuInfo::host();
if (cpu.hasFeature(CpuFeatures::X86::kAVX))
widths.push_back(8);
#endif
}
return widths;
}
void CompiledVectorExpression::compileExpression(const ExpressionTreeNode& node, vector<pair<ExpressionTreeNode, int> >& temps, int& workspaceSize) {
if (findTempIndex(node, temps) != -1)
return; // We have already processed a node identical to this one.
// Process the child nodes.
vector<int> args;
for (int i = 0; i < node.getChildren().size(); i++) {
compileExpression(node.getChildren()[i], temps, workspaceSize);
args.push_back(findTempIndex(node.getChildren()[i], temps));
}
// Process this node.
if (node.getOperation().getId() == Operation::VARIABLE) {
variableIndices[node.getOperation().getName()] = workspaceSize;
variableNames.insert(node.getOperation().getName());
}
else {
int stepIndex = (int) arguments.size();
arguments.push_back(vector<int>());
target.push_back(workspaceSize);
operation.push_back(node.getOperation().clone());
if (args.size() == 0)
arguments[stepIndex].push_back(0); // The value won't actually be used. We just need something there.
else {
// If the arguments are sequential, we can just pass a pointer to the first one.
bool sequential = true;
for (int i = 1; i < args.size(); i++)
if (args[i] != args[i - 1] + 1)
sequential = false;
if (sequential)
arguments[stepIndex].push_back(args[0]);
else
arguments[stepIndex] = args;
}
}
temps.push_back(make_pair(node, workspaceSize));
workspaceSize++;
}
int CompiledVectorExpression::findTempIndex(const ExpressionTreeNode& node, vector<pair<ExpressionTreeNode, int> >& temps) {
for (int i = 0; i < (int) temps.size(); i++)
if (temps[i].first == node)
return i;
return -1;
}
int CompiledVectorExpression::getWidth() const {
return width;
}
const set<string>& CompiledVectorExpression::getVariables() const {
return variableNames;
}
float* CompiledVectorExpression::getVariablePointer(const string& name) {
map<string, float*>::iterator pointer = variablePointers.find(name);
if (pointer != variablePointers.end())
return pointer->second;
map<string, int>::iterator index = variableIndices.find(name);
if (index == variableIndices.end())
throw Exception("getVariableReference: Unknown variable '" + name + "'");
return &workspace[index->second*width];
}
void CompiledVectorExpression::setVariableLocations(map<string, float*>& variableLocations) {
variablePointers = variableLocations;
#ifdef LEPTON_USE_JIT
// Rebuild the JIT code.
if (workspace.size() > 0)
generateJitCode();
#endif
// Make a list of all variables we will need to copy before evaluating the expression.
variablesToCopy.clear();
for (map<string, int>::const_iterator iter = variableIndices.begin(); iter != variableIndices.end(); ++iter) {
map<string, float*>::iterator pointer = variablePointers.find(iter->first);
if (pointer != variablePointers.end())
variablesToCopy.push_back(make_pair(&workspace[iter->second*width], pointer->second));
}
}
const float* CompiledVectorExpression::evaluate() const {
if (jitCode) {
jitCode();
return &workspace[workspace.size()-width];
}
for (int i = 0; i < variablesToCopy.size(); i++)
for (int j = 0; j < width; j++)
variablesToCopy[i].first[j] = variablesToCopy[i].second[j];
// Loop over the operations and evaluate each one.
for (int step = 0; step < operation.size(); step++) {
const vector<int>& args = arguments[step];
if (args.size() == 1) {
for (int j = 0; j < width; j++) {
for (int i = 0; i < operation[step]->getNumArguments(); i++)
argValues[i] = workspace[(args[0]+i)*width+j];
workspace[target[step]*width+j] = operation[step]->evaluate(&argValues[0], dummyVariables);
}
} else {
for (int j = 0; j < width; j++) {
for (int i = 0; i < args.size(); i++)
argValues[i] = workspace[args[i]*width+j];
workspace[target[step]*width+j] = operation[step]->evaluate(&argValues[0], dummyVariables);
}
}
}
return &workspace[workspace.size()-width];
}
#ifdef LEPTON_USE_JIT
static double evaluateOperation(Operation* op, double* args) {
static map<string, double> dummyVariables;
return op->evaluate(args, dummyVariables);
}
void CompiledVectorExpression::findPowerGroups(vector<vector<int> >& groups, vector<vector<int> >& groupPowers, vector<int>& stepGroup) {
// Identify every step that raises an argument to an integer power.
vector<int> stepPower(operation.size(), 0);
vector<int> stepArg(operation.size(), -1);
for (int step = 0; step < operation.size(); step++) {
Operation& op = *operation[step];
int power = 0;
if (op.getId() == Operation::SQUARE)
power = 2;
else if (op.getId() == Operation::CUBE)
power = 3;
else if (op.getId() == Operation::POWER_CONSTANT) {
double realPower = dynamic_cast<const Operation::PowerConstant*> (&op)->getValue();
if (realPower == (int) realPower)
power = (int) realPower;
}
if (power != 0) {
stepPower[step] = power;
stepArg[step] = arguments[step][0];
}
}
// Find groups that operate on the same argument and whose powers have the same sign.
stepGroup.resize(operation.size(), -1);
for (int i = 0; i < operation.size(); i++) {
if (stepGroup[i] != -1)
continue;
vector<int> group, power;
for (int j = i; j < operation.size(); j++) {
if (stepArg[i] == stepArg[j] && stepPower[i] * stepPower[j] > 0) {
stepGroup[j] = groups.size();
group.push_back(j);
power.push_back(stepPower[j]);
}
}
groups.push_back(group);
groupPowers.push_back(power);
}
}
#if defined(__ARM__) || defined(__ARM64__)
void CompiledVectorExpression::generateJitCode() {
CodeHolder code;
code.init(runtime.environment());
a64::Compiler c(&code);
c.addFunc(FuncSignatureT<void>());
vector<arm::Vec> workspaceVar(workspace.size()/width);
for (int i = 0; i < (int) workspaceVar.size(); i++)
workspaceVar[i] = c.newVecQ();
arm::Gp argsPointer = c.newIntPtr();
c.mov(argsPointer, imm(&argValues[0]));
vector<vector<int> > groups, groupPowers;
vector<int> stepGroup;
findPowerGroups(groups, groupPowers, stepGroup);
// Load the arguments into variables.
arm::Gp variablePointer = c.newIntPtr();
for (set<string>::const_iterator iter = variableNames.begin(); iter != variableNames.end(); ++iter) {
map<string, int>::iterator index = variableIndices.find(*iter);
c.mov(variablePointer, imm(getVariablePointer(index->first)));
c.ldr(workspaceVar[index->second].s4(), arm::ptr(variablePointer, 0));
}
// Make a list of all constants that will be needed for evaluation.
vector<int> operationConstantIndex(operation.size(), -1);
for (int step = 0; step < (int) operation.size(); step++) {
// Find the constant value (if any) used by this operation.
Operation& op = *operation[step];
float value;
if (op.getId() == Operation::CONSTANT)
value = dynamic_cast<Operation::Constant&> (op).getValue();
else if (op.getId() == Operation::ADD_CONSTANT)
value = dynamic_cast<Operation::AddConstant&> (op).getValue();
else if (op.getId() == Operation::MULTIPLY_CONSTANT)
value = dynamic_cast<Operation::MultiplyConstant&> (op).getValue();
else if (op.getId() == Operation::RECIPROCAL)
value = 1.0;
else if (op.getId() == Operation::STEP)
value = 1.0;
else if (op.getId() == Operation::DELTA)
value = 1.0;
else if (op.getId() == Operation::POWER_CONSTANT) {
if (stepGroup[step] == -1)
value = dynamic_cast<Operation::PowerConstant&> (op).getValue();
else
value = 1.0;
} else
continue;
// See if we already have a variable for this constant.
for (int i = 0; i < (int) constants.size(); i++)
if (value == constants[i]) {
operationConstantIndex[step] = i;
break;
}
if (operationConstantIndex[step] == -1) {
operationConstantIndex[step] = constants.size();
constants.push_back(value);
}
}
// Load constants into variables.
vector<arm::Vec> constantVar(constants.size());
if (constants.size() > 0) {
arm::Gp constantsPointer = c.newIntPtr();
for (int i = 0; i < (int) constants.size(); i++) {
c.mov(constantsPointer, imm(&constants[i]));
constantVar[i] = c.newVecQ();
c.ld1r(constantVar[i].s4(), arm::ptr(constantsPointer));
}
}
// Evaluate the operations.
vector<bool> hasComputedPower(operation.size(), false);
arm::Vec argReg = c.newVecS();
arm::Vec doubleArgReg = c.newVecD();
arm::Vec doubleResultReg = c.newVecD();
for (int step = 0; step < (int) operation.size(); step++) {
if (hasComputedPower[step])
continue;
// When one or more steps involve raising the same argument to multiple integer
// powers, we can compute them all together for efficiency.
if (stepGroup[step] != -1) {
vector<int>& group = groups[stepGroup[step]];
vector<int>& powers = groupPowers[stepGroup[step]];
arm::Vec multiplier = c.newVecQ();
if (powers[0] > 0)
c.mov(multiplier.s4(), workspaceVar[arguments[step][0]].s4());
else {
c.fdiv(multiplier.s4(), constantVar[operationConstantIndex[step]].s4(), workspaceVar[arguments[step][0]].s4());
for (int i = 0; i < powers.size(); i++)
powers[i] = -powers[i];
}
vector<bool> hasAssigned(group.size(), false);
bool done = false;
while (!done) {
done = true;
for (int i = 0; i < group.size(); i++) {
if (powers[i] % 2 == 1) {
if (!hasAssigned[i])
c.mov(workspaceVar[target[group[i]]].s4(), multiplier.s4());
else
c.fmul(workspaceVar[target[group[i]]].s4(), workspaceVar[target[group[i]]].s4(), multiplier.s4());
hasAssigned[i] = true;
}
powers[i] >>= 1;
if (powers[i] != 0)
done = false;
}
if (!done)
c.fmul(multiplier.s4(), multiplier.s4(), multiplier.s4());
}
for (int step : group)
hasComputedPower[step] = true;
continue;
}
// Evaluate the step.
Operation& op = *operation[step];
vector<int> args = arguments[step];
if (args.size() == 1) {
// One or more sequential arguments. Fill out the list.
for (int i = 1; i < op.getNumArguments(); i++)
args.push_back(args[0] + i);
}
// Generate instructions to execute this operation.
switch (op.getId()) {
case Operation::CONSTANT:
c.mov(workspaceVar[target[step]].s4(), constantVar[operationConstantIndex[step]].s4());
break;
case Operation::ADD:
c.fadd(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[1]].s4());
break;
case Operation::SUBTRACT:
c.fsub(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[1]].s4());
break;
case Operation::MULTIPLY:
c.fmul(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[1]].s4());
break;
case Operation::DIVIDE:
c.fdiv(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[1]].s4());
break;
case Operation::POWER:
generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], powf);
break;
case Operation::NEGATE:
c.fneg(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4());
break;
case Operation::SQRT:
c.fsqrt(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4());
break;
case Operation::EXP:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], expf);
break;
case Operation::LOG:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], logf);
break;
case Operation::SIN:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], sinf);
break;
case Operation::COS:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], cosf);
break;
case Operation::TAN:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tanf);
break;
case Operation::ASIN:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], asinf);
break;
case Operation::ACOS:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], acosf);
break;
case Operation::ATAN:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], atanf);
break;
case Operation::ATAN2:
generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], atan2f);
break;
case Operation::SINH:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], sinhf);
break;
case Operation::COSH:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], coshf);
break;
case Operation::TANH:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tanhf);
break;
case Operation::STEP:
c.cmge(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), imm(0));
c.and_(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
break;
case Operation::DELTA:
c.cmeq(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), imm(0));
c.and_(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
break;
case Operation::SQUARE:
c.fmul(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[0]].s4());
break;
case Operation::CUBE:
c.fmul(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[0]].s4());
c.fmul(workspaceVar[target[step]].s4(), workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4());
break;
case Operation::RECIPROCAL:
c.fdiv(workspaceVar[target[step]].s4(), constantVar[operationConstantIndex[step]].s4(), workspaceVar[args[0]].s4());
break;
case Operation::ADD_CONSTANT:
c.fadd(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), constantVar[operationConstantIndex[step]].s4());
break;
case Operation::MULTIPLY_CONSTANT:
c.fmul(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), constantVar[operationConstantIndex[step]].s4());
break;
case Operation::POWER_CONSTANT:
generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]], powf);
break;
case Operation::MIN:
c.fmin(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[1]].s4());
break;
case Operation::MAX:
c.fmax(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[1]].s4());
break;
case Operation::ABS:
c.fabs(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4());
break;
case Operation::FLOOR:
c.frintm(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4());
break;
case Operation::CEIL:
c.frintp(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4());
break;
case Operation::SELECT:
c.fcmeq(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), imm(0));
c.bsl(workspaceVar[target[step]], workspaceVar[args[2]], workspaceVar[args[1]]);
break;
default:
// Just invoke evaluateOperation().
for (int element = 0; element < width; element++) {
for (int i = 0; i < (int) args.size(); i++) {
c.ins(argReg.s(0), workspaceVar[args[i]].s(element));
c.fcvt(doubleArgReg, argReg);
c.str(doubleArgReg, arm::ptr(argsPointer, 8*i));
}
arm::Gp fn = c.newIntPtr();
c.mov(fn, imm((void*) evaluateOperation));
InvokeNode* invoke;
c.invoke(&invoke, fn, FuncSignatureT<double, Operation*, double*>());
invoke->setArg(0, imm(&op));
invoke->setArg(1, imm(&argValues[0]));
invoke->setRet(0, doubleResultReg);
c.fcvt(argReg, doubleResultReg);
c.ins(workspaceVar[target[step]].s(element), argReg.s(0));
}
}
}
arm::Gp resultPointer = c.newIntPtr();
c.mov(resultPointer, imm(&workspace[workspace.size()-width]));
c.str(workspaceVar.back().s4(), arm::ptr(resultPointer, 0));
c.endFunc();
c.finalize();
runtime.add(&jitCode, &code);
}
void CompiledVectorExpression::generateSingleArgCall(a64::Compiler& c, arm::Vec& dest, arm::Vec& arg, float (*function)(float)) {
arm::Gp fn = c.newIntPtr();
c.mov(fn, imm((void*) function));
arm::Vec a = c.newVecS();
arm::Vec d = c.newVecS();
for (int element = 0; element < width; element++) {
c.ins(a.s(0), arg.s(element));
InvokeNode* invoke;
c.invoke(&invoke, fn, FuncSignatureT<float, float>());
invoke->setArg(0, a);
invoke->setRet(0, d);
c.ins(dest.s(element), d.s(0));
}
}
void CompiledVectorExpression::generateTwoArgCall(a64::Compiler& c, arm::Vec& dest, arm::Vec& arg1, arm::Vec& arg2, float (*function)(float, float)) {
arm::Gp fn = c.newIntPtr();
c.mov(fn, imm((void*) function));
arm::Vec a1 = c.newVecS();
arm::Vec a2 = c.newVecS();
arm::Vec d = c.newVecS();
for (int element = 0; element < width; element++) {
c.ins(a1.s(0), arg1.s(element));
c.ins(a2.s(0), arg2.s(element));
InvokeNode* invoke;
c.invoke(&invoke, fn, FuncSignatureT<float, float, float>());
invoke->setArg(0, a1);
invoke->setArg(1, a2);
invoke->setRet(0, d);
c.ins(dest.s(element), d.s(0));
}
}
#else
void CompiledVectorExpression::generateJitCode() {
const CpuInfo& cpu = CpuInfo::host();
if (!cpu.hasFeature(CpuFeatures::X86::kAVX))
return;
CodeHolder code;
code.init(runtime.environment());
x86::Compiler c(&code);
FuncNode* funcNode = c.addFunc(FuncSignatureT<void>());
funcNode->frame().setAvxEnabled();
vector<x86::Ymm> workspaceVar(workspace.size()/width);
for (int i = 0; i < (int) workspaceVar.size(); i++)
workspaceVar[i] = c.newYmmPs();
x86::Gp argsPointer = c.newIntPtr();
c.mov(argsPointer, imm(&argValues[0]));
vector<vector<int> > groups, groupPowers;
vector<int> stepGroup;
findPowerGroups(groups, groupPowers, stepGroup);
// Load the arguments into variables.
for (set<string>::const_iterator iter = variableNames.begin(); iter != variableNames.end(); ++iter) {
map<string, int>::iterator index = variableIndices.find(*iter);
x86::Gp variablePointer = c.newIntPtr();
c.mov(variablePointer, imm(getVariablePointer(index->first)));
if (width == 4)
c.vmovdqu(workspaceVar[index->second].xmm(), x86::ptr(variablePointer, 0, 0));
else
c.vmovdqu(workspaceVar[index->second], x86::ptr(variablePointer, 0, 0));
}
// Make a list of all constants that will be needed for evaluation.
vector<int> operationConstantIndex(operation.size(), -1);
for (int step = 0; step < (int) operation.size(); step++) {
// Find the constant value (if any) used by this operation.
Operation& op = *operation[step];
double value;
if (op.getId() == Operation::CONSTANT)
value = dynamic_cast<Operation::Constant&> (op).getValue();
else if (op.getId() == Operation::ADD_CONSTANT)
value = dynamic_cast<Operation::AddConstant&> (op).getValue();
else if (op.getId() == Operation::MULTIPLY_CONSTANT)
value = dynamic_cast<Operation::MultiplyConstant&> (op).getValue();
else if (op.getId() == Operation::RECIPROCAL)
value = 1.0;
else if (op.getId() == Operation::STEP)
value = 1.0;
else if (op.getId() == Operation::DELTA)
value = 1.0;
else if (op.getId() == Operation::ABS) {
int mask = 0x7FFFFFFF;
value = *reinterpret_cast<float*>(&mask);
}
else if (op.getId() == Operation::POWER_CONSTANT) {
if (stepGroup[step] == -1)
value = dynamic_cast<Operation::PowerConstant&> (op).getValue();
else
value = 1.0;
} else
continue;
// See if we already have a variable for this constant.
for (int i = 0; i < (int) constants.size(); i++)
if (value == constants[i]) {
operationConstantIndex[step] = i;
break;
}
if (operationConstantIndex[step] == -1) {
operationConstantIndex[step] = constants.size();
constants.push_back(value);
}
}
// Load constants into variables.
vector<x86::Ymm> constantVar(constants.size());
if (constants.size() > 0) {
x86::Gp constantsPointer = c.newIntPtr();
c.mov(constantsPointer, imm(&constants[0]));
for (int i = 0; i < (int) constants.size(); i++) {
constantVar[i] = c.newYmmPs();
c.vbroadcastss(constantVar[i], x86::ptr(constantsPointer, 4*i, 0));
}
}
// Evaluate the operations.
vector<bool> hasComputedPower(operation.size(), false);
x86::Ymm argReg = c.newYmm();
x86::Ymm doubleArgReg = c.newYmm();
x86::Ymm doubleResultReg = c.newYmm();
for (int step = 0; step < (int) operation.size(); step++) {
if (hasComputedPower[step])
continue;
// When one or more steps involve raising the same argument to multiple integer
// powers, we can compute them all together for efficiency.
if (stepGroup[step] != -1) {
vector<int>& group = groups[stepGroup[step]];
vector<int>& powers = groupPowers[stepGroup[step]];
x86::Ymm multiplier = c.newYmmPs();
if (powers[0] > 0)
c.vmovdqu(multiplier, workspaceVar[arguments[step][0]]);
else {
c.vdivps(multiplier, constantVar[operationConstantIndex[step]], workspaceVar[arguments[step][0]]);
for (int i = 0; i < powers.size(); i++)
powers[i] = -powers[i];
}
vector<bool> hasAssigned(group.size(), false);
bool done = false;
while (!done) {
done = true;
for (int i = 0; i < group.size(); i++) {
if (powers[i] % 2 == 1) {
if (!hasAssigned[i])
c.vmovdqu(workspaceVar[target[group[i]]], multiplier);
else
c.vmulps(workspaceVar[target[group[i]]], workspaceVar[target[group[i]]], multiplier);
hasAssigned[i] = true;
}
powers[i] >>= 1;
if (powers[i] != 0)
done = false;
}
if (!done)
c.vmulps(multiplier, multiplier, multiplier);
}
for (int step : group)
hasComputedPower[step] = true;
continue;
}
// Evaluate the step.
Operation& op = *operation[step];
vector<int> args = arguments[step];
if (args.size() == 1) {
// One or more sequential arguments. Fill out the list.
for (int i = 1; i < op.getNumArguments(); i++)
args.push_back(args[0] + i);
}
// Generate instructions to execute this operation.
switch (op.getId()) {
case Operation::CONSTANT:
c.vmovdqu(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
break;
case Operation::ADD:
c.vaddps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
break;
case Operation::SUBTRACT:
c.vsubps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
break;
case Operation::MULTIPLY:
c.vmulps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
break;
case Operation::DIVIDE:
c.vdivps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
break;
case Operation::POWER:
generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], powf);
break;
case Operation::NEGATE:
c.vxorps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[target[step]]);
c.vsubps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]]);
break;
case Operation::SQRT:
c.vsqrtps(workspaceVar[target[step]], workspaceVar[args[0]]);
break;
case Operation::EXP:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], expf);
break;
case Operation::LOG:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], logf);
break;
case Operation::SIN:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], sinf);
break;
case Operation::COS:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], cosf);
break;
case Operation::TAN:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tanf);
break;
case Operation::ASIN:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], asinf);
break;
case Operation::ACOS:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], acosf);
break;
case Operation::ATAN:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], atanf);
break;
case Operation::ATAN2:
generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], atan2f);
break;
case Operation::SINH:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], sinhf);
break;
case Operation::COSH:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], coshf);
break;
case Operation::TANH:
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tanhf);
break;
case Operation::STEP:
c.vxorps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[target[step]]);
c.vcmpps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]], imm(18)); // Comparison mode is _CMP_LE_OQ = 18
c.vandps(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
break;
case Operation::DELTA:
c.vxorps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[target[step]]);
c.vcmpps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]], imm(16)); // Comparison mode is _CMP_EQ_OQ = 0
c.vandps(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
break;
case Operation::SQUARE:
c.vmulps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]]);
break;
case Operation::CUBE:
c.vmulps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]]);
c.vmulps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]]);
break;
case Operation::RECIPROCAL:
c.vdivps(workspaceVar[target[step]], constantVar[operationConstantIndex[step]], workspaceVar[args[0]]);
break;
case Operation::ADD_CONSTANT:
c.vaddps(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]);
break;
case Operation::MULTIPLY_CONSTANT:
c.vmulps(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]);
break;
case Operation::POWER_CONSTANT:
generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]], powf);
break;
case Operation::MIN:
c.vminps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
break;
case Operation::MAX:
c.vmaxps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
break;
case Operation::ABS:
c.vandps(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]);
break;
case Operation::FLOOR:
c.vroundps(workspaceVar[target[step]], workspaceVar[args[0]], imm(1));
break;
case Operation::CEIL:
c.vroundps(workspaceVar[target[step]], workspaceVar[args[0]], imm(2));
break;
case Operation::SELECT:
{
x86::Ymm mask = c.newYmmPs();
c.vxorps(mask, mask, mask);
c.vcmpps(mask, mask, workspaceVar[args[0]], imm(0)); // Comparison mode is _CMP_EQ_OQ = 0
c.vblendvps(workspaceVar[target[step]], workspaceVar[args[1]], workspaceVar[args[2]], mask);
break;
}
default:
// Just invoke evaluateOperation().
for (int element = 0; element < width; element++) {
for (int i = 0; i < (int) args.size(); i++) {
if (element < 4)
c.vshufps(argReg, workspaceVar[args[i]], workspaceVar[args[i]], imm(element));
else {
c.vperm2f128(argReg, workspaceVar[args[i]], workspaceVar[args[i]], imm(1));
c.vshufps(argReg, argReg, argReg, imm(element-4));
}
c.vcvtss2sd(doubleArgReg.xmm(), doubleArgReg.xmm(), argReg.xmm());
c.vmovsd(x86::ptr(argsPointer, 8*i, 0), doubleArgReg.xmm());
}
x86::Gp fn = c.newIntPtr();
c.mov(fn, imm((void*) evaluateOperation));
InvokeNode* invoke;
c.invoke(&invoke, fn, FuncSignatureT<double, Operation*, double*>());
invoke->setArg(0, imm(&op));
invoke->setArg(1, imm(&argValues[0]));
invoke->setRet(0, doubleResultReg);
c.vcvtsd2ss(argReg.xmm(), argReg.xmm(), doubleResultReg.xmm());
if (element > 3)
c.vperm2f128(argReg, argReg, argReg, imm(0));
if (element != 0)
c.vshufps(argReg, argReg, argReg, imm(0));
c.vblendps(workspaceVar[target[step]], workspaceVar[target[step]], argReg, 1<<element);
}
}
}
x86::Gp resultPointer = c.newIntPtr();
c.mov(resultPointer, imm(&workspace[workspace.size()-width]));
if (width == 4)
c.vmovdqu(x86::ptr(resultPointer, 0, 0), workspaceVar.back().xmm());
else
c.vmovdqu(x86::ptr(resultPointer, 0, 0), workspaceVar.back());
c.endFunc();
c.finalize();
runtime.add(&jitCode, &code);
}
void CompiledVectorExpression::generateSingleArgCall(x86::Compiler& c, x86::Ymm& dest, x86::Ymm& arg, float (*function)(float)) {
x86::Gp fn = c.newIntPtr();
c.mov(fn, imm((void*) function));
x86::Ymm a = c.newYmm();
x86::Ymm d = c.newYmm();
for (int element = 0; element < width; element++) {
if (element < 4)
c.vshufps(a, arg, arg, imm(element));
else {
c.vperm2f128(a, arg, arg, imm(1));
c.vshufps(a, a, a, imm(element-4));
}
InvokeNode* invoke;
c.invoke(&invoke, fn, FuncSignatureT<float, float>());
invoke->setArg(0, a);
invoke->setRet(0, d);
if (element > 3)
c.vperm2f128(d, d, d, imm(0));
if (element != 0)
c.vshufps(d, d, d, imm(0));
c.vblendps(dest, dest, d, 1<<element);
}
}
void CompiledVectorExpression::generateTwoArgCall(x86::Compiler& c, x86::Ymm& dest, x86::Ymm& arg1, x86::Ymm& arg2, float (*function)(float, float)) {
x86::Gp fn = c.newIntPtr();
c.mov(fn, imm((void*) function));
x86::Ymm a1 = c.newYmm();
x86::Ymm a2 = c.newYmm();
x86::Ymm d = c.newYmm();
for (int element = 0; element < width; element++) {
if (element < 4) {
c.vshufps(a1, arg1, arg1, imm(element));
c.vshufps(a2, arg2, arg2, imm(element));
}
else {
c.vperm2f128(a1, arg1, arg1, imm(1));
c.vperm2f128(a2, arg2, arg2, imm(1));
c.vshufps(a1, a1, a1, imm(element-4));
c.vshufps(a2, a2, a2, imm(element-4));
}
InvokeNode* invoke;
c.invoke(&invoke, fn, FuncSignatureT<float, float, float>());
invoke->setArg(0, a1);
invoke->setArg(1, a2);
invoke->setRet(0, d);
if (element > 3)
c.vperm2f128(d, d, d, imm(0));
if (element != 0)
c.vshufps(d, d, d, imm(0));
c.vblendps(dest, dest, d, 1<<element);
}
}
#endif
#endif

View File

@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2009-2015 Stanford University and the Authors. *
* Portions copyright (c) 2009-2021 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
@ -32,6 +32,7 @@
#include "lepton/ExpressionTreeNode.h"
#include "lepton/Exception.h"
#include "lepton/Operation.h"
#include <utility>
using namespace Lepton;
using namespace std;
@ -62,6 +63,11 @@ ExpressionTreeNode::ExpressionTreeNode(Operation* operation) : operation(operati
ExpressionTreeNode::ExpressionTreeNode(const ExpressionTreeNode& node) : operation(node.operation == NULL ? NULL : node.operation->clone()), children(node.getChildren()) {
}
ExpressionTreeNode::ExpressionTreeNode(ExpressionTreeNode&& node) : operation(node.operation), children(move(node.children)) {
node.operation = NULL;
node.children.clear();
}
ExpressionTreeNode::ExpressionTreeNode() : operation(NULL) {
}
@ -98,6 +104,16 @@ ExpressionTreeNode& ExpressionTreeNode::operator=(const ExpressionTreeNode& node
return *this;
}
ExpressionTreeNode& ExpressionTreeNode::operator=(ExpressionTreeNode&& node) {
if (operation != NULL)
delete operation;
operation = node.operation;
children = move(node.children);
node.operation = NULL;
node.children.clear();
return *this;
}
const Operation& ExpressionTreeNode::getOperation() const {
return *operation;
}
@ -105,3 +121,33 @@ const Operation& ExpressionTreeNode::getOperation() const {
const vector<ExpressionTreeNode>& ExpressionTreeNode::getChildren() const {
return children;
}
void ExpressionTreeNode::assignTags(vector<const ExpressionTreeNode*>& examples) const {
// Assign tag values to all nodes in a tree, such that two nodes have the same
// tag if and only if they (and all their children) are equal. This is used to
// optimize other operations.
int numTags = examples.size();
for (const ExpressionTreeNode& child : getChildren())
child.assignTags(examples);
if (numTags == examples.size()) {
// All the children matched existing tags, so possibly this node does too.
for (int i = 0; i < examples.size(); i++) {
const ExpressionTreeNode& example = *examples[i];
bool matches = (getChildren().size() == example.getChildren().size() && getOperation() == example.getOperation());
for (int j = 0; matches && j < getChildren().size(); j++)
if (getChildren()[j].tag != example.getChildren()[j].tag)
matches = false;
if (matches) {
tag = i;
return;
}
}
}
// This node does not match any previous node, so assign a new tag.
tag = examples.size();
examples.push_back(this);
}

View File

@ -3,7 +3,7 @@
/*
* Up to version 11 (VC++ 2012), Microsoft does not support the
* standard C99 erf() and erfc() functions so we have to fake them here.
* standard C99 erf() and erfc() functions so we have to fake them here.
* These were added in version 12 (VC++ 2013), which sets _MSC_VER=1800
* (VC11 has _MSC_VER=1700).
*/
@ -15,7 +15,7 @@
#endif
#if defined(_MSC_VER)
#if _MSC_VER <= 1700 // 1700 is VC11, 1800 is VC12
#if _MSC_VER <= 1700 // 1700 is VC11, 1800 is VC12
/***************************
* erf.cpp
* author: Steve Strand

View File

@ -7,7 +7,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2009-2019 Stanford University and the Authors. *
* Portions copyright (c) 2009-2021 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
@ -37,6 +37,12 @@
using namespace Lepton;
using namespace std;
static bool isZero(const ExpressionTreeNode& node) {
if (node.getOperation().getId() != Operation::CONSTANT)
return false;
return dynamic_cast<const Operation::Constant&>(node.getOperation()).getValue() == 0.0;
}
double Operation::Erf::evaluate(double* args, const map<string, double>& variables) const {
return erf(args[0]);
}
@ -58,35 +64,71 @@ ExpressionTreeNode Operation::Variable::differentiate(const std::vector<Expressi
ExpressionTreeNode Operation::Custom::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (function->getNumArguments() == 0)
return ExpressionTreeNode(new Operation::Constant(0.0));
ExpressionTreeNode result = ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Custom(*this, 0), children), childDerivs[0]);
for (int i = 1; i < getNumArguments(); i++) {
result = ExpressionTreeNode(new Operation::Add(),
result,
ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Custom(*this, i), children), childDerivs[i]));
ExpressionTreeNode result;
bool foundTerm = false;
for (int i = 0; i < getNumArguments(); i++) {
if (!isZero(childDerivs[i])) {
if (foundTerm)
result = ExpressionTreeNode(new Operation::Add(),
result,
ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Custom(*this, i), children), childDerivs[i]));
else {
result = ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Custom(*this, i), children), childDerivs[i]);
foundTerm = true;
}
}
}
return result;
if (foundTerm)
return result;
return ExpressionTreeNode(new Operation::Constant(0.0));
}
ExpressionTreeNode Operation::Add::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return childDerivs[1];
if (isZero(childDerivs[1]))
return childDerivs[0];
return ExpressionTreeNode(new Operation::Add(), childDerivs[0], childDerivs[1]);
}
ExpressionTreeNode Operation::Subtract::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0])) {
if (isZero(childDerivs[1]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Negate(), childDerivs[1]);
}
if (isZero(childDerivs[1]))
return childDerivs[0];
return ExpressionTreeNode(new Operation::Subtract(), childDerivs[0], childDerivs[1]);
}
ExpressionTreeNode Operation::Multiply::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0])) {
if (isZero(childDerivs[1]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(), children[0], childDerivs[1]);
}
if (isZero(childDerivs[1]))
return ExpressionTreeNode(new Operation::Multiply(), children[1], childDerivs[0]);
return ExpressionTreeNode(new Operation::Add(),
ExpressionTreeNode(new Operation::Multiply(), children[0], childDerivs[1]),
ExpressionTreeNode(new Operation::Multiply(), children[1], childDerivs[0]));
}
ExpressionTreeNode Operation::Divide::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return ExpressionTreeNode(new Operation::Divide(),
ExpressionTreeNode(new Operation::Subtract(),
ExpressionTreeNode subexp;
if (isZero(childDerivs[0])) {
if (isZero(childDerivs[1]))
return ExpressionTreeNode(new Operation::Constant(0.0));
subexp = ExpressionTreeNode(new Operation::Negate(), ExpressionTreeNode(new Operation::Multiply(), children[0], childDerivs[1]));
}
else if (isZero(childDerivs[1]))
subexp = ExpressionTreeNode(new Operation::Multiply(), children[1], childDerivs[0]);
else
subexp = ExpressionTreeNode(new Operation::Subtract(),
ExpressionTreeNode(new Operation::Multiply(), children[1], childDerivs[0]),
ExpressionTreeNode(new Operation::Multiply(), children[0], childDerivs[1])),
ExpressionTreeNode(new Operation::Square(), children[1]));
ExpressionTreeNode(new Operation::Multiply(), children[0], childDerivs[1]));
return ExpressionTreeNode(new Operation::Divide(), subexp, ExpressionTreeNode(new Operation::Square(), children[1]));
}
ExpressionTreeNode Operation::Power::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
@ -105,10 +147,14 @@ ExpressionTreeNode Operation::Power::differentiate(const std::vector<ExpressionT
}
ExpressionTreeNode Operation::Negate::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Negate(), childDerivs[0]);
}
ExpressionTreeNode Operation::Sqrt::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::MultiplyConstant(0.5),
ExpressionTreeNode(new Operation::Reciprocal(),
@ -117,24 +163,32 @@ ExpressionTreeNode Operation::Sqrt::differentiate(const std::vector<ExpressionTr
}
ExpressionTreeNode Operation::Exp::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Exp(), children[0]),
childDerivs[0]);
}
ExpressionTreeNode Operation::Log::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Reciprocal(), children[0]),
childDerivs[0]);
}
ExpressionTreeNode Operation::Sin::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Cos(), children[0]),
childDerivs[0]);
}
ExpressionTreeNode Operation::Cos::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Negate(),
ExpressionTreeNode(new Operation::Sin(), children[0])),
@ -142,6 +196,8 @@ ExpressionTreeNode Operation::Cos::differentiate(const std::vector<ExpressionTre
}
ExpressionTreeNode Operation::Sec::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Sec(), children[0]),
@ -150,6 +206,8 @@ ExpressionTreeNode Operation::Sec::differentiate(const std::vector<ExpressionTre
}
ExpressionTreeNode Operation::Csc::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Negate(),
ExpressionTreeNode(new Operation::Multiply(),
@ -159,6 +217,8 @@ ExpressionTreeNode Operation::Csc::differentiate(const std::vector<ExpressionTre
}
ExpressionTreeNode Operation::Tan::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Square(),
ExpressionTreeNode(new Operation::Sec(), children[0])),
@ -166,6 +226,8 @@ ExpressionTreeNode Operation::Tan::differentiate(const std::vector<ExpressionTre
}
ExpressionTreeNode Operation::Cot::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Negate(),
ExpressionTreeNode(new Operation::Square(),
@ -174,6 +236,8 @@ ExpressionTreeNode Operation::Cot::differentiate(const std::vector<ExpressionTre
}
ExpressionTreeNode Operation::Asin::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Reciprocal(),
ExpressionTreeNode(new Operation::Sqrt(),
@ -184,6 +248,8 @@ ExpressionTreeNode Operation::Asin::differentiate(const std::vector<ExpressionTr
}
ExpressionTreeNode Operation::Acos::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Negate(),
ExpressionTreeNode(new Operation::Reciprocal(),
@ -195,6 +261,8 @@ ExpressionTreeNode Operation::Acos::differentiate(const std::vector<ExpressionTr
}
ExpressionTreeNode Operation::Atan::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Reciprocal(),
ExpressionTreeNode(new Operation::AddConstant(1.0),
@ -213,6 +281,8 @@ ExpressionTreeNode Operation::Atan2::differentiate(const std::vector<ExpressionT
}
ExpressionTreeNode Operation::Sinh::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Cosh(),
children[0]),
@ -220,6 +290,8 @@ ExpressionTreeNode Operation::Sinh::differentiate(const std::vector<ExpressionTr
}
ExpressionTreeNode Operation::Cosh::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Sinh(),
children[0]),
@ -227,6 +299,8 @@ ExpressionTreeNode Operation::Cosh::differentiate(const std::vector<ExpressionTr
}
ExpressionTreeNode Operation::Tanh::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Subtract(),
ExpressionTreeNode(new Operation::Constant(1.0)),
@ -236,6 +310,8 @@ ExpressionTreeNode Operation::Tanh::differentiate(const std::vector<ExpressionTr
}
ExpressionTreeNode Operation::Erf::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Constant(2.0/sqrt(M_PI))),
@ -246,6 +322,8 @@ ExpressionTreeNode Operation::Erf::differentiate(const std::vector<ExpressionTre
}
ExpressionTreeNode Operation::Erfc::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Constant(-2.0/sqrt(M_PI))),
@ -264,6 +342,8 @@ ExpressionTreeNode Operation::Delta::differentiate(const std::vector<ExpressionT
}
ExpressionTreeNode Operation::Square::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::MultiplyConstant(2.0),
children[0]),
@ -271,6 +351,8 @@ ExpressionTreeNode Operation::Square::differentiate(const std::vector<Expression
}
ExpressionTreeNode Operation::Cube::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::MultiplyConstant(3.0),
ExpressionTreeNode(new Operation::Square(), children[0])),
@ -278,6 +360,8 @@ ExpressionTreeNode Operation::Cube::differentiate(const std::vector<ExpressionTr
}
ExpressionTreeNode Operation::Reciprocal::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Negate(),
ExpressionTreeNode(new Operation::Reciprocal(),
@ -290,11 +374,15 @@ ExpressionTreeNode Operation::AddConstant::differentiate(const std::vector<Expre
}
ExpressionTreeNode Operation::MultiplyConstant::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::MultiplyConstant(value),
childDerivs[0]);
}
ExpressionTreeNode Operation::PowerConstant::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::MultiplyConstant(value),
ExpressionTreeNode(new Operation::PowerConstant(value-1),
@ -305,22 +393,18 @@ ExpressionTreeNode Operation::PowerConstant::differentiate(const std::vector<Exp
ExpressionTreeNode Operation::Min::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
ExpressionTreeNode step(new Operation::Step(),
ExpressionTreeNode(new Operation::Subtract(), children[0], children[1]));
return ExpressionTreeNode(new Operation::Subtract(),
ExpressionTreeNode(new Operation::Multiply(), childDerivs[1], step),
ExpressionTreeNode(new Operation::Multiply(), childDerivs[0],
ExpressionTreeNode(new Operation::AddConstant(-1), step)));
return ExpressionTreeNode(new Operation::Select(), {step, childDerivs[1], childDerivs[0]});
}
ExpressionTreeNode Operation::Max::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
ExpressionTreeNode step(new Operation::Step(),
ExpressionTreeNode(new Operation::Subtract(), children[0], children[1]));
return ExpressionTreeNode(new Operation::Subtract(),
ExpressionTreeNode(new Operation::Multiply(), childDerivs[0], step),
ExpressionTreeNode(new Operation::Multiply(), childDerivs[1],
ExpressionTreeNode(new Operation::AddConstant(-1), step)));
return ExpressionTreeNode(new Operation::Select(), {step, childDerivs[0], childDerivs[1]});
}
ExpressionTreeNode Operation::Abs::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
ExpressionTreeNode step(new Operation::Step(), children[0]);
return ExpressionTreeNode(new Operation::Multiply(),
childDerivs[0],
@ -337,9 +421,5 @@ ExpressionTreeNode Operation::Ceil::differentiate(const std::vector<ExpressionTr
}
ExpressionTreeNode Operation::Select::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
vector<ExpressionTreeNode> derivChildren;
derivChildren.push_back(children[0]);
derivChildren.push_back(childDerivs[1]);
derivChildren.push_back(childDerivs[2]);
return ExpressionTreeNode(new Operation::Select(), derivChildren);
return ExpressionTreeNode(new Operation::Select(), {children[0], childDerivs[1], childDerivs[2]});
}

View File

@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2009 Stanford University and the Authors. *
* Portions copyright (c) 2009-2022 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
@ -31,6 +31,7 @@
#include "lepton/ParsedExpression.h"
#include "lepton/CompiledExpression.h"
#include "lepton/CompiledVectorExpression.h"
#include "lepton/ExpressionProgram.h"
#include "lepton/Operation.h"
#include <limits>
@ -68,9 +69,16 @@ double ParsedExpression::evaluate(const ExpressionTreeNode& node, const map<stri
}
ParsedExpression ParsedExpression::optimize() const {
ExpressionTreeNode result = precalculateConstantSubexpressions(getRootNode());
ExpressionTreeNode result = getRootNode();
vector<const ExpressionTreeNode*> examples;
result.assignTags(examples);
map<int, ExpressionTreeNode> nodeCache;
result = precalculateConstantSubexpressions(result, nodeCache);
while (true) {
ExpressionTreeNode simplified = substituteSimplerExpression(result);
examples.clear();
result.assignTags(examples);
nodeCache.clear();
ExpressionTreeNode simplified = substituteSimplerExpression(result, nodeCache);
if (simplified == result)
break;
result = simplified;
@ -80,9 +88,15 @@ ParsedExpression ParsedExpression::optimize() const {
ParsedExpression ParsedExpression::optimize(const map<string, double>& variables) const {
ExpressionTreeNode result = preevaluateVariables(getRootNode(), variables);
result = precalculateConstantSubexpressions(result);
vector<const ExpressionTreeNode*> examples;
result.assignTags(examples);
map<int, ExpressionTreeNode> nodeCache;
result = precalculateConstantSubexpressions(result, nodeCache);
while (true) {
ExpressionTreeNode simplified = substituteSimplerExpression(result);
examples.clear();
result.assignTags(examples);
nodeCache.clear();
ExpressionTreeNode simplified = substituteSimplerExpression(result, nodeCache);
if (simplified == result)
break;
result = simplified;
@ -104,36 +118,67 @@ ExpressionTreeNode ParsedExpression::preevaluateVariables(const ExpressionTreeNo
return ExpressionTreeNode(node.getOperation().clone(), children);
}
ExpressionTreeNode ParsedExpression::precalculateConstantSubexpressions(const ExpressionTreeNode& node) {
ExpressionTreeNode ParsedExpression::precalculateConstantSubexpressions(const ExpressionTreeNode& node, map<int, ExpressionTreeNode>& nodeCache) {
auto cached = nodeCache.find(node.tag);
if (cached != nodeCache.end())
return cached->second;
vector<ExpressionTreeNode> children(node.getChildren().size());
for (int i = 0; i < (int) children.size(); i++)
children[i] = precalculateConstantSubexpressions(node.getChildren()[i]);
children[i] = precalculateConstantSubexpressions(node.getChildren()[i], nodeCache);
ExpressionTreeNode result = ExpressionTreeNode(node.getOperation().clone(), children);
if (node.getOperation().getId() == Operation::VARIABLE || node.getOperation().getId() == Operation::CUSTOM)
if (node.getOperation().getId() == Operation::VARIABLE || node.getOperation().getId() == Operation::CUSTOM) {
nodeCache[node.tag] = result;
return result;
}
for (int i = 0; i < (int) children.size(); i++)
if (children[i].getOperation().getId() != Operation::CONSTANT)
if (children[i].getOperation().getId() != Operation::CONSTANT) {
nodeCache[node.tag] = result;
return result;
return ExpressionTreeNode(new Operation::Constant(evaluate(result, map<string, double>())));
}
result = ExpressionTreeNode(new Operation::Constant(evaluate(result, map<string, double>())));
nodeCache[node.tag] = result;
return result;
}
ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const ExpressionTreeNode& node) {
ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const ExpressionTreeNode& node, map<int, ExpressionTreeNode>& nodeCache) {
vector<ExpressionTreeNode> children(node.getChildren().size());
for (int i = 0; i < (int) children.size(); i++)
children[i] = substituteSimplerExpression(node.getChildren()[i]);
for (int i = 0; i < (int) children.size(); i++) {
const ExpressionTreeNode& child = node.getChildren()[i];
auto cached = nodeCache.find(child.tag);
if (cached == nodeCache.end()) {
children[i] = substituteSimplerExpression(child, nodeCache);
nodeCache[child.tag] = children[i];
}
else
children[i] = cached->second;
}
// Collect some info on constant expressions in children
bool first_const = children.size() > 0 && isConstant(children[0]); // is first child constant?
bool second_const = children.size() > 1 && isConstant(children[1]); ; // is second child constant?
double first, second; // if yes, value of first and second child
if (first_const)
first = getConstantValue(children[0]);
if (second_const)
second = getConstantValue(children[1]);
switch (node.getOperation().getId()) {
case Operation::ADD:
{
double first = getConstantValue(children[0]);
double second = getConstantValue(children[1]);
if (first == 0.0) // Add 0
return children[1];
if (second == 0.0) // Add 0
return children[0];
if (first == first) // Add a constant
return ExpressionTreeNode(new Operation::AddConstant(first), children[1]);
if (second == second) // Add a constant
return ExpressionTreeNode(new Operation::AddConstant(second), children[0]);
if (first_const) {
if (first == 0.0) { // Add 0
return children[1];
} else { // Add a constant
return ExpressionTreeNode(new Operation::AddConstant(first), children[1]);
}
}
if (second_const) {
if (second == 0.0) { // Add 0
return children[0];
} else { // Add a constant
return ExpressionTreeNode(new Operation::AddConstant(second), children[0]);
}
}
if (children[1].getOperation().getId() == Operation::NEGATE) // a+(-b) = a-b
return ExpressionTreeNode(new Operation::Subtract(), children[0], children[1].getChildren()[0]);
if (children[0].getOperation().getId() == Operation::NEGATE) // (-a)+b = b-a
@ -144,34 +189,35 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio
{
if (children[0] == children[1])
return ExpressionTreeNode(new Operation::Constant(0.0)); // Subtracting anything from itself is 0
double first = getConstantValue(children[0]);
if (first == 0.0) // Subtract from 0
return ExpressionTreeNode(new Operation::Negate(), children[1]);
double second = getConstantValue(children[1]);
if (second == 0.0) // Subtract 0
return children[0];
if (second == second) // Subtract a constant
return ExpressionTreeNode(new Operation::AddConstant(-second), children[0]);
if (first_const) {
if (first == 0.0) // Subtract from 0
return ExpressionTreeNode(new Operation::Negate(), children[1]);
}
if (second_const) {
if (second == 0.0) { // Subtract 0
return children[0];
} else { // Subtract a constant
return ExpressionTreeNode(new Operation::AddConstant(-second), children[0]);
}
}
if (children[1].getOperation().getId() == Operation::NEGATE) // a-(-b) = a+b
return ExpressionTreeNode(new Operation::Add(), children[0], children[1].getChildren()[0]);
break;
}
case Operation::MULTIPLY:
{
double first = getConstantValue(children[0]);
double second = getConstantValue(children[1]);
if (first == 0.0 || second == 0.0) // Multiply by 0
{
if ((first_const && first == 0.0) || (second_const && second == 0.0)) // Multiply by 0
return ExpressionTreeNode(new Operation::Constant(0.0));
if (first == 1.0) // Multiply by 1
if (first_const && first == 1.0) // Multiply by 1
return children[1];
if (second == 1.0) // Multiply by 1
if (second_const && second == 1.0) // Multiply by 1
return children[0];
if (children[0].getOperation().getId() == Operation::CONSTANT) { // Multiply by a constant
if (first_const) { // Multiply by a constant
if (children[1].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine two multiplies into a single one
return ExpressionTreeNode(new Operation::MultiplyConstant(first*dynamic_cast<const Operation::MultiplyConstant*>(&children[1].getOperation())->getValue()), children[1].getChildren()[0]);
return ExpressionTreeNode(new Operation::MultiplyConstant(first), children[1]);
}
if (children[1].getOperation().getId() == Operation::CONSTANT) { // Multiply by a constant
if (second_const) { // Multiply by a constant
if (children[0].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine two multiplies into a single one
return ExpressionTreeNode(new Operation::MultiplyConstant(second*dynamic_cast<const Operation::MultiplyConstant*>(&children[0].getOperation())->getValue()), children[0].getChildren()[0]);
return ExpressionTreeNode(new Operation::MultiplyConstant(second), children[0]);
@ -202,18 +248,16 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio
{
if (children[0] == children[1])
return ExpressionTreeNode(new Operation::Constant(1.0)); // Dividing anything from itself is 0
double numerator = getConstantValue(children[0]);
if (numerator == 0.0) // 0 divided by something
if (first_const && first == 0.0) // 0 divided by something
return ExpressionTreeNode(new Operation::Constant(0.0));
if (numerator == 1.0) // 1 divided by something
if (first_const && first == 1.0) // 1 divided by something
return ExpressionTreeNode(new Operation::Reciprocal(), children[1]);
double denominator = getConstantValue(children[1]);
if (denominator == 1.0) // Divide by 1
if (second_const && second == 1.0) // Divide by 1
return children[0];
if (children[1].getOperation().getId() == Operation::CONSTANT) {
if (second_const) {
if (children[0].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine a multiply and a divide into one multiply
return ExpressionTreeNode(new Operation::MultiplyConstant(dynamic_cast<const Operation::MultiplyConstant*>(&children[0].getOperation())->getValue()/denominator), children[0].getChildren()[0]);
return ExpressionTreeNode(new Operation::MultiplyConstant(1.0/denominator), children[0]); // Replace a divide with a multiply
return ExpressionTreeNode(new Operation::MultiplyConstant(dynamic_cast<const Operation::MultiplyConstant*>(&children[0].getOperation())->getValue()/second), children[0].getChildren()[0]);
return ExpressionTreeNode(new Operation::MultiplyConstant(1.0/second), children[0]); // Replace a divide with a multiply
}
if (children[0].getOperation().getId() == Operation::NEGATE && children[1].getOperation().getId() == Operation::NEGATE) // The two negations cancel
return ExpressionTreeNode(new Operation::Divide(), children[0].getChildren()[0], children[1].getChildren()[0]);
@ -229,34 +273,34 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio
}
case Operation::POWER:
{
double base = getConstantValue(children[0]);
if (base == 0.0) // 0 to any power is 0
if (first_const && first == 0.0) // 0 to any power is 0
return ExpressionTreeNode(new Operation::Constant(0.0));
if (base == 1.0) // 1 to any power is 1
if (first_const && first == 1.0) // 1 to any power is 1
return ExpressionTreeNode(new Operation::Constant(1.0));
double exponent = getConstantValue(children[1]);
if (exponent == 0.0) // x^0 = 1
return ExpressionTreeNode(new Operation::Constant(1.0));
if (exponent == 1.0) // x^1 = x
return children[0];
if (exponent == -1.0) // x^-1 = recip(x)
return ExpressionTreeNode(new Operation::Reciprocal(), children[0]);
if (exponent == 2.0) // x^2 = square(x)
return ExpressionTreeNode(new Operation::Square(), children[0]);
if (exponent == 3.0) // x^3 = cube(x)
return ExpressionTreeNode(new Operation::Cube(), children[0]);
if (exponent == 0.5) // x^0.5 = sqrt(x)
return ExpressionTreeNode(new Operation::Sqrt(), children[0]);
if (exponent == exponent) // Constant power
return ExpressionTreeNode(new Operation::PowerConstant(exponent), children[0]);
if (second_const) { // Constant exponent
if (second == 0.0) // x^0 = 1
return ExpressionTreeNode(new Operation::Constant(1.0));
if (second == 1.0) // x^1 = x
return children[0];
if (second == -1.0) // x^-1 = recip(x)
return ExpressionTreeNode(new Operation::Reciprocal(), children[0]);
if (second == 2.0) // x^2 = square(x)
return ExpressionTreeNode(new Operation::Square(), children[0]);
if (second == 3.0) // x^3 = cube(x)
return ExpressionTreeNode(new Operation::Cube(), children[0]);
if (second == 0.5) // x^0.5 = sqrt(x)
return ExpressionTreeNode(new Operation::Sqrt(), children[0]);
// Constant power
return ExpressionTreeNode(new Operation::PowerConstant(second), children[0]);
}
break;
}
case Operation::NEGATE:
{
if (children[0].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine a multiply and a negate into a single multiply
return ExpressionTreeNode(new Operation::MultiplyConstant(-dynamic_cast<const Operation::MultiplyConstant*>(&children[0].getOperation())->getValue()), children[0].getChildren()[0]);
if (children[0].getOperation().getId() == Operation::CONSTANT) // Negate a constant
return ExpressionTreeNode(new Operation::Constant(-getConstantValue(children[0])));
if (first_const) // Negate a constant
return ExpressionTreeNode(new Operation::Constant(-first));
if (children[0].getOperation().getId() == Operation::NEGATE) // The two negations cancel
return children[0].getChildren()[0];
break;
@ -265,7 +309,7 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio
{
if (children[0].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine two multiplies into a single one
return ExpressionTreeNode(new Operation::MultiplyConstant(dynamic_cast<const Operation::MultiplyConstant*>(&node.getOperation())->getValue()*dynamic_cast<const Operation::MultiplyConstant*>(&children[0].getOperation())->getValue()), children[0].getChildren()[0]);
if (children[0].getOperation().getId() == Operation::CONSTANT) // Multiply two constants
if (first_const) // Multiply two constants
return ExpressionTreeNode(new Operation::Constant(dynamic_cast<const Operation::MultiplyConstant*>(&node.getOperation())->getValue()*getConstantValue(children[0])));
if (children[0].getOperation().getId() == Operation::NEGATE) // Combine a multiply and a negate into a single multiply
return ExpressionTreeNode(new Operation::MultiplyConstant(-dynamic_cast<const Operation::MultiplyConstant*>(&node.getOperation())->getValue()), children[0].getChildren()[0]);
@ -293,20 +337,33 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio
}
ParsedExpression ParsedExpression::differentiate(const string& variable) const {
return differentiate(getRootNode(), variable);
vector<const ExpressionTreeNode*> examples;
getRootNode().assignTags(examples);
map<int, ExpressionTreeNode> nodeCache;
return differentiate(getRootNode(), variable, nodeCache);
}
ExpressionTreeNode ParsedExpression::differentiate(const ExpressionTreeNode& node, const string& variable) {
ExpressionTreeNode ParsedExpression::differentiate(const ExpressionTreeNode& node, const string& variable, map<int, ExpressionTreeNode>& nodeCache) {
auto cached = nodeCache.find(node.tag);
if (cached != nodeCache.end())
return cached->second;
vector<ExpressionTreeNode> childDerivs(node.getChildren().size());
for (int i = 0; i < (int) childDerivs.size(); i++)
childDerivs[i] = differentiate(node.getChildren()[i], variable);
return node.getOperation().differentiate(node.getChildren(),childDerivs, variable);
childDerivs[i] = differentiate(node.getChildren()[i], variable, nodeCache);
ExpressionTreeNode result = node.getOperation().differentiate(node.getChildren(), childDerivs, variable);
nodeCache[node.tag] = result;
return result;
}
bool ParsedExpression::isConstant(const ExpressionTreeNode& node) {
return (node.getOperation().getId() == Operation::CONSTANT);
}
double ParsedExpression::getConstantValue(const ExpressionTreeNode& node) {
if (node.getOperation().getId() == Operation::CONSTANT)
return dynamic_cast<const Operation::Constant&>(node.getOperation()).getValue();
return numeric_limits<double>::quiet_NaN();
if (node.getOperation().getId() != Operation::CONSTANT) {
throw Exception("getConstantValue called on a non-constant ExpressionNode");
}
return dynamic_cast<const Operation::Constant&>(node.getOperation()).getValue();
}
ExpressionProgram ParsedExpression::createProgram() const {
@ -317,6 +374,10 @@ CompiledExpression ParsedExpression::createCompiledExpression() const {
return CompiledExpression(*this);
}
CompiledVectorExpression ParsedExpression::createCompiledVectorExpression(int width) const {
return CompiledVectorExpression(*this, width);
}
ParsedExpression ParsedExpression::renameVariables(const map<string, string>& replacements) const {
return ParsedExpression(renameNodeVariables(getRootNode(), replacements));
}

View File

@ -66,7 +66,7 @@ private:
string Parser::trim(const string& expression) {
// Remove leading and trailing spaces.
int start, end;
for (start = 0; start < (int) expression.size() && isspace(expression[start]); start++)
;