update Lepton to current master branch
This commit is contained in:
@ -9,7 +9,7 @@
|
||||
* Biological Structures at Stanford, funded under the NIH Roadmap for *
|
||||
* Medical Research, grant U54 GM072970. See https://simtk.org. *
|
||||
* *
|
||||
* Portions copyright (c) 2013-2019 Stanford University and the Authors. *
|
||||
* Portions copyright (c) 2013-2022 Stanford University and the Authors. *
|
||||
* Authors: Peter Eastman *
|
||||
* Contributors: *
|
||||
* *
|
||||
@ -40,7 +40,11 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#ifdef LEPTON_USE_JIT
|
||||
#include "asmjit.h"
|
||||
#if defined(__ARM__) || defined(__ARM64__)
|
||||
#include "asmjit/a64.h"
|
||||
#else
|
||||
#include "asmjit/x86.h"
|
||||
#endif
|
||||
#endif
|
||||
|
||||
namespace LMP_Lepton {
|
||||
@ -101,9 +105,15 @@ private:
|
||||
std::map<std::string, double> dummyVariables;
|
||||
double (*jitCode)();
|
||||
#ifdef LEPTON_USE_JIT
|
||||
void findPowerGroups(std::vector<std::vector<int> >& groups, std::vector<std::vector<int> >& groupPowers, std::vector<int>& stepGroup);
|
||||
void generateJitCode();
|
||||
void generateSingleArgCall(asmjit::X86Compiler& c, asmjit::X86Xmm& dest, asmjit::X86Xmm& arg, double (*function)(double));
|
||||
void generateTwoArgCall(asmjit::X86Compiler& c, asmjit::X86Xmm& dest, asmjit::X86Xmm& arg1, asmjit::X86Xmm& arg2, double (*function)(double, double));
|
||||
#if defined(__ARM__) || defined(__ARM64__)
|
||||
void generateSingleArgCall(asmjit::a64::Compiler& c, asmjit::arm::Vec& dest, asmjit::arm::Vec& arg, double (*function)(double));
|
||||
void generateTwoArgCall(asmjit::a64::Compiler& c, asmjit::arm::Vec& dest, asmjit::arm::Vec& arg1, asmjit::arm::Vec& arg2, double (*function)(double, double));
|
||||
#else
|
||||
void generateSingleArgCall(asmjit::x86::Compiler& c, asmjit::x86::Xmm& dest, asmjit::x86::Xmm& arg, double (*function)(double));
|
||||
void generateTwoArgCall(asmjit::x86::Compiler& c, asmjit::x86::Xmm& dest, asmjit::x86::Xmm& arg1, asmjit::x86::Xmm& arg2, double (*function)(double, double));
|
||||
#endif
|
||||
std::vector<double> constants;
|
||||
asmjit::JitRuntime runtime;
|
||||
#endif
|
||||
|
||||
145
lib/lepton/include/lepton/CompiledVectorExpression.h
Normal file
145
lib/lepton/include/lepton/CompiledVectorExpression.h
Normal file
@ -0,0 +1,145 @@
|
||||
#ifndef LEPTON_VECTOR_EXPRESSION_H_
|
||||
#define LEPTON_VECTOR_EXPRESSION_H_
|
||||
|
||||
/* -------------------------------------------------------------------------- *
|
||||
* Lepton *
|
||||
* -------------------------------------------------------------------------- *
|
||||
* This is part of the Lepton expression parser originating from *
|
||||
* Simbios, the NIH National Center for Physics-Based Simulation of *
|
||||
* Biological Structures at Stanford, funded under the NIH Roadmap for *
|
||||
* Medical Research, grant U54 GM072970. See https://simtk.org. *
|
||||
* *
|
||||
* Portions copyright (c) 2013-2022 Stanford University and the Authors. *
|
||||
* Authors: Peter Eastman *
|
||||
* Contributors: *
|
||||
* *
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a *
|
||||
* copy of this software and associated documentation files (the "Software"), *
|
||||
* to deal in the Software without restriction, including without limitation *
|
||||
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
|
||||
* and/or sell copies of the Software, and to permit persons to whom the *
|
||||
* Software is furnished to do so, subject to the following conditions: *
|
||||
* *
|
||||
* The above copyright notice and this permission notice shall be included in *
|
||||
* all copies or substantial portions of the Software. *
|
||||
* *
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
|
||||
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
|
||||
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
|
||||
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
|
||||
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
#include "ExpressionTreeNode.h"
|
||||
#include "windowsIncludes.h"
|
||||
#include <array>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#ifdef LEPTON_USE_JIT
|
||||
#if defined(__ARM__) || defined(__ARM64__)
|
||||
#include "asmjit/a64.h"
|
||||
#else
|
||||
#include "asmjit/x86.h"
|
||||
#endif
|
||||
#endif
|
||||
|
||||
namespace LMP_Lepton {
|
||||
|
||||
class Operation;
|
||||
class ParsedExpression;
|
||||
|
||||
/**
|
||||
* A CompiledVectorExpression is a highly optimized representation of an expression for cases when you want to evaluate
|
||||
* it many times as quickly as possible. It is similar to CompiledExpression, with the extra feature that it uses the CPU's
|
||||
* vector unit (AVX on x86, NEON on ARM) to evaluate the expression for multiple sets of arguments at once. It also differs
|
||||
* from CompiledExpression and ParsedExpression in using single precision rather than double precision to evaluate the expression.
|
||||
* You should treat it as an opaque object; none of the internal representation is visible.
|
||||
*
|
||||
* A CompiledVectorExpression is created by calling createCompiledVectorExpression() on a ParsedExpression. When you create
|
||||
* it, you must specify the width of the vectors on which to compute the expression. The allowed widths depend on the type of
|
||||
* CPU it is running on. 4 is always allowed, and 8 is allowed on x86 processors with AVX. Call getAllowedWidths() to query
|
||||
* the allowed values.
|
||||
*
|
||||
* WARNING: CompiledVectorExpression is NOT thread safe. You should never access a CompiledVectorExpression from two threads at
|
||||
* the same time.
|
||||
*/
|
||||
|
||||
class LEPTON_EXPORT CompiledVectorExpression {
|
||||
public:
|
||||
CompiledVectorExpression();
|
||||
CompiledVectorExpression(const CompiledVectorExpression& expression);
|
||||
~CompiledVectorExpression();
|
||||
CompiledVectorExpression& operator=(const CompiledVectorExpression& expression);
|
||||
/**
|
||||
* Get the width of the vectors on which the expression is computed.
|
||||
*/
|
||||
int getWidth() const;
|
||||
/**
|
||||
* Get the names of all variables used by this expression.
|
||||
*/
|
||||
const std::set<std::string>& getVariables() const;
|
||||
/**
|
||||
* Get a pointer to the memory location where the value of a particular variable is stored. This can be used
|
||||
* to set the value of the variable before calling evaluate().
|
||||
*
|
||||
* @param name the name of the variable to query
|
||||
* @return a pointer to N floating point values, where N is the vector width
|
||||
*/
|
||||
float* getVariablePointer(const std::string& name);
|
||||
/**
|
||||
* You can optionally specify the memory locations from which the values of variables should be read.
|
||||
* This is useful, for example, when several expressions all use the same variable. You can then set
|
||||
* the value of that variable in one place, and it will be seen by all of them. The location should
|
||||
* be a pointer to N floating point values, where N is the vector width.
|
||||
*/
|
||||
void setVariableLocations(std::map<std::string, float*>& variableLocations);
|
||||
/**
|
||||
* Evaluate the expression. The values of all variables should have been set before calling this.
|
||||
*
|
||||
* @return a pointer to N floating point values, where N is the vector width
|
||||
*/
|
||||
const float* evaluate() const;
|
||||
/**
|
||||
* Get the list of vector widths that are supported on the current processor.
|
||||
*/
|
||||
static const std::vector<int>& getAllowedWidths();
|
||||
private:
|
||||
friend class ParsedExpression;
|
||||
CompiledVectorExpression(const ParsedExpression& expression, int width);
|
||||
void compileExpression(const ExpressionTreeNode& node, std::vector<std::pair<ExpressionTreeNode, int> >& temps, int& workspaceSize);
|
||||
int findTempIndex(const ExpressionTreeNode& node, std::vector<std::pair<ExpressionTreeNode, int> >& temps);
|
||||
int width;
|
||||
std::map<std::string, float*> variablePointers;
|
||||
std::vector<std::pair<float*, float*> > variablesToCopy;
|
||||
std::vector<std::vector<int> > arguments;
|
||||
std::vector<int> target;
|
||||
std::vector<Operation*> operation;
|
||||
std::map<std::string, int> variableIndices;
|
||||
std::set<std::string> variableNames;
|
||||
mutable std::vector<float> workspace;
|
||||
mutable std::vector<double> argValues;
|
||||
std::map<std::string, double> dummyVariables;
|
||||
void (*jitCode)();
|
||||
#ifdef LEPTON_USE_JIT
|
||||
void findPowerGroups(std::vector<std::vector<int> >& groups, std::vector<std::vector<int> >& groupPowers, std::vector<int>& stepGroup);
|
||||
void generateJitCode();
|
||||
#if defined(__ARM__) || defined(__ARM64__)
|
||||
void generateSingleArgCall(asmjit::a64::Compiler& c, asmjit::arm::Vec& dest, asmjit::arm::Vec& arg, float (*function)(float));
|
||||
void generateTwoArgCall(asmjit::a64::Compiler& c, asmjit::arm::Vec& dest, asmjit::arm::Vec& arg1, asmjit::arm::Vec& arg2, float (*function)(float, float));
|
||||
#else
|
||||
void generateSingleArgCall(asmjit::x86::Compiler& c, asmjit::x86::Ymm& dest, asmjit::x86::Ymm& arg, float (*function)(float));
|
||||
void generateTwoArgCall(asmjit::x86::Compiler& c, asmjit::x86::Ymm& dest, asmjit::x86::Ymm& arg1, asmjit::x86::Ymm& arg2, float (*function)(float, float));
|
||||
#endif
|
||||
std::vector<float> constants;
|
||||
asmjit::JitRuntime runtime;
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace LMP_Lepton
|
||||
|
||||
#endif /*LEPTON_VECTOR_EXPRESSION_H_*/
|
||||
@ -9,7 +9,7 @@
|
||||
* Biological Structures at Stanford, funded under the NIH Roadmap for *
|
||||
* Medical Research, grant U54 GM072970. See https://simtk.org. *
|
||||
* *
|
||||
* Portions copyright (c) 2009 Stanford University and the Authors. *
|
||||
* Portions copyright (c) 2009-2021 Stanford University and the Authors. *
|
||||
* Authors: Peter Eastman *
|
||||
* Contributors: *
|
||||
* *
|
||||
@ -39,6 +39,7 @@
|
||||
namespace LMP_Lepton {
|
||||
|
||||
class Operation;
|
||||
class ParsedExpression;
|
||||
|
||||
/**
|
||||
* This class represents a node in the abstract syntax tree representation of an expression.
|
||||
@ -82,11 +83,13 @@ public:
|
||||
*/
|
||||
ExpressionTreeNode(Operation* operation);
|
||||
ExpressionTreeNode(const ExpressionTreeNode& node);
|
||||
ExpressionTreeNode(ExpressionTreeNode&& node);
|
||||
ExpressionTreeNode();
|
||||
~ExpressionTreeNode();
|
||||
bool operator==(const ExpressionTreeNode& node) const;
|
||||
bool operator!=(const ExpressionTreeNode& node) const;
|
||||
ExpressionTreeNode& operator=(const ExpressionTreeNode& node);
|
||||
ExpressionTreeNode& operator=(ExpressionTreeNode&& node);
|
||||
/**
|
||||
* Get the Operation performed by this node.
|
||||
*/
|
||||
@ -96,8 +99,11 @@ public:
|
||||
*/
|
||||
const std::vector<ExpressionTreeNode>& getChildren() const;
|
||||
private:
|
||||
friend class ParsedExpression;
|
||||
void assignTags(std::vector<const ExpressionTreeNode*>& examples) const;
|
||||
Operation* operation;
|
||||
std::vector<ExpressionTreeNode> children;
|
||||
mutable int tag;
|
||||
};
|
||||
|
||||
} // namespace LMP_Lepton
|
||||
|
||||
@ -9,7 +9,7 @@
|
||||
* Biological Structures at Stanford, funded under the NIH Roadmap for *
|
||||
* Medical Research, grant U54 GM072970. See https://simtk.org. *
|
||||
* *
|
||||
* Portions copyright (c) 2009=2013 Stanford University and the Authors. *
|
||||
* Portions copyright (c) 2009-2022 Stanford University and the Authors. *
|
||||
* Authors: Peter Eastman *
|
||||
* Contributors: *
|
||||
* *
|
||||
@ -41,6 +41,7 @@ namespace LMP_Lepton {
|
||||
|
||||
class CompiledExpression;
|
||||
class ExpressionProgram;
|
||||
class CompiledVectorExpression;
|
||||
|
||||
/**
|
||||
* This class represents the result of parsing an expression. It provides methods for working with the
|
||||
@ -102,6 +103,16 @@ public:
|
||||
* Create a CompiledExpression that represents the same calculation as this expression.
|
||||
*/
|
||||
CompiledExpression createCompiledExpression() const;
|
||||
/**
|
||||
* Create a CompiledVectorExpression that allows the expression to be evaluated efficiently
|
||||
* using the CPU's vector unit.
|
||||
*
|
||||
* @param width the width of the vectors to evaluate it on. The allowed values
|
||||
* depend on the CPU. 4 is always allowed, and 8 is allowed on
|
||||
* x86 processors with AVX. Call CompiledVectorExpression::getAllowedWidths()
|
||||
* to query the allowed widths on the current processor.
|
||||
*/
|
||||
CompiledVectorExpression createCompiledVectorExpression(int width) const;
|
||||
/**
|
||||
* Create a new ParsedExpression which is identical to this one, except that the names of some
|
||||
* variables have been changed.
|
||||
@ -113,9 +124,9 @@ public:
|
||||
private:
|
||||
static double evaluate(const ExpressionTreeNode& node, const std::map<std::string, double>& variables);
|
||||
static ExpressionTreeNode preevaluateVariables(const ExpressionTreeNode& node, const std::map<std::string, double>& variables);
|
||||
static ExpressionTreeNode precalculateConstantSubexpressions(const ExpressionTreeNode& node);
|
||||
static ExpressionTreeNode substituteSimplerExpression(const ExpressionTreeNode& node);
|
||||
static ExpressionTreeNode differentiate(const ExpressionTreeNode& node, const std::string& variable);
|
||||
static ExpressionTreeNode precalculateConstantSubexpressions(const ExpressionTreeNode& node, std::map<int, ExpressionTreeNode>& nodeCache);
|
||||
static ExpressionTreeNode substituteSimplerExpression(const ExpressionTreeNode& node, std::map<int, ExpressionTreeNode>& nodeCache);
|
||||
static ExpressionTreeNode differentiate(const ExpressionTreeNode& node, const std::string& variable, std::map<int, ExpressionTreeNode>& nodeCache);
|
||||
static bool isConstant(const ExpressionTreeNode& node);
|
||||
static double getConstantValue(const ExpressionTreeNode& node);
|
||||
static ExpressionTreeNode renameNodeVariables(const ExpressionTreeNode& node, const std::map<std::string, std::string>& replacements);
|
||||
|
||||
@ -6,7 +6,7 @@
|
||||
* Biological Structures at Stanford, funded under the NIH Roadmap for *
|
||||
* Medical Research, grant U54 GM072970. See https://simtk.org. *
|
||||
* *
|
||||
* Portions copyright (c) 2013-2019 Stanford University and the Authors. *
|
||||
* Portions copyright (c) 2013-2022 Stanford University and the Authors. *
|
||||
* Authors: Peter Eastman *
|
||||
* Contributors: *
|
||||
* *
|
||||
@ -151,7 +151,7 @@ void CompiledExpression::setVariableLocations(map<string, double*>& variableLoca
|
||||
|
||||
if (workspace.size() > 0)
|
||||
generateJitCode();
|
||||
#else
|
||||
#endif
|
||||
// Make a list of all variables we will need to copy before evaluating the expression.
|
||||
|
||||
variablesToCopy.clear();
|
||||
@ -160,13 +160,11 @@ void CompiledExpression::setVariableLocations(map<string, double*>& variableLoca
|
||||
if (pointer != variablePointers.end())
|
||||
variablesToCopy.push_back(make_pair(&workspace[iter->second], pointer->second));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
double CompiledExpression::evaluate() const {
|
||||
#ifdef LEPTON_USE_JIT
|
||||
return jitCode();
|
||||
#else
|
||||
if (jitCode)
|
||||
return jitCode();
|
||||
for (int i = 0; i < (int)variablesToCopy.size(); i++)
|
||||
*variablesToCopy[i].first = *variablesToCopy[i].second;
|
||||
|
||||
@ -183,7 +181,6 @@ double CompiledExpression::evaluate() const {
|
||||
}
|
||||
}
|
||||
return workspace[workspace.size()-1];
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef LEPTON_USE_JIT
|
||||
@ -192,24 +189,70 @@ static double evaluateOperation(Operation* op, double* args) {
|
||||
return op->evaluate(args, dummyVariables);
|
||||
}
|
||||
|
||||
void CompiledExpression::findPowerGroups(vector<vector<int> >& groups, vector<vector<int> >& groupPowers, vector<int>& stepGroup) {
|
||||
// Identify every step that raises an argument to an integer power.
|
||||
|
||||
vector<int> stepPower(operation.size(), 0);
|
||||
vector<int> stepArg(operation.size(), -1);
|
||||
for (int step = 0; step < (int)operation.size(); step++) {
|
||||
Operation& op = *operation[step];
|
||||
int power = 0;
|
||||
if (op.getId() == Operation::SQUARE)
|
||||
power = 2;
|
||||
else if (op.getId() == Operation::CUBE)
|
||||
power = 3;
|
||||
else if (op.getId() == Operation::POWER_CONSTANT) {
|
||||
double realPower = dynamic_cast<const Operation::PowerConstant*>(&op)->getValue();
|
||||
if (realPower == (int) realPower)
|
||||
power = (int) realPower;
|
||||
}
|
||||
if (power != 0) {
|
||||
stepPower[step] = power;
|
||||
stepArg[step] = arguments[step][0];
|
||||
}
|
||||
}
|
||||
|
||||
// Find groups that operate on the same argument and whose powers have the same sign.
|
||||
|
||||
stepGroup.resize(operation.size(), -1);
|
||||
for (int i = 0; i < (int)operation.size(); i++) {
|
||||
if (stepGroup[i] != -1)
|
||||
continue;
|
||||
vector<int> group, power;
|
||||
for (int j = i; j < (int)operation.size(); j++) {
|
||||
if (stepArg[i] == stepArg[j] && stepPower[i]*stepPower[j] > 0) {
|
||||
stepGroup[j] = groups.size();
|
||||
group.push_back(j);
|
||||
power.push_back(stepPower[j]);
|
||||
}
|
||||
}
|
||||
groups.push_back(group);
|
||||
groupPowers.push_back(power);
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(__ARM__) || defined(__ARM64__)
|
||||
void CompiledExpression::generateJitCode() {
|
||||
CodeHolder code;
|
||||
code.init(runtime.getCodeInfo());
|
||||
X86Compiler c(&code);
|
||||
c.addFunc(FuncSignature0<double>());
|
||||
vector<X86Xmm> workspaceVar(workspace.size());
|
||||
code.init(runtime.environment());
|
||||
a64::Compiler c(&code);
|
||||
c.addFunc(FuncSignatureT<double>());
|
||||
vector<arm::Vec> workspaceVar(workspace.size());
|
||||
for (int i = 0; i < (int) workspaceVar.size(); i++)
|
||||
workspaceVar[i] = c.newXmmSd();
|
||||
X86Gp argsPointer = c.newIntPtr();
|
||||
c.mov(argsPointer, imm_ptr(&argValues[0]));
|
||||
workspaceVar[i] = c.newVecD();
|
||||
arm::Gp argsPointer = c.newIntPtr();
|
||||
c.mov(argsPointer, imm(&argValues[0]));
|
||||
vector<vector<int> > groups, groupPowers;
|
||||
vector<int> stepGroup;
|
||||
findPowerGroups(groups, groupPowers, stepGroup);
|
||||
|
||||
// Load the arguments into variables.
|
||||
|
||||
for (set<string>::const_iterator iter = variableNames.begin(); iter != variableNames.end(); ++iter) {
|
||||
map<string, int>::iterator index = variableIndices.find(*iter);
|
||||
X86Gp variablePointer = c.newIntPtr();
|
||||
c.mov(variablePointer, imm_ptr(&getVariableReference(index->first)));
|
||||
c.movsd(workspaceVar[index->second], x86::ptr(variablePointer, 0, 0));
|
||||
arm::Gp variablePointer = c.newIntPtr();
|
||||
c.mov(variablePointer, imm(&getVariableReference(index->first)));
|
||||
c.ldr(workspaceVar[index->second], arm::ptr(variablePointer, 0));
|
||||
}
|
||||
|
||||
// Make a list of all constants that will be needed for evaluation.
|
||||
@ -232,6 +275,12 @@ void CompiledExpression::generateJitCode() {
|
||||
value = 1.0;
|
||||
else if (op.getId() == Operation::DELTA)
|
||||
value = 1.0;
|
||||
else if (op.getId() == Operation::POWER_CONSTANT) {
|
||||
if (stepGroup[step] == -1)
|
||||
value = dynamic_cast<Operation::PowerConstant&>(op).getValue();
|
||||
else
|
||||
value = 1.0;
|
||||
}
|
||||
else
|
||||
continue;
|
||||
|
||||
@ -250,19 +299,63 @@ void CompiledExpression::generateJitCode() {
|
||||
|
||||
// Load constants into variables.
|
||||
|
||||
vector<X86Xmm> constantVar(constants.size());
|
||||
vector<arm::Vec> constantVar(constants.size());
|
||||
if (constants.size() > 0) {
|
||||
X86Gp constantsPointer = c.newIntPtr();
|
||||
c.mov(constantsPointer, imm_ptr(&constants[0]));
|
||||
arm::Gp constantsPointer = c.newIntPtr();
|
||||
c.mov(constantsPointer, imm(&constants[0]));
|
||||
for (int i = 0; i < (int) constants.size(); i++) {
|
||||
constantVar[i] = c.newXmmSd();
|
||||
c.movsd(constantVar[i], x86::ptr(constantsPointer, 8*i, 0));
|
||||
constantVar[i] = c.newVecD();
|
||||
c.ldr(constantVar[i], arm::ptr(constantsPointer, 8*i));
|
||||
}
|
||||
}
|
||||
|
||||
// Evaluate the operations.
|
||||
|
||||
vector<bool> hasComputedPower(operation.size(), false);
|
||||
for (int step = 0; step < (int) operation.size(); step++) {
|
||||
if (hasComputedPower[step])
|
||||
continue;
|
||||
|
||||
// When one or more steps involve raising the same argument to multiple integer
|
||||
// powers, we can compute them all together for efficiency.
|
||||
|
||||
if (stepGroup[step] != -1) {
|
||||
vector<int>& group = groups[stepGroup[step]];
|
||||
vector<int>& powers = groupPowers[stepGroup[step]];
|
||||
arm::Vec multiplier = c.newVecD();
|
||||
if (powers[0] > 0)
|
||||
c.fmov(multiplier, workspaceVar[arguments[step][0]]);
|
||||
else {
|
||||
c.fdiv(multiplier, constantVar[operationConstantIndex[step]], workspaceVar[arguments[step][0]]);
|
||||
for (int i = 0; i < powers.size(); i++)
|
||||
powers[i] = -powers[i];
|
||||
}
|
||||
vector<bool> hasAssigned(group.size(), false);
|
||||
bool done = false;
|
||||
while (!done) {
|
||||
done = true;
|
||||
for (int i = 0; i < group.size(); i++) {
|
||||
if (powers[i]%2 == 1) {
|
||||
if (!hasAssigned[i])
|
||||
c.fmov(workspaceVar[target[group[i]]], multiplier);
|
||||
else
|
||||
c.fmul(workspaceVar[target[group[i]]], workspaceVar[target[group[i]]], multiplier);
|
||||
hasAssigned[i] = true;
|
||||
}
|
||||
powers[i] >>= 1;
|
||||
if (powers[i] != 0)
|
||||
done = false;
|
||||
}
|
||||
if (!done)
|
||||
c.fmul(multiplier, multiplier, multiplier);
|
||||
}
|
||||
for (int step : group)
|
||||
hasComputedPower[step] = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Evaluate the step.
|
||||
|
||||
Operation& op = *operation[step];
|
||||
vector<int> args = arguments[step];
|
||||
if (args.size() == 1) {
|
||||
@ -276,33 +369,28 @@ void CompiledExpression::generateJitCode() {
|
||||
|
||||
switch (op.getId()) {
|
||||
case Operation::CONSTANT:
|
||||
c.movsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
|
||||
c.fmov(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
|
||||
break;
|
||||
case Operation::ADD:
|
||||
c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
|
||||
c.addsd(workspaceVar[target[step]], workspaceVar[args[1]]);
|
||||
c.fadd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
|
||||
break;
|
||||
case Operation::SUBTRACT:
|
||||
c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
|
||||
c.subsd(workspaceVar[target[step]], workspaceVar[args[1]]);
|
||||
c.fsub(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
|
||||
break;
|
||||
case Operation::MULTIPLY:
|
||||
c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
|
||||
c.mulsd(workspaceVar[target[step]], workspaceVar[args[1]]);
|
||||
c.fmul(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
|
||||
break;
|
||||
case Operation::DIVIDE:
|
||||
c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
|
||||
c.divsd(workspaceVar[target[step]], workspaceVar[args[1]]);
|
||||
c.fdiv(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
|
||||
break;
|
||||
case Operation::POWER:
|
||||
generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], pow);
|
||||
break;
|
||||
case Operation::NEGATE:
|
||||
c.xorps(workspaceVar[target[step]], workspaceVar[target[step]]);
|
||||
c.subsd(workspaceVar[target[step]], workspaceVar[args[0]]);
|
||||
c.fneg(workspaceVar[target[step]], workspaceVar[args[0]]);
|
||||
break;
|
||||
case Operation::SQRT:
|
||||
c.sqrtsd(workspaceVar[target[step]], workspaceVar[args[0]]);
|
||||
c.fsqrt(workspaceVar[target[step]], workspaceVar[args[0]]);
|
||||
break;
|
||||
case Operation::EXP:
|
||||
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], exp);
|
||||
@ -341,56 +429,63 @@ void CompiledExpression::generateJitCode() {
|
||||
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tanh);
|
||||
break;
|
||||
case Operation::STEP:
|
||||
c.xorps(workspaceVar[target[step]], workspaceVar[target[step]]);
|
||||
c.cmpsd(workspaceVar[target[step]], workspaceVar[args[0]], imm(18)); // Comparison mode is _CMP_LE_OQ = 18
|
||||
c.andps(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
|
||||
c.cmge(workspaceVar[target[step]], workspaceVar[args[0]], imm(0));
|
||||
c.and_(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
|
||||
break;
|
||||
case Operation::DELTA:
|
||||
c.xorps(workspaceVar[target[step]], workspaceVar[target[step]]);
|
||||
c.cmpsd(workspaceVar[target[step]], workspaceVar[args[0]], imm(16)); // Comparison mode is _CMP_EQ_OS = 16
|
||||
c.andps(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
|
||||
c.cmeq(workspaceVar[target[step]], workspaceVar[args[0]], imm(0));
|
||||
c.and_(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
|
||||
break;
|
||||
case Operation::SQUARE:
|
||||
c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
|
||||
c.mulsd(workspaceVar[target[step]], workspaceVar[args[0]]);
|
||||
c.fmul(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]]);
|
||||
break;
|
||||
case Operation::CUBE:
|
||||
c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
|
||||
c.mulsd(workspaceVar[target[step]], workspaceVar[args[0]]);
|
||||
c.mulsd(workspaceVar[target[step]], workspaceVar[args[0]]);
|
||||
c.fmul(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]]);
|
||||
c.fmul(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]]);
|
||||
break;
|
||||
case Operation::RECIPROCAL:
|
||||
c.movsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
|
||||
c.divsd(workspaceVar[target[step]], workspaceVar[args[0]]);
|
||||
c.fdiv(workspaceVar[target[step]], constantVar[operationConstantIndex[step]], workspaceVar[args[0]]);
|
||||
break;
|
||||
case Operation::ADD_CONSTANT:
|
||||
c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
|
||||
c.addsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
|
||||
c.fadd(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]);
|
||||
break;
|
||||
case Operation::MULTIPLY_CONSTANT:
|
||||
c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
|
||||
c.mulsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
|
||||
c.fmul(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]);
|
||||
break;
|
||||
case Operation::POWER_CONSTANT:
|
||||
generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]], pow);
|
||||
break;
|
||||
case Operation::MIN:
|
||||
c.fmin(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
|
||||
break;
|
||||
case Operation::MAX:
|
||||
c.fmax(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
|
||||
break;
|
||||
case Operation::ABS:
|
||||
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], fabs);
|
||||
c.fabs(workspaceVar[target[step]], workspaceVar[args[0]]);
|
||||
break;
|
||||
case Operation::FLOOR:
|
||||
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], floor);
|
||||
c.frintm(workspaceVar[target[step]], workspaceVar[args[0]]);
|
||||
break;
|
||||
case Operation::CEIL:
|
||||
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], ceil);
|
||||
c.frintp(workspaceVar[target[step]], workspaceVar[args[0]]);
|
||||
break;
|
||||
case Operation::SELECT:
|
||||
c.fcmeq(workspaceVar[target[step]], workspaceVar[args[0]], imm(0));
|
||||
c.bsl(workspaceVar[target[step]], workspaceVar[args[2]], workspaceVar[args[1]]);
|
||||
break;
|
||||
default:
|
||||
// Just invoke evaluateOperation().
|
||||
|
||||
for (int i = 0; i < (int) args.size(); i++)
|
||||
c.movsd(x86::ptr(argsPointer, 8*i, 0), workspaceVar[args[i]]);
|
||||
X86Gp fn = c.newIntPtr();
|
||||
c.mov(fn, imm_ptr((void*) evaluateOperation));
|
||||
CCFuncCall* call = c.call(fn, FuncSignature2<double, Operation*, double*>());
|
||||
call->setArg(0, imm_ptr(&op));
|
||||
call->setArg(1, imm_ptr(&argValues[0]));
|
||||
call->setRet(0, workspaceVar[target[step]]);
|
||||
c.str(workspaceVar[args[i]], arm::ptr(argsPointer, 8*i));
|
||||
arm::Gp fn = c.newIntPtr();
|
||||
c.mov(fn, imm((void*) evaluateOperation));
|
||||
InvokeNode* invoke;
|
||||
c.invoke(&invoke, fn, FuncSignatureT<double, Operation*, double*>());
|
||||
invoke->setArg(0, imm(&op));
|
||||
invoke->setArg(1, imm(&argValues[0]));
|
||||
invoke->setRet(0, workspaceVar[target[step]]);
|
||||
}
|
||||
}
|
||||
c.ret(workspaceVar[workspace.size()-1]);
|
||||
@ -399,20 +494,319 @@ void CompiledExpression::generateJitCode() {
|
||||
runtime.add(&jitCode, &code);
|
||||
}
|
||||
|
||||
void CompiledExpression::generateSingleArgCall(X86Compiler& c, X86Xmm& dest, X86Xmm& arg, double (*function)(double)) {
|
||||
X86Gp fn = c.newIntPtr();
|
||||
c.mov(fn, imm_ptr((void*) function));
|
||||
CCFuncCall* call = c.call(fn, FuncSignature1<double, double>());
|
||||
call->setArg(0, arg);
|
||||
call->setRet(0, dest);
|
||||
void CompiledExpression::generateSingleArgCall(a64::Compiler& c, arm::Vec& dest, arm::Vec& arg, double (*function)(double)) {
|
||||
arm::Gp fn = c.newIntPtr();
|
||||
c.mov(fn, imm((void*) function));
|
||||
InvokeNode* invoke;
|
||||
c.invoke(&invoke, fn, FuncSignatureT<double, double>());
|
||||
invoke->setArg(0, arg);
|
||||
invoke->setRet(0, dest);
|
||||
}
|
||||
|
||||
void CompiledExpression::generateTwoArgCall(X86Compiler& c, X86Xmm& dest, X86Xmm& arg1, X86Xmm& arg2, double (*function)(double, double)) {
|
||||
X86Gp fn = c.newIntPtr();
|
||||
c.mov(fn, imm_ptr((void*) function));
|
||||
CCFuncCall* call = c.call(fn, FuncSignature2<double, double, double>());
|
||||
call->setArg(0, arg1);
|
||||
call->setArg(1, arg2);
|
||||
call->setRet(0, dest);
|
||||
void CompiledExpression::generateTwoArgCall(a64::Compiler& c, arm::Vec& dest, arm::Vec& arg1, arm::Vec& arg2, double (*function)(double, double)) {
|
||||
arm::Gp fn = c.newIntPtr();
|
||||
c.mov(fn, imm((void*) function));
|
||||
InvokeNode* invoke;
|
||||
c.invoke(&invoke, fn, FuncSignatureT<double, double, double>());
|
||||
invoke->setArg(0, arg1);
|
||||
invoke->setArg(1, arg2);
|
||||
invoke->setRet(0, dest);
|
||||
}
|
||||
#else
|
||||
void CompiledExpression::generateJitCode() {
|
||||
const CpuInfo& cpu = CpuInfo::host();
|
||||
if (!cpu.hasFeature(CpuFeatures::X86::kAVX))
|
||||
return;
|
||||
CodeHolder code;
|
||||
code.init(runtime.environment());
|
||||
x86::Compiler c(&code);
|
||||
FuncNode* funcNode = c.addFunc(FuncSignatureT<double>());
|
||||
funcNode->frame().setAvxEnabled();
|
||||
vector<x86::Xmm> workspaceVar(workspace.size());
|
||||
for (int i = 0; i < (int) workspaceVar.size(); i++)
|
||||
workspaceVar[i] = c.newXmmSd();
|
||||
x86::Gp argsPointer = c.newIntPtr();
|
||||
c.mov(argsPointer, imm(&argValues[0]));
|
||||
vector<vector<int> > groups, groupPowers;
|
||||
vector<int> stepGroup;
|
||||
findPowerGroups(groups, groupPowers, stepGroup);
|
||||
|
||||
// Load the arguments into variables.
|
||||
|
||||
x86::Gp variablePointer = c.newIntPtr();
|
||||
for (set<string>::const_iterator iter = variableNames.begin(); iter != variableNames.end(); ++iter) {
|
||||
map<string, int>::iterator index = variableIndices.find(*iter);
|
||||
c.mov(variablePointer, imm(&getVariableReference(index->first)));
|
||||
c.vmovsd(workspaceVar[index->second], x86::ptr(variablePointer, 0, 0));
|
||||
}
|
||||
|
||||
// Make a list of all constants that will be needed for evaluation.
|
||||
|
||||
vector<int> operationConstantIndex(operation.size(), -1);
|
||||
for (int step = 0; step < (int) operation.size(); step++) {
|
||||
// Find the constant value (if any) used by this operation.
|
||||
|
||||
Operation& op = *operation[step];
|
||||
double value;
|
||||
if (op.getId() == Operation::CONSTANT)
|
||||
value = dynamic_cast<Operation::Constant&>(op).getValue();
|
||||
else if (op.getId() == Operation::ADD_CONSTANT)
|
||||
value = dynamic_cast<Operation::AddConstant&>(op).getValue();
|
||||
else if (op.getId() == Operation::MULTIPLY_CONSTANT)
|
||||
value = dynamic_cast<Operation::MultiplyConstant&>(op).getValue();
|
||||
else if (op.getId() == Operation::RECIPROCAL)
|
||||
value = 1.0;
|
||||
else if (op.getId() == Operation::STEP)
|
||||
value = 1.0;
|
||||
else if (op.getId() == Operation::DELTA)
|
||||
value = 1.0;
|
||||
else if (op.getId() == Operation::ABS) {
|
||||
long long mask = 0x7FFFFFFFFFFFFFFF;
|
||||
value = *reinterpret_cast<double*>(&mask);
|
||||
}
|
||||
else if (op.getId() == Operation::POWER_CONSTANT) {
|
||||
if (stepGroup[step] == -1)
|
||||
value = dynamic_cast<Operation::PowerConstant&>(op).getValue();
|
||||
else
|
||||
value = 1.0;
|
||||
}
|
||||
else
|
||||
continue;
|
||||
|
||||
// See if we already have a variable for this constant.
|
||||
|
||||
for (int i = 0; i < (int) constants.size(); i++)
|
||||
if (value == constants[i]) {
|
||||
operationConstantIndex[step] = i;
|
||||
break;
|
||||
}
|
||||
if (operationConstantIndex[step] == -1) {
|
||||
operationConstantIndex[step] = constants.size();
|
||||
constants.push_back(value);
|
||||
}
|
||||
}
|
||||
|
||||
// Load constants into variables.
|
||||
|
||||
vector<x86::Xmm> constantVar(constants.size());
|
||||
if (constants.size() > 0) {
|
||||
x86::Gp constantsPointer = c.newIntPtr();
|
||||
c.mov(constantsPointer, imm(&constants[0]));
|
||||
for (int i = 0; i < (int) constants.size(); i++) {
|
||||
constantVar[i] = c.newXmmSd();
|
||||
c.vmovsd(constantVar[i], x86::ptr(constantsPointer, 8*i, 0));
|
||||
}
|
||||
}
|
||||
|
||||
// Evaluate the operations.
|
||||
|
||||
vector<bool> hasComputedPower(operation.size(), false);
|
||||
for (int step = 0; step < (int) operation.size(); step++) {
|
||||
if (hasComputedPower[step])
|
||||
continue;
|
||||
|
||||
// When one or more steps involve raising the same argument to multiple integer
|
||||
// powers, we can compute them all together for efficiency.
|
||||
|
||||
if (stepGroup[step] != -1) {
|
||||
vector<int>& group = groups[stepGroup[step]];
|
||||
vector<int>& powers = groupPowers[stepGroup[step]];
|
||||
x86::Xmm multiplier = c.newXmmSd();
|
||||
if (powers[0] > 0)
|
||||
c.vmovsd(multiplier, workspaceVar[arguments[step][0]], workspaceVar[arguments[step][0]]);
|
||||
else {
|
||||
c.vdivsd(multiplier, constantVar[operationConstantIndex[step]], workspaceVar[arguments[step][0]]);
|
||||
for (int i = 0; i < (int)powers.size(); i++)
|
||||
powers[i] = -powers[i];
|
||||
}
|
||||
vector<bool> hasAssigned(group.size(), false);
|
||||
bool done = false;
|
||||
while (!done) {
|
||||
done = true;
|
||||
for (int i = 0; i < (int)group.size(); i++) {
|
||||
if (powers[i]%2 == 1) {
|
||||
if (!hasAssigned[i])
|
||||
c.vmovsd(workspaceVar[target[group[i]]], multiplier, multiplier);
|
||||
else
|
||||
c.vmulsd(workspaceVar[target[group[i]]], workspaceVar[target[group[i]]], multiplier);
|
||||
hasAssigned[i] = true;
|
||||
}
|
||||
powers[i] >>= 1;
|
||||
if (powers[i] != 0)
|
||||
done = false;
|
||||
}
|
||||
if (!done)
|
||||
c.vmulsd(multiplier, multiplier, multiplier);
|
||||
}
|
||||
for (int step : group)
|
||||
hasComputedPower[step] = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Evaluate the step.
|
||||
|
||||
Operation& op = *operation[step];
|
||||
vector<int> args = arguments[step];
|
||||
if (args.size() == 1) {
|
||||
// One or more sequential arguments. Fill out the list.
|
||||
|
||||
for (int i = 1; i < op.getNumArguments(); i++)
|
||||
args.push_back(args[0]+i);
|
||||
}
|
||||
|
||||
// Generate instructions to execute this operation.
|
||||
|
||||
switch (op.getId()) {
|
||||
case Operation::CONSTANT:
|
||||
c.vmovsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]], constantVar[operationConstantIndex[step]]);
|
||||
break;
|
||||
case Operation::ADD:
|
||||
c.vaddsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
|
||||
break;
|
||||
case Operation::SUBTRACT:
|
||||
c.vsubsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
|
||||
break;
|
||||
case Operation::MULTIPLY:
|
||||
c.vmulsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
|
||||
break;
|
||||
case Operation::DIVIDE:
|
||||
c.vdivsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
|
||||
break;
|
||||
case Operation::POWER:
|
||||
generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], pow);
|
||||
break;
|
||||
case Operation::NEGATE:
|
||||
c.vxorps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[target[step]]);
|
||||
c.vsubsd(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]]);
|
||||
break;
|
||||
case Operation::SQRT:
|
||||
c.vsqrtsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]]);
|
||||
break;
|
||||
case Operation::EXP:
|
||||
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], exp);
|
||||
break;
|
||||
case Operation::LOG:
|
||||
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], log);
|
||||
break;
|
||||
case Operation::SIN:
|
||||
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], sin);
|
||||
break;
|
||||
case Operation::COS:
|
||||
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], cos);
|
||||
break;
|
||||
case Operation::TAN:
|
||||
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tan);
|
||||
break;
|
||||
case Operation::ASIN:
|
||||
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], asin);
|
||||
break;
|
||||
case Operation::ACOS:
|
||||
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], acos);
|
||||
break;
|
||||
case Operation::ATAN:
|
||||
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], atan);
|
||||
break;
|
||||
case Operation::ATAN2:
|
||||
generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], atan2);
|
||||
break;
|
||||
case Operation::SINH:
|
||||
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], sinh);
|
||||
break;
|
||||
case Operation::COSH:
|
||||
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], cosh);
|
||||
break;
|
||||
case Operation::TANH:
|
||||
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tanh);
|
||||
break;
|
||||
case Operation::STEP:
|
||||
c.vxorps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[target[step]]);
|
||||
c.vcmpsd(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]], imm(18)); // Comparison mode is _CMP_LE_OQ = 18
|
||||
c.vandps(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
|
||||
break;
|
||||
case Operation::DELTA:
|
||||
c.vxorps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[target[step]]);
|
||||
c.vcmpsd(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]], imm(16)); // Comparison mode is _CMP_EQ_OS = 16
|
||||
c.vandps(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
|
||||
break;
|
||||
case Operation::SQUARE:
|
||||
c.vmulsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]]);
|
||||
break;
|
||||
case Operation::CUBE:
|
||||
c.vmulsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]]);
|
||||
c.vmulsd(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]]);
|
||||
break;
|
||||
case Operation::RECIPROCAL:
|
||||
c.vdivsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]], workspaceVar[args[0]]);
|
||||
break;
|
||||
case Operation::ADD_CONSTANT:
|
||||
c.vaddsd(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]);
|
||||
break;
|
||||
case Operation::MULTIPLY_CONSTANT:
|
||||
c.vmulsd(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]);
|
||||
break;
|
||||
case Operation::POWER_CONSTANT:
|
||||
generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]], pow);
|
||||
break;
|
||||
case Operation::MIN:
|
||||
c.vminsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
|
||||
break;
|
||||
case Operation::MAX:
|
||||
c.vmaxsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
|
||||
break;
|
||||
case Operation::ABS:
|
||||
c.vandpd(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]);
|
||||
break;
|
||||
case Operation::FLOOR:
|
||||
c.vroundsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]], imm(1));
|
||||
break;
|
||||
case Operation::CEIL:
|
||||
c.vroundsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]], imm(2));
|
||||
break;
|
||||
case Operation::SELECT:
|
||||
{
|
||||
x86::Xmm mask = c.newXmmSd();
|
||||
c.vxorps(mask, mask, mask);
|
||||
c.vcmpsd(mask, mask, workspaceVar[args[0]], imm(0)); // Comparison mode is _CMP_EQ_OQ = 0
|
||||
c.vblendvps(workspaceVar[target[step]], workspaceVar[args[1]], workspaceVar[args[2]], mask);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
// Just invoke evaluateOperation().
|
||||
|
||||
for (int i = 0; i < (int) args.size(); i++)
|
||||
c.vmovsd(x86::ptr(argsPointer, 8*i, 0), workspaceVar[args[i]]);
|
||||
x86::Gp fn = c.newIntPtr();
|
||||
c.mov(fn, imm((void*) evaluateOperation));
|
||||
InvokeNode* invoke;
|
||||
c.invoke(&invoke, fn, FuncSignatureT<double, Operation*, double*>());
|
||||
invoke->setArg(0, imm(&op));
|
||||
invoke->setArg(1, imm(&argValues[0]));
|
||||
invoke->setRet(0, workspaceVar[target[step]]);
|
||||
}
|
||||
}
|
||||
c.ret(workspaceVar[workspace.size()-1]);
|
||||
c.endFunc();
|
||||
c.finalize();
|
||||
runtime.add(&jitCode, &code);
|
||||
}
|
||||
|
||||
void CompiledExpression::generateSingleArgCall(x86::Compiler& c, x86::Xmm& dest, x86::Xmm& arg, double (*function)(double)) {
|
||||
x86::Gp fn = c.newIntPtr();
|
||||
c.mov(fn, imm((void*) function));
|
||||
InvokeNode* invoke;
|
||||
c.invoke(&invoke, fn, FuncSignatureT<double, double>());
|
||||
invoke->setArg(0, arg);
|
||||
invoke->setRet(0, dest);
|
||||
}
|
||||
|
||||
void CompiledExpression::generateTwoArgCall(x86::Compiler& c, x86::Xmm& dest, x86::Xmm& arg1, x86::Xmm& arg2, double (*function)(double, double)) {
|
||||
x86::Gp fn = c.newIntPtr();
|
||||
c.mov(fn, imm((void*) function));
|
||||
InvokeNode* invoke;
|
||||
c.invoke(&invoke, fn, FuncSignatureT<double, double, double>());
|
||||
invoke->setArg(0, arg1);
|
||||
invoke->setArg(1, arg2);
|
||||
invoke->setRet(0, dest);
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
933
lib/lepton/src/CompiledVectorExpression.cpp
Normal file
933
lib/lepton/src/CompiledVectorExpression.cpp
Normal file
@ -0,0 +1,933 @@
|
||||
/* -------------------------------------------------------------------------- *
|
||||
* Lepton *
|
||||
* -------------------------------------------------------------------------- *
|
||||
* This is part of the Lepton expression parser originating from *
|
||||
* Simbios, the NIH National Center for Physics-Based Simulation of *
|
||||
* Biological Structures at Stanford, funded under the NIH Roadmap for *
|
||||
* Medical Research, grant U54 GM072970. See https://simtk.org. *
|
||||
* *
|
||||
* Portions copyright (c) 2013-2022 Stanford University and the Authors. *
|
||||
* Authors: Peter Eastman *
|
||||
* Contributors: *
|
||||
* *
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a *
|
||||
* copy of this software and associated documentation files (the "Software"), *
|
||||
* to deal in the Software without restriction, including without limitation *
|
||||
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
|
||||
* and/or sell copies of the Software, and to permit persons to whom the *
|
||||
* Software is furnished to do so, subject to the following conditions: *
|
||||
* *
|
||||
* The above copyright notice and this permission notice shall be included in *
|
||||
* all copies or substantial portions of the Software. *
|
||||
* *
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
|
||||
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
|
||||
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
|
||||
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
|
||||
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
#include "lepton/CompiledVectorExpression.h"
|
||||
#include "lepton/Operation.h"
|
||||
#include "lepton/ParsedExpression.h"
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
|
||||
using namespace LMP_Lepton;
|
||||
using namespace std;
|
||||
#ifdef LEPTON_USE_JIT
|
||||
using namespace asmjit;
|
||||
#endif
|
||||
|
||||
CompiledVectorExpression::CompiledVectorExpression() : jitCode(NULL) {
|
||||
}
|
||||
|
||||
CompiledVectorExpression::CompiledVectorExpression(const ParsedExpression& expression, int width) : width(width), jitCode(NULL) {
|
||||
const vector<int> allowedWidths = getAllowedWidths();
|
||||
if (find(allowedWidths.begin(), allowedWidths.end(), width) == allowedWidths.end())
|
||||
throw Exception("Unsupported width for vector expression: "+to_string(width));
|
||||
ParsedExpression expr = expression.optimize(); // Just in case it wasn't already optimized.
|
||||
vector<pair<ExpressionTreeNode, int> > temps;
|
||||
int workspaceSize = 0;
|
||||
compileExpression(expr.getRootNode(), temps, workspaceSize);
|
||||
workspace.resize(workspaceSize*width);
|
||||
int maxArguments = 1;
|
||||
for (int i = 0; i < (int) operation.size(); i++)
|
||||
if (operation[i]->getNumArguments() > maxArguments)
|
||||
maxArguments = operation[i]->getNumArguments();
|
||||
argValues.resize(maxArguments);
|
||||
#ifdef LEPTON_USE_JIT
|
||||
generateJitCode();
|
||||
#endif
|
||||
}
|
||||
|
||||
CompiledVectorExpression::~CompiledVectorExpression() {
|
||||
for (int i = 0; i < (int) operation.size(); i++)
|
||||
if (operation[i] != NULL)
|
||||
delete operation[i];
|
||||
}
|
||||
|
||||
CompiledVectorExpression::CompiledVectorExpression(const CompiledVectorExpression& expression) : jitCode(NULL) {
|
||||
*this = expression;
|
||||
}
|
||||
|
||||
CompiledVectorExpression& CompiledVectorExpression::operator=(const CompiledVectorExpression& expression) {
|
||||
arguments = expression.arguments;
|
||||
width = expression.width;
|
||||
target = expression.target;
|
||||
variableIndices = expression.variableIndices;
|
||||
variableNames = expression.variableNames;
|
||||
workspace.resize(expression.workspace.size());
|
||||
argValues.resize(expression.argValues.size());
|
||||
operation.resize(expression.operation.size());
|
||||
for (int i = 0; i < (int) operation.size(); i++)
|
||||
operation[i] = expression.operation[i]->clone();
|
||||
setVariableLocations(variablePointers);
|
||||
return *this;
|
||||
}
|
||||
|
||||
const vector<int>& CompiledVectorExpression::getAllowedWidths() {
|
||||
static vector<int> widths;
|
||||
if (widths.size() == 0) {
|
||||
widths.push_back(4);
|
||||
#ifdef LEPTON_USE_JIT
|
||||
const CpuInfo& cpu = CpuInfo::host();
|
||||
if (cpu.hasFeature(CpuFeatures::X86::kAVX))
|
||||
widths.push_back(8);
|
||||
#endif
|
||||
}
|
||||
return widths;
|
||||
}
|
||||
|
||||
void CompiledVectorExpression::compileExpression(const ExpressionTreeNode& node, vector<pair<ExpressionTreeNode, int> >& temps, int& workspaceSize) {
|
||||
if (findTempIndex(node, temps) != -1)
|
||||
return; // We have already processed a node identical to this one.
|
||||
|
||||
// Process the child nodes.
|
||||
|
||||
vector<int> args;
|
||||
for (int i = 0; i < (int)node.getChildren().size(); i++) {
|
||||
compileExpression(node.getChildren()[i], temps, workspaceSize);
|
||||
args.push_back(findTempIndex(node.getChildren()[i], temps));
|
||||
}
|
||||
|
||||
// Process this node.
|
||||
|
||||
if (node.getOperation().getId() == Operation::VARIABLE) {
|
||||
variableIndices[node.getOperation().getName()] = workspaceSize;
|
||||
variableNames.insert(node.getOperation().getName());
|
||||
}
|
||||
else {
|
||||
int stepIndex = (int) arguments.size();
|
||||
arguments.push_back(vector<int>());
|
||||
target.push_back(workspaceSize);
|
||||
operation.push_back(node.getOperation().clone());
|
||||
if (args.size() == 0)
|
||||
arguments[stepIndex].push_back(0); // The value won't actually be used. We just need something there.
|
||||
else {
|
||||
// If the arguments are sequential, we can just pass a pointer to the first one.
|
||||
|
||||
bool sequential = true;
|
||||
for (int i = 1; i < (int)args.size(); i++)
|
||||
if (args[i] != args[i - 1] + 1)
|
||||
sequential = false;
|
||||
if (sequential)
|
||||
arguments[stepIndex].push_back(args[0]);
|
||||
else
|
||||
arguments[stepIndex] = args;
|
||||
}
|
||||
}
|
||||
temps.push_back(make_pair(node, workspaceSize));
|
||||
workspaceSize++;
|
||||
}
|
||||
|
||||
int CompiledVectorExpression::findTempIndex(const ExpressionTreeNode& node, vector<pair<ExpressionTreeNode, int> >& temps) {
|
||||
for (int i = 0; i < (int) temps.size(); i++)
|
||||
if (temps[i].first == node)
|
||||
return i;
|
||||
return -1;
|
||||
}
|
||||
|
||||
int CompiledVectorExpression::getWidth() const {
|
||||
return width;
|
||||
}
|
||||
|
||||
const set<string>& CompiledVectorExpression::getVariables() const {
|
||||
return variableNames;
|
||||
}
|
||||
|
||||
float* CompiledVectorExpression::getVariablePointer(const string& name) {
|
||||
map<string, float*>::iterator pointer = variablePointers.find(name);
|
||||
if (pointer != variablePointers.end())
|
||||
return pointer->second;
|
||||
map<string, int>::iterator index = variableIndices.find(name);
|
||||
if (index == variableIndices.end())
|
||||
throw Exception("getVariableReference: Unknown variable '" + name + "'");
|
||||
return &workspace[index->second*width];
|
||||
}
|
||||
|
||||
void CompiledVectorExpression::setVariableLocations(map<string, float*>& variableLocations) {
|
||||
variablePointers = variableLocations;
|
||||
#ifdef LEPTON_USE_JIT
|
||||
// Rebuild the JIT code.
|
||||
|
||||
if (workspace.size() > 0)
|
||||
generateJitCode();
|
||||
#endif
|
||||
// Make a list of all variables we will need to copy before evaluating the expression.
|
||||
|
||||
variablesToCopy.clear();
|
||||
for (map<string, int>::const_iterator iter = variableIndices.begin(); iter != variableIndices.end(); ++iter) {
|
||||
map<string, float*>::iterator pointer = variablePointers.find(iter->first);
|
||||
if (pointer != variablePointers.end())
|
||||
variablesToCopy.push_back(make_pair(&workspace[iter->second*width], pointer->second));
|
||||
}
|
||||
}
|
||||
|
||||
const float* CompiledVectorExpression::evaluate() const {
|
||||
if (jitCode) {
|
||||
jitCode();
|
||||
return &workspace[workspace.size()-width];
|
||||
}
|
||||
for (int i = 0; i < (int)variablesToCopy.size(); i++)
|
||||
for (int j = 0; j < width; j++)
|
||||
variablesToCopy[i].first[j] = variablesToCopy[i].second[j];
|
||||
|
||||
// Loop over the operations and evaluate each one.
|
||||
|
||||
for (int step = 0; step < (int)operation.size(); step++) {
|
||||
const vector<int>& args = arguments[step];
|
||||
if (args.size() == 1) {
|
||||
for (int j = 0; j < width; j++) {
|
||||
for (int i = 0; i < operation[step]->getNumArguments(); i++)
|
||||
argValues[i] = workspace[(args[0]+i)*width+j];
|
||||
workspace[target[step]*width+j] = operation[step]->evaluate(&argValues[0], dummyVariables);
|
||||
}
|
||||
} else {
|
||||
for (int j = 0; j < width; j++) {
|
||||
for (int i = 0; i < (int)args.size(); i++)
|
||||
argValues[i] = workspace[args[i]*width+j];
|
||||
workspace[target[step]*width+j] = operation[step]->evaluate(&argValues[0], dummyVariables);
|
||||
}
|
||||
}
|
||||
}
|
||||
return &workspace[workspace.size()-width];
|
||||
}
|
||||
|
||||
#ifdef LEPTON_USE_JIT
|
||||
|
||||
static double evaluateOperation(Operation* op, double* args) {
|
||||
static map<string, double> dummyVariables;
|
||||
return op->evaluate(args, dummyVariables);
|
||||
}
|
||||
|
||||
void CompiledVectorExpression::findPowerGroups(vector<vector<int> >& groups, vector<vector<int> >& groupPowers, vector<int>& stepGroup) {
|
||||
// Identify every step that raises an argument to an integer power.
|
||||
|
||||
vector<int> stepPower(operation.size(), 0);
|
||||
vector<int> stepArg(operation.size(), -1);
|
||||
for (int step = 0; step < (int)operation.size(); step++) {
|
||||
Operation& op = *operation[step];
|
||||
int power = 0;
|
||||
if (op.getId() == Operation::SQUARE)
|
||||
power = 2;
|
||||
else if (op.getId() == Operation::CUBE)
|
||||
power = 3;
|
||||
else if (op.getId() == Operation::POWER_CONSTANT) {
|
||||
double realPower = dynamic_cast<const Operation::PowerConstant*> (&op)->getValue();
|
||||
if (realPower == (int) realPower)
|
||||
power = (int) realPower;
|
||||
}
|
||||
if (power != 0) {
|
||||
stepPower[step] = power;
|
||||
stepArg[step] = arguments[step][0];
|
||||
}
|
||||
}
|
||||
|
||||
// Find groups that operate on the same argument and whose powers have the same sign.
|
||||
|
||||
stepGroup.resize(operation.size(), -1);
|
||||
for (int i = 0; i < (int)operation.size(); i++) {
|
||||
if (stepGroup[i] != -1)
|
||||
continue;
|
||||
vector<int> group, power;
|
||||
for (int j = i; j < (int)operation.size(); j++) {
|
||||
if (stepArg[i] == stepArg[j] && stepPower[i] * stepPower[j] > 0) {
|
||||
stepGroup[j] = groups.size();
|
||||
group.push_back(j);
|
||||
power.push_back(stepPower[j]);
|
||||
}
|
||||
}
|
||||
groups.push_back(group);
|
||||
groupPowers.push_back(power);
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(__ARM__) || defined(__ARM64__)
|
||||
|
||||
void CompiledVectorExpression::generateJitCode() {
|
||||
CodeHolder code;
|
||||
code.init(runtime.environment());
|
||||
a64::Compiler c(&code);
|
||||
c.addFunc(FuncSignatureT<void>());
|
||||
vector<arm::Vec> workspaceVar(workspace.size()/width);
|
||||
for (int i = 0; i < (int) workspaceVar.size(); i++)
|
||||
workspaceVar[i] = c.newVecQ();
|
||||
arm::Gp argsPointer = c.newIntPtr();
|
||||
c.mov(argsPointer, imm(&argValues[0]));
|
||||
vector<vector<int> > groups, groupPowers;
|
||||
vector<int> stepGroup;
|
||||
findPowerGroups(groups, groupPowers, stepGroup);
|
||||
|
||||
// Load the arguments into variables.
|
||||
|
||||
arm::Gp variablePointer = c.newIntPtr();
|
||||
for (set<string>::const_iterator iter = variableNames.begin(); iter != variableNames.end(); ++iter) {
|
||||
map<string, int>::iterator index = variableIndices.find(*iter);
|
||||
c.mov(variablePointer, imm(getVariablePointer(index->first)));
|
||||
c.ldr(workspaceVar[index->second].s4(), arm::ptr(variablePointer, 0));
|
||||
}
|
||||
|
||||
// Make a list of all constants that will be needed for evaluation.
|
||||
|
||||
vector<int> operationConstantIndex(operation.size(), -1);
|
||||
for (int step = 0; step < (int) operation.size(); step++) {
|
||||
// Find the constant value (if any) used by this operation.
|
||||
|
||||
Operation& op = *operation[step];
|
||||
float value;
|
||||
if (op.getId() == Operation::CONSTANT)
|
||||
value = dynamic_cast<Operation::Constant&> (op).getValue();
|
||||
else if (op.getId() == Operation::ADD_CONSTANT)
|
||||
value = dynamic_cast<Operation::AddConstant&> (op).getValue();
|
||||
else if (op.getId() == Operation::MULTIPLY_CONSTANT)
|
||||
value = dynamic_cast<Operation::MultiplyConstant&> (op).getValue();
|
||||
else if (op.getId() == Operation::RECIPROCAL)
|
||||
value = 1.0;
|
||||
else if (op.getId() == Operation::STEP)
|
||||
value = 1.0;
|
||||
else if (op.getId() == Operation::DELTA)
|
||||
value = 1.0;
|
||||
else if (op.getId() == Operation::POWER_CONSTANT) {
|
||||
if (stepGroup[step] == -1)
|
||||
value = dynamic_cast<Operation::PowerConstant&> (op).getValue();
|
||||
else
|
||||
value = 1.0;
|
||||
} else
|
||||
continue;
|
||||
|
||||
// See if we already have a variable for this constant.
|
||||
|
||||
for (int i = 0; i < (int) constants.size(); i++)
|
||||
if (value == constants[i]) {
|
||||
operationConstantIndex[step] = i;
|
||||
break;
|
||||
}
|
||||
if (operationConstantIndex[step] == -1) {
|
||||
operationConstantIndex[step] = constants.size();
|
||||
constants.push_back(value);
|
||||
}
|
||||
}
|
||||
|
||||
// Load constants into variables.
|
||||
|
||||
vector<arm::Vec> constantVar(constants.size());
|
||||
if (constants.size() > 0) {
|
||||
arm::Gp constantsPointer = c.newIntPtr();
|
||||
for (int i = 0; i < (int) constants.size(); i++) {
|
||||
c.mov(constantsPointer, imm(&constants[i]));
|
||||
constantVar[i] = c.newVecQ();
|
||||
c.ld1r(constantVar[i].s4(), arm::ptr(constantsPointer));
|
||||
}
|
||||
}
|
||||
|
||||
// Evaluate the operations.
|
||||
|
||||
vector<bool> hasComputedPower(operation.size(), false);
|
||||
arm::Vec argReg = c.newVecS();
|
||||
arm::Vec doubleArgReg = c.newVecD();
|
||||
arm::Vec doubleResultReg = c.newVecD();
|
||||
for (int step = 0; step < (int) operation.size(); step++) {
|
||||
if (hasComputedPower[step])
|
||||
continue;
|
||||
|
||||
// When one or more steps involve raising the same argument to multiple integer
|
||||
// powers, we can compute them all together for efficiency.
|
||||
|
||||
if (stepGroup[step] != -1) {
|
||||
vector<int>& group = groups[stepGroup[step]];
|
||||
vector<int>& powers = groupPowers[stepGroup[step]];
|
||||
arm::Vec multiplier = c.newVecQ();
|
||||
if (powers[0] > 0)
|
||||
c.mov(multiplier.s4(), workspaceVar[arguments[step][0]].s4());
|
||||
else {
|
||||
c.fdiv(multiplier.s4(), constantVar[operationConstantIndex[step]].s4(), workspaceVar[arguments[step][0]].s4());
|
||||
for (int i = 0; i < powers.size(); i++)
|
||||
powers[i] = -powers[i];
|
||||
}
|
||||
vector<bool> hasAssigned(group.size(), false);
|
||||
bool done = false;
|
||||
while (!done) {
|
||||
done = true;
|
||||
for (int i = 0; i < group.size(); i++) {
|
||||
if (powers[i] % 2 == 1) {
|
||||
if (!hasAssigned[i])
|
||||
c.mov(workspaceVar[target[group[i]]].s4(), multiplier.s4());
|
||||
else
|
||||
c.fmul(workspaceVar[target[group[i]]].s4(), workspaceVar[target[group[i]]].s4(), multiplier.s4());
|
||||
hasAssigned[i] = true;
|
||||
}
|
||||
powers[i] >>= 1;
|
||||
if (powers[i] != 0)
|
||||
done = false;
|
||||
}
|
||||
if (!done)
|
||||
c.fmul(multiplier.s4(), multiplier.s4(), multiplier.s4());
|
||||
}
|
||||
for (int step : group)
|
||||
hasComputedPower[step] = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Evaluate the step.
|
||||
|
||||
Operation& op = *operation[step];
|
||||
vector<int> args = arguments[step];
|
||||
if (args.size() == 1) {
|
||||
// One or more sequential arguments. Fill out the list.
|
||||
|
||||
for (int i = 1; i < op.getNumArguments(); i++)
|
||||
args.push_back(args[0] + i);
|
||||
}
|
||||
|
||||
// Generate instructions to execute this operation.
|
||||
|
||||
switch (op.getId()) {
|
||||
case Operation::CONSTANT:
|
||||
c.mov(workspaceVar[target[step]].s4(), constantVar[operationConstantIndex[step]].s4());
|
||||
break;
|
||||
case Operation::ADD:
|
||||
c.fadd(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[1]].s4());
|
||||
break;
|
||||
case Operation::SUBTRACT:
|
||||
c.fsub(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[1]].s4());
|
||||
break;
|
||||
case Operation::MULTIPLY:
|
||||
c.fmul(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[1]].s4());
|
||||
break;
|
||||
case Operation::DIVIDE:
|
||||
c.fdiv(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[1]].s4());
|
||||
break;
|
||||
case Operation::POWER:
|
||||
generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], powf);
|
||||
break;
|
||||
case Operation::NEGATE:
|
||||
c.fneg(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4());
|
||||
break;
|
||||
case Operation::SQRT:
|
||||
c.fsqrt(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4());
|
||||
break;
|
||||
case Operation::EXP:
|
||||
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], expf);
|
||||
break;
|
||||
case Operation::LOG:
|
||||
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], logf);
|
||||
break;
|
||||
case Operation::SIN:
|
||||
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], sinf);
|
||||
break;
|
||||
case Operation::COS:
|
||||
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], cosf);
|
||||
break;
|
||||
case Operation::TAN:
|
||||
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tanf);
|
||||
break;
|
||||
case Operation::ASIN:
|
||||
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], asinf);
|
||||
break;
|
||||
case Operation::ACOS:
|
||||
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], acosf);
|
||||
break;
|
||||
case Operation::ATAN:
|
||||
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], atanf);
|
||||
break;
|
||||
case Operation::ATAN2:
|
||||
generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], atan2f);
|
||||
break;
|
||||
case Operation::SINH:
|
||||
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], sinhf);
|
||||
break;
|
||||
case Operation::COSH:
|
||||
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], coshf);
|
||||
break;
|
||||
case Operation::TANH:
|
||||
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tanhf);
|
||||
break;
|
||||
case Operation::STEP:
|
||||
c.cmge(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), imm(0));
|
||||
c.and_(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
|
||||
break;
|
||||
case Operation::DELTA:
|
||||
c.cmeq(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), imm(0));
|
||||
c.and_(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
|
||||
break;
|
||||
case Operation::SQUARE:
|
||||
c.fmul(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[0]].s4());
|
||||
break;
|
||||
case Operation::CUBE:
|
||||
c.fmul(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[0]].s4());
|
||||
c.fmul(workspaceVar[target[step]].s4(), workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4());
|
||||
break;
|
||||
case Operation::RECIPROCAL:
|
||||
c.fdiv(workspaceVar[target[step]].s4(), constantVar[operationConstantIndex[step]].s4(), workspaceVar[args[0]].s4());
|
||||
break;
|
||||
case Operation::ADD_CONSTANT:
|
||||
c.fadd(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), constantVar[operationConstantIndex[step]].s4());
|
||||
break;
|
||||
case Operation::MULTIPLY_CONSTANT:
|
||||
c.fmul(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), constantVar[operationConstantIndex[step]].s4());
|
||||
break;
|
||||
case Operation::POWER_CONSTANT:
|
||||
generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]], powf);
|
||||
break;
|
||||
case Operation::MIN:
|
||||
c.fmin(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[1]].s4());
|
||||
break;
|
||||
case Operation::MAX:
|
||||
c.fmax(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[1]].s4());
|
||||
break;
|
||||
case Operation::ABS:
|
||||
c.fabs(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4());
|
||||
break;
|
||||
case Operation::FLOOR:
|
||||
c.frintm(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4());
|
||||
break;
|
||||
case Operation::CEIL:
|
||||
c.frintp(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4());
|
||||
break;
|
||||
case Operation::SELECT:
|
||||
c.fcmeq(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), imm(0));
|
||||
c.bsl(workspaceVar[target[step]], workspaceVar[args[2]], workspaceVar[args[1]]);
|
||||
break;
|
||||
default:
|
||||
// Just invoke evaluateOperation().
|
||||
for (int element = 0; element < width; element++) {
|
||||
for (int i = 0; i < (int) args.size(); i++) {
|
||||
c.ins(argReg.s(0), workspaceVar[args[i]].s(element));
|
||||
c.fcvt(doubleArgReg, argReg);
|
||||
c.str(doubleArgReg, arm::ptr(argsPointer, 8*i));
|
||||
}
|
||||
arm::Gp fn = c.newIntPtr();
|
||||
c.mov(fn, imm((void*) evaluateOperation));
|
||||
InvokeNode* invoke;
|
||||
c.invoke(&invoke, fn, FuncSignatureT<double, Operation*, double*>());
|
||||
invoke->setArg(0, imm(&op));
|
||||
invoke->setArg(1, imm(&argValues[0]));
|
||||
invoke->setRet(0, doubleResultReg);
|
||||
c.fcvt(argReg, doubleResultReg);
|
||||
c.ins(workspaceVar[target[step]].s(element), argReg.s(0));
|
||||
}
|
||||
}
|
||||
}
|
||||
arm::Gp resultPointer = c.newIntPtr();
|
||||
c.mov(resultPointer, imm(&workspace[workspace.size()-width]));
|
||||
c.str(workspaceVar.back().s4(), arm::ptr(resultPointer, 0));
|
||||
c.endFunc();
|
||||
c.finalize();
|
||||
runtime.add(&jitCode, &code);
|
||||
}
|
||||
|
||||
void CompiledVectorExpression::generateSingleArgCall(a64::Compiler& c, arm::Vec& dest, arm::Vec& arg, float (*function)(float)) {
|
||||
arm::Gp fn = c.newIntPtr();
|
||||
c.mov(fn, imm((void*) function));
|
||||
arm::Vec a = c.newVecS();
|
||||
arm::Vec d = c.newVecS();
|
||||
for (int element = 0; element < width; element++) {
|
||||
c.ins(a.s(0), arg.s(element));
|
||||
InvokeNode* invoke;
|
||||
c.invoke(&invoke, fn, FuncSignatureT<float, float>());
|
||||
invoke->setArg(0, a);
|
||||
invoke->setRet(0, d);
|
||||
c.ins(dest.s(element), d.s(0));
|
||||
}
|
||||
}
|
||||
|
||||
void CompiledVectorExpression::generateTwoArgCall(a64::Compiler& c, arm::Vec& dest, arm::Vec& arg1, arm::Vec& arg2, float (*function)(float, float)) {
|
||||
arm::Gp fn = c.newIntPtr();
|
||||
c.mov(fn, imm((void*) function));
|
||||
arm::Vec a1 = c.newVecS();
|
||||
arm::Vec a2 = c.newVecS();
|
||||
arm::Vec d = c.newVecS();
|
||||
for (int element = 0; element < width; element++) {
|
||||
c.ins(a1.s(0), arg1.s(element));
|
||||
c.ins(a2.s(0), arg2.s(element));
|
||||
InvokeNode* invoke;
|
||||
c.invoke(&invoke, fn, FuncSignatureT<float, float, float>());
|
||||
invoke->setArg(0, a1);
|
||||
invoke->setArg(1, a2);
|
||||
invoke->setRet(0, d);
|
||||
c.ins(dest.s(element), d.s(0));
|
||||
}
|
||||
}
|
||||
#else
|
||||
|
||||
void CompiledVectorExpression::generateJitCode() {
|
||||
const CpuInfo& cpu = CpuInfo::host();
|
||||
if (!cpu.hasFeature(CpuFeatures::X86::kAVX))
|
||||
return;
|
||||
CodeHolder code;
|
||||
code.init(runtime.environment());
|
||||
x86::Compiler c(&code);
|
||||
FuncNode* funcNode = c.addFunc(FuncSignatureT<void>());
|
||||
funcNode->frame().setAvxEnabled();
|
||||
vector<x86::Ymm> workspaceVar(workspace.size()/width);
|
||||
for (int i = 0; i < (int) workspaceVar.size(); i++)
|
||||
workspaceVar[i] = c.newYmmPs();
|
||||
x86::Gp argsPointer = c.newIntPtr();
|
||||
c.mov(argsPointer, imm(&argValues[0]));
|
||||
vector<vector<int> > groups, groupPowers;
|
||||
vector<int> stepGroup;
|
||||
findPowerGroups(groups, groupPowers, stepGroup);
|
||||
|
||||
// Load the arguments into variables.
|
||||
|
||||
for (set<string>::const_iterator iter = variableNames.begin(); iter != variableNames.end(); ++iter) {
|
||||
map<string, int>::iterator index = variableIndices.find(*iter);
|
||||
x86::Gp variablePointer = c.newIntPtr();
|
||||
c.mov(variablePointer, imm(getVariablePointer(index->first)));
|
||||
if (width == 4)
|
||||
c.vmovdqu(workspaceVar[index->second].xmm(), x86::ptr(variablePointer, 0, 0));
|
||||
else
|
||||
c.vmovdqu(workspaceVar[index->second], x86::ptr(variablePointer, 0, 0));
|
||||
}
|
||||
|
||||
// Make a list of all constants that will be needed for evaluation.
|
||||
|
||||
vector<int> operationConstantIndex(operation.size(), -1);
|
||||
for (int step = 0; step < (int) operation.size(); step++) {
|
||||
// Find the constant value (if any) used by this operation.
|
||||
|
||||
Operation& op = *operation[step];
|
||||
double value;
|
||||
if (op.getId() == Operation::CONSTANT)
|
||||
value = dynamic_cast<Operation::Constant&> (op).getValue();
|
||||
else if (op.getId() == Operation::ADD_CONSTANT)
|
||||
value = dynamic_cast<Operation::AddConstant&> (op).getValue();
|
||||
else if (op.getId() == Operation::MULTIPLY_CONSTANT)
|
||||
value = dynamic_cast<Operation::MultiplyConstant&> (op).getValue();
|
||||
else if (op.getId() == Operation::RECIPROCAL)
|
||||
value = 1.0;
|
||||
else if (op.getId() == Operation::STEP)
|
||||
value = 1.0;
|
||||
else if (op.getId() == Operation::DELTA)
|
||||
value = 1.0;
|
||||
else if (op.getId() == Operation::ABS) {
|
||||
int mask = 0x7FFFFFFF;
|
||||
value = *reinterpret_cast<float*>(&mask);
|
||||
}
|
||||
else if (op.getId() == Operation::POWER_CONSTANT) {
|
||||
if (stepGroup[step] == -1)
|
||||
value = dynamic_cast<Operation::PowerConstant&> (op).getValue();
|
||||
else
|
||||
value = 1.0;
|
||||
} else
|
||||
continue;
|
||||
|
||||
// See if we already have a variable for this constant.
|
||||
|
||||
for (int i = 0; i < (int) constants.size(); i++)
|
||||
if (value == constants[i]) {
|
||||
operationConstantIndex[step] = i;
|
||||
break;
|
||||
}
|
||||
if (operationConstantIndex[step] == -1) {
|
||||
operationConstantIndex[step] = constants.size();
|
||||
constants.push_back(value);
|
||||
}
|
||||
}
|
||||
|
||||
// Load constants into variables.
|
||||
|
||||
vector<x86::Ymm> constantVar(constants.size());
|
||||
if (constants.size() > 0) {
|
||||
x86::Gp constantsPointer = c.newIntPtr();
|
||||
c.mov(constantsPointer, imm(&constants[0]));
|
||||
for (int i = 0; i < (int) constants.size(); i++) {
|
||||
constantVar[i] = c.newYmmPs();
|
||||
c.vbroadcastss(constantVar[i], x86::ptr(constantsPointer, 4*i, 0));
|
||||
}
|
||||
}
|
||||
|
||||
// Evaluate the operations.
|
||||
|
||||
vector<bool> hasComputedPower(operation.size(), false);
|
||||
x86::Ymm argReg = c.newYmm();
|
||||
x86::Ymm doubleArgReg = c.newYmm();
|
||||
x86::Ymm doubleResultReg = c.newYmm();
|
||||
for (int step = 0; step < (int) operation.size(); step++) {
|
||||
if (hasComputedPower[step])
|
||||
continue;
|
||||
|
||||
// When one or more steps involve raising the same argument to multiple integer
|
||||
// powers, we can compute them all together for efficiency.
|
||||
|
||||
if (stepGroup[step] != -1) {
|
||||
vector<int>& group = groups[stepGroup[step]];
|
||||
vector<int>& powers = groupPowers[stepGroup[step]];
|
||||
x86::Ymm multiplier = c.newYmmPs();
|
||||
if (powers[0] > 0)
|
||||
c.vmovdqu(multiplier, workspaceVar[arguments[step][0]]);
|
||||
else {
|
||||
c.vdivps(multiplier, constantVar[operationConstantIndex[step]], workspaceVar[arguments[step][0]]);
|
||||
for (int i = 0; i < (int)powers.size(); i++)
|
||||
powers[i] = -powers[i];
|
||||
}
|
||||
vector<bool> hasAssigned(group.size(), false);
|
||||
bool done = false;
|
||||
while (!done) {
|
||||
done = true;
|
||||
for (int i = 0; i < (int)group.size(); i++) {
|
||||
if (powers[i] % 2 == 1) {
|
||||
if (!hasAssigned[i])
|
||||
c.vmovdqu(workspaceVar[target[group[i]]], multiplier);
|
||||
else
|
||||
c.vmulps(workspaceVar[target[group[i]]], workspaceVar[target[group[i]]], multiplier);
|
||||
hasAssigned[i] = true;
|
||||
}
|
||||
powers[i] >>= 1;
|
||||
if (powers[i] != 0)
|
||||
done = false;
|
||||
}
|
||||
if (!done)
|
||||
c.vmulps(multiplier, multiplier, multiplier);
|
||||
}
|
||||
for (int step : group)
|
||||
hasComputedPower[step] = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Evaluate the step.
|
||||
|
||||
Operation& op = *operation[step];
|
||||
vector<int> args = arguments[step];
|
||||
if (args.size() == 1) {
|
||||
// One or more sequential arguments. Fill out the list.
|
||||
|
||||
for (int i = 1; i < op.getNumArguments(); i++)
|
||||
args.push_back(args[0] + i);
|
||||
}
|
||||
|
||||
// Generate instructions to execute this operation.
|
||||
|
||||
switch (op.getId()) {
|
||||
case Operation::CONSTANT:
|
||||
c.vmovdqu(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
|
||||
break;
|
||||
case Operation::ADD:
|
||||
c.vaddps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
|
||||
break;
|
||||
case Operation::SUBTRACT:
|
||||
c.vsubps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
|
||||
break;
|
||||
case Operation::MULTIPLY:
|
||||
c.vmulps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
|
||||
break;
|
||||
case Operation::DIVIDE:
|
||||
c.vdivps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
|
||||
break;
|
||||
case Operation::POWER:
|
||||
generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], powf);
|
||||
break;
|
||||
case Operation::NEGATE:
|
||||
c.vxorps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[target[step]]);
|
||||
c.vsubps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]]);
|
||||
break;
|
||||
case Operation::SQRT:
|
||||
c.vsqrtps(workspaceVar[target[step]], workspaceVar[args[0]]);
|
||||
break;
|
||||
case Operation::EXP:
|
||||
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], expf);
|
||||
break;
|
||||
case Operation::LOG:
|
||||
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], logf);
|
||||
break;
|
||||
case Operation::SIN:
|
||||
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], sinf);
|
||||
break;
|
||||
case Operation::COS:
|
||||
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], cosf);
|
||||
break;
|
||||
case Operation::TAN:
|
||||
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tanf);
|
||||
break;
|
||||
case Operation::ASIN:
|
||||
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], asinf);
|
||||
break;
|
||||
case Operation::ACOS:
|
||||
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], acosf);
|
||||
break;
|
||||
case Operation::ATAN:
|
||||
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], atanf);
|
||||
break;
|
||||
case Operation::ATAN2:
|
||||
generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], atan2f);
|
||||
break;
|
||||
case Operation::SINH:
|
||||
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], sinhf);
|
||||
break;
|
||||
case Operation::COSH:
|
||||
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], coshf);
|
||||
break;
|
||||
case Operation::TANH:
|
||||
generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tanhf);
|
||||
break;
|
||||
case Operation::STEP:
|
||||
c.vxorps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[target[step]]);
|
||||
c.vcmpps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]], imm(18)); // Comparison mode is _CMP_LE_OQ = 18
|
||||
c.vandps(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
|
||||
break;
|
||||
case Operation::DELTA:
|
||||
c.vxorps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[target[step]]);
|
||||
c.vcmpps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]], imm(16)); // Comparison mode is _CMP_EQ_OQ = 0
|
||||
c.vandps(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
|
||||
break;
|
||||
case Operation::SQUARE:
|
||||
c.vmulps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]]);
|
||||
break;
|
||||
case Operation::CUBE:
|
||||
c.vmulps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]]);
|
||||
c.vmulps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]]);
|
||||
break;
|
||||
case Operation::RECIPROCAL:
|
||||
c.vdivps(workspaceVar[target[step]], constantVar[operationConstantIndex[step]], workspaceVar[args[0]]);
|
||||
break;
|
||||
case Operation::ADD_CONSTANT:
|
||||
c.vaddps(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]);
|
||||
break;
|
||||
case Operation::MULTIPLY_CONSTANT:
|
||||
c.vmulps(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]);
|
||||
break;
|
||||
case Operation::POWER_CONSTANT:
|
||||
generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]], powf);
|
||||
break;
|
||||
case Operation::MIN:
|
||||
c.vminps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
|
||||
break;
|
||||
case Operation::MAX:
|
||||
c.vmaxps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]);
|
||||
break;
|
||||
case Operation::ABS:
|
||||
c.vandps(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]);
|
||||
break;
|
||||
case Operation::FLOOR:
|
||||
c.vroundps(workspaceVar[target[step]], workspaceVar[args[0]], imm(1));
|
||||
break;
|
||||
case Operation::CEIL:
|
||||
c.vroundps(workspaceVar[target[step]], workspaceVar[args[0]], imm(2));
|
||||
break;
|
||||
case Operation::SELECT:
|
||||
{
|
||||
x86::Ymm mask = c.newYmmPs();
|
||||
c.vxorps(mask, mask, mask);
|
||||
c.vcmpps(mask, mask, workspaceVar[args[0]], imm(0)); // Comparison mode is _CMP_EQ_OQ = 0
|
||||
c.vblendvps(workspaceVar[target[step]], workspaceVar[args[1]], workspaceVar[args[2]], mask);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
// Just invoke evaluateOperation().
|
||||
|
||||
for (int element = 0; element < width; element++) {
|
||||
for (int i = 0; i < (int) args.size(); i++) {
|
||||
if (element < 4)
|
||||
c.vshufps(argReg, workspaceVar[args[i]], workspaceVar[args[i]], imm(element));
|
||||
else {
|
||||
c.vperm2f128(argReg, workspaceVar[args[i]], workspaceVar[args[i]], imm(1));
|
||||
c.vshufps(argReg, argReg, argReg, imm(element-4));
|
||||
}
|
||||
c.vcvtss2sd(doubleArgReg.xmm(), doubleArgReg.xmm(), argReg.xmm());
|
||||
c.vmovsd(x86::ptr(argsPointer, 8*i, 0), doubleArgReg.xmm());
|
||||
}
|
||||
x86::Gp fn = c.newIntPtr();
|
||||
c.mov(fn, imm((void*) evaluateOperation));
|
||||
InvokeNode* invoke;
|
||||
c.invoke(&invoke, fn, FuncSignatureT<double, Operation*, double*>());
|
||||
invoke->setArg(0, imm(&op));
|
||||
invoke->setArg(1, imm(&argValues[0]));
|
||||
invoke->setRet(0, doubleResultReg);
|
||||
c.vcvtsd2ss(argReg.xmm(), argReg.xmm(), doubleResultReg.xmm());
|
||||
if (element > 3)
|
||||
c.vperm2f128(argReg, argReg, argReg, imm(0));
|
||||
if (element != 0)
|
||||
c.vshufps(argReg, argReg, argReg, imm(0));
|
||||
c.vblendps(workspaceVar[target[step]], workspaceVar[target[step]], argReg, 1<<element);
|
||||
}
|
||||
}
|
||||
}
|
||||
x86::Gp resultPointer = c.newIntPtr();
|
||||
c.mov(resultPointer, imm(&workspace[workspace.size()-width]));
|
||||
if (width == 4)
|
||||
c.vmovdqu(x86::ptr(resultPointer, 0, 0), workspaceVar.back().xmm());
|
||||
else
|
||||
c.vmovdqu(x86::ptr(resultPointer, 0, 0), workspaceVar.back());
|
||||
c.endFunc();
|
||||
c.finalize();
|
||||
runtime.add(&jitCode, &code);
|
||||
}
|
||||
|
||||
void CompiledVectorExpression::generateSingleArgCall(x86::Compiler& c, x86::Ymm& dest, x86::Ymm& arg, float (*function)(float)) {
|
||||
x86::Gp fn = c.newIntPtr();
|
||||
c.mov(fn, imm((void*) function));
|
||||
x86::Ymm a = c.newYmm();
|
||||
x86::Ymm d = c.newYmm();
|
||||
for (int element = 0; element < width; element++) {
|
||||
if (element < 4)
|
||||
c.vshufps(a, arg, arg, imm(element));
|
||||
else {
|
||||
c.vperm2f128(a, arg, arg, imm(1));
|
||||
c.vshufps(a, a, a, imm(element-4));
|
||||
}
|
||||
InvokeNode* invoke;
|
||||
c.invoke(&invoke, fn, FuncSignatureT<float, float>());
|
||||
invoke->setArg(0, a);
|
||||
invoke->setRet(0, d);
|
||||
if (element > 3)
|
||||
c.vperm2f128(d, d, d, imm(0));
|
||||
if (element != 0)
|
||||
c.vshufps(d, d, d, imm(0));
|
||||
c.vblendps(dest, dest, d, 1<<element);
|
||||
}
|
||||
}
|
||||
|
||||
void CompiledVectorExpression::generateTwoArgCall(x86::Compiler& c, x86::Ymm& dest, x86::Ymm& arg1, x86::Ymm& arg2, float (*function)(float, float)) {
|
||||
x86::Gp fn = c.newIntPtr();
|
||||
c.mov(fn, imm((void*) function));
|
||||
x86::Ymm a1 = c.newYmm();
|
||||
x86::Ymm a2 = c.newYmm();
|
||||
x86::Ymm d = c.newYmm();
|
||||
for (int element = 0; element < width; element++) {
|
||||
if (element < 4) {
|
||||
c.vshufps(a1, arg1, arg1, imm(element));
|
||||
c.vshufps(a2, arg2, arg2, imm(element));
|
||||
}
|
||||
else {
|
||||
c.vperm2f128(a1, arg1, arg1, imm(1));
|
||||
c.vperm2f128(a2, arg2, arg2, imm(1));
|
||||
c.vshufps(a1, a1, a1, imm(element-4));
|
||||
c.vshufps(a2, a2, a2, imm(element-4));
|
||||
}
|
||||
InvokeNode* invoke;
|
||||
c.invoke(&invoke, fn, FuncSignatureT<float, float, float>());
|
||||
invoke->setArg(0, a1);
|
||||
invoke->setArg(1, a2);
|
||||
invoke->setRet(0, d);
|
||||
if (element > 3)
|
||||
c.vperm2f128(d, d, d, imm(0));
|
||||
if (element != 0)
|
||||
c.vshufps(d, d, d, imm(0));
|
||||
c.vblendps(dest, dest, d, 1<<element);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
@ -6,7 +6,7 @@
|
||||
* Biological Structures at Stanford, funded under the NIH Roadmap for *
|
||||
* Medical Research, grant U54 GM072970. See https://simtk.org. *
|
||||
* *
|
||||
* Portions copyright (c) 2009-2015 Stanford University and the Authors. *
|
||||
* Portions copyright (c) 2009-2021 Stanford University and the Authors. *
|
||||
* Authors: Peter Eastman *
|
||||
* Contributors: *
|
||||
* *
|
||||
@ -32,6 +32,7 @@
|
||||
#include "lepton/ExpressionTreeNode.h"
|
||||
#include "lepton/Exception.h"
|
||||
#include "lepton/Operation.h"
|
||||
#include <utility>
|
||||
|
||||
using namespace LMP_Lepton;
|
||||
using namespace std;
|
||||
@ -62,6 +63,11 @@ ExpressionTreeNode::ExpressionTreeNode(Operation* operation) : operation(operati
|
||||
ExpressionTreeNode::ExpressionTreeNode(const ExpressionTreeNode& node) : operation(node.operation == NULL ? NULL : node.operation->clone()), children(node.getChildren()) {
|
||||
}
|
||||
|
||||
ExpressionTreeNode::ExpressionTreeNode(ExpressionTreeNode&& node) : operation(node.operation), children(move(node.children)) {
|
||||
node.operation = NULL;
|
||||
node.children.clear();
|
||||
}
|
||||
|
||||
ExpressionTreeNode::ExpressionTreeNode() : operation(NULL) {
|
||||
}
|
||||
|
||||
@ -98,6 +104,16 @@ ExpressionTreeNode& ExpressionTreeNode::operator=(const ExpressionTreeNode& node
|
||||
return *this;
|
||||
}
|
||||
|
||||
ExpressionTreeNode& ExpressionTreeNode::operator=(ExpressionTreeNode&& node) {
|
||||
if (operation != NULL)
|
||||
delete operation;
|
||||
operation = node.operation;
|
||||
children = move(node.children);
|
||||
node.operation = NULL;
|
||||
node.children.clear();
|
||||
return *this;
|
||||
}
|
||||
|
||||
const Operation& ExpressionTreeNode::getOperation() const {
|
||||
return *operation;
|
||||
}
|
||||
@ -105,3 +121,33 @@ const Operation& ExpressionTreeNode::getOperation() const {
|
||||
const vector<ExpressionTreeNode>& ExpressionTreeNode::getChildren() const {
|
||||
return children;
|
||||
}
|
||||
|
||||
void ExpressionTreeNode::assignTags(vector<const ExpressionTreeNode*>& examples) const {
|
||||
// Assign tag values to all nodes in a tree, such that two nodes have the same
|
||||
// tag if and only if they (and all their children) are equal. This is used to
|
||||
// optimize other operations.
|
||||
|
||||
int numTags = examples.size();
|
||||
for (const ExpressionTreeNode& child : getChildren())
|
||||
child.assignTags(examples);
|
||||
if (numTags == (int)examples.size()) {
|
||||
// All the children matched existing tags, so possibly this node does too.
|
||||
|
||||
for (int i = 0; i < (int)examples.size(); i++) {
|
||||
const ExpressionTreeNode& example = *examples[i];
|
||||
bool matches = (getChildren().size() == example.getChildren().size() && getOperation() == example.getOperation());
|
||||
for (int j = 0; matches && j < (int)getChildren().size(); j++)
|
||||
if (getChildren()[j].tag != example.getChildren()[j].tag)
|
||||
matches = false;
|
||||
if (matches) {
|
||||
tag = i;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This node does not match any previous node, so assign a new tag.
|
||||
|
||||
tag = examples.size();
|
||||
examples.push_back(this);
|
||||
}
|
||||
|
||||
@ -7,7 +7,7 @@
|
||||
* Biological Structures at Stanford, funded under the NIH Roadmap for *
|
||||
* Medical Research, grant U54 GM072970. See https://simtk.org. *
|
||||
* *
|
||||
* Portions copyright (c) 2009-2019 Stanford University and the Authors. *
|
||||
* Portions copyright (c) 2009-2021 Stanford University and the Authors. *
|
||||
* Authors: Peter Eastman *
|
||||
* Contributors: *
|
||||
* *
|
||||
@ -37,7 +37,13 @@
|
||||
using namespace LMP_Lepton;
|
||||
using namespace std;
|
||||
|
||||
double Operation::Erf::evaluate(double* args, const map<string, double>& ) const {
|
||||
static bool isZero(const ExpressionTreeNode& node) {
|
||||
if (node.getOperation().getId() != Operation::CONSTANT)
|
||||
return false;
|
||||
return dynamic_cast<const Operation::Constant&>(node.getOperation()).getValue() == 0.0;
|
||||
}
|
||||
|
||||
double Operation::Erf::evaluate(double* args, const map<string, double>&) const {
|
||||
return erf(args[0]);
|
||||
}
|
||||
|
||||
@ -58,35 +64,71 @@ ExpressionTreeNode Operation::Variable::differentiate(const std::vector<Expressi
|
||||
ExpressionTreeNode Operation::Custom::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& ) const {
|
||||
if (function->getNumArguments() == 0)
|
||||
return ExpressionTreeNode(new Operation::Constant(0.0));
|
||||
ExpressionTreeNode result = ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Custom(*this, 0), children), childDerivs[0]);
|
||||
for (int i = 1; i < getNumArguments(); i++) {
|
||||
result = ExpressionTreeNode(new Operation::Add(),
|
||||
result,
|
||||
ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Custom(*this, i), children), childDerivs[i]));
|
||||
ExpressionTreeNode result;
|
||||
bool foundTerm = false;
|
||||
for (int i = 0; i < getNumArguments(); i++) {
|
||||
if (!isZero(childDerivs[i])) {
|
||||
if (foundTerm)
|
||||
result = ExpressionTreeNode(new Operation::Add(),
|
||||
result,
|
||||
ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Custom(*this, i), children), childDerivs[i]));
|
||||
else {
|
||||
result = ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Custom(*this, i), children), childDerivs[i]);
|
||||
foundTerm = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return result;
|
||||
if (foundTerm)
|
||||
return result;
|
||||
return ExpressionTreeNode(new Operation::Constant(0.0));
|
||||
}
|
||||
|
||||
ExpressionTreeNode Operation::Add::differentiate(const std::vector<ExpressionTreeNode>& , const std::vector<ExpressionTreeNode>& childDerivs, const std::string& ) const {
|
||||
if (isZero(childDerivs[0]))
|
||||
return childDerivs[1];
|
||||
if (isZero(childDerivs[1]))
|
||||
return childDerivs[0];
|
||||
return ExpressionTreeNode(new Operation::Add(), childDerivs[0], childDerivs[1]);
|
||||
}
|
||||
|
||||
ExpressionTreeNode Operation::Subtract::differentiate(const std::vector<ExpressionTreeNode>& , const std::vector<ExpressionTreeNode>& childDerivs, const std::string& ) const {
|
||||
if (isZero(childDerivs[0])) {
|
||||
if (isZero(childDerivs[1]))
|
||||
return ExpressionTreeNode(new Operation::Constant(0.0));
|
||||
return ExpressionTreeNode(new Operation::Negate(), childDerivs[1]);
|
||||
}
|
||||
if (isZero(childDerivs[1]))
|
||||
return childDerivs[0];
|
||||
return ExpressionTreeNode(new Operation::Subtract(), childDerivs[0], childDerivs[1]);
|
||||
}
|
||||
|
||||
ExpressionTreeNode Operation::Multiply::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& ) const {
|
||||
if (isZero(childDerivs[0])) {
|
||||
if (isZero(childDerivs[1]))
|
||||
return ExpressionTreeNode(new Operation::Constant(0.0));
|
||||
return ExpressionTreeNode(new Operation::Multiply(), children[0], childDerivs[1]);
|
||||
}
|
||||
if (isZero(childDerivs[1]))
|
||||
return ExpressionTreeNode(new Operation::Multiply(), children[1], childDerivs[0]);
|
||||
return ExpressionTreeNode(new Operation::Add(),
|
||||
ExpressionTreeNode(new Operation::Multiply(), children[0], childDerivs[1]),
|
||||
ExpressionTreeNode(new Operation::Multiply(), children[1], childDerivs[0]));
|
||||
}
|
||||
|
||||
ExpressionTreeNode Operation::Divide::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& ) const {
|
||||
return ExpressionTreeNode(new Operation::Divide(),
|
||||
ExpressionTreeNode(new Operation::Subtract(),
|
||||
ExpressionTreeNode subexp;
|
||||
if (isZero(childDerivs[0])) {
|
||||
if (isZero(childDerivs[1]))
|
||||
return ExpressionTreeNode(new Operation::Constant(0.0));
|
||||
subexp = ExpressionTreeNode(new Operation::Negate(), ExpressionTreeNode(new Operation::Multiply(), children[0], childDerivs[1]));
|
||||
}
|
||||
else if (isZero(childDerivs[1]))
|
||||
subexp = ExpressionTreeNode(new Operation::Multiply(), children[1], childDerivs[0]);
|
||||
else
|
||||
subexp = ExpressionTreeNode(new Operation::Subtract(),
|
||||
ExpressionTreeNode(new Operation::Multiply(), children[1], childDerivs[0]),
|
||||
ExpressionTreeNode(new Operation::Multiply(), children[0], childDerivs[1])),
|
||||
ExpressionTreeNode(new Operation::Square(), children[1]));
|
||||
ExpressionTreeNode(new Operation::Multiply(), children[0], childDerivs[1]));
|
||||
return ExpressionTreeNode(new Operation::Divide(), subexp, ExpressionTreeNode(new Operation::Square(), children[1]));
|
||||
}
|
||||
|
||||
ExpressionTreeNode Operation::Power::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& ) const {
|
||||
@ -105,10 +147,14 @@ ExpressionTreeNode Operation::Power::differentiate(const std::vector<ExpressionT
|
||||
}
|
||||
|
||||
ExpressionTreeNode Operation::Negate::differentiate(const std::vector<ExpressionTreeNode>& , const std::vector<ExpressionTreeNode>& childDerivs, const std::string& ) const {
|
||||
if (isZero(childDerivs[0]))
|
||||
return ExpressionTreeNode(new Operation::Constant(0.0));
|
||||
return ExpressionTreeNode(new Operation::Negate(), childDerivs[0]);
|
||||
}
|
||||
|
||||
ExpressionTreeNode Operation::Sqrt::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& ) const {
|
||||
if (isZero(childDerivs[0]))
|
||||
return ExpressionTreeNode(new Operation::Constant(0.0));
|
||||
return ExpressionTreeNode(new Operation::Multiply(),
|
||||
ExpressionTreeNode(new Operation::MultiplyConstant(0.5),
|
||||
ExpressionTreeNode(new Operation::Reciprocal(),
|
||||
@ -117,24 +163,32 @@ ExpressionTreeNode Operation::Sqrt::differentiate(const std::vector<ExpressionTr
|
||||
}
|
||||
|
||||
ExpressionTreeNode Operation::Exp::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& ) const {
|
||||
if (isZero(childDerivs[0]))
|
||||
return ExpressionTreeNode(new Operation::Constant(0.0));
|
||||
return ExpressionTreeNode(new Operation::Multiply(),
|
||||
ExpressionTreeNode(new Operation::Exp(), children[0]),
|
||||
childDerivs[0]);
|
||||
}
|
||||
|
||||
ExpressionTreeNode Operation::Log::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& ) const {
|
||||
if (isZero(childDerivs[0]))
|
||||
return ExpressionTreeNode(new Operation::Constant(0.0));
|
||||
return ExpressionTreeNode(new Operation::Multiply(),
|
||||
ExpressionTreeNode(new Operation::Reciprocal(), children[0]),
|
||||
childDerivs[0]);
|
||||
}
|
||||
|
||||
ExpressionTreeNode Operation::Sin::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& ) const {
|
||||
if (isZero(childDerivs[0]))
|
||||
return ExpressionTreeNode(new Operation::Constant(0.0));
|
||||
return ExpressionTreeNode(new Operation::Multiply(),
|
||||
ExpressionTreeNode(new Operation::Cos(), children[0]),
|
||||
childDerivs[0]);
|
||||
}
|
||||
|
||||
ExpressionTreeNode Operation::Cos::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& ) const {
|
||||
if (isZero(childDerivs[0]))
|
||||
return ExpressionTreeNode(new Operation::Constant(0.0));
|
||||
return ExpressionTreeNode(new Operation::Multiply(),
|
||||
ExpressionTreeNode(new Operation::Negate(),
|
||||
ExpressionTreeNode(new Operation::Sin(), children[0])),
|
||||
@ -142,6 +196,8 @@ ExpressionTreeNode Operation::Cos::differentiate(const std::vector<ExpressionTre
|
||||
}
|
||||
|
||||
ExpressionTreeNode Operation::Sec::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& ) const {
|
||||
if (isZero(childDerivs[0]))
|
||||
return ExpressionTreeNode(new Operation::Constant(0.0));
|
||||
return ExpressionTreeNode(new Operation::Multiply(),
|
||||
ExpressionTreeNode(new Operation::Multiply(),
|
||||
ExpressionTreeNode(new Operation::Sec(), children[0]),
|
||||
@ -150,6 +206,8 @@ ExpressionTreeNode Operation::Sec::differentiate(const std::vector<ExpressionTre
|
||||
}
|
||||
|
||||
ExpressionTreeNode Operation::Csc::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& ) const {
|
||||
if (isZero(childDerivs[0]))
|
||||
return ExpressionTreeNode(new Operation::Constant(0.0));
|
||||
return ExpressionTreeNode(new Operation::Multiply(),
|
||||
ExpressionTreeNode(new Operation::Negate(),
|
||||
ExpressionTreeNode(new Operation::Multiply(),
|
||||
@ -159,6 +217,8 @@ ExpressionTreeNode Operation::Csc::differentiate(const std::vector<ExpressionTre
|
||||
}
|
||||
|
||||
ExpressionTreeNode Operation::Tan::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& ) const {
|
||||
if (isZero(childDerivs[0]))
|
||||
return ExpressionTreeNode(new Operation::Constant(0.0));
|
||||
return ExpressionTreeNode(new Operation::Multiply(),
|
||||
ExpressionTreeNode(new Operation::Square(),
|
||||
ExpressionTreeNode(new Operation::Sec(), children[0])),
|
||||
@ -166,6 +226,8 @@ ExpressionTreeNode Operation::Tan::differentiate(const std::vector<ExpressionTre
|
||||
}
|
||||
|
||||
ExpressionTreeNode Operation::Cot::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& ) const {
|
||||
if (isZero(childDerivs[0]))
|
||||
return ExpressionTreeNode(new Operation::Constant(0.0));
|
||||
return ExpressionTreeNode(new Operation::Multiply(),
|
||||
ExpressionTreeNode(new Operation::Negate(),
|
||||
ExpressionTreeNode(new Operation::Square(),
|
||||
@ -174,6 +236,8 @@ ExpressionTreeNode Operation::Cot::differentiate(const std::vector<ExpressionTre
|
||||
}
|
||||
|
||||
ExpressionTreeNode Operation::Asin::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& ) const {
|
||||
if (isZero(childDerivs[0]))
|
||||
return ExpressionTreeNode(new Operation::Constant(0.0));
|
||||
return ExpressionTreeNode(new Operation::Multiply(),
|
||||
ExpressionTreeNode(new Operation::Reciprocal(),
|
||||
ExpressionTreeNode(new Operation::Sqrt(),
|
||||
@ -184,6 +248,8 @@ ExpressionTreeNode Operation::Asin::differentiate(const std::vector<ExpressionTr
|
||||
}
|
||||
|
||||
ExpressionTreeNode Operation::Acos::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& ) const {
|
||||
if (isZero(childDerivs[0]))
|
||||
return ExpressionTreeNode(new Operation::Constant(0.0));
|
||||
return ExpressionTreeNode(new Operation::Multiply(),
|
||||
ExpressionTreeNode(new Operation::Negate(),
|
||||
ExpressionTreeNode(new Operation::Reciprocal(),
|
||||
@ -195,6 +261,8 @@ ExpressionTreeNode Operation::Acos::differentiate(const std::vector<ExpressionTr
|
||||
}
|
||||
|
||||
ExpressionTreeNode Operation::Atan::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& ) const {
|
||||
if (isZero(childDerivs[0]))
|
||||
return ExpressionTreeNode(new Operation::Constant(0.0));
|
||||
return ExpressionTreeNode(new Operation::Multiply(),
|
||||
ExpressionTreeNode(new Operation::Reciprocal(),
|
||||
ExpressionTreeNode(new Operation::AddConstant(1.0),
|
||||
@ -213,6 +281,8 @@ ExpressionTreeNode Operation::Atan2::differentiate(const std::vector<ExpressionT
|
||||
}
|
||||
|
||||
ExpressionTreeNode Operation::Sinh::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& ) const {
|
||||
if (isZero(childDerivs[0]))
|
||||
return ExpressionTreeNode(new Operation::Constant(0.0));
|
||||
return ExpressionTreeNode(new Operation::Multiply(),
|
||||
ExpressionTreeNode(new Operation::Cosh(),
|
||||
children[0]),
|
||||
@ -220,6 +290,8 @@ ExpressionTreeNode Operation::Sinh::differentiate(const std::vector<ExpressionTr
|
||||
}
|
||||
|
||||
ExpressionTreeNode Operation::Cosh::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& ) const {
|
||||
if (isZero(childDerivs[0]))
|
||||
return ExpressionTreeNode(new Operation::Constant(0.0));
|
||||
return ExpressionTreeNode(new Operation::Multiply(),
|
||||
ExpressionTreeNode(new Operation::Sinh(),
|
||||
children[0]),
|
||||
@ -227,6 +299,8 @@ ExpressionTreeNode Operation::Cosh::differentiate(const std::vector<ExpressionTr
|
||||
}
|
||||
|
||||
ExpressionTreeNode Operation::Tanh::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& ) const {
|
||||
if (isZero(childDerivs[0]))
|
||||
return ExpressionTreeNode(new Operation::Constant(0.0));
|
||||
return ExpressionTreeNode(new Operation::Multiply(),
|
||||
ExpressionTreeNode(new Operation::Subtract(),
|
||||
ExpressionTreeNode(new Operation::Constant(1.0)),
|
||||
@ -236,6 +310,8 @@ ExpressionTreeNode Operation::Tanh::differentiate(const std::vector<ExpressionTr
|
||||
}
|
||||
|
||||
ExpressionTreeNode Operation::Erf::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& ) const {
|
||||
if (isZero(childDerivs[0]))
|
||||
return ExpressionTreeNode(new Operation::Constant(0.0));
|
||||
return ExpressionTreeNode(new Operation::Multiply(),
|
||||
ExpressionTreeNode(new Operation::Multiply(),
|
||||
ExpressionTreeNode(new Operation::Constant(2.0/sqrt(M_PI))),
|
||||
@ -246,6 +322,8 @@ ExpressionTreeNode Operation::Erf::differentiate(const std::vector<ExpressionTre
|
||||
}
|
||||
|
||||
ExpressionTreeNode Operation::Erfc::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& ) const {
|
||||
if (isZero(childDerivs[0]))
|
||||
return ExpressionTreeNode(new Operation::Constant(0.0));
|
||||
return ExpressionTreeNode(new Operation::Multiply(),
|
||||
ExpressionTreeNode(new Operation::Multiply(),
|
||||
ExpressionTreeNode(new Operation::Constant(-2.0/sqrt(M_PI))),
|
||||
@ -264,6 +342,8 @@ ExpressionTreeNode Operation::Delta::differentiate(const std::vector<ExpressionT
|
||||
}
|
||||
|
||||
ExpressionTreeNode Operation::Square::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& ) const {
|
||||
if (isZero(childDerivs[0]))
|
||||
return ExpressionTreeNode(new Operation::Constant(0.0));
|
||||
return ExpressionTreeNode(new Operation::Multiply(),
|
||||
ExpressionTreeNode(new Operation::MultiplyConstant(2.0),
|
||||
children[0]),
|
||||
@ -271,6 +351,8 @@ ExpressionTreeNode Operation::Square::differentiate(const std::vector<Expression
|
||||
}
|
||||
|
||||
ExpressionTreeNode Operation::Cube::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& ) const {
|
||||
if (isZero(childDerivs[0]))
|
||||
return ExpressionTreeNode(new Operation::Constant(0.0));
|
||||
return ExpressionTreeNode(new Operation::Multiply(),
|
||||
ExpressionTreeNode(new Operation::MultiplyConstant(3.0),
|
||||
ExpressionTreeNode(new Operation::Square(), children[0])),
|
||||
@ -278,6 +360,8 @@ ExpressionTreeNode Operation::Cube::differentiate(const std::vector<ExpressionTr
|
||||
}
|
||||
|
||||
ExpressionTreeNode Operation::Reciprocal::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& ) const {
|
||||
if (isZero(childDerivs[0]))
|
||||
return ExpressionTreeNode(new Operation::Constant(0.0));
|
||||
return ExpressionTreeNode(new Operation::Multiply(),
|
||||
ExpressionTreeNode(new Operation::Negate(),
|
||||
ExpressionTreeNode(new Operation::Reciprocal(),
|
||||
@ -290,11 +374,15 @@ ExpressionTreeNode Operation::AddConstant::differentiate(const std::vector<Expre
|
||||
}
|
||||
|
||||
ExpressionTreeNode Operation::MultiplyConstant::differentiate(const std::vector<ExpressionTreeNode>& , const std::vector<ExpressionTreeNode>& childDerivs, const std::string& ) const {
|
||||
if (isZero(childDerivs[0]))
|
||||
return ExpressionTreeNode(new Operation::Constant(0.0));
|
||||
return ExpressionTreeNode(new Operation::MultiplyConstant(value),
|
||||
childDerivs[0]);
|
||||
}
|
||||
|
||||
ExpressionTreeNode Operation::PowerConstant::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& ) const {
|
||||
if (isZero(childDerivs[0]))
|
||||
return ExpressionTreeNode(new Operation::Constant(0.0));
|
||||
return ExpressionTreeNode(new Operation::Multiply(),
|
||||
ExpressionTreeNode(new Operation::MultiplyConstant(value),
|
||||
ExpressionTreeNode(new Operation::PowerConstant(value-1),
|
||||
@ -305,22 +393,18 @@ ExpressionTreeNode Operation::PowerConstant::differentiate(const std::vector<Exp
|
||||
ExpressionTreeNode Operation::Min::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& ) const {
|
||||
ExpressionTreeNode step(new Operation::Step(),
|
||||
ExpressionTreeNode(new Operation::Subtract(), children[0], children[1]));
|
||||
return ExpressionTreeNode(new Operation::Subtract(),
|
||||
ExpressionTreeNode(new Operation::Multiply(), childDerivs[1], step),
|
||||
ExpressionTreeNode(new Operation::Multiply(), childDerivs[0],
|
||||
ExpressionTreeNode(new Operation::AddConstant(-1), step)));
|
||||
return ExpressionTreeNode(new Operation::Select(), {step, childDerivs[1], childDerivs[0]});
|
||||
}
|
||||
|
||||
ExpressionTreeNode Operation::Max::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& ) const {
|
||||
ExpressionTreeNode step(new Operation::Step(),
|
||||
ExpressionTreeNode(new Operation::Subtract(), children[0], children[1]));
|
||||
return ExpressionTreeNode(new Operation::Subtract(),
|
||||
ExpressionTreeNode(new Operation::Multiply(), childDerivs[0], step),
|
||||
ExpressionTreeNode(new Operation::Multiply(), childDerivs[1],
|
||||
ExpressionTreeNode(new Operation::AddConstant(-1), step)));
|
||||
return ExpressionTreeNode(new Operation::Select(), {step, childDerivs[0], childDerivs[1]});
|
||||
}
|
||||
|
||||
ExpressionTreeNode Operation::Abs::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& ) const {
|
||||
if (isZero(childDerivs[0]))
|
||||
return ExpressionTreeNode(new Operation::Constant(0.0));
|
||||
ExpressionTreeNode step(new Operation::Step(), children[0]);
|
||||
return ExpressionTreeNode(new Operation::Multiply(),
|
||||
childDerivs[0],
|
||||
@ -337,9 +421,5 @@ ExpressionTreeNode Operation::Ceil::differentiate(const std::vector<ExpressionTr
|
||||
}
|
||||
|
||||
ExpressionTreeNode Operation::Select::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& ) const {
|
||||
vector<ExpressionTreeNode> derivChildren;
|
||||
derivChildren.push_back(children[0]);
|
||||
derivChildren.push_back(childDerivs[1]);
|
||||
derivChildren.push_back(childDerivs[2]);
|
||||
return ExpressionTreeNode(new Operation::Select(), derivChildren);
|
||||
return ExpressionTreeNode(new Operation::Select(), {children[0], childDerivs[1], childDerivs[2]});
|
||||
}
|
||||
|
||||
@ -6,7 +6,7 @@
|
||||
* Biological Structures at Stanford, funded under the NIH Roadmap for *
|
||||
* Medical Research, grant U54 GM072970. See https://simtk.org. *
|
||||
* *
|
||||
* Portions copyright (c) 2009 Stanford University and the Authors. *
|
||||
* Portions copyright (c) 2009-2022 Stanford University and the Authors. *
|
||||
* Authors: Peter Eastman *
|
||||
* Contributors: *
|
||||
* *
|
||||
@ -31,6 +31,7 @@
|
||||
|
||||
#include "lepton/ParsedExpression.h"
|
||||
#include "lepton/CompiledExpression.h"
|
||||
#include "lepton/CompiledVectorExpression.h"
|
||||
#include "lepton/ExpressionProgram.h"
|
||||
#include "lepton/Operation.h"
|
||||
#include <limits>
|
||||
@ -68,9 +69,16 @@ double ParsedExpression::evaluate(const ExpressionTreeNode& node, const map<stri
|
||||
}
|
||||
|
||||
ParsedExpression ParsedExpression::optimize() const {
|
||||
ExpressionTreeNode result = precalculateConstantSubexpressions(getRootNode());
|
||||
ExpressionTreeNode result = getRootNode();
|
||||
vector<const ExpressionTreeNode*> examples;
|
||||
result.assignTags(examples);
|
||||
map<int, ExpressionTreeNode> nodeCache;
|
||||
result = precalculateConstantSubexpressions(result, nodeCache);
|
||||
while (true) {
|
||||
ExpressionTreeNode simplified = substituteSimplerExpression(result);
|
||||
examples.clear();
|
||||
result.assignTags(examples);
|
||||
nodeCache.clear();
|
||||
ExpressionTreeNode simplified = substituteSimplerExpression(result, nodeCache);
|
||||
if (simplified == result)
|
||||
break;
|
||||
result = simplified;
|
||||
@ -80,9 +88,15 @@ ParsedExpression ParsedExpression::optimize() const {
|
||||
|
||||
ParsedExpression ParsedExpression::optimize(const map<string, double>& variables) const {
|
||||
ExpressionTreeNode result = preevaluateVariables(getRootNode(), variables);
|
||||
result = precalculateConstantSubexpressions(result);
|
||||
vector<const ExpressionTreeNode*> examples;
|
||||
result.assignTags(examples);
|
||||
map<int, ExpressionTreeNode> nodeCache;
|
||||
result = precalculateConstantSubexpressions(result, nodeCache);
|
||||
while (true) {
|
||||
ExpressionTreeNode simplified = substituteSimplerExpression(result);
|
||||
examples.clear();
|
||||
result.assignTags(examples);
|
||||
nodeCache.clear();
|
||||
ExpressionTreeNode simplified = substituteSimplerExpression(result, nodeCache);
|
||||
if (simplified == result)
|
||||
break;
|
||||
result = simplified;
|
||||
@ -104,27 +118,44 @@ ExpressionTreeNode ParsedExpression::preevaluateVariables(const ExpressionTreeNo
|
||||
return ExpressionTreeNode(node.getOperation().clone(), children);
|
||||
}
|
||||
|
||||
ExpressionTreeNode ParsedExpression::precalculateConstantSubexpressions(const ExpressionTreeNode& node) {
|
||||
ExpressionTreeNode ParsedExpression::precalculateConstantSubexpressions(const ExpressionTreeNode& node, map<int, ExpressionTreeNode>& nodeCache) {
|
||||
auto cached = nodeCache.find(node.tag);
|
||||
if (cached != nodeCache.end())
|
||||
return cached->second;
|
||||
vector<ExpressionTreeNode> children(node.getChildren().size());
|
||||
for (int i = 0; i < (int) children.size(); i++)
|
||||
children[i] = precalculateConstantSubexpressions(node.getChildren()[i]);
|
||||
children[i] = precalculateConstantSubexpressions(node.getChildren()[i], nodeCache);
|
||||
ExpressionTreeNode result = ExpressionTreeNode(node.getOperation().clone(), children);
|
||||
if (node.getOperation().getId() == Operation::VARIABLE || node.getOperation().getId() == Operation::CUSTOM)
|
||||
if (node.getOperation().getId() == Operation::VARIABLE || node.getOperation().getId() == Operation::CUSTOM) {
|
||||
nodeCache[node.tag] = result;
|
||||
return result;
|
||||
}
|
||||
for (int i = 0; i < (int) children.size(); i++)
|
||||
if (children[i].getOperation().getId() != Operation::CONSTANT)
|
||||
if (children[i].getOperation().getId() != Operation::CONSTANT) {
|
||||
nodeCache[node.tag] = result;
|
||||
return result;
|
||||
return ExpressionTreeNode(new Operation::Constant(evaluate(result, map<string, double>())));
|
||||
}
|
||||
result = ExpressionTreeNode(new Operation::Constant(evaluate(result, map<string, double>())));
|
||||
nodeCache[node.tag] = result;
|
||||
return result;
|
||||
}
|
||||
|
||||
ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const ExpressionTreeNode& node) {
|
||||
ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const ExpressionTreeNode& node, map<int, ExpressionTreeNode>& nodeCache) {
|
||||
vector<ExpressionTreeNode> children(node.getChildren().size());
|
||||
for (int i = 0; i < (int) children.size(); i++)
|
||||
children[i] = substituteSimplerExpression(node.getChildren()[i]);
|
||||
for (int i = 0; i < (int) children.size(); i++) {
|
||||
const ExpressionTreeNode& child = node.getChildren()[i];
|
||||
auto cached = nodeCache.find(child.tag);
|
||||
if (cached == nodeCache.end()) {
|
||||
children[i] = substituteSimplerExpression(child, nodeCache);
|
||||
nodeCache[child.tag] = children[i];
|
||||
}
|
||||
else
|
||||
children[i] = cached->second;
|
||||
}
|
||||
|
||||
// Collect some info on constant expressions in children
|
||||
bool first_const = children.size() > 0 && isConstant(children[0]); // is first child constant?
|
||||
bool second_const = children.size() > 1 && isConstant(children[1]); ; // is second child constant?
|
||||
bool second_const = children.size() > 1 && isConstant(children[1]); // is second child constant?
|
||||
double first, second; // if yes, value of first and second child
|
||||
if (first_const)
|
||||
first = getConstantValue(children[0]);
|
||||
@ -296,6 +327,12 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio
|
||||
return children[0].getChildren()[0];
|
||||
break;
|
||||
}
|
||||
case Operation::SELECT:
|
||||
{
|
||||
if (children[1] == children[2]) // Select between two identical values
|
||||
return children[1];
|
||||
break;
|
||||
}
|
||||
default:
|
||||
{
|
||||
// If operation ID is not one of the above,
|
||||
@ -308,14 +345,22 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio
|
||||
}
|
||||
|
||||
ParsedExpression ParsedExpression::differentiate(const string& variable) const {
|
||||
return differentiate(getRootNode(), variable);
|
||||
vector<const ExpressionTreeNode*> examples;
|
||||
getRootNode().assignTags(examples);
|
||||
map<int, ExpressionTreeNode> nodeCache;
|
||||
return differentiate(getRootNode(), variable, nodeCache);
|
||||
}
|
||||
|
||||
ExpressionTreeNode ParsedExpression::differentiate(const ExpressionTreeNode& node, const string& variable) {
|
||||
ExpressionTreeNode ParsedExpression::differentiate(const ExpressionTreeNode& node, const string& variable, map<int, ExpressionTreeNode>& nodeCache) {
|
||||
auto cached = nodeCache.find(node.tag);
|
||||
if (cached != nodeCache.end())
|
||||
return cached->second;
|
||||
vector<ExpressionTreeNode> childDerivs(node.getChildren().size());
|
||||
for (int i = 0; i < (int) childDerivs.size(); i++)
|
||||
childDerivs[i] = differentiate(node.getChildren()[i], variable);
|
||||
return node.getOperation().differentiate(node.getChildren(),childDerivs, variable);
|
||||
childDerivs[i] = differentiate(node.getChildren()[i], variable, nodeCache);
|
||||
ExpressionTreeNode result = node.getOperation().differentiate(node.getChildren(), childDerivs, variable);
|
||||
nodeCache[node.tag] = result;
|
||||
return result;
|
||||
}
|
||||
|
||||
bool ParsedExpression::isConstant(const ExpressionTreeNode& node) {
|
||||
@ -337,6 +382,10 @@ CompiledExpression ParsedExpression::createCompiledExpression() const {
|
||||
return CompiledExpression(*this);
|
||||
}
|
||||
|
||||
CompiledVectorExpression ParsedExpression::createCompiledVectorExpression(int width) const {
|
||||
return CompiledVectorExpression(*this, width);
|
||||
}
|
||||
|
||||
ParsedExpression ParsedExpression::renameVariables(const map<string, string>& replacements) const {
|
||||
return ParsedExpression(renameNodeVariables(getRootNode(), replacements));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user