From 517a2e5e2644440fde99438e00b0a699c072b414 Mon Sep 17 00:00:00 2001 From: Axel Kohlmeyer Date: Wed, 21 Dec 2022 14:18:39 -0500 Subject: [PATCH 01/79] import Lepton library with namespace and header changed to LMP_Lepton --- lib/README | 2 + lib/lepton/Install.py | 1 + lib/lepton/LICENSE | 20 + lib/lepton/Makefile.lammps.empty | 7 + lib/lepton/Makefile.mpi | 34 + lib/lepton/Makefile.serial | 34 + lib/lepton/README.md | 43 + lib/lepton/include/LMP_Lepton.h | 43 + .../include/lepton/CompiledExpression.h | 114 ++ lib/lepton/include/lepton/CustomFunction.h | 109 ++ lib/lepton/include/lepton/Exception.h | 59 + lib/lepton/include/lepton/ExpressionProgram.h | 103 ++ .../include/lepton/ExpressionTreeNode.h | 105 ++ lib/lepton/include/lepton/Operation.h | 1193 +++++++++++++++++ lib/lepton/include/lepton/ParsedExpression.h | 131 ++ lib/lepton/include/lepton/Parser.h | 77 ++ lib/lepton/include/lepton/windowsIncludes.h | 41 + lib/lepton/src/CompiledExpression.cpp | 418 ++++++ lib/lepton/src/ExpressionProgram.cpp | 110 ++ lib/lepton/src/ExpressionTreeNode.cpp | 107 ++ lib/lepton/src/MSVC_erfc.h | 91 ++ lib/lepton/src/Operation.cpp | 345 +++++ lib/lepton/src/ParsedExpression.cpp | 379 ++++++ lib/lepton/src/Parser.cpp | 409 ++++++ 24 files changed, 3975 insertions(+) create mode 120000 lib/lepton/Install.py create mode 100644 lib/lepton/LICENSE create mode 100644 lib/lepton/Makefile.lammps.empty create mode 100644 lib/lepton/Makefile.mpi create mode 100644 lib/lepton/Makefile.serial create mode 100644 lib/lepton/README.md create mode 100644 lib/lepton/include/LMP_Lepton.h create mode 100644 lib/lepton/include/lepton/CompiledExpression.h create mode 100644 lib/lepton/include/lepton/CustomFunction.h create mode 100644 lib/lepton/include/lepton/Exception.h create mode 100644 lib/lepton/include/lepton/ExpressionProgram.h create mode 100644 lib/lepton/include/lepton/ExpressionTreeNode.h create mode 100644 lib/lepton/include/lepton/Operation.h create mode 100644 lib/lepton/include/lepton/ParsedExpression.h create mode 100644 lib/lepton/include/lepton/Parser.h create mode 100644 lib/lepton/include/lepton/windowsIncludes.h create mode 100644 lib/lepton/src/CompiledExpression.cpp create mode 100644 lib/lepton/src/ExpressionProgram.cpp create mode 100644 lib/lepton/src/ExpressionTreeNode.cpp create mode 100644 lib/lepton/src/MSVC_erfc.h create mode 100644 lib/lepton/src/Operation.cpp create mode 100644 lib/lepton/src/ParsedExpression.cpp create mode 100644 lib/lepton/src/Parser.cpp diff --git a/lib/README b/lib/README index ab71e6763c..255077bb1b 100644 --- a/lib/README +++ b/lib/README @@ -33,6 +33,8 @@ kim hooks to the KIM library, used by KIM package from Ryan Elliott and Ellad Tadmor (U Minn) kokkos Kokkos package for GPU and many-core acceleration from Kokkos development team (Sandia) +lepton Lepton library for fast evaluation of mathematical + expressions from a string. Imported from OpenMM. linalg set of BLAS and LAPACK routines needed by ATC package from Axel Kohlmeyer (Temple U) mdi hooks to the MDI library, used by MDI package diff --git a/lib/lepton/Install.py b/lib/lepton/Install.py new file mode 120000 index 0000000000..ffe709d44c --- /dev/null +++ b/lib/lepton/Install.py @@ -0,0 +1 @@ +../Install.py \ No newline at end of file diff --git a/lib/lepton/LICENSE b/lib/lepton/LICENSE new file mode 100644 index 0000000000..6359209705 --- /dev/null +++ b/lib/lepton/LICENSE @@ -0,0 +1,20 @@ +Portions copyright (c) 2009-2019 Stanford University and the Authors. +Authors: Peter Eastman and OpenMM contributors + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE +USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/lib/lepton/Makefile.lammps.empty b/lib/lepton/Makefile.lammps.empty new file mode 100644 index 0000000000..9e74c23b1d --- /dev/null +++ b/lib/lepton/Makefile.lammps.empty @@ -0,0 +1,7 @@ +# Settings that the LAMMPS build will import when this package library is used +# The default settings assume that HDF5 support is integrated into the standard +# distribution and search paths and thus only needs to link the HDF5 library. + +lepton_SYSINC = +lepton_SYSLIB = +lepton_SYSPATH = diff --git a/lib/lepton/Makefile.mpi b/lib/lepton/Makefile.mpi new file mode 100644 index 0000000000..045ea61015 --- /dev/null +++ b/lib/lepton/Makefile.mpi @@ -0,0 +1,34 @@ +EXTRAMAKE=Makefile.lammps.empty + +CC=mpicxx + +# -DH5_NO_DEPRECATED_SYMBOLS is required here to ensure we are using +# the v1.8 API when HDF5 is configured to default to using the v1.6 API. +CXXFLAGS=-D_DEFAULT_SOURCE -O2 -Wall -fPIC -std=c++11 +INC=-I include +AR=ar +ARFLAGS=rc +# need to build two libraries to not break compatibility and to support Install.py +LIB=liblepton.a +SRC=src/CompiledExpression.cpp src/ExpressionProgram.cpp src/ExpressionTreeNode.cpp src/Operation.cpp src/ParsedExpression.cpp src/Parser.cpp +OBJ=$(SRC:src/%.cpp=build/%.o) + +all: $(LIB) Makefile.lammps + +build: + mkdir -p build + +build/%.o: src/%.cpp build + $(CXX) $(INC) $(CXXFLAGS) -c $< -o $@ + +Makefile.lammps: + cp $(EXTRAMAKE) $@ + +.PHONY: all lib clean + +$(LIB) : $(OBJ) + $(AR) $(ARFLAGS) $@ $^ + +clean: + rm -f build/*.o $(LIB) *~ + diff --git a/lib/lepton/Makefile.serial b/lib/lepton/Makefile.serial new file mode 100644 index 0000000000..58151e49c2 --- /dev/null +++ b/lib/lepton/Makefile.serial @@ -0,0 +1,34 @@ +EXTRAMAKE=Makefile.lammps.empty + +CC=g++ + +# -DH5_NO_DEPRECATED_SYMBOLS is required here to ensure we are using +# the v1.8 API when HDF5 is configured to default to using the v1.6 API. +CXXFLAGS=-D_DEFAULT_SOURCE -O2 -Wall -fPIC -std=c++11 +INC=-I include +AR=ar +ARFLAGS=rc +# need to build two libraries to not break compatibility and to support Install.py +LIB=liblepton.a +SRC=src/CompiledExpression.cpp src/ExpressionProgram.cpp src/ExpressionTreeNode.cpp src/Operation.cpp src/ParsedExpression.cpp src/Parser.cpp +OBJ=$(SRC:src/%.cpp=build/%.o) + +all: $(LIB) Makefile.lammps + +build: + mkdir -p build + +build/%.o: src/%.cpp build + $(CXX) $(INC) $(CXXFLAGS) -c $< -o $@ + +Makefile.lammps: + cp $(EXTRAMAKE) $@ + +.PHONY: all lib clean + +$(LIB) : $(OBJ) + $(AR) $(ARFLAGS) $@ $^ + +clean: + rm -f build/*.o $(LIB) *~ + diff --git a/lib/lepton/README.md b/lib/lepton/README.md new file mode 100644 index 0000000000..d2e4240c92 --- /dev/null +++ b/lib/lepton/README.md @@ -0,0 +1,43 @@ +This directory contains the lepton library from the OpenMM software +which allows to efficiently evaluate mathematical expressions from +strings. This library is used with the LEPTON package that support +force styles within LAMMPS that make use of this library. + +You can type "make lib-lepton" from the src directory to see help on how +to build this library via make commands, or you can do the same thing +by typing "python Install.py" from within this directory, or you can +do it manually by following the instructions below. + +--------------------- + +Lepton (short for “lightweight expression parser”) is a C++ library for +parsing, evaluating, differentiating, and analyzing mathematical +expressions. It takes expressions in the form of text strings, then +converts them to an internal representation suitable for evaluation or +analysis. Here are some of its major features: + +- Support for a large number of standard mathematical functions and operations. +- Support for user defined custom functions. +- A variety of optimizations for automatically simplifying expressions. +- Computing analytic derivatives. +- Representing parsed expressions in two different forms (tree or program) suitable for + further analysis or processing. + +Lepton was originally created for use in the [OpenMM project](https://openmm.org) +ch5md is developed by Pierre de Buyl and is released under the 3-clause BSD +license that can be found in the file LICENSE. + +To use the h5md dump style in lammps, execute +make -f Makefile.h5cc +in this directory then +make yes-h5md +in the src directory of LAMMPS to rebuild LAMMPS. + +Note that you must have the h5cc compiler installed to use +Makefile.h5cc. It should be part + +If HDF5 is not in a standard system location, edit Makefile.lammps accordingly. + +In the case of 2015 and more recent debian and ubuntu systems where concurrent +serial and mpi are possible, use the full platform depedent path, i.e. +`HDF5_PATH=/usr/lib/x86_64-linux-gnu/hdf5/serial` diff --git a/lib/lepton/include/LMP_Lepton.h b/lib/lepton/include/LMP_Lepton.h new file mode 100644 index 0000000000..73b6b6fa38 --- /dev/null +++ b/lib/lepton/include/LMP_Lepton.h @@ -0,0 +1,43 @@ +#ifndef LMP_LEPTON_H_ +#define LMP_LEPTON_H_ + +/* -------------------------------------------------------------------------- * + * Lepton * + * -------------------------------------------------------------------------- * + * This is part of the Lepton expression parser originating from * + * Simbios, the NIH National Center for Physics-Based Simulation of * + * Biological Structures at Stanford, funded under the NIH Roadmap for * + * Medical Research, grant U54 GM072970. See https://simtk.org. * + * * + * Portions copyright (c) 2009 Stanford University and the Authors. * + * Authors: Peter Eastman * + * Contributors: * + * * + * Permission is hereby granted, free of charge, to any person obtaining a * + * copy of this software and associated documentation files (the "Software"), * + * to deal in the Software without restriction, including without limitation * + * the rights to use, copy, modify, merge, publish, distribute, sublicense, * + * and/or sell copies of the Software, and to permit persons to whom the * + * Software is furnished to do so, subject to the following conditions: * + * * + * The above copyright notice and this permission notice shall be included in * + * all copies or substantial portions of the Software. * + * * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL * + * THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, * + * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR * + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE * + * USE OR OTHER DEALINGS IN THE SOFTWARE. * + * -------------------------------------------------------------------------- */ + +#include "lepton/CompiledExpression.h" +#include "lepton/CustomFunction.h" +#include "lepton/ExpressionProgram.h" +#include "lepton/ExpressionTreeNode.h" +#include "lepton/Operation.h" +#include "lepton/ParsedExpression.h" +#include "lepton/Parser.h" + +#endif /*LMP_LEPTON_H_*/ diff --git a/lib/lepton/include/lepton/CompiledExpression.h b/lib/lepton/include/lepton/CompiledExpression.h new file mode 100644 index 0000000000..bf076b0d8b --- /dev/null +++ b/lib/lepton/include/lepton/CompiledExpression.h @@ -0,0 +1,114 @@ +#ifndef LEPTON_COMPILED_EXPRESSION_H_ +#define LEPTON_COMPILED_EXPRESSION_H_ + +/* -------------------------------------------------------------------------- * + * Lepton * + * -------------------------------------------------------------------------- * + * This is part of the Lepton expression parser originating from * + * Simbios, the NIH National Center for Physics-Based Simulation of * + * Biological Structures at Stanford, funded under the NIH Roadmap for * + * Medical Research, grant U54 GM072970. See https://simtk.org. * + * * + * Portions copyright (c) 2013-2019 Stanford University and the Authors. * + * Authors: Peter Eastman * + * Contributors: * + * * + * Permission is hereby granted, free of charge, to any person obtaining a * + * copy of this software and associated documentation files (the "Software"), * + * to deal in the Software without restriction, including without limitation * + * the rights to use, copy, modify, merge, publish, distribute, sublicense, * + * and/or sell copies of the Software, and to permit persons to whom the * + * Software is furnished to do so, subject to the following conditions: * + * * + * The above copyright notice and this permission notice shall be included in * + * all copies or substantial portions of the Software. * + * * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL * + * THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, * + * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR * + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE * + * USE OR OTHER DEALINGS IN THE SOFTWARE. * + * -------------------------------------------------------------------------- */ + +#include "ExpressionTreeNode.h" +#include "windowsIncludes.h" +#include +#include +#include +#include +#include +#ifdef LEPTON_USE_JIT + #include "asmjit.h" +#endif + +namespace LMP_Lepton { + +class Operation; +class ParsedExpression; + +/** + * A CompiledExpression is a highly optimized representation of an expression for cases when you want to evaluate + * it many times as quickly as possible. You should treat it as an opaque object; none of the internal representation + * is visible. + * + * A CompiledExpression is created by calling createCompiledExpression() on a ParsedExpression. + * + * WARNING: CompiledExpression is NOT thread safe. You should never access a CompiledExpression from two threads at + * the same time. + */ + +class LEPTON_EXPORT CompiledExpression { +public: + CompiledExpression(); + CompiledExpression(const CompiledExpression& expression); + ~CompiledExpression(); + CompiledExpression& operator=(const CompiledExpression& expression); + /** + * Get the names of all variables used by this expression. + */ + const std::set& getVariables() const; + /** + * Get a reference to the memory location where the value of a particular variable is stored. This can be used + * to set the value of the variable before calling evaluate(). + */ + double& getVariableReference(const std::string& name); + /** + * You can optionally specify the memory locations from which the values of variables should be read. + * This is useful, for example, when several expressions all use the same variable. You can then set + * the value of that variable in one place, and it will be seen by all of them. + */ + void setVariableLocations(std::map& variableLocations); + /** + * Evaluate the expression. The values of all variables should have been set before calling this. + */ + double evaluate() const; +private: + friend class ParsedExpression; + CompiledExpression(const ParsedExpression& expression); + void compileExpression(const ExpressionTreeNode& node, std::vector >& temps); + int findTempIndex(const ExpressionTreeNode& node, std::vector >& temps); + std::map variablePointers; + std::vector > variablesToCopy; + std::vector > arguments; + std::vector target; + std::vector operation; + std::map variableIndices; + std::set variableNames; + mutable std::vector workspace; + mutable std::vector argValues; + std::map dummyVariables; + double (*jitCode)(); +#ifdef LEPTON_USE_JIT + void generateJitCode(); + void generateSingleArgCall(asmjit::X86Compiler& c, asmjit::X86Xmm& dest, asmjit::X86Xmm& arg, double (*function)(double)); + void generateTwoArgCall(asmjit::X86Compiler& c, asmjit::X86Xmm& dest, asmjit::X86Xmm& arg1, asmjit::X86Xmm& arg2, double (*function)(double, double)); + std::vector constants; + asmjit::JitRuntime runtime; +#endif +}; + +} // namespace LMP_Lepton + +#endif /*LEPTON_COMPILED_EXPRESSION_H_*/ diff --git a/lib/lepton/include/lepton/CustomFunction.h b/lib/lepton/include/lepton/CustomFunction.h new file mode 100644 index 0000000000..b8cbee8c96 --- /dev/null +++ b/lib/lepton/include/lepton/CustomFunction.h @@ -0,0 +1,109 @@ +#ifndef LEPTON_CUSTOM_FUNCTION_H_ +#define LEPTON_CUSTOM_FUNCTION_H_ + +/* -------------------------------------------------------------------------- * + * Lepton * + * -------------------------------------------------------------------------- * + * This is part of the Lepton expression parser originating from * + * Simbios, the NIH National Center for Physics-Based Simulation of * + * Biological Structures at Stanford, funded under the NIH Roadmap for * + * Medical Research, grant U54 GM072970. See https://simtk.org. * + * * + * Portions copyright (c) 2009 Stanford University and the Authors. * + * Authors: Peter Eastman * + * Contributors: * + * * + * Permission is hereby granted, free of charge, to any person obtaining a * + * copy of this software and associated documentation files (the "Software"), * + * to deal in the Software without restriction, including without limitation * + * the rights to use, copy, modify, merge, publish, distribute, sublicense, * + * and/or sell copies of the Software, and to permit persons to whom the * + * Software is furnished to do so, subject to the following conditions: * + * * + * The above copyright notice and this permission notice shall be included in * + * all copies or substantial portions of the Software. * + * * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL * + * THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, * + * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR * + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE * + * USE OR OTHER DEALINGS IN THE SOFTWARE. * + * -------------------------------------------------------------------------- */ + +#include "windowsIncludes.h" + +namespace LMP_Lepton { + +/** + * This class is the interface for defining your own function that may be included in expressions. + * To use it, create a concrete subclass that implements all of the virtual methods for each new function + * you want to define. Then when you call Parser::parse() to parse an expression, pass a map of + * function names to CustomFunction objects. + */ + +class LEPTON_EXPORT CustomFunction { +public: + virtual ~CustomFunction() { + } + /** + * Get the number of arguments this function expects. + */ + virtual int getNumArguments() const = 0; + /** + * Evaluate the function. + * + * @param arguments the array of argument values + */ + virtual double evaluate(const double* arguments) const = 0; + /** + * Evaluate a derivative of the function. + * + * @param arguments the array of argument values + * @param derivOrder an array specifying the number of times the function has been differentiated + * with respect to each of its arguments. For example, the array {0, 2} indicates + * a second derivative with respect to the second argument. + */ + virtual double evaluateDerivative(const double* arguments, const int* derivOrder) const = 0; + /** + * Create a new duplicate of this object on the heap using the "new" operator. + */ + virtual CustomFunction* clone() const = 0; +}; + +/** + * This class is an implementation of CustomFunction that does no computation. It just returns + * 0 for the value and derivatives. This is useful when using the parser to analyze expressions + * rather than to evaluate them. You can just create PlaceholderFunctions to represent any custom + * functions that may appear in expressions. + */ + +class LEPTON_EXPORT PlaceholderFunction : public CustomFunction { +public: + /** + * Create a Placeholder function. + * + * @param numArgs the number of arguments the function expects + */ + PlaceholderFunction(int numArgs) : numArgs(numArgs) { + } + int getNumArguments() const { + return numArgs; + } + double evaluate(const double* arguments) const { + return 0.0; + } + double evaluateDerivative(const double* arguments, const int* derivOrder) const { + return 0.0; + } + CustomFunction* clone() const { + return new PlaceholderFunction(numArgs); + }; +private: + int numArgs; +}; + +} // namespace LMP_Lepton + +#endif /*LEPTON_CUSTOM_FUNCTION_H_*/ diff --git a/lib/lepton/include/lepton/Exception.h b/lib/lepton/include/lepton/Exception.h new file mode 100644 index 0000000000..413b08f52e --- /dev/null +++ b/lib/lepton/include/lepton/Exception.h @@ -0,0 +1,59 @@ +#ifndef LEPTON_EXCEPTION_H_ +#define LEPTON_EXCEPTION_H_ + +/* -------------------------------------------------------------------------- * + * Lepton * + * -------------------------------------------------------------------------- * + * This is part of the Lepton expression parser originating from * + * Simbios, the NIH National Center for Physics-Based Simulation of * + * Biological Structures at Stanford, funded under the NIH Roadmap for * + * Medical Research, grant U54 GM072970. See https://simtk.org. * + * * + * Portions copyright (c) 2009 Stanford University and the Authors. * + * Authors: Peter Eastman * + * Contributors: * + * * + * Permission is hereby granted, free of charge, to any person obtaining a * + * copy of this software and associated documentation files (the "Software"), * + * to deal in the Software without restriction, including without limitation * + * the rights to use, copy, modify, merge, publish, distribute, sublicense, * + * and/or sell copies of the Software, and to permit persons to whom the * + * Software is furnished to do so, subject to the following conditions: * + * * + * The above copyright notice and this permission notice shall be included in * + * all copies or substantial portions of the Software. * + * * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL * + * THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, * + * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR * + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE * + * USE OR OTHER DEALINGS IN THE SOFTWARE. * + * -------------------------------------------------------------------------- */ + +#include +#include + +namespace LMP_Lepton { + +/** + * This class is used for all exceptions thrown by Lepton. + */ + +class Exception : public std::exception { +public: + Exception(const std::string& message) : message(message) { + } + ~Exception() throw() { + } + const char* what() const throw() { + return message.c_str(); + } +private: + std::string message; +}; + +} // namespace LMP_Lepton + +#endif /*LEPTON_EXCEPTION_H_*/ diff --git a/lib/lepton/include/lepton/ExpressionProgram.h b/lib/lepton/include/lepton/ExpressionProgram.h new file mode 100644 index 0000000000..4fba4051e4 --- /dev/null +++ b/lib/lepton/include/lepton/ExpressionProgram.h @@ -0,0 +1,103 @@ +#ifndef LEPTON_EXPRESSION_PROGRAM_H_ +#define LEPTON_EXPRESSION_PROGRAM_H_ + +/* -------------------------------------------------------------------------- * + * Lepton * + * -------------------------------------------------------------------------- * + * This is part of the Lepton expression parser originating from * + * Simbios, the NIH National Center for Physics-Based Simulation of * + * Biological Structures at Stanford, funded under the NIH Roadmap for * + * Medical Research, grant U54 GM072970. See https://simtk.org. * + * * + * Portions copyright (c) 2009-2018 Stanford University and the Authors. * + * Authors: Peter Eastman * + * Contributors: * + * * + * Permission is hereby granted, free of charge, to any person obtaining a * + * copy of this software and associated documentation files (the "Software"), * + * to deal in the Software without restriction, including without limitation * + * the rights to use, copy, modify, merge, publish, distribute, sublicense, * + * and/or sell copies of the Software, and to permit persons to whom the * + * Software is furnished to do so, subject to the following conditions: * + * * + * The above copyright notice and this permission notice shall be included in * + * all copies or substantial portions of the Software. * + * * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL * + * THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, * + * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR * + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE * + * USE OR OTHER DEALINGS IN THE SOFTWARE. * + * -------------------------------------------------------------------------- */ + +#include "ExpressionTreeNode.h" +#include "windowsIncludes.h" +#include +#include +#include + +namespace LMP_Lepton { + +class ParsedExpression; + +/** + * An ExpressionProgram is a linear sequence of Operations for evaluating an expression. The evaluation + * is done with a stack. The arguments to each Operation are first taken off the stack in order, then it is + * evaluated and the result is pushed back onto the stack. At the end, the stack contains a single value, + * which is the value of the expression. + * + * An ExpressionProgram is created by calling createProgram() on a ParsedExpression. + */ + +class LEPTON_EXPORT ExpressionProgram { +public: + ExpressionProgram(); + ExpressionProgram(const ExpressionProgram& program); + ~ExpressionProgram(); + ExpressionProgram& operator=(const ExpressionProgram& program); + /** + * Get the number of Operations that make up this program. + */ + int getNumOperations() const; + /** + * Get an Operation in this program. + */ + const Operation& getOperation(int index) const; + /** + * Change an Operation in this program. + * + * The Operation must have been allocated on the heap with the "new" operator. + * The ExpressionProgram assumes ownership of it and will delete it when it + * is no longer needed. + */ + void setOperation(int index, Operation* operation); + /** + * Get the size of the stack needed to execute this program. This is the largest number of elements present + * on the stack at any point during evaluation. + */ + int getStackSize() const; + /** + * Evaluate the expression. If the expression involves any variables, this method will throw an exception. + */ + double evaluate() const; + /** + * Evaluate the expression. + * + * @param variables a map specifying the values of all variables that appear in the expression. If any + * variable appears in the expression but is not included in this map, an exception + * will be thrown. + */ + double evaluate(const std::map& variables) const; +private: + friend class ParsedExpression; + ExpressionProgram(const ParsedExpression& expression); + void buildProgram(const ExpressionTreeNode& node); + std::vector operations; + int maxArgs, stackSize; +}; + +} // namespace LMP_Lepton + +#endif /*LEPTON_EXPRESSION_PROGRAM_H_*/ diff --git a/lib/lepton/include/lepton/ExpressionTreeNode.h b/lib/lepton/include/lepton/ExpressionTreeNode.h new file mode 100644 index 0000000000..514cc008a9 --- /dev/null +++ b/lib/lepton/include/lepton/ExpressionTreeNode.h @@ -0,0 +1,105 @@ +#ifndef LEPTON_EXPRESSION_TREE_NODE_H_ +#define LEPTON_EXPRESSION_TREE_NODE_H_ + +/* -------------------------------------------------------------------------- * + * Lepton * + * -------------------------------------------------------------------------- * + * This is part of the Lepton expression parser originating from * + * Simbios, the NIH National Center for Physics-Based Simulation of * + * Biological Structures at Stanford, funded under the NIH Roadmap for * + * Medical Research, grant U54 GM072970. See https://simtk.org. * + * * + * Portions copyright (c) 2009 Stanford University and the Authors. * + * Authors: Peter Eastman * + * Contributors: * + * * + * Permission is hereby granted, free of charge, to any person obtaining a * + * copy of this software and associated documentation files (the "Software"), * + * to deal in the Software without restriction, including without limitation * + * the rights to use, copy, modify, merge, publish, distribute, sublicense, * + * and/or sell copies of the Software, and to permit persons to whom the * + * Software is furnished to do so, subject to the following conditions: * + * * + * The above copyright notice and this permission notice shall be included in * + * all copies or substantial portions of the Software. * + * * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL * + * THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, * + * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR * + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE * + * USE OR OTHER DEALINGS IN THE SOFTWARE. * + * -------------------------------------------------------------------------- */ + +#include "windowsIncludes.h" +#include +#include + +namespace LMP_Lepton { + +class Operation; + +/** + * This class represents a node in the abstract syntax tree representation of an expression. + * Each node is defined by an Operation and a set of children. When the expression is + * evaluated, each child is first evaluated in order, then the resulting values are passed + * as the arguments to the Operation's evaluate() method. + */ + +class LEPTON_EXPORT ExpressionTreeNode { +public: + /** + * Create a new ExpressionTreeNode. + * + * @param operation the operation for this node. The ExpressionTreeNode takes over ownership + * of this object, and deletes it when the node is itself deleted. + * @param children the children of this node + */ + ExpressionTreeNode(Operation* operation, const std::vector& children); + /** + * Create a new ExpressionTreeNode with two children. + * + * @param operation the operation for this node. The ExpressionTreeNode takes over ownership + * of this object, and deletes it when the node is itself deleted. + * @param child1 the first child of this node + * @param child2 the second child of this node + */ + ExpressionTreeNode(Operation* operation, const ExpressionTreeNode& child1, const ExpressionTreeNode& child2); + /** + * Create a new ExpressionTreeNode with one child. + * + * @param operation the operation for this node. The ExpressionTreeNode takes over ownership + * of this object, and deletes it when the node is itself deleted. + * @param child the child of this node + */ + ExpressionTreeNode(Operation* operation, const ExpressionTreeNode& child); + /** + * Create a new ExpressionTreeNode with no children. + * + * @param operation the operation for this node. The ExpressionTreeNode takes over ownership + * of this object, and deletes it when the node is itself deleted. + */ + ExpressionTreeNode(Operation* operation); + ExpressionTreeNode(const ExpressionTreeNode& node); + ExpressionTreeNode(); + ~ExpressionTreeNode(); + bool operator==(const ExpressionTreeNode& node) const; + bool operator!=(const ExpressionTreeNode& node) const; + ExpressionTreeNode& operator=(const ExpressionTreeNode& node); + /** + * Get the Operation performed by this node. + */ + const Operation& getOperation() const; + /** + * Get this node's child nodes. + */ + const std::vector& getChildren() const; +private: + Operation* operation; + std::vector children; +}; + +} // namespace LMP_Lepton + +#endif /*LEPTON_EXPRESSION_TREE_NODE_H_*/ diff --git a/lib/lepton/include/lepton/Operation.h b/lib/lepton/include/lepton/Operation.h new file mode 100644 index 0000000000..b27b25d3d8 --- /dev/null +++ b/lib/lepton/include/lepton/Operation.h @@ -0,0 +1,1193 @@ +#ifndef LEPTON_OPERATION_H_ +#define LEPTON_OPERATION_H_ + +/* -------------------------------------------------------------------------- * + * Lepton * + * -------------------------------------------------------------------------- * + * This is part of the Lepton expression parser originating from * + * Simbios, the NIH National Center for Physics-Based Simulation of * + * Biological Structures at Stanford, funded under the NIH Roadmap for * + * Medical Research, grant U54 GM072970. See https://simtk.org. * + * * + * Portions copyright (c) 2009-2019 Stanford University and the Authors. * + * Authors: Peter Eastman * + * Contributors: * + * * + * Permission is hereby granted, free of charge, to any person obtaining a * + * copy of this software and associated documentation files (the "Software"), * + * to deal in the Software without restriction, including without limitation * + * the rights to use, copy, modify, merge, publish, distribute, sublicense, * + * and/or sell copies of the Software, and to permit persons to whom the * + * Software is furnished to do so, subject to the following conditions: * + * * + * The above copyright notice and this permission notice shall be included in * + * all copies or substantial portions of the Software. * + * * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL * + * THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, * + * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR * + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE * + * USE OR OTHER DEALINGS IN THE SOFTWARE. * + * -------------------------------------------------------------------------- */ + +#include "windowsIncludes.h" +#include "CustomFunction.h" +#include "Exception.h" +#include +#include +#include +#include +#include +#include + +namespace LMP_Lepton { + +class ExpressionTreeNode; + +/** + * An Operation represents a single step in the evaluation of an expression, such as a function, + * an operator, or a constant value. Each Operation takes some number of values as arguments + * and produces a single value. + * + * This is an abstract class with subclasses for specific operations. + */ + +class LEPTON_EXPORT Operation { +public: + virtual ~Operation() { + } + /** + * This enumeration lists all Operation subclasses. This is provided so that switch statements + * can be used when processing or analyzing parsed expressions. + */ + enum Id {CONSTANT, VARIABLE, CUSTOM, ADD, SUBTRACT, MULTIPLY, DIVIDE, POWER, NEGATE, SQRT, EXP, LOG, + SIN, COS, SEC, CSC, TAN, COT, ASIN, ACOS, ATAN, ATAN2, SINH, COSH, TANH, ERF, ERFC, STEP, DELTA, SQUARE, CUBE, RECIPROCAL, + ADD_CONSTANT, MULTIPLY_CONSTANT, POWER_CONSTANT, MIN, MAX, ABS, FLOOR, CEIL, SELECT}; + /** + * Get the name of this Operation. + */ + virtual std::string getName() const = 0; + /** + * Get this Operation's ID. + */ + virtual Id getId() const = 0; + /** + * Get the number of arguments this operation expects. + */ + virtual int getNumArguments() const = 0; + /** + * Create a clone of this Operation. + */ + virtual Operation* clone() const = 0; + /** + * Perform the computation represented by this operation. + * + * @param args the array of arguments + * @param variables a map containing the values of all variables + * @return the result of performing the computation. + */ + virtual double evaluate(double* args, const std::map& variables) const = 0; + /** + * Return an ExpressionTreeNode which represents the analytic derivative of this Operation with respect to a variable. + * + * @param children the child nodes + * @param childDerivs the derivatives of the child nodes with respect to the variable + * @param variable the variable with respect to which the derivate should be taken + */ + virtual ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const = 0; + /** + * Get whether this operation should be displayed with infix notation. + */ + virtual bool isInfixOperator() const { + return false; + } + /** + * Get whether this is a symmetric binary operation, such that exchanging its arguments + * does not affect the result. + */ + virtual bool isSymmetric() const { + return false; + } + virtual bool operator!=(const Operation& op) const { + return op.getId() != getId(); + } + virtual bool operator==(const Operation& op) const { + return !(*this != op); + } + class Constant; + class Variable; + class Custom; + class Add; + class Subtract; + class Multiply; + class Divide; + class Power; + class Negate; + class Sqrt; + class Exp; + class Log; + class Sin; + class Cos; + class Sec; + class Csc; + class Tan; + class Cot; + class Asin; + class Acos; + class Atan; + class Atan2; + class Sinh; + class Cosh; + class Tanh; + class Erf; + class Erfc; + class Step; + class Delta; + class Square; + class Cube; + class Reciprocal; + class AddConstant; + class MultiplyConstant; + class PowerConstant; + class Min; + class Max; + class Abs; + class Floor; + class Ceil; + class Select; +}; + +class LEPTON_EXPORT Operation::Constant : public Operation { +public: + Constant(double value) : value(value) { + } + std::string getName() const { + std::stringstream name; + name << value; + return name.str(); + } + Id getId() const { + return CONSTANT; + } + int getNumArguments() const { + return 0; + } + Operation* clone() const { + return new Constant(value); + } + double evaluate(double* args, const std::map& variables) const { + return value; + } + ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; + double getValue() const { + return value; + } + bool operator!=(const Operation& op) const { + const Constant* o = dynamic_cast(&op); + return (o == NULL || o->value != value); + } +private: + double value; +}; + +class LEPTON_EXPORT Operation::Variable : public Operation { +public: + Variable(const std::string& name) : name(name) { + } + std::string getName() const { + return name; + } + Id getId() const { + return VARIABLE; + } + int getNumArguments() const { + return 0; + } + Operation* clone() const { + return new Variable(name); + } + double evaluate(double* args, const std::map& variables) const { + std::map::const_iterator iter = variables.find(name); + if (iter == variables.end()) + throw Exception("No value specified for variable "+name); + return iter->second; + } + ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; + bool operator!=(const Operation& op) const { + const Variable* o = dynamic_cast(&op); + return (o == NULL || o->name != name); + } +private: + std::string name; +}; + +class LEPTON_EXPORT Operation::Custom : public Operation { +public: + Custom(const std::string& name, CustomFunction* function) : name(name), function(function), isDerivative(false), derivOrder(function->getNumArguments(), 0) { + } + Custom(const std::string& name, CustomFunction* function, const std::vector& derivOrder) : name(name), function(function), isDerivative(false), derivOrder(derivOrder) { + for (int order : derivOrder) + if (order != 0) + isDerivative = true; + } + Custom(const Custom& base, int derivIndex) : name(base.name), function(base.function->clone()), isDerivative(true), derivOrder(base.derivOrder) { + derivOrder[derivIndex]++; + } + ~Custom() { + delete function; + } + std::string getName() const { + return name; + } + Id getId() const { + return CUSTOM; + } + int getNumArguments() const { + return function->getNumArguments(); + } + Operation* clone() const { + Custom* clone = new Custom(name, function->clone()); + clone->isDerivative = isDerivative; + clone->derivOrder = derivOrder; + return clone; + } + double evaluate(double* args, const std::map& variables) const { + if (isDerivative) + return function->evaluateDerivative(args, &derivOrder[0]); + return function->evaluate(args); + } + ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; + const std::vector& getDerivOrder() const { + return derivOrder; + } + bool operator!=(const Operation& op) const { + const Custom* o = dynamic_cast(&op); + return (o == NULL || o->name != name || o->isDerivative != isDerivative || o->derivOrder != derivOrder); + } +private: + std::string name; + CustomFunction* function; + bool isDerivative; + std::vector derivOrder; +}; + +class LEPTON_EXPORT Operation::Add : public Operation { +public: + Add() { + } + std::string getName() const { + return "+"; + } + Id getId() const { + return ADD; + } + int getNumArguments() const { + return 2; + } + Operation* clone() const { + return new Add(); + } + double evaluate(double* args, const std::map& variables) const { + return args[0]+args[1]; + } + ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; + bool isInfixOperator() const { + return true; + } + bool isSymmetric() const { + return true; + } +}; + +class LEPTON_EXPORT Operation::Subtract : public Operation { +public: + Subtract() { + } + std::string getName() const { + return "-"; + } + Id getId() const { + return SUBTRACT; + } + int getNumArguments() const { + return 2; + } + Operation* clone() const { + return new Subtract(); + } + double evaluate(double* args, const std::map& variables) const { + return args[0]-args[1]; + } + ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; + bool isInfixOperator() const { + return true; + } +}; + +class LEPTON_EXPORT Operation::Multiply : public Operation { +public: + Multiply() { + } + std::string getName() const { + return "*"; + } + Id getId() const { + return MULTIPLY; + } + int getNumArguments() const { + return 2; + } + Operation* clone() const { + return new Multiply(); + } + double evaluate(double* args, const std::map& variables) const { + return args[0]*args[1]; + } + ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; + bool isInfixOperator() const { + return true; + } + bool isSymmetric() const { + return true; + } +}; + +class LEPTON_EXPORT Operation::Divide : public Operation { +public: + Divide() { + } + std::string getName() const { + return "/"; + } + Id getId() const { + return DIVIDE; + } + int getNumArguments() const { + return 2; + } + Operation* clone() const { + return new Divide(); + } + double evaluate(double* args, const std::map& variables) const { + return args[0]/args[1]; + } + ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; + bool isInfixOperator() const { + return true; + } +}; + +class LEPTON_EXPORT Operation::Power : public Operation { +public: + Power() { + } + std::string getName() const { + return "^"; + } + Id getId() const { + return POWER; + } + int getNumArguments() const { + return 2; + } + Operation* clone() const { + return new Power(); + } + double evaluate(double* args, const std::map& variables) const { + return std::pow(args[0], args[1]); + } + ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; + bool isInfixOperator() const { + return true; + } +}; + +class LEPTON_EXPORT Operation::Negate : public Operation { +public: + Negate() { + } + std::string getName() const { + return "-"; + } + Id getId() const { + return NEGATE; + } + int getNumArguments() const { + return 1; + } + Operation* clone() const { + return new Negate(); + } + double evaluate(double* args, const std::map& variables) const { + return -args[0]; + } + ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; +}; + +class LEPTON_EXPORT Operation::Sqrt : public Operation { +public: + Sqrt() { + } + std::string getName() const { + return "sqrt"; + } + Id getId() const { + return SQRT; + } + int getNumArguments() const { + return 1; + } + Operation* clone() const { + return new Sqrt(); + } + double evaluate(double* args, const std::map& variables) const { + return std::sqrt(args[0]); + } + ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; +}; + +class LEPTON_EXPORT Operation::Exp : public Operation { +public: + Exp() { + } + std::string getName() const { + return "exp"; + } + Id getId() const { + return EXP; + } + int getNumArguments() const { + return 1; + } + Operation* clone() const { + return new Exp(); + } + double evaluate(double* args, const std::map& variables) const { + return std::exp(args[0]); + } + ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; +}; + +class LEPTON_EXPORT Operation::Log : public Operation { +public: + Log() { + } + std::string getName() const { + return "log"; + } + Id getId() const { + return LOG; + } + int getNumArguments() const { + return 1; + } + Operation* clone() const { + return new Log(); + } + double evaluate(double* args, const std::map& variables) const { + return std::log(args[0]); + } + ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; +}; + +class LEPTON_EXPORT Operation::Sin : public Operation { +public: + Sin() { + } + std::string getName() const { + return "sin"; + } + Id getId() const { + return SIN; + } + int getNumArguments() const { + return 1; + } + Operation* clone() const { + return new Sin(); + } + double evaluate(double* args, const std::map& variables) const { + return std::sin(args[0]); + } + ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; +}; + +class LEPTON_EXPORT Operation::Cos : public Operation { +public: + Cos() { + } + std::string getName() const { + return "cos"; + } + Id getId() const { + return COS; + } + int getNumArguments() const { + return 1; + } + Operation* clone() const { + return new Cos(); + } + double evaluate(double* args, const std::map& variables) const { + return std::cos(args[0]); + } + ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; +}; + +class LEPTON_EXPORT Operation::Sec : public Operation { +public: + Sec() { + } + std::string getName() const { + return "sec"; + } + Id getId() const { + return SEC; + } + int getNumArguments() const { + return 1; + } + Operation* clone() const { + return new Sec(); + } + double evaluate(double* args, const std::map& variables) const { + return 1.0/std::cos(args[0]); + } + ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; +}; + +class LEPTON_EXPORT Operation::Csc : public Operation { +public: + Csc() { + } + std::string getName() const { + return "csc"; + } + Id getId() const { + return CSC; + } + int getNumArguments() const { + return 1; + } + Operation* clone() const { + return new Csc(); + } + double evaluate(double* args, const std::map& variables) const { + return 1.0/std::sin(args[0]); + } + ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; +}; + +class LEPTON_EXPORT Operation::Tan : public Operation { +public: + Tan() { + } + std::string getName() const { + return "tan"; + } + Id getId() const { + return TAN; + } + int getNumArguments() const { + return 1; + } + Operation* clone() const { + return new Tan(); + } + double evaluate(double* args, const std::map& variables) const { + return std::tan(args[0]); + } + ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; +}; + +class LEPTON_EXPORT Operation::Cot : public Operation { +public: + Cot() { + } + std::string getName() const { + return "cot"; + } + Id getId() const { + return COT; + } + int getNumArguments() const { + return 1; + } + Operation* clone() const { + return new Cot(); + } + double evaluate(double* args, const std::map& variables) const { + return 1.0/std::tan(args[0]); + } + ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; +}; + +class LEPTON_EXPORT Operation::Asin : public Operation { +public: + Asin() { + } + std::string getName() const { + return "asin"; + } + Id getId() const { + return ASIN; + } + int getNumArguments() const { + return 1; + } + Operation* clone() const { + return new Asin(); + } + double evaluate(double* args, const std::map& variables) const { + return std::asin(args[0]); + } + ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; +}; + +class LEPTON_EXPORT Operation::Acos : public Operation { +public: + Acos() { + } + std::string getName() const { + return "acos"; + } + Id getId() const { + return ACOS; + } + int getNumArguments() const { + return 1; + } + Operation* clone() const { + return new Acos(); + } + double evaluate(double* args, const std::map& variables) const { + return std::acos(args[0]); + } + ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; +}; + +class LEPTON_EXPORT Operation::Atan : public Operation { +public: + Atan() { + } + std::string getName() const { + return "atan"; + } + Id getId() const { + return ATAN; + } + int getNumArguments() const { + return 1; + } + Operation* clone() const { + return new Atan(); + } + double evaluate(double* args, const std::map& variables) const { + return std::atan(args[0]); + } + ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; +}; + +class LEPTON_EXPORT Operation::Atan2 : public Operation { +public: + Atan2() { + } + std::string getName() const { + return "atan2"; + } + Id getId() const { + return ATAN2; + } + int getNumArguments() const { + return 2; + } + Operation* clone() const { + return new Atan2(); + } + double evaluate(double* args, const std::map& variables) const { + return std::atan2(args[0], args[1]); + } + ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; +}; + +class LEPTON_EXPORT Operation::Sinh : public Operation { +public: + Sinh() { + } + std::string getName() const { + return "sinh"; + } + Id getId() const { + return SINH; + } + int getNumArguments() const { + return 1; + } + Operation* clone() const { + return new Sinh(); + } + double evaluate(double* args, const std::map& variables) const { + return std::sinh(args[0]); + } + ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; +}; + +class LEPTON_EXPORT Operation::Cosh : public Operation { +public: + Cosh() { + } + std::string getName() const { + return "cosh"; + } + Id getId() const { + return COSH; + } + int getNumArguments() const { + return 1; + } + Operation* clone() const { + return new Cosh(); + } + double evaluate(double* args, const std::map& variables) const { + return std::cosh(args[0]); + } + ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; +}; + +class LEPTON_EXPORT Operation::Tanh : public Operation { +public: + Tanh() { + } + std::string getName() const { + return "tanh"; + } + Id getId() const { + return TANH; + } + int getNumArguments() const { + return 1; + } + Operation* clone() const { + return new Tanh(); + } + double evaluate(double* args, const std::map& variables) const { + return std::tanh(args[0]); + } + ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; +}; + +class LEPTON_EXPORT Operation::Erf : public Operation { +public: + Erf() { + } + std::string getName() const { + return "erf"; + } + Id getId() const { + return ERF; + } + int getNumArguments() const { + return 1; + } + Operation* clone() const { + return new Erf(); + } + double evaluate(double* args, const std::map& variables) const; + ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; +}; + +class LEPTON_EXPORT Operation::Erfc : public Operation { +public: + Erfc() { + } + std::string getName() const { + return "erfc"; + } + Id getId() const { + return ERFC; + } + int getNumArguments() const { + return 1; + } + Operation* clone() const { + return new Erfc(); + } + double evaluate(double* args, const std::map& variables) const; + ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; +}; + +class LEPTON_EXPORT Operation::Step : public Operation { +public: + Step() { + } + std::string getName() const { + return "step"; + } + Id getId() const { + return STEP; + } + int getNumArguments() const { + return 1; + } + Operation* clone() const { + return new Step(); + } + double evaluate(double* args, const std::map& variables) const { + return (args[0] >= 0.0 ? 1.0 : 0.0); + } + ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; +}; + +class LEPTON_EXPORT Operation::Delta : public Operation { +public: + Delta() { + } + std::string getName() const { + return "delta"; + } + Id getId() const { + return DELTA; + } + int getNumArguments() const { + return 1; + } + Operation* clone() const { + return new Delta(); + } + double evaluate(double* args, const std::map& variables) const { + return (args[0] == 0.0 ? 1.0 : 0.0); + } + ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; +}; + +class LEPTON_EXPORT Operation::Square : public Operation { +public: + Square() { + } + std::string getName() const { + return "square"; + } + Id getId() const { + return SQUARE; + } + int getNumArguments() const { + return 1; + } + Operation* clone() const { + return new Square(); + } + double evaluate(double* args, const std::map& variables) const { + return args[0]*args[0]; + } + ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; +}; + +class LEPTON_EXPORT Operation::Cube : public Operation { +public: + Cube() { + } + std::string getName() const { + return "cube"; + } + Id getId() const { + return CUBE; + } + int getNumArguments() const { + return 1; + } + Operation* clone() const { + return new Cube(); + } + double evaluate(double* args, const std::map& variables) const { + return args[0]*args[0]*args[0]; + } + ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; +}; + +class LEPTON_EXPORT Operation::Reciprocal : public Operation { +public: + Reciprocal() { + } + std::string getName() const { + return "recip"; + } + Id getId() const { + return RECIPROCAL; + } + int getNumArguments() const { + return 1; + } + Operation* clone() const { + return new Reciprocal(); + } + double evaluate(double* args, const std::map& variables) const { + return 1.0/args[0]; + } + ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; +}; + +class LEPTON_EXPORT Operation::AddConstant : public Operation { +public: + AddConstant(double value) : value(value) { + } + std::string getName() const { + std::stringstream name; + name << value << "+"; + return name.str(); + } + Id getId() const { + return ADD_CONSTANT; + } + int getNumArguments() const { + return 1; + } + Operation* clone() const { + return new AddConstant(value); + } + double evaluate(double* args, const std::map& variables) const { + return args[0]+value; + } + ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; + double getValue() const { + return value; + } + bool operator!=(const Operation& op) const { + const AddConstant* o = dynamic_cast(&op); + return (o == NULL || o->value != value); + } +private: + double value; +}; + +class LEPTON_EXPORT Operation::MultiplyConstant : public Operation { +public: + MultiplyConstant(double value) : value(value) { + } + std::string getName() const { + std::stringstream name; + name << value << "*"; + return name.str(); + } + Id getId() const { + return MULTIPLY_CONSTANT; + } + int getNumArguments() const { + return 1; + } + Operation* clone() const { + return new MultiplyConstant(value); + } + double evaluate(double* args, const std::map& variables) const { + return args[0]*value; + } + ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; + double getValue() const { + return value; + } + bool operator!=(const Operation& op) const { + const MultiplyConstant* o = dynamic_cast(&op); + return (o == NULL || o->value != value); + } +private: + double value; +}; + +class LEPTON_EXPORT Operation::PowerConstant : public Operation { +public: + PowerConstant(double value) : value(value) { + intValue = (int) value; + isIntPower = (intValue == value); + } + std::string getName() const { + std::stringstream name; + name << "^" << value; + return name.str(); + } + Id getId() const { + return POWER_CONSTANT; + } + int getNumArguments() const { + return 1; + } + Operation* clone() const { + return new PowerConstant(value); + } + double evaluate(double* args, const std::map& variables) const { + if (isIntPower) { + // Integer powers can be computed much more quickly by repeated multiplication. + + int exponent = intValue; + double base = args[0]; + if (exponent < 0) { + exponent = -exponent; + base = 1.0/base; + } + double result = 1.0; + while (exponent != 0) { + if ((exponent&1) == 1) + result *= base; + base *= base; + exponent = exponent>>1; + } + return result; + } + else + return std::pow(args[0], value); + } + ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; + double getValue() const { + return value; + } + bool operator!=(const Operation& op) const { + const PowerConstant* o = dynamic_cast(&op); + return (o == NULL || o->value != value); + } + bool isInfixOperator() const { + return true; + } +private: + double value; + int intValue; + bool isIntPower; +}; + +class LEPTON_EXPORT Operation::Min : public Operation { +public: + Min() { + } + std::string getName() const { + return "min"; + } + Id getId() const { + return MIN; + } + int getNumArguments() const { + return 2; + } + Operation* clone() const { + return new Min(); + } + double evaluate(double* args, const std::map& variables) const { + // parens around (std::min) are workaround for horrible microsoft max/min macro trouble + return (std::min)(args[0], args[1]); + } + ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; +}; + +class LEPTON_EXPORT Operation::Max : public Operation { +public: + Max() { + } + std::string getName() const { + return "max"; + } + Id getId() const { + return MAX; + } + int getNumArguments() const { + return 2; + } + Operation* clone() const { + return new Max(); + } + double evaluate(double* args, const std::map& variables) const { + // parens around (std::min) are workaround for horrible microsoft max/min macro trouble + return (std::max)(args[0], args[1]); + } + ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; +}; + +class LEPTON_EXPORT Operation::Abs : public Operation { +public: + Abs() { + } + std::string getName() const { + return "abs"; + } + Id getId() const { + return ABS; + } + int getNumArguments() const { + return 1; + } + Operation* clone() const { + return new Abs(); + } + double evaluate(double* args, const std::map& variables) const { + return std::abs(args[0]); + } + ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; +}; + +class LEPTON_EXPORT Operation::Floor : public Operation { +public: + + Floor() { + } + std::string getName() const { + return "floor"; + } + Id getId() const { + return FLOOR; + } + int getNumArguments() const { + return 1; + } + Operation* clone() const { + return new Floor(); + } + double evaluate(double* args, const std::map& variables) const { + return std::floor(args[0]); + } + ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; +}; + +class LEPTON_EXPORT Operation::Ceil : public Operation { +public: + Ceil() { + } + std::string getName() const { + return "ceil"; + } + Id getId() const { + return CEIL; + } + int getNumArguments() const { + return 1; + } + Operation* clone() const { + return new Ceil(); + } + double evaluate(double* args, const std::map& variables) const { + return std::ceil(args[0]); + } + ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; +}; + +class LEPTON_EXPORT Operation::Select : public Operation { +public: + Select() { + } + std::string getName() const { + return "select"; + } + Id getId() const { + return SELECT; + } + int getNumArguments() const { + return 3; + } + Operation* clone() const { + return new Select(); + } + double evaluate(double* args, const std::map& variables) const { + return (args[0] != 0.0 ? args[1] : args[2]); + } + ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; +}; + +} // namespace LMP_Lepton + +#endif /*LEPTON_OPERATION_H_*/ diff --git a/lib/lepton/include/lepton/ParsedExpression.h b/lib/lepton/include/lepton/ParsedExpression.h new file mode 100644 index 0000000000..586acb4d2c --- /dev/null +++ b/lib/lepton/include/lepton/ParsedExpression.h @@ -0,0 +1,131 @@ +#ifndef LEPTON_PARSED_EXPRESSION_H_ +#define LEPTON_PARSED_EXPRESSION_H_ + +/* -------------------------------------------------------------------------- * + * Lepton * + * -------------------------------------------------------------------------- * + * This is part of the Lepton expression parser originating from * + * Simbios, the NIH National Center for Physics-Based Simulation of * + * Biological Structures at Stanford, funded under the NIH Roadmap for * + * Medical Research, grant U54 GM072970. See https://simtk.org. * + * * + * Portions copyright (c) 2009=2013 Stanford University and the Authors. * + * Authors: Peter Eastman * + * Contributors: * + * * + * Permission is hereby granted, free of charge, to any person obtaining a * + * copy of this software and associated documentation files (the "Software"), * + * to deal in the Software without restriction, including without limitation * + * the rights to use, copy, modify, merge, publish, distribute, sublicense, * + * and/or sell copies of the Software, and to permit persons to whom the * + * Software is furnished to do so, subject to the following conditions: * + * * + * The above copyright notice and this permission notice shall be included in * + * all copies or substantial portions of the Software. * + * * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL * + * THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, * + * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR * + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE * + * USE OR OTHER DEALINGS IN THE SOFTWARE. * + * -------------------------------------------------------------------------- */ + +#include "ExpressionTreeNode.h" +#include "windowsIncludes.h" +#include +#include + +namespace LMP_Lepton { + +class CompiledExpression; +class ExpressionProgram; + +/** + * This class represents the result of parsing an expression. It provides methods for working with the + * expression in various ways, such as evaluating it, getting the tree representation of the expresson, etc. + */ + +class LEPTON_EXPORT ParsedExpression { +public: + /** + * Create an uninitialized ParsedExpression. This exists so that ParsedExpressions can be put in STL containers. + * Doing anything with it will produce an exception. + */ + ParsedExpression(); + /** + * Create a ParsedExpression. Normally you will not call this directly. Instead, use the Parser class + * to parse expression. + */ + ParsedExpression(const ExpressionTreeNode& rootNode); + /** + * Get the root node of the expression's abstract syntax tree. + */ + const ExpressionTreeNode& getRootNode() const; + /** + * Evaluate the expression. If the expression involves any variables, this method will throw an exception. + */ + double evaluate() const; + /** + * Evaluate the expression. + * + * @param variables a map specifying the values of all variables that appear in the expression. If any + * variable appears in the expression but is not included in this map, an exception + * will be thrown. + */ + double evaluate(const std::map& variables) const; + /** + * Create a new ParsedExpression which produces the same result as this one, but is faster to evaluate. + */ + ParsedExpression optimize() const; + /** + * Create a new ParsedExpression which produces the same result as this one, but is faster to evaluate. + * + * @param variables a map specifying values for a subset of variables that appear in the expression. + * All occurrences of these variables in the expression are replaced with the values + * specified. + */ + ParsedExpression optimize(const std::map& variables) const; + /** + * Create a new ParsedExpression which is the analytic derivative of this expression with respect to a + * particular variable. + * + * @param variable the variable with respect to which the derivate should be taken + */ + ParsedExpression differentiate(const std::string& variable) const; + /** + * Create an ExpressionProgram that represents the same calculation as this expression. + */ + ExpressionProgram createProgram() const; + /** + * Create a CompiledExpression that represents the same calculation as this expression. + */ + CompiledExpression createCompiledExpression() const; + /** + * Create a new ParsedExpression which is identical to this one, except that the names of some + * variables have been changed. + * + * @param replacements a map whose keys are the names of variables, and whose values are the + * new names to replace them with + */ + ParsedExpression renameVariables(const std::map& replacements) const; +private: + static double evaluate(const ExpressionTreeNode& node, const std::map& variables); + static ExpressionTreeNode preevaluateVariables(const ExpressionTreeNode& node, const std::map& variables); + static ExpressionTreeNode precalculateConstantSubexpressions(const ExpressionTreeNode& node); + static ExpressionTreeNode substituteSimplerExpression(const ExpressionTreeNode& node); + static ExpressionTreeNode differentiate(const ExpressionTreeNode& node, const std::string& variable); + static bool isConstant(const ExpressionTreeNode& node); + static double getConstantValue(const ExpressionTreeNode& node); + static ExpressionTreeNode renameNodeVariables(const ExpressionTreeNode& node, const std::map& replacements); + ExpressionTreeNode rootNode; +}; + +LEPTON_EXPORT std::ostream& operator<<(std::ostream& out, const ExpressionTreeNode& node); + +LEPTON_EXPORT std::ostream& operator<<(std::ostream& out, const ParsedExpression& exp); + +} // namespace LMP_Lepton + +#endif /*LEPTON_PARSED_EXPRESSION_H_*/ diff --git a/lib/lepton/include/lepton/Parser.h b/lib/lepton/include/lepton/Parser.h new file mode 100644 index 0000000000..9eefe3f59e --- /dev/null +++ b/lib/lepton/include/lepton/Parser.h @@ -0,0 +1,77 @@ +#ifndef LEPTON_PARSER_H_ +#define LEPTON_PARSER_H_ + +/* -------------------------------------------------------------------------- * + * Lepton * + * -------------------------------------------------------------------------- * + * This is part of the Lepton expression parser originating from * + * Simbios, the NIH National Center for Physics-Based Simulation of * + * Biological Structures at Stanford, funded under the NIH Roadmap for * + * Medical Research, grant U54 GM072970. See https://simtk.org. * + * * + * Portions copyright (c) 2009 Stanford University and the Authors. * + * Authors: Peter Eastman * + * Contributors: * + * * + * Permission is hereby granted, free of charge, to any person obtaining a * + * copy of this software and associated documentation files (the "Software"), * + * to deal in the Software without restriction, including without limitation * + * the rights to use, copy, modify, merge, publish, distribute, sublicense, * + * and/or sell copies of the Software, and to permit persons to whom the * + * Software is furnished to do so, subject to the following conditions: * + * * + * The above copyright notice and this permission notice shall be included in * + * all copies or substantial portions of the Software. * + * * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL * + * THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, * + * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR * + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE * + * USE OR OTHER DEALINGS IN THE SOFTWARE. * + * -------------------------------------------------------------------------- */ + +#include "windowsIncludes.h" +#include +#include +#include + +namespace LMP_Lepton { + +class CustomFunction; +class ExpressionTreeNode; +class Operation; +class ParsedExpression; +class ParseToken; + +/** + * This class provides the main interface for parsing expressions. + */ + +class LEPTON_EXPORT Parser { +public: + /** + * Parse a mathematical expression and return a representation of it as an abstract syntax tree. + */ + static ParsedExpression parse(const std::string& expression); + /** + * Parse a mathematical expression and return a representation of it as an abstract syntax tree. + * + * @param customFunctions a map specifying user defined functions that may appear in the expression. + * The key are function names, and the values are corresponding CustomFunction objects. + */ + static ParsedExpression parse(const std::string& expression, const std::map& customFunctions); +private: + static std::string trim(const std::string& expression); + static std::vector tokenize(const std::string& expression); + static ParseToken getNextToken(const std::string& expression, int start); + static ExpressionTreeNode parsePrecedence(const std::vector& tokens, int& pos, const std::map& customFunctions, + const std::map& subexpressionDefs, int precedence); + static Operation* getOperatorOperation(const std::string& name); + static Operation* getFunctionOperation(const std::string& name, const std::map& customFunctions); +}; + +} // namespace LMP_Lepton + +#endif /*LEPTON_PARSER_H_*/ diff --git a/lib/lepton/include/lepton/windowsIncludes.h b/lib/lepton/include/lepton/windowsIncludes.h new file mode 100644 index 0000000000..798229850e --- /dev/null +++ b/lib/lepton/include/lepton/windowsIncludes.h @@ -0,0 +1,41 @@ +#ifndef LEPTON_WINDOW_INCLUDE_H_ +#define LEPTON_WINDOW_INCLUDE_H_ + +/* + * Shared libraries are messy in Visual Studio. We have to distinguish three + * cases: + * (1) this header is being used to build the Lepton shared library + * (dllexport) + * (2) this header is being used by a *client* of the Lepton shared + * library (dllimport) + * (3) we are building the Lepton static library, or the client is + * being compiled with the expectation of linking with the + * Lepton static library (nothing special needed) + * In the CMake script for building this library, we define one of the symbols + * Lepton_BUILDING_{SHARED|STATIC}_LIBRARY + * Client code normally has no special symbol defined, in which case we'll + * assume it wants to use the shared library. However, if the client defines + * the symbol LEPTON_USE_STATIC_LIBRARIES we'll suppress the dllimport so + * that the client code can be linked with static libraries. Note that + * the client symbol is not library dependent, while the library symbols + * affect only the Lepton library, meaning that other libraries can + * be clients of this one. However, we are assuming all-static or all-shared. + */ + +#ifdef _MSC_VER + // We don't want to hear about how sprintf is "unsafe". + #pragma warning(disable:4996) + // Keep MS VC++ quiet about lack of dll export of private members. + #pragma warning(disable:4251) + #if defined(LEPTON_BUILDING_SHARED_LIBRARY) + #define LEPTON_EXPORT __declspec(dllexport) + #elif defined(LEPTON_BUILDING_STATIC_LIBRARY) || defined(LEPTON_USE_STATIC_LIBRARIES) + #define LEPTON_EXPORT + #else + #define LEPTON_EXPORT __declspec(dllimport) // i.e., a client of a shared library + #endif +#else + #define LEPTON_EXPORT // Linux, Mac +#endif + +#endif // LEPTON_WINDOW_INCLUDE_H_ diff --git a/lib/lepton/src/CompiledExpression.cpp b/lib/lepton/src/CompiledExpression.cpp new file mode 100644 index 0000000000..7805c4674d --- /dev/null +++ b/lib/lepton/src/CompiledExpression.cpp @@ -0,0 +1,418 @@ +/* -------------------------------------------------------------------------- * + * Lepton * + * -------------------------------------------------------------------------- * + * This is part of the Lepton expression parser originating from * + * Simbios, the NIH National Center for Physics-Based Simulation of * + * Biological Structures at Stanford, funded under the NIH Roadmap for * + * Medical Research, grant U54 GM072970. See https://simtk.org. * + * * + * Portions copyright (c) 2013-2019 Stanford University and the Authors. * + * Authors: Peter Eastman * + * Contributors: * + * * + * Permission is hereby granted, free of charge, to any person obtaining a * + * copy of this software and associated documentation files (the "Software"), * + * to deal in the Software without restriction, including without limitation * + * the rights to use, copy, modify, merge, publish, distribute, sublicense, * + * and/or sell copies of the Software, and to permit persons to whom the * + * Software is furnished to do so, subject to the following conditions: * + * * + * The above copyright notice and this permission notice shall be included in * + * all copies or substantial portions of the Software. * + * * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL * + * THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, * + * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR * + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE * + * USE OR OTHER DEALINGS IN THE SOFTWARE. * + * -------------------------------------------------------------------------- */ + +#include "lepton/CompiledExpression.h" +#include "lepton/Operation.h" +#include "lepton/ParsedExpression.h" +#include + +using namespace LMP_Lepton; +using namespace std; +#ifdef LEPTON_USE_JIT + using namespace asmjit; +#endif + +CompiledExpression::CompiledExpression() : jitCode(NULL) { +} + +CompiledExpression::CompiledExpression(const ParsedExpression& expression) : jitCode(NULL) { + ParsedExpression expr = expression.optimize(); // Just in case it wasn't already optimized. + vector > temps; + compileExpression(expr.getRootNode(), temps); + int maxArguments = 1; + for (int i = 0; i < (int) operation.size(); i++) + if (operation[i]->getNumArguments() > maxArguments) + maxArguments = operation[i]->getNumArguments(); + argValues.resize(maxArguments); +#ifdef LEPTON_USE_JIT + generateJitCode(); +#endif +} + +CompiledExpression::~CompiledExpression() { + for (int i = 0; i < (int) operation.size(); i++) + if (operation[i] != NULL) + delete operation[i]; +} + +CompiledExpression::CompiledExpression(const CompiledExpression& expression) : jitCode(NULL) { + *this = expression; +} + +CompiledExpression& CompiledExpression::operator=(const CompiledExpression& expression) { + arguments = expression.arguments; + target = expression.target; + variableIndices = expression.variableIndices; + variableNames = expression.variableNames; + workspace.resize(expression.workspace.size()); + argValues.resize(expression.argValues.size()); + operation.resize(expression.operation.size()); + for (int i = 0; i < (int) operation.size(); i++) + operation[i] = expression.operation[i]->clone(); + setVariableLocations(variablePointers); + return *this; +} + +void CompiledExpression::compileExpression(const ExpressionTreeNode& node, vector >& temps) { + if (findTempIndex(node, temps) != -1) + return; // We have already processed a node identical to this one. + + // Process the child nodes. + + vector args; + for (int i = 0; i < node.getChildren().size(); i++) { + compileExpression(node.getChildren()[i], temps); + args.push_back(findTempIndex(node.getChildren()[i], temps)); + } + + // Process this node. + + if (node.getOperation().getId() == Operation::VARIABLE) { + variableIndices[node.getOperation().getName()] = (int) workspace.size(); + variableNames.insert(node.getOperation().getName()); + } + else { + int stepIndex = (int) arguments.size(); + arguments.push_back(vector()); + target.push_back((int) workspace.size()); + operation.push_back(node.getOperation().clone()); + if (args.size() == 0) + arguments[stepIndex].push_back(0); // The value won't actually be used. We just need something there. + else { + // If the arguments are sequential, we can just pass a pointer to the first one. + + bool sequential = true; + for (int i = 1; i < args.size(); i++) + if (args[i] != args[i-1]+1) + sequential = false; + if (sequential) + arguments[stepIndex].push_back(args[0]); + else + arguments[stepIndex] = args; + } + } + temps.push_back(make_pair(node, (int) workspace.size())); + workspace.push_back(0.0); +} + +int CompiledExpression::findTempIndex(const ExpressionTreeNode& node, vector >& temps) { + for (int i = 0; i < (int) temps.size(); i++) + if (temps[i].first == node) + return i; + return -1; +} + +const set& CompiledExpression::getVariables() const { + return variableNames; +} + +double& CompiledExpression::getVariableReference(const string& name) { + map::iterator pointer = variablePointers.find(name); + if (pointer != variablePointers.end()) + return *pointer->second; + map::iterator index = variableIndices.find(name); + if (index == variableIndices.end()) + throw Exception("getVariableReference: Unknown variable '"+name+"'"); + return workspace[index->second]; +} + +void CompiledExpression::setVariableLocations(map& variableLocations) { + variablePointers = variableLocations; +#ifdef LEPTON_USE_JIT + // Rebuild the JIT code. + + if (workspace.size() > 0) + generateJitCode(); +#else + // Make a list of all variables we will need to copy before evaluating the expression. + + variablesToCopy.clear(); + for (map::const_iterator iter = variableIndices.begin(); iter != variableIndices.end(); ++iter) { + map::iterator pointer = variablePointers.find(iter->first); + if (pointer != variablePointers.end()) + variablesToCopy.push_back(make_pair(&workspace[iter->second], pointer->second)); + } +#endif +} + +double CompiledExpression::evaluate() const { +#ifdef LEPTON_USE_JIT + return jitCode(); +#else + for (int i = 0; i < variablesToCopy.size(); i++) + *variablesToCopy[i].first = *variablesToCopy[i].second; + + // Loop over the operations and evaluate each one. + + for (int step = 0; step < operation.size(); step++) { + const vector& args = arguments[step]; + if (args.size() == 1) + workspace[target[step]] = operation[step]->evaluate(&workspace[args[0]], dummyVariables); + else { + for (int i = 0; i < args.size(); i++) + argValues[i] = workspace[args[i]]; + workspace[target[step]] = operation[step]->evaluate(&argValues[0], dummyVariables); + } + } + return workspace[workspace.size()-1]; +#endif +} + +#ifdef LEPTON_USE_JIT +static double evaluateOperation(Operation* op, double* args) { + static map dummyVariables; + return op->evaluate(args, dummyVariables); +} + +void CompiledExpression::generateJitCode() { + CodeHolder code; + code.init(runtime.getCodeInfo()); + X86Compiler c(&code); + c.addFunc(FuncSignature0()); + vector workspaceVar(workspace.size()); + for (int i = 0; i < (int) workspaceVar.size(); i++) + workspaceVar[i] = c.newXmmSd(); + X86Gp argsPointer = c.newIntPtr(); + c.mov(argsPointer, imm_ptr(&argValues[0])); + + // Load the arguments into variables. + + for (set::const_iterator iter = variableNames.begin(); iter != variableNames.end(); ++iter) { + map::iterator index = variableIndices.find(*iter); + X86Gp variablePointer = c.newIntPtr(); + c.mov(variablePointer, imm_ptr(&getVariableReference(index->first))); + c.movsd(workspaceVar[index->second], x86::ptr(variablePointer, 0, 0)); + } + + // Make a list of all constants that will be needed for evaluation. + + vector operationConstantIndex(operation.size(), -1); + for (int step = 0; step < (int) operation.size(); step++) { + // Find the constant value (if any) used by this operation. + + Operation& op = *operation[step]; + double value; + if (op.getId() == Operation::CONSTANT) + value = dynamic_cast(op).getValue(); + else if (op.getId() == Operation::ADD_CONSTANT) + value = dynamic_cast(op).getValue(); + else if (op.getId() == Operation::MULTIPLY_CONSTANT) + value = dynamic_cast(op).getValue(); + else if (op.getId() == Operation::RECIPROCAL) + value = 1.0; + else if (op.getId() == Operation::STEP) + value = 1.0; + else if (op.getId() == Operation::DELTA) + value = 1.0; + else + continue; + + // See if we already have a variable for this constant. + + for (int i = 0; i < (int) constants.size(); i++) + if (value == constants[i]) { + operationConstantIndex[step] = i; + break; + } + if (operationConstantIndex[step] == -1) { + operationConstantIndex[step] = constants.size(); + constants.push_back(value); + } + } + + // Load constants into variables. + + vector constantVar(constants.size()); + if (constants.size() > 0) { + X86Gp constantsPointer = c.newIntPtr(); + c.mov(constantsPointer, imm_ptr(&constants[0])); + for (int i = 0; i < (int) constants.size(); i++) { + constantVar[i] = c.newXmmSd(); + c.movsd(constantVar[i], x86::ptr(constantsPointer, 8*i, 0)); + } + } + + // Evaluate the operations. + + for (int step = 0; step < (int) operation.size(); step++) { + Operation& op = *operation[step]; + vector args = arguments[step]; + if (args.size() == 1) { + // One or more sequential arguments. Fill out the list. + + for (int i = 1; i < op.getNumArguments(); i++) + args.push_back(args[0]+i); + } + + // Generate instructions to execute this operation. + + switch (op.getId()) { + case Operation::CONSTANT: + c.movsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + break; + case Operation::ADD: + c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); + c.addsd(workspaceVar[target[step]], workspaceVar[args[1]]); + break; + case Operation::SUBTRACT: + c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); + c.subsd(workspaceVar[target[step]], workspaceVar[args[1]]); + break; + case Operation::MULTIPLY: + c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); + c.mulsd(workspaceVar[target[step]], workspaceVar[args[1]]); + break; + case Operation::DIVIDE: + c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); + c.divsd(workspaceVar[target[step]], workspaceVar[args[1]]); + break; + case Operation::POWER: + generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], pow); + break; + case Operation::NEGATE: + c.xorps(workspaceVar[target[step]], workspaceVar[target[step]]); + c.subsd(workspaceVar[target[step]], workspaceVar[args[0]]); + break; + case Operation::SQRT: + c.sqrtsd(workspaceVar[target[step]], workspaceVar[args[0]]); + break; + case Operation::EXP: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], exp); + break; + case Operation::LOG: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], log); + break; + case Operation::SIN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], sin); + break; + case Operation::COS: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], cos); + break; + case Operation::TAN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tan); + break; + case Operation::ASIN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], asin); + break; + case Operation::ACOS: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], acos); + break; + case Operation::ATAN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], atan); + break; + case Operation::ATAN2: + generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], atan2); + break; + case Operation::SINH: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], sinh); + break; + case Operation::COSH: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], cosh); + break; + case Operation::TANH: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tanh); + break; + case Operation::STEP: + c.xorps(workspaceVar[target[step]], workspaceVar[target[step]]); + c.cmpsd(workspaceVar[target[step]], workspaceVar[args[0]], imm(18)); // Comparison mode is _CMP_LE_OQ = 18 + c.andps(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + break; + case Operation::DELTA: + c.xorps(workspaceVar[target[step]], workspaceVar[target[step]]); + c.cmpsd(workspaceVar[target[step]], workspaceVar[args[0]], imm(16)); // Comparison mode is _CMP_EQ_OS = 16 + c.andps(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + break; + case Operation::SQUARE: + c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); + c.mulsd(workspaceVar[target[step]], workspaceVar[args[0]]); + break; + case Operation::CUBE: + c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); + c.mulsd(workspaceVar[target[step]], workspaceVar[args[0]]); + c.mulsd(workspaceVar[target[step]], workspaceVar[args[0]]); + break; + case Operation::RECIPROCAL: + c.movsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + c.divsd(workspaceVar[target[step]], workspaceVar[args[0]]); + break; + case Operation::ADD_CONSTANT: + c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); + c.addsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + break; + case Operation::MULTIPLY_CONSTANT: + c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); + c.mulsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + break; + case Operation::ABS: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], fabs); + break; + case Operation::FLOOR: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], floor); + break; + case Operation::CEIL: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], ceil); + break; + default: + // Just invoke evaluateOperation(). + + for (int i = 0; i < (int) args.size(); i++) + c.movsd(x86::ptr(argsPointer, 8*i, 0), workspaceVar[args[i]]); + X86Gp fn = c.newIntPtr(); + c.mov(fn, imm_ptr((void*) evaluateOperation)); + CCFuncCall* call = c.call(fn, FuncSignature2()); + call->setArg(0, imm_ptr(&op)); + call->setArg(1, imm_ptr(&argValues[0])); + call->setRet(0, workspaceVar[target[step]]); + } + } + c.ret(workspaceVar[workspace.size()-1]); + c.endFunc(); + c.finalize(); + runtime.add(&jitCode, &code); +} + +void CompiledExpression::generateSingleArgCall(X86Compiler& c, X86Xmm& dest, X86Xmm& arg, double (*function)(double)) { + X86Gp fn = c.newIntPtr(); + c.mov(fn, imm_ptr((void*) function)); + CCFuncCall* call = c.call(fn, FuncSignature1()); + call->setArg(0, arg); + call->setRet(0, dest); +} + +void CompiledExpression::generateTwoArgCall(X86Compiler& c, X86Xmm& dest, X86Xmm& arg1, X86Xmm& arg2, double (*function)(double, double)) { + X86Gp fn = c.newIntPtr(); + c.mov(fn, imm_ptr((void*) function)); + CCFuncCall* call = c.call(fn, FuncSignature2()); + call->setArg(0, arg1); + call->setArg(1, arg2); + call->setRet(0, dest); +} +#endif diff --git a/lib/lepton/src/ExpressionProgram.cpp b/lib/lepton/src/ExpressionProgram.cpp new file mode 100644 index 0000000000..74c545287b --- /dev/null +++ b/lib/lepton/src/ExpressionProgram.cpp @@ -0,0 +1,110 @@ +/* -------------------------------------------------------------------------- * + * Lepton * + * -------------------------------------------------------------------------- * + * This is part of the Lepton expression parser originating from * + * Simbios, the NIH National Center for Physics-Based Simulation of * + * Biological Structures at Stanford, funded under the NIH Roadmap for * + * Medical Research, grant U54 GM072970. See https://simtk.org. * + * * + * Portions copyright (c) 2009-2018 Stanford University and the Authors. * + * Authors: Peter Eastman * + * Contributors: * + * * + * Permission is hereby granted, free of charge, to any person obtaining a * + * copy of this software and associated documentation files (the "Software"), * + * to deal in the Software without restriction, including without limitation * + * the rights to use, copy, modify, merge, publish, distribute, sublicense, * + * and/or sell copies of the Software, and to permit persons to whom the * + * Software is furnished to do so, subject to the following conditions: * + * * + * The above copyright notice and this permission notice shall be included in * + * all copies or substantial portions of the Software. * + * * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL * + * THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, * + * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR * + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE * + * USE OR OTHER DEALINGS IN THE SOFTWARE. * + * -------------------------------------------------------------------------- */ + +#include "lepton/ExpressionProgram.h" +#include "lepton/Operation.h" +#include "lepton/ParsedExpression.h" + +using namespace LMP_Lepton; +using namespace std; + +ExpressionProgram::ExpressionProgram() : maxArgs(0), stackSize(0) { +} + +ExpressionProgram::ExpressionProgram(const ParsedExpression& expression) : maxArgs(0), stackSize(0) { + buildProgram(expression.getRootNode()); + int currentStackSize = 0; + for (int i = 0; i < (int) operations.size(); i++) { + int args = operations[i]->getNumArguments(); + if (args > maxArgs) + maxArgs = args; + currentStackSize += 1-args; + if (currentStackSize > stackSize) + stackSize = currentStackSize; + } +} + +ExpressionProgram::~ExpressionProgram() { + for (int i = 0; i < (int) operations.size(); i++) + delete operations[i]; +} + +ExpressionProgram::ExpressionProgram(const ExpressionProgram& program) { + *this = program; +} + +ExpressionProgram& ExpressionProgram::operator=(const ExpressionProgram& program) { + maxArgs = program.maxArgs; + stackSize = program.stackSize; + operations.resize(program.operations.size()); + for (int i = 0; i < (int) operations.size(); i++) + operations[i] = program.operations[i]->clone(); + return *this; +} + +void ExpressionProgram::buildProgram(const ExpressionTreeNode& node) { + for (int i = (int) node.getChildren().size()-1; i >= 0; i--) + buildProgram(node.getChildren()[i]); + operations.push_back(node.getOperation().clone()); +} + +int ExpressionProgram::getNumOperations() const { + return (int) operations.size(); +} + +const Operation& ExpressionProgram::getOperation(int index) const { + return *operations[index]; +} + +void ExpressionProgram::setOperation(int index, Operation* operation) { + delete operations[index]; + operations[index] = operation; +} + +int ExpressionProgram::getStackSize() const { + return stackSize; +} + +double ExpressionProgram::evaluate() const { + return evaluate(map()); +} + +double ExpressionProgram::evaluate(const std::map& variables) const { + vector stack(stackSize+1); + int stackPointer = stackSize; + for (int i = 0; i < (int) operations.size(); i++) { + int numArgs = operations[i]->getNumArguments(); + double result = operations[i]->evaluate(&stack[stackPointer], variables); + stackPointer += numArgs-1; + stack[stackPointer] = result; + } + return stack[stackSize-1]; +} diff --git a/lib/lepton/src/ExpressionTreeNode.cpp b/lib/lepton/src/ExpressionTreeNode.cpp new file mode 100644 index 0000000000..e4fbbc6f50 --- /dev/null +++ b/lib/lepton/src/ExpressionTreeNode.cpp @@ -0,0 +1,107 @@ +/* -------------------------------------------------------------------------- * + * Lepton * + * -------------------------------------------------------------------------- * + * This is part of the Lepton expression parser originating from * + * Simbios, the NIH National Center for Physics-Based Simulation of * + * Biological Structures at Stanford, funded under the NIH Roadmap for * + * Medical Research, grant U54 GM072970. See https://simtk.org. * + * * + * Portions copyright (c) 2009-2015 Stanford University and the Authors. * + * Authors: Peter Eastman * + * Contributors: * + * * + * Permission is hereby granted, free of charge, to any person obtaining a * + * copy of this software and associated documentation files (the "Software"), * + * to deal in the Software without restriction, including without limitation * + * the rights to use, copy, modify, merge, publish, distribute, sublicense, * + * and/or sell copies of the Software, and to permit persons to whom the * + * Software is furnished to do so, subject to the following conditions: * + * * + * The above copyright notice and this permission notice shall be included in * + * all copies or substantial portions of the Software. * + * * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL * + * THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, * + * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR * + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE * + * USE OR OTHER DEALINGS IN THE SOFTWARE. * + * -------------------------------------------------------------------------- */ + +#include "lepton/ExpressionTreeNode.h" +#include "lepton/Exception.h" +#include "lepton/Operation.h" + +using namespace LMP_Lepton; +using namespace std; + +ExpressionTreeNode::ExpressionTreeNode(Operation* operation, const vector& children) : operation(operation), children(children) { + if (operation->getNumArguments() != children.size()) + throw Exception("wrong number of arguments to function: "+operation->getName()); +} + +ExpressionTreeNode::ExpressionTreeNode(Operation* operation, const ExpressionTreeNode& child1, const ExpressionTreeNode& child2) : operation(operation) { + children.push_back(child1); + children.push_back(child2); + if (operation->getNumArguments() != children.size()) + throw Exception("wrong number of arguments to function: "+operation->getName()); +} + +ExpressionTreeNode::ExpressionTreeNode(Operation* operation, const ExpressionTreeNode& child) : operation(operation) { + children.push_back(child); + if (operation->getNumArguments() != children.size()) + throw Exception("wrong number of arguments to function: "+operation->getName()); +} + +ExpressionTreeNode::ExpressionTreeNode(Operation* operation) : operation(operation) { + if (operation->getNumArguments() != children.size()) + throw Exception("wrong number of arguments to function: "+operation->getName()); +} + +ExpressionTreeNode::ExpressionTreeNode(const ExpressionTreeNode& node) : operation(node.operation == NULL ? NULL : node.operation->clone()), children(node.getChildren()) { +} + +ExpressionTreeNode::ExpressionTreeNode() : operation(NULL) { +} + +ExpressionTreeNode::~ExpressionTreeNode() { + if (operation != NULL) + delete operation; +} + +bool ExpressionTreeNode::operator!=(const ExpressionTreeNode& node) const { + if (node.getOperation() != getOperation()) + return true; + if (getOperation().isSymmetric() && getChildren().size() == 2) { + if (getChildren()[0] == node.getChildren()[0] && getChildren()[1] == node.getChildren()[1]) + return false; + if (getChildren()[0] == node.getChildren()[1] && getChildren()[1] == node.getChildren()[0]) + return false; + return true; + } + for (int i = 0; i < (int) getChildren().size(); i++) + if (getChildren()[i] != node.getChildren()[i]) + return true; + return false; +} + +bool ExpressionTreeNode::operator==(const ExpressionTreeNode& node) const { + return !(*this != node); +} + +ExpressionTreeNode& ExpressionTreeNode::operator=(const ExpressionTreeNode& node) { + if (operation != NULL) + delete operation; + operation = node.getOperation().clone(); + children = node.getChildren(); + return *this; +} + +const Operation& ExpressionTreeNode::getOperation() const { + return *operation; +} + +const vector& ExpressionTreeNode::getChildren() const { + return children; +} diff --git a/lib/lepton/src/MSVC_erfc.h b/lib/lepton/src/MSVC_erfc.h new file mode 100644 index 0000000000..2c6b619e89 --- /dev/null +++ b/lib/lepton/src/MSVC_erfc.h @@ -0,0 +1,91 @@ +#ifndef LEPTON_MSVC_ERFC_H_ +#define LEPTON_MSVC_ERFC_H_ + +/* + * Up to version 11 (VC++ 2012), Microsoft does not support the + * standard C99 erf() and erfc() functions so we have to fake them here. + * These were added in version 12 (VC++ 2013), which sets _MSC_VER=1800 + * (VC11 has _MSC_VER=1700). + */ + +#if defined(_MSC_VER) || defined(__MINGW32__) +#if !defined(M_PI) +#define M_PI 3.14159265358979323846264338327950288 +#endif +#endif + +#if defined(_MSC_VER) +#if _MSC_VER <= 1700 // 1700 is VC11, 1800 is VC12 +/*************************** +* erf.cpp +* author: Steve Strand +* written: 29-Jan-04 +***************************/ + +#include + +static const double rel_error= 1E-12; //calculate 12 significant figures +//you can adjust rel_error to trade off between accuracy and speed +//but don't ask for > 15 figures (assuming usual 52 bit mantissa in a double) + +static double erfc(double x); + +static double erf(double x) +//erf(x) = 2/sqrt(pi)*integral(exp(-t^2),t,0,x) +// = 2/sqrt(pi)*[x - x^3/3 + x^5/5*2! - x^7/7*3! + ...] +// = 1-erfc(x) +{ + static const double two_sqrtpi= 1.128379167095512574; // 2/sqrt(pi) + if (fabs(x) > 2.2) { + return 1.0 - erfc(x); //use continued fraction when fabs(x) > 2.2 + } + double sum= x, term= x, xsqr= x*x; + int j= 1; + do { + term*= xsqr/j; + sum-= term/(2*j+1); + ++j; + term*= xsqr/j; + sum+= term/(2*j+1); + ++j; + } while (fabs(term)/sum > rel_error); + return two_sqrtpi*sum; +} + + +static double erfc(double x) +//erfc(x) = 2/sqrt(pi)*integral(exp(-t^2),t,x,inf) +// = exp(-x^2)/sqrt(pi) * [1/x+ (1/2)/x+ (2/2)/x+ (3/2)/x+ (4/2)/x+ ...] +// = 1-erf(x) +//expression inside [] is a continued fraction so '+' means add to denominator only +{ + static const double one_sqrtpi= 0.564189583547756287; // 1/sqrt(pi) + if (fabs(x) < 2.2) { + return 1.0 - erf(x); //use series when fabs(x) < 2.2 + } + // Don't look for x==0 here! + if (x < 0) { //continued fraction only valid for x>0 + return 2.0 - erfc(-x); + } + double a=1, b=x; //last two convergent numerators + double c=x, d=x*x+0.5; //last two convergent denominators + double q1, q2= b/d; //last two convergents (a/c and b/d) + double n= 1.0, t; + do { + t= a*n+b*x; + a= b; + b= t; + t= c*n+d*x; + c= d; + d= t; + n+= 0.5; + q1= q2; + q2= b/d; + } while (fabs(q1-q2)/q2 > rel_error); + return one_sqrtpi*exp(-x*x)*q2; +} + +#endif // _MSC_VER <= 1700 +#endif // _MSC_VER + +#endif // LEPTON_MSVC_ERFC_H_ diff --git a/lib/lepton/src/Operation.cpp b/lib/lepton/src/Operation.cpp new file mode 100644 index 0000000000..512f5db321 --- /dev/null +++ b/lib/lepton/src/Operation.cpp @@ -0,0 +1,345 @@ + +/* -------------------------------------------------------------------------- * + * Lepton * + * -------------------------------------------------------------------------- * + * This is part of the Lepton expression parser originating from * + * Simbios, the NIH National Center for Physics-Based Simulation of * + * Biological Structures at Stanford, funded under the NIH Roadmap for * + * Medical Research, grant U54 GM072970. See https://simtk.org. * + * * + * Portions copyright (c) 2009-2019 Stanford University and the Authors. * + * Authors: Peter Eastman * + * Contributors: * + * * + * Permission is hereby granted, free of charge, to any person obtaining a * + * copy of this software and associated documentation files (the "Software"), * + * to deal in the Software without restriction, including without limitation * + * the rights to use, copy, modify, merge, publish, distribute, sublicense, * + * and/or sell copies of the Software, and to permit persons to whom the * + * Software is furnished to do so, subject to the following conditions: * + * * + * The above copyright notice and this permission notice shall be included in * + * all copies or substantial portions of the Software. * + * * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL * + * THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, * + * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR * + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE * + * USE OR OTHER DEALINGS IN THE SOFTWARE. * + * -------------------------------------------------------------------------- */ + +#include "lepton/Operation.h" +#include "lepton/ExpressionTreeNode.h" +#include "MSVC_erfc.h" + +using namespace LMP_Lepton; +using namespace std; + +double Operation::Erf::evaluate(double* args, const map& variables) const { + return erf(args[0]); +} + +double Operation::Erfc::evaluate(double* args, const map& variables) const { + return erfc(args[0]); +} + +ExpressionTreeNode Operation::Constant::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + return ExpressionTreeNode(new Operation::Constant(0.0)); +} + +ExpressionTreeNode Operation::Variable::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (variable == name) + return ExpressionTreeNode(new Operation::Constant(1.0)); + return ExpressionTreeNode(new Operation::Constant(0.0)); +} + +ExpressionTreeNode Operation::Custom::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + if (function->getNumArguments() == 0) + return ExpressionTreeNode(new Operation::Constant(0.0)); + ExpressionTreeNode result = ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Custom(*this, 0), children), childDerivs[0]); + for (int i = 1; i < getNumArguments(); i++) { + result = ExpressionTreeNode(new Operation::Add(), + result, + ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Custom(*this, i), children), childDerivs[i])); + } + return result; +} + +ExpressionTreeNode Operation::Add::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + return ExpressionTreeNode(new Operation::Add(), childDerivs[0], childDerivs[1]); +} + +ExpressionTreeNode Operation::Subtract::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + return ExpressionTreeNode(new Operation::Subtract(), childDerivs[0], childDerivs[1]); +} + +ExpressionTreeNode Operation::Multiply::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + return ExpressionTreeNode(new Operation::Add(), + ExpressionTreeNode(new Operation::Multiply(), children[0], childDerivs[1]), + ExpressionTreeNode(new Operation::Multiply(), children[1], childDerivs[0])); +} + +ExpressionTreeNode Operation::Divide::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + return ExpressionTreeNode(new Operation::Divide(), + ExpressionTreeNode(new Operation::Subtract(), + ExpressionTreeNode(new Operation::Multiply(), children[1], childDerivs[0]), + ExpressionTreeNode(new Operation::Multiply(), children[0], childDerivs[1])), + ExpressionTreeNode(new Operation::Square(), children[1])); +} + +ExpressionTreeNode Operation::Power::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + return ExpressionTreeNode(new Operation::Add(), + ExpressionTreeNode(new Operation::Multiply(), + ExpressionTreeNode(new Operation::Multiply(), + children[1], + ExpressionTreeNode(new Operation::Power(), + children[0], ExpressionTreeNode(new Operation::AddConstant(-1.0), children[1]))), + childDerivs[0]), + ExpressionTreeNode(new Operation::Multiply(), + ExpressionTreeNode(new Operation::Multiply(), + ExpressionTreeNode(new Operation::Log(), children[0]), + ExpressionTreeNode(new Operation::Power(), children[0], children[1])), + childDerivs[1])); +} + +ExpressionTreeNode Operation::Negate::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + return ExpressionTreeNode(new Operation::Negate(), childDerivs[0]); +} + +ExpressionTreeNode Operation::Sqrt::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + return ExpressionTreeNode(new Operation::Multiply(), + ExpressionTreeNode(new Operation::MultiplyConstant(0.5), + ExpressionTreeNode(new Operation::Reciprocal(), + ExpressionTreeNode(new Operation::Sqrt(), children[0]))), + childDerivs[0]); +} + +ExpressionTreeNode Operation::Exp::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + return ExpressionTreeNode(new Operation::Multiply(), + ExpressionTreeNode(new Operation::Exp(), children[0]), + childDerivs[0]); +} + +ExpressionTreeNode Operation::Log::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + return ExpressionTreeNode(new Operation::Multiply(), + ExpressionTreeNode(new Operation::Reciprocal(), children[0]), + childDerivs[0]); +} + +ExpressionTreeNode Operation::Sin::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + return ExpressionTreeNode(new Operation::Multiply(), + ExpressionTreeNode(new Operation::Cos(), children[0]), + childDerivs[0]); +} + +ExpressionTreeNode Operation::Cos::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + return ExpressionTreeNode(new Operation::Multiply(), + ExpressionTreeNode(new Operation::Negate(), + ExpressionTreeNode(new Operation::Sin(), children[0])), + childDerivs[0]); +} + +ExpressionTreeNode Operation::Sec::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + return ExpressionTreeNode(new Operation::Multiply(), + ExpressionTreeNode(new Operation::Multiply(), + ExpressionTreeNode(new Operation::Sec(), children[0]), + ExpressionTreeNode(new Operation::Tan(), children[0])), + childDerivs[0]); +} + +ExpressionTreeNode Operation::Csc::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + return ExpressionTreeNode(new Operation::Multiply(), + ExpressionTreeNode(new Operation::Negate(), + ExpressionTreeNode(new Operation::Multiply(), + ExpressionTreeNode(new Operation::Csc(), children[0]), + ExpressionTreeNode(new Operation::Cot(), children[0]))), + childDerivs[0]); +} + +ExpressionTreeNode Operation::Tan::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + return ExpressionTreeNode(new Operation::Multiply(), + ExpressionTreeNode(new Operation::Square(), + ExpressionTreeNode(new Operation::Sec(), children[0])), + childDerivs[0]); +} + +ExpressionTreeNode Operation::Cot::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + return ExpressionTreeNode(new Operation::Multiply(), + ExpressionTreeNode(new Operation::Negate(), + ExpressionTreeNode(new Operation::Square(), + ExpressionTreeNode(new Operation::Csc(), children[0]))), + childDerivs[0]); +} + +ExpressionTreeNode Operation::Asin::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + return ExpressionTreeNode(new Operation::Multiply(), + ExpressionTreeNode(new Operation::Reciprocal(), + ExpressionTreeNode(new Operation::Sqrt(), + ExpressionTreeNode(new Operation::Subtract(), + ExpressionTreeNode(new Operation::Constant(1.0)), + ExpressionTreeNode(new Operation::Square(), children[0])))), + childDerivs[0]); +} + +ExpressionTreeNode Operation::Acos::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + return ExpressionTreeNode(new Operation::Multiply(), + ExpressionTreeNode(new Operation::Negate(), + ExpressionTreeNode(new Operation::Reciprocal(), + ExpressionTreeNode(new Operation::Sqrt(), + ExpressionTreeNode(new Operation::Subtract(), + ExpressionTreeNode(new Operation::Constant(1.0)), + ExpressionTreeNode(new Operation::Square(), children[0]))))), + childDerivs[0]); +} + +ExpressionTreeNode Operation::Atan::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + return ExpressionTreeNode(new Operation::Multiply(), + ExpressionTreeNode(new Operation::Reciprocal(), + ExpressionTreeNode(new Operation::AddConstant(1.0), + ExpressionTreeNode(new Operation::Square(), children[0]))), + childDerivs[0]); +} + +ExpressionTreeNode Operation::Atan2::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + return ExpressionTreeNode(new Operation::Divide(), + ExpressionTreeNode(new Operation::Subtract(), + ExpressionTreeNode(new Operation::Multiply(), children[1], childDerivs[0]), + ExpressionTreeNode(new Operation::Multiply(), children[0], childDerivs[1])), + ExpressionTreeNode(new Operation::Add(), + ExpressionTreeNode(new Operation::Square(), children[0]), + ExpressionTreeNode(new Operation::Square(), children[1]))); +} + +ExpressionTreeNode Operation::Sinh::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + return ExpressionTreeNode(new Operation::Multiply(), + ExpressionTreeNode(new Operation::Cosh(), + children[0]), + childDerivs[0]); +} + +ExpressionTreeNode Operation::Cosh::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + return ExpressionTreeNode(new Operation::Multiply(), + ExpressionTreeNode(new Operation::Sinh(), + children[0]), + childDerivs[0]); +} + +ExpressionTreeNode Operation::Tanh::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + return ExpressionTreeNode(new Operation::Multiply(), + ExpressionTreeNode(new Operation::Subtract(), + ExpressionTreeNode(new Operation::Constant(1.0)), + ExpressionTreeNode(new Operation::Square(), + ExpressionTreeNode(new Operation::Tanh(), children[0]))), + childDerivs[0]); +} + +ExpressionTreeNode Operation::Erf::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + return ExpressionTreeNode(new Operation::Multiply(), + ExpressionTreeNode(new Operation::Multiply(), + ExpressionTreeNode(new Operation::Constant(2.0/sqrt(M_PI))), + ExpressionTreeNode(new Operation::Exp(), + ExpressionTreeNode(new Operation::Negate(), + ExpressionTreeNode(new Operation::Square(), children[0])))), + childDerivs[0]); +} + +ExpressionTreeNode Operation::Erfc::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + return ExpressionTreeNode(new Operation::Multiply(), + ExpressionTreeNode(new Operation::Multiply(), + ExpressionTreeNode(new Operation::Constant(-2.0/sqrt(M_PI))), + ExpressionTreeNode(new Operation::Exp(), + ExpressionTreeNode(new Operation::Negate(), + ExpressionTreeNode(new Operation::Square(), children[0])))), + childDerivs[0]); +} + +ExpressionTreeNode Operation::Step::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + return ExpressionTreeNode(new Operation::Constant(0.0)); +} + +ExpressionTreeNode Operation::Delta::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + return ExpressionTreeNode(new Operation::Constant(0.0)); +} + +ExpressionTreeNode Operation::Square::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + return ExpressionTreeNode(new Operation::Multiply(), + ExpressionTreeNode(new Operation::MultiplyConstant(2.0), + children[0]), + childDerivs[0]); +} + +ExpressionTreeNode Operation::Cube::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + return ExpressionTreeNode(new Operation::Multiply(), + ExpressionTreeNode(new Operation::MultiplyConstant(3.0), + ExpressionTreeNode(new Operation::Square(), children[0])), + childDerivs[0]); +} + +ExpressionTreeNode Operation::Reciprocal::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + return ExpressionTreeNode(new Operation::Multiply(), + ExpressionTreeNode(new Operation::Negate(), + ExpressionTreeNode(new Operation::Reciprocal(), + ExpressionTreeNode(new Operation::Square(), children[0]))), + childDerivs[0]); +} + +ExpressionTreeNode Operation::AddConstant::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + return childDerivs[0]; +} + +ExpressionTreeNode Operation::MultiplyConstant::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + return ExpressionTreeNode(new Operation::MultiplyConstant(value), + childDerivs[0]); +} + +ExpressionTreeNode Operation::PowerConstant::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + return ExpressionTreeNode(new Operation::Multiply(), + ExpressionTreeNode(new Operation::MultiplyConstant(value), + ExpressionTreeNode(new Operation::PowerConstant(value-1), + children[0])), + childDerivs[0]); +} + +ExpressionTreeNode Operation::Min::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + ExpressionTreeNode step(new Operation::Step(), + ExpressionTreeNode(new Operation::Subtract(), children[0], children[1])); + return ExpressionTreeNode(new Operation::Subtract(), + ExpressionTreeNode(new Operation::Multiply(), childDerivs[1], step), + ExpressionTreeNode(new Operation::Multiply(), childDerivs[0], + ExpressionTreeNode(new Operation::AddConstant(-1), step))); +} + +ExpressionTreeNode Operation::Max::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + ExpressionTreeNode step(new Operation::Step(), + ExpressionTreeNode(new Operation::Subtract(), children[0], children[1])); + return ExpressionTreeNode(new Operation::Subtract(), + ExpressionTreeNode(new Operation::Multiply(), childDerivs[0], step), + ExpressionTreeNode(new Operation::Multiply(), childDerivs[1], + ExpressionTreeNode(new Operation::AddConstant(-1), step))); +} + +ExpressionTreeNode Operation::Abs::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + ExpressionTreeNode step(new Operation::Step(), children[0]); + return ExpressionTreeNode(new Operation::Multiply(), + childDerivs[0], + ExpressionTreeNode(new Operation::AddConstant(-1), + ExpressionTreeNode(new Operation::MultiplyConstant(2), step))); +} + +ExpressionTreeNode Operation::Floor::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + return ExpressionTreeNode(new Operation::Constant(0.0)); +} + +ExpressionTreeNode Operation::Ceil::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + return ExpressionTreeNode(new Operation::Constant(0.0)); +} + +ExpressionTreeNode Operation::Select::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { + vector derivChildren; + derivChildren.push_back(children[0]); + derivChildren.push_back(childDerivs[1]); + derivChildren.push_back(childDerivs[2]); + return ExpressionTreeNode(new Operation::Select(), derivChildren); +} diff --git a/lib/lepton/src/ParsedExpression.cpp b/lib/lepton/src/ParsedExpression.cpp new file mode 100644 index 0000000000..13ebbf2dc2 --- /dev/null +++ b/lib/lepton/src/ParsedExpression.cpp @@ -0,0 +1,379 @@ +/* -------------------------------------------------------------------------- * + * Lepton * + * -------------------------------------------------------------------------- * + * This is part of the Lepton expression parser originating from * + * Simbios, the NIH National Center for Physics-Based Simulation of * + * Biological Structures at Stanford, funded under the NIH Roadmap for * + * Medical Research, grant U54 GM072970. See https://simtk.org. * + * * + * Portions copyright (c) 2009 Stanford University and the Authors. * + * Authors: Peter Eastman * + * Contributors: * + * * + * Permission is hereby granted, free of charge, to any person obtaining a * + * copy of this software and associated documentation files (the "Software"), * + * to deal in the Software without restriction, including without limitation * + * the rights to use, copy, modify, merge, publish, distribute, sublicense, * + * and/or sell copies of the Software, and to permit persons to whom the * + * Software is furnished to do so, subject to the following conditions: * + * * + * The above copyright notice and this permission notice shall be included in * + * all copies or substantial portions of the Software. * + * * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL * + * THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, * + * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR * + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE * + * USE OR OTHER DEALINGS IN THE SOFTWARE. * + * -------------------------------------------------------------------------- */ + +#include "lepton/ParsedExpression.h" +#include "lepton/CompiledExpression.h" +#include "lepton/ExpressionProgram.h" +#include "lepton/Operation.h" +#include +#include + +using namespace LMP_Lepton; +using namespace std; + +ParsedExpression::ParsedExpression() : rootNode(ExpressionTreeNode()) { +} + +ParsedExpression::ParsedExpression(const ExpressionTreeNode& rootNode) : rootNode(rootNode) { +} + +const ExpressionTreeNode& ParsedExpression::getRootNode() const { + if (&rootNode.getOperation() == NULL) + throw Exception("Illegal call to an initialized ParsedExpression"); + return rootNode; +} + +double ParsedExpression::evaluate() const { + return evaluate(getRootNode(), map()); +} + +double ParsedExpression::evaluate(const map& variables) const { + return evaluate(getRootNode(), variables); +} + +double ParsedExpression::evaluate(const ExpressionTreeNode& node, const map& variables) { + int numArgs = (int) node.getChildren().size(); + vector args(max(numArgs, 1)); + for (int i = 0; i < numArgs; i++) + args[i] = evaluate(node.getChildren()[i], variables); + return node.getOperation().evaluate(&args[0], variables); +} + +ParsedExpression ParsedExpression::optimize() const { + ExpressionTreeNode result = precalculateConstantSubexpressions(getRootNode()); + while (true) { + ExpressionTreeNode simplified = substituteSimplerExpression(result); + if (simplified == result) + break; + result = simplified; + } + return ParsedExpression(result); +} + +ParsedExpression ParsedExpression::optimize(const map& variables) const { + ExpressionTreeNode result = preevaluateVariables(getRootNode(), variables); + result = precalculateConstantSubexpressions(result); + while (true) { + ExpressionTreeNode simplified = substituteSimplerExpression(result); + if (simplified == result) + break; + result = simplified; + } + return ParsedExpression(result); +} + +ExpressionTreeNode ParsedExpression::preevaluateVariables(const ExpressionTreeNode& node, const map& variables) { + if (node.getOperation().getId() == Operation::VARIABLE) { + const Operation::Variable& var = dynamic_cast(node.getOperation()); + map::const_iterator iter = variables.find(var.getName()); + if (iter == variables.end()) + return node; + return ExpressionTreeNode(new Operation::Constant(iter->second)); + } + vector children(node.getChildren().size()); + for (int i = 0; i < (int) children.size(); i++) + children[i] = preevaluateVariables(node.getChildren()[i], variables); + return ExpressionTreeNode(node.getOperation().clone(), children); +} + +ExpressionTreeNode ParsedExpression::precalculateConstantSubexpressions(const ExpressionTreeNode& node) { + vector children(node.getChildren().size()); + for (int i = 0; i < (int) children.size(); i++) + children[i] = precalculateConstantSubexpressions(node.getChildren()[i]); + ExpressionTreeNode result = ExpressionTreeNode(node.getOperation().clone(), children); + if (node.getOperation().getId() == Operation::VARIABLE || node.getOperation().getId() == Operation::CUSTOM) + return result; + for (int i = 0; i < (int) children.size(); i++) + if (children[i].getOperation().getId() != Operation::CONSTANT) + return result; + return ExpressionTreeNode(new Operation::Constant(evaluate(result, map()))); +} + +ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const ExpressionTreeNode& node) { + vector children(node.getChildren().size()); + for (int i = 0; i < (int) children.size(); i++) + children[i] = substituteSimplerExpression(node.getChildren()[i]); + + // Collect some info on constant expressions in children + bool first_const = children.size() > 0 && isConstant(children[0]); // is first child constant? + bool second_const = children.size() > 1 && isConstant(children[1]); ; // is second child constant? + double first, second; // if yes, value of first and second child + if (first_const) + first = getConstantValue(children[0]); + if (second_const) + second = getConstantValue(children[1]); + + switch (node.getOperation().getId()) { + case Operation::ADD: + { + if (first_const) { + if (first == 0.0) { // Add 0 + return children[1]; + } else { // Add a constant + return ExpressionTreeNode(new Operation::AddConstant(first), children[1]); + } + } + if (second_const) { + if (second == 0.0) { // Add 0 + return children[0]; + } else { // Add a constant + return ExpressionTreeNode(new Operation::AddConstant(second), children[0]); + } + } + if (children[1].getOperation().getId() == Operation::NEGATE) // a+(-b) = a-b + return ExpressionTreeNode(new Operation::Subtract(), children[0], children[1].getChildren()[0]); + if (children[0].getOperation().getId() == Operation::NEGATE) // (-a)+b = b-a + return ExpressionTreeNode(new Operation::Subtract(), children[1], children[0].getChildren()[0]); + break; + } + case Operation::SUBTRACT: + { + if (children[0] == children[1]) + return ExpressionTreeNode(new Operation::Constant(0.0)); // Subtracting anything from itself is 0 + if (first_const) { + if (first == 0.0) // Subtract from 0 + return ExpressionTreeNode(new Operation::Negate(), children[1]); + } + if (second_const) { + if (second == 0.0) { // Subtract 0 + return children[0]; + } else { // Subtract a constant + return ExpressionTreeNode(new Operation::AddConstant(-second), children[0]); + } + } + if (children[1].getOperation().getId() == Operation::NEGATE) // a-(-b) = a+b + return ExpressionTreeNode(new Operation::Add(), children[0], children[1].getChildren()[0]); + break; + } + case Operation::MULTIPLY: + { + if ((first_const && first == 0.0) || (second_const && second == 0.0)) // Multiply by 0 + return ExpressionTreeNode(new Operation::Constant(0.0)); + if (first_const && first == 1.0) // Multiply by 1 + return children[1]; + if (second_const && second == 1.0) // Multiply by 1 + return children[0]; + if (first_const) { // Multiply by a constant + if (children[1].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine two multiplies into a single one + return ExpressionTreeNode(new Operation::MultiplyConstant(first*dynamic_cast(&children[1].getOperation())->getValue()), children[1].getChildren()[0]); + return ExpressionTreeNode(new Operation::MultiplyConstant(first), children[1]); + } + if (second_const) { // Multiply by a constant + if (children[0].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine two multiplies into a single one + return ExpressionTreeNode(new Operation::MultiplyConstant(second*dynamic_cast(&children[0].getOperation())->getValue()), children[0].getChildren()[0]); + return ExpressionTreeNode(new Operation::MultiplyConstant(second), children[0]); + } + if (children[0].getOperation().getId() == Operation::NEGATE && children[1].getOperation().getId() == Operation::NEGATE) // The two negations cancel + return ExpressionTreeNode(new Operation::Multiply(), children[0].getChildren()[0], children[1].getChildren()[0]); + if (children[0].getOperation().getId() == Operation::NEGATE && children[1].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Negate the constant + return ExpressionTreeNode(new Operation::Multiply(), children[0].getChildren()[0], ExpressionTreeNode(new Operation::MultiplyConstant(-dynamic_cast(&children[1].getOperation())->getValue()), children[1].getChildren()[0])); + if (children[1].getOperation().getId() == Operation::NEGATE && children[0].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Negate the constant + return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::MultiplyConstant(-dynamic_cast(&children[0].getOperation())->getValue()), children[0].getChildren()[0]), children[1].getChildren()[0]); + if (children[0].getOperation().getId() == Operation::NEGATE) // Pull the negation out so it can possibly be optimized further + return ExpressionTreeNode(new Operation::Negate(), ExpressionTreeNode(new Operation::Multiply(), children[0].getChildren()[0], children[1])); + if (children[1].getOperation().getId() == Operation::NEGATE) // Pull the negation out so it can possibly be optimized further + return ExpressionTreeNode(new Operation::Negate(), ExpressionTreeNode(new Operation::Multiply(), children[0], children[1].getChildren()[0])); + if (children[1].getOperation().getId() == Operation::RECIPROCAL) // a*(1/b) = a/b + return ExpressionTreeNode(new Operation::Divide(), children[0], children[1].getChildren()[0]); + if (children[0].getOperation().getId() == Operation::RECIPROCAL) // (1/a)*b = b/a + return ExpressionTreeNode(new Operation::Divide(), children[1], children[0].getChildren()[0]); + if (children[0] == children[1]) + return ExpressionTreeNode(new Operation::Square(), children[0]); // x*x = square(x) + if (children[0].getOperation().getId() == Operation::SQUARE && children[0].getChildren()[0] == children[1]) + return ExpressionTreeNode(new Operation::Cube(), children[1]); // x*x*x = cube(x) + if (children[1].getOperation().getId() == Operation::SQUARE && children[1].getChildren()[0] == children[0]) + return ExpressionTreeNode(new Operation::Cube(), children[0]); // x*x*x = cube(x) + break; + } + case Operation::DIVIDE: + { + if (children[0] == children[1]) + return ExpressionTreeNode(new Operation::Constant(1.0)); // Dividing anything from itself is 0 + if (first_const && first == 0.0) // 0 divided by something + return ExpressionTreeNode(new Operation::Constant(0.0)); + if (first_const && first == 1.0) // 1 divided by something + return ExpressionTreeNode(new Operation::Reciprocal(), children[1]); + if (second_const && second == 1.0) // Divide by 1 + return children[0]; + if (second_const) { + if (children[0].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine a multiply and a divide into one multiply + return ExpressionTreeNode(new Operation::MultiplyConstant(dynamic_cast(&children[0].getOperation())->getValue()/second), children[0].getChildren()[0]); + return ExpressionTreeNode(new Operation::MultiplyConstant(1.0/second), children[0]); // Replace a divide with a multiply + } + if (children[0].getOperation().getId() == Operation::NEGATE && children[1].getOperation().getId() == Operation::NEGATE) // The two negations cancel + return ExpressionTreeNode(new Operation::Divide(), children[0].getChildren()[0], children[1].getChildren()[0]); + if (children[1].getOperation().getId() == Operation::NEGATE && children[0].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Negate the constant + return ExpressionTreeNode(new Operation::Divide(), ExpressionTreeNode(new Operation::MultiplyConstant(-dynamic_cast(&children[0].getOperation())->getValue()), children[0].getChildren()[0]), children[1].getChildren()[0]); + if (children[0].getOperation().getId() == Operation::NEGATE) // Pull the negation out so it can possibly be optimized further + return ExpressionTreeNode(new Operation::Negate(), ExpressionTreeNode(new Operation::Divide(), children[0].getChildren()[0], children[1])); + if (children[1].getOperation().getId() == Operation::NEGATE) // Pull the negation out so it can possibly be optimized further + return ExpressionTreeNode(new Operation::Negate(), ExpressionTreeNode(new Operation::Divide(), children[0], children[1].getChildren()[0])); + if (children[1].getOperation().getId() == Operation::RECIPROCAL) // a/(1/b) = a*b + return ExpressionTreeNode(new Operation::Multiply(), children[0], children[1].getChildren()[0]); + break; + } + case Operation::POWER: + { + if (first_const && first == 0.0) // 0 to any power is 0 + return ExpressionTreeNode(new Operation::Constant(0.0)); + if (first_const && first == 1.0) // 1 to any power is 1 + return ExpressionTreeNode(new Operation::Constant(1.0)); + if (second_const) { // Constant exponent + if (second == 0.0) // x^0 = 1 + return ExpressionTreeNode(new Operation::Constant(1.0)); + if (second == 1.0) // x^1 = x + return children[0]; + if (second == -1.0) // x^-1 = recip(x) + return ExpressionTreeNode(new Operation::Reciprocal(), children[0]); + if (second == 2.0) // x^2 = square(x) + return ExpressionTreeNode(new Operation::Square(), children[0]); + if (second == 3.0) // x^3 = cube(x) + return ExpressionTreeNode(new Operation::Cube(), children[0]); + if (second == 0.5) // x^0.5 = sqrt(x) + return ExpressionTreeNode(new Operation::Sqrt(), children[0]); + // Constant power + return ExpressionTreeNode(new Operation::PowerConstant(second), children[0]); + } + break; + } + case Operation::NEGATE: + { + if (children[0].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine a multiply and a negate into a single multiply + return ExpressionTreeNode(new Operation::MultiplyConstant(-dynamic_cast(&children[0].getOperation())->getValue()), children[0].getChildren()[0]); + if (first_const) // Negate a constant + return ExpressionTreeNode(new Operation::Constant(-first)); + if (children[0].getOperation().getId() == Operation::NEGATE) // The two negations cancel + return children[0].getChildren()[0]; + break; + } + case Operation::MULTIPLY_CONSTANT: + { + if (children[0].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine two multiplies into a single one + return ExpressionTreeNode(new Operation::MultiplyConstant(dynamic_cast(&node.getOperation())->getValue()*dynamic_cast(&children[0].getOperation())->getValue()), children[0].getChildren()[0]); + if (first_const) // Multiply two constants + return ExpressionTreeNode(new Operation::Constant(dynamic_cast(&node.getOperation())->getValue()*getConstantValue(children[0]))); + if (children[0].getOperation().getId() == Operation::NEGATE) // Combine a multiply and a negate into a single multiply + return ExpressionTreeNode(new Operation::MultiplyConstant(-dynamic_cast(&node.getOperation())->getValue()), children[0].getChildren()[0]); + break; + } + case Operation::SQRT: + { + if (children[0].getOperation().getId() == Operation::SQUARE) // sqrt(square(x)) = abs(x) + return ExpressionTreeNode(new Operation::Abs(), children[0].getChildren()[0]); + } + case Operation::SQUARE: + { + if (children[0].getOperation().getId() == Operation::SQRT) // square(sqrt(x)) = x + return children[0].getChildren()[0]; + } + default: + { + // If operation ID is not one of the above, + // we don't substitute a simpler expression. + break; + } + + } + return ExpressionTreeNode(node.getOperation().clone(), children); +} + +ParsedExpression ParsedExpression::differentiate(const string& variable) const { + return differentiate(getRootNode(), variable); +} + +ExpressionTreeNode ParsedExpression::differentiate(const ExpressionTreeNode& node, const string& variable) { + vector childDerivs(node.getChildren().size()); + for (int i = 0; i < (int) childDerivs.size(); i++) + childDerivs[i] = differentiate(node.getChildren()[i], variable); + return node.getOperation().differentiate(node.getChildren(),childDerivs, variable); +} + +bool ParsedExpression::isConstant(const ExpressionTreeNode& node) { + return (node.getOperation().getId() == Operation::CONSTANT); +} + +double ParsedExpression::getConstantValue(const ExpressionTreeNode& node) { + if (node.getOperation().getId() != Operation::CONSTANT) { + throw Exception("getConstantValue called on a non-constant ExpressionNode"); + } + return dynamic_cast(node.getOperation()).getValue(); +} + +ExpressionProgram ParsedExpression::createProgram() const { + return ExpressionProgram(*this); +} + +CompiledExpression ParsedExpression::createCompiledExpression() const { + return CompiledExpression(*this); +} + +ParsedExpression ParsedExpression::renameVariables(const map& replacements) const { + return ParsedExpression(renameNodeVariables(getRootNode(), replacements)); +} + +ExpressionTreeNode ParsedExpression::renameNodeVariables(const ExpressionTreeNode& node, const map& replacements) { + if (node.getOperation().getId() == Operation::VARIABLE) { + map::const_iterator replace = replacements.find(node.getOperation().getName()); + if (replace != replacements.end()) + return ExpressionTreeNode(new Operation::Variable(replace->second)); + } + vector children; + for (int i = 0; i < (int) node.getChildren().size(); i++) + children.push_back(renameNodeVariables(node.getChildren()[i], replacements)); + return ExpressionTreeNode(node.getOperation().clone(), children); +} + +ostream& LMP_Lepton::operator<<(ostream& out, const ExpressionTreeNode& node) { + if (node.getOperation().isInfixOperator() && node.getChildren().size() == 2) { + out << "(" << node.getChildren()[0] << ")" << node.getOperation().getName() << "(" << node.getChildren()[1] << ")"; + } + else if (node.getOperation().isInfixOperator() && node.getChildren().size() == 1) { + out << "(" << node.getChildren()[0] << ")" << node.getOperation().getName(); + } + else { + out << node.getOperation().getName(); + if (node.getChildren().size() > 0) { + out << "("; + for (int i = 0; i < (int) node.getChildren().size(); i++) { + if (i > 0) + out << ", "; + out << node.getChildren()[i]; + } + out << ")"; + } + } + return out; +} + +ostream& LMP_Lepton::operator<<(ostream& out, const ParsedExpression& exp) { + out << exp.getRootNode(); + return out; +} diff --git a/lib/lepton/src/Parser.cpp b/lib/lepton/src/Parser.cpp new file mode 100644 index 0000000000..c0b4c185e8 --- /dev/null +++ b/lib/lepton/src/Parser.cpp @@ -0,0 +1,409 @@ +/* -------------------------------------------------------------------------- * + * Lepton * + * -------------------------------------------------------------------------- * + * This is part of the Lepton expression parser originating from * + * Simbios, the NIH National Center for Physics-Based Simulation of * + * Biological Structures at Stanford, funded under the NIH Roadmap for * + * Medical Research, grant U54 GM072970. See https://simtk.org. * + * * + * Portions copyright (c) 2009-2019 Stanford University and the Authors. * + * Authors: Peter Eastman * + * Contributors: * + * * + * Permission is hereby granted, free of charge, to any person obtaining a * + * copy of this software and associated documentation files (the "Software"), * + * to deal in the Software without restriction, including without limitation * + * the rights to use, copy, modify, merge, publish, distribute, sublicense, * + * and/or sell copies of the Software, and to permit persons to whom the * + * Software is furnished to do so, subject to the following conditions: * + * * + * The above copyright notice and this permission notice shall be included in * + * all copies or substantial portions of the Software. * + * * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL * + * THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, * + * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR * + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE * + * USE OR OTHER DEALINGS IN THE SOFTWARE. * + * -------------------------------------------------------------------------- */ + +#include "lepton/Parser.h" +#include "lepton/CustomFunction.h" +#include "lepton/Exception.h" +#include "lepton/ExpressionTreeNode.h" +#include "lepton/Operation.h" +#include "lepton/ParsedExpression.h" +#include +#include + +using namespace LMP_Lepton; +using namespace std; + +static const string Digits = "0123456789"; +static const string Operators = "+-*/^"; +static const bool LeftAssociative[] = {true, true, true, true, false}; +static const int Precedence[] = {0, 0, 1, 1, 3}; +static const Operation::Id OperationId[] = {Operation::ADD, Operation::SUBTRACT, Operation::MULTIPLY, Operation::DIVIDE, Operation::POWER}; + +class LMP_Lepton::ParseToken { +public: + enum Type {Number, Operator, Variable, Function, LeftParen, RightParen, Comma, Whitespace}; + + ParseToken(string text, Type type) : text(text), type(type) { + } + const string& getText() const { + return text; + } + Type getType() const { + return type; + } +private: + string text; + Type type; +}; + +string Parser::trim(const string& expression) { + // Remove leading and trailing spaces. + + int start, end; + for (start = 0; start < (int) expression.size() && isspace(expression[start]); start++) + ; + for (end = (int) expression.size()-1; end > start && isspace(expression[end]); end--) + ; + if (start == end && isspace(expression[end])) + return ""; + return expression.substr(start, end-start+1); +} + +ParseToken Parser::getNextToken(const string& expression, int start) { + char c = expression[start]; + if (c == '(') + return ParseToken("(", ParseToken::LeftParen); + if (c == ')') + return ParseToken(")", ParseToken::RightParen); + if (c == ',') + return ParseToken(",", ParseToken::Comma); + if (Operators.find(c) != string::npos) + return ParseToken(string(1, c), ParseToken::Operator); + if (isspace(c)) { + // White space + + for (int pos = start+1; pos < (int) expression.size(); pos++) { + if (!isspace(expression[pos])) + return ParseToken(expression.substr(start, pos-start), ParseToken::Whitespace); + } + return ParseToken(expression.substr(start, string::npos), ParseToken::Whitespace); + } + if (c == '.' || Digits.find(c) != string::npos) { + // A number + + bool foundDecimal = (c == '.'); + bool foundExp = false; + int pos; + for (pos = start+1; pos < (int) expression.size(); pos++) { + c = expression[pos]; + if (Digits.find(c) != string::npos) + continue; + if (c == '.' && !foundDecimal) { + foundDecimal = true; + continue; + } + if ((c == 'e' || c == 'E') && !foundExp) { + foundExp = true; + if (pos < (int) expression.size()-1 && (expression[pos+1] == '-' || expression[pos+1] == '+')) + pos++; + continue; + } + break; + } + return ParseToken(expression.substr(start, pos-start), ParseToken::Number); + } + + // A variable, function, or left parenthesis + + for (int pos = start; pos < (int) expression.size(); pos++) { + c = expression[pos]; + if (c == '(') + return ParseToken(expression.substr(start, pos-start+1), ParseToken::Function); + if (Operators.find(c) != string::npos || c == ',' || c == ')' || isspace(c)) + return ParseToken(expression.substr(start, pos-start), ParseToken::Variable); + } + return ParseToken(expression.substr(start, string::npos), ParseToken::Variable); +} + +vector Parser::tokenize(const string& expression) { + vector tokens; + int pos = 0; + while (pos < (int) expression.size()) { + ParseToken token = getNextToken(expression, pos); + if (token.getType() != ParseToken::Whitespace) + tokens.push_back(token); + pos += (int) token.getText().size(); + } + return tokens; +} + +ParsedExpression Parser::parse(const string& expression) { + return parse(expression, map()); +} + +ParsedExpression Parser::parse(const string& expression, const map& customFunctions) { + try { + // First split the expression into subexpressions. + + string primaryExpression = expression; + vector subexpressions; + while (true) { + string::size_type pos = primaryExpression.find_last_of(';'); + if (pos == string::npos) + break; + string sub = trim(primaryExpression.substr(pos+1)); + if (sub.size() > 0) + subexpressions.push_back(sub); + primaryExpression = primaryExpression.substr(0, pos); + } + + // Parse the subexpressions. + + map subexpDefs; + for (int i = 0; i < (int) subexpressions.size(); i++) { + string::size_type equalsPos = subexpressions[i].find('='); + if (equalsPos == string::npos) + throw Exception("subexpression does not specify a name"); + string name = trim(subexpressions[i].substr(0, equalsPos)); + if (name.size() == 0) + throw Exception("subexpression does not specify a name"); + vector tokens = tokenize(subexpressions[i].substr(equalsPos+1)); + int pos = 0; + subexpDefs[name] = parsePrecedence(tokens, pos, customFunctions, subexpDefs, 0); + if (pos != tokens.size()) + throw Exception("unexpected text at end of subexpression: "+tokens[pos].getText()); + } + + // Now parse the primary expression. + + vector tokens = tokenize(primaryExpression); + int pos = 0; + ExpressionTreeNode result = parsePrecedence(tokens, pos, customFunctions, subexpDefs, 0); + if (pos != tokens.size()) + throw Exception("unexpected text at end of expression: "+tokens[pos].getText()); + return ParsedExpression(result); + } + catch (Exception& ex) { + throw Exception("Parse error in expression \""+expression+"\": "+ex.what()); + } +} + +ExpressionTreeNode Parser::parsePrecedence(const vector& tokens, int& pos, const map& customFunctions, + const map& subexpressionDefs, int precedence) { + if (pos == tokens.size()) + throw Exception("unexpected end of expression"); + + // Parse the next value (number, variable, function, parenthesized expression) + + ParseToken token = tokens[pos]; + ExpressionTreeNode result; + if (token.getType() == ParseToken::Number) { + double value; + stringstream(token.getText()) >> value; + result = ExpressionTreeNode(new Operation::Constant(value)); + pos++; + } + else if (token.getType() == ParseToken::Variable) { + map::const_iterator subexp = subexpressionDefs.find(token.getText()); + if (subexp == subexpressionDefs.end()) { + Operation* op = new Operation::Variable(token.getText()); + result = ExpressionTreeNode(op); + } + else + result = subexp->second; + pos++; + } + else if (token.getType() == ParseToken::LeftParen) { + pos++; + result = parsePrecedence(tokens, pos, customFunctions, subexpressionDefs, 0); + if (pos == tokens.size() || tokens[pos].getType() != ParseToken::RightParen) + throw Exception("unbalanced parentheses"); + pos++; + } + else if (token.getType() == ParseToken::Function) { + pos++; + vector args; + bool moreArgs; + do { + args.push_back(parsePrecedence(tokens, pos, customFunctions, subexpressionDefs, 0)); + moreArgs = (pos < (int) tokens.size() && tokens[pos].getType() == ParseToken::Comma); + if (moreArgs) + pos++; + } while (moreArgs); + if (pos == tokens.size() || tokens[pos].getType() != ParseToken::RightParen) + throw Exception("unbalanced parentheses"); + pos++; + Operation* op = getFunctionOperation(token.getText(), customFunctions); + try { + result = ExpressionTreeNode(op, args); + } + catch (...) { + delete op; + throw; + } + } + else if (token.getType() == ParseToken::Operator && token.getText() == "-") { + pos++; + ExpressionTreeNode toNegate = parsePrecedence(tokens, pos, customFunctions, subexpressionDefs, 2); + result = ExpressionTreeNode(new Operation::Negate(), toNegate); + } + else + throw Exception("unexpected token: "+token.getText()); + + // Now deal with the next binary operator. + + while (pos < (int) tokens.size() && tokens[pos].getType() == ParseToken::Operator) { + token = tokens[pos]; + int opIndex = (int) Operators.find(token.getText()); + int opPrecedence = Precedence[opIndex]; + if (opPrecedence < precedence) + return result; + pos++; + ExpressionTreeNode arg = parsePrecedence(tokens, pos, customFunctions, subexpressionDefs, LeftAssociative[opIndex] ? opPrecedence+1 : opPrecedence); + Operation* op = getOperatorOperation(token.getText()); + try { + result = ExpressionTreeNode(op, result, arg); + } + catch (...) { + delete op; + throw; + } + } + return result; +} + +Operation* Parser::getOperatorOperation(const std::string& name) { + switch (OperationId[Operators.find(name)]) { + case Operation::ADD: + return new Operation::Add(); + case Operation::SUBTRACT: + return new Operation::Subtract(); + case Operation::MULTIPLY: + return new Operation::Multiply(); + case Operation::DIVIDE: + return new Operation::Divide(); + case Operation::POWER: + return new Operation::Power(); + default: + throw Exception("unknown operator"); + } +} + +Operation* Parser::getFunctionOperation(const std::string& name, const map& customFunctions) { + + static map opMap; + if (opMap.size() == 0) { + opMap["sqrt"] = Operation::SQRT; + opMap["exp"] = Operation::EXP; + opMap["log"] = Operation::LOG; + opMap["sin"] = Operation::SIN; + opMap["cos"] = Operation::COS; + opMap["sec"] = Operation::SEC; + opMap["csc"] = Operation::CSC; + opMap["tan"] = Operation::TAN; + opMap["cot"] = Operation::COT; + opMap["asin"] = Operation::ASIN; + opMap["acos"] = Operation::ACOS; + opMap["atan"] = Operation::ATAN; + opMap["atan2"] = Operation::ATAN2; + opMap["sinh"] = Operation::SINH; + opMap["cosh"] = Operation::COSH; + opMap["tanh"] = Operation::TANH; + opMap["erf"] = Operation::ERF; + opMap["erfc"] = Operation::ERFC; + opMap["step"] = Operation::STEP; + opMap["delta"] = Operation::DELTA; + opMap["square"] = Operation::SQUARE; + opMap["cube"] = Operation::CUBE; + opMap["recip"] = Operation::RECIPROCAL; + opMap["min"] = Operation::MIN; + opMap["max"] = Operation::MAX; + opMap["abs"] = Operation::ABS; + opMap["floor"] = Operation::FLOOR; + opMap["ceil"] = Operation::CEIL; + opMap["select"] = Operation::SELECT; + } + string trimmed = name.substr(0, name.size()-1); + + // First check custom functions. + + map::const_iterator custom = customFunctions.find(trimmed); + if (custom != customFunctions.end()) + return new Operation::Custom(trimmed, custom->second->clone()); + + // Now try standard functions. + + map::const_iterator iter = opMap.find(trimmed); + if (iter == opMap.end()) + throw Exception("unknown function: "+trimmed); + switch (iter->second) { + case Operation::SQRT: + return new Operation::Sqrt(); + case Operation::EXP: + return new Operation::Exp(); + case Operation::LOG: + return new Operation::Log(); + case Operation::SIN: + return new Operation::Sin(); + case Operation::COS: + return new Operation::Cos(); + case Operation::SEC: + return new Operation::Sec(); + case Operation::CSC: + return new Operation::Csc(); + case Operation::TAN: + return new Operation::Tan(); + case Operation::COT: + return new Operation::Cot(); + case Operation::ASIN: + return new Operation::Asin(); + case Operation::ACOS: + return new Operation::Acos(); + case Operation::ATAN: + return new Operation::Atan(); + case Operation::ATAN2: + return new Operation::Atan2(); + case Operation::SINH: + return new Operation::Sinh(); + case Operation::COSH: + return new Operation::Cosh(); + case Operation::TANH: + return new Operation::Tanh(); + case Operation::ERF: + return new Operation::Erf(); + case Operation::ERFC: + return new Operation::Erfc(); + case Operation::STEP: + return new Operation::Step(); + case Operation::DELTA: + return new Operation::Delta(); + case Operation::SQUARE: + return new Operation::Square(); + case Operation::CUBE: + return new Operation::Cube(); + case Operation::RECIPROCAL: + return new Operation::Reciprocal(); + case Operation::MIN: + return new Operation::Min(); + case Operation::MAX: + return new Operation::Max(); + case Operation::ABS: + return new Operation::Abs(); + case Operation::FLOOR: + return new Operation::Floor(); + case Operation::CEIL: + return new Operation::Ceil(); + case Operation::SELECT: + return new Operation::Select(); + default: + throw Exception("unknown function"); + } +} From c44e87d87a35116100713d6e4612af750cdb5cd9 Mon Sep 17 00:00:00 2001 From: Axel Kohlmeyer Date: Wed, 21 Dec 2022 18:28:04 -0500 Subject: [PATCH 02/79] avoid name conflict with COLVARS package --- lib/lepton/Makefile.mpi | 6 +----- lib/lepton/Makefile.serial | 6 +----- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/lib/lepton/Makefile.mpi b/lib/lepton/Makefile.mpi index 045ea61015..934188ba50 100644 --- a/lib/lepton/Makefile.mpi +++ b/lib/lepton/Makefile.mpi @@ -1,15 +1,11 @@ EXTRAMAKE=Makefile.lammps.empty CC=mpicxx - -# -DH5_NO_DEPRECATED_SYMBOLS is required here to ensure we are using -# the v1.8 API when HDF5 is configured to default to using the v1.6 API. CXXFLAGS=-D_DEFAULT_SOURCE -O2 -Wall -fPIC -std=c++11 INC=-I include AR=ar ARFLAGS=rc -# need to build two libraries to not break compatibility and to support Install.py -LIB=liblepton.a +LIB=liblmplepton.a SRC=src/CompiledExpression.cpp src/ExpressionProgram.cpp src/ExpressionTreeNode.cpp src/Operation.cpp src/ParsedExpression.cpp src/Parser.cpp OBJ=$(SRC:src/%.cpp=build/%.o) diff --git a/lib/lepton/Makefile.serial b/lib/lepton/Makefile.serial index 58151e49c2..23d9b3dd57 100644 --- a/lib/lepton/Makefile.serial +++ b/lib/lepton/Makefile.serial @@ -1,15 +1,11 @@ EXTRAMAKE=Makefile.lammps.empty CC=g++ - -# -DH5_NO_DEPRECATED_SYMBOLS is required here to ensure we are using -# the v1.8 API when HDF5 is configured to default to using the v1.6 API. CXXFLAGS=-D_DEFAULT_SOURCE -O2 -Wall -fPIC -std=c++11 INC=-I include AR=ar ARFLAGS=rc -# need to build two libraries to not break compatibility and to support Install.py -LIB=liblepton.a +LIB=liblmplepton.a SRC=src/CompiledExpression.cpp src/ExpressionProgram.cpp src/ExpressionTreeNode.cpp src/Operation.cpp src/ParsedExpression.cpp src/Parser.cpp OBJ=$(SRC:src/%.cpp=build/%.o) From 5f934e3eae9a53f51c4eb0b79f6476c437cefde6 Mon Sep 17 00:00:00 2001 From: Axel Kohlmeyer Date: Wed, 21 Dec 2022 18:28:35 -0500 Subject: [PATCH 03/79] add LEPTON package build system support for CMake --- cmake/CMakeLists.txt | 5 +++-- cmake/Modules/Packages/LEPTON.cmake | 8 ++++++++ 2 files changed, 11 insertions(+), 2 deletions(-) create mode 100644 cmake/Modules/Packages/LEPTON.cmake diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index ec142af426..bea9a31197 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -251,6 +251,7 @@ set(STANDARD_PACKAGES KSPACE LATBOLTZ LATTE + LEPTON MACHDYN MANIFOLD MANYBODY @@ -266,7 +267,7 @@ set(STANDARD_PACKAGES ML-QUIP ML-RANN ML-SNAP - ML-POD + ML-POD MOFFF MOLECULE MOLFILE @@ -517,7 +518,7 @@ else() endif() foreach(PKG_WITH_INCL KSPACE PYTHON ML-IAP VORONOI COLVARS ML-HDNNP MDI MOLFILE NETCDF - PLUMED QMMM ML-QUIP SCAFACOS MACHDYN VTK KIM LATTE MSCG COMPRESS ML-PACE) + PLUMED QMMM ML-QUIP SCAFACOS MACHDYN VTK KIM LATTE MSCG COMPRESS ML-PACE LEPTON) if(PKG_${PKG_WITH_INCL}) include(Packages/${PKG_WITH_INCL}) endif() diff --git a/cmake/Modules/Packages/LEPTON.cmake b/cmake/Modules/Packages/LEPTON.cmake new file mode 100644 index 0000000000..05241de592 --- /dev/null +++ b/cmake/Modules/Packages/LEPTON.cmake @@ -0,0 +1,8 @@ +set(LEPTON_SOURCE_DIR ${LAMMPS_LIB_SOURCE_DIR}/lepton) + +file(GLOB LEPTON_SOURCES ${LEPTON_SOURCE_DIR}/src/[^.]*.cpp) +add_library(lmplepton STATIC ${LEPTON_SOURCES}) +target_compile_definitions(lmplepton PRIVATE -DLEPTON_BUILDING_STATIC_LIBRARY) +set_target_properties(lmplepton PROPERTIES OUTPUT_NAME lammps_lmplepton${LAMMPS_MACHINE}) +target_include_directories(lmplepton PUBLIC ${LEPTON_SOURCE_DIR}/include) +target_link_libraries(lammps PRIVATE lmplepton) From 76a84d7865d2724dc75781514d317632a9de7812 Mon Sep 17 00:00:00 2001 From: Axel Kohlmeyer Date: Wed, 21 Dec 2022 18:28:57 -0500 Subject: [PATCH 04/79] add pair style lepton --- src/LEPTON/pair_lepton.cpp | 273 +++++++++++++++++++++++++++++++++++++ src/LEPTON/pair_lepton.h | 61 +++++++++ 2 files changed, 334 insertions(+) create mode 100644 src/LEPTON/pair_lepton.cpp create mode 100644 src/LEPTON/pair_lepton.h diff --git a/src/LEPTON/pair_lepton.cpp b/src/LEPTON/pair_lepton.cpp new file mode 100644 index 0000000000..7ec12bc988 --- /dev/null +++ b/src/LEPTON/pair_lepton.cpp @@ -0,0 +1,273 @@ +/* ---------------------------------------------------------------------- + LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator + https://www.lammps.org/, Sandia National Laboratories + LAMMPS development team: developers@lammps.org + + Copyright (2003) Sandia Corporation. Under the terms of Contract + DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government retains + certain rights in this software. This software is distributed under + the GNU General Public License. + + See the README file in the top-level LAMMPS directory. +------------------------------------------------------------------------- */ + +/* ---------------------------------------------------------------------- + Contributing authors: Axel Kohlmeyer (Temple U) +------------------------------------------------------------------------- */ + +#include "pair_lepton.h" + +#include "atom.h" +#include "error.h" +#include "force.h" +#include "memory.h" +#include "neigh_list.h" +#include "update.h" + +#include +#include + +using namespace LAMMPS_NS; + +/* ---------------------------------------------------------------------- */ + +PairLepton::PairLepton(LAMMPS *lmp) : Pair(lmp), cut(nullptr), type2expression(nullptr) +{ + respa_enable = 0; + single_enable = 1; + writedata = 1; + restartinfo = 0; + reinitflag = 0; + cut_global = 0.0; + centroidstressflag = CENTROID_SAME; +} + +/* ---------------------------------------------------------------------- */ + +PairLepton::~PairLepton() +{ + if (allocated) { + memory->destroy(cut); + memory->destroy(cutsq); + memory->destroy(setflag); + memory->destroy(type2expression); + } +} + +/* ---------------------------------------------------------------------- */ + +void PairLepton::compute(int eflag, int vflag) +{ + int i, j, ii, jj, inum, jnum, itype, jtype; + double xtmp, ytmp, ztmp, delx, dely, delz, evdwl, fpair; + double rsq, factor_lj; + int *ilist, *jlist, *numneigh, **firstneigh; + + evdwl = 0.0; + ev_init(eflag, vflag); + + double **x = atom->x; + double **f = atom->f; + int *type = atom->type; + int nlocal = atom->nlocal; + double *special_lj = force->special_lj; + int newton_pair = force->newton_pair; + + inum = list->inum; + ilist = list->ilist; + numneigh = list->numneigh; + firstneigh = list->firstneigh; + + std::vector force; + std::vector epot; + for (const auto &expr : expressions) { + force.emplace_back( + LMP_Lepton::Parser::parse(expr).differentiate("r").createCompiledExpression()); + if (eflag) epot.emplace_back(LMP_Lepton::Parser::parse(expr).createCompiledExpression()); + } + + // loop over neighbors of my atoms + + for (ii = 0; ii < inum; ii++) { + i = ilist[ii]; + xtmp = x[i][0]; + ytmp = x[i][1]; + ztmp = x[i][2]; + itype = type[i]; + jlist = firstneigh[i]; + jnum = numneigh[i]; + + for (jj = 0; jj < jnum; jj++) { + j = jlist[jj]; + const double factor_lj = special_lj[sbmask(j)]; + j &= NEIGHMASK; + + delx = xtmp - x[j][0]; + dely = ytmp - x[j][1]; + delz = ztmp - x[j][2]; + rsq = delx * delx + dely * dely + delz * delz; + jtype = type[j]; + + if (rsq < cutsq[itype][jtype]) { + const double r = sqrt(rsq); + const int idx = type2expression[itype][jtype]; + double &r_for = force[idx].getVariableReference("r"); + r_for = r; + fpair = -force[idx].evaluate() / r; + fpair *= factor_lj; + + f[i][0] += delx * fpair; + f[i][1] += dely * fpair; + f[i][2] += delz * fpair; + if (newton_pair || j < nlocal) { + f[j][0] -= delx * fpair; + f[j][1] -= dely * fpair; + f[j][2] -= delz * fpair; + } + + if (eflag) { + double &r_pot = epot[idx].getVariableReference("r"); + r_pot = r; + evdwl = factor_lj * epot[idx].evaluate(); + } else + evdwl = 0.0; + + if (evflag) ev_tally(i, j, nlocal, newton_pair, evdwl, 0.0, fpair, delx, dely, delz); + } + } + } + + if (vflag_fdotr) virial_fdotr_compute(); +} + +/* ---------------------------------------------------------------------- + allocate all arrays +------------------------------------------------------------------------- */ + +void PairLepton::allocate() +{ + allocated = 1; + int np1 = atom->ntypes + 1; + + memory->create(setflag, np1, np1, "pair:setflag"); + for (int i = 1; i < np1; i++) + for (int j = i; j < np1; j++) setflag[i][j] = 0; + + memory->create(cut, np1, np1, "pair:cut"); + memory->create(cutsq, np1, np1, "pair:cutsq"); + memory->create(type2expression, np1, np1, "pair:type2expression"); +} + +/* ---------------------------------------------------------------------- + global settings +------------------------------------------------------------------------- */ + +void PairLepton::settings(int narg, char **arg) +{ + if (narg != 1) error->all(FLERR, "Illegal pair_style command"); + + cut_global = utils::numeric(FLERR, arg[0], false, lmp); +} + +/* ---------------------------------------------------------------------- + set coeffs for all type pairs +------------------------------------------------------------------------- */ + +void PairLepton::coeff(int narg, char **arg) +{ + if (narg < 3 || narg > 4) error->all(FLERR, "Incorrect number of args for pair coefficients"); + if (!allocated) allocate(); + + int ilo, ihi, jlo, jhi; + utils::bounds(FLERR, arg[0], 1, atom->ntypes, ilo, ihi, error); + utils::bounds(FLERR, arg[1], 1, atom->ntypes, jlo, jhi, error); + + std::string exp_one = arg[2]; + double cut_one = cut_global; + if (narg == 4) cut_one = utils::numeric(FLERR, arg[3], false, lmp); + + // check if the expression can be parsed and evaluated as needed without error + try { + auto epot = LMP_Lepton::Parser::parse(exp_one).createCompiledExpression(); + auto force = LMP_Lepton::Parser::parse(exp_one).differentiate("r").createCompiledExpression(); + double &r_pot = epot.getVariableReference("r"); + double &r_for = force.getVariableReference("r"); + r_for = r_pot = 1.0; + epot.evaluate(); + force.evaluate(); + } catch (std::exception &e) { + error->all(FLERR, e.what()); + } + + std::size_t idx = 0; + for (const auto &exp : expressions) { + if (exp == exp_one) break; + ++idx; + } + + // not found, add to list + if ((expressions.size() == 0) || (idx == expressions.size())) expressions.push_back(exp_one); + + int count = 0; + for (int i = ilo; i <= ihi; i++) { + for (int j = MAX(jlo, i); j <= jhi; j++) { + cut[i][j] = cut_one; + setflag[i][j] = 1; + type2expression[i][j] = idx; + count++; + } + } + + if (count == 0) error->all(FLERR, "Incorrect args for pair coefficients"); +} + +/* ---------------------------------------------------------------------- */ + +double PairLepton::init_one(int i, int j) +{ + if (setflag[i][j] == 0) error->all(FLERR, "All pair coeffs are not set"); + + cut[j][i] = cut[i][j]; + type2expression[j][i] = type2expression[i][j]; + + return cut[i][j]; +} + +/* ---------------------------------------------------------------------- + proc 0 writes to data file +------------------------------------------------------------------------- */ + +void PairLepton::write_data(FILE *fp) +{ + for (int i = 1; i <= atom->ntypes; i++) + fprintf(fp, "%d '%s' %g\n", i, expressions[type2expression[i][i]].c_str(), cut[i][i]); +} + +/* ---------------------------------------------------------------------- + proc 0 writes all pairs to data file +------------------------------------------------------------------------- */ + +void PairLepton::write_data_all(FILE *fp) +{ + for (int i = 1; i <= atom->ntypes; i++) + for (int j = i; j <= atom->ntypes; j++) + fprintf(fp, "%d %d '%s' %g\n", i, j, expressions[type2expression[i][j]].c_str(), cut[i][j]); +} + +/* ---------------------------------------------------------------------- */ + +double PairLepton::single(int /* i */, int /* j */, int itype, int jtype, double rsq, + double /* factor_coul */, double factor_lj, double &fforce) +{ + auto expr = expressions[type2expression[itype][jtype]]; + auto epot = LMP_Lepton::Parser::parse(expr).createCompiledExpression(); + auto force = LMP_Lepton::Parser::parse(expr).differentiate("r").createCompiledExpression(); + + double r = sqrt(rsq); + double &r_pot = epot.getVariableReference("r"); + double &r_for = force.getVariableReference("r"); + + r_pot = r_for = r; + fforce = -force.evaluate() / r; + return epot.evaluate(); +} diff --git a/src/LEPTON/pair_lepton.h b/src/LEPTON/pair_lepton.h new file mode 100644 index 0000000000..a4f3b0c084 --- /dev/null +++ b/src/LEPTON/pair_lepton.h @@ -0,0 +1,61 @@ +/* -*- c++ -*- ---------------------------------------------------------- + LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator + https://www.lammps.org/, Sandia National Laboratories + LAMMPS development team: developers@lammps.org + + Copyright (2003) Sandia Corporation. Under the terms of Contract + DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government retains + certain rights in this software. This software is distributed under + the GNU General Public License. + + See the README file in the top-level LAMMPS directory. + + Pair zero is a dummy pair interaction useful for requiring a + force cutoff distance in the absence of pair-interactions or + with hybrid/overlay if a larger force cutoff distance is required. + + This can be used in conjunction with bond/create to create bonds + that are longer than the cutoff of a given force field, or to + calculate radial distribution functions for models without + pair interactions. + +------------------------------------------------------------------------- */ + +#ifdef PAIR_CLASS +// clang-format off +PairStyle(lepton,PairLepton); +// clang-format on +#else + +#ifndef LMP_PAIR_LEPTON_H +#define LMP_PAIR_LEPTON_H + +#include "pair.h" + +namespace LAMMPS_NS { + +class PairLepton : public Pair { + public: + PairLepton(class LAMMPS *); + ~PairLepton() override; + void compute(int, int) override; + void settings(int, char **) override; + void coeff(int, char **) override; + double init_one(int, int) override; + void write_data(FILE *) override; + void write_data_all(FILE *) override; + double single(int, int, int, int, double, double, double, double &) override; + + protected: + std::vector expressions; + double **cut; + int **type2expression; + double cut_global; + + virtual void allocate(); +}; + +} // namespace LAMMPS_NS + +#endif +#endif From 2cf1793a93996576e84e4668abc9e27b332db112 Mon Sep 17 00:00:00 2001 From: Axel Kohlmeyer Date: Wed, 21 Dec 2022 18:29:11 -0500 Subject: [PATCH 05/79] add unit test for pair style lepton --- .../force-styles/tests/mol-pair-lepton.yaml | 98 +++++++++++++++++++ 1 file changed, 98 insertions(+) create mode 100644 unittest/force-styles/tests/mol-pair-lepton.yaml diff --git a/unittest/force-styles/tests/mol-pair-lepton.yaml b/unittest/force-styles/tests/mol-pair-lepton.yaml new file mode 100644 index 0000000000..eaed580b02 --- /dev/null +++ b/unittest/force-styles/tests/mol-pair-lepton.yaml @@ -0,0 +1,98 @@ +--- +lammps_version: 3 Nov 2022 +tags: generated +date_generated: Wed Dec 21 18:26:23 2022 +epsilon: 5e-14 +skip_tests: +prerequisites: ! | + atom full + pair lepton +pre_commands: ! | + variable write_data_pair index ij +post_commands: ! "" +input_file: in.fourmol +pair_style: lepton 8.0 +pair_coeff: ! | + * * "4.0*eps*((sig/r)^12 - (sig/r)^6);eps=0.015;sig=3.1" + 1 1 "4.0*eps*((sig/r)^12 - (sig/r)^6);eps=0.02;sig=2.5" + 1 2 "4.0*eps*((sig/r)^12 - (sig/r)^6);eps=0.01;sig=1.75" + 1 3 "4.0*eps*((sig/r)^12 - (sig/r)^6);eps=0.02;sig=2.85" + 1 4*5 "4.0*eps*((sig/r)^12 - (sig/r)^6);eps=0.0173205;sig=2.8" + 2 2 "4.0*eps*((sig/r)^12 - (sig/r)^6);eps=0.005;sig=1.0" + 2 3 "4.0*eps*((sig/r)^12 - (sig/r)^6);eps=0.01;sig=2.1" + 2 4 "4.0*eps*((sig/r)^12 - (sig/r)^6);eps=0.005;sig=0.5" + 2 5 "4.0*eps*((sig/r)^12 - (sig/r)^6);eps=0.00866025;sig=2.05" + 3 3 "4.0*eps*((sig/r)^12 - (sig/r)^6);eps=0.02;sig=3.2" + 3 4 "4.0*eps*((sig/r)^12 - (sig/r)^6);eps=0.0173205;sig=3.15" + 3 5 "4.0*eps*((sig/r)^12 - (sig/r)^6);eps=0.0173205;sig=3.15" +extract: ! "" +natoms: 29 +init_vdwl: 749.2370315373564 +init_coul: 0 +init_stress: ! |2- + 2.1793853434038242e+03 2.1988955172192768e+03 4.6653977523326257e+03 -7.5956547636050584e+02 2.4751536734032868e+01 6.6652028436400667e+02 +init_forces: ! |2 + 1 -2.3333390280895909e+01 2.6994567613322647e+02 3.3272827850356805e+02 + 2 1.5828554630414899e+02 1.3025008843535872e+02 -1.8629682358935722e+02 + 3 -1.3528903738169063e+02 -3.8704313358319996e+02 -1.4568978437133106e+02 + 4 -7.8711096705893366e+00 2.1350518625373534e+00 -5.5954532185548143e+00 + 5 -2.5176757268228540e+00 -4.0521510681020230e+00 1.2152704057877019e+01 + 6 -8.3190662465252137e+02 9.6394149462625592e+02 1.1509093566509246e+03 + 7 5.8203388932513604e+01 -3.3608997951626793e+02 -1.7179617996573040e+03 + 8 1.4451392284291526e+02 -1.0927475861088995e+02 3.9990593492420442e+02 + 9 7.9156945283097443e+01 8.5273009783986538e+01 3.5032175698445189e+02 + 10 5.3118875219105371e+02 -6.1040990859419412e+02 -1.8355872642619292e+02 + 11 -2.3530157267965524e+00 -5.9077640073819726e+00 -9.6590723955414290e+00 + 12 1.7527155146800425e+01 1.0633119523437513e+01 -7.9254398064483160e+00 + 13 8.0986409579532950e+00 -3.2098088264781550e+00 -1.4896399843793842e-01 + 14 -3.3852721292265158e+00 6.8636181241903627e-01 -8.7507190862499868e+00 + 15 -2.0454999188605297e-01 8.4846165523049901e+00 3.0131615419406712e+00 + 16 4.6326310311812108e+02 -3.3087715736498188e+02 -1.1893024561782554e+03 + 17 -4.5334300923766733e+02 3.1554283255882575e+02 1.2058417793481203e+03 + 18 -1.8862623280672661e-02 -3.3402010907951661e-02 3.1000479299095263e-02 + 19 3.1843079640570047e-04 -2.3918627818763423e-04 1.7427252638513439e-03 + 20 -9.9760831209706009e-04 -1.0209184826753086e-03 3.6910972636601454e-04 + 21 -7.1566125273265186e+01 -8.1615678329920655e+01 2.2589561408339890e+02 + 22 -1.0808835729977497e+02 -2.6193787235943894e+01 -1.6957904943161401e+02 + 23 1.7964455474779490e+02 1.0782097695276948e+02 -5.6305786479140629e+01 + 24 3.6591406576584546e+01 -2.1181587621785579e+02 1.1218301872572377e+02 + 25 -1.4851489147738798e+02 2.3907118122949061e+01 -1.2485634873166291e+02 + 26 1.1191129453598218e+02 1.8789774664223384e+02 1.2650137204319906e+01 + 27 5.1810388677546001e+01 -2.2705458321213797e+02 9.0849111082069669e+01 + 28 -1.8041307121444069e+02 7.7534042932772905e+01 -1.2206956760706598e+02 + 29 1.2861057254925012e+02 1.4952711274394568e+02 3.1216025556267880e+01 +run_vdwl: 719.4432816774653 +run_coul: 0 +run_stress: ! |2- + 2.1330153957371017e+03 2.1547728168285516e+03 4.3976497417710116e+03 -7.3873328448298525e+02 4.1743821105368760e+01 6.2788012209191072e+02 +run_forces: ! |2 + 1 -2.0299419751359164e+01 2.6686193378823020e+02 3.2358785870694010e+02 + 2 1.5298617928491225e+02 1.2596516341409203e+02 -1.7961292655338619e+02 + 3 -1.3353630652439830e+02 -3.7923748696131315e+02 -1.4291839793625815e+02 + 4 -7.8374717836161762e+00 2.1276610789823409e+00 -5.5845014473820616e+00 + 5 -2.5014258630866721e+00 -4.0250131424704412e+00 1.2103512372025637e+01 + 6 -8.0681462887292457e+02 9.2165637136761688e+02 1.0270795806932783e+03 + 7 5.5780279349903516e+01 -3.1117530951561662e+02 -1.5746991292869018e+03 + 8 1.3452983055535049e+02 -1.0064659350255911e+02 3.8851791558207651e+02 + 9 7.6746213883425980e+01 8.2501469877402130e+01 3.3944351200617888e+02 + 10 5.2128033527695595e+02 -5.9920098848285863e+02 -1.8126029815043339e+02 + 11 -2.3573118090915250e+00 -5.8616944550888368e+00 -9.6049808811326240e+00 + 12 1.7503975847822900e+01 1.0626930310560816e+01 -8.0603160272054986e+00 + 13 8.0530313322973104e+00 -3.1756495170399108e+00 -1.4618315664740528e-01 + 14 -3.3416065168069768e+00 6.6492606336082150e-01 -8.6345131440469700e+00 + 15 -2.2253843262374914e-01 8.5025661635348762e+00 3.0369735873081622e+00 + 16 4.3476311264989465e+02 -3.1171086735551415e+02 -1.1135217194927448e+03 + 17 -4.2469846140777139e+02 2.9615411776780593e+02 1.1302573488400665e+03 + 18 -1.8849981672825911e-02 -3.3371636477421307e-02 3.0986293443778734e-02 + 19 3.0940277774414016e-04 -2.4634536455373038e-04 1.7433360008861014e-03 + 20 -9.8648131277150747e-04 -1.0112587134526948e-03 3.6932948773965417e-04 + 21 -7.0490745283106378e+01 -7.9749153581142139e+01 2.2171003384646431e+02 + 22 -1.0638717908920071e+02 -2.5949502163177975e+01 -1.6645589526812273e+02 + 23 1.7686797710735033e+02 1.0571018898885515e+02 -5.5243337084099373e+01 + 24 3.8206017656281375e+01 -2.1022820141992960e+02 1.1260711266189014e+02 + 25 -1.4918881473530880e+02 2.3762151395876508e+01 -1.2549188139143085e+02 + 26 1.1097059498808308e+02 1.8645503634228518e+02 1.2861559677865248e+01 + 27 5.0800844984832125e+01 -2.2296588090685469e+02 8.8607367716323253e+01 + 28 -1.7694190504288886e+02 7.6029945485182026e+01 -1.1950518150242071e+02 + 29 1.2614894925528141e+02 1.4694250820033548e+02 3.0893386672863034e+01 +... From 6c5a698be48fa4f8d53b864485ee21273cc338be Mon Sep 17 00:00:00 2001 From: Axel Kohlmeyer Date: Wed, 21 Dec 2022 19:24:28 -0500 Subject: [PATCH 06/79] try to speed up compute kernel --- src/LEPTON/pair_lepton.cpp | 106 ++++++++++++++++++++++--------------- src/LEPTON/pair_lepton.h | 5 +- 2 files changed, 66 insertions(+), 45 deletions(-) diff --git a/src/LEPTON/pair_lepton.cpp b/src/LEPTON/pair_lepton.cpp index 7ec12bc988..46377a341a 100644 --- a/src/LEPTON/pair_lepton.cpp +++ b/src/LEPTON/pair_lepton.cpp @@ -58,86 +58,104 @@ PairLepton::~PairLepton() void PairLepton::compute(int eflag, int vflag) { - int i, j, ii, jj, inum, jnum, itype, jtype; - double xtmp, ytmp, ztmp, delx, dely, delz, evdwl, fpair; - double rsq, factor_lj; - int *ilist, *jlist, *numneigh, **firstneigh; - - evdwl = 0.0; ev_init(eflag, vflag); + if (evflag) { + if (eflag) { + if (force->newton_pair) + eval<1, 1, 1>(); + else + eval<1, 1, 0>(); + } else { + if (force->newton_pair) + eval<1, 0, 1>(); + else + eval<1, 0, 0>(); + } + } else { + if (force->newton_pair) + eval<0, 0, 1>(); + else + eval<0, 0, 0>(); + } + if (vflag_fdotr) virial_fdotr_compute(); +} - double **x = atom->x; - double **f = atom->f; - int *type = atom->type; - int nlocal = atom->nlocal; - double *special_lj = force->special_lj; - int newton_pair = force->newton_pair; +/* ---------------------------------------------------------------------- */ +template void PairLepton::eval() +{ + const double *const *const x = atom->x; + double *const *const f = atom->f; + const int *const type = atom->type; + const int nlocal = atom->nlocal; + const double *const special_lj = force->special_lj; - inum = list->inum; - ilist = list->ilist; - numneigh = list->numneigh; - firstneigh = list->firstneigh; + const int inum = list->inum; + const int *const ilist = list->ilist; + const int *const numneigh = list->numneigh; + const int *const *const firstneigh = list->firstneigh; std::vector force; std::vector epot; for (const auto &expr : expressions) { force.emplace_back( LMP_Lepton::Parser::parse(expr).differentiate("r").createCompiledExpression()); - if (eflag) epot.emplace_back(LMP_Lepton::Parser::parse(expr).createCompiledExpression()); + if (EFLAG) epot.emplace_back(LMP_Lepton::Parser::parse(expr).createCompiledExpression()); } // loop over neighbors of my atoms - for (ii = 0; ii < inum; ii++) { - i = ilist[ii]; - xtmp = x[i][0]; - ytmp = x[i][1]; - ztmp = x[i][2]; - itype = type[i]; - jlist = firstneigh[i]; - jnum = numneigh[i]; + for (int ii = 0; ii < inum; ii++) { + const int i = ilist[ii]; + const double xtmp = x[i][0]; + const double ytmp = x[i][1]; + const double ztmp = x[i][2]; + const int itype = type[i]; + const int *jlist = firstneigh[i]; + const int jnum = numneigh[i]; + double fxtmp, fytmp, fztmp; + fxtmp = fytmp = fztmp = 0.0; - for (jj = 0; jj < jnum; jj++) { - j = jlist[jj]; + for (int jj = 0; jj < jnum; jj++) { + int j = jlist[jj]; const double factor_lj = special_lj[sbmask(j)]; j &= NEIGHMASK; + const int jtype = type[j]; - delx = xtmp - x[j][0]; - dely = ytmp - x[j][1]; - delz = ztmp - x[j][2]; - rsq = delx * delx + dely * dely + delz * delz; - jtype = type[j]; + const double delx = xtmp - x[j][0]; + const double dely = ytmp - x[j][1]; + const double delz = ztmp - x[j][2]; + const double rsq = delx * delx + dely * dely + delz * delz; if (rsq < cutsq[itype][jtype]) { const double r = sqrt(rsq); const int idx = type2expression[itype][jtype]; double &r_for = force[idx].getVariableReference("r"); r_for = r; - fpair = -force[idx].evaluate() / r; - fpair *= factor_lj; + const double fpair = -force[idx].evaluate() / r * factor_lj; - f[i][0] += delx * fpair; - f[i][1] += dely * fpair; - f[i][2] += delz * fpair; - if (newton_pair || j < nlocal) { + fxtmp += delx * fpair; + fytmp += dely * fpair; + fztmp += delz * fpair; + if (NEWTON_PAIR || (j < nlocal)) { f[j][0] -= delx * fpair; f[j][1] -= dely * fpair; f[j][2] -= delz * fpair; } - if (eflag) { + double evdwl = 0.0; + if (EFLAG) { double &r_pot = epot[idx].getVariableReference("r"); r_pot = r; evdwl = factor_lj * epot[idx].evaluate(); - } else - evdwl = 0.0; + } - if (evflag) ev_tally(i, j, nlocal, newton_pair, evdwl, 0.0, fpair, delx, dely, delz); + if (EVFLAG) ev_tally(i, j, nlocal, NEWTON_PAIR, evdwl, 0.0, fpair, delx, dely, delz); } } + f[i][0] += fxtmp; + f[i][1] += fytmp; + f[i][2] += fztmp; } - - if (vflag_fdotr) virial_fdotr_compute(); } /* ---------------------------------------------------------------------- diff --git a/src/LEPTON/pair_lepton.h b/src/LEPTON/pair_lepton.h index a4f3b0c084..dd1f40d1d3 100644 --- a/src/LEPTON/pair_lepton.h +++ b/src/LEPTON/pair_lepton.h @@ -43,7 +43,7 @@ class PairLepton : public Pair { void coeff(int, char **) override; double init_one(int, int) override; void write_data(FILE *) override; - void write_data_all(FILE *) override; + void write_data_all(FILE *) override; double single(int, int, int, int, double, double, double, double &) override; protected: @@ -52,6 +52,9 @@ class PairLepton : public Pair { int **type2expression; double cut_global; + private: + template void eval(); + virtual void allocate(); }; From 969ac572562396b9840a1fce1d1df2ccd9564065 Mon Sep 17 00:00:00 2001 From: Axel Kohlmeyer Date: Wed, 21 Dec 2022 21:16:23 -0500 Subject: [PATCH 07/79] make expression string compact and easier restartable by removing quotes and whitespace --- src/LEPTON/pair_lepton.cpp | 12 +++++++---- .../force-styles/tests/mol-pair-lepton.yaml | 20 +++++++++---------- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/src/LEPTON/pair_lepton.cpp b/src/LEPTON/pair_lepton.cpp index 46377a341a..f596a901f0 100644 --- a/src/LEPTON/pair_lepton.cpp +++ b/src/LEPTON/pair_lepton.cpp @@ -200,11 +200,15 @@ void PairLepton::coeff(int narg, char **arg) utils::bounds(FLERR, arg[0], 1, atom->ntypes, ilo, ihi, error); utils::bounds(FLERR, arg[1], 1, atom->ntypes, jlo, jhi, error); - std::string exp_one = arg[2]; double cut_one = cut_global; if (narg == 4) cut_one = utils::numeric(FLERR, arg[3], false, lmp); - // check if the expression can be parsed and evaluated as needed without error + // remove whitespace and quotes from expression string and then + // check if the expression can be parsed and evaluated without error + std::string exp_one; + for (const auto &c : std::string(arg[2])) + if (!isspace(c) && (c != '"') && (c != '\'')) exp_one.push_back(c); + try { auto epot = LMP_Lepton::Parser::parse(exp_one).createCompiledExpression(); auto force = LMP_Lepton::Parser::parse(exp_one).differentiate("r").createCompiledExpression(); @@ -258,7 +262,7 @@ double PairLepton::init_one(int i, int j) void PairLepton::write_data(FILE *fp) { for (int i = 1; i <= atom->ntypes; i++) - fprintf(fp, "%d '%s' %g\n", i, expressions[type2expression[i][i]].c_str(), cut[i][i]); + fprintf(fp, "%d %s %g\n", i, expressions[type2expression[i][i]].c_str(), cut[i][i]); } /* ---------------------------------------------------------------------- @@ -269,7 +273,7 @@ void PairLepton::write_data_all(FILE *fp) { for (int i = 1; i <= atom->ntypes; i++) for (int j = i; j <= atom->ntypes; j++) - fprintf(fp, "%d %d '%s' %g\n", i, j, expressions[type2expression[i][j]].c_str(), cut[i][j]); + fprintf(fp, "%d %d %s %g\n", i, j, expressions[type2expression[i][j]].c_str(), cut[i][j]); } /* ---------------------------------------------------------------------- */ diff --git a/unittest/force-styles/tests/mol-pair-lepton.yaml b/unittest/force-styles/tests/mol-pair-lepton.yaml index eaed580b02..56640e3834 100644 --- a/unittest/force-styles/tests/mol-pair-lepton.yaml +++ b/unittest/force-styles/tests/mol-pair-lepton.yaml @@ -14,17 +14,17 @@ input_file: in.fourmol pair_style: lepton 8.0 pair_coeff: ! | * * "4.0*eps*((sig/r)^12 - (sig/r)^6);eps=0.015;sig=3.1" - 1 1 "4.0*eps*((sig/r)^12 - (sig/r)^6);eps=0.02;sig=2.5" + 1 1 '4.0*eps*((sig/r)^12 - (sig/r)^6);eps=0.02;sig=2.5' 1 2 "4.0*eps*((sig/r)^12 - (sig/r)^6);eps=0.01;sig=1.75" - 1 3 "4.0*eps*((sig/r)^12 - (sig/r)^6);eps=0.02;sig=2.85" - 1 4*5 "4.0*eps*((sig/r)^12 - (sig/r)^6);eps=0.0173205;sig=2.8" - 2 2 "4.0*eps*((sig/r)^12 - (sig/r)^6);eps=0.005;sig=1.0" - 2 3 "4.0*eps*((sig/r)^12 - (sig/r)^6);eps=0.01;sig=2.1" - 2 4 "4.0*eps*((sig/r)^12 - (sig/r)^6);eps=0.005;sig=0.5" - 2 5 "4.0*eps*((sig/r)^12 - (sig/r)^6);eps=0.00866025;sig=2.05" - 3 3 "4.0*eps*((sig/r)^12 - (sig/r)^6);eps=0.02;sig=3.2" - 3 4 "4.0*eps*((sig/r)^12 - (sig/r)^6);eps=0.0173205;sig=3.15" - 3 5 "4.0*eps*((sig/r)^12 - (sig/r)^6);eps=0.0173205;sig=3.15" + 1 3 '4.0*eps*((sig/r)^12-(sig/r)^6); eps=0.02;sig=2.85' + 1 4*5 "4.0*eps*((sig/r)^12-(sig/r)^6);eps=0.0173205; sig=2.8" + 2 2 "4.0*eps*((sig/r)^12-(sig/r)^6);eps=0.005;sig=1.0" + 2 3 "4.0*eps*((sig/r)^12-(sig/r)^6);eps=0.01;sig=2.1" + 2 4 "4.0*eps*((sig/r)^12-(sig/r)^6);eps=0.005;sig=0.5" + 2 5 "4.0*eps*((sig/r)^12-(sig/r)^6);eps=0.00866025;sig=2.05" + 3 3 "4.0*eps*((sig/r)^12-(sig/r)^6);eps=0.02;sig=3.2" + 3 4 "4.0*eps*((sig/r)^12-(sig/r)^6);eps=0.0173205;sig=3.15" + 3 5 "4.0*eps*((sig/r)^12-(sig/r)^6);eps=0.0173205;sig=3.15" extract: ! "" natoms: 29 init_vdwl: 749.2370315373564 From c64066eb21c0c21d4ac8b230f4347c7626bb59ff Mon Sep 17 00:00:00 2001 From: Axel Kohlmeyer Date: Wed, 21 Dec 2022 21:16:59 -0500 Subject: [PATCH 08/79] simplify processing of expressions --- src/LEPTON/pair_lepton.cpp | 27 ++++--- .../force-styles/tests/mol-pair-lepton.yaml | 78 +++++++++---------- 2 files changed, 54 insertions(+), 51 deletions(-) diff --git a/src/LEPTON/pair_lepton.cpp b/src/LEPTON/pair_lepton.cpp index f596a901f0..11587d637c 100644 --- a/src/LEPTON/pair_lepton.cpp +++ b/src/LEPTON/pair_lepton.cpp @@ -24,9 +24,11 @@ #include "neigh_list.h" #include "update.h" -#include +#include #include +#include "LMP_Lepton.h" + using namespace LAMMPS_NS; /* ---------------------------------------------------------------------- */ @@ -93,13 +95,14 @@ template void PairLepton::eval() const int *const ilist = list->ilist; const int *const numneigh = list->numneigh; const int *const *const firstneigh = list->firstneigh; + double fxtmp, fytmp, fztmp; std::vector force; std::vector epot; for (const auto &expr : expressions) { - force.emplace_back( - LMP_Lepton::Parser::parse(expr).differentiate("r").createCompiledExpression()); - if (EFLAG) epot.emplace_back(LMP_Lepton::Parser::parse(expr).createCompiledExpression()); + auto parsed = LMP_Lepton::Parser::parse(expr); + force.emplace_back(parsed.differentiate("r").createCompiledExpression()); + if (EFLAG) epot.emplace_back(parsed.createCompiledExpression()); } // loop over neighbors of my atoms @@ -112,7 +115,6 @@ template void PairLepton::eval() const int itype = type[i]; const int *jlist = firstneigh[i]; const int jnum = numneigh[i]; - double fxtmp, fytmp, fztmp; fxtmp = fytmp = fztmp = 0.0; for (int jj = 0; jj < jnum; jj++) { @@ -210,8 +212,9 @@ void PairLepton::coeff(int narg, char **arg) if (!isspace(c) && (c != '"') && (c != '\'')) exp_one.push_back(c); try { - auto epot = LMP_Lepton::Parser::parse(exp_one).createCompiledExpression(); - auto force = LMP_Lepton::Parser::parse(exp_one).differentiate("r").createCompiledExpression(); + auto parsed = LMP_Lepton::Parser::parse(exp_one); + auto epot = parsed.createCompiledExpression(); + auto force = parsed.differentiate("r").createCompiledExpression(); double &r_pot = epot.getVariableReference("r"); double &r_for = force.getVariableReference("r"); r_for = r_pot = 1.0; @@ -281,15 +284,15 @@ void PairLepton::write_data_all(FILE *fp) double PairLepton::single(int /* i */, int /* j */, int itype, int jtype, double rsq, double /* factor_coul */, double factor_lj, double &fforce) { - auto expr = expressions[type2expression[itype][jtype]]; - auto epot = LMP_Lepton::Parser::parse(expr).createCompiledExpression(); - auto force = LMP_Lepton::Parser::parse(expr).differentiate("r").createCompiledExpression(); + auto parsed = LMP_Lepton::Parser::parse(expressions[type2expression[itype][jtype]]); + auto epot = parsed.createCompiledExpression(); + auto force = parsed.differentiate("r").createCompiledExpression(); double r = sqrt(rsq); double &r_pot = epot.getVariableReference("r"); double &r_for = force.getVariableReference("r"); r_pot = r_for = r; - fforce = -force.evaluate() / r; - return epot.evaluate(); + fforce = -force.evaluate() / r * factor_lj; + return epot.evaluate() * factor_lj; } diff --git a/unittest/force-styles/tests/mol-pair-lepton.yaml b/unittest/force-styles/tests/mol-pair-lepton.yaml index 56640e3834..d58780b1fe 100644 --- a/unittest/force-styles/tests/mol-pair-lepton.yaml +++ b/unittest/force-styles/tests/mol-pair-lepton.yaml @@ -1,7 +1,7 @@ --- lammps_version: 3 Nov 2022 tags: generated -date_generated: Wed Dec 21 18:26:23 2022 +date_generated: Wed Dec 21 20:29:34 2022 epsilon: 5e-14 skip_tests: prerequisites: ! | @@ -30,65 +30,65 @@ natoms: 29 init_vdwl: 749.2370315373564 init_coul: 0 init_stress: ! |2- - 2.1793853434038242e+03 2.1988955172192768e+03 4.6653977523326257e+03 -7.5956547636050584e+02 2.4751536734032868e+01 6.6652028436400667e+02 + 2.1793853434038242e+03 2.1988955172192768e+03 4.6653977523326257e+03 -7.5956547636050584e+02 2.4751536734032861e+01 6.6652028436400667e+02 init_forces: ! |2 - 1 -2.3333390280895909e+01 2.6994567613322647e+02 3.3272827850356805e+02 + 1 -2.3333390280895912e+01 2.6994567613322641e+02 3.3272827850356805e+02 2 1.5828554630414899e+02 1.3025008843535872e+02 -1.8629682358935722e+02 - 3 -1.3528903738169063e+02 -3.8704313358319996e+02 -1.4568978437133106e+02 - 4 -7.8711096705893366e+00 2.1350518625373534e+00 -5.5954532185548143e+00 - 5 -2.5176757268228540e+00 -4.0521510681020230e+00 1.2152704057877019e+01 - 6 -8.3190662465252137e+02 9.6394149462625592e+02 1.1509093566509246e+03 - 7 5.8203388932513604e+01 -3.3608997951626793e+02 -1.7179617996573040e+03 - 8 1.4451392284291526e+02 -1.0927475861088995e+02 3.9990593492420442e+02 + 3 -1.3528903738169066e+02 -3.8704313358319990e+02 -1.4568978437133106e+02 + 4 -7.8711096705893366e+00 2.1350518625373538e+00 -5.5954532185548134e+00 + 5 -2.5176757268228540e+00 -4.0521510681020239e+00 1.2152704057877019e+01 + 6 -8.3190662465252137e+02 9.6394149462625603e+02 1.1509093566509248e+03 + 7 5.8203388932513583e+01 -3.3608997951626793e+02 -1.7179617996573040e+03 + 8 1.4451392284291535e+02 -1.0927475861088995e+02 3.9990593492420442e+02 9 7.9156945283097443e+01 8.5273009783986538e+01 3.5032175698445189e+02 - 10 5.3118875219105371e+02 -6.1040990859419412e+02 -1.8355872642619292e+02 - 11 -2.3530157267965524e+00 -5.9077640073819726e+00 -9.6590723955414290e+00 - 12 1.7527155146800425e+01 1.0633119523437513e+01 -7.9254398064483160e+00 - 13 8.0986409579532950e+00 -3.2098088264781550e+00 -1.4896399843793842e-01 - 14 -3.3852721292265158e+00 6.8636181241903627e-01 -8.7507190862499868e+00 - 15 -2.0454999188605297e-01 8.4846165523049901e+00 3.0131615419406712e+00 + 10 5.3118875219105360e+02 -6.1040990859419412e+02 -1.8355872642619292e+02 + 11 -2.3530157267965532e+00 -5.9077640073819717e+00 -9.6590723955414290e+00 + 12 1.7527155146800425e+01 1.0633119523437511e+01 -7.9254398064483169e+00 + 13 8.0986409579532967e+00 -3.2098088264781546e+00 -1.4896399843793839e-01 + 14 -3.3852721292265153e+00 6.8636181241903649e-01 -8.7507190862499868e+00 + 15 -2.0454999188605300e-01 8.4846165523049883e+00 3.0131615419406712e+00 16 4.6326310311812108e+02 -3.3087715736498188e+02 -1.1893024561782554e+03 - 17 -4.5334300923766733e+02 3.1554283255882575e+02 1.2058417793481203e+03 - 18 -1.8862623280672661e-02 -3.3402010907951661e-02 3.1000479299095263e-02 - 19 3.1843079640570047e-04 -2.3918627818763423e-04 1.7427252638513439e-03 - 20 -9.9760831209706009e-04 -1.0209184826753086e-03 3.6910972636601454e-04 + 17 -4.5334300923766727e+02 3.1554283255882569e+02 1.2058417793481203e+03 + 18 -1.8862623280672661e-02 -3.3402010907951661e-02 3.1000479299095260e-02 + 19 3.1843079640570047e-04 -2.3918627818763426e-04 1.7427252638513439e-03 + 20 -9.9760831209706009e-04 -1.0209184826753090e-03 3.6910972636601454e-04 21 -7.1566125273265186e+01 -8.1615678329920655e+01 2.2589561408339890e+02 - 22 -1.0808835729977497e+02 -2.6193787235943894e+01 -1.6957904943161401e+02 - 23 1.7964455474779490e+02 1.0782097695276948e+02 -5.6305786479140629e+01 + 22 -1.0808835729977498e+02 -2.6193787235943887e+01 -1.6957904943161401e+02 + 23 1.7964455474779487e+02 1.0782097695276950e+02 -5.6305786479140636e+01 24 3.6591406576584546e+01 -2.1181587621785579e+02 1.1218301872572377e+02 25 -1.4851489147738798e+02 2.3907118122949061e+01 -1.2485634873166291e+02 - 26 1.1191129453598218e+02 1.8789774664223384e+02 1.2650137204319906e+01 + 26 1.1191129453598219e+02 1.8789774664223384e+02 1.2650137204319904e+01 27 5.1810388677546001e+01 -2.2705458321213797e+02 9.0849111082069669e+01 28 -1.8041307121444069e+02 7.7534042932772905e+01 -1.2206956760706598e+02 29 1.2861057254925012e+02 1.4952711274394568e+02 3.1216025556267880e+01 run_vdwl: 719.4432816774653 run_coul: 0 run_stress: ! |2- - 2.1330153957371017e+03 2.1547728168285516e+03 4.3976497417710116e+03 -7.3873328448298525e+02 4.1743821105368760e+01 6.2788012209191072e+02 + 2.1330153957371017e+03 2.1547728168285516e+03 4.3976497417710125e+03 -7.3873328448298525e+02 4.1743821105370067e+01 6.2788012209191027e+02 run_forces: ! |2 - 1 -2.0299419751359164e+01 2.6686193378823020e+02 3.2358785870694010e+02 + 1 -2.0299419751359164e+01 2.6686193378823020e+02 3.2358785870694015e+02 2 1.5298617928491225e+02 1.2596516341409203e+02 -1.7961292655338619e+02 - 3 -1.3353630652439830e+02 -3.7923748696131315e+02 -1.4291839793625815e+02 + 3 -1.3353630652439830e+02 -3.7923748696131315e+02 -1.4291839793625817e+02 4 -7.8374717836161762e+00 2.1276610789823409e+00 -5.5845014473820616e+00 - 5 -2.5014258630866721e+00 -4.0250131424704412e+00 1.2103512372025637e+01 + 5 -2.5014258630866735e+00 -4.0250131424704412e+00 1.2103512372025639e+01 6 -8.0681462887292457e+02 9.2165637136761688e+02 1.0270795806932783e+03 - 7 5.5780279349903516e+01 -3.1117530951561662e+02 -1.5746991292869018e+03 + 7 5.5780279349903523e+01 -3.1117530951561656e+02 -1.5746991292869018e+03 8 1.3452983055535049e+02 -1.0064659350255911e+02 3.8851791558207651e+02 - 9 7.6746213883425980e+01 8.2501469877402130e+01 3.3944351200617888e+02 + 9 7.6746213883425980e+01 8.2501469877402130e+01 3.3944351200617882e+02 10 5.2128033527695595e+02 -5.9920098848285863e+02 -1.8126029815043339e+02 - 11 -2.3573118090915250e+00 -5.8616944550888368e+00 -9.6049808811326240e+00 - 12 1.7503975847822900e+01 1.0626930310560816e+01 -8.0603160272054986e+00 - 13 8.0530313322973104e+00 -3.1756495170399108e+00 -1.4618315664740528e-01 - 14 -3.3416065168069768e+00 6.6492606336082150e-01 -8.6345131440469700e+00 - 15 -2.2253843262374914e-01 8.5025661635348762e+00 3.0369735873081622e+00 + 11 -2.3573118090915246e+00 -5.8616944550888359e+00 -9.6049808811326205e+00 + 12 1.7503975847822900e+01 1.0626930310560814e+01 -8.0603160272054968e+00 + 13 8.0530313322973104e+00 -3.1756495170399117e+00 -1.4618315664740528e-01 + 14 -3.3416065168069773e+00 6.6492606336082150e-01 -8.6345131440469700e+00 + 15 -2.2253843262374914e-01 8.5025661635348779e+00 3.0369735873081622e+00 16 4.3476311264989465e+02 -3.1171086735551415e+02 -1.1135217194927448e+03 - 17 -4.2469846140777139e+02 2.9615411776780593e+02 1.1302573488400665e+03 - 18 -1.8849981672825911e-02 -3.3371636477421307e-02 3.0986293443778734e-02 - 19 3.0940277774414016e-04 -2.4634536455373038e-04 1.7433360008861014e-03 - 20 -9.8648131277150747e-04 -1.0112587134526948e-03 3.6932948773965417e-04 + 17 -4.2469846140777133e+02 2.9615411776780593e+02 1.1302573488400669e+03 + 18 -1.8849981672825908e-02 -3.3371636477421307e-02 3.0986293443778727e-02 + 19 3.0940277774414027e-04 -2.4634536455373044e-04 1.7433360008861016e-03 + 20 -9.8648131277150790e-04 -1.0112587134526946e-03 3.6932948773965417e-04 21 -7.0490745283106378e+01 -7.9749153581142139e+01 2.2171003384646431e+02 - 22 -1.0638717908920071e+02 -2.5949502163177975e+01 -1.6645589526812273e+02 - 23 1.7686797710735033e+02 1.0571018898885515e+02 -5.5243337084099373e+01 + 22 -1.0638717908920071e+02 -2.5949502163177968e+01 -1.6645589526812276e+02 + 23 1.7686797710735027e+02 1.0571018898885514e+02 -5.5243337084099387e+01 24 3.8206017656281375e+01 -2.1022820141992960e+02 1.1260711266189014e+02 25 -1.4918881473530880e+02 2.3762151395876508e+01 -1.2549188139143085e+02 26 1.1097059498808308e+02 1.8645503634228518e+02 1.2861559677865248e+01 From 8511aae2110870d9a45ffc58d5bb965e523ef6f0 Mon Sep 17 00:00:00 2001 From: Axel Kohlmeyer Date: Wed, 21 Dec 2022 21:18:20 -0500 Subject: [PATCH 09/79] add OPENMP package version of pair style lepton --- src/OPENMP/pair_lepton_omp.cpp | 168 +++++++++++++++++++++++++++++++++ src/OPENMP/pair_lepton_omp.h | 48 ++++++++++ 2 files changed, 216 insertions(+) create mode 100644 src/OPENMP/pair_lepton_omp.cpp create mode 100644 src/OPENMP/pair_lepton_omp.h diff --git a/src/OPENMP/pair_lepton_omp.cpp b/src/OPENMP/pair_lepton_omp.cpp new file mode 100644 index 0000000000..35f8e2c89d --- /dev/null +++ b/src/OPENMP/pair_lepton_omp.cpp @@ -0,0 +1,168 @@ +/* ---------------------------------------------------------------------- + LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator + https://www.lammps.org/, Sandia National Laboratories + LAMMPS development team: developers@lammps.org + + This software is distributed under the GNU General Public License. + + See the README file in the top-level LAMMPS directory. +------------------------------------------------------------------------- */ + +/* ---------------------------------------------------------------------- + Contributing author: Axel Kohlmeyer (Temple U) +------------------------------------------------------------------------- */ + +#include "pair_lepton_omp.h" + +#include "atom.h" +#include "comm.h" +#include "force.h" +#include "neigh_list.h" +#include "suffix.h" + +#include + +#include "LMP_Lepton.h" +#include "omp_compat.h" +using namespace LAMMPS_NS; + +/* ---------------------------------------------------------------------- */ + +PairLeptonOMP::PairLeptonOMP(LAMMPS *lmp) : PairLepton(lmp), ThrOMP(lmp, THR_PAIR) +{ + suffix_flag |= Suffix::OMP; + respa_enable = 0; +} + +/* ---------------------------------------------------------------------- */ + +void PairLeptonOMP::compute(int eflag, int vflag) +{ + ev_init(eflag, vflag); + + const int nall = atom->nlocal + atom->nghost; + const int nthreads = comm->nthreads; + const int inum = list->inum; + +#if defined(_OPENMP) +#pragma omp parallel LMP_DEFAULT_NONE LMP_SHARED(eflag, vflag) +#endif + { + int ifrom, ito, tid; + + loop_setup_thr(ifrom, ito, tid, inum, nthreads); + ThrData *thr = fix->get_thr(tid); + thr->timer(Timer::START); + ev_setup_thr(eflag, vflag, nall, eatom, vatom, nullptr, thr); + + if (evflag) { + if (eflag) { + if (force->newton_pair) + eval<1, 1, 1>(ifrom, ito, thr); + else + eval<1, 1, 0>(ifrom, ito, thr); + } else { + if (force->newton_pair) + eval<1, 0, 1>(ifrom, ito, thr); + else + eval<1, 0, 0>(ifrom, ito, thr); + } + } else { + if (force->newton_pair) + eval<0, 0, 1>(ifrom, ito, thr); + else + eval<0, 0, 0>(ifrom, ito, thr); + } + + thr->timer(Timer::PAIR); + reduce_thr(this, eflag, vflag, thr); + } // end of omp parallel region +} + +template +void PairLeptonOMP::eval(int iifrom, int iito, ThrData *const thr) +{ + const auto *_noalias const x = (dbl3_t *) atom->x[0]; + auto *_noalias const f = (dbl3_t *) thr->get_f()[0]; + const int *_noalias const type = atom->type; + const int nlocal = atom->nlocal; + const double *_noalias const special_lj = force->special_lj; + + const int *const ilist = list->ilist; + const int *const numneigh = list->numneigh; + const int *const *const firstneigh = list->firstneigh; + double fxtmp, fytmp, fztmp; + + std::vector force; + std::vector epot; + for (const auto &expr : expressions) { + auto parsed = LMP_Lepton::Parser::parse(expr); + force.emplace_back(parsed.differentiate("r").createCompiledExpression()); + if (EFLAG) epot.emplace_back(parsed.createCompiledExpression()); + } + + // loop over neighbors of my atoms + + for (int ii = iifrom; ii < iito; ++ii) { + const int i = ilist[ii]; + const double xtmp = x[i].x; + const double ytmp = x[i].y; + const double ztmp = x[i].z; + const int itype = type[i]; + const int *jlist = firstneigh[i]; + const int jnum = numneigh[i]; + fxtmp = fytmp = fztmp = 0.0; + + for (int jj = 0; jj < jnum; jj++) { + int j = jlist[jj]; + const double factor_lj = special_lj[sbmask(j)]; + j &= NEIGHMASK; + const int jtype = type[j]; + + const double delx = xtmp - x[j].x; + const double dely = ytmp - x[j].y; + const double delz = ztmp - x[j].z; + const double rsq = delx * delx + dely * dely + delz * delz; + + if (rsq < cutsq[itype][jtype]) { + const double r = sqrt(rsq); + const int idx = type2expression[itype][jtype]; + double &r_for = force[idx].getVariableReference("r"); + r_for = r; + const double fpair = -force[idx].evaluate() / r * factor_lj; + + fxtmp += delx * fpair; + fytmp += dely * fpair; + fztmp += delz * fpair; + if (NEWTON_PAIR || j < nlocal) { + f[j].x -= delx * fpair; + f[j].y -= dely * fpair; + f[j].z -= delz * fpair; + } + + double evdwl = 0.0; + if (EFLAG) { + double &r_pot = epot[idx].getVariableReference("r"); + r_pot = r; + evdwl = factor_lj * epot[idx].evaluate(); + } + + if (EVFLAG) + ev_tally_thr(this, i, j, nlocal, NEWTON_PAIR, evdwl, 0.0, fpair, delx, dely, delz, thr); + } + } + f[i].x += fxtmp; + f[i].y += fytmp; + f[i].z += fztmp; + } +} + +/* ---------------------------------------------------------------------- */ + +double PairLeptonOMP::memory_usage() +{ + double bytes = memory_usage_thr(); + bytes += PairLepton::memory_usage(); + + return bytes; +} diff --git a/src/OPENMP/pair_lepton_omp.h b/src/OPENMP/pair_lepton_omp.h new file mode 100644 index 0000000000..7d658dba1c --- /dev/null +++ b/src/OPENMP/pair_lepton_omp.h @@ -0,0 +1,48 @@ +/* -*- c++ -*- ---------------------------------------------------------- + LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator + https://www.lammps.org/, Sandia National Laboratories + LAMMPS development team: developers@lammps.org + + Copyright (2003) Sandia Corporation. Under the terms of Contract + DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government retains + certain rights in this software. This software is distributed under + the GNU General Public License. + + See the README file in the top-level LAMMPS directory. +------------------------------------------------------------------------- */ + +/* ---------------------------------------------------------------------- + Contributing author: Axel Kohlmeyer (Temple U) +------------------------------------------------------------------------- */ + +#ifdef PAIR_CLASS +// clang-format off +PairStyle(lepton/omp,PairLeptonOMP); +// clang-format on +#else + +#ifndef LMP_PAIR_LEPTON_OMP_H +#define LMP_PAIR_LEPTON_OMP_H + +#include "pair_lepton.h" +#include "thr_omp.h" + +namespace LAMMPS_NS { + +class PairLeptonOMP : public PairLepton, public ThrOMP { + + public: + PairLeptonOMP(class LAMMPS *); + + void compute(int, int) override; + double memory_usage() override; + + private: + template + void eval(int ifrom, int ito, ThrData *const thr); +}; + +} // namespace LAMMPS_NS + +#endif +#endif From cf0fb7f5dff01191f590eac84e18519cc93e1027 Mon Sep 17 00:00:00 2001 From: Axel Kohlmeyer Date: Wed, 21 Dec 2022 21:25:41 -0500 Subject: [PATCH 10/79] build system updates for presets and dependencies --- cmake/presets/all_off.cmake | 1 + cmake/presets/all_on.cmake | 1 + cmake/presets/mingw-cross.cmake | 1 + cmake/presets/most.cmake | 1 + cmake/presets/nolib.cmake | 1 + cmake/presets/windows.cmake | 1 + src/.gitignore | 3 ++ src/Depend.sh | 4 ++ src/LEPTON/Install.sh | 68 +++++++++++++++++++++++++++++++++ src/Makefile | 5 ++- 10 files changed, 85 insertions(+), 1 deletion(-) create mode 100755 src/LEPTON/Install.sh diff --git a/cmake/presets/all_off.cmake b/cmake/presets/all_off.cmake index 9127305528..3d5ee95b3d 100644 --- a/cmake/presets/all_off.cmake +++ b/cmake/presets/all_off.cmake @@ -44,6 +44,7 @@ set(ALL_PACKAGES KSPACE LATBOLTZ LATTE + LEPTON MACHDYN MANIFOLD MANYBODY diff --git a/cmake/presets/all_on.cmake b/cmake/presets/all_on.cmake index 0a001bdc56..474051f6ec 100644 --- a/cmake/presets/all_on.cmake +++ b/cmake/presets/all_on.cmake @@ -46,6 +46,7 @@ set(ALL_PACKAGES KSPACE LATBOLTZ LATTE + LEPTON MACHDYN MANIFOLD MANYBODY diff --git a/cmake/presets/mingw-cross.cmake b/cmake/presets/mingw-cross.cmake index 2d74657394..6c6170acd3 100644 --- a/cmake/presets/mingw-cross.cmake +++ b/cmake/presets/mingw-cross.cmake @@ -36,6 +36,7 @@ set(WIN_PACKAGES INTERLAYER KSPACE LATTE + LEPTON MACHDYN MANIFOLD MANYBODY diff --git a/cmake/presets/most.cmake b/cmake/presets/most.cmake index 5dd9a2b78b..0d63140506 100644 --- a/cmake/presets/most.cmake +++ b/cmake/presets/most.cmake @@ -35,6 +35,7 @@ set(ALL_PACKAGES GRANULAR INTERLAYER KSPACE + LEPTON MACHDYN MANYBODY MC diff --git a/cmake/presets/nolib.cmake b/cmake/presets/nolib.cmake index b6567ad617..b022d4bb55 100644 --- a/cmake/presets/nolib.cmake +++ b/cmake/presets/nolib.cmake @@ -13,6 +13,7 @@ set(PACKAGES_WITH_LIB KOKKOS LATBOLTZ LATTE + LEPTON MACHDYN MDI MESONT diff --git a/cmake/presets/windows.cmake b/cmake/presets/windows.cmake index 21be0efefb..68bc5c4335 100644 --- a/cmake/presets/windows.cmake +++ b/cmake/presets/windows.cmake @@ -29,6 +29,7 @@ set(WIN_PACKAGES GRANULAR INTERLAYER KSPACE + LEPTON MANIFOLD MANYBODY MC diff --git a/src/.gitignore b/src/.gitignore index 19f4d924b0..bc77a647a2 100644 --- a/src/.gitignore +++ b/src/.gitignore @@ -54,6 +54,9 @@ /pair_kim.cpp /pair_kim.h +/pair_lepton.cpp +/pair_lepton.h + /pair_pace.cpp /pair_pace.h diff --git a/src/Depend.sh b/src/Depend.sh index 90dfdbba7a..50b34db899 100755 --- a/src/Depend.sh +++ b/src/Depend.sh @@ -118,6 +118,10 @@ if (test $1 = "KSPACE") then depend FEP fi +if (test $1 = "LEPTON") then + depend OPENMP +fi + if (test $1 = "MANYBODY") then depend ATC depend GPU diff --git a/src/LEPTON/Install.sh b/src/LEPTON/Install.sh new file mode 100755 index 0000000000..937ddf28bc --- /dev/null +++ b/src/LEPTON/Install.sh @@ -0,0 +1,68 @@ +# Install/unInstall package files in LAMMPS +# mode = 0/1/2 for uninstall/install/update + +mode=$1 + +# arg1 = file, arg2 = file it depends on + +# enforce using portable C locale +LC_ALL=C +export LC_ALL + +action () { + if (test $mode = 0) then + rm -f ../$1 + elif (! cmp -s $1 ../$1) then + if (test -z "$2" || test -e ../$2) then + cp $1 .. + if (test $mode = 2) then + echo " updating src/$1" + fi + fi + elif (test -n "$2") then + if (test ! -e ../$2) then + rm -f ../$1 + fi + fi +} + +# all package files with no dependencies + +for file in *.cpp *.h; do + test -f ${file} && action $file +done + +# edit 2 Makefile.package files to include/exclude package info + +if (test $1 = 1) then + + if (test -e ../Makefile.package) then + sed -i -e 's/[^ \t]*lepton[^ \t]* //g' ../Makefile.package + sed -i -e 's|^PKG_INC =[ \t]*|&-I..\/..\/lib\/lepton\/include |' ../Makefile.package + sed -i -e 's|^PKG_PATH =[ \t]*|&-L..\/..\/lib\/lepton$(LIBOBJDIR) |' ../Makefile.package + sed -i -e 's|^PKG_LIB =[ \t]*|&-llmplepton |' ../Makefile.package + sed -i -e 's|^PKG_SYSINC =[ \t]*|&$(lepton_SYSINC) |' ../Makefile.package + sed -i -e 's|^PKG_SYSLIB =[ \t]*|&$(lepton_SYSLIB) |' ../Makefile.package + sed -i -e 's|^PKG_SYSPATH =[ \t]*|&$(lepton_SYSPATH) |' ../Makefile.package + fi + + if (test -e ../Makefile.package.settings) then + sed -i -e '/^[ \t]*include.*lepton.*$/d' ../Makefile.package.settings + # multiline form needed for BSD sed on Macs + sed -i -e '4 i \ +include ..\/..\/lib\/lepton\/Makefile.lammps +' ../Makefile.package.settings + + fi + +elif (test $1 = 0) then + + if (test -e ../Makefile.package) then + sed -i -e 's/[^ \t]*lepton[^ \t]* //g' ../Makefile.package + fi + + if (test -e ../Makefile.package.settings) then + sed -i -e '/^[ \t]*include.*lepton.*$/d' ../Makefile.package.settings + fi + +fi diff --git a/src/Makefile b/src/Makefile index 6f3ece5376..494c64699e 100644 --- a/src/Makefile +++ b/src/Makefile @@ -88,6 +88,7 @@ PACKAGE = \ kspace \ latboltz \ latte \ + lepton \ machdyn \ manifold \ manybody \ @@ -212,6 +213,7 @@ PACKLIB = \ kim \ kokkos \ latte \ + lepton \ mpiio \ mscg \ poems \ @@ -224,6 +226,7 @@ PACKLIB = \ h5md \ ml-hdnnp \ latboltz \ + lepton \ mdi \ mesont \ molfile \ @@ -240,7 +243,7 @@ PACKLIB = \ PACKSYS = compress latboltz mpiio python -PACKINT = atc awpmd colvars electrode gpu kokkos mesont ml-pod poems +PACKINT = atc awpmd colvars electrode gpu kokkos lepton mesont ml-pod poems PACKEXT = \ adios \ From e2f9d594840d8c7cec0e354135a797350ac4a8c2 Mon Sep 17 00:00:00 2001 From: Axel Kohlmeyer Date: Wed, 21 Dec 2022 21:34:56 -0500 Subject: [PATCH 11/79] whitespace fixes --- .../include/lepton/CompiledExpression.h | 4 +- lib/lepton/include/lepton/CustomFunction.h | 2 +- lib/lepton/include/lepton/ExpressionProgram.h | 2 +- lib/lepton/include/lepton/Operation.h | 2 +- lib/lepton/src/CompiledExpression.cpp | 44 +++++++++---------- lib/lepton/src/MSVC_erfc.h | 4 +- lib/lepton/src/ParsedExpression.cpp | 6 +-- lib/lepton/src/Parser.cpp | 2 +- 8 files changed, 33 insertions(+), 33 deletions(-) diff --git a/lib/lepton/include/lepton/CompiledExpression.h b/lib/lepton/include/lepton/CompiledExpression.h index bf076b0d8b..8ead5ce96f 100644 --- a/lib/lepton/include/lepton/CompiledExpression.h +++ b/lib/lepton/include/lepton/CompiledExpression.h @@ -52,9 +52,9 @@ class ParsedExpression; * A CompiledExpression is a highly optimized representation of an expression for cases when you want to evaluate * it many times as quickly as possible. You should treat it as an opaque object; none of the internal representation * is visible. - * + * * A CompiledExpression is created by calling createCompiledExpression() on a ParsedExpression. - * + * * WARNING: CompiledExpression is NOT thread safe. You should never access a CompiledExpression from two threads at * the same time. */ diff --git a/lib/lepton/include/lepton/CustomFunction.h b/lib/lepton/include/lepton/CustomFunction.h index b8cbee8c96..c4b9932cd8 100644 --- a/lib/lepton/include/lepton/CustomFunction.h +++ b/lib/lepton/include/lepton/CustomFunction.h @@ -83,7 +83,7 @@ class LEPTON_EXPORT PlaceholderFunction : public CustomFunction { public: /** * Create a Placeholder function. - * + * * @param numArgs the number of arguments the function expects */ PlaceholderFunction(int numArgs) : numArgs(numArgs) { diff --git a/lib/lepton/include/lepton/ExpressionProgram.h b/lib/lepton/include/lepton/ExpressionProgram.h index 4fba4051e4..3737cf8082 100644 --- a/lib/lepton/include/lepton/ExpressionProgram.h +++ b/lib/lepton/include/lepton/ExpressionProgram.h @@ -67,7 +67,7 @@ public: const Operation& getOperation(int index) const; /** * Change an Operation in this program. - * + * * The Operation must have been allocated on the heap with the "new" operator. * The ExpressionProgram assumes ownership of it and will delete it when it * is no longer needed. diff --git a/lib/lepton/include/lepton/Operation.h b/lib/lepton/include/lepton/Operation.h index b27b25d3d8..848910f6e4 100644 --- a/lib/lepton/include/lepton/Operation.h +++ b/lib/lepton/include/lepton/Operation.h @@ -1017,7 +1017,7 @@ public: double evaluate(double* args, const std::map& variables) const { if (isIntPower) { // Integer powers can be computed much more quickly by repeated multiplication. - + int exponent = intValue; double base = args[0]; if (exponent < 0) { diff --git a/lib/lepton/src/CompiledExpression.cpp b/lib/lepton/src/CompiledExpression.cpp index 7805c4674d..67cf196ebb 100644 --- a/lib/lepton/src/CompiledExpression.cpp +++ b/lib/lepton/src/CompiledExpression.cpp @@ -84,17 +84,17 @@ CompiledExpression& CompiledExpression::operator=(const CompiledExpression& expr void CompiledExpression::compileExpression(const ExpressionTreeNode& node, vector >& temps) { if (findTempIndex(node, temps) != -1) return; // We have already processed a node identical to this one. - + // Process the child nodes. - + vector args; for (int i = 0; i < node.getChildren().size(); i++) { compileExpression(node.getChildren()[i], temps); args.push_back(findTempIndex(node.getChildren()[i], temps)); } - + // Process this node. - + if (node.getOperation().getId() == Operation::VARIABLE) { variableIndices[node.getOperation().getName()] = (int) workspace.size(); variableNames.insert(node.getOperation().getName()); @@ -108,7 +108,7 @@ void CompiledExpression::compileExpression(const ExpressionTreeNode& node, vecto arguments[stepIndex].push_back(0); // The value won't actually be used. We just need something there. else { // If the arguments are sequential, we can just pass a pointer to the first one. - + bool sequential = true; for (int i = 1; i < args.size(); i++) if (args[i] != args[i-1]+1) @@ -148,12 +148,12 @@ void CompiledExpression::setVariableLocations(map& variableLoca variablePointers = variableLocations; #ifdef LEPTON_USE_JIT // Rebuild the JIT code. - + if (workspace.size() > 0) generateJitCode(); #else // Make a list of all variables we will need to copy before evaluating the expression. - + variablesToCopy.clear(); for (map::const_iterator iter = variableIndices.begin(); iter != variableIndices.end(); ++iter) { map::iterator pointer = variablePointers.find(iter->first); @@ -171,7 +171,7 @@ double CompiledExpression::evaluate() const { *variablesToCopy[i].first = *variablesToCopy[i].second; // Loop over the operations and evaluate each one. - + for (int step = 0; step < operation.size(); step++) { const vector& args = arguments[step]; if (args.size() == 1) @@ -202,9 +202,9 @@ void CompiledExpression::generateJitCode() { workspaceVar[i] = c.newXmmSd(); X86Gp argsPointer = c.newIntPtr(); c.mov(argsPointer, imm_ptr(&argValues[0])); - + // Load the arguments into variables. - + for (set::const_iterator iter = variableNames.begin(); iter != variableNames.end(); ++iter) { map::iterator index = variableIndices.find(*iter); X86Gp variablePointer = c.newIntPtr(); @@ -213,11 +213,11 @@ void CompiledExpression::generateJitCode() { } // Make a list of all constants that will be needed for evaluation. - + vector operationConstantIndex(operation.size(), -1); for (int step = 0; step < (int) operation.size(); step++) { // Find the constant value (if any) used by this operation. - + Operation& op = *operation[step]; double value; if (op.getId() == Operation::CONSTANT) @@ -234,9 +234,9 @@ void CompiledExpression::generateJitCode() { value = 1.0; else continue; - + // See if we already have a variable for this constant. - + for (int i = 0; i < (int) constants.size(); i++) if (value == constants[i]) { operationConstantIndex[step] = i; @@ -247,9 +247,9 @@ void CompiledExpression::generateJitCode() { constants.push_back(value); } } - + // Load constants into variables. - + vector constantVar(constants.size()); if (constants.size() > 0) { X86Gp constantsPointer = c.newIntPtr(); @@ -259,21 +259,21 @@ void CompiledExpression::generateJitCode() { c.movsd(constantVar[i], x86::ptr(constantsPointer, 8*i, 0)); } } - + // Evaluate the operations. - + for (int step = 0; step < (int) operation.size(); step++) { Operation& op = *operation[step]; vector args = arguments[step]; if (args.size() == 1) { // One or more sequential arguments. Fill out the list. - + for (int i = 1; i < op.getNumArguments(); i++) args.push_back(args[0]+i); } - + // Generate instructions to execute this operation. - + switch (op.getId()) { case Operation::CONSTANT: c.movsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); @@ -382,7 +382,7 @@ void CompiledExpression::generateJitCode() { break; default: // Just invoke evaluateOperation(). - + for (int i = 0; i < (int) args.size(); i++) c.movsd(x86::ptr(argsPointer, 8*i, 0), workspaceVar[args[i]]); X86Gp fn = c.newIntPtr(); diff --git a/lib/lepton/src/MSVC_erfc.h b/lib/lepton/src/MSVC_erfc.h index 2c6b619e89..b1cd87a289 100644 --- a/lib/lepton/src/MSVC_erfc.h +++ b/lib/lepton/src/MSVC_erfc.h @@ -3,7 +3,7 @@ /* * Up to version 11 (VC++ 2012), Microsoft does not support the - * standard C99 erf() and erfc() functions so we have to fake them here. + * standard C99 erf() and erfc() functions so we have to fake them here. * These were added in version 12 (VC++ 2013), which sets _MSC_VER=1800 * (VC11 has _MSC_VER=1700). */ @@ -15,7 +15,7 @@ #endif #if defined(_MSC_VER) -#if _MSC_VER <= 1700 // 1700 is VC11, 1800 is VC12 +#if _MSC_VER <= 1700 // 1700 is VC11, 1800 is VC12 /*************************** * erf.cpp * author: Steve Strand diff --git a/lib/lepton/src/ParsedExpression.cpp b/lib/lepton/src/ParsedExpression.cpp index 13ebbf2dc2..c6092a2dc2 100644 --- a/lib/lepton/src/ParsedExpression.cpp +++ b/lib/lepton/src/ParsedExpression.cpp @@ -121,10 +121,10 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio vector children(node.getChildren().size()); for (int i = 0; i < (int) children.size(); i++) children[i] = substituteSimplerExpression(node.getChildren()[i]); - + // Collect some info on constant expressions in children bool first_const = children.size() > 0 && isConstant(children[0]); // is first child constant? - bool second_const = children.size() > 1 && isConstant(children[1]); ; // is second child constant? + bool second_const = children.size() > 1 && isConstant(children[1]); ; // is second child constant? double first, second; // if yes, value of first and second child if (first_const) first = getConstantValue(children[0]); @@ -174,7 +174,7 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio break; } case Operation::MULTIPLY: - { + { if ((first_const && first == 0.0) || (second_const && second == 0.0)) // Multiply by 0 return ExpressionTreeNode(new Operation::Constant(0.0)); if (first_const && first == 1.0) // Multiply by 1 diff --git a/lib/lepton/src/Parser.cpp b/lib/lepton/src/Parser.cpp index c0b4c185e8..e7d87ba289 100644 --- a/lib/lepton/src/Parser.cpp +++ b/lib/lepton/src/Parser.cpp @@ -66,7 +66,7 @@ private: string Parser::trim(const string& expression) { // Remove leading and trailing spaces. - + int start, end; for (start = 0; start < (int) expression.size() && isspace(expression[start]); start++) ; From 4293771ae8cb5e406f67007ea4e2ad378a7c8718 Mon Sep 17 00:00:00 2001 From: Axel Kohlmeyer Date: Wed, 21 Dec 2022 21:53:25 -0500 Subject: [PATCH 12/79] silence compiler warnings --- lib/lepton/include/lepton/CustomFunction.h | 4 +- lib/lepton/include/lepton/Operation.h | 78 ++++++++++---------- lib/lepton/src/CompiledExpression.cpp | 10 +-- lib/lepton/src/ExpressionTreeNode.cpp | 8 +- lib/lepton/src/Operation.cpp | 86 +++++++++++----------- lib/lepton/src/ParsedExpression.cpp | 2 + lib/lepton/src/Parser.cpp | 10 +-- 7 files changed, 100 insertions(+), 98 deletions(-) diff --git a/lib/lepton/include/lepton/CustomFunction.h b/lib/lepton/include/lepton/CustomFunction.h index c4b9932cd8..4b8121a87f 100644 --- a/lib/lepton/include/lepton/CustomFunction.h +++ b/lib/lepton/include/lepton/CustomFunction.h @@ -91,10 +91,10 @@ public: int getNumArguments() const { return numArgs; } - double evaluate(const double* arguments) const { + double evaluate(const double* ) const { return 0.0; } - double evaluateDerivative(const double* arguments, const int* derivOrder) const { + double evaluateDerivative(const double* , const int* ) const { return 0.0; } CustomFunction* clone() const { diff --git a/lib/lepton/include/lepton/Operation.h b/lib/lepton/include/lepton/Operation.h index 848910f6e4..bde9cfe37f 100644 --- a/lib/lepton/include/lepton/Operation.h +++ b/lib/lepton/include/lepton/Operation.h @@ -177,7 +177,7 @@ public: Operation* clone() const { return new Constant(value); } - double evaluate(double* args, const std::map& variables) const { + double evaluate(double* , const std::map& ) const { return value; } ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; @@ -208,7 +208,7 @@ public: Operation* clone() const { return new Variable(name); } - double evaluate(double* args, const std::map& variables) const { + double evaluate(double* , const std::map& variables) const { std::map::const_iterator iter = variables.find(name); if (iter == variables.end()) throw Exception("No value specified for variable "+name); @@ -253,7 +253,7 @@ public: clone->derivOrder = derivOrder; return clone; } - double evaluate(double* args, const std::map& variables) const { + double evaluate(double* args, const std::map& ) const { if (isDerivative) return function->evaluateDerivative(args, &derivOrder[0]); return function->evaluate(args); @@ -289,7 +289,7 @@ public: Operation* clone() const { return new Add(); } - double evaluate(double* args, const std::map& variables) const { + double evaluate(double* args, const std::map& ) const { return args[0]+args[1]; } ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; @@ -317,7 +317,7 @@ public: Operation* clone() const { return new Subtract(); } - double evaluate(double* args, const std::map& variables) const { + double evaluate(double* args, const std::map& ) const { return args[0]-args[1]; } ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; @@ -342,7 +342,7 @@ public: Operation* clone() const { return new Multiply(); } - double evaluate(double* args, const std::map& variables) const { + double evaluate(double* args, const std::map& ) const { return args[0]*args[1]; } ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; @@ -370,7 +370,7 @@ public: Operation* clone() const { return new Divide(); } - double evaluate(double* args, const std::map& variables) const { + double evaluate(double* args, const std::map& ) const { return args[0]/args[1]; } ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; @@ -395,7 +395,7 @@ public: Operation* clone() const { return new Power(); } - double evaluate(double* args, const std::map& variables) const { + double evaluate(double* args, const std::map& ) const { return std::pow(args[0], args[1]); } ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; @@ -420,7 +420,7 @@ public: Operation* clone() const { return new Negate(); } - double evaluate(double* args, const std::map& variables) const { + double evaluate(double* args, const std::map& ) const { return -args[0]; } ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; @@ -442,7 +442,7 @@ public: Operation* clone() const { return new Sqrt(); } - double evaluate(double* args, const std::map& variables) const { + double evaluate(double* args, const std::map& ) const { return std::sqrt(args[0]); } ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; @@ -464,7 +464,7 @@ public: Operation* clone() const { return new Exp(); } - double evaluate(double* args, const std::map& variables) const { + double evaluate(double* args, const std::map& ) const { return std::exp(args[0]); } ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; @@ -486,7 +486,7 @@ public: Operation* clone() const { return new Log(); } - double evaluate(double* args, const std::map& variables) const { + double evaluate(double* args, const std::map& ) const { return std::log(args[0]); } ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; @@ -508,7 +508,7 @@ public: Operation* clone() const { return new Sin(); } - double evaluate(double* args, const std::map& variables) const { + double evaluate(double* args, const std::map& ) const { return std::sin(args[0]); } ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; @@ -530,7 +530,7 @@ public: Operation* clone() const { return new Cos(); } - double evaluate(double* args, const std::map& variables) const { + double evaluate(double* args, const std::map& ) const { return std::cos(args[0]); } ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; @@ -552,7 +552,7 @@ public: Operation* clone() const { return new Sec(); } - double evaluate(double* args, const std::map& variables) const { + double evaluate(double* args, const std::map& ) const { return 1.0/std::cos(args[0]); } ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; @@ -574,7 +574,7 @@ public: Operation* clone() const { return new Csc(); } - double evaluate(double* args, const std::map& variables) const { + double evaluate(double* args, const std::map& ) const { return 1.0/std::sin(args[0]); } ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; @@ -596,7 +596,7 @@ public: Operation* clone() const { return new Tan(); } - double evaluate(double* args, const std::map& variables) const { + double evaluate(double* args, const std::map& ) const { return std::tan(args[0]); } ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; @@ -618,7 +618,7 @@ public: Operation* clone() const { return new Cot(); } - double evaluate(double* args, const std::map& variables) const { + double evaluate(double* args, const std::map& ) const { return 1.0/std::tan(args[0]); } ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; @@ -640,7 +640,7 @@ public: Operation* clone() const { return new Asin(); } - double evaluate(double* args, const std::map& variables) const { + double evaluate(double* args, const std::map& ) const { return std::asin(args[0]); } ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; @@ -662,7 +662,7 @@ public: Operation* clone() const { return new Acos(); } - double evaluate(double* args, const std::map& variables) const { + double evaluate(double* args, const std::map& ) const { return std::acos(args[0]); } ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; @@ -684,7 +684,7 @@ public: Operation* clone() const { return new Atan(); } - double evaluate(double* args, const std::map& variables) const { + double evaluate(double* args, const std::map& ) const { return std::atan(args[0]); } ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; @@ -706,7 +706,7 @@ public: Operation* clone() const { return new Atan2(); } - double evaluate(double* args, const std::map& variables) const { + double evaluate(double* args, const std::map& ) const { return std::atan2(args[0], args[1]); } ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; @@ -728,7 +728,7 @@ public: Operation* clone() const { return new Sinh(); } - double evaluate(double* args, const std::map& variables) const { + double evaluate(double* args, const std::map& ) const { return std::sinh(args[0]); } ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; @@ -750,7 +750,7 @@ public: Operation* clone() const { return new Cosh(); } - double evaluate(double* args, const std::map& variables) const { + double evaluate(double* args, const std::map& ) const { return std::cosh(args[0]); } ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; @@ -772,7 +772,7 @@ public: Operation* clone() const { return new Tanh(); } - double evaluate(double* args, const std::map& variables) const { + double evaluate(double* args, const std::map& ) const { return std::tanh(args[0]); } ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; @@ -834,7 +834,7 @@ public: Operation* clone() const { return new Step(); } - double evaluate(double* args, const std::map& variables) const { + double evaluate(double* args, const std::map& ) const { return (args[0] >= 0.0 ? 1.0 : 0.0); } ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; @@ -856,7 +856,7 @@ public: Operation* clone() const { return new Delta(); } - double evaluate(double* args, const std::map& variables) const { + double evaluate(double* args, const std::map& ) const { return (args[0] == 0.0 ? 1.0 : 0.0); } ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; @@ -878,7 +878,7 @@ public: Operation* clone() const { return new Square(); } - double evaluate(double* args, const std::map& variables) const { + double evaluate(double* args, const std::map& ) const { return args[0]*args[0]; } ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; @@ -900,7 +900,7 @@ public: Operation* clone() const { return new Cube(); } - double evaluate(double* args, const std::map& variables) const { + double evaluate(double* args, const std::map& ) const { return args[0]*args[0]*args[0]; } ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; @@ -922,7 +922,7 @@ public: Operation* clone() const { return new Reciprocal(); } - double evaluate(double* args, const std::map& variables) const { + double evaluate(double* args, const std::map& ) const { return 1.0/args[0]; } ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; @@ -946,7 +946,7 @@ public: Operation* clone() const { return new AddConstant(value); } - double evaluate(double* args, const std::map& variables) const { + double evaluate(double* args, const std::map& ) const { return args[0]+value; } ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; @@ -979,7 +979,7 @@ public: Operation* clone() const { return new MultiplyConstant(value); } - double evaluate(double* args, const std::map& variables) const { + double evaluate(double* args, const std::map& ) const { return args[0]*value; } ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; @@ -1014,7 +1014,7 @@ public: Operation* clone() const { return new PowerConstant(value); } - double evaluate(double* args, const std::map& variables) const { + double evaluate(double* args, const std::map& ) const { if (isIntPower) { // Integer powers can be computed much more quickly by repeated multiplication. @@ -1069,7 +1069,7 @@ public: Operation* clone() const { return new Min(); } - double evaluate(double* args, const std::map& variables) const { + double evaluate(double* args, const std::map& ) const { // parens around (std::min) are workaround for horrible microsoft max/min macro trouble return (std::min)(args[0], args[1]); } @@ -1092,7 +1092,7 @@ public: Operation* clone() const { return new Max(); } - double evaluate(double* args, const std::map& variables) const { + double evaluate(double* args, const std::map& ) const { // parens around (std::min) are workaround for horrible microsoft max/min macro trouble return (std::max)(args[0], args[1]); } @@ -1115,7 +1115,7 @@ public: Operation* clone() const { return new Abs(); } - double evaluate(double* args, const std::map& variables) const { + double evaluate(double* args, const std::map& ) const { return std::abs(args[0]); } ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; @@ -1138,7 +1138,7 @@ public: Operation* clone() const { return new Floor(); } - double evaluate(double* args, const std::map& variables) const { + double evaluate(double* args, const std::map& ) const { return std::floor(args[0]); } ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; @@ -1160,7 +1160,7 @@ public: Operation* clone() const { return new Ceil(); } - double evaluate(double* args, const std::map& variables) const { + double evaluate(double* args, const std::map& ) const { return std::ceil(args[0]); } ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; @@ -1182,7 +1182,7 @@ public: Operation* clone() const { return new Select(); } - double evaluate(double* args, const std::map& variables) const { + double evaluate(double* args, const std::map&) const { return (args[0] != 0.0 ? args[1] : args[2]); } ExpressionTreeNode differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const; diff --git a/lib/lepton/src/CompiledExpression.cpp b/lib/lepton/src/CompiledExpression.cpp index 67cf196ebb..c6c1543ce4 100644 --- a/lib/lepton/src/CompiledExpression.cpp +++ b/lib/lepton/src/CompiledExpression.cpp @@ -88,7 +88,7 @@ void CompiledExpression::compileExpression(const ExpressionTreeNode& node, vecto // Process the child nodes. vector args; - for (int i = 0; i < node.getChildren().size(); i++) { + for (int i = 0; i < (int)node.getChildren().size(); i++) { compileExpression(node.getChildren()[i], temps); args.push_back(findTempIndex(node.getChildren()[i], temps)); } @@ -110,7 +110,7 @@ void CompiledExpression::compileExpression(const ExpressionTreeNode& node, vecto // If the arguments are sequential, we can just pass a pointer to the first one. bool sequential = true; - for (int i = 1; i < args.size(); i++) + for (int i = 1; i < (int)args.size(); i++) if (args[i] != args[i-1]+1) sequential = false; if (sequential) @@ -167,17 +167,17 @@ double CompiledExpression::evaluate() const { #ifdef LEPTON_USE_JIT return jitCode(); #else - for (int i = 0; i < variablesToCopy.size(); i++) + for (int i = 0; i < (int)variablesToCopy.size(); i++) *variablesToCopy[i].first = *variablesToCopy[i].second; // Loop over the operations and evaluate each one. - for (int step = 0; step < operation.size(); step++) { + for (int step = 0; step < (int)operation.size(); step++) { const vector& args = arguments[step]; if (args.size() == 1) workspace[target[step]] = operation[step]->evaluate(&workspace[args[0]], dummyVariables); else { - for (int i = 0; i < args.size(); i++) + for (int i = 0; i < (int)args.size(); i++) argValues[i] = workspace[args[i]]; workspace[target[step]] = operation[step]->evaluate(&argValues[0], dummyVariables); } diff --git a/lib/lepton/src/ExpressionTreeNode.cpp b/lib/lepton/src/ExpressionTreeNode.cpp index e4fbbc6f50..90020aa373 100644 --- a/lib/lepton/src/ExpressionTreeNode.cpp +++ b/lib/lepton/src/ExpressionTreeNode.cpp @@ -37,25 +37,25 @@ using namespace LMP_Lepton; using namespace std; ExpressionTreeNode::ExpressionTreeNode(Operation* operation, const vector& children) : operation(operation), children(children) { - if (operation->getNumArguments() != children.size()) + if (operation->getNumArguments() != (int)children.size()) throw Exception("wrong number of arguments to function: "+operation->getName()); } ExpressionTreeNode::ExpressionTreeNode(Operation* operation, const ExpressionTreeNode& child1, const ExpressionTreeNode& child2) : operation(operation) { children.push_back(child1); children.push_back(child2); - if (operation->getNumArguments() != children.size()) + if (operation->getNumArguments() != (int)children.size()) throw Exception("wrong number of arguments to function: "+operation->getName()); } ExpressionTreeNode::ExpressionTreeNode(Operation* operation, const ExpressionTreeNode& child) : operation(operation) { children.push_back(child); - if (operation->getNumArguments() != children.size()) + if (operation->getNumArguments() != (int)children.size()) throw Exception("wrong number of arguments to function: "+operation->getName()); } ExpressionTreeNode::ExpressionTreeNode(Operation* operation) : operation(operation) { - if (operation->getNumArguments() != children.size()) + if (operation->getNumArguments() != (int)children.size()) throw Exception("wrong number of arguments to function: "+operation->getName()); } diff --git a/lib/lepton/src/Operation.cpp b/lib/lepton/src/Operation.cpp index 512f5db321..bec5686a74 100644 --- a/lib/lepton/src/Operation.cpp +++ b/lib/lepton/src/Operation.cpp @@ -37,25 +37,25 @@ using namespace LMP_Lepton; using namespace std; -double Operation::Erf::evaluate(double* args, const map& variables) const { +double Operation::Erf::evaluate(double* args, const map& ) const { return erf(args[0]); } -double Operation::Erfc::evaluate(double* args, const map& variables) const { +double Operation::Erfc::evaluate(double* args, const map& ) const { return erfc(args[0]); } -ExpressionTreeNode Operation::Constant::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { +ExpressionTreeNode Operation::Constant::differentiate(const std::vector& , const std::vector& , const std::string& ) const { return ExpressionTreeNode(new Operation::Constant(0.0)); } -ExpressionTreeNode Operation::Variable::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { +ExpressionTreeNode Operation::Variable::differentiate(const std::vector& , const std::vector& , const std::string& variable) const { if (variable == name) return ExpressionTreeNode(new Operation::Constant(1.0)); return ExpressionTreeNode(new Operation::Constant(0.0)); } -ExpressionTreeNode Operation::Custom::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { +ExpressionTreeNode Operation::Custom::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { if (function->getNumArguments() == 0) return ExpressionTreeNode(new Operation::Constant(0.0)); ExpressionTreeNode result = ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Custom(*this, 0), children), childDerivs[0]); @@ -67,21 +67,21 @@ ExpressionTreeNode Operation::Custom::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { +ExpressionTreeNode Operation::Add::differentiate(const std::vector& , const std::vector& childDerivs, const std::string& ) const { return ExpressionTreeNode(new Operation::Add(), childDerivs[0], childDerivs[1]); } -ExpressionTreeNode Operation::Subtract::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { +ExpressionTreeNode Operation::Subtract::differentiate(const std::vector& , const std::vector& childDerivs, const std::string& ) const { return ExpressionTreeNode(new Operation::Subtract(), childDerivs[0], childDerivs[1]); } -ExpressionTreeNode Operation::Multiply::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { +ExpressionTreeNode Operation::Multiply::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { return ExpressionTreeNode(new Operation::Add(), ExpressionTreeNode(new Operation::Multiply(), children[0], childDerivs[1]), ExpressionTreeNode(new Operation::Multiply(), children[1], childDerivs[0])); } -ExpressionTreeNode Operation::Divide::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { +ExpressionTreeNode Operation::Divide::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { return ExpressionTreeNode(new Operation::Divide(), ExpressionTreeNode(new Operation::Subtract(), ExpressionTreeNode(new Operation::Multiply(), children[1], childDerivs[0]), @@ -89,7 +89,7 @@ ExpressionTreeNode Operation::Divide::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { +ExpressionTreeNode Operation::Power::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { return ExpressionTreeNode(new Operation::Add(), ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Multiply(), @@ -104,11 +104,11 @@ ExpressionTreeNode Operation::Power::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { +ExpressionTreeNode Operation::Negate::differentiate(const std::vector& , const std::vector& childDerivs, const std::string& ) const { return ExpressionTreeNode(new Operation::Negate(), childDerivs[0]); } -ExpressionTreeNode Operation::Sqrt::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { +ExpressionTreeNode Operation::Sqrt::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::MultiplyConstant(0.5), ExpressionTreeNode(new Operation::Reciprocal(), @@ -116,32 +116,32 @@ ExpressionTreeNode Operation::Sqrt::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { +ExpressionTreeNode Operation::Exp::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Exp(), children[0]), childDerivs[0]); } -ExpressionTreeNode Operation::Log::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { +ExpressionTreeNode Operation::Log::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Reciprocal(), children[0]), childDerivs[0]); } -ExpressionTreeNode Operation::Sin::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { +ExpressionTreeNode Operation::Sin::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Cos(), children[0]), childDerivs[0]); } -ExpressionTreeNode Operation::Cos::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { +ExpressionTreeNode Operation::Cos::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Negate(), ExpressionTreeNode(new Operation::Sin(), children[0])), childDerivs[0]); } -ExpressionTreeNode Operation::Sec::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { +ExpressionTreeNode Operation::Sec::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Sec(), children[0]), @@ -149,7 +149,7 @@ ExpressionTreeNode Operation::Sec::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { +ExpressionTreeNode Operation::Csc::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Negate(), ExpressionTreeNode(new Operation::Multiply(), @@ -158,14 +158,14 @@ ExpressionTreeNode Operation::Csc::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { +ExpressionTreeNode Operation::Tan::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Square(), ExpressionTreeNode(new Operation::Sec(), children[0])), childDerivs[0]); } -ExpressionTreeNode Operation::Cot::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { +ExpressionTreeNode Operation::Cot::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Negate(), ExpressionTreeNode(new Operation::Square(), @@ -173,7 +173,7 @@ ExpressionTreeNode Operation::Cot::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { +ExpressionTreeNode Operation::Asin::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Reciprocal(), ExpressionTreeNode(new Operation::Sqrt(), @@ -183,7 +183,7 @@ ExpressionTreeNode Operation::Asin::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { +ExpressionTreeNode Operation::Acos::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Negate(), ExpressionTreeNode(new Operation::Reciprocal(), @@ -194,7 +194,7 @@ ExpressionTreeNode Operation::Acos::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { +ExpressionTreeNode Operation::Atan::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Reciprocal(), ExpressionTreeNode(new Operation::AddConstant(1.0), @@ -202,7 +202,7 @@ ExpressionTreeNode Operation::Atan::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { +ExpressionTreeNode Operation::Atan2::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { return ExpressionTreeNode(new Operation::Divide(), ExpressionTreeNode(new Operation::Subtract(), ExpressionTreeNode(new Operation::Multiply(), children[1], childDerivs[0]), @@ -212,21 +212,21 @@ ExpressionTreeNode Operation::Atan2::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { +ExpressionTreeNode Operation::Sinh::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Cosh(), children[0]), childDerivs[0]); } -ExpressionTreeNode Operation::Cosh::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { +ExpressionTreeNode Operation::Cosh::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Sinh(), children[0]), childDerivs[0]); } -ExpressionTreeNode Operation::Tanh::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { +ExpressionTreeNode Operation::Tanh::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Subtract(), ExpressionTreeNode(new Operation::Constant(1.0)), @@ -235,7 +235,7 @@ ExpressionTreeNode Operation::Tanh::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { +ExpressionTreeNode Operation::Erf::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Constant(2.0/sqrt(M_PI))), @@ -245,7 +245,7 @@ ExpressionTreeNode Operation::Erf::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { +ExpressionTreeNode Operation::Erfc::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Constant(-2.0/sqrt(M_PI))), @@ -255,29 +255,29 @@ ExpressionTreeNode Operation::Erfc::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { +ExpressionTreeNode Operation::Step::differentiate(const std::vector& , const std::vector& , const std::string& ) const { return ExpressionTreeNode(new Operation::Constant(0.0)); } -ExpressionTreeNode Operation::Delta::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { +ExpressionTreeNode Operation::Delta::differentiate(const std::vector& , const std::vector& , const std::string& ) const { return ExpressionTreeNode(new Operation::Constant(0.0)); } -ExpressionTreeNode Operation::Square::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { +ExpressionTreeNode Operation::Square::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::MultiplyConstant(2.0), children[0]), childDerivs[0]); } -ExpressionTreeNode Operation::Cube::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { +ExpressionTreeNode Operation::Cube::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::MultiplyConstant(3.0), ExpressionTreeNode(new Operation::Square(), children[0])), childDerivs[0]); } -ExpressionTreeNode Operation::Reciprocal::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { +ExpressionTreeNode Operation::Reciprocal::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Negate(), ExpressionTreeNode(new Operation::Reciprocal(), @@ -285,16 +285,16 @@ ExpressionTreeNode Operation::Reciprocal::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { +ExpressionTreeNode Operation::AddConstant::differentiate(const std::vector& , const std::vector& childDerivs, const std::string& ) const { return childDerivs[0]; } -ExpressionTreeNode Operation::MultiplyConstant::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { +ExpressionTreeNode Operation::MultiplyConstant::differentiate(const std::vector& , const std::vector& childDerivs, const std::string& ) const { return ExpressionTreeNode(new Operation::MultiplyConstant(value), childDerivs[0]); } -ExpressionTreeNode Operation::PowerConstant::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { +ExpressionTreeNode Operation::PowerConstant::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::MultiplyConstant(value), ExpressionTreeNode(new Operation::PowerConstant(value-1), @@ -302,7 +302,7 @@ ExpressionTreeNode Operation::PowerConstant::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { +ExpressionTreeNode Operation::Min::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { ExpressionTreeNode step(new Operation::Step(), ExpressionTreeNode(new Operation::Subtract(), children[0], children[1])); return ExpressionTreeNode(new Operation::Subtract(), @@ -311,7 +311,7 @@ ExpressionTreeNode Operation::Min::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { +ExpressionTreeNode Operation::Max::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { ExpressionTreeNode step(new Operation::Step(), ExpressionTreeNode(new Operation::Subtract(), children[0], children[1])); return ExpressionTreeNode(new Operation::Subtract(), @@ -320,7 +320,7 @@ ExpressionTreeNode Operation::Max::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { +ExpressionTreeNode Operation::Abs::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { ExpressionTreeNode step(new Operation::Step(), children[0]); return ExpressionTreeNode(new Operation::Multiply(), childDerivs[0], @@ -328,15 +328,15 @@ ExpressionTreeNode Operation::Abs::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { +ExpressionTreeNode Operation::Floor::differentiate(const std::vector& , const std::vector& , const std::string& ) const { return ExpressionTreeNode(new Operation::Constant(0.0)); } -ExpressionTreeNode Operation::Ceil::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { +ExpressionTreeNode Operation::Ceil::differentiate(const std::vector& , const std::vector& , const std::string& ) const { return ExpressionTreeNode(new Operation::Constant(0.0)); } -ExpressionTreeNode Operation::Select::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& variable) const { +ExpressionTreeNode Operation::Select::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { vector derivChildren; derivChildren.push_back(children[0]); derivChildren.push_back(childDerivs[1]); diff --git a/lib/lepton/src/ParsedExpression.cpp b/lib/lepton/src/ParsedExpression.cpp index c6092a2dc2..1417551011 100644 --- a/lib/lepton/src/ParsedExpression.cpp +++ b/lib/lepton/src/ParsedExpression.cpp @@ -288,11 +288,13 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio { if (children[0].getOperation().getId() == Operation::SQUARE) // sqrt(square(x)) = abs(x) return ExpressionTreeNode(new Operation::Abs(), children[0].getChildren()[0]); + break; } case Operation::SQUARE: { if (children[0].getOperation().getId() == Operation::SQRT) // square(sqrt(x)) = x return children[0].getChildren()[0]; + break; } default: { diff --git a/lib/lepton/src/Parser.cpp b/lib/lepton/src/Parser.cpp index e7d87ba289..d094b8e4e4 100644 --- a/lib/lepton/src/Parser.cpp +++ b/lib/lepton/src/Parser.cpp @@ -178,7 +178,7 @@ ParsedExpression Parser::parse(const string& expression, const map tokens = tokenize(subexpressions[i].substr(equalsPos+1)); int pos = 0; subexpDefs[name] = parsePrecedence(tokens, pos, customFunctions, subexpDefs, 0); - if (pos != tokens.size()) + if (pos != (int)tokens.size()) throw Exception("unexpected text at end of subexpression: "+tokens[pos].getText()); } @@ -187,7 +187,7 @@ ParsedExpression Parser::parse(const string& expression, const map tokens = tokenize(primaryExpression); int pos = 0; ExpressionTreeNode result = parsePrecedence(tokens, pos, customFunctions, subexpDefs, 0); - if (pos != tokens.size()) + if (pos != (int)tokens.size()) throw Exception("unexpected text at end of expression: "+tokens[pos].getText()); return ParsedExpression(result); } @@ -198,7 +198,7 @@ ParsedExpression Parser::parse(const string& expression, const map& tokens, int& pos, const map& customFunctions, const map& subexpressionDefs, int precedence) { - if (pos == tokens.size()) + if (pos == (int)tokens.size()) throw Exception("unexpected end of expression"); // Parse the next value (number, variable, function, parenthesized expression) @@ -224,7 +224,7 @@ ExpressionTreeNode Parser::parsePrecedence(const vector& tokens, int else if (token.getType() == ParseToken::LeftParen) { pos++; result = parsePrecedence(tokens, pos, customFunctions, subexpressionDefs, 0); - if (pos == tokens.size() || tokens[pos].getType() != ParseToken::RightParen) + if (pos == (int)tokens.size() || tokens[pos].getType() != ParseToken::RightParen) throw Exception("unbalanced parentheses"); pos++; } @@ -238,7 +238,7 @@ ExpressionTreeNode Parser::parsePrecedence(const vector& tokens, int if (moreArgs) pos++; } while (moreArgs); - if (pos == tokens.size() || tokens[pos].getType() != ParseToken::RightParen) + if (pos == (int)tokens.size() || tokens[pos].getType() != ParseToken::RightParen) throw Exception("unbalanced parentheses"); pos++; Operation* op = getFunctionOperation(token.getText(), customFunctions); From 46f514d2cabb20150502520546843d106ce98b58 Mon Sep 17 00:00:00 2001 From: Axel Kohlmeyer Date: Wed, 21 Dec 2022 22:30:05 -0500 Subject: [PATCH 13/79] add support for writing binary restart files --- src/LEPTON/pair_lepton.cpp | 100 ++++++++++++++++++++++++++++++++++++- src/LEPTON/pair_lepton.h | 4 ++ 2 files changed, 103 insertions(+), 1 deletion(-) diff --git a/src/LEPTON/pair_lepton.cpp b/src/LEPTON/pair_lepton.cpp index 11587d637c..35264962e2 100644 --- a/src/LEPTON/pair_lepton.cpp +++ b/src/LEPTON/pair_lepton.cpp @@ -18,6 +18,7 @@ #include "pair_lepton.h" #include "atom.h" +#include "comm.h" #include "error.h" #include "force.h" #include "memory.h" @@ -38,7 +39,7 @@ PairLepton::PairLepton(LAMMPS *lmp) : Pair(lmp), cut(nullptr), type2expression(n respa_enable = 0; single_enable = 1; writedata = 1; - restartinfo = 0; + restartinfo = 1; reinitflag = 0; cut_global = 0.0; centroidstressflag = CENTROID_SAME; @@ -258,6 +259,103 @@ double PairLepton::init_one(int i, int j) return cut[i][j]; } +/* ---------------------------------------------------------------------- + proc 0 writes to restart file +------------------------------------------------------------------------- */ + +void PairLepton::write_restart(FILE *fp) +{ + write_restart_settings(fp); + + for (int i = 1; i <= atom->ntypes; i++) + for (int j = i; j <= atom->ntypes; j++) { + fwrite(&setflag[i][j], sizeof(int), 1, fp); + if (setflag[i][j]) { + fwrite(&cut[i][j], sizeof(double), 1, fp); + fwrite(&type2expression[i][j], sizeof(int), 1, fp); + } + } + + int num = expressions.size(); + int maxlen = 0; + for (const auto &exp : expressions) maxlen = MAX(maxlen, (int) exp.size()); + ++maxlen; + + fwrite(&num, sizeof(int), 1, fp); + fwrite(&maxlen, sizeof(int), 1, fp); + for (const auto &exp : expressions) { + int n = exp.size() + 1; + fwrite(&n, sizeof(int), 1, fp); + fwrite(exp.c_str(), sizeof(char), n, fp); + } +} + +/* ---------------------------------------------------------------------- + proc 0 reads from restart file, bcasts +------------------------------------------------------------------------- */ + +void PairLepton::read_restart(FILE *fp) +{ + read_restart_settings(fp); + + allocate(); + expressions.clear(); + + const int me = comm->me; + for (int i = 1; i <= atom->ntypes; i++) + for (int j = i; j <= atom->ntypes; j++) { + if (me == 0) utils::sfread(FLERR, &setflag[i][j], sizeof(int), 1, fp, nullptr, error); + MPI_Bcast(&setflag[i][j], 1, MPI_INT, 0, world); + if (setflag[i][j]) { + if (me == 0) { + utils::sfread(FLERR, &cut[i][j], sizeof(double), 1, fp, nullptr, error); + utils::sfread(FLERR, &type2expression[i][j], sizeof(int), 1, fp, nullptr, error); + } + MPI_Bcast(&cut[i][j], 1, MPI_DOUBLE, 0, world); + MPI_Bcast(&type2expression[i][j], 1, MPI_INT, 0, world); + } + } + + int num, maxlen, len; + if (me == 0) { + utils::sfread(FLERR, &num, sizeof(int), 1, fp, nullptr, error); + utils::sfread(FLERR, &maxlen, sizeof(int), 1, fp, nullptr, error); + } + MPI_Bcast(&num, 1, MPI_INT, 0, world); + MPI_Bcast(&maxlen, 1, MPI_INT, 0, world); + char *buf = new char[maxlen]; + + for (int i = 0; i < num; ++i) { + if (me == 0) { + utils::sfread(FLERR, &len, sizeof(int), 1, fp, nullptr, error); + utils::sfread(FLERR, buf, sizeof(char), len, fp, nullptr, error); + } + MPI_Bcast(buf, maxlen, MPI_CHAR, 0, world); + expressions.push_back(buf); + } + + delete[] buf; +} + +/* ---------------------------------------------------------------------- + proc 0 writes to restart file +------------------------------------------------------------------------- */ + +void PairLepton::write_restart_settings(FILE *fp) +{ + fwrite(&cut_global, sizeof(double), 1, fp); +} + +/* ---------------------------------------------------------------------- + proc 0 reads from restart file, bcasts +------------------------------------------------------------------------- */ + +void PairLepton::read_restart_settings(FILE *fp) +{ + if (comm->me == 0) { utils::sfread(FLERR, &cut_global, sizeof(double), 1, fp, nullptr, error); } + MPI_Bcast(&cut_global, 1, MPI_DOUBLE, 0, world); +} + /* ---------------------------------------------------------------------- proc 0 writes to data file ------------------------------------------------------------------------- */ diff --git a/src/LEPTON/pair_lepton.h b/src/LEPTON/pair_lepton.h index dd1f40d1d3..c6d62c1adc 100644 --- a/src/LEPTON/pair_lepton.h +++ b/src/LEPTON/pair_lepton.h @@ -42,6 +42,10 @@ class PairLepton : public Pair { void settings(int, char **) override; void coeff(int, char **) override; double init_one(int, int) override; + void write_restart(FILE *) override; + void read_restart(FILE *) override; + void write_restart_settings(FILE *) override; + void read_restart_settings(FILE *) override; void write_data(FILE *) override; void write_data_all(FILE *) override; double single(int, int, int, int, double, double, double, double &) override; From 966211bb53ccd5abe91de6e01cba00e6414fe74e Mon Sep 17 00:00:00 2001 From: Axel Kohlmeyer Date: Thu, 22 Dec 2022 02:13:28 -0500 Subject: [PATCH 14/79] avoid conflicting names --- src/LEPTON/pair_lepton.cpp | 42 ++++++++++++++++++---------------- src/LEPTON/pair_lepton.h | 6 ++--- src/OPENMP/pair_lepton_omp.cpp | 16 ++++++------- 3 files changed, 32 insertions(+), 32 deletions(-) diff --git a/src/LEPTON/pair_lepton.cpp b/src/LEPTON/pair_lepton.cpp index 35264962e2..c2f8b35b9f 100644 --- a/src/LEPTON/pair_lepton.cpp +++ b/src/LEPTON/pair_lepton.cpp @@ -26,6 +26,7 @@ #include "update.h" #include +#include #include #include "LMP_Lepton.h" @@ -98,12 +99,12 @@ template void PairLepton::eval() const int *const *const firstneigh = list->firstneigh; double fxtmp, fytmp, fztmp; - std::vector force; - std::vector epot; + std::vector pairforce; + std::vector pairpot; for (const auto &expr : expressions) { auto parsed = LMP_Lepton::Parser::parse(expr); - force.emplace_back(parsed.differentiate("r").createCompiledExpression()); - if (EFLAG) epot.emplace_back(parsed.createCompiledExpression()); + pairforce.emplace_back(parsed.differentiate("r").createCompiledExpression()); + if (EFLAG) pairpot.emplace_back(parsed.createCompiledExpression()); } // loop over neighbors of my atoms @@ -132,9 +133,9 @@ template void PairLepton::eval() if (rsq < cutsq[itype][jtype]) { const double r = sqrt(rsq); const int idx = type2expression[itype][jtype]; - double &r_for = force[idx].getVariableReference("r"); + double &r_for = pairforce[idx].getVariableReference("r"); r_for = r; - const double fpair = -force[idx].evaluate() / r * factor_lj; + const double fpair = -pairforce[idx].evaluate() / r * factor_lj; fxtmp += delx * fpair; fytmp += dely * fpair; @@ -147,9 +148,9 @@ template void PairLepton::eval() double evdwl = 0.0; if (EFLAG) { - double &r_pot = epot[idx].getVariableReference("r"); + double &r_pot = pairpot[idx].getVariableReference("r"); r_pot = r; - evdwl = factor_lj * epot[idx].evaluate(); + evdwl = factor_lj * pairpot[idx].evaluate(); } if (EVFLAG) ev_tally(i, j, nlocal, NEWTON_PAIR, evdwl, 0.0, fpair, delx, dely, delz); @@ -214,13 +215,13 @@ void PairLepton::coeff(int narg, char **arg) try { auto parsed = LMP_Lepton::Parser::parse(exp_one); - auto epot = parsed.createCompiledExpression(); - auto force = parsed.differentiate("r").createCompiledExpression(); - double &r_pot = epot.getVariableReference("r"); - double &r_for = force.getVariableReference("r"); + auto pairpot = parsed.createCompiledExpression(); + auto pairforce = parsed.differentiate("r").createCompiledExpression(); + double &r_pot = pairpot.getVariableReference("r"); + double &r_for = pairforce.getVariableReference("r"); r_for = r_pot = 1.0; - epot.evaluate(); - force.evaluate(); + pairpot.evaluate(); + pairforce.evaluate(); } catch (std::exception &e) { error->all(FLERR, e.what()); } @@ -323,6 +324,7 @@ void PairLepton::read_restart(FILE *fp) } MPI_Bcast(&num, 1, MPI_INT, 0, world); MPI_Bcast(&maxlen, 1, MPI_INT, 0, world); + char *buf = new char[maxlen]; for (int i = 0; i < num; ++i) { @@ -383,14 +385,14 @@ double PairLepton::single(int /* i */, int /* j */, int itype, int jtype, double double /* factor_coul */, double factor_lj, double &fforce) { auto parsed = LMP_Lepton::Parser::parse(expressions[type2expression[itype][jtype]]); - auto epot = parsed.createCompiledExpression(); - auto force = parsed.differentiate("r").createCompiledExpression(); + auto pairpot = parsed.createCompiledExpression(); + auto pairforce = parsed.differentiate("r").createCompiledExpression(); double r = sqrt(rsq); - double &r_pot = epot.getVariableReference("r"); - double &r_for = force.getVariableReference("r"); + double &r_pot = pairpot.getVariableReference("r"); + double &r_for = pairforce.getVariableReference("r"); r_pot = r_for = r; - fforce = -force.evaluate() / r * factor_lj; - return epot.evaluate() * factor_lj; + fforce = -pairforce.evaluate() / r * factor_lj; + return pairpot.evaluate() * factor_lj; } diff --git a/src/LEPTON/pair_lepton.h b/src/LEPTON/pair_lepton.h index c6d62c1adc..6a55695bc3 100644 --- a/src/LEPTON/pair_lepton.h +++ b/src/LEPTON/pair_lepton.h @@ -56,13 +56,11 @@ class PairLepton : public Pair { int **type2expression; double cut_global; + virtual void allocate(); + private: template void eval(); - - virtual void allocate(); }; - } // namespace LAMMPS_NS - #endif #endif diff --git a/src/OPENMP/pair_lepton_omp.cpp b/src/OPENMP/pair_lepton_omp.cpp index 35f8e2c89d..2acb9e8b9f 100644 --- a/src/OPENMP/pair_lepton_omp.cpp +++ b/src/OPENMP/pair_lepton_omp.cpp @@ -93,12 +93,12 @@ void PairLeptonOMP::eval(int iifrom, int iito, ThrData *const thr) const int *const *const firstneigh = list->firstneigh; double fxtmp, fytmp, fztmp; - std::vector force; - std::vector epot; + std::vector pairforce; + std::vector pairpot; for (const auto &expr : expressions) { auto parsed = LMP_Lepton::Parser::parse(expr); - force.emplace_back(parsed.differentiate("r").createCompiledExpression()); - if (EFLAG) epot.emplace_back(parsed.createCompiledExpression()); + pairforce.emplace_back(parsed.differentiate("r").createCompiledExpression()); + if (EFLAG) pairpot.emplace_back(parsed.createCompiledExpression()); } // loop over neighbors of my atoms @@ -127,9 +127,9 @@ void PairLeptonOMP::eval(int iifrom, int iito, ThrData *const thr) if (rsq < cutsq[itype][jtype]) { const double r = sqrt(rsq); const int idx = type2expression[itype][jtype]; - double &r_for = force[idx].getVariableReference("r"); + double &r_for = pairforce[idx].getVariableReference("r"); r_for = r; - const double fpair = -force[idx].evaluate() / r * factor_lj; + const double fpair = -pairforce[idx].evaluate() / r * factor_lj; fxtmp += delx * fpair; fytmp += dely * fpair; @@ -142,9 +142,9 @@ void PairLeptonOMP::eval(int iifrom, int iito, ThrData *const thr) double evdwl = 0.0; if (EFLAG) { - double &r_pot = epot[idx].getVariableReference("r"); + double &r_pot = pairpot[idx].getVariableReference("r"); r_pot = r; - evdwl = factor_lj * epot[idx].evaluate(); + evdwl = factor_lj * pairpot[idx].evaluate(); } if (EVFLAG) From 5da8242690f837f8380c691b87cb05b70fe4df3e Mon Sep 17 00:00:00 2001 From: Axel Kohlmeyer Date: Thu, 22 Dec 2022 02:13:51 -0500 Subject: [PATCH 15/79] add bond style lepton --- src/LEPTON/bond_lepton.cpp | 322 +++++++++++++++++++ src/LEPTON/bond_lepton.h | 52 +++ src/OPENMP/bond_lepton_omp.cpp | 148 +++++++++ src/OPENMP/bond_lepton_omp.h | 44 +++ unittest/force-styles/tests/bond-lepton.yaml | 90 ++++++ 5 files changed, 656 insertions(+) create mode 100644 src/LEPTON/bond_lepton.cpp create mode 100644 src/LEPTON/bond_lepton.h create mode 100644 src/OPENMP/bond_lepton_omp.cpp create mode 100644 src/OPENMP/bond_lepton_omp.h create mode 100644 unittest/force-styles/tests/bond-lepton.yaml diff --git a/src/LEPTON/bond_lepton.cpp b/src/LEPTON/bond_lepton.cpp new file mode 100644 index 0000000000..ad2ac0af12 --- /dev/null +++ b/src/LEPTON/bond_lepton.cpp @@ -0,0 +1,322 @@ +/* ---------------------------------------------------------------------- + LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator + https://www.lammps.org/, Sandia National Laboratories + LAMMPS development team: developers@lammps.org + + Copyright (2003) Sandia Corporation. Under the terms of Contract + DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government retains + certain rights in this software. This software is distributed under + the GNU General Public License. + + See the README file in the top-level LAMMPS directory. +------------------------------------------------------------------------- */ + +/* ---------------------------------------------------------------------- + Contributing author: Axel Kohlmeyer (Temple U) +------------------------------------------------------------------------- */ + +#include "bond_lepton.h" + +#include "atom.h" +#include "comm.h" +#include "error.h" +#include "force.h" +#include "memory.h" +#include "neighbor.h" + +#include +#include +#include + +#include "LMP_Lepton.h" + +using namespace LAMMPS_NS; + +/* ---------------------------------------------------------------------- */ + +BondLepton::BondLepton(LAMMPS *_lmp) : Bond(_lmp), r0(nullptr), type2expression(nullptr) +{ + writedata = 1; + reinitflag = 0; +} + +/* ---------------------------------------------------------------------- */ + +BondLepton::~BondLepton() +{ + if (allocated) { + memory->destroy(setflag); + memory->destroy(r0); + memory->destroy(type2expression); + } +} + +/* ---------------------------------------------------------------------- */ + +void BondLepton::compute(int eflag, int vflag) +{ + ev_init(eflag, vflag); + ev_init(eflag, vflag); + if (evflag) { + if (eflag) { + if (force->newton_bond) + eval<1, 1, 1>(); + else + eval<1, 1, 0>(); + } else { + if (force->newton_bond) + eval<1, 0, 1>(); + else + eval<1, 0, 0>(); + } + } else { + if (force->newton_bond) + eval<0, 0, 1>(); + else + eval<0, 0, 0>(); + } +} + +/* ---------------------------------------------------------------------- */ +template void BondLepton::eval() +{ + std::vector bondforce; + std::vector bondpot; + for (const auto &expr : expressions) { + auto parsed = LMP_Lepton::Parser::parse(expr); + bondforce.emplace_back(parsed.differentiate("r").createCompiledExpression()); + if (EFLAG) bondpot.emplace_back(parsed.createCompiledExpression()); + } + + const double *const *const x = atom->x; + double *const *const f = atom->f; + const int *const *const bondlist = neighbor->bondlist; + const int nbondlist = neighbor->nbondlist; + const int nlocal = atom->nlocal; + + for (int n = 0; n < nbondlist; n++) { + const int i1 = bondlist[n][0]; + const int i2 = bondlist[n][1]; + const int type = bondlist[n][2]; + + const double delx = x[i1][0] - x[i2][0]; + const double dely = x[i1][1] - x[i2][1]; + const double delz = x[i1][2] - x[i2][2]; + + const double rsq = delx * delx + dely * dely + delz * delz; + const double r = sqrt(rsq); + const double dr = r - r0[type]; + const int idx = type2expression[type]; + + // force and energy + + double fbond = 0.0; + if (r > 0.0) { + double &r_for = bondforce[idx].getVariableReference("r"); + r_for = dr; + fbond = -bondforce[idx].evaluate() / r; + } + + // apply force to each of 2 atoms + + if (NEWTON_BOND || i1 < nlocal) { + f[i1][0] += delx * fbond; + f[i1][1] += dely * fbond; + f[i1][2] += delz * fbond; + } + + if (NEWTON_BOND || i2 < nlocal) { + f[i2][0] -= delx * fbond; + f[i2][1] -= dely * fbond; + f[i2][2] -= delz * fbond; + } + + double ebond = 0.0; + if (EFLAG) { + double &r_pot = bondpot[idx].getVariableReference("r"); + r_pot = dr; + ebond = bondpot[idx].evaluate(); + } + if (EVFLAG) ev_tally(i1, i2, nlocal, NEWTON_BOND, ebond, fbond, delx, dely, delz); + } +} + +/* ---------------------------------------------------------------------- */ + +void BondLepton::allocate() +{ + allocated = 1; + const int np1 = atom->nbondtypes + 1; + + memory->create(r0, np1, "bond:r0"); + memory->create(type2expression, np1, "bond:type2expression"); + memory->create(setflag, np1, "bond:setflag"); + for (int i = 1; i < np1; i++) setflag[i] = 0; +} + +/* ---------------------------------------------------------------------- + set coeffs for one type +------------------------------------------------------------------------- */ + +void BondLepton::coeff(int narg, char **arg) +{ + if (narg != 3) error->all(FLERR, "Incorrect number of args for bond coefficients"); + if (!allocated) allocate(); + + int ilo, ihi; + utils::bounds(FLERR, arg[0], 1, atom->nbondtypes, ilo, ihi, error); + + double r0_one = utils::numeric(FLERR, arg[1], false, lmp); + + // remove whitespace and quotes from expression string and then + // check if the expression can be parsed and evaluated without error + std::string exp_one; + for (const auto &c : std::string(arg[2])) + if (!isspace(c) && (c != '"') && (c != '\'')) exp_one.push_back(c); + + try { + auto parsed = LMP_Lepton::Parser::parse(exp_one); + auto bondpot = parsed.createCompiledExpression(); + auto bondforce = parsed.differentiate("r").createCompiledExpression(); + double &r_pot = bondpot.getVariableReference("r"); + double &r_for = bondforce.getVariableReference("r"); + r_for = r_pot = 1.0; + bondpot.evaluate(); + bondforce.evaluate(); + } catch (std::exception &e) { + error->all(FLERR, e.what()); + } + + std::size_t idx = 0; + for (const auto &exp : expressions) { + if (exp == exp_one) break; + ++idx; + } + + // not found, add to list + if ((expressions.size() == 0) || (idx == expressions.size())) expressions.push_back(exp_one); + + int count = 0; + for (int i = ilo; i <= ihi; i++) { + r0[i] = r0_one; + type2expression[i] = idx; + setflag[i] = 1; + count++; + } + + if (count == 0) error->all(FLERR, "Incorrect args for bond coefficients"); +} + +/* ---------------------------------------------------------------------- + return an equilbrium bond length +------------------------------------------------------------------------- */ + +double BondLepton::equilibrium_distance(int i) +{ + return r0[i]; +} + +/* ---------------------------------------------------------------------- + proc 0 writes to restart file +------------------------------------------------------------------------- */ + +void BondLepton::write_restart(FILE *fp) +{ + fwrite(&r0[1], sizeof(double), atom->nbondtypes, fp); + fwrite(&type2expression[1], sizeof(int), atom->nbondtypes, fp); + + int num = expressions.size(); + int maxlen = 0; + for (const auto &exp : expressions) maxlen = MAX(maxlen, (int) exp.size()); + ++maxlen; + + fwrite(&num, sizeof(int), 1, fp); + fwrite(&maxlen, sizeof(int), 1, fp); + for (const auto &exp : expressions) { + int n = exp.size() + 1; + fwrite(&n, sizeof(int), 1, fp); + fwrite(exp.c_str(), sizeof(char), n, fp); + } +} + +/* ---------------------------------------------------------------------- + proc 0 reads from restart file, bcasts +------------------------------------------------------------------------- */ + +void BondLepton::read_restart(FILE *fp) +{ + allocate(); + + if (comm->me == 0) { + utils::sfread(FLERR, &r0[1], sizeof(double), atom->nbondtypes, fp, nullptr, error); + utils::sfread(FLERR, &type2expression[1], sizeof(int), atom->nbondtypes, fp, nullptr, error); + } + MPI_Bcast(&r0[1], atom->nbondtypes, MPI_DOUBLE, 0, world); + MPI_Bcast(&type2expression[1], atom->nbondtypes, MPI_INT, 0, world); + for (int i = 1; i <= atom->nbondtypes; i++) setflag[i] = 1; + + int num, maxlen, len; + if (comm->me == 0) { + utils::sfread(FLERR, &num, sizeof(int), 1, fp, nullptr, error); + utils::sfread(FLERR, &maxlen, sizeof(int), 1, fp, nullptr, error); + } + MPI_Bcast(&num, 1, MPI_INT, 0, world); + MPI_Bcast(&maxlen, 1, MPI_INT, 0, world); + + char *buf = new char[maxlen]; + + for (int i = 0; i < num; ++i) { + if (comm->me == 0) { + utils::sfread(FLERR, &len, sizeof(int), 1, fp, nullptr, error); + utils::sfread(FLERR, buf, sizeof(char), len, fp, nullptr, error); + } + MPI_Bcast(buf, maxlen, MPI_CHAR, 0, world); + expressions.push_back(buf); + } + + delete[] buf; +} + +/* ---------------------------------------------------------------------- + proc 0 writes to data file +------------------------------------------------------------------------- */ + +void BondLepton::write_data(FILE *fp) +{ + for (int i = 1; i <= atom->nbondtypes; i++) + fprintf(fp, "%d %g %s\n", i, r0[i], expressions[type2expression[i]].c_str()); +} + +/* ---------------------------------------------------------------------- */ + +double BondLepton::single(int type, double rsq, int /*i*/, int /*j*/, double &fforce) +{ + const double r = sqrt(rsq); + const double dr = r - r0[type]; + + auto parsed = LMP_Lepton::Parser::parse(expressions[type2expression[type]]); + auto bondpot = parsed.createCompiledExpression(); + auto bondforce = parsed.differentiate("r").createCompiledExpression(); + double &r_for = bondforce.getVariableReference("r"); + double &r_pot = bondpot.getVariableReference("r"); + r_for = r_pot = dr; + + // force and energy + + fforce = 0.0; + double ebond = 0.0; + if (r > 0.0) { + fforce = -bondforce.evaluate() / r; + ebond = bondpot.evaluate(); + } + return ebond; +} + +/* ---------------------------------------------------------------------- */ + +void *BondLepton::extract(const char *str, int &dim) +{ + dim = 1; + if (strcmp(str, "r0") == 0) return (void *) r0; + return nullptr; +} diff --git a/src/LEPTON/bond_lepton.h b/src/LEPTON/bond_lepton.h new file mode 100644 index 0000000000..5c430b7e63 --- /dev/null +++ b/src/LEPTON/bond_lepton.h @@ -0,0 +1,52 @@ +/* -*- c++ -*- ---------------------------------------------------------- + LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator + https://www.lammps.org/, Sandia National Laboratories + LAMMPS development team: developers@lammps.org + + Copyright (2003) Sandia Corporation. Under the terms of Contract + DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government retains + certain rights in this software. This software is distributed under + the GNU General Public License. + + See the README file in the top-level LAMMPS directory. +------------------------------------------------------------------------- */ + +#ifdef BOND_CLASS +// clang-format off +BondStyle(lepton,BondLepton); +// clang-format on +#else + +#ifndef LMP_BOND_LEPTON_H +#define LMP_BOND_LEPTON_H + +#include "bond.h" + +namespace LAMMPS_NS { + +class BondLepton : public Bond { + public: + BondLepton(class LAMMPS *); + ~BondLepton() override; + void compute(int, int) override; + void coeff(int, char **) override; + double equilibrium_distance(int) override; + void write_restart(FILE *) override; + void read_restart(FILE *) override; + void write_data(FILE *) override; + double single(int, double, int, int, double &) override; + void *extract(const char *, int &) override; + + protected: + std::vector expressions; + double *r0; + int *type2expression; + + virtual void allocate(); + + private: + template void eval(); +}; +} // namespace LAMMPS_NS +#endif +#endif diff --git a/src/OPENMP/bond_lepton_omp.cpp b/src/OPENMP/bond_lepton_omp.cpp new file mode 100644 index 0000000000..3171aaa51c --- /dev/null +++ b/src/OPENMP/bond_lepton_omp.cpp @@ -0,0 +1,148 @@ +/* ---------------------------------------------------------------------- + LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator + https://www.lammps.org/, Sandia National Laboratories + LAMMPS development team: developers@lammps.org + + Copyright (2003) Sandia Corporation. Under the terms of Contract + DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government retains + certain rights in this software. This software is distributed under + the GNU General Public License. + + See the README file in the top-level LAMMPS directory. +------------------------------------------------------------------------- */ + +/* ---------------------------------------------------------------------- + Contributing author: Axel Kohlmeyer (Temple U) +------------------------------------------------------------------------- */ + +#include "bond_lepton_omp.h" +#include "atom.h" +#include "comm.h" +#include "force.h" +#include "neighbor.h" +#include "omp_compat.h" + +#include + +#include "LMP_Lepton.h" +#include "suffix.h" +using namespace LAMMPS_NS; + +/* ---------------------------------------------------------------------- */ + +BondLeptonOMP::BondLeptonOMP(class LAMMPS *lmp) : BondLepton(lmp), ThrOMP(lmp, THR_BOND) +{ + suffix_flag |= Suffix::OMP; +} + +/* ---------------------------------------------------------------------- */ + +void BondLeptonOMP::compute(int eflag, int vflag) +{ + ev_init(eflag, vflag); + + const int nall = atom->nlocal + atom->nghost; + const int nthreads = comm->nthreads; + const int inum = neighbor->nbondlist; + +#if defined(_OPENMP) +#pragma omp parallel LMP_DEFAULT_NONE LMP_SHARED(eflag, vflag) +#endif + { + int ifrom, ito, tid; + + loop_setup_thr(ifrom, ito, tid, inum, nthreads); + ThrData *thr = fix->get_thr(tid); + thr->timer(Timer::START); + ev_setup_thr(eflag, vflag, nall, eatom, vatom, nullptr, thr); + + if (inum > 0) { + if (evflag) { + if (eflag) { + if (force->newton_bond) + eval<1, 1, 1>(ifrom, ito, thr); + else + eval<1, 1, 0>(ifrom, ito, thr); + } else { + if (force->newton_bond) + eval<1, 0, 1>(ifrom, ito, thr); + else + eval<1, 0, 0>(ifrom, ito, thr); + } + } else { + if (force->newton_bond) + eval<0, 0, 1>(ifrom, ito, thr); + else + eval<0, 0, 0>(ifrom, ito, thr); + } + } + thr->timer(Timer::BOND); + reduce_thr(this, eflag, vflag, thr); + } // end of omp parallel region +} + +/* ---------------------------------------------------------------------- */ + +template +void BondLeptonOMP::eval(int nfrom, int nto, ThrData *const thr) +{ + std::vector bondforce; + std::vector bondpot; + for (const auto &expr : expressions) { + auto parsed = LMP_Lepton::Parser::parse(expr); + bondforce.emplace_back(parsed.differentiate("r").createCompiledExpression()); + if (EFLAG) bondpot.emplace_back(parsed.createCompiledExpression()); + } + + const auto *_noalias const x = (dbl3_t *) atom->x[0]; + auto *_noalias const f = (dbl3_t *) thr->get_f()[0]; + const int3_t *_noalias const bondlist = (int3_t *) neighbor->bondlist[0]; + const int nlocal = atom->nlocal; + + for (int n = nfrom; n < nto; n++) { + const int i1 = bondlist[n].a; + const int i2 = bondlist[n].b; + const int type = bondlist[n].t; + + const double delx = x[i1].x - x[i2].x; + const double dely = x[i1].y - x[i2].y; + const double delz = x[i1].z - x[i2].z; + + const double rsq = delx * delx + dely * dely + delz * delz; + const double r = sqrt(rsq); + const double dr = r - r0[type]; + const int idx = type2expression[type]; + + // force and energy + + double fbond = 0.0; + if (r > 0.0) { + double &r_for = bondforce[idx].getVariableReference("r"); + r_for = dr; + fbond = -bondforce[idx].evaluate() / r; + } + + // apply force to each of 2 atoms + + if (NEWTON_BOND || i1 < nlocal) { + f[i1].x += delx * fbond; + f[i1].y += dely * fbond; + f[i1].z += delz * fbond; + } + + if (NEWTON_BOND || i2 < nlocal) { + f[i2].x -= delx * fbond; + f[i2].y -= dely * fbond; + f[i2].z -= delz * fbond; + } + + double ebond = 0.0; + if (EFLAG) { + double &r_pot = bondpot[idx].getVariableReference("r"); + r_pot = dr; + ebond = bondpot[idx].evaluate(); + } + if (EVFLAG) + ev_tally_thr(this, i1, i2, nlocal, NEWTON_BOND, ebond, fbond, delx, dely, delz, thr); + } +} diff --git a/src/OPENMP/bond_lepton_omp.h b/src/OPENMP/bond_lepton_omp.h new file mode 100644 index 0000000000..bdcc36434e --- /dev/null +++ b/src/OPENMP/bond_lepton_omp.h @@ -0,0 +1,44 @@ +/* -*- c++ -*- ---------------------------------------------------------- + LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator + https://www.lammps.org/, Sandia National Laboratories + LAMMPS development team: developers@lammps.org + + Copyright (2003) Sandia Corporation. Under the terms of Contract + DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government retains + certain rights in this software. This software is distributed under + the GNU General Public License. + + See the README file in the top-level LAMMPS directory. +------------------------------------------------------------------------- */ + +/* ---------------------------------------------------------------------- + Contributing author: Axel Kohlmeyer (Temple U) +------------------------------------------------------------------------- */ + +#ifdef BOND_CLASS +// clang-format off +BondStyle(lepton/omp,BondLeptonOMP); +// clang-format on +#else + +#ifndef LMP_BOND_LEPTON_OMP_H +#define LMP_BOND_LEPTON_OMP_H + +#include "bond_lepton.h" +#include "thr_omp.h" + +namespace LAMMPS_NS { + +class BondLeptonOMP : public BondLepton, public ThrOMP { + + public: + BondLeptonOMP(class LAMMPS *lmp); + void compute(int, int) override; + + private: + template + void eval(int ifrom, int ito, ThrData *const thr); +}; +} // namespace LAMMPS_NS +#endif +#endif diff --git a/unittest/force-styles/tests/bond-lepton.yaml b/unittest/force-styles/tests/bond-lepton.yaml new file mode 100644 index 0000000000..a220dd8b0c --- /dev/null +++ b/unittest/force-styles/tests/bond-lepton.yaml @@ -0,0 +1,90 @@ +--- +lammps_version: 3 Nov 2022 +tags: generated +date_generated: Thu Dec 22 02:03:59 2022 +epsilon: 2.5e-13 +skip_tests: +prerequisites: ! | + atom full + bond lepton +pre_commands: ! "" +post_commands: ! "" +input_file: in.fourmol +bond_style: lepton +bond_coeff: ! | + 1 1.5 "k*r^2; k=250.0" + 2 1.1 "k2*r^2 + k3*r^3 + k4*r^4; k2=300.0; k3=-100.0; k4=50.0" + 3 1.3 "k*r^2; k=350.0" + 4 1.2 "k*(r-0.2)^2; k=500.0" + 5 1.0 "k*r^2; k=450.0" +equilibrium: 5 1.5 1.1 1.3 1.2 1 +extract: ! | + r0 1 +natoms: 29 +init_energy: 38.295825321689215 +init_stress: ! |- + -4.7778964706834920e+01 -9.3066674567350432e+01 3.4789470658440035e+02 -3.0023920169312170e+01 -8.0421418879842847e+01 5.8592449335969732e+01 +init_forces: ! |2 + 1 -5.9149914305071416e+00 -3.7728809612345245e+01 -2.7769433362963369e+01 + 2 -9.4281609567839944e+00 -7.7586487054273015e+00 1.1096676787527940e+01 + 3 3.2211742366572125e+01 2.7682361264425523e+01 -7.0109911672970497e+00 + 4 4.9260777576375503e+00 -1.3809750102765932e+00 3.4951785613141868e+00 + 5 -1.2606902198593501e+00 -1.9373397933007170e+00 6.4372463095041841e+00 + 6 -3.8858476307965482e+01 6.8567296300319640e+01 1.9889888806614337e+02 + 7 7.5297927100028144e+00 -3.8622600737556944e+01 -1.9268793182212875e+02 + 8 1.3018665172824681e+01 -1.2902789438539877e+01 3.2406676637830003e+00 + 9 7.4343536239661590e-01 8.0072549738604493e-01 3.2899591078538779e+00 + 10 6.1558871886113291e+00 -2.2419470219698296e+00 1.0080175092279852e+01 + 11 -3.7020922615305768e-01 -9.1704102274126453e-01 -1.5046795827370363e+00 + 12 5.2437190958790678e+00 3.4225915524442998e+00 -2.5523597276998897e+00 + 13 -1.1277007635800260e+01 4.4610677459696646e+00 2.1195215396108269e-01 + 14 2.9813926585641828e+00 -6.0667387499775116e-01 7.7317115100728788e+00 + 15 2.5872825164662799e-01 -9.9415365173790704e+00 -3.5428115826174169e+00 + 16 5.2775953236493464e+01 -3.1855535724919463e+01 -1.6524229620195118e+02 + 17 -5.8735858023559175e+01 4.0959855098908882e+01 1.5582804819495431e+02 + 18 -9.0963607969319646e+00 -4.3343406270234155e+00 -1.7623055551859267e+01 + 19 1.2597490501067170e+01 8.0591915019111742e+00 1.5261489294231819e+01 + 20 -3.5011297041352050e+00 -3.7248508748877587e+00 2.3615662576274494e+00 + 21 -1.5332952658285048e+00 5.9630208068632040e-01 -7.4967230017303281e+00 + 22 4.2380253233105529e+00 1.0270453290850614e+00 6.6489894421385651e+00 + 23 -2.7047300574820481e+00 -1.6233474097713818e+00 8.4773355959176278e-01 + 24 -6.6588083188726532e+00 3.5110922792825918e+00 -6.5625174267043489e+00 + 25 7.9844426562464141e+00 -1.2853795683286129e+00 6.7123710742192300e+00 + 26 -1.3256343373737607e+00 -2.2257127109539789e+00 -1.4985364751488087e-01 + 27 6.6999960289138851e+00 6.3808952243186141e+00 2.0100808779497248e+00 + 28 -8.8466157439236681e-01 3.8018717064230995e-01 -5.9857060538593476e-01 + 29 -5.8153344545215182e+00 -6.7610823949609244e+00 -1.4115102725637900e+00 +run_energy: 37.78424389351509 +run_stress: ! |- + -4.6127506998693484e+01 -9.2129732247211749e+01 3.4548310342284810e+02 -2.9841348469661163e+01 -7.8434962689387717e+01 5.9253167412123155e+01 +run_forces: ! |2 + 1 -5.8451208652159004e+00 -3.7483084455000643e+01 -2.7706576989352534e+01 + 2 -9.4646964278974774e+00 -7.8058897724822449e+00 1.1098831256058579e+01 + 3 3.1827086102630346e+01 2.7573911030624821e+01 -6.9576662575837211e+00 + 4 5.1502169659901655e+00 -1.4367546726785101e+00 3.6631301025186187e+00 + 5 -1.2208420775139264e+00 -1.8781699435112362e+00 6.2332639085051911e+00 + 6 -3.8491523409043303e+01 6.8063273218541468e+01 1.9723141045830272e+02 + 7 7.4838209349394775e+00 -3.8394258853636330e+01 -1.9092625515909930e+02 + 8 1.2676329319901857e+01 -1.2475162287097550e+01 3.3659783337736577e+00 + 9 6.8845241565874460e-01 7.3814593866184031e-01 3.0434095400342533e+00 + 10 6.2545583994797553e+00 -2.9600470917047201e+00 9.4247125735981765e+00 + 11 -1.9554747834212524e-01 -4.8434314068172696e-01 -7.9452259566032057e-01 + 12 5.2092795750960841e+00 3.1431929551776721e+00 -3.1346654851373348e+00 + 13 -1.1496483840617872e+01 4.5245217971580018e+00 2.1348220240918236e-01 + 14 3.1913399826660909e+00 -6.3760720126489068e-01 8.2740980433927742e+00 + 15 2.7338564489784484e-01 -9.7206665011069671e+00 -3.4841809697094543e+00 + 16 5.2461611410918316e+01 -3.1639255494702798e+01 -1.6483607587596811e+02 + 17 -5.8501866653548078e+01 4.0872194473703807e+01 1.5529162691391761e+02 + 18 -7.0990354207248405e+00 -2.4743922643289666e+00 -1.7824398936159682e+01 + 19 1.2019842510974870e+01 7.7105128268768715e+00 1.4523712108141252e+01 + 20 -4.9208070902500296e+00 -5.2361205625479048e+00 3.3006868280184283e+00 + 21 -1.8548628650934149e+00 2.7467524264262122e-01 -6.7601469408617412e+00 + 22 3.9136757840663186e+00 9.5561415744904055e-01 6.1181929861632272e+00 + 23 -2.0588129189729036e+00 -1.2302894000916618e+00 6.4195395469851357e-01 + 24 -5.7681973234153086e+00 2.0209144998436366e+00 -5.2864044021513967e+00 + 25 6.3696975292216704e+00 -1.0109756418053095e+00 5.3564043759405795e+00 + 26 -6.0150020580636188e-01 -1.0099388580383271e+00 -6.9999973789182365e-02 + 27 6.8467535469188450e+00 5.7500299184200578e+00 2.2775780974490298e+00 + 28 -1.3929430925479587e+00 5.9772788540443345e-01 -9.4056106886485980e-01 + 29 -5.4538104543708865e+00 -6.3477578038244911e+00 -1.3370170285841700e+00 +... From 4cbe8b353be849f1b62302ac6698f81cdce8a406 Mon Sep 17 00:00:00 2001 From: Axel Kohlmeyer Date: Thu, 22 Dec 2022 05:37:59 -0500 Subject: [PATCH 16/79] move shared functionality to utility function added to Lepton library --- lib/lepton/include/LMP_Lepton.h | 8 ++++++++ lib/lepton/src/Utils.cpp | 31 +++++++++++++++++++++++++++++++ src/LEPTON/bond_lepton.cpp | 12 +++++------- src/LEPTON/pair_lepton.cpp | 7 +------ 4 files changed, 45 insertions(+), 13 deletions(-) create mode 100644 lib/lepton/src/Utils.cpp diff --git a/lib/lepton/include/LMP_Lepton.h b/lib/lepton/include/LMP_Lepton.h index 73b6b6fa38..d277bd2761 100644 --- a/lib/lepton/include/LMP_Lepton.h +++ b/lib/lepton/include/LMP_Lepton.h @@ -32,6 +32,8 @@ * USE OR OTHER DEALINGS IN THE SOFTWARE. * * -------------------------------------------------------------------------- */ +#include + #include "lepton/CompiledExpression.h" #include "lepton/CustomFunction.h" #include "lepton/ExpressionProgram.h" @@ -40,4 +42,10 @@ #include "lepton/ParsedExpression.h" #include "lepton/Parser.h" +// utility functions +namespace LMP_Lepton +{ + /// remove whitespace and quotes from expression string + std::string condense(const std::string &); +} #endif /*LMP_LEPTON_H_*/ diff --git a/lib/lepton/src/Utils.cpp b/lib/lepton/src/Utils.cpp new file mode 100644 index 0000000000..839da6cda2 --- /dev/null +++ b/lib/lepton/src/Utils.cpp @@ -0,0 +1,31 @@ +/* ---------------------------------------------------------------------- + LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator + https://www.lammps.org/, Sandia National Laboratories + LAMMPS development team: developers@lammps.org + + Copyright (2003) Sandia Corporation. Under the terms of Contract + DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government retains + certain rights in this software. This software is distributed under + the GNU General Public License. + + See the README file in the top-level LAMMPS directory. +------------------------------------------------------------------------- */ + +/* ---------------------------------------------------------------------- + Contributing author: Axel Kohlmeyer (Temple U) +------------------------------------------------------------------------- */ + +#include "LMP_Lepton.h" + +#include + +/// remove whitespace and quotes from expression string +std::string LMP_Lepton::condense(const std::string & in) +{ + std::string out; + for (const auto &c : in) + if (!isspace(c) && (c != '"') && (c != '\'')) out.push_back(c); + return out; +} + + diff --git a/src/LEPTON/bond_lepton.cpp b/src/LEPTON/bond_lepton.cpp index ad2ac0af12..b19a7ae50c 100644 --- a/src/LEPTON/bond_lepton.cpp +++ b/src/LEPTON/bond_lepton.cpp @@ -24,9 +24,7 @@ #include "memory.h" #include "neighbor.h" -#include #include -#include #include "LMP_Lepton.h" @@ -170,10 +168,7 @@ void BondLepton::coeff(int narg, char **arg) // remove whitespace and quotes from expression string and then // check if the expression can be parsed and evaluated without error - std::string exp_one; - for (const auto &c : std::string(arg[2])) - if (!isspace(c) && (c != '"') && (c != '\'')) exp_one.push_back(c); - + std::string exp_one = LMP_Lepton::condense(arg[2]); try { auto parsed = LMP_Lepton::Parser::parse(exp_one); auto bondpot = parsed.createCompiledExpression(); @@ -317,6 +312,9 @@ double BondLepton::single(int type, double rsq, int /*i*/, int /*j*/, double &ff void *BondLepton::extract(const char *str, int &dim) { dim = 1; - if (strcmp(str, "r0") == 0) return (void *) r0; + if (str) { + std::string keyword(str); + if (keyword == "r0") return (void *) r0; + } return nullptr; } diff --git a/src/LEPTON/pair_lepton.cpp b/src/LEPTON/pair_lepton.cpp index c2f8b35b9f..4a53e6985c 100644 --- a/src/LEPTON/pair_lepton.cpp +++ b/src/LEPTON/pair_lepton.cpp @@ -25,9 +25,7 @@ #include "neigh_list.h" #include "update.h" -#include #include -#include #include "LMP_Lepton.h" @@ -209,10 +207,7 @@ void PairLepton::coeff(int narg, char **arg) // remove whitespace and quotes from expression string and then // check if the expression can be parsed and evaluated without error - std::string exp_one; - for (const auto &c : std::string(arg[2])) - if (!isspace(c) && (c != '"') && (c != '\'')) exp_one.push_back(c); - + std::string exp_one = LMP_Lepton::condense(arg[2]); try { auto parsed = LMP_Lepton::Parser::parse(exp_one); auto pairpot = parsed.createCompiledExpression(); From 48c23788f2d0ae8b697c7aa6e1f25720ce14e583 Mon Sep 17 00:00:00 2001 From: Axel Kohlmeyer Date: Thu, 22 Dec 2022 07:10:48 -0500 Subject: [PATCH 17/79] handle pair_modify shift and enforce the bond lepton has zero energy at r0 --- src/LEPTON/bond_lepton.cpp | 15 ++++++++++----- src/LEPTON/bond_lepton.h | 1 + src/LEPTON/pair_lepton.cpp | 27 +++++++++++++++++++++++---- src/LEPTON/pair_lepton.h | 1 + src/OPENMP/bond_lepton_omp.cpp | 2 +- src/OPENMP/pair_lepton_omp.cpp | 3 ++- 6 files changed, 38 insertions(+), 11 deletions(-) diff --git a/src/LEPTON/bond_lepton.cpp b/src/LEPTON/bond_lepton.cpp index b19a7ae50c..7ebca00012 100644 --- a/src/LEPTON/bond_lepton.cpp +++ b/src/LEPTON/bond_lepton.cpp @@ -32,7 +32,8 @@ using namespace LAMMPS_NS; /* ---------------------------------------------------------------------- */ -BondLepton::BondLepton(LAMMPS *_lmp) : Bond(_lmp), r0(nullptr), type2expression(nullptr) +BondLepton::BondLepton(LAMMPS *_lmp) : + Bond(_lmp), r0(nullptr), type2expression(nullptr), offset(nullptr) { writedata = 1; reinitflag = 0; @@ -46,6 +47,7 @@ BondLepton::~BondLepton() memory->destroy(setflag); memory->destroy(r0); memory->destroy(type2expression); + memory->destroy(offset); } } @@ -133,7 +135,7 @@ template void BondLepton::eval() if (EFLAG) { double &r_pot = bondpot[idx].getVariableReference("r"); r_pot = dr; - ebond = bondpot[idx].evaluate(); + ebond = bondpot[idx].evaluate() - offset[type]; } if (EVFLAG) ev_tally(i1, i2, nlocal, NEWTON_BOND, ebond, fbond, delx, dely, delz); } @@ -148,6 +150,7 @@ void BondLepton::allocate() memory->create(r0, np1, "bond:r0"); memory->create(type2expression, np1, "bond:type2expression"); + memory->create(offset, np1, "bond:offset"); memory->create(setflag, np1, "bond:setflag"); for (int i = 1; i < np1; i++) setflag[i] = 0; } @@ -169,14 +172,15 @@ void BondLepton::coeff(int narg, char **arg) // remove whitespace and quotes from expression string and then // check if the expression can be parsed and evaluated without error std::string exp_one = LMP_Lepton::condense(arg[2]); + double offset_one = 0.0; try { auto parsed = LMP_Lepton::Parser::parse(exp_one); auto bondpot = parsed.createCompiledExpression(); auto bondforce = parsed.differentiate("r").createCompiledExpression(); double &r_pot = bondpot.getVariableReference("r"); double &r_for = bondforce.getVariableReference("r"); - r_for = r_pot = 1.0; - bondpot.evaluate(); + r_for = r_pot = r0_one; + offset_one = bondpot.evaluate(); bondforce.evaluate(); } catch (std::exception &e) { error->all(FLERR, e.what()); @@ -195,6 +199,7 @@ void BondLepton::coeff(int narg, char **arg) for (int i = ilo; i <= ihi; i++) { r0[i] = r0_one; type2expression[i] = idx; + offset[i] = offset_one; setflag[i] = 1; count++; } @@ -302,7 +307,7 @@ double BondLepton::single(int type, double rsq, int /*i*/, int /*j*/, double &ff double ebond = 0.0; if (r > 0.0) { fforce = -bondforce.evaluate() / r; - ebond = bondpot.evaluate(); + ebond = bondpot.evaluate() - offset[type]; } return ebond; } diff --git a/src/LEPTON/bond_lepton.h b/src/LEPTON/bond_lepton.h index 5c430b7e63..e91dda3187 100644 --- a/src/LEPTON/bond_lepton.h +++ b/src/LEPTON/bond_lepton.h @@ -41,6 +41,7 @@ class BondLepton : public Bond { std::vector expressions; double *r0; int *type2expression; + double *offset; virtual void allocate(); diff --git a/src/LEPTON/pair_lepton.cpp b/src/LEPTON/pair_lepton.cpp index 4a53e6985c..c159db6388 100644 --- a/src/LEPTON/pair_lepton.cpp +++ b/src/LEPTON/pair_lepton.cpp @@ -33,7 +33,8 @@ using namespace LAMMPS_NS; /* ---------------------------------------------------------------------- */ -PairLepton::PairLepton(LAMMPS *lmp) : Pair(lmp), cut(nullptr), type2expression(nullptr) +PairLepton::PairLepton(LAMMPS *lmp) : + Pair(lmp), cut(nullptr), type2expression(nullptr), offset(nullptr) { respa_enable = 0; single_enable = 1; @@ -53,6 +54,7 @@ PairLepton::~PairLepton() memory->destroy(cutsq); memory->destroy(setflag); memory->destroy(type2expression); + memory->destroy(offset); } } @@ -148,7 +150,8 @@ template void PairLepton::eval() if (EFLAG) { double &r_pot = pairpot[idx].getVariableReference("r"); r_pot = r; - evdwl = factor_lj * pairpot[idx].evaluate(); + evdwl = pairpot[idx].evaluate() - offset[itype][jtype]; + evdwl *= factor_lj; } if (EVFLAG) ev_tally(i, j, nlocal, NEWTON_PAIR, evdwl, 0.0, fpair, delx, dely, delz); @@ -176,6 +179,7 @@ void PairLepton::allocate() memory->create(cut, np1, np1, "pair:cut"); memory->create(cutsq, np1, np1, "pair:cutsq"); memory->create(type2expression, np1, np1, "pair:type2expression"); + memory->create(offset, np1, np1, "pair:offset"); } /* ---------------------------------------------------------------------- @@ -249,8 +253,18 @@ double PairLepton::init_one(int i, int j) { if (setflag[i][j] == 0) error->all(FLERR, "All pair coeffs are not set"); + if (offset_flag) { + auto parsed = LMP_Lepton::Parser::parse(expressions[type2expression[i][j]]); + auto pairpot = parsed.createCompiledExpression(); + double &r_pot = pairpot.getVariableReference("r"); + r_pot = 1.0; + offset[i][j] = pairpot.evaluate(); + } else + offset[i][j] = 0.0; + cut[j][i] = cut[i][j]; type2expression[j][i] = type2expression[i][j]; + offset[j][i] = offset[i][j]; return cut[i][j]; } @@ -341,6 +355,7 @@ void PairLepton::read_restart(FILE *fp) void PairLepton::write_restart_settings(FILE *fp) { fwrite(&cut_global, sizeof(double), 1, fp); + fwrite(&offset_flag, sizeof(int), 1, fp); } /* ---------------------------------------------------------------------- @@ -349,8 +364,12 @@ void PairLepton::write_restart_settings(FILE *fp) void PairLepton::read_restart_settings(FILE *fp) { - if (comm->me == 0) { utils::sfread(FLERR, &cut_global, sizeof(double), 1, fp, nullptr, error); } + if (comm->me == 0) { + utils::sfread(FLERR, &cut_global, sizeof(double), 1, fp, nullptr, error); + utils::sfread(FLERR, &offset_flag, sizeof(int), 1, fp, nullptr, error); + } MPI_Bcast(&cut_global, 1, MPI_DOUBLE, 0, world); + MPI_Bcast(&offset_flag, 1, MPI_INT, 0, world); } /* ---------------------------------------------------------------------- @@ -389,5 +408,5 @@ double PairLepton::single(int /* i */, int /* j */, int itype, int jtype, double r_pot = r_for = r; fforce = -pairforce.evaluate() / r * factor_lj; - return pairpot.evaluate() * factor_lj; + return (pairpot.evaluate() - offset[itype][jtype]) * factor_lj; } diff --git a/src/LEPTON/pair_lepton.h b/src/LEPTON/pair_lepton.h index 6a55695bc3..02105d1f27 100644 --- a/src/LEPTON/pair_lepton.h +++ b/src/LEPTON/pair_lepton.h @@ -54,6 +54,7 @@ class PairLepton : public Pair { std::vector expressions; double **cut; int **type2expression; + double **offset; double cut_global; virtual void allocate(); diff --git a/src/OPENMP/bond_lepton_omp.cpp b/src/OPENMP/bond_lepton_omp.cpp index 3171aaa51c..560256076f 100644 --- a/src/OPENMP/bond_lepton_omp.cpp +++ b/src/OPENMP/bond_lepton_omp.cpp @@ -140,7 +140,7 @@ void BondLeptonOMP::eval(int nfrom, int nto, ThrData *const thr) if (EFLAG) { double &r_pot = bondpot[idx].getVariableReference("r"); r_pot = dr; - ebond = bondpot[idx].evaluate(); + ebond = bondpot[idx].evaluate() - offset[type]; } if (EVFLAG) ev_tally_thr(this, i1, i2, nlocal, NEWTON_BOND, ebond, fbond, delx, dely, delz, thr); diff --git a/src/OPENMP/pair_lepton_omp.cpp b/src/OPENMP/pair_lepton_omp.cpp index 2acb9e8b9f..b3abfdd34b 100644 --- a/src/OPENMP/pair_lepton_omp.cpp +++ b/src/OPENMP/pair_lepton_omp.cpp @@ -144,7 +144,8 @@ void PairLeptonOMP::eval(int iifrom, int iito, ThrData *const thr) if (EFLAG) { double &r_pot = pairpot[idx].getVariableReference("r"); r_pot = r; - evdwl = factor_lj * pairpot[idx].evaluate(); + evdwl = pairpot[idx].evaluate() - offset[itype][jtype]; + evdwl *= factor_lj; } if (EVFLAG) From 28659295580f4ddc6dd980c62e8e99e451562412 Mon Sep 17 00:00:00 2001 From: Axel Kohlmeyer Date: Thu, 22 Dec 2022 07:11:04 -0500 Subject: [PATCH 18/79] update for added source --- lib/lepton/Makefile.mpi | 9 ++++++++- lib/lepton/Makefile.serial | 9 ++++++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/lib/lepton/Makefile.mpi b/lib/lepton/Makefile.mpi index 934188ba50..3d9cc49310 100644 --- a/lib/lepton/Makefile.mpi +++ b/lib/lepton/Makefile.mpi @@ -6,7 +6,14 @@ INC=-I include AR=ar ARFLAGS=rc LIB=liblmplepton.a -SRC=src/CompiledExpression.cpp src/ExpressionProgram.cpp src/ExpressionTreeNode.cpp src/Operation.cpp src/ParsedExpression.cpp src/Parser.cpp +SRC=src/CompiledExpression.cpp \ + src/ExpressionProgram.cpp \ + src/ExpressionTreeNode.cpp \ + src/Operation.cpp \ + src/ParsedExpression.cpp \ + src/Parser.cpp \ + src/Utils.cpp + OBJ=$(SRC:src/%.cpp=build/%.o) all: $(LIB) Makefile.lammps diff --git a/lib/lepton/Makefile.serial b/lib/lepton/Makefile.serial index 23d9b3dd57..c83951774e 100644 --- a/lib/lepton/Makefile.serial +++ b/lib/lepton/Makefile.serial @@ -6,7 +6,14 @@ INC=-I include AR=ar ARFLAGS=rc LIB=liblmplepton.a -SRC=src/CompiledExpression.cpp src/ExpressionProgram.cpp src/ExpressionTreeNode.cpp src/Operation.cpp src/ParsedExpression.cpp src/Parser.cpp +SRC=src/CompiledExpression.cpp \ + src/ExpressionProgram.cpp \ + src/ExpressionTreeNode.cpp \ + src/Operation.cpp \ + src/ParsedExpression.cpp \ + src/Parser.cpp \ + src/Utils.cpp + OBJ=$(SRC:src/%.cpp=build/%.o) all: $(LIB) Makefile.lammps From ab72e95d0a5ce7bfa8057d8b6cc0b7c426ae728b Mon Sep 17 00:00:00 2001 From: Axel Kohlmeyer Date: Thu, 22 Dec 2022 07:26:00 -0500 Subject: [PATCH 19/79] restart offset for bond style lepton --- src/LEPTON/bond_lepton.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/LEPTON/bond_lepton.cpp b/src/LEPTON/bond_lepton.cpp index 7ebca00012..3f764e4456 100644 --- a/src/LEPTON/bond_lepton.cpp +++ b/src/LEPTON/bond_lepton.cpp @@ -224,6 +224,7 @@ void BondLepton::write_restart(FILE *fp) { fwrite(&r0[1], sizeof(double), atom->nbondtypes, fp); fwrite(&type2expression[1], sizeof(int), atom->nbondtypes, fp); + fwrite(&offset[1], sizeof(double), atom->nbondtypes, fp); int num = expressions.size(); int maxlen = 0; @@ -250,9 +251,11 @@ void BondLepton::read_restart(FILE *fp) if (comm->me == 0) { utils::sfread(FLERR, &r0[1], sizeof(double), atom->nbondtypes, fp, nullptr, error); utils::sfread(FLERR, &type2expression[1], sizeof(int), atom->nbondtypes, fp, nullptr, error); + utils::sfread(FLERR, &offset[1], sizeof(double), atom->nbondtypes, fp, nullptr, error); } MPI_Bcast(&r0[1], atom->nbondtypes, MPI_DOUBLE, 0, world); MPI_Bcast(&type2expression[1], atom->nbondtypes, MPI_INT, 0, world); + MPI_Bcast(&offset[1], atom->nbondtypes, MPI_DOUBLE, 0, world); for (int i = 1; i <= atom->nbondtypes; i++) setflag[i] = 1; int num, maxlen, len; From a5ecef708f6cf85c4ac8cf6c27dcc7502334b49a Mon Sep 17 00:00:00 2001 From: Axel Kohlmeyer Date: Thu, 22 Dec 2022 09:59:55 -0500 Subject: [PATCH 20/79] correctly compute offsets. update unit test files. --- src/LEPTON/bond_lepton.cpp | 2 +- src/LEPTON/pair_lepton.cpp | 2 +- unittest/force-styles/tests/bond-lepton.yaml | 9 ++++----- unittest/force-styles/tests/mol-pair-lepton.yaml | 12 ++++++------ unittest/force-styles/tests/mol-pair-lj_cut.yaml | 9 +++++---- 5 files changed, 17 insertions(+), 17 deletions(-) diff --git a/src/LEPTON/bond_lepton.cpp b/src/LEPTON/bond_lepton.cpp index 3f764e4456..fb55571ca1 100644 --- a/src/LEPTON/bond_lepton.cpp +++ b/src/LEPTON/bond_lepton.cpp @@ -179,7 +179,7 @@ void BondLepton::coeff(int narg, char **arg) auto bondforce = parsed.differentiate("r").createCompiledExpression(); double &r_pot = bondpot.getVariableReference("r"); double &r_for = bondforce.getVariableReference("r"); - r_for = r_pot = r0_one; + r_for = r_pot = 0.0; offset_one = bondpot.evaluate(); bondforce.evaluate(); } catch (std::exception &e) { diff --git a/src/LEPTON/pair_lepton.cpp b/src/LEPTON/pair_lepton.cpp index c159db6388..39b5ade806 100644 --- a/src/LEPTON/pair_lepton.cpp +++ b/src/LEPTON/pair_lepton.cpp @@ -257,7 +257,7 @@ double PairLepton::init_one(int i, int j) auto parsed = LMP_Lepton::Parser::parse(expressions[type2expression[i][j]]); auto pairpot = parsed.createCompiledExpression(); double &r_pot = pairpot.getVariableReference("r"); - r_pot = 1.0; + r_pot = cut[i][j]; offset[i][j] = pairpot.evaluate(); } else offset[i][j] = 0.0; diff --git a/unittest/force-styles/tests/bond-lepton.yaml b/unittest/force-styles/tests/bond-lepton.yaml index a220dd8b0c..32d0d1453c 100644 --- a/unittest/force-styles/tests/bond-lepton.yaml +++ b/unittest/force-styles/tests/bond-lepton.yaml @@ -1,7 +1,6 @@ --- -lammps_version: 3 Nov 2022 -tags: generated -date_generated: Thu Dec 22 02:03:59 2022 +lammps_version: 22 Dec 2022 +date_generated: Thu Dec 22 09:47:41 2022 epsilon: 2.5e-13 skip_tests: prerequisites: ! | @@ -21,7 +20,7 @@ equilibrium: 5 1.5 1.1 1.3 1.2 1 extract: ! | r0 1 natoms: 29 -init_energy: 38.295825321689215 +init_energy: -1.7041746783107878 init_stress: ! |- -4.7778964706834920e+01 -9.3066674567350432e+01 3.4789470658440035e+02 -3.0023920169312170e+01 -8.0421418879842847e+01 5.8592449335969732e+01 init_forces: ! |2 @@ -54,7 +53,7 @@ init_forces: ! |2 27 6.6999960289138851e+00 6.3808952243186141e+00 2.0100808779497248e+00 28 -8.8466157439236681e-01 3.8018717064230995e-01 -5.9857060538593476e-01 29 -5.8153344545215182e+00 -6.7610823949609244e+00 -1.4115102725637900e+00 -run_energy: 37.78424389351509 +run_energy: -2.215756106484914 run_stress: ! |- -4.6127506998693484e+01 -9.2129732247211749e+01 3.4548310342284810e+02 -2.9841348469661163e+01 -7.8434962689387717e+01 5.9253167412123155e+01 run_forces: ! |2 diff --git a/unittest/force-styles/tests/mol-pair-lepton.yaml b/unittest/force-styles/tests/mol-pair-lepton.yaml index d58780b1fe..81d8286588 100644 --- a/unittest/force-styles/tests/mol-pair-lepton.yaml +++ b/unittest/force-styles/tests/mol-pair-lepton.yaml @@ -1,7 +1,6 @@ --- -lammps_version: 3 Nov 2022 -tags: generated -date_generated: Wed Dec 21 20:29:34 2022 +lammps_version: 22 Dec 2022 +date_generated: Thu Dec 22 09:57:30 2022 epsilon: 5e-14 skip_tests: prerequisites: ! | @@ -9,7 +8,8 @@ prerequisites: ! | pair lepton pre_commands: ! | variable write_data_pair index ij -post_commands: ! "" +post_commands: ! | + pair_modify shift yes input_file: in.fourmol pair_style: lepton 8.0 pair_coeff: ! | @@ -27,7 +27,7 @@ pair_coeff: ! | 3 5 "4.0*eps*((sig/r)^12-(sig/r)^6);eps=0.0173205;sig=3.15" extract: ! "" natoms: 29 -init_vdwl: 749.2370315373564 +init_vdwl: 749.2468149791969 init_coul: 0 init_stress: ! |2- 2.1793853434038242e+03 2.1988955172192768e+03 4.6653977523326257e+03 -7.5956547636050584e+02 2.4751536734032861e+01 6.6652028436400667e+02 @@ -61,7 +61,7 @@ init_forces: ! |2 27 5.1810388677546001e+01 -2.2705458321213797e+02 9.0849111082069669e+01 28 -1.8041307121444069e+02 7.7534042932772905e+01 -1.2206956760706598e+02 29 1.2861057254925012e+02 1.4952711274394568e+02 3.1216025556267880e+01 -run_vdwl: 719.4432816774653 +run_vdwl: 719.4530651193046 run_coul: 0 run_stress: ! |2- 2.1330153957371017e+03 2.1547728168285516e+03 4.3976497417710125e+03 -7.3873328448298525e+02 4.1743821105370067e+01 6.2788012209191027e+02 diff --git a/unittest/force-styles/tests/mol-pair-lj_cut.yaml b/unittest/force-styles/tests/mol-pair-lj_cut.yaml index 58bb0abf08..68bba170fe 100644 --- a/unittest/force-styles/tests/mol-pair-lj_cut.yaml +++ b/unittest/force-styles/tests/mol-pair-lj_cut.yaml @@ -1,6 +1,6 @@ --- -lammps_version: 17 Feb 2022 -date_generated: Fri Mar 18 22:17:31 2022 +lammps_version: 22 Dec 2022 +date_generated: Thu Dec 22 09:53:54 2022 epsilon: 5e-14 skip_tests: prerequisites: ! | @@ -9,6 +9,7 @@ prerequisites: ! | pre_commands: ! "" post_commands: ! | pair_modify mix arithmetic + pair_modify shift yes input_file: in.fourmol pair_style: lj/cut 8.0 pair_coeff: ! | @@ -22,7 +23,7 @@ extract: ! | epsilon 2 sigma 2 natoms: 29 -init_vdwl: 749.2372261744105 +init_vdwl: 749.2470096189502 init_coul: 0 init_stress: ! |2- 2.1793857186503233e+03 2.1988957679770601e+03 4.6653994738862330e+03 -7.5956544622684294e+02 2.4751393539192360e+01 6.6652061873806701e+02 @@ -56,7 +57,7 @@ init_forces: ! |2 27 5.1810412832327984e+01 -2.2705468907750401e+02 9.0849153441059272e+01 28 -1.8041315533250560e+02 7.7534079082878250e+01 -1.2206962452216491e+02 29 1.2861063251415729e+02 1.4952718246094855e+02 3.1216040111076961e+01 -run_vdwl: 719.4434555542921 +run_vdwl: 719.4532389988314 run_coul: 0 run_stress: ! |2- 2.1330157554553721e+03 2.1547730555430498e+03 4.3976512412988704e+03 -7.3873325485023690e+02 4.1743707190786367e+01 6.2788040986774604e+02 From d9b1e318e88e05d801dbfdf6e0f0ccbb4b411e5d Mon Sep 17 00:00:00 2001 From: Axel Kohlmeyer Date: Thu, 22 Dec 2022 15:48:24 -0500 Subject: [PATCH 21/79] add documentation for LEPTON package and lepton pair and bond style --- doc/src/Build_extras.rst | 45 +++++++++++++++++++++ doc/src/Commands_bond.rst | 1 + doc/src/Commands_pair.rst | 1 + doc/src/Packages_details.rst | 38 +++++++++++++++++ doc/src/Packages_list.rst | 5 +++ doc/src/bond_style.rst | 3 +- doc/src/pair_style.rst | 1 + doc/utils/sphinx-config/false_positives.txt | 3 ++ 8 files changed, 96 insertions(+), 1 deletion(-) diff --git a/doc/src/Build_extras.rst b/doc/src/Build_extras.rst index 910dccbb8a..976e6e723d 100644 --- a/doc/src/Build_extras.rst +++ b/doc/src/Build_extras.rst @@ -43,6 +43,7 @@ This is the list of packages that may require additional steps. * :ref:`KIM ` * :ref:`KOKKOS ` * :ref:`LATTE ` + * :ref:`LEPTON ` * :ref:`MACHDYN ` * :ref:`MDI ` * :ref:`MESONT ` @@ -873,6 +874,50 @@ library. ---------- +.. _lepton: + +LEPTON package +-------------- + +To build with this package, you must build the Lepton library which is +included in the LAMMPS source distribution in the ``lib/lepton`` folder. + +.. tabs:: + + .. tab:: CMake build + + This is the recommended build procedure for using Lepton in + LAMMPS. No additional settings are normally needed besides + ``-D PKG_LEPTON=yes``. + + .. tab:: Traditional make + + Before building LAMMPS, one must build the Lepton library in lib/lepton. + + This can be done manually in the same folder by using or adapting + one of the provided Makefiles: for example, ``Makefile.serial`` for + the GNU C++ compiler, or ``Makefile.mpi`` for the MPI compiler wrapper. + The Lepton library is written in C++-11 and thus the C++ compiler + may need to be instructed to enable support for that. + + In general, it is safer to use build setting consistent with the + rest of LAMMPS. This is best carried out from the LAMMPS src + directory using a command like these, which simply invokes the + ``lib/lepton/Install.py`` script with the specified args: + + .. code-block:: bash + + $ make lib-lepton # print help message + $ make lib-lepton args="-m serial" # build with GNU g++ compiler (settings as with "make serial") + $ make lib-lepton args="-m mpi" # build with default MPI compiler (settings as with "make mpi") + + The "machine" argument of the "-m" flag is used to find a + Makefile.machine to use as build recipe. + + The build should produce a ``build`` folder and the library ``lib/lepton/liblmplepton.a`` + +---------- + .. _mliap: ML-IAP package diff --git a/doc/src/Commands_bond.rst b/doc/src/Commands_bond.rst index ac2d5882fb..f5e5edcc5a 100644 --- a/doc/src/Commands_bond.rst +++ b/doc/src/Commands_bond.rst @@ -44,6 +44,7 @@ OPT. * :doc:`harmonic (iko) ` * :doc:`harmonic/shift (o) ` * :doc:`harmonic/shift/cut (o) ` + * :doc:`lepton (o) ` * :doc:`mesocnt ` * :doc:`mm3 ` * :doc:`morse (o) ` diff --git a/doc/src/Commands_pair.rst b/doc/src/Commands_pair.rst index c6d54f0683..8a5f05d095 100644 --- a/doc/src/Commands_pair.rst +++ b/doc/src/Commands_pair.rst @@ -134,6 +134,7 @@ OPT. * :doc:`lcbop ` * :doc:`lebedeva/z ` * :doc:`lennard/mdf ` + * :doc:`lepton (o) ` * :doc:`line/lj ` * :doc:`lj/charmm/coul/charmm (giko) ` * :doc:`lj/charmm/coul/charmm/implicit (ko) ` diff --git a/doc/src/Packages_details.rst b/doc/src/Packages_details.rst index 74ddb066c6..96ab174a10 100644 --- a/doc/src/Packages_details.rst +++ b/doc/src/Packages_details.rst @@ -68,6 +68,7 @@ page gives those details. * :ref:`KSPACE ` * :ref:`LATBOLTZ ` * :ref:`LATTE ` + * :ref:`LEPTON ` * :ref:`MACHDYN ` * :ref:`MANIFOLD ` * :ref:`MANYBODY ` @@ -1384,6 +1385,43 @@ This package has :ref:`specific installation instructions ` on the :doc:` ---------- +.. _PKG-LEPTON: + +LEPTON package +-------------- + +**Contents:** + +Styles for pair, bond forces that evaluate the potential function from a +string using the `Lepton mathematical expression parser +`_. Lepton is a C++ library that is +bundled with `OpenMM `_ and can be used for +parsing, evaluating, differentiating, and analyzing mathematical +expressions. This is a more lightweight and efficient alternative +for evaluating custom potential function to an embedded Python +interpreter as used in the :ref:`PYTHON package `. +On the other hand, since the potentials are evaluated form analytical +expressions, they are more accurate than what can be done with +:ref:`tabulated potentials `. Using the runtime evaluation +comes with a significant increase in runtime. + +**Authors:** Axel Kohlmeyer (Temple U). Lepton itself is developed +by Peter Eastman at Stanford University. + +**Install:** + +This package has :ref:`specific installation instructions ` on +the :doc:`Build extras ` page. + +**Supporting info:** + +* src/LEPTON: filenames -> commands +* lib/lepton/README.md +* :doc:`pair_style lepton ` +* :doc:`bond_style lepton ` + +---------- + .. _PKG-MACHDYN: MACHDYN package diff --git a/doc/src/Packages_list.rst b/doc/src/Packages_list.rst index ac0ba7728e..b1483bd954 100644 --- a/doc/src/Packages_list.rst +++ b/doc/src/Packages_list.rst @@ -238,6 +238,11 @@ whether an extra library is needed to build and use the package: - :doc:`fix latte ` - latte - ext + * - :ref:`LEPTON ` + - evaluate strings as potential function + - :doc:`pair_style lepton ` + - PACKAGES/lepton + - int * - :ref:`MACHDYN ` - smoothed Mach dynamics - `SMD User Guide `_ diff --git a/doc/src/bond_style.rst b/doc/src/bond_style.rst index 9197e6c4eb..23b89d00a2 100644 --- a/doc/src/bond_style.rst +++ b/doc/src/bond_style.rst @@ -10,7 +10,7 @@ Syntax bond_style style args -* style = *none* or *zero* or *hybrid* or *bpm/rotational* or *bpm/spring* or *class2* or *fene* or *fene/expand* or *fene/nm* or *gaussian* or *gromos* or *harmonic* or *harmonic/shift* or *harmonic/shift/cut* or *morse* or *nonlinear* or *oxdna/fene* or *oxdena2/fene* or *oxrna2/fene* or *quartic* or *special* or *table* +* style = *none* or *zero* or *hybrid* or *bpm/rotational* or *bpm/spring* or *class2* or *fene* or *fene/expand* or *fene/nm* or *gaussian* or *gromos* or *harmonic* or *harmonic/shift* or *harmonic/shift/cut* or *lepton* or *morse* or *nonlinear* or *oxdna/fene* or *oxdena2/fene* or *oxrna2/fene* or *quartic* or *special* or *table* * args = none for any style except *hybrid* @@ -95,6 +95,7 @@ accelerated styles exist. * :doc:`harmonic ` - harmonic bond * :doc:`harmonic/shift ` - shifted harmonic bond * :doc:`harmonic/shift/cut ` - shifted harmonic bond with a cutoff +* :doc:`lepton ` - bond potential from evaluating a string * :doc:`mesocnt ` - Harmonic bond wrapper with parameterization presets for nanotubes * :doc:`mm3 ` - MM3 anharmonic bond * :doc:`morse ` - Morse bond diff --git a/doc/src/pair_style.rst b/doc/src/pair_style.rst index 48daf34f17..ac8888f8ad 100644 --- a/doc/src/pair_style.rst +++ b/doc/src/pair_style.rst @@ -212,6 +212,7 @@ accelerated styles exist. * :doc:`lcbop ` - long-range bond-order potential (LCBOP) * :doc:`lebedeva/z ` - Lebedeva interlayer potential for graphene with normals along z-axis * :doc:`lennard/mdf ` - LJ potential in A/B form with a taper function +* :doc:`lepton ` - pair potential from evaluating a string * :doc:`line/lj ` - LJ potential between line segments * :doc:`list ` - potential between pairs of atoms explicitly listed in an input file * :doc:`lj/charmm/coul/charmm ` - CHARMM potential with cutoff Coulomb diff --git a/doc/utils/sphinx-config/false_positives.txt b/doc/utils/sphinx-config/false_positives.txt index 7cb409d040..c879bdb244 100644 --- a/doc/utils/sphinx-config/false_positives.txt +++ b/doc/utils/sphinx-config/false_positives.txt @@ -553,6 +553,7 @@ corotate corotation corotational correlator +Cosecant cosineshifted cossq costheta @@ -585,6 +586,7 @@ Crozier Cryst Crystallogr Csanyi +csc csg csh cshrc @@ -3274,6 +3276,7 @@ Simul simulations Sinkovits Sinnott +sinh sinusoid sinusoidally SiO From 91c498c4130109c5c5b9b179e1710da5f616d4e1 Mon Sep 17 00:00:00 2001 From: Axel Kohlmeyer Date: Thu, 22 Dec 2022 16:32:17 -0500 Subject: [PATCH 22/79] suppres explicit exports/import in Lepton lib --- cmake/Modules/Packages/LEPTON.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/Modules/Packages/LEPTON.cmake b/cmake/Modules/Packages/LEPTON.cmake index 05241de592..a1a74f3aa9 100644 --- a/cmake/Modules/Packages/LEPTON.cmake +++ b/cmake/Modules/Packages/LEPTON.cmake @@ -2,7 +2,7 @@ set(LEPTON_SOURCE_DIR ${LAMMPS_LIB_SOURCE_DIR}/lepton) file(GLOB LEPTON_SOURCES ${LEPTON_SOURCE_DIR}/src/[^.]*.cpp) add_library(lmplepton STATIC ${LEPTON_SOURCES}) -target_compile_definitions(lmplepton PRIVATE -DLEPTON_BUILDING_STATIC_LIBRARY) +target_compile_definitions(lmplepton PUBLIC -DLEPTON_BUILDING_STATIC_LIBRARY=1) set_target_properties(lmplepton PROPERTIES OUTPUT_NAME lammps_lmplepton${LAMMPS_MACHINE}) target_include_directories(lmplepton PUBLIC ${LEPTON_SOURCE_DIR}/include) target_link_libraries(lammps PRIVATE lmplepton) From ca27fb3a981f59a3fb68c151a2113992f007bdb9 Mon Sep 17 00:00:00 2001 From: Axel Kohlmeyer Date: Thu, 22 Dec 2022 22:16:17 -0500 Subject: [PATCH 23/79] update Lepton to current master branch --- .../include/lepton/CompiledExpression.h | 18 +- .../include/lepton/CompiledVectorExpression.h | 145 +++ .../include/lepton/ExpressionTreeNode.h | 8 +- lib/lepton/include/lepton/ParsedExpression.h | 19 +- lib/lepton/src/CompiledExpression.cpp | 542 ++++++++-- lib/lepton/src/CompiledVectorExpression.cpp | 933 ++++++++++++++++++ lib/lepton/src/ExpressionTreeNode.cpp | 48 +- lib/lepton/src/Operation.cpp | 130 ++- lib/lepton/src/ParsedExpression.cpp | 85 +- 9 files changed, 1801 insertions(+), 127 deletions(-) create mode 100644 lib/lepton/include/lepton/CompiledVectorExpression.h create mode 100644 lib/lepton/src/CompiledVectorExpression.cpp diff --git a/lib/lepton/include/lepton/CompiledExpression.h b/lib/lepton/include/lepton/CompiledExpression.h index 8ead5ce96f..6c940e081c 100644 --- a/lib/lepton/include/lepton/CompiledExpression.h +++ b/lib/lepton/include/lepton/CompiledExpression.h @@ -9,7 +9,7 @@ * Biological Structures at Stanford, funded under the NIH Roadmap for * * Medical Research, grant U54 GM072970. See https://simtk.org. * * * - * Portions copyright (c) 2013-2019 Stanford University and the Authors. * + * Portions copyright (c) 2013-2022 Stanford University and the Authors. * * Authors: Peter Eastman * * Contributors: * * * @@ -40,7 +40,11 @@ #include #include #ifdef LEPTON_USE_JIT - #include "asmjit.h" +#if defined(__ARM__) || defined(__ARM64__) +#include "asmjit/a64.h" +#else +#include "asmjit/x86.h" +#endif #endif namespace LMP_Lepton { @@ -101,9 +105,15 @@ private: std::map dummyVariables; double (*jitCode)(); #ifdef LEPTON_USE_JIT + void findPowerGroups(std::vector >& groups, std::vector >& groupPowers, std::vector& stepGroup); void generateJitCode(); - void generateSingleArgCall(asmjit::X86Compiler& c, asmjit::X86Xmm& dest, asmjit::X86Xmm& arg, double (*function)(double)); - void generateTwoArgCall(asmjit::X86Compiler& c, asmjit::X86Xmm& dest, asmjit::X86Xmm& arg1, asmjit::X86Xmm& arg2, double (*function)(double, double)); +#if defined(__ARM__) || defined(__ARM64__) + void generateSingleArgCall(asmjit::a64::Compiler& c, asmjit::arm::Vec& dest, asmjit::arm::Vec& arg, double (*function)(double)); + void generateTwoArgCall(asmjit::a64::Compiler& c, asmjit::arm::Vec& dest, asmjit::arm::Vec& arg1, asmjit::arm::Vec& arg2, double (*function)(double, double)); +#else + void generateSingleArgCall(asmjit::x86::Compiler& c, asmjit::x86::Xmm& dest, asmjit::x86::Xmm& arg, double (*function)(double)); + void generateTwoArgCall(asmjit::x86::Compiler& c, asmjit::x86::Xmm& dest, asmjit::x86::Xmm& arg1, asmjit::x86::Xmm& arg2, double (*function)(double, double)); +#endif std::vector constants; asmjit::JitRuntime runtime; #endif diff --git a/lib/lepton/include/lepton/CompiledVectorExpression.h b/lib/lepton/include/lepton/CompiledVectorExpression.h new file mode 100644 index 0000000000..e097e3eae1 --- /dev/null +++ b/lib/lepton/include/lepton/CompiledVectorExpression.h @@ -0,0 +1,145 @@ +#ifndef LEPTON_VECTOR_EXPRESSION_H_ +#define LEPTON_VECTOR_EXPRESSION_H_ + +/* -------------------------------------------------------------------------- * + * Lepton * + * -------------------------------------------------------------------------- * + * This is part of the Lepton expression parser originating from * + * Simbios, the NIH National Center for Physics-Based Simulation of * + * Biological Structures at Stanford, funded under the NIH Roadmap for * + * Medical Research, grant U54 GM072970. See https://simtk.org. * + * * + * Portions copyright (c) 2013-2022 Stanford University and the Authors. * + * Authors: Peter Eastman * + * Contributors: * + * * + * Permission is hereby granted, free of charge, to any person obtaining a * + * copy of this software and associated documentation files (the "Software"), * + * to deal in the Software without restriction, including without limitation * + * the rights to use, copy, modify, merge, publish, distribute, sublicense, * + * and/or sell copies of the Software, and to permit persons to whom the * + * Software is furnished to do so, subject to the following conditions: * + * * + * The above copyright notice and this permission notice shall be included in * + * all copies or substantial portions of the Software. * + * * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL * + * THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, * + * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR * + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE * + * USE OR OTHER DEALINGS IN THE SOFTWARE. * + * -------------------------------------------------------------------------- */ + +#include "ExpressionTreeNode.h" +#include "windowsIncludes.h" +#include +#include +#include +#include +#include +#include +#ifdef LEPTON_USE_JIT +#if defined(__ARM__) || defined(__ARM64__) +#include "asmjit/a64.h" +#else +#include "asmjit/x86.h" +#endif +#endif + +namespace LMP_Lepton { + +class Operation; +class ParsedExpression; + +/** + * A CompiledVectorExpression is a highly optimized representation of an expression for cases when you want to evaluate + * it many times as quickly as possible. It is similar to CompiledExpression, with the extra feature that it uses the CPU's + * vector unit (AVX on x86, NEON on ARM) to evaluate the expression for multiple sets of arguments at once. It also differs + * from CompiledExpression and ParsedExpression in using single precision rather than double precision to evaluate the expression. + * You should treat it as an opaque object; none of the internal representation is visible. + * + * A CompiledVectorExpression is created by calling createCompiledVectorExpression() on a ParsedExpression. When you create + * it, you must specify the width of the vectors on which to compute the expression. The allowed widths depend on the type of + * CPU it is running on. 4 is always allowed, and 8 is allowed on x86 processors with AVX. Call getAllowedWidths() to query + * the allowed values. + * + * WARNING: CompiledVectorExpression is NOT thread safe. You should never access a CompiledVectorExpression from two threads at + * the same time. + */ + +class LEPTON_EXPORT CompiledVectorExpression { +public: + CompiledVectorExpression(); + CompiledVectorExpression(const CompiledVectorExpression& expression); + ~CompiledVectorExpression(); + CompiledVectorExpression& operator=(const CompiledVectorExpression& expression); + /** + * Get the width of the vectors on which the expression is computed. + */ + int getWidth() const; + /** + * Get the names of all variables used by this expression. + */ + const std::set& getVariables() const; + /** + * Get a pointer to the memory location where the value of a particular variable is stored. This can be used + * to set the value of the variable before calling evaluate(). + * + * @param name the name of the variable to query + * @return a pointer to N floating point values, where N is the vector width + */ + float* getVariablePointer(const std::string& name); + /** + * You can optionally specify the memory locations from which the values of variables should be read. + * This is useful, for example, when several expressions all use the same variable. You can then set + * the value of that variable in one place, and it will be seen by all of them. The location should + * be a pointer to N floating point values, where N is the vector width. + */ + void setVariableLocations(std::map& variableLocations); + /** + * Evaluate the expression. The values of all variables should have been set before calling this. + * + * @return a pointer to N floating point values, where N is the vector width + */ + const float* evaluate() const; + /** + * Get the list of vector widths that are supported on the current processor. + */ + static const std::vector& getAllowedWidths(); +private: + friend class ParsedExpression; + CompiledVectorExpression(const ParsedExpression& expression, int width); + void compileExpression(const ExpressionTreeNode& node, std::vector >& temps, int& workspaceSize); + int findTempIndex(const ExpressionTreeNode& node, std::vector >& temps); + int width; + std::map variablePointers; + std::vector > variablesToCopy; + std::vector > arguments; + std::vector target; + std::vector operation; + std::map variableIndices; + std::set variableNames; + mutable std::vector workspace; + mutable std::vector argValues; + std::map dummyVariables; + void (*jitCode)(); +#ifdef LEPTON_USE_JIT + void findPowerGroups(std::vector >& groups, std::vector >& groupPowers, std::vector& stepGroup); + void generateJitCode(); +#if defined(__ARM__) || defined(__ARM64__) + void generateSingleArgCall(asmjit::a64::Compiler& c, asmjit::arm::Vec& dest, asmjit::arm::Vec& arg, float (*function)(float)); + void generateTwoArgCall(asmjit::a64::Compiler& c, asmjit::arm::Vec& dest, asmjit::arm::Vec& arg1, asmjit::arm::Vec& arg2, float (*function)(float, float)); +#else + void generateSingleArgCall(asmjit::x86::Compiler& c, asmjit::x86::Ymm& dest, asmjit::x86::Ymm& arg, float (*function)(float)); + void generateTwoArgCall(asmjit::x86::Compiler& c, asmjit::x86::Ymm& dest, asmjit::x86::Ymm& arg1, asmjit::x86::Ymm& arg2, float (*function)(float, float)); +#endif + std::vector constants; + asmjit::JitRuntime runtime; +#endif +}; + +} // namespace LMP_Lepton + +#endif /*LEPTON_VECTOR_EXPRESSION_H_*/ diff --git a/lib/lepton/include/lepton/ExpressionTreeNode.h b/lib/lepton/include/lepton/ExpressionTreeNode.h index 514cc008a9..eba791fbaa 100644 --- a/lib/lepton/include/lepton/ExpressionTreeNode.h +++ b/lib/lepton/include/lepton/ExpressionTreeNode.h @@ -9,7 +9,7 @@ * Biological Structures at Stanford, funded under the NIH Roadmap for * * Medical Research, grant U54 GM072970. See https://simtk.org. * * * - * Portions copyright (c) 2009 Stanford University and the Authors. * + * Portions copyright (c) 2009-2021 Stanford University and the Authors. * * Authors: Peter Eastman * * Contributors: * * * @@ -39,6 +39,7 @@ namespace LMP_Lepton { class Operation; +class ParsedExpression; /** * This class represents a node in the abstract syntax tree representation of an expression. @@ -82,11 +83,13 @@ public: */ ExpressionTreeNode(Operation* operation); ExpressionTreeNode(const ExpressionTreeNode& node); + ExpressionTreeNode(ExpressionTreeNode&& node); ExpressionTreeNode(); ~ExpressionTreeNode(); bool operator==(const ExpressionTreeNode& node) const; bool operator!=(const ExpressionTreeNode& node) const; ExpressionTreeNode& operator=(const ExpressionTreeNode& node); + ExpressionTreeNode& operator=(ExpressionTreeNode&& node); /** * Get the Operation performed by this node. */ @@ -96,8 +99,11 @@ public: */ const std::vector& getChildren() const; private: + friend class ParsedExpression; + void assignTags(std::vector& examples) const; Operation* operation; std::vector children; + mutable int tag; }; } // namespace LMP_Lepton diff --git a/lib/lepton/include/lepton/ParsedExpression.h b/lib/lepton/include/lepton/ParsedExpression.h index 586acb4d2c..05081f677c 100644 --- a/lib/lepton/include/lepton/ParsedExpression.h +++ b/lib/lepton/include/lepton/ParsedExpression.h @@ -9,7 +9,7 @@ * Biological Structures at Stanford, funded under the NIH Roadmap for * * Medical Research, grant U54 GM072970. See https://simtk.org. * * * - * Portions copyright (c) 2009=2013 Stanford University and the Authors. * + * Portions copyright (c) 2009-2022 Stanford University and the Authors. * * Authors: Peter Eastman * * Contributors: * * * @@ -41,6 +41,7 @@ namespace LMP_Lepton { class CompiledExpression; class ExpressionProgram; +class CompiledVectorExpression; /** * This class represents the result of parsing an expression. It provides methods for working with the @@ -102,6 +103,16 @@ public: * Create a CompiledExpression that represents the same calculation as this expression. */ CompiledExpression createCompiledExpression() const; + /** + * Create a CompiledVectorExpression that allows the expression to be evaluated efficiently + * using the CPU's vector unit. + * + * @param width the width of the vectors to evaluate it on. The allowed values + * depend on the CPU. 4 is always allowed, and 8 is allowed on + * x86 processors with AVX. Call CompiledVectorExpression::getAllowedWidths() + * to query the allowed widths on the current processor. + */ + CompiledVectorExpression createCompiledVectorExpression(int width) const; /** * Create a new ParsedExpression which is identical to this one, except that the names of some * variables have been changed. @@ -113,9 +124,9 @@ public: private: static double evaluate(const ExpressionTreeNode& node, const std::map& variables); static ExpressionTreeNode preevaluateVariables(const ExpressionTreeNode& node, const std::map& variables); - static ExpressionTreeNode precalculateConstantSubexpressions(const ExpressionTreeNode& node); - static ExpressionTreeNode substituteSimplerExpression(const ExpressionTreeNode& node); - static ExpressionTreeNode differentiate(const ExpressionTreeNode& node, const std::string& variable); + static ExpressionTreeNode precalculateConstantSubexpressions(const ExpressionTreeNode& node, std::map& nodeCache); + static ExpressionTreeNode substituteSimplerExpression(const ExpressionTreeNode& node, std::map& nodeCache); + static ExpressionTreeNode differentiate(const ExpressionTreeNode& node, const std::string& variable, std::map& nodeCache); static bool isConstant(const ExpressionTreeNode& node); static double getConstantValue(const ExpressionTreeNode& node); static ExpressionTreeNode renameNodeVariables(const ExpressionTreeNode& node, const std::map& replacements); diff --git a/lib/lepton/src/CompiledExpression.cpp b/lib/lepton/src/CompiledExpression.cpp index c6c1543ce4..b85c3a08f7 100644 --- a/lib/lepton/src/CompiledExpression.cpp +++ b/lib/lepton/src/CompiledExpression.cpp @@ -6,7 +6,7 @@ * Biological Structures at Stanford, funded under the NIH Roadmap for * * Medical Research, grant U54 GM072970. See https://simtk.org. * * * - * Portions copyright (c) 2013-2019 Stanford University and the Authors. * + * Portions copyright (c) 2013-2022 Stanford University and the Authors. * * Authors: Peter Eastman * * Contributors: * * * @@ -151,7 +151,7 @@ void CompiledExpression::setVariableLocations(map& variableLoca if (workspace.size() > 0) generateJitCode(); -#else +#endif // Make a list of all variables we will need to copy before evaluating the expression. variablesToCopy.clear(); @@ -160,13 +160,11 @@ void CompiledExpression::setVariableLocations(map& variableLoca if (pointer != variablePointers.end()) variablesToCopy.push_back(make_pair(&workspace[iter->second], pointer->second)); } -#endif } double CompiledExpression::evaluate() const { -#ifdef LEPTON_USE_JIT - return jitCode(); -#else + if (jitCode) + return jitCode(); for (int i = 0; i < (int)variablesToCopy.size(); i++) *variablesToCopy[i].first = *variablesToCopy[i].second; @@ -183,7 +181,6 @@ double CompiledExpression::evaluate() const { } } return workspace[workspace.size()-1]; -#endif } #ifdef LEPTON_USE_JIT @@ -192,24 +189,70 @@ static double evaluateOperation(Operation* op, double* args) { return op->evaluate(args, dummyVariables); } +void CompiledExpression::findPowerGroups(vector >& groups, vector >& groupPowers, vector& stepGroup) { + // Identify every step that raises an argument to an integer power. + + vector stepPower(operation.size(), 0); + vector stepArg(operation.size(), -1); + for (int step = 0; step < (int)operation.size(); step++) { + Operation& op = *operation[step]; + int power = 0; + if (op.getId() == Operation::SQUARE) + power = 2; + else if (op.getId() == Operation::CUBE) + power = 3; + else if (op.getId() == Operation::POWER_CONSTANT) { + double realPower = dynamic_cast(&op)->getValue(); + if (realPower == (int) realPower) + power = (int) realPower; + } + if (power != 0) { + stepPower[step] = power; + stepArg[step] = arguments[step][0]; + } + } + + // Find groups that operate on the same argument and whose powers have the same sign. + + stepGroup.resize(operation.size(), -1); + for (int i = 0; i < (int)operation.size(); i++) { + if (stepGroup[i] != -1) + continue; + vector group, power; + for (int j = i; j < (int)operation.size(); j++) { + if (stepArg[i] == stepArg[j] && stepPower[i]*stepPower[j] > 0) { + stepGroup[j] = groups.size(); + group.push_back(j); + power.push_back(stepPower[j]); + } + } + groups.push_back(group); + groupPowers.push_back(power); + } +} + +#if defined(__ARM__) || defined(__ARM64__) void CompiledExpression::generateJitCode() { CodeHolder code; - code.init(runtime.getCodeInfo()); - X86Compiler c(&code); - c.addFunc(FuncSignature0()); - vector workspaceVar(workspace.size()); + code.init(runtime.environment()); + a64::Compiler c(&code); + c.addFunc(FuncSignatureT()); + vector workspaceVar(workspace.size()); for (int i = 0; i < (int) workspaceVar.size(); i++) - workspaceVar[i] = c.newXmmSd(); - X86Gp argsPointer = c.newIntPtr(); - c.mov(argsPointer, imm_ptr(&argValues[0])); + workspaceVar[i] = c.newVecD(); + arm::Gp argsPointer = c.newIntPtr(); + c.mov(argsPointer, imm(&argValues[0])); + vector > groups, groupPowers; + vector stepGroup; + findPowerGroups(groups, groupPowers, stepGroup); // Load the arguments into variables. for (set::const_iterator iter = variableNames.begin(); iter != variableNames.end(); ++iter) { map::iterator index = variableIndices.find(*iter); - X86Gp variablePointer = c.newIntPtr(); - c.mov(variablePointer, imm_ptr(&getVariableReference(index->first))); - c.movsd(workspaceVar[index->second], x86::ptr(variablePointer, 0, 0)); + arm::Gp variablePointer = c.newIntPtr(); + c.mov(variablePointer, imm(&getVariableReference(index->first))); + c.ldr(workspaceVar[index->second], arm::ptr(variablePointer, 0)); } // Make a list of all constants that will be needed for evaluation. @@ -232,6 +275,12 @@ void CompiledExpression::generateJitCode() { value = 1.0; else if (op.getId() == Operation::DELTA) value = 1.0; + else if (op.getId() == Operation::POWER_CONSTANT) { + if (stepGroup[step] == -1) + value = dynamic_cast(op).getValue(); + else + value = 1.0; + } else continue; @@ -250,19 +299,63 @@ void CompiledExpression::generateJitCode() { // Load constants into variables. - vector constantVar(constants.size()); + vector constantVar(constants.size()); if (constants.size() > 0) { - X86Gp constantsPointer = c.newIntPtr(); - c.mov(constantsPointer, imm_ptr(&constants[0])); + arm::Gp constantsPointer = c.newIntPtr(); + c.mov(constantsPointer, imm(&constants[0])); for (int i = 0; i < (int) constants.size(); i++) { - constantVar[i] = c.newXmmSd(); - c.movsd(constantVar[i], x86::ptr(constantsPointer, 8*i, 0)); + constantVar[i] = c.newVecD(); + c.ldr(constantVar[i], arm::ptr(constantsPointer, 8*i)); } } // Evaluate the operations. + vector hasComputedPower(operation.size(), false); for (int step = 0; step < (int) operation.size(); step++) { + if (hasComputedPower[step]) + continue; + + // When one or more steps involve raising the same argument to multiple integer + // powers, we can compute them all together for efficiency. + + if (stepGroup[step] != -1) { + vector& group = groups[stepGroup[step]]; + vector& powers = groupPowers[stepGroup[step]]; + arm::Vec multiplier = c.newVecD(); + if (powers[0] > 0) + c.fmov(multiplier, workspaceVar[arguments[step][0]]); + else { + c.fdiv(multiplier, constantVar[operationConstantIndex[step]], workspaceVar[arguments[step][0]]); + for (int i = 0; i < powers.size(); i++) + powers[i] = -powers[i]; + } + vector hasAssigned(group.size(), false); + bool done = false; + while (!done) { + done = true; + for (int i = 0; i < group.size(); i++) { + if (powers[i]%2 == 1) { + if (!hasAssigned[i]) + c.fmov(workspaceVar[target[group[i]]], multiplier); + else + c.fmul(workspaceVar[target[group[i]]], workspaceVar[target[group[i]]], multiplier); + hasAssigned[i] = true; + } + powers[i] >>= 1; + if (powers[i] != 0) + done = false; + } + if (!done) + c.fmul(multiplier, multiplier, multiplier); + } + for (int step : group) + hasComputedPower[step] = true; + continue; + } + + // Evaluate the step. + Operation& op = *operation[step]; vector args = arguments[step]; if (args.size() == 1) { @@ -276,33 +369,28 @@ void CompiledExpression::generateJitCode() { switch (op.getId()) { case Operation::CONSTANT: - c.movsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + c.fmov(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); break; case Operation::ADD: - c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); - c.addsd(workspaceVar[target[step]], workspaceVar[args[1]]); + c.fadd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); break; case Operation::SUBTRACT: - c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); - c.subsd(workspaceVar[target[step]], workspaceVar[args[1]]); + c.fsub(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); break; case Operation::MULTIPLY: - c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); - c.mulsd(workspaceVar[target[step]], workspaceVar[args[1]]); + c.fmul(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); break; case Operation::DIVIDE: - c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); - c.divsd(workspaceVar[target[step]], workspaceVar[args[1]]); + c.fdiv(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); break; case Operation::POWER: generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], pow); break; case Operation::NEGATE: - c.xorps(workspaceVar[target[step]], workspaceVar[target[step]]); - c.subsd(workspaceVar[target[step]], workspaceVar[args[0]]); + c.fneg(workspaceVar[target[step]], workspaceVar[args[0]]); break; case Operation::SQRT: - c.sqrtsd(workspaceVar[target[step]], workspaceVar[args[0]]); + c.fsqrt(workspaceVar[target[step]], workspaceVar[args[0]]); break; case Operation::EXP: generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], exp); @@ -341,56 +429,63 @@ void CompiledExpression::generateJitCode() { generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tanh); break; case Operation::STEP: - c.xorps(workspaceVar[target[step]], workspaceVar[target[step]]); - c.cmpsd(workspaceVar[target[step]], workspaceVar[args[0]], imm(18)); // Comparison mode is _CMP_LE_OQ = 18 - c.andps(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + c.cmge(workspaceVar[target[step]], workspaceVar[args[0]], imm(0)); + c.and_(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); break; case Operation::DELTA: - c.xorps(workspaceVar[target[step]], workspaceVar[target[step]]); - c.cmpsd(workspaceVar[target[step]], workspaceVar[args[0]], imm(16)); // Comparison mode is _CMP_EQ_OS = 16 - c.andps(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + c.cmeq(workspaceVar[target[step]], workspaceVar[args[0]], imm(0)); + c.and_(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); break; case Operation::SQUARE: - c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); - c.mulsd(workspaceVar[target[step]], workspaceVar[args[0]]); + c.fmul(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]]); break; case Operation::CUBE: - c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); - c.mulsd(workspaceVar[target[step]], workspaceVar[args[0]]); - c.mulsd(workspaceVar[target[step]], workspaceVar[args[0]]); + c.fmul(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]]); + c.fmul(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]]); break; case Operation::RECIPROCAL: - c.movsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); - c.divsd(workspaceVar[target[step]], workspaceVar[args[0]]); + c.fdiv(workspaceVar[target[step]], constantVar[operationConstantIndex[step]], workspaceVar[args[0]]); break; case Operation::ADD_CONSTANT: - c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); - c.addsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + c.fadd(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]); break; case Operation::MULTIPLY_CONSTANT: - c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]); - c.mulsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + c.fmul(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]); + break; + case Operation::POWER_CONSTANT: + generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]], pow); + break; + case Operation::MIN: + c.fmin(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::MAX: + c.fmax(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); break; case Operation::ABS: - generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], fabs); + c.fabs(workspaceVar[target[step]], workspaceVar[args[0]]); break; case Operation::FLOOR: - generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], floor); + c.frintm(workspaceVar[target[step]], workspaceVar[args[0]]); break; case Operation::CEIL: - generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], ceil); + c.frintp(workspaceVar[target[step]], workspaceVar[args[0]]); + break; + case Operation::SELECT: + c.fcmeq(workspaceVar[target[step]], workspaceVar[args[0]], imm(0)); + c.bsl(workspaceVar[target[step]], workspaceVar[args[2]], workspaceVar[args[1]]); break; default: // Just invoke evaluateOperation(). for (int i = 0; i < (int) args.size(); i++) - c.movsd(x86::ptr(argsPointer, 8*i, 0), workspaceVar[args[i]]); - X86Gp fn = c.newIntPtr(); - c.mov(fn, imm_ptr((void*) evaluateOperation)); - CCFuncCall* call = c.call(fn, FuncSignature2()); - call->setArg(0, imm_ptr(&op)); - call->setArg(1, imm_ptr(&argValues[0])); - call->setRet(0, workspaceVar[target[step]]); + c.str(workspaceVar[args[i]], arm::ptr(argsPointer, 8*i)); + arm::Gp fn = c.newIntPtr(); + c.mov(fn, imm((void*) evaluateOperation)); + InvokeNode* invoke; + c.invoke(&invoke, fn, FuncSignatureT()); + invoke->setArg(0, imm(&op)); + invoke->setArg(1, imm(&argValues[0])); + invoke->setRet(0, workspaceVar[target[step]]); } } c.ret(workspaceVar[workspace.size()-1]); @@ -399,20 +494,319 @@ void CompiledExpression::generateJitCode() { runtime.add(&jitCode, &code); } -void CompiledExpression::generateSingleArgCall(X86Compiler& c, X86Xmm& dest, X86Xmm& arg, double (*function)(double)) { - X86Gp fn = c.newIntPtr(); - c.mov(fn, imm_ptr((void*) function)); - CCFuncCall* call = c.call(fn, FuncSignature1()); - call->setArg(0, arg); - call->setRet(0, dest); +void CompiledExpression::generateSingleArgCall(a64::Compiler& c, arm::Vec& dest, arm::Vec& arg, double (*function)(double)) { + arm::Gp fn = c.newIntPtr(); + c.mov(fn, imm((void*) function)); + InvokeNode* invoke; + c.invoke(&invoke, fn, FuncSignatureT()); + invoke->setArg(0, arg); + invoke->setRet(0, dest); } -void CompiledExpression::generateTwoArgCall(X86Compiler& c, X86Xmm& dest, X86Xmm& arg1, X86Xmm& arg2, double (*function)(double, double)) { - X86Gp fn = c.newIntPtr(); - c.mov(fn, imm_ptr((void*) function)); - CCFuncCall* call = c.call(fn, FuncSignature2()); - call->setArg(0, arg1); - call->setArg(1, arg2); - call->setRet(0, dest); +void CompiledExpression::generateTwoArgCall(a64::Compiler& c, arm::Vec& dest, arm::Vec& arg1, arm::Vec& arg2, double (*function)(double, double)) { + arm::Gp fn = c.newIntPtr(); + c.mov(fn, imm((void*) function)); + InvokeNode* invoke; + c.invoke(&invoke, fn, FuncSignatureT()); + invoke->setArg(0, arg1); + invoke->setArg(1, arg2); + invoke->setRet(0, dest); +} +#else +void CompiledExpression::generateJitCode() { + const CpuInfo& cpu = CpuInfo::host(); + if (!cpu.hasFeature(CpuFeatures::X86::kAVX)) + return; + CodeHolder code; + code.init(runtime.environment()); + x86::Compiler c(&code); + FuncNode* funcNode = c.addFunc(FuncSignatureT()); + funcNode->frame().setAvxEnabled(); + vector workspaceVar(workspace.size()); + for (int i = 0; i < (int) workspaceVar.size(); i++) + workspaceVar[i] = c.newXmmSd(); + x86::Gp argsPointer = c.newIntPtr(); + c.mov(argsPointer, imm(&argValues[0])); + vector > groups, groupPowers; + vector stepGroup; + findPowerGroups(groups, groupPowers, stepGroup); + + // Load the arguments into variables. + + x86::Gp variablePointer = c.newIntPtr(); + for (set::const_iterator iter = variableNames.begin(); iter != variableNames.end(); ++iter) { + map::iterator index = variableIndices.find(*iter); + c.mov(variablePointer, imm(&getVariableReference(index->first))); + c.vmovsd(workspaceVar[index->second], x86::ptr(variablePointer, 0, 0)); + } + + // Make a list of all constants that will be needed for evaluation. + + vector operationConstantIndex(operation.size(), -1); + for (int step = 0; step < (int) operation.size(); step++) { + // Find the constant value (if any) used by this operation. + + Operation& op = *operation[step]; + double value; + if (op.getId() == Operation::CONSTANT) + value = dynamic_cast(op).getValue(); + else if (op.getId() == Operation::ADD_CONSTANT) + value = dynamic_cast(op).getValue(); + else if (op.getId() == Operation::MULTIPLY_CONSTANT) + value = dynamic_cast(op).getValue(); + else if (op.getId() == Operation::RECIPROCAL) + value = 1.0; + else if (op.getId() == Operation::STEP) + value = 1.0; + else if (op.getId() == Operation::DELTA) + value = 1.0; + else if (op.getId() == Operation::ABS) { + long long mask = 0x7FFFFFFFFFFFFFFF; + value = *reinterpret_cast(&mask); + } + else if (op.getId() == Operation::POWER_CONSTANT) { + if (stepGroup[step] == -1) + value = dynamic_cast(op).getValue(); + else + value = 1.0; + } + else + continue; + + // See if we already have a variable for this constant. + + for (int i = 0; i < (int) constants.size(); i++) + if (value == constants[i]) { + operationConstantIndex[step] = i; + break; + } + if (operationConstantIndex[step] == -1) { + operationConstantIndex[step] = constants.size(); + constants.push_back(value); + } + } + + // Load constants into variables. + + vector constantVar(constants.size()); + if (constants.size() > 0) { + x86::Gp constantsPointer = c.newIntPtr(); + c.mov(constantsPointer, imm(&constants[0])); + for (int i = 0; i < (int) constants.size(); i++) { + constantVar[i] = c.newXmmSd(); + c.vmovsd(constantVar[i], x86::ptr(constantsPointer, 8*i, 0)); + } + } + + // Evaluate the operations. + + vector hasComputedPower(operation.size(), false); + for (int step = 0; step < (int) operation.size(); step++) { + if (hasComputedPower[step]) + continue; + + // When one or more steps involve raising the same argument to multiple integer + // powers, we can compute them all together for efficiency. + + if (stepGroup[step] != -1) { + vector& group = groups[stepGroup[step]]; + vector& powers = groupPowers[stepGroup[step]]; + x86::Xmm multiplier = c.newXmmSd(); + if (powers[0] > 0) + c.vmovsd(multiplier, workspaceVar[arguments[step][0]], workspaceVar[arguments[step][0]]); + else { + c.vdivsd(multiplier, constantVar[operationConstantIndex[step]], workspaceVar[arguments[step][0]]); + for (int i = 0; i < (int)powers.size(); i++) + powers[i] = -powers[i]; + } + vector hasAssigned(group.size(), false); + bool done = false; + while (!done) { + done = true; + for (int i = 0; i < (int)group.size(); i++) { + if (powers[i]%2 == 1) { + if (!hasAssigned[i]) + c.vmovsd(workspaceVar[target[group[i]]], multiplier, multiplier); + else + c.vmulsd(workspaceVar[target[group[i]]], workspaceVar[target[group[i]]], multiplier); + hasAssigned[i] = true; + } + powers[i] >>= 1; + if (powers[i] != 0) + done = false; + } + if (!done) + c.vmulsd(multiplier, multiplier, multiplier); + } + for (int step : group) + hasComputedPower[step] = true; + continue; + } + + // Evaluate the step. + + Operation& op = *operation[step]; + vector args = arguments[step]; + if (args.size() == 1) { + // One or more sequential arguments. Fill out the list. + + for (int i = 1; i < op.getNumArguments(); i++) + args.push_back(args[0]+i); + } + + // Generate instructions to execute this operation. + + switch (op.getId()) { + case Operation::CONSTANT: + c.vmovsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]], constantVar[operationConstantIndex[step]]); + break; + case Operation::ADD: + c.vaddsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::SUBTRACT: + c.vsubsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::MULTIPLY: + c.vmulsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::DIVIDE: + c.vdivsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::POWER: + generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], pow); + break; + case Operation::NEGATE: + c.vxorps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[target[step]]); + c.vsubsd(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]]); + break; + case Operation::SQRT: + c.vsqrtsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]]); + break; + case Operation::EXP: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], exp); + break; + case Operation::LOG: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], log); + break; + case Operation::SIN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], sin); + break; + case Operation::COS: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], cos); + break; + case Operation::TAN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tan); + break; + case Operation::ASIN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], asin); + break; + case Operation::ACOS: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], acos); + break; + case Operation::ATAN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], atan); + break; + case Operation::ATAN2: + generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], atan2); + break; + case Operation::SINH: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], sinh); + break; + case Operation::COSH: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], cosh); + break; + case Operation::TANH: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tanh); + break; + case Operation::STEP: + c.vxorps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[target[step]]); + c.vcmpsd(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]], imm(18)); // Comparison mode is _CMP_LE_OQ = 18 + c.vandps(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + break; + case Operation::DELTA: + c.vxorps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[target[step]]); + c.vcmpsd(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]], imm(16)); // Comparison mode is _CMP_EQ_OS = 16 + c.vandps(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + break; + case Operation::SQUARE: + c.vmulsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]]); + break; + case Operation::CUBE: + c.vmulsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]]); + c.vmulsd(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]]); + break; + case Operation::RECIPROCAL: + c.vdivsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]], workspaceVar[args[0]]); + break; + case Operation::ADD_CONSTANT: + c.vaddsd(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]); + break; + case Operation::MULTIPLY_CONSTANT: + c.vmulsd(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]); + break; + case Operation::POWER_CONSTANT: + generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]], pow); + break; + case Operation::MIN: + c.vminsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::MAX: + c.vmaxsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::ABS: + c.vandpd(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]); + break; + case Operation::FLOOR: + c.vroundsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]], imm(1)); + break; + case Operation::CEIL: + c.vroundsd(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]], imm(2)); + break; + case Operation::SELECT: + { + x86::Xmm mask = c.newXmmSd(); + c.vxorps(mask, mask, mask); + c.vcmpsd(mask, mask, workspaceVar[args[0]], imm(0)); // Comparison mode is _CMP_EQ_OQ = 0 + c.vblendvps(workspaceVar[target[step]], workspaceVar[args[1]], workspaceVar[args[2]], mask); + break; + } + default: + // Just invoke evaluateOperation(). + + for (int i = 0; i < (int) args.size(); i++) + c.vmovsd(x86::ptr(argsPointer, 8*i, 0), workspaceVar[args[i]]); + x86::Gp fn = c.newIntPtr(); + c.mov(fn, imm((void*) evaluateOperation)); + InvokeNode* invoke; + c.invoke(&invoke, fn, FuncSignatureT()); + invoke->setArg(0, imm(&op)); + invoke->setArg(1, imm(&argValues[0])); + invoke->setRet(0, workspaceVar[target[step]]); + } + } + c.ret(workspaceVar[workspace.size()-1]); + c.endFunc(); + c.finalize(); + runtime.add(&jitCode, &code); +} + +void CompiledExpression::generateSingleArgCall(x86::Compiler& c, x86::Xmm& dest, x86::Xmm& arg, double (*function)(double)) { + x86::Gp fn = c.newIntPtr(); + c.mov(fn, imm((void*) function)); + InvokeNode* invoke; + c.invoke(&invoke, fn, FuncSignatureT()); + invoke->setArg(0, arg); + invoke->setRet(0, dest); +} + +void CompiledExpression::generateTwoArgCall(x86::Compiler& c, x86::Xmm& dest, x86::Xmm& arg1, x86::Xmm& arg2, double (*function)(double, double)) { + x86::Gp fn = c.newIntPtr(); + c.mov(fn, imm((void*) function)); + InvokeNode* invoke; + c.invoke(&invoke, fn, FuncSignatureT()); + invoke->setArg(0, arg1); + invoke->setArg(1, arg2); + invoke->setRet(0, dest); } #endif +#endif diff --git a/lib/lepton/src/CompiledVectorExpression.cpp b/lib/lepton/src/CompiledVectorExpression.cpp new file mode 100644 index 0000000000..7e4dfcad9c --- /dev/null +++ b/lib/lepton/src/CompiledVectorExpression.cpp @@ -0,0 +1,933 @@ +/* -------------------------------------------------------------------------- * + * Lepton * + * -------------------------------------------------------------------------- * + * This is part of the Lepton expression parser originating from * + * Simbios, the NIH National Center for Physics-Based Simulation of * + * Biological Structures at Stanford, funded under the NIH Roadmap for * + * Medical Research, grant U54 GM072970. See https://simtk.org. * + * * + * Portions copyright (c) 2013-2022 Stanford University and the Authors. * + * Authors: Peter Eastman * + * Contributors: * + * * + * Permission is hereby granted, free of charge, to any person obtaining a * + * copy of this software and associated documentation files (the "Software"), * + * to deal in the Software without restriction, including without limitation * + * the rights to use, copy, modify, merge, publish, distribute, sublicense, * + * and/or sell copies of the Software, and to permit persons to whom the * + * Software is furnished to do so, subject to the following conditions: * + * * + * The above copyright notice and this permission notice shall be included in * + * all copies or substantial portions of the Software. * + * * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL * + * THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, * + * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR * + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE * + * USE OR OTHER DEALINGS IN THE SOFTWARE. * + * -------------------------------------------------------------------------- */ + +#include "lepton/CompiledVectorExpression.h" +#include "lepton/Operation.h" +#include "lepton/ParsedExpression.h" +#include +#include + +using namespace LMP_Lepton; +using namespace std; +#ifdef LEPTON_USE_JIT +using namespace asmjit; +#endif + +CompiledVectorExpression::CompiledVectorExpression() : jitCode(NULL) { +} + +CompiledVectorExpression::CompiledVectorExpression(const ParsedExpression& expression, int width) : width(width), jitCode(NULL) { + const vector allowedWidths = getAllowedWidths(); + if (find(allowedWidths.begin(), allowedWidths.end(), width) == allowedWidths.end()) + throw Exception("Unsupported width for vector expression: "+to_string(width)); + ParsedExpression expr = expression.optimize(); // Just in case it wasn't already optimized. + vector > temps; + int workspaceSize = 0; + compileExpression(expr.getRootNode(), temps, workspaceSize); + workspace.resize(workspaceSize*width); + int maxArguments = 1; + for (int i = 0; i < (int) operation.size(); i++) + if (operation[i]->getNumArguments() > maxArguments) + maxArguments = operation[i]->getNumArguments(); + argValues.resize(maxArguments); +#ifdef LEPTON_USE_JIT + generateJitCode(); +#endif +} + +CompiledVectorExpression::~CompiledVectorExpression() { + for (int i = 0; i < (int) operation.size(); i++) + if (operation[i] != NULL) + delete operation[i]; +} + +CompiledVectorExpression::CompiledVectorExpression(const CompiledVectorExpression& expression) : jitCode(NULL) { + *this = expression; +} + +CompiledVectorExpression& CompiledVectorExpression::operator=(const CompiledVectorExpression& expression) { + arguments = expression.arguments; + width = expression.width; + target = expression.target; + variableIndices = expression.variableIndices; + variableNames = expression.variableNames; + workspace.resize(expression.workspace.size()); + argValues.resize(expression.argValues.size()); + operation.resize(expression.operation.size()); + for (int i = 0; i < (int) operation.size(); i++) + operation[i] = expression.operation[i]->clone(); + setVariableLocations(variablePointers); + return *this; +} + +const vector& CompiledVectorExpression::getAllowedWidths() { + static vector widths; + if (widths.size() == 0) { + widths.push_back(4); +#ifdef LEPTON_USE_JIT + const CpuInfo& cpu = CpuInfo::host(); + if (cpu.hasFeature(CpuFeatures::X86::kAVX)) + widths.push_back(8); +#endif + } + return widths; +} + +void CompiledVectorExpression::compileExpression(const ExpressionTreeNode& node, vector >& temps, int& workspaceSize) { + if (findTempIndex(node, temps) != -1) + return; // We have already processed a node identical to this one. + + // Process the child nodes. + + vector args; + for (int i = 0; i < (int)node.getChildren().size(); i++) { + compileExpression(node.getChildren()[i], temps, workspaceSize); + args.push_back(findTempIndex(node.getChildren()[i], temps)); + } + + // Process this node. + + if (node.getOperation().getId() == Operation::VARIABLE) { + variableIndices[node.getOperation().getName()] = workspaceSize; + variableNames.insert(node.getOperation().getName()); + } + else { + int stepIndex = (int) arguments.size(); + arguments.push_back(vector()); + target.push_back(workspaceSize); + operation.push_back(node.getOperation().clone()); + if (args.size() == 0) + arguments[stepIndex].push_back(0); // The value won't actually be used. We just need something there. + else { + // If the arguments are sequential, we can just pass a pointer to the first one. + + bool sequential = true; + for (int i = 1; i < (int)args.size(); i++) + if (args[i] != args[i - 1] + 1) + sequential = false; + if (sequential) + arguments[stepIndex].push_back(args[0]); + else + arguments[stepIndex] = args; + } + } + temps.push_back(make_pair(node, workspaceSize)); + workspaceSize++; +} + +int CompiledVectorExpression::findTempIndex(const ExpressionTreeNode& node, vector >& temps) { + for (int i = 0; i < (int) temps.size(); i++) + if (temps[i].first == node) + return i; + return -1; +} + +int CompiledVectorExpression::getWidth() const { + return width; +} + +const set& CompiledVectorExpression::getVariables() const { + return variableNames; +} + +float* CompiledVectorExpression::getVariablePointer(const string& name) { + map::iterator pointer = variablePointers.find(name); + if (pointer != variablePointers.end()) + return pointer->second; + map::iterator index = variableIndices.find(name); + if (index == variableIndices.end()) + throw Exception("getVariableReference: Unknown variable '" + name + "'"); + return &workspace[index->second*width]; +} + +void CompiledVectorExpression::setVariableLocations(map& variableLocations) { + variablePointers = variableLocations; +#ifdef LEPTON_USE_JIT + // Rebuild the JIT code. + + if (workspace.size() > 0) + generateJitCode(); +#endif + // Make a list of all variables we will need to copy before evaluating the expression. + + variablesToCopy.clear(); + for (map::const_iterator iter = variableIndices.begin(); iter != variableIndices.end(); ++iter) { + map::iterator pointer = variablePointers.find(iter->first); + if (pointer != variablePointers.end()) + variablesToCopy.push_back(make_pair(&workspace[iter->second*width], pointer->second)); + } +} + +const float* CompiledVectorExpression::evaluate() const { + if (jitCode) { + jitCode(); + return &workspace[workspace.size()-width]; + } + for (int i = 0; i < (int)variablesToCopy.size(); i++) + for (int j = 0; j < width; j++) + variablesToCopy[i].first[j] = variablesToCopy[i].second[j]; + + // Loop over the operations and evaluate each one. + + for (int step = 0; step < (int)operation.size(); step++) { + const vector& args = arguments[step]; + if (args.size() == 1) { + for (int j = 0; j < width; j++) { + for (int i = 0; i < operation[step]->getNumArguments(); i++) + argValues[i] = workspace[(args[0]+i)*width+j]; + workspace[target[step]*width+j] = operation[step]->evaluate(&argValues[0], dummyVariables); + } + } else { + for (int j = 0; j < width; j++) { + for (int i = 0; i < (int)args.size(); i++) + argValues[i] = workspace[args[i]*width+j]; + workspace[target[step]*width+j] = operation[step]->evaluate(&argValues[0], dummyVariables); + } + } + } + return &workspace[workspace.size()-width]; +} + +#ifdef LEPTON_USE_JIT + +static double evaluateOperation(Operation* op, double* args) { + static map dummyVariables; + return op->evaluate(args, dummyVariables); +} + +void CompiledVectorExpression::findPowerGroups(vector >& groups, vector >& groupPowers, vector& stepGroup) { + // Identify every step that raises an argument to an integer power. + + vector stepPower(operation.size(), 0); + vector stepArg(operation.size(), -1); + for (int step = 0; step < (int)operation.size(); step++) { + Operation& op = *operation[step]; + int power = 0; + if (op.getId() == Operation::SQUARE) + power = 2; + else if (op.getId() == Operation::CUBE) + power = 3; + else if (op.getId() == Operation::POWER_CONSTANT) { + double realPower = dynamic_cast (&op)->getValue(); + if (realPower == (int) realPower) + power = (int) realPower; + } + if (power != 0) { + stepPower[step] = power; + stepArg[step] = arguments[step][0]; + } + } + + // Find groups that operate on the same argument and whose powers have the same sign. + + stepGroup.resize(operation.size(), -1); + for (int i = 0; i < (int)operation.size(); i++) { + if (stepGroup[i] != -1) + continue; + vector group, power; + for (int j = i; j < (int)operation.size(); j++) { + if (stepArg[i] == stepArg[j] && stepPower[i] * stepPower[j] > 0) { + stepGroup[j] = groups.size(); + group.push_back(j); + power.push_back(stepPower[j]); + } + } + groups.push_back(group); + groupPowers.push_back(power); + } +} + +#if defined(__ARM__) || defined(__ARM64__) + +void CompiledVectorExpression::generateJitCode() { + CodeHolder code; + code.init(runtime.environment()); + a64::Compiler c(&code); + c.addFunc(FuncSignatureT()); + vector workspaceVar(workspace.size()/width); + for (int i = 0; i < (int) workspaceVar.size(); i++) + workspaceVar[i] = c.newVecQ(); + arm::Gp argsPointer = c.newIntPtr(); + c.mov(argsPointer, imm(&argValues[0])); + vector > groups, groupPowers; + vector stepGroup; + findPowerGroups(groups, groupPowers, stepGroup); + + // Load the arguments into variables. + + arm::Gp variablePointer = c.newIntPtr(); + for (set::const_iterator iter = variableNames.begin(); iter != variableNames.end(); ++iter) { + map::iterator index = variableIndices.find(*iter); + c.mov(variablePointer, imm(getVariablePointer(index->first))); + c.ldr(workspaceVar[index->second].s4(), arm::ptr(variablePointer, 0)); + } + + // Make a list of all constants that will be needed for evaluation. + + vector operationConstantIndex(operation.size(), -1); + for (int step = 0; step < (int) operation.size(); step++) { + // Find the constant value (if any) used by this operation. + + Operation& op = *operation[step]; + float value; + if (op.getId() == Operation::CONSTANT) + value = dynamic_cast (op).getValue(); + else if (op.getId() == Operation::ADD_CONSTANT) + value = dynamic_cast (op).getValue(); + else if (op.getId() == Operation::MULTIPLY_CONSTANT) + value = dynamic_cast (op).getValue(); + else if (op.getId() == Operation::RECIPROCAL) + value = 1.0; + else if (op.getId() == Operation::STEP) + value = 1.0; + else if (op.getId() == Operation::DELTA) + value = 1.0; + else if (op.getId() == Operation::POWER_CONSTANT) { + if (stepGroup[step] == -1) + value = dynamic_cast (op).getValue(); + else + value = 1.0; + } else + continue; + + // See if we already have a variable for this constant. + + for (int i = 0; i < (int) constants.size(); i++) + if (value == constants[i]) { + operationConstantIndex[step] = i; + break; + } + if (operationConstantIndex[step] == -1) { + operationConstantIndex[step] = constants.size(); + constants.push_back(value); + } + } + + // Load constants into variables. + + vector constantVar(constants.size()); + if (constants.size() > 0) { + arm::Gp constantsPointer = c.newIntPtr(); + for (int i = 0; i < (int) constants.size(); i++) { + c.mov(constantsPointer, imm(&constants[i])); + constantVar[i] = c.newVecQ(); + c.ld1r(constantVar[i].s4(), arm::ptr(constantsPointer)); + } + } + + // Evaluate the operations. + + vector hasComputedPower(operation.size(), false); + arm::Vec argReg = c.newVecS(); + arm::Vec doubleArgReg = c.newVecD(); + arm::Vec doubleResultReg = c.newVecD(); + for (int step = 0; step < (int) operation.size(); step++) { + if (hasComputedPower[step]) + continue; + + // When one or more steps involve raising the same argument to multiple integer + // powers, we can compute them all together for efficiency. + + if (stepGroup[step] != -1) { + vector& group = groups[stepGroup[step]]; + vector& powers = groupPowers[stepGroup[step]]; + arm::Vec multiplier = c.newVecQ(); + if (powers[0] > 0) + c.mov(multiplier.s4(), workspaceVar[arguments[step][0]].s4()); + else { + c.fdiv(multiplier.s4(), constantVar[operationConstantIndex[step]].s4(), workspaceVar[arguments[step][0]].s4()); + for (int i = 0; i < powers.size(); i++) + powers[i] = -powers[i]; + } + vector hasAssigned(group.size(), false); + bool done = false; + while (!done) { + done = true; + for (int i = 0; i < group.size(); i++) { + if (powers[i] % 2 == 1) { + if (!hasAssigned[i]) + c.mov(workspaceVar[target[group[i]]].s4(), multiplier.s4()); + else + c.fmul(workspaceVar[target[group[i]]].s4(), workspaceVar[target[group[i]]].s4(), multiplier.s4()); + hasAssigned[i] = true; + } + powers[i] >>= 1; + if (powers[i] != 0) + done = false; + } + if (!done) + c.fmul(multiplier.s4(), multiplier.s4(), multiplier.s4()); + } + for (int step : group) + hasComputedPower[step] = true; + continue; + } + + // Evaluate the step. + + Operation& op = *operation[step]; + vector args = arguments[step]; + if (args.size() == 1) { + // One or more sequential arguments. Fill out the list. + + for (int i = 1; i < op.getNumArguments(); i++) + args.push_back(args[0] + i); + } + + // Generate instructions to execute this operation. + + switch (op.getId()) { + case Operation::CONSTANT: + c.mov(workspaceVar[target[step]].s4(), constantVar[operationConstantIndex[step]].s4()); + break; + case Operation::ADD: + c.fadd(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[1]].s4()); + break; + case Operation::SUBTRACT: + c.fsub(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[1]].s4()); + break; + case Operation::MULTIPLY: + c.fmul(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[1]].s4()); + break; + case Operation::DIVIDE: + c.fdiv(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[1]].s4()); + break; + case Operation::POWER: + generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], powf); + break; + case Operation::NEGATE: + c.fneg(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4()); + break; + case Operation::SQRT: + c.fsqrt(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4()); + break; + case Operation::EXP: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], expf); + break; + case Operation::LOG: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], logf); + break; + case Operation::SIN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], sinf); + break; + case Operation::COS: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], cosf); + break; + case Operation::TAN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tanf); + break; + case Operation::ASIN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], asinf); + break; + case Operation::ACOS: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], acosf); + break; + case Operation::ATAN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], atanf); + break; + case Operation::ATAN2: + generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], atan2f); + break; + case Operation::SINH: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], sinhf); + break; + case Operation::COSH: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], coshf); + break; + case Operation::TANH: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tanhf); + break; + case Operation::STEP: + c.cmge(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), imm(0)); + c.and_(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + break; + case Operation::DELTA: + c.cmeq(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), imm(0)); + c.and_(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + break; + case Operation::SQUARE: + c.fmul(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[0]].s4()); + break; + case Operation::CUBE: + c.fmul(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[0]].s4()); + c.fmul(workspaceVar[target[step]].s4(), workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4()); + break; + case Operation::RECIPROCAL: + c.fdiv(workspaceVar[target[step]].s4(), constantVar[operationConstantIndex[step]].s4(), workspaceVar[args[0]].s4()); + break; + case Operation::ADD_CONSTANT: + c.fadd(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), constantVar[operationConstantIndex[step]].s4()); + break; + case Operation::MULTIPLY_CONSTANT: + c.fmul(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), constantVar[operationConstantIndex[step]].s4()); + break; + case Operation::POWER_CONSTANT: + generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]], powf); + break; + case Operation::MIN: + c.fmin(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[1]].s4()); + break; + case Operation::MAX: + c.fmax(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), workspaceVar[args[1]].s4()); + break; + case Operation::ABS: + c.fabs(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4()); + break; + case Operation::FLOOR: + c.frintm(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4()); + break; + case Operation::CEIL: + c.frintp(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4()); + break; + case Operation::SELECT: + c.fcmeq(workspaceVar[target[step]].s4(), workspaceVar[args[0]].s4(), imm(0)); + c.bsl(workspaceVar[target[step]], workspaceVar[args[2]], workspaceVar[args[1]]); + break; + default: + // Just invoke evaluateOperation(). + for (int element = 0; element < width; element++) { + for (int i = 0; i < (int) args.size(); i++) { + c.ins(argReg.s(0), workspaceVar[args[i]].s(element)); + c.fcvt(doubleArgReg, argReg); + c.str(doubleArgReg, arm::ptr(argsPointer, 8*i)); + } + arm::Gp fn = c.newIntPtr(); + c.mov(fn, imm((void*) evaluateOperation)); + InvokeNode* invoke; + c.invoke(&invoke, fn, FuncSignatureT()); + invoke->setArg(0, imm(&op)); + invoke->setArg(1, imm(&argValues[0])); + invoke->setRet(0, doubleResultReg); + c.fcvt(argReg, doubleResultReg); + c.ins(workspaceVar[target[step]].s(element), argReg.s(0)); + } + } + } + arm::Gp resultPointer = c.newIntPtr(); + c.mov(resultPointer, imm(&workspace[workspace.size()-width])); + c.str(workspaceVar.back().s4(), arm::ptr(resultPointer, 0)); + c.endFunc(); + c.finalize(); + runtime.add(&jitCode, &code); +} + +void CompiledVectorExpression::generateSingleArgCall(a64::Compiler& c, arm::Vec& dest, arm::Vec& arg, float (*function)(float)) { + arm::Gp fn = c.newIntPtr(); + c.mov(fn, imm((void*) function)); + arm::Vec a = c.newVecS(); + arm::Vec d = c.newVecS(); + for (int element = 0; element < width; element++) { + c.ins(a.s(0), arg.s(element)); + InvokeNode* invoke; + c.invoke(&invoke, fn, FuncSignatureT()); + invoke->setArg(0, a); + invoke->setRet(0, d); + c.ins(dest.s(element), d.s(0)); + } +} + +void CompiledVectorExpression::generateTwoArgCall(a64::Compiler& c, arm::Vec& dest, arm::Vec& arg1, arm::Vec& arg2, float (*function)(float, float)) { + arm::Gp fn = c.newIntPtr(); + c.mov(fn, imm((void*) function)); + arm::Vec a1 = c.newVecS(); + arm::Vec a2 = c.newVecS(); + arm::Vec d = c.newVecS(); + for (int element = 0; element < width; element++) { + c.ins(a1.s(0), arg1.s(element)); + c.ins(a2.s(0), arg2.s(element)); + InvokeNode* invoke; + c.invoke(&invoke, fn, FuncSignatureT()); + invoke->setArg(0, a1); + invoke->setArg(1, a2); + invoke->setRet(0, d); + c.ins(dest.s(element), d.s(0)); + } +} +#else + +void CompiledVectorExpression::generateJitCode() { + const CpuInfo& cpu = CpuInfo::host(); + if (!cpu.hasFeature(CpuFeatures::X86::kAVX)) + return; + CodeHolder code; + code.init(runtime.environment()); + x86::Compiler c(&code); + FuncNode* funcNode = c.addFunc(FuncSignatureT()); + funcNode->frame().setAvxEnabled(); + vector workspaceVar(workspace.size()/width); + for (int i = 0; i < (int) workspaceVar.size(); i++) + workspaceVar[i] = c.newYmmPs(); + x86::Gp argsPointer = c.newIntPtr(); + c.mov(argsPointer, imm(&argValues[0])); + vector > groups, groupPowers; + vector stepGroup; + findPowerGroups(groups, groupPowers, stepGroup); + + // Load the arguments into variables. + + for (set::const_iterator iter = variableNames.begin(); iter != variableNames.end(); ++iter) { + map::iterator index = variableIndices.find(*iter); + x86::Gp variablePointer = c.newIntPtr(); + c.mov(variablePointer, imm(getVariablePointer(index->first))); + if (width == 4) + c.vmovdqu(workspaceVar[index->second].xmm(), x86::ptr(variablePointer, 0, 0)); + else + c.vmovdqu(workspaceVar[index->second], x86::ptr(variablePointer, 0, 0)); + } + + // Make a list of all constants that will be needed for evaluation. + + vector operationConstantIndex(operation.size(), -1); + for (int step = 0; step < (int) operation.size(); step++) { + // Find the constant value (if any) used by this operation. + + Operation& op = *operation[step]; + double value; + if (op.getId() == Operation::CONSTANT) + value = dynamic_cast (op).getValue(); + else if (op.getId() == Operation::ADD_CONSTANT) + value = dynamic_cast (op).getValue(); + else if (op.getId() == Operation::MULTIPLY_CONSTANT) + value = dynamic_cast (op).getValue(); + else if (op.getId() == Operation::RECIPROCAL) + value = 1.0; + else if (op.getId() == Operation::STEP) + value = 1.0; + else if (op.getId() == Operation::DELTA) + value = 1.0; + else if (op.getId() == Operation::ABS) { + int mask = 0x7FFFFFFF; + value = *reinterpret_cast(&mask); + } + else if (op.getId() == Operation::POWER_CONSTANT) { + if (stepGroup[step] == -1) + value = dynamic_cast (op).getValue(); + else + value = 1.0; + } else + continue; + + // See if we already have a variable for this constant. + + for (int i = 0; i < (int) constants.size(); i++) + if (value == constants[i]) { + operationConstantIndex[step] = i; + break; + } + if (operationConstantIndex[step] == -1) { + operationConstantIndex[step] = constants.size(); + constants.push_back(value); + } + } + + // Load constants into variables. + + vector constantVar(constants.size()); + if (constants.size() > 0) { + x86::Gp constantsPointer = c.newIntPtr(); + c.mov(constantsPointer, imm(&constants[0])); + for (int i = 0; i < (int) constants.size(); i++) { + constantVar[i] = c.newYmmPs(); + c.vbroadcastss(constantVar[i], x86::ptr(constantsPointer, 4*i, 0)); + } + } + + // Evaluate the operations. + + vector hasComputedPower(operation.size(), false); + x86::Ymm argReg = c.newYmm(); + x86::Ymm doubleArgReg = c.newYmm(); + x86::Ymm doubleResultReg = c.newYmm(); + for (int step = 0; step < (int) operation.size(); step++) { + if (hasComputedPower[step]) + continue; + + // When one or more steps involve raising the same argument to multiple integer + // powers, we can compute them all together for efficiency. + + if (stepGroup[step] != -1) { + vector& group = groups[stepGroup[step]]; + vector& powers = groupPowers[stepGroup[step]]; + x86::Ymm multiplier = c.newYmmPs(); + if (powers[0] > 0) + c.vmovdqu(multiplier, workspaceVar[arguments[step][0]]); + else { + c.vdivps(multiplier, constantVar[operationConstantIndex[step]], workspaceVar[arguments[step][0]]); + for (int i = 0; i < (int)powers.size(); i++) + powers[i] = -powers[i]; + } + vector hasAssigned(group.size(), false); + bool done = false; + while (!done) { + done = true; + for (int i = 0; i < (int)group.size(); i++) { + if (powers[i] % 2 == 1) { + if (!hasAssigned[i]) + c.vmovdqu(workspaceVar[target[group[i]]], multiplier); + else + c.vmulps(workspaceVar[target[group[i]]], workspaceVar[target[group[i]]], multiplier); + hasAssigned[i] = true; + } + powers[i] >>= 1; + if (powers[i] != 0) + done = false; + } + if (!done) + c.vmulps(multiplier, multiplier, multiplier); + } + for (int step : group) + hasComputedPower[step] = true; + continue; + } + + // Evaluate the step. + + Operation& op = *operation[step]; + vector args = arguments[step]; + if (args.size() == 1) { + // One or more sequential arguments. Fill out the list. + + for (int i = 1; i < op.getNumArguments(); i++) + args.push_back(args[0] + i); + } + + // Generate instructions to execute this operation. + + switch (op.getId()) { + case Operation::CONSTANT: + c.vmovdqu(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + break; + case Operation::ADD: + c.vaddps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::SUBTRACT: + c.vsubps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::MULTIPLY: + c.vmulps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::DIVIDE: + c.vdivps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::POWER: + generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], powf); + break; + case Operation::NEGATE: + c.vxorps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[target[step]]); + c.vsubps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]]); + break; + case Operation::SQRT: + c.vsqrtps(workspaceVar[target[step]], workspaceVar[args[0]]); + break; + case Operation::EXP: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], expf); + break; + case Operation::LOG: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], logf); + break; + case Operation::SIN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], sinf); + break; + case Operation::COS: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], cosf); + break; + case Operation::TAN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tanf); + break; + case Operation::ASIN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], asinf); + break; + case Operation::ACOS: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], acosf); + break; + case Operation::ATAN: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], atanf); + break; + case Operation::ATAN2: + generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]], atan2f); + break; + case Operation::SINH: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], sinhf); + break; + case Operation::COSH: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], coshf); + break; + case Operation::TANH: + generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tanhf); + break; + case Operation::STEP: + c.vxorps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[target[step]]); + c.vcmpps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]], imm(18)); // Comparison mode is _CMP_LE_OQ = 18 + c.vandps(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + break; + case Operation::DELTA: + c.vxorps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[target[step]]); + c.vcmpps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]], imm(16)); // Comparison mode is _CMP_EQ_OQ = 0 + c.vandps(workspaceVar[target[step]], workspaceVar[target[step]], constantVar[operationConstantIndex[step]]); + break; + case Operation::SQUARE: + c.vmulps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]]); + break; + case Operation::CUBE: + c.vmulps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[0]]); + c.vmulps(workspaceVar[target[step]], workspaceVar[target[step]], workspaceVar[args[0]]); + break; + case Operation::RECIPROCAL: + c.vdivps(workspaceVar[target[step]], constantVar[operationConstantIndex[step]], workspaceVar[args[0]]); + break; + case Operation::ADD_CONSTANT: + c.vaddps(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]); + break; + case Operation::MULTIPLY_CONSTANT: + c.vmulps(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]); + break; + case Operation::POWER_CONSTANT: + generateTwoArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]], powf); + break; + case Operation::MIN: + c.vminps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::MAX: + c.vmaxps(workspaceVar[target[step]], workspaceVar[args[0]], workspaceVar[args[1]]); + break; + case Operation::ABS: + c.vandps(workspaceVar[target[step]], workspaceVar[args[0]], constantVar[operationConstantIndex[step]]); + break; + case Operation::FLOOR: + c.vroundps(workspaceVar[target[step]], workspaceVar[args[0]], imm(1)); + break; + case Operation::CEIL: + c.vroundps(workspaceVar[target[step]], workspaceVar[args[0]], imm(2)); + break; + case Operation::SELECT: + { + x86::Ymm mask = c.newYmmPs(); + c.vxorps(mask, mask, mask); + c.vcmpps(mask, mask, workspaceVar[args[0]], imm(0)); // Comparison mode is _CMP_EQ_OQ = 0 + c.vblendvps(workspaceVar[target[step]], workspaceVar[args[1]], workspaceVar[args[2]], mask); + break; + } + default: + // Just invoke evaluateOperation(). + + for (int element = 0; element < width; element++) { + for (int i = 0; i < (int) args.size(); i++) { + if (element < 4) + c.vshufps(argReg, workspaceVar[args[i]], workspaceVar[args[i]], imm(element)); + else { + c.vperm2f128(argReg, workspaceVar[args[i]], workspaceVar[args[i]], imm(1)); + c.vshufps(argReg, argReg, argReg, imm(element-4)); + } + c.vcvtss2sd(doubleArgReg.xmm(), doubleArgReg.xmm(), argReg.xmm()); + c.vmovsd(x86::ptr(argsPointer, 8*i, 0), doubleArgReg.xmm()); + } + x86::Gp fn = c.newIntPtr(); + c.mov(fn, imm((void*) evaluateOperation)); + InvokeNode* invoke; + c.invoke(&invoke, fn, FuncSignatureT()); + invoke->setArg(0, imm(&op)); + invoke->setArg(1, imm(&argValues[0])); + invoke->setRet(0, doubleResultReg); + c.vcvtsd2ss(argReg.xmm(), argReg.xmm(), doubleResultReg.xmm()); + if (element > 3) + c.vperm2f128(argReg, argReg, argReg, imm(0)); + if (element != 0) + c.vshufps(argReg, argReg, argReg, imm(0)); + c.vblendps(workspaceVar[target[step]], workspaceVar[target[step]], argReg, 1<()); + invoke->setArg(0, a); + invoke->setRet(0, d); + if (element > 3) + c.vperm2f128(d, d, d, imm(0)); + if (element != 0) + c.vshufps(d, d, d, imm(0)); + c.vblendps(dest, dest, d, 1<()); + invoke->setArg(0, a1); + invoke->setArg(1, a2); + invoke->setRet(0, d); + if (element > 3) + c.vperm2f128(d, d, d, imm(0)); + if (element != 0) + c.vshufps(d, d, d, imm(0)); + c.vblendps(dest, dest, d, 1< using namespace LMP_Lepton; using namespace std; @@ -62,6 +63,11 @@ ExpressionTreeNode::ExpressionTreeNode(Operation* operation) : operation(operati ExpressionTreeNode::ExpressionTreeNode(const ExpressionTreeNode& node) : operation(node.operation == NULL ? NULL : node.operation->clone()), children(node.getChildren()) { } +ExpressionTreeNode::ExpressionTreeNode(ExpressionTreeNode&& node) : operation(node.operation), children(move(node.children)) { + node.operation = NULL; + node.children.clear(); +} + ExpressionTreeNode::ExpressionTreeNode() : operation(NULL) { } @@ -98,6 +104,16 @@ ExpressionTreeNode& ExpressionTreeNode::operator=(const ExpressionTreeNode& node return *this; } +ExpressionTreeNode& ExpressionTreeNode::operator=(ExpressionTreeNode&& node) { + if (operation != NULL) + delete operation; + operation = node.operation; + children = move(node.children); + node.operation = NULL; + node.children.clear(); + return *this; +} + const Operation& ExpressionTreeNode::getOperation() const { return *operation; } @@ -105,3 +121,33 @@ const Operation& ExpressionTreeNode::getOperation() const { const vector& ExpressionTreeNode::getChildren() const { return children; } + +void ExpressionTreeNode::assignTags(vector& examples) const { + // Assign tag values to all nodes in a tree, such that two nodes have the same + // tag if and only if they (and all their children) are equal. This is used to + // optimize other operations. + + int numTags = examples.size(); + for (const ExpressionTreeNode& child : getChildren()) + child.assignTags(examples); + if (numTags == (int)examples.size()) { + // All the children matched existing tags, so possibly this node does too. + + for (int i = 0; i < (int)examples.size(); i++) { + const ExpressionTreeNode& example = *examples[i]; + bool matches = (getChildren().size() == example.getChildren().size() && getOperation() == example.getOperation()); + for (int j = 0; matches && j < (int)getChildren().size(); j++) + if (getChildren()[j].tag != example.getChildren()[j].tag) + matches = false; + if (matches) { + tag = i; + return; + } + } + } + + // This node does not match any previous node, so assign a new tag. + + tag = examples.size(); + examples.push_back(this); +} diff --git a/lib/lepton/src/Operation.cpp b/lib/lepton/src/Operation.cpp index bec5686a74..08deff8584 100644 --- a/lib/lepton/src/Operation.cpp +++ b/lib/lepton/src/Operation.cpp @@ -7,7 +7,7 @@ * Biological Structures at Stanford, funded under the NIH Roadmap for * * Medical Research, grant U54 GM072970. See https://simtk.org. * * * - * Portions copyright (c) 2009-2019 Stanford University and the Authors. * + * Portions copyright (c) 2009-2021 Stanford University and the Authors. * * Authors: Peter Eastman * * Contributors: * * * @@ -37,7 +37,13 @@ using namespace LMP_Lepton; using namespace std; -double Operation::Erf::evaluate(double* args, const map& ) const { +static bool isZero(const ExpressionTreeNode& node) { + if (node.getOperation().getId() != Operation::CONSTANT) + return false; + return dynamic_cast(node.getOperation()).getValue() == 0.0; +} + +double Operation::Erf::evaluate(double* args, const map&) const { return erf(args[0]); } @@ -58,35 +64,71 @@ ExpressionTreeNode Operation::Variable::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { if (function->getNumArguments() == 0) return ExpressionTreeNode(new Operation::Constant(0.0)); - ExpressionTreeNode result = ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Custom(*this, 0), children), childDerivs[0]); - for (int i = 1; i < getNumArguments(); i++) { - result = ExpressionTreeNode(new Operation::Add(), - result, - ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Custom(*this, i), children), childDerivs[i])); + ExpressionTreeNode result; + bool foundTerm = false; + for (int i = 0; i < getNumArguments(); i++) { + if (!isZero(childDerivs[i])) { + if (foundTerm) + result = ExpressionTreeNode(new Operation::Add(), + result, + ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Custom(*this, i), children), childDerivs[i])); + else { + result = ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Custom(*this, i), children), childDerivs[i]); + foundTerm = true; + } + } } - return result; + if (foundTerm) + return result; + return ExpressionTreeNode(new Operation::Constant(0.0)); } ExpressionTreeNode Operation::Add::differentiate(const std::vector& , const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) + return childDerivs[1]; + if (isZero(childDerivs[1])) + return childDerivs[0]; return ExpressionTreeNode(new Operation::Add(), childDerivs[0], childDerivs[1]); } ExpressionTreeNode Operation::Subtract::differentiate(const std::vector& , const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) { + if (isZero(childDerivs[1])) + return ExpressionTreeNode(new Operation::Constant(0.0)); + return ExpressionTreeNode(new Operation::Negate(), childDerivs[1]); + } + if (isZero(childDerivs[1])) + return childDerivs[0]; return ExpressionTreeNode(new Operation::Subtract(), childDerivs[0], childDerivs[1]); } ExpressionTreeNode Operation::Multiply::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) { + if (isZero(childDerivs[1])) + return ExpressionTreeNode(new Operation::Constant(0.0)); + return ExpressionTreeNode(new Operation::Multiply(), children[0], childDerivs[1]); + } + if (isZero(childDerivs[1])) + return ExpressionTreeNode(new Operation::Multiply(), children[1], childDerivs[0]); return ExpressionTreeNode(new Operation::Add(), ExpressionTreeNode(new Operation::Multiply(), children[0], childDerivs[1]), ExpressionTreeNode(new Operation::Multiply(), children[1], childDerivs[0])); } ExpressionTreeNode Operation::Divide::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { - return ExpressionTreeNode(new Operation::Divide(), - ExpressionTreeNode(new Operation::Subtract(), + ExpressionTreeNode subexp; + if (isZero(childDerivs[0])) { + if (isZero(childDerivs[1])) + return ExpressionTreeNode(new Operation::Constant(0.0)); + subexp = ExpressionTreeNode(new Operation::Negate(), ExpressionTreeNode(new Operation::Multiply(), children[0], childDerivs[1])); + } + else if (isZero(childDerivs[1])) + subexp = ExpressionTreeNode(new Operation::Multiply(), children[1], childDerivs[0]); + else + subexp = ExpressionTreeNode(new Operation::Subtract(), ExpressionTreeNode(new Operation::Multiply(), children[1], childDerivs[0]), - ExpressionTreeNode(new Operation::Multiply(), children[0], childDerivs[1])), - ExpressionTreeNode(new Operation::Square(), children[1])); + ExpressionTreeNode(new Operation::Multiply(), children[0], childDerivs[1])); + return ExpressionTreeNode(new Operation::Divide(), subexp, ExpressionTreeNode(new Operation::Square(), children[1])); } ExpressionTreeNode Operation::Power::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { @@ -105,10 +147,14 @@ ExpressionTreeNode Operation::Power::differentiate(const std::vector& , const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Negate(), childDerivs[0]); } ExpressionTreeNode Operation::Sqrt::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::MultiplyConstant(0.5), ExpressionTreeNode(new Operation::Reciprocal(), @@ -117,24 +163,32 @@ ExpressionTreeNode Operation::Sqrt::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Exp(), children[0]), childDerivs[0]); } ExpressionTreeNode Operation::Log::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Reciprocal(), children[0]), childDerivs[0]); } ExpressionTreeNode Operation::Sin::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Cos(), children[0]), childDerivs[0]); } ExpressionTreeNode Operation::Cos::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Negate(), ExpressionTreeNode(new Operation::Sin(), children[0])), @@ -142,6 +196,8 @@ ExpressionTreeNode Operation::Cos::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Sec(), children[0]), @@ -150,6 +206,8 @@ ExpressionTreeNode Operation::Sec::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Negate(), ExpressionTreeNode(new Operation::Multiply(), @@ -159,6 +217,8 @@ ExpressionTreeNode Operation::Csc::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Square(), ExpressionTreeNode(new Operation::Sec(), children[0])), @@ -166,6 +226,8 @@ ExpressionTreeNode Operation::Tan::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Negate(), ExpressionTreeNode(new Operation::Square(), @@ -174,6 +236,8 @@ ExpressionTreeNode Operation::Cot::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Reciprocal(), ExpressionTreeNode(new Operation::Sqrt(), @@ -184,6 +248,8 @@ ExpressionTreeNode Operation::Asin::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Negate(), ExpressionTreeNode(new Operation::Reciprocal(), @@ -195,6 +261,8 @@ ExpressionTreeNode Operation::Acos::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Reciprocal(), ExpressionTreeNode(new Operation::AddConstant(1.0), @@ -213,6 +281,8 @@ ExpressionTreeNode Operation::Atan2::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Cosh(), children[0]), @@ -220,6 +290,8 @@ ExpressionTreeNode Operation::Sinh::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Sinh(), children[0]), @@ -227,6 +299,8 @@ ExpressionTreeNode Operation::Cosh::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Subtract(), ExpressionTreeNode(new Operation::Constant(1.0)), @@ -236,6 +310,8 @@ ExpressionTreeNode Operation::Tanh::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Constant(2.0/sqrt(M_PI))), @@ -246,6 +322,8 @@ ExpressionTreeNode Operation::Erf::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Constant(-2.0/sqrt(M_PI))), @@ -264,6 +342,8 @@ ExpressionTreeNode Operation::Delta::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::MultiplyConstant(2.0), children[0]), @@ -271,6 +351,8 @@ ExpressionTreeNode Operation::Square::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::MultiplyConstant(3.0), ExpressionTreeNode(new Operation::Square(), children[0])), @@ -278,6 +360,8 @@ ExpressionTreeNode Operation::Cube::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Negate(), ExpressionTreeNode(new Operation::Reciprocal(), @@ -290,11 +374,15 @@ ExpressionTreeNode Operation::AddConstant::differentiate(const std::vector& , const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::MultiplyConstant(value), childDerivs[0]); } ExpressionTreeNode Operation::PowerConstant::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::MultiplyConstant(value), ExpressionTreeNode(new Operation::PowerConstant(value-1), @@ -305,22 +393,18 @@ ExpressionTreeNode Operation::PowerConstant::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { ExpressionTreeNode step(new Operation::Step(), ExpressionTreeNode(new Operation::Subtract(), children[0], children[1])); - return ExpressionTreeNode(new Operation::Subtract(), - ExpressionTreeNode(new Operation::Multiply(), childDerivs[1], step), - ExpressionTreeNode(new Operation::Multiply(), childDerivs[0], - ExpressionTreeNode(new Operation::AddConstant(-1), step))); + return ExpressionTreeNode(new Operation::Select(), {step, childDerivs[1], childDerivs[0]}); } ExpressionTreeNode Operation::Max::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { ExpressionTreeNode step(new Operation::Step(), ExpressionTreeNode(new Operation::Subtract(), children[0], children[1])); - return ExpressionTreeNode(new Operation::Subtract(), - ExpressionTreeNode(new Operation::Multiply(), childDerivs[0], step), - ExpressionTreeNode(new Operation::Multiply(), childDerivs[1], - ExpressionTreeNode(new Operation::AddConstant(-1), step))); + return ExpressionTreeNode(new Operation::Select(), {step, childDerivs[0], childDerivs[1]}); } ExpressionTreeNode Operation::Abs::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { + if (isZero(childDerivs[0])) + return ExpressionTreeNode(new Operation::Constant(0.0)); ExpressionTreeNode step(new Operation::Step(), children[0]); return ExpressionTreeNode(new Operation::Multiply(), childDerivs[0], @@ -337,9 +421,5 @@ ExpressionTreeNode Operation::Ceil::differentiate(const std::vector& children, const std::vector& childDerivs, const std::string& ) const { - vector derivChildren; - derivChildren.push_back(children[0]); - derivChildren.push_back(childDerivs[1]); - derivChildren.push_back(childDerivs[2]); - return ExpressionTreeNode(new Operation::Select(), derivChildren); + return ExpressionTreeNode(new Operation::Select(), {children[0], childDerivs[1], childDerivs[2]}); } diff --git a/lib/lepton/src/ParsedExpression.cpp b/lib/lepton/src/ParsedExpression.cpp index 1417551011..a6f41ae354 100644 --- a/lib/lepton/src/ParsedExpression.cpp +++ b/lib/lepton/src/ParsedExpression.cpp @@ -6,7 +6,7 @@ * Biological Structures at Stanford, funded under the NIH Roadmap for * * Medical Research, grant U54 GM072970. See https://simtk.org. * * * - * Portions copyright (c) 2009 Stanford University and the Authors. * + * Portions copyright (c) 2009-2022 Stanford University and the Authors. * * Authors: Peter Eastman * * Contributors: * * * @@ -31,6 +31,7 @@ #include "lepton/ParsedExpression.h" #include "lepton/CompiledExpression.h" +#include "lepton/CompiledVectorExpression.h" #include "lepton/ExpressionProgram.h" #include "lepton/Operation.h" #include @@ -68,9 +69,16 @@ double ParsedExpression::evaluate(const ExpressionTreeNode& node, const map examples; + result.assignTags(examples); + map nodeCache; + result = precalculateConstantSubexpressions(result, nodeCache); while (true) { - ExpressionTreeNode simplified = substituteSimplerExpression(result); + examples.clear(); + result.assignTags(examples); + nodeCache.clear(); + ExpressionTreeNode simplified = substituteSimplerExpression(result, nodeCache); if (simplified == result) break; result = simplified; @@ -80,9 +88,15 @@ ParsedExpression ParsedExpression::optimize() const { ParsedExpression ParsedExpression::optimize(const map& variables) const { ExpressionTreeNode result = preevaluateVariables(getRootNode(), variables); - result = precalculateConstantSubexpressions(result); + vector examples; + result.assignTags(examples); + map nodeCache; + result = precalculateConstantSubexpressions(result, nodeCache); while (true) { - ExpressionTreeNode simplified = substituteSimplerExpression(result); + examples.clear(); + result.assignTags(examples); + nodeCache.clear(); + ExpressionTreeNode simplified = substituteSimplerExpression(result, nodeCache); if (simplified == result) break; result = simplified; @@ -104,27 +118,44 @@ ExpressionTreeNode ParsedExpression::preevaluateVariables(const ExpressionTreeNo return ExpressionTreeNode(node.getOperation().clone(), children); } -ExpressionTreeNode ParsedExpression::precalculateConstantSubexpressions(const ExpressionTreeNode& node) { +ExpressionTreeNode ParsedExpression::precalculateConstantSubexpressions(const ExpressionTreeNode& node, map& nodeCache) { + auto cached = nodeCache.find(node.tag); + if (cached != nodeCache.end()) + return cached->second; vector children(node.getChildren().size()); for (int i = 0; i < (int) children.size(); i++) - children[i] = precalculateConstantSubexpressions(node.getChildren()[i]); + children[i] = precalculateConstantSubexpressions(node.getChildren()[i], nodeCache); ExpressionTreeNode result = ExpressionTreeNode(node.getOperation().clone(), children); - if (node.getOperation().getId() == Operation::VARIABLE || node.getOperation().getId() == Operation::CUSTOM) + if (node.getOperation().getId() == Operation::VARIABLE || node.getOperation().getId() == Operation::CUSTOM) { + nodeCache[node.tag] = result; return result; + } for (int i = 0; i < (int) children.size(); i++) - if (children[i].getOperation().getId() != Operation::CONSTANT) + if (children[i].getOperation().getId() != Operation::CONSTANT) { + nodeCache[node.tag] = result; return result; - return ExpressionTreeNode(new Operation::Constant(evaluate(result, map()))); + } + result = ExpressionTreeNode(new Operation::Constant(evaluate(result, map()))); + nodeCache[node.tag] = result; + return result; } -ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const ExpressionTreeNode& node) { +ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const ExpressionTreeNode& node, map& nodeCache) { vector children(node.getChildren().size()); - for (int i = 0; i < (int) children.size(); i++) - children[i] = substituteSimplerExpression(node.getChildren()[i]); + for (int i = 0; i < (int) children.size(); i++) { + const ExpressionTreeNode& child = node.getChildren()[i]; + auto cached = nodeCache.find(child.tag); + if (cached == nodeCache.end()) { + children[i] = substituteSimplerExpression(child, nodeCache); + nodeCache[child.tag] = children[i]; + } + else + children[i] = cached->second; + } // Collect some info on constant expressions in children bool first_const = children.size() > 0 && isConstant(children[0]); // is first child constant? - bool second_const = children.size() > 1 && isConstant(children[1]); ; // is second child constant? + bool second_const = children.size() > 1 && isConstant(children[1]); // is second child constant? double first, second; // if yes, value of first and second child if (first_const) first = getConstantValue(children[0]); @@ -296,6 +327,12 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio return children[0].getChildren()[0]; break; } + case Operation::SELECT: + { + if (children[1] == children[2]) // Select between two identical values + return children[1]; + break; + } default: { // If operation ID is not one of the above, @@ -308,14 +345,22 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio } ParsedExpression ParsedExpression::differentiate(const string& variable) const { - return differentiate(getRootNode(), variable); + vector examples; + getRootNode().assignTags(examples); + map nodeCache; + return differentiate(getRootNode(), variable, nodeCache); } -ExpressionTreeNode ParsedExpression::differentiate(const ExpressionTreeNode& node, const string& variable) { +ExpressionTreeNode ParsedExpression::differentiate(const ExpressionTreeNode& node, const string& variable, map& nodeCache) { + auto cached = nodeCache.find(node.tag); + if (cached != nodeCache.end()) + return cached->second; vector childDerivs(node.getChildren().size()); for (int i = 0; i < (int) childDerivs.size(); i++) - childDerivs[i] = differentiate(node.getChildren()[i], variable); - return node.getOperation().differentiate(node.getChildren(),childDerivs, variable); + childDerivs[i] = differentiate(node.getChildren()[i], variable, nodeCache); + ExpressionTreeNode result = node.getOperation().differentiate(node.getChildren(), childDerivs, variable); + nodeCache[node.tag] = result; + return result; } bool ParsedExpression::isConstant(const ExpressionTreeNode& node) { @@ -337,6 +382,10 @@ CompiledExpression ParsedExpression::createCompiledExpression() const { return CompiledExpression(*this); } +CompiledVectorExpression ParsedExpression::createCompiledVectorExpression(int width) const { + return CompiledVectorExpression(*this, width); +} + ParsedExpression ParsedExpression::renameVariables(const map& replacements) const { return ParsedExpression(renameNodeVariables(getRootNode(), replacements)); } From e59f99b44079e5bf60f458905b43ad7ec756fbf8 Mon Sep 17 00:00:00 2001 From: Axel Kohlmeyer Date: Thu, 22 Dec 2022 22:50:01 -0500 Subject: [PATCH 24/79] add support for JIT compilation --- cmake/Modules/Packages/LEPTON.cmake | 23 +- lib/lepton/asmjit/LICENSE.md | 17 + lib/lepton/asmjit/a64.h | 62 + lib/lepton/asmjit/arm.h | 62 + lib/lepton/asmjit/arm/a64archtraits_p.h | 81 + lib/lepton/asmjit/arm/a64assembler.cpp | 5115 ++++++++++++++++++++ lib/lepton/asmjit/arm/a64assembler.h | 72 + lib/lepton/asmjit/arm/a64builder.cpp | 51 + lib/lepton/asmjit/arm/a64builder.h | 57 + lib/lepton/asmjit/arm/a64compiler.cpp | 60 + lib/lepton/asmjit/arm/a64compiler.h | 247 + lib/lepton/asmjit/arm/a64emithelper.cpp | 464 ++ lib/lepton/asmjit/arm/a64emithelper_p.h | 50 + lib/lepton/asmjit/arm/a64emitter.h | 1228 +++++ lib/lepton/asmjit/arm/a64formatter.cpp | 298 ++ lib/lepton/asmjit/arm/a64formatter_p.h | 59 + lib/lepton/asmjit/arm/a64func.cpp | 189 + lib/lepton/asmjit/arm/a64func_p.h | 33 + lib/lepton/asmjit/arm/a64globals.h | 1894 ++++++++ lib/lepton/asmjit/arm/a64instapi.cpp | 278 ++ lib/lepton/asmjit/arm/a64instapi_p.h | 41 + lib/lepton/asmjit/arm/a64instdb.cpp | 1957 ++++++++ lib/lepton/asmjit/arm/a64instdb.h | 74 + lib/lepton/asmjit/arm/a64instdb_p.h | 876 ++++ lib/lepton/asmjit/arm/a64operand.cpp | 85 + lib/lepton/asmjit/arm/a64operand.h | 312 ++ lib/lepton/asmjit/arm/a64rapass.cpp | 852 ++++ lib/lepton/asmjit/arm/a64rapass_p.h | 105 + lib/lepton/asmjit/arm/a64utils.h | 179 + lib/lepton/asmjit/arm/armformatter.cpp | 143 + lib/lepton/asmjit/arm/armformatter_p.h | 44 + lib/lepton/asmjit/arm/armglobals.h | 21 + lib/lepton/asmjit/arm/armoperand.h | 621 +++ lib/lepton/asmjit/asmjit-scope-begin.h | 17 + lib/lepton/asmjit/asmjit-scope-end.h | 9 + lib/lepton/asmjit/asmjit.h | 33 + lib/lepton/asmjit/core.h | 1861 +++++++ lib/lepton/asmjit/core/api-build_p.h | 55 + lib/lepton/asmjit/core/api-config.h | 613 +++ lib/lepton/asmjit/core/archcommons.h | 229 + lib/lepton/asmjit/core/archtraits.cpp | 160 + lib/lepton/asmjit/core/archtraits.h | 290 ++ lib/lepton/asmjit/core/assembler.cpp | 406 ++ lib/lepton/asmjit/core/assembler.h | 129 + lib/lepton/asmjit/core/builder.cpp | 889 ++++ lib/lepton/asmjit/core/builder.h | 1391 ++++++ lib/lepton/asmjit/core/codebuffer.h | 113 + lib/lepton/asmjit/core/codeholder.cpp | 1149 +++++ lib/lepton/asmjit/core/codeholder.h | 1035 ++++ lib/lepton/asmjit/core/codewriter.cpp | 175 + lib/lepton/asmjit/core/codewriter_p.h | 179 + lib/lepton/asmjit/core/compiler.cpp | 582 +++ lib/lepton/asmjit/core/compiler.h | 737 +++ lib/lepton/asmjit/core/compilerdefs.h | 173 + lib/lepton/asmjit/core/constpool.cpp | 363 ++ lib/lepton/asmjit/core/constpool.h | 250 + lib/lepton/asmjit/core/cpuinfo.cpp | 1162 +++++ lib/lepton/asmjit/core/cpuinfo.h | 813 ++++ lib/lepton/asmjit/core/emithelper.cpp | 323 ++ lib/lepton/asmjit/core/emithelper_p.h | 58 + lib/lepton/asmjit/core/emitter.cpp | 333 ++ lib/lepton/asmjit/core/emitter.h | 741 +++ lib/lepton/asmjit/core/emitterutils.cpp | 129 + lib/lepton/asmjit/core/emitterutils_p.h | 89 + lib/lepton/asmjit/core/environment.cpp | 46 + lib/lepton/asmjit/core/environment.h | 508 ++ lib/lepton/asmjit/core/errorhandler.cpp | 14 + lib/lepton/asmjit/core/errorhandler.h | 228 + lib/lepton/asmjit/core/formatter.cpp | 584 +++ lib/lepton/asmjit/core/formatter.h | 247 + lib/lepton/asmjit/core/formatter_p.h | 34 + lib/lepton/asmjit/core/func.cpp | 286 ++ lib/lepton/asmjit/core/func.h | 1445 ++++++ lib/lepton/asmjit/core/funcargscontext.cpp | 293 ++ lib/lepton/asmjit/core/funcargscontext_p.h | 199 + lib/lepton/asmjit/core/globals.cpp | 133 + lib/lepton/asmjit/core/globals.h | 393 ++ lib/lepton/asmjit/core/inst.cpp | 113 + lib/lepton/asmjit/core/inst.h | 772 +++ lib/lepton/asmjit/core/jitallocator.cpp | 1242 +++++ lib/lepton/asmjit/core/jitallocator.h | 261 + lib/lepton/asmjit/core/jitruntime.cpp | 80 + lib/lepton/asmjit/core/jitruntime.h | 89 + lib/lepton/asmjit/core/logger.cpp | 69 + lib/lepton/asmjit/core/logger.h | 198 + lib/lepton/asmjit/core/misc_p.h | 33 + lib/lepton/asmjit/core/operand.cpp | 132 + lib/lepton/asmjit/core/operand.h | 1611 ++++++ lib/lepton/asmjit/core/osutils.cpp | 84 + lib/lepton/asmjit/core/osutils.h | 61 + lib/lepton/asmjit/core/osutils_p.h | 68 + lib/lepton/asmjit/core/raassignment_p.h | 418 ++ lib/lepton/asmjit/core/rabuilders_p.h | 612 +++ lib/lepton/asmjit/core/radefs_p.h | 1204 +++++ lib/lepton/asmjit/core/ralocal.cpp | 1166 +++++ lib/lepton/asmjit/core/ralocal_p.h | 254 + lib/lepton/asmjit/core/rapass.cpp | 1969 ++++++++ lib/lepton/asmjit/core/rapass_p.h | 1183 +++++ lib/lepton/asmjit/core/rastack.cpp | 184 + lib/lepton/asmjit/core/rastack_p.h | 171 + lib/lepton/asmjit/core/string.cpp | 559 +++ lib/lepton/asmjit/core/string.h | 372 ++ lib/lepton/asmjit/core/support.cpp | 494 ++ lib/lepton/asmjit/core/support.h | 1773 +++++++ lib/lepton/asmjit/core/target.cpp | 14 + lib/lepton/asmjit/core/target.h | 53 + lib/lepton/asmjit/core/type.cpp | 74 + lib/lepton/asmjit/core/type.h | 419 ++ lib/lepton/asmjit/core/virtmem.cpp | 722 +++ lib/lepton/asmjit/core/virtmem.h | 242 + lib/lepton/asmjit/core/zone.cpp | 353 ++ lib/lepton/asmjit/core/zone.h | 615 +++ lib/lepton/asmjit/core/zonehash.cpp | 309 ++ lib/lepton/asmjit/core/zonehash.h | 186 + lib/lepton/asmjit/core/zonelist.cpp | 163 + lib/lepton/asmjit/core/zonelist.h | 209 + lib/lepton/asmjit/core/zonestack.cpp | 176 + lib/lepton/asmjit/core/zonestack.h | 239 + lib/lepton/asmjit/core/zonestring.h | 120 + lib/lepton/asmjit/core/zonetree.cpp | 99 + lib/lepton/asmjit/core/zonetree.h | 380 ++ lib/lepton/asmjit/core/zonevector.cpp | 356 ++ lib/lepton/asmjit/core/zonevector.h | 690 +++ lib/lepton/asmjit/x86.h | 93 + lib/lepton/asmjit/x86/x86archtraits_p.h | 148 + lib/lepton/asmjit/x86/x86assembler.cpp | 5110 +++++++++++++++++++ lib/lepton/asmjit/x86/x86assembler.h | 685 +++ lib/lepton/asmjit/x86/x86builder.cpp | 52 + lib/lepton/asmjit/x86/x86builder.h | 351 ++ lib/lepton/asmjit/x86/x86compiler.cpp | 61 + lib/lepton/asmjit/x86/x86compiler.h | 721 +++ lib/lepton/asmjit/x86/x86emithelper.cpp | 619 +++ lib/lepton/asmjit/x86/x86emithelper_p.h | 60 + lib/lepton/asmjit/x86/x86emitter.h | 4315 +++++++++++++++++ lib/lepton/asmjit/x86/x86formatter.cpp | 944 ++++ lib/lepton/asmjit/x86/x86formatter_p.h | 58 + lib/lepton/asmjit/x86/x86func.cpp | 503 ++ lib/lepton/asmjit/x86/x86func_p.h | 33 + lib/lepton/asmjit/x86/x86globals.h | 2169 +++++++++ lib/lepton/asmjit/x86/x86instapi.cpp | 1732 +++++++ lib/lepton/asmjit/x86/x86instapi_p.h | 41 + lib/lepton/asmjit/x86/x86instdb.cpp | 4427 +++++++++++++++++ lib/lepton/asmjit/x86/x86instdb.h | 563 +++ lib/lepton/asmjit/x86/x86instdb_p.h | 311 ++ lib/lepton/asmjit/x86/x86opcode_p.h | 436 ++ lib/lepton/asmjit/x86/x86operand.cpp | 231 + lib/lepton/asmjit/x86/x86operand.h | 1085 +++++ lib/lepton/asmjit/x86/x86rapass.cpp | 1509 ++++++ lib/lepton/asmjit/x86/x86rapass_p.h | 94 + 149 files changed, 81486 insertions(+), 2 deletions(-) create mode 100644 lib/lepton/asmjit/LICENSE.md create mode 100644 lib/lepton/asmjit/a64.h create mode 100644 lib/lepton/asmjit/arm.h create mode 100644 lib/lepton/asmjit/arm/a64archtraits_p.h create mode 100644 lib/lepton/asmjit/arm/a64assembler.cpp create mode 100644 lib/lepton/asmjit/arm/a64assembler.h create mode 100644 lib/lepton/asmjit/arm/a64builder.cpp create mode 100644 lib/lepton/asmjit/arm/a64builder.h create mode 100644 lib/lepton/asmjit/arm/a64compiler.cpp create mode 100644 lib/lepton/asmjit/arm/a64compiler.h create mode 100644 lib/lepton/asmjit/arm/a64emithelper.cpp create mode 100644 lib/lepton/asmjit/arm/a64emithelper_p.h create mode 100644 lib/lepton/asmjit/arm/a64emitter.h create mode 100644 lib/lepton/asmjit/arm/a64formatter.cpp create mode 100644 lib/lepton/asmjit/arm/a64formatter_p.h create mode 100644 lib/lepton/asmjit/arm/a64func.cpp create mode 100644 lib/lepton/asmjit/arm/a64func_p.h create mode 100644 lib/lepton/asmjit/arm/a64globals.h create mode 100644 lib/lepton/asmjit/arm/a64instapi.cpp create mode 100644 lib/lepton/asmjit/arm/a64instapi_p.h create mode 100644 lib/lepton/asmjit/arm/a64instdb.cpp create mode 100644 lib/lepton/asmjit/arm/a64instdb.h create mode 100644 lib/lepton/asmjit/arm/a64instdb_p.h create mode 100644 lib/lepton/asmjit/arm/a64operand.cpp create mode 100644 lib/lepton/asmjit/arm/a64operand.h create mode 100644 lib/lepton/asmjit/arm/a64rapass.cpp create mode 100644 lib/lepton/asmjit/arm/a64rapass_p.h create mode 100644 lib/lepton/asmjit/arm/a64utils.h create mode 100644 lib/lepton/asmjit/arm/armformatter.cpp create mode 100644 lib/lepton/asmjit/arm/armformatter_p.h create mode 100644 lib/lepton/asmjit/arm/armglobals.h create mode 100644 lib/lepton/asmjit/arm/armoperand.h create mode 100644 lib/lepton/asmjit/asmjit-scope-begin.h create mode 100644 lib/lepton/asmjit/asmjit-scope-end.h create mode 100644 lib/lepton/asmjit/asmjit.h create mode 100644 lib/lepton/asmjit/core.h create mode 100644 lib/lepton/asmjit/core/api-build_p.h create mode 100644 lib/lepton/asmjit/core/api-config.h create mode 100644 lib/lepton/asmjit/core/archcommons.h create mode 100644 lib/lepton/asmjit/core/archtraits.cpp create mode 100644 lib/lepton/asmjit/core/archtraits.h create mode 100644 lib/lepton/asmjit/core/assembler.cpp create mode 100644 lib/lepton/asmjit/core/assembler.h create mode 100644 lib/lepton/asmjit/core/builder.cpp create mode 100644 lib/lepton/asmjit/core/builder.h create mode 100644 lib/lepton/asmjit/core/codebuffer.h create mode 100644 lib/lepton/asmjit/core/codeholder.cpp create mode 100644 lib/lepton/asmjit/core/codeholder.h create mode 100644 lib/lepton/asmjit/core/codewriter.cpp create mode 100644 lib/lepton/asmjit/core/codewriter_p.h create mode 100644 lib/lepton/asmjit/core/compiler.cpp create mode 100644 lib/lepton/asmjit/core/compiler.h create mode 100644 lib/lepton/asmjit/core/compilerdefs.h create mode 100644 lib/lepton/asmjit/core/constpool.cpp create mode 100644 lib/lepton/asmjit/core/constpool.h create mode 100644 lib/lepton/asmjit/core/cpuinfo.cpp create mode 100644 lib/lepton/asmjit/core/cpuinfo.h create mode 100644 lib/lepton/asmjit/core/emithelper.cpp create mode 100644 lib/lepton/asmjit/core/emithelper_p.h create mode 100644 lib/lepton/asmjit/core/emitter.cpp create mode 100644 lib/lepton/asmjit/core/emitter.h create mode 100644 lib/lepton/asmjit/core/emitterutils.cpp create mode 100644 lib/lepton/asmjit/core/emitterutils_p.h create mode 100644 lib/lepton/asmjit/core/environment.cpp create mode 100644 lib/lepton/asmjit/core/environment.h create mode 100644 lib/lepton/asmjit/core/errorhandler.cpp create mode 100644 lib/lepton/asmjit/core/errorhandler.h create mode 100644 lib/lepton/asmjit/core/formatter.cpp create mode 100644 lib/lepton/asmjit/core/formatter.h create mode 100644 lib/lepton/asmjit/core/formatter_p.h create mode 100644 lib/lepton/asmjit/core/func.cpp create mode 100644 lib/lepton/asmjit/core/func.h create mode 100644 lib/lepton/asmjit/core/funcargscontext.cpp create mode 100644 lib/lepton/asmjit/core/funcargscontext_p.h create mode 100644 lib/lepton/asmjit/core/globals.cpp create mode 100644 lib/lepton/asmjit/core/globals.h create mode 100644 lib/lepton/asmjit/core/inst.cpp create mode 100644 lib/lepton/asmjit/core/inst.h create mode 100644 lib/lepton/asmjit/core/jitallocator.cpp create mode 100644 lib/lepton/asmjit/core/jitallocator.h create mode 100644 lib/lepton/asmjit/core/jitruntime.cpp create mode 100644 lib/lepton/asmjit/core/jitruntime.h create mode 100644 lib/lepton/asmjit/core/logger.cpp create mode 100644 lib/lepton/asmjit/core/logger.h create mode 100644 lib/lepton/asmjit/core/misc_p.h create mode 100644 lib/lepton/asmjit/core/operand.cpp create mode 100644 lib/lepton/asmjit/core/operand.h create mode 100644 lib/lepton/asmjit/core/osutils.cpp create mode 100644 lib/lepton/asmjit/core/osutils.h create mode 100644 lib/lepton/asmjit/core/osutils_p.h create mode 100644 lib/lepton/asmjit/core/raassignment_p.h create mode 100644 lib/lepton/asmjit/core/rabuilders_p.h create mode 100644 lib/lepton/asmjit/core/radefs_p.h create mode 100644 lib/lepton/asmjit/core/ralocal.cpp create mode 100644 lib/lepton/asmjit/core/ralocal_p.h create mode 100644 lib/lepton/asmjit/core/rapass.cpp create mode 100644 lib/lepton/asmjit/core/rapass_p.h create mode 100644 lib/lepton/asmjit/core/rastack.cpp create mode 100644 lib/lepton/asmjit/core/rastack_p.h create mode 100644 lib/lepton/asmjit/core/string.cpp create mode 100644 lib/lepton/asmjit/core/string.h create mode 100644 lib/lepton/asmjit/core/support.cpp create mode 100644 lib/lepton/asmjit/core/support.h create mode 100644 lib/lepton/asmjit/core/target.cpp create mode 100644 lib/lepton/asmjit/core/target.h create mode 100644 lib/lepton/asmjit/core/type.cpp create mode 100644 lib/lepton/asmjit/core/type.h create mode 100644 lib/lepton/asmjit/core/virtmem.cpp create mode 100644 lib/lepton/asmjit/core/virtmem.h create mode 100644 lib/lepton/asmjit/core/zone.cpp create mode 100644 lib/lepton/asmjit/core/zone.h create mode 100644 lib/lepton/asmjit/core/zonehash.cpp create mode 100644 lib/lepton/asmjit/core/zonehash.h create mode 100644 lib/lepton/asmjit/core/zonelist.cpp create mode 100644 lib/lepton/asmjit/core/zonelist.h create mode 100644 lib/lepton/asmjit/core/zonestack.cpp create mode 100644 lib/lepton/asmjit/core/zonestack.h create mode 100644 lib/lepton/asmjit/core/zonestring.h create mode 100644 lib/lepton/asmjit/core/zonetree.cpp create mode 100644 lib/lepton/asmjit/core/zonetree.h create mode 100644 lib/lepton/asmjit/core/zonevector.cpp create mode 100644 lib/lepton/asmjit/core/zonevector.h create mode 100644 lib/lepton/asmjit/x86.h create mode 100644 lib/lepton/asmjit/x86/x86archtraits_p.h create mode 100644 lib/lepton/asmjit/x86/x86assembler.cpp create mode 100644 lib/lepton/asmjit/x86/x86assembler.h create mode 100644 lib/lepton/asmjit/x86/x86builder.cpp create mode 100644 lib/lepton/asmjit/x86/x86builder.h create mode 100644 lib/lepton/asmjit/x86/x86compiler.cpp create mode 100644 lib/lepton/asmjit/x86/x86compiler.h create mode 100644 lib/lepton/asmjit/x86/x86emithelper.cpp create mode 100644 lib/lepton/asmjit/x86/x86emithelper_p.h create mode 100644 lib/lepton/asmjit/x86/x86emitter.h create mode 100644 lib/lepton/asmjit/x86/x86formatter.cpp create mode 100644 lib/lepton/asmjit/x86/x86formatter_p.h create mode 100644 lib/lepton/asmjit/x86/x86func.cpp create mode 100644 lib/lepton/asmjit/x86/x86func_p.h create mode 100644 lib/lepton/asmjit/x86/x86globals.h create mode 100644 lib/lepton/asmjit/x86/x86instapi.cpp create mode 100644 lib/lepton/asmjit/x86/x86instapi_p.h create mode 100644 lib/lepton/asmjit/x86/x86instdb.cpp create mode 100644 lib/lepton/asmjit/x86/x86instdb.h create mode 100644 lib/lepton/asmjit/x86/x86instdb_p.h create mode 100644 lib/lepton/asmjit/x86/x86opcode_p.h create mode 100644 lib/lepton/asmjit/x86/x86operand.cpp create mode 100644 lib/lepton/asmjit/x86/x86operand.h create mode 100644 lib/lepton/asmjit/x86/x86rapass.cpp create mode 100644 lib/lepton/asmjit/x86/x86rapass_p.h diff --git a/cmake/Modules/Packages/LEPTON.cmake b/cmake/Modules/Packages/LEPTON.cmake index a1a74f3aa9..28ade58636 100644 --- a/cmake/Modules/Packages/LEPTON.cmake +++ b/cmake/Modules/Packages/LEPTON.cmake @@ -1,8 +1,27 @@ set(LEPTON_SOURCE_DIR ${LAMMPS_LIB_SOURCE_DIR}/lepton) file(GLOB LEPTON_SOURCES ${LEPTON_SOURCE_DIR}/src/[^.]*.cpp) -add_library(lmplepton STATIC ${LEPTON_SOURCES}) -target_compile_definitions(lmplepton PUBLIC -DLEPTON_BUILDING_STATIC_LIBRARY=1) + +if((CMAKE_HOST_SYSTEM_PROCESSOR STREQUAL "x86_64") OR + (CMAKE_HOST_SYSTEM_PROCESSOR STREQUAL "amd64") OR + (CMAKE_HOST_SYSTEM_PROCESSOR STREQUAL "AMD64")) + option(LEPTON_ENABLE_JIT "Enable Just-In-Time compiler for Lepton" ON) +else() + option(LEPTON_ENABLE_JIT "Enable Just-In-Time compiler for Lepton" OFF) +endif() + +if(LEPTON_ENABLE_JIT) + file(GLOB ASMJIT_SOURCES ${LEPTON_SOURCE_DIR}/asmjit/*/[^.]*.cpp) +endif() + +add_library(lmplepton STATIC ${LEPTON_SOURCES} ${ASMJIT_SOURCES}) set_target_properties(lmplepton PROPERTIES OUTPUT_NAME lammps_lmplepton${LAMMPS_MACHINE}) +target_compile_definitions(lmplepton PUBLIC -DLEPTON_BUILDING_STATIC_LIBRARY=1) target_include_directories(lmplepton PUBLIC ${LEPTON_SOURCE_DIR}/include) + +if(LEPTON_ENABLE_JIT) + target_compile_definitions(lmplepton PUBLIC "LEPTON_USE_JIT=1;ASMJIT_BUILD_X86=1;ASMJIT_EMBED=1;ASMJIT_BUILD_RELEASE=1") + target_include_directories(lmplepton PUBLIC ${LEPTON_SOURCE_DIR}) +endif() + target_link_libraries(lammps PRIVATE lmplepton) diff --git a/lib/lepton/asmjit/LICENSE.md b/lib/lepton/asmjit/LICENSE.md new file mode 100644 index 0000000000..020a569dbd --- /dev/null +++ b/lib/lepton/asmjit/LICENSE.md @@ -0,0 +1,17 @@ +Copyright (c) 2008-2020 The AsmJit Authors + +This software is provided 'as-is', without any express or implied +warranty. In no event will the authors be held liable for any damages +arising from the use of this software. + +Permission is granted to anyone to use this software for any purpose, +including commercial applications, and to alter it and redistribute it +freely, subject to the following restrictions: + +1. The origin of this software must not be misrepresented; you must not + claim that you wrote the original software. If you use this software + in a product, an acknowledgment in the product documentation would be + appreciated but is not required. +2. Altered source versions must be plainly marked as such, and must not be + misrepresented as being the original software. +3. This notice may not be removed or altered from any source distribution. diff --git a/lib/lepton/asmjit/a64.h b/lib/lepton/asmjit/a64.h new file mode 100644 index 0000000000..ea4d304f05 --- /dev/null +++ b/lib/lepton/asmjit/a64.h @@ -0,0 +1,62 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_A64_H_INCLUDED +#define ASMJIT_A64_H_INCLUDED + +//! \addtogroup asmjit_a64 +//! +//! ### Emitters +//! +//! - \ref a64::Assembler - AArch64 assembler (must read, provides examples). +//! - \ref a64::Builder - AArch64 builder. +//! - \ref a64::Compiler - AArch64 compiler. +//! - \ref a64::Emitter - AArch64 emitter (abstract). +//! +//! ### Supported Instructions +//! +//! - Emitters: +//! - \ref a64::EmitterExplicitT - Provides all instructions that use explicit +//! operands, provides also utility functions. The member functions provided +//! are part of all ARM/AArch64 emitters. +//! +//! - Instruction representation: +//! - \ref a64::Inst::Id - instruction identifiers. +//! +//! ### Register Operands +//! +//! - \ref arm::Reg - Base class for any AArch32/AArch64 register. +//! - \ref arm::Gp - General purpose register: +//! - \ref arm::GpW - 32-bit register. +//! - \ref arm::GpX - 64-bit register. +//! - \ref arm::Vec - Vector (SIMD) register: +//! - \ref arm::VecB - 8-bit SIMD register (AArch64 only). +//! - \ref arm::VecH - 16-bit SIMD register (AArch64 only). +//! - \ref arm::VecS - 32-bit SIMD register. +//! - \ref arm::VecD - 64-bit SIMD register. +//! - \ref arm::VecV - 128-bit SIMD register. +//! +//! ### Memory Operands +//! +//! - \ref arm::Mem - AArch32/AArch64 memory operand that provides support for all ARM addressing features +//! including base, index, pre/post increment, and ARM-specific shift addressing and index extending. +//! +//! ### Other +//! +//! - \ref arm::Shift - Shift operation and value. +//! - \ref a64::Utils - Utilities that can help during code generation for AArch64. + +#include "./arm.h" +#include "./arm/a64assembler.h" +#include "./arm/a64builder.h" +#include "./arm/a64compiler.h" +#include "./arm/a64emitter.h" +#include "./arm/a64globals.h" +#include "./arm/a64instdb.h" +#include "./arm/a64operand.h" +#include "./arm/a64utils.h" + +#endif // ASMJIT_A64_H_INCLUDED + diff --git a/lib/lepton/asmjit/arm.h b/lib/lepton/asmjit/arm.h new file mode 100644 index 0000000000..57ffa815b8 --- /dev/null +++ b/lib/lepton/asmjit/arm.h @@ -0,0 +1,62 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_ARM_H_INCLUDED +#define ASMJIT_ARM_H_INCLUDED + +//! \addtogroup asmjit_arm +//! +//! ### Namespaces +//! +//! - \ref arm - arm namespace provides common functionality for both AArch32 and AArch64 backends. +//! - \ref a64 - a64 namespace provides support for AArch64 architecture. In addition it includes +//! \ref arm namespace, so you can only use a single namespace when targeting AArch64 architecture. +//! +//! ### Emitters +//! +//! - AArch64 +//! - \ref a64::Assembler - AArch64 assembler (must read, provides examples). +//! - \ref a64::Builder - AArch64 builder. +//! - \ref a64::Compiler - AArch64 compiler. +//! - \ref a64::Emitter - AArch64 emitter (abstract). +//! +//! ### Supported Instructions +//! +//! - AArch64: +//! - Emitters: +//! - \ref a64::EmitterExplicitT - Provides all instructions that use explicit operands, provides also +//! utility functions. The member functions provided are part of all AArch64 emitters. +//! - Instruction representation: +//! - \ref a64::Inst::Id - instruction identifiers. +//! +//! ### Register Operands +//! +//! - \ref arm::Reg - Base class for any AArch32/AArch64 register. +//! - \ref arm::Gp - General purpose register: +//! - \ref arm::GpW - 32-bit register. +//! - \ref arm::GpX - 64-bit register. +//! - \ref arm::Vec - Vector (SIMD) register: +//! - \ref arm::VecB - 8-bit SIMD register (AArch64 only). +//! - \ref arm::VecH - 16-bit SIMD register (AArch64 only). +//! - \ref arm::VecS - 32-bit SIMD register. +//! - \ref arm::VecD - 64-bit SIMD register. +//! - \ref arm::VecV - 128-bit SIMD register. +//! +//! ### Memory Operands +//! +//! - \ref arm::Mem - AArch32/AArch64 memory operand that provides support for all ARM addressing features +//! including base, index, pre/post increment, and ARM-specific shift addressing and index extending. +//! +//! ### Other +//! +//! - \ref arm::Shift - Shift operation and value (both AArch32 and AArch64). +//! - \ref arm::DataType - Data type that is part of an instruction in AArch32 mode. +//! - \ref a64::Utils - Utilities that can help during code generation for AArch64. + +#include "./core.h" +#include "./arm/armglobals.h" +#include "./arm/armoperand.h" + +#endif // ASMJIT_ARM_H_INCLUDED diff --git a/lib/lepton/asmjit/arm/a64archtraits_p.h b/lib/lepton/asmjit/arm/a64archtraits_p.h new file mode 100644 index 0000000000..87559c71d5 --- /dev/null +++ b/lib/lepton/asmjit/arm/a64archtraits_p.h @@ -0,0 +1,81 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_ARM_A64ARCHTRAITS_P_H_INCLUDED +#define ASMJIT_ARM_A64ARCHTRAITS_P_H_INCLUDED + +#include "../core/archtraits.h" +#include "../core/misc_p.h" +#include "../core/type.h" +#include "../arm/a64operand.h" + +ASMJIT_BEGIN_SUB_NAMESPACE(a64) + +//! \cond INTERNAL +//! \addtogroup asmjit_a64 +//! \{ + +static const constexpr ArchTraits a64ArchTraits = { + // SP/FP/LR/PC. + Gp::kIdSp, Gp::kIdFp, Gp::kIdLr, 0xFF, + + // Reserved. + { 0, 0, 0 }, + + // HW stack alignment (AArch64 requires stack aligned to 64 bytes). + 16, + + // Min/max stack offset - byte addressing is the worst, VecQ addressing the best. + 4095, 65520, + + // Instruction hints [Gp, Vec, ExtraVirt2, ExtraVirt3]. + {{ + InstHints::kPushPop, + InstHints::kPushPop, + InstHints::kNoHints, + InstHints::kNoHints + }}, + + // RegInfo. + #define V(index) OperandSignature{arm::RegTraits::kSignature} + {{ ASMJIT_LOOKUP_TABLE_32(V, 0) }}, + #undef V + + // RegTypeToTypeId. + #define V(index) TypeId(arm::RegTraits::kTypeId) + {{ ASMJIT_LOOKUP_TABLE_32(V, 0) }}, + #undef V + + // TypeIdToRegType. + #define V(index) (index + uint32_t(TypeId::_kBaseStart) == uint32_t(TypeId::kInt8) ? RegType::kARM_GpW : \ + index + uint32_t(TypeId::_kBaseStart) == uint32_t(TypeId::kUInt8) ? RegType::kARM_GpW : \ + index + uint32_t(TypeId::_kBaseStart) == uint32_t(TypeId::kInt16) ? RegType::kARM_GpW : \ + index + uint32_t(TypeId::_kBaseStart) == uint32_t(TypeId::kUInt16) ? RegType::kARM_GpW : \ + index + uint32_t(TypeId::_kBaseStart) == uint32_t(TypeId::kInt32) ? RegType::kARM_GpW : \ + index + uint32_t(TypeId::_kBaseStart) == uint32_t(TypeId::kUInt32) ? RegType::kARM_GpW : \ + index + uint32_t(TypeId::_kBaseStart) == uint32_t(TypeId::kInt64) ? RegType::kARM_GpX : \ + index + uint32_t(TypeId::_kBaseStart) == uint32_t(TypeId::kUInt64) ? RegType::kARM_GpX : \ + index + uint32_t(TypeId::_kBaseStart) == uint32_t(TypeId::kIntPtr) ? RegType::kARM_GpX : \ + index + uint32_t(TypeId::_kBaseStart) == uint32_t(TypeId::kUIntPtr) ? RegType::kARM_GpX : \ + index + uint32_t(TypeId::_kBaseStart) == uint32_t(TypeId::kFloat32) ? RegType::kARM_VecS : \ + index + uint32_t(TypeId::_kBaseStart) == uint32_t(TypeId::kFloat64) ? RegType::kARM_VecD : RegType::kNone) + {{ ASMJIT_LOOKUP_TABLE_32(V, 0) }}, + #undef V + + // Word names of 8-bit, 16-bit, 32-bit, and 64-bit quantities. + { + ArchTypeNameId::kByte, + ArchTypeNameId::kHWord, + ArchTypeNameId::kWord, + ArchTypeNameId::kXWord + } +}; + +//! \} +//! \endcond + +ASMJIT_END_SUB_NAMESPACE + +#endif // ASMJIT_ARM_A64ARCHTRAITS_P_H_INCLUDED diff --git a/lib/lepton/asmjit/arm/a64assembler.cpp b/lib/lepton/asmjit/arm/a64assembler.cpp new file mode 100644 index 0000000000..485f05f491 --- /dev/null +++ b/lib/lepton/asmjit/arm/a64assembler.cpp @@ -0,0 +1,5115 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#include "../core/api-build_p.h" +#if !defined(ASMJIT_NO_AARCH64) + +#include "../core/codewriter_p.h" +#include "../core/cpuinfo.h" +#include "../core/emitterutils_p.h" +#include "../core/formatter.h" +#include "../core/logger.h" +#include "../core/misc_p.h" +#include "../core/support.h" +#include "../arm/armformatter_p.h" +#include "../arm/a64assembler.h" +#include "../arm/a64emithelper_p.h" +#include "../arm/a64instdb_p.h" +#include "../arm/a64utils.h" + +ASMJIT_BEGIN_SUB_NAMESPACE(a64) + +// a64::Assembler - Cond +// ===================== + +static inline uint32_t condCodeToOpcodeCond(uint32_t cond) noexcept { + return (uint32_t(cond) - 2u) & 0xFu; +} + +// a64::Assembler - Bits +// ===================== + +template +static inline constexpr uint32_t B(const T& index) noexcept { return uint32_t(1u) << uint32_t(index); } + +static constexpr uint32_t kSP = Gp::kIdSp; +static constexpr uint32_t kZR = Gp::kIdZr; +static constexpr uint32_t kWX = InstDB::kWX; + +// a64::Assembler - ShiftOpToLdStOptMap +// ==================================== + +// Table that maps ShiftOp to OPT part in LD/ST (register) opcode. +#define VALUE(index) index == uint32_t(ShiftOp::kUXTW) ? 2u : \ + index == uint32_t(ShiftOp::kLSL) ? 3u : \ + index == uint32_t(ShiftOp::kSXTW) ? 6u : \ + index == uint32_t(ShiftOp::kSXTX) ? 7u : 0xFF +static const uint8_t armShiftOpToLdStOptMap[] = { ASMJIT_LOOKUP_TABLE_16(VALUE, 0) }; +#undef VALUE + +static inline constexpr uint32_t diff(RegType a, RegType b) noexcept { + return uint32_t(a) - uint32_t(b); +} + +// asmjit::a64::Assembler - SizeOp +// =============================== + +//! Struct that contains Size (2 bits), Q flag, and S (scalar) flag. These values +//! are used to encode Q, Size, and Scalar fields in an opcode. +struct SizeOp { + enum : uint8_t { + k128BitShift = 0, + kScalarShift = 1, + kSizeShift = 2, + + kQ = uint8_t(1u << k128BitShift), + kS = uint8_t(1u << kScalarShift), + + k00 = uint8_t(0 << kSizeShift), + k01 = uint8_t(1 << kSizeShift), + k10 = uint8_t(2 << kSizeShift), + k11 = uint8_t(3 << kSizeShift), + + k00Q = k00 | kQ, + k01Q = k01 | kQ, + k10Q = k10 | kQ, + k11Q = k11 | kQ, + + k00S = k00 | kS, + k01S = k01 | kS, + k10S = k10 | kS, + k11S = k11 | kS, + + kInvalid = 0xFFu, + + // Masks used by SizeOpMap. + kSzQ = (0x3u << kSizeShift) | kQ, + kSzS = (0x3u << kSizeShift) | kS, + kSzQS = (0x3u << kSizeShift) | kQ | kS + }; + + uint8_t value; + + inline bool isValid() const noexcept { return value != kInvalid; } + inline void makeInvalid() noexcept { value = kInvalid; } + + inline uint32_t q() const noexcept { return (value >> k128BitShift) & 0x1u; } + inline uint32_t qs() const noexcept { return ((value >> k128BitShift) | (value >> kScalarShift)) & 0x1u; } + inline uint32_t scalar() const noexcept { return (value >> kScalarShift) & 0x1u; } + inline uint32_t size() const noexcept { return (value >> kSizeShift) & 0x3u; } + + inline void decrementSize() noexcept { + ASMJIT_ASSERT(size() > 0); + value = uint8_t(value - (1u << kSizeShift)); + } +}; + +struct SizeOpTable { + enum TableId : uint8_t { + kTableBin = 0, + kTableAny, + kCount + }; + + // 40 elements for each combination. + SizeOp array[(uint32_t(RegType::kARM_VecV) - uint32_t(RegType::kARM_VecB) + 1) * 8]; +}; + +#define VALUE_BIN(x) { \ + x == (((uint32_t(RegType::kARM_VecD) - uint32_t(RegType::kARM_VecB)) << 3) | (Vec::kElementTypeNone)) ? SizeOp::k00 : \ + x == (((uint32_t(RegType::kARM_VecV) - uint32_t(RegType::kARM_VecB)) << 3) | (Vec::kElementTypeNone)) ? SizeOp::k00Q : \ + x == (((uint32_t(RegType::kARM_VecD) - uint32_t(RegType::kARM_VecB)) << 3) | (Vec::kElementTypeB )) ? SizeOp::k00 : \ + x == (((uint32_t(RegType::kARM_VecV) - uint32_t(RegType::kARM_VecB)) << 3) | (Vec::kElementTypeB )) ? SizeOp::k00Q : SizeOp::kInvalid \ +} + +#define VALUE_ANY(x) { \ + x == (((uint32_t(RegType::kARM_VecB) - uint32_t(RegType::kARM_VecB)) << 3) | (Vec::kElementTypeNone)) ? SizeOp::k00S : \ + x == (((uint32_t(RegType::kARM_VecH) - uint32_t(RegType::kARM_VecB)) << 3) | (Vec::kElementTypeNone)) ? SizeOp::k01S : \ + x == (((uint32_t(RegType::kARM_VecS) - uint32_t(RegType::kARM_VecB)) << 3) | (Vec::kElementTypeNone)) ? SizeOp::k10S : \ + x == (((uint32_t(RegType::kARM_VecD) - uint32_t(RegType::kARM_VecB)) << 3) | (Vec::kElementTypeNone)) ? SizeOp::k11S : \ + x == (((uint32_t(RegType::kARM_VecD) - uint32_t(RegType::kARM_VecB)) << 3) | (Vec::kElementTypeB )) ? SizeOp::k00 : \ + x == (((uint32_t(RegType::kARM_VecV) - uint32_t(RegType::kARM_VecB)) << 3) | (Vec::kElementTypeB )) ? SizeOp::k00Q : \ + x == (((uint32_t(RegType::kARM_VecD) - uint32_t(RegType::kARM_VecB)) << 3) | (Vec::kElementTypeH )) ? SizeOp::k01 : \ + x == (((uint32_t(RegType::kARM_VecV) - uint32_t(RegType::kARM_VecB)) << 3) | (Vec::kElementTypeH )) ? SizeOp::k01Q : \ + x == (((uint32_t(RegType::kARM_VecD) - uint32_t(RegType::kARM_VecB)) << 3) | (Vec::kElementTypeS )) ? SizeOp::k10 : \ + x == (((uint32_t(RegType::kARM_VecV) - uint32_t(RegType::kARM_VecB)) << 3) | (Vec::kElementTypeS )) ? SizeOp::k10Q : \ + x == (((uint32_t(RegType::kARM_VecD) - uint32_t(RegType::kARM_VecB)) << 3) | (Vec::kElementTypeD )) ? SizeOp::k11S : \ + x == (((uint32_t(RegType::kARM_VecV) - uint32_t(RegType::kARM_VecB)) << 3) | (Vec::kElementTypeD )) ? SizeOp::k11Q : SizeOp::kInvalid \ +} + +static const SizeOpTable sizeOpTable[SizeOpTable::kCount] = { + {{ ASMJIT_LOOKUP_TABLE_40(VALUE_BIN, 0) }}, + {{ ASMJIT_LOOKUP_TABLE_40(VALUE_ANY, 0) }} +}; + +#undef VALUE_ANY +#undef VALUE_BIN + +struct SizeOpMap { + uint8_t tableId; + uint8_t sizeOpMask; + uint16_t acceptMask; +}; + +static const constexpr SizeOpMap sizeOpMap[InstDB::kVO_Count] = { + { // kVO_V_B: + SizeOpTable::kTableBin, SizeOp::kQ , uint16_t(B(SizeOp::k00) | B(SizeOp::k00Q)) + }, + + { // kVO_V_BH: + SizeOpTable::kTableAny, SizeOp::kSzQS, uint16_t(B(SizeOp::k00) | B(SizeOp::k00Q) | B(SizeOp::k01) | B(SizeOp::k01Q)) + }, + + { // kVO_V_BH_4S: + SizeOpTable::kTableAny, SizeOp::kSzQS, uint16_t(B(SizeOp::k00) | B(SizeOp::k00Q) | B(SizeOp::k01) | B(SizeOp::k01Q) | B(SizeOp::k10Q)) + }, + + { // kVO_V_BHS: + SizeOpTable::kTableAny, SizeOp::kSzQS, uint16_t(B(SizeOp::k00) | B(SizeOp::k00Q) | B(SizeOp::k01) | B(SizeOp::k01Q) | B(SizeOp::k10) | B(SizeOp::k10Q)) + }, + + { // kVO_V_BHS_D2: + SizeOpTable::kTableAny, SizeOp::kSzQS, uint16_t(B(SizeOp::k00) | B(SizeOp::k00Q) | B(SizeOp::k01) | B(SizeOp::k01Q) | B(SizeOp::k10) | B(SizeOp::k10Q) | B(SizeOp::k11Q)) + }, + + { // kVO_V_HS: + SizeOpTable::kTableAny, SizeOp::kSzQS, uint16_t(B(SizeOp::k01) | B(SizeOp::k01Q) | B(SizeOp::k10) | B(SizeOp::k10Q)) + }, + + { // kVO_V_S: + SizeOpTable::kTableAny, SizeOp::kQ , uint16_t(B(SizeOp::k10) | B(SizeOp::k10Q)) + }, + + { // kVO_V_B8H4: + SizeOpTable::kTableAny, SizeOp::kSzQS, uint16_t(B(SizeOp::k00) | B(SizeOp::k01)) + }, + + { // kVO_V_B8H4S2: + SizeOpTable::kTableAny, SizeOp::kSzQS, uint16_t(B(SizeOp::k00) | B(SizeOp::k01) | B(SizeOp::k10)) + }, + + { // kVO_V_B8D1: + SizeOpTable::kTableAny, SizeOp::kSzQ , uint16_t(B(SizeOp::k00) | B(SizeOp::k11S)) + }, + + { // kVO_V_H4S2: + SizeOpTable::kTableAny, SizeOp::kSzQS, uint16_t(B(SizeOp::k01) | B(SizeOp::k10)) + }, + + { // kVO_V_B16: + SizeOpTable::kTableBin, SizeOp::kQ , uint16_t(B(SizeOp::k00Q)) + }, + + { // kVO_V_B16H8: + SizeOpTable::kTableAny, SizeOp::kSzQS, uint16_t(B(SizeOp::k00Q) | B(SizeOp::k01Q)) + }, + + { // kVO_V_B16H8S4: + SizeOpTable::kTableAny, SizeOp::kSzQS, uint16_t(B(SizeOp::k00Q) | B(SizeOp::k01Q) | B(SizeOp::k10Q)) + }, + + { // kVO_V_B16D2: + SizeOpTable::kTableAny, SizeOp::kSzQS, uint16_t(B(SizeOp::k00Q) | B(SizeOp::k11Q)) + }, + + { // kVO_V_H8S4: + SizeOpTable::kTableAny, SizeOp::kSzQS, uint16_t(B(SizeOp::k01Q) | B(SizeOp::k10Q)) + }, + + { // kVO_V_S4: + SizeOpTable::kTableAny, 0 , uint16_t(B(SizeOp::k10Q)) + }, + + { // kVO_V_D2: + SizeOpTable::kTableAny, 0 , uint16_t(B(SizeOp::k11Q)) + }, + + { // kVO_SV_BHS: + SizeOpTable::kTableAny, SizeOp::kSzQS, uint16_t(B(SizeOp::k00) | B(SizeOp::k00Q) | B(SizeOp::k00S) | B(SizeOp::k01) | B(SizeOp::k01Q) | B(SizeOp::k01S) | B(SizeOp::k10) | B(SizeOp::k10Q) | B(SizeOp::k10S)) + }, + + { // kVO_SV_B8H4S2: + SizeOpTable::kTableAny, SizeOp::kSzQS, uint16_t(B(SizeOp::k00) | B(SizeOp::k00S) | B(SizeOp::k01) | B(SizeOp::k01S) | B(SizeOp::k10) | B(SizeOp::k10S)) + }, + + { // kVO_SV_HS: + SizeOpTable::kTableAny, SizeOp::kSzQS, uint16_t(B(SizeOp::k01) | B(SizeOp::k01Q) | B(SizeOp::k01S) | B(SizeOp::k10) | B(SizeOp::k10Q) | B(SizeOp::k10S)) + }, + + { // kVO_V_Any: + SizeOpTable::kTableAny, SizeOp::kSzQS, uint16_t(B(SizeOp::k00) | B(SizeOp::k00Q) | B(SizeOp::k01) | B(SizeOp::k01Q) | B(SizeOp::k10) | B(SizeOp::k10Q) | B(SizeOp::k11S) | B(SizeOp::k11Q)) + }, + + { // kVO_SV_Any: + SizeOpTable::kTableAny, SizeOp::kSzQS, uint16_t(B(SizeOp::k00) | B(SizeOp::k00Q) | B(SizeOp::k00S) | + B(SizeOp::k01) | B(SizeOp::k01Q) | B(SizeOp::k01S) | + B(SizeOp::k10) | B(SizeOp::k10Q) | B(SizeOp::k10S) | + B(SizeOp::k11) | B(SizeOp::k11Q) | B(SizeOp::k11S)) + } +}; + +static const Operand_& significantSimdOp(const Operand_& o0, const Operand_& o1, uint32_t instFlags) noexcept { + return !(instFlags & InstDB::kInstFlagLong) ? o0 : o1; +} + +static inline SizeOp armElementTypeToSizeOp(uint32_t vecOpType, RegType regType, uint32_t elementType) noexcept { + // Instruction data or Assembler is wrong if this triggers an assertion failure. + ASMJIT_ASSERT(vecOpType < InstDB::kVO_Count); + // ElementType uses 3 bits in the operand signature, it should never overflow. + ASMJIT_ASSERT(elementType <= 0x7u); + + const SizeOpMap& map = sizeOpMap[vecOpType]; + const SizeOpTable& table = sizeOpTable[map.tableId]; + + size_t index = (Support::min(diff(regType, RegType::kARM_VecB), diff(RegType::kARM_VecV, RegType::kARM_VecB) + 1) << 3) | elementType; + SizeOp op = table.array[index]; + SizeOp modifiedOp { uint8_t(op.value & map.sizeOpMask) }; + + if (!Support::bitTest(map.acceptMask, op.value)) + modifiedOp.makeInvalid(); + + return modifiedOp; +} + +// a64::Assembler - Immediate Encoding Utilities (Integral) +// ======================================================== + +using Utils::LogicalImm; + +struct HalfWordImm { + uint32_t hw; + uint32_t inv; + uint32_t imm; +}; + +struct LMHImm { + uint32_t lm; + uint32_t h; + uint32_t maxRmId; +}; + +static inline uint32_t countZeroHalfWords64(uint64_t imm) noexcept { + return uint32_t((imm & 0x000000000000FFFFu) == 0) + + uint32_t((imm & 0x00000000FFFF0000u) == 0) + + uint32_t((imm & 0x0000FFFF00000000u) == 0) + + uint32_t((imm & 0xFFFF000000000000u) == 0) ; +} + +static uint32_t encodeMovSequence32(uint32_t out[2], uint32_t imm, uint32_t rd, uint32_t x) noexcept { + ASMJIT_ASSERT(rd <= 31); + + uint32_t kMovZ = 0b01010010100000000000000000000000 | (x << 31); + uint32_t kMovN = 0b00010010100000000000000000000000; + uint32_t kMovK = 0b01110010100000000000000000000000; + + if ((imm & 0xFFFF0000u) == 0x00000000u) { + out[0] = kMovZ | (0 << 21) | ((imm & 0xFFFFu) << 5) | rd; + return 1; + } + + if ((imm & 0xFFFF0000u) == 0xFFFF0000u) { + out[0] = kMovN | (0 << 21) | ((~imm & 0xFFFFu) << 5) | rd; + return 1; + } + + if ((imm & 0x0000FFFFu) == 0x00000000u) { + out[0] = kMovZ | (1 << 21) | ((imm >> 16) << 5) | rd; + return 1; + } + + if ((imm & 0x0000FFFFu) == 0x0000FFFFu) { + out[0] = kMovN | (1 << 21) | ((~imm >> 16) << 5) | rd; + return 1; + } + + out[0] = kMovZ | (0 << 21) | ((imm & 0xFFFFu) << 5) | rd; + out[1] = kMovK | (1 << 21) | ((imm >> 16) << 5) | rd; + return 2; +} + +static uint32_t encodeMovSequence64(uint32_t out[4], uint64_t imm, uint32_t rd, uint32_t x) noexcept { + ASMJIT_ASSERT(rd <= 31); + + uint32_t kMovZ = 0b11010010100000000000000000000000; + uint32_t kMovN = 0b10010010100000000000000000000000; + uint32_t kMovK = 0b11110010100000000000000000000000; + + if (imm <= 0xFFFFFFFFu) + return encodeMovSequence32(out, uint32_t(imm), rd, x); + + uint32_t zhw = countZeroHalfWords64( imm); + uint32_t ohw = countZeroHalfWords64(~imm); + + if (zhw >= ohw) { + uint32_t op = kMovZ; + uint32_t count = 0; + + for (uint32_t hwIndex = 0; hwIndex < 4; hwIndex++, imm >>= 16) { + uint32_t hwImm = uint32_t(imm & 0xFFFFu); + if (hwImm == 0) + continue; + + out[count++] = op | (hwIndex << 21) | (hwImm << 5) | rd; + op = kMovK; + } + + // This should not happen - zero should be handled by encodeMovSequence32(). + ASMJIT_ASSERT(count > 0); + + return count; + } + else { + uint32_t op = kMovN; + uint32_t count = 0; + uint32_t negMask = 0xFFFFu; + + for (uint32_t hwIndex = 0; hwIndex < 4; hwIndex++, imm >>= 16) { + uint32_t hwImm = uint32_t(imm & 0xFFFFu); + if (hwImm == 0xFFFFu) + continue; + + out[count++] = op | (hwIndex << 21) | ((hwImm ^ negMask) << 5) | rd; + op = kMovK; + negMask = 0; + } + + if (count == 0) { + out[count++] = kMovN | ((0xFFFF ^ negMask) << 5) | rd; + } + + return count; + } +} + +static inline bool encodeLMH(uint32_t sizeField, uint32_t elementIndex, LMHImm* out) noexcept { + if (sizeField != 1 && sizeField != 2) + return false; + + uint32_t hShift = 3u - sizeField; + uint32_t lmShift = sizeField - 1u; + uint32_t maxElementIndex = 15u >> sizeField; + + out->h = elementIndex >> hShift; + out->lm = (elementIndex << lmShift) & 0x3u; + out->maxRmId = (8u << sizeField) - 1; + + return elementIndex <= maxElementIndex; +} + +// [.......A|B.......|.......C|D.......|.......E|F.......|.......G|H.......] +static inline uint32_t encodeImm64ByteMaskToImm8(uint64_t imm) noexcept { + return uint32_t(((imm >> (7 - 0)) & 0b00000011) | // [.......G|H.......] + ((imm >> (23 - 2)) & 0b00001100) | // [.......E|F.......] + ((imm >> (39 - 4)) & 0b00110000) | // [.......C|D.......] + ((imm >> (55 - 6)) & 0b11000000)); // [.......A|B.......] +} + +// a64::Assembler - Opcode +// ======================= + +//! Helper class to store and manipulate ARM opcode. +struct Opcode { + uint32_t v; + + enum Bits : uint32_t { + kN = (1u << 22), + kQ = (1u << 30), + kX = (1u << 31) + }; + + // -------------------------------------------------------------------------- + // [Opcode Builder] + // -------------------------------------------------------------------------- + + inline uint32_t get() const noexcept { return v; } + inline void reset(uint32_t value) noexcept { v = value; } + + inline bool hasQ() const noexcept { return (v & kQ) != 0; } + inline bool hasX() const noexcept { return (v & kX) != 0; } + + template + inline Opcode& addImm(T value, uint32_t bitIndex) noexcept { return operator|=(uint32_t(value) << bitIndex); } + + template + inline Opcode& xorImm(T value, uint32_t bitIndex) noexcept { return operator^=(uint32_t(value) << bitIndex); } + + template + inline Opcode& addIf(T value, const Condition& condition) noexcept { return operator|=(condition ? uint32_t(value) : uint32_t(0)); } + + inline Opcode& addLogicalImm(const LogicalImm& logicalImm) noexcept { + addImm(logicalImm.n, 22); + addImm(logicalImm.r, 16); + addImm(logicalImm.s, 10); + return *this; + } + + inline Opcode& addReg(uint32_t id, uint32_t bitIndex) noexcept { return operator|=((id & 31u) << bitIndex); } + inline Opcode& addReg(const Operand_& op, uint32_t bitIndex) noexcept { return addReg(op.id(), bitIndex); } + + inline Opcode& operator=(uint32_t x) noexcept { v = x; return *this; } + inline Opcode& operator&=(uint32_t x) noexcept { v &= x; return *this; } + inline Opcode& operator|=(uint32_t x) noexcept { v |= x; return *this; } + inline Opcode& operator^=(uint32_t x) noexcept { v ^= x; return *this; } + + inline uint32_t operator&(uint32_t x) const noexcept { return v & x; } + inline uint32_t operator|(uint32_t x) const noexcept { return v | x; } + inline uint32_t operator^(uint32_t x) const noexcept { return v ^ x; } +}; + +// a64::Assembler - Signature Utilities +// ==================================== + +// TODO: [ARM] Deprecate matchSignature. +static inline bool matchSignature(const Operand_& o0, const Operand_& o1, uint32_t instFlags) noexcept { + if (!(instFlags & (InstDB::kInstFlagLong | InstDB::kInstFlagNarrow))) + return o0.signature() == o1.signature(); + + // TODO: [ARM] Something smart to validate this. + return true; +} + +static inline bool matchSignature(const Operand_& o0, const Operand_& o1, const Operand_& o2, uint32_t instFlags) noexcept { + return matchSignature(o0, o1, instFlags) && o1.signature() == o2.signature(); +} + +static inline bool matchSignature(const Operand_& o0, const Operand_& o1, const Operand_& o2, const Operand_& o3, uint32_t instFlags) noexcept { + return matchSignature(o0, o1, instFlags) && o1.signature() == o2.signature() && o2.signature() == o3.signature();; +} + +// Memory must be either: +// 1. Absolute address, which will be converted to relative. +// 2. Relative displacement (Label). +// 3. Base register + either offset or index. +static inline bool armCheckMemBaseIndexRel(const Mem& mem) noexcept { + // Allowed base types (Nothing, Label, and GpX). + constexpr uint32_t kBaseMask = B(0) | + B(RegType::kLabelTag) | + B(RegType::kARM_GpX); + + // Allowed index types (Nothing, GpW, and GpX). + constexpr uint32_t kIndexMask = B(0) | + B(RegType::kARM_GpW) | + B(RegType::kARM_GpX) ; + + RegType baseType = mem.baseType(); + RegType indexType = mem.indexType(); + + if (!Support::bitTest(kBaseMask, baseType)) + return false; + + if (baseType > RegType::kLabelTag) { + // Index allows either GpW or GpX. + if (!Support::bitTest(kIndexMask, indexType)) + return false; + + if (indexType == RegType::kNone) + return true; + else + return !mem.hasOffset(); + } + else { + // No index register allowed if this is a PC relative address (literal). + return indexType == RegType::kNone; + } +} + +struct EncodeFpOpcodeBits { + uint32_t sizeMask; + uint32_t mask[3]; +}; + +static inline bool pickFpOpcode(const Vec& reg, uint32_t sOp, uint32_t sHf, uint32_t vOp, uint32_t vHf, Opcode* opcode, uint32_t* szOut) noexcept { + static constexpr uint32_t kQBitIndex = 30; + + static const EncodeFpOpcodeBits szBits[InstDB::kHF_Count] = { + { B(2) | B(1) , { 0u , 0u, B(22) } }, + { B(2) | B(1) | B(0), { 0u , 0u, 0u } }, + { B(2) | B(1) | B(0), { B(23) | B(22) , 0u, B(22) } }, + { B(2) | B(1) | B(0), { B(22) | B(20) | B(19) , 0u, B(22) } }, + { B(2) | B(1) | B(0), { B(22) | B(21) | B(15) | B(14), 0u, B(22) } }, + { B(2) | B(1) | B(0), { B(23) , 0u, B(22) } } + }; + + if (!reg.hasElementType()) { + // Scalar operation [HSD]. + uint32_t sz = diff(reg.type(), RegType::kARM_VecH); + if (sz > 2u || !Support::bitTest(szBits[sHf].sizeMask, sz)) + return false; + + opcode->reset(szBits[sHf].mask[sz] ^ sOp); + *szOut = sz; + return sOp != 0; + } + else { + // Vector operation [HSD]. + uint32_t q = diff(reg.type(), RegType::kARM_VecD); + uint32_t sz = reg.elementType() - Vec::kElementTypeH; + + if (q > 1u || sz > 2u || !Support::bitTest(szBits[vHf].sizeMask, sz)) + return false; + + opcode->reset(szBits[vHf].mask[sz] ^ (vOp | (q << kQBitIndex))); + *szOut = sz; + return vOp != 0; + } +} + +static inline bool pickFpOpcode(const Vec& reg, uint32_t sOp, uint32_t sHf, uint32_t vOp, uint32_t vHf, Opcode* opcode) noexcept { + uint32_t sz; + return pickFpOpcode(reg, sOp, sHf, vOp, vHf, opcode, &sz); +} + +// a64::Assembler - Operand Checks +// =============================== + +// Checks whether all operands have the same signature. +static inline bool checkSignature(const Operand_& o0, const Operand_& o1) noexcept { + return o0.signature() == o1.signature(); +} + +static inline bool checkSignature(const Operand_& o0, const Operand_& o1, const Operand_& o2) noexcept { + return o0.signature() == o1.signature() && + o1.signature() == o2.signature(); +} + +static inline bool checkSignature(const Operand_& o0, const Operand_& o1, const Operand_& o2, const Operand_& o3) noexcept { + return o0.signature() == o1.signature() && + o1.signature() == o2.signature() && + o2.signature() == o3.signature(); +} + +// Checks whether the register is GP register of the allowed types. +// +// Allowed is a 2-bit mask, where the first bits allows GpW and the second bit +// allows GpX. These bits are usually stored within the instruction, but could +// be also hardcoded in the assembler for instructions where GP types are not +// selectable. +static inline bool checkGpType(const Operand_& op, uint32_t allowed) noexcept { + RegType type = op.as().type(); + return Support::bitTest(allowed << uint32_t(RegType::kARM_GpW), type); +} + +static inline bool checkGpType(const Operand_& op, uint32_t allowed, uint32_t* x) noexcept { + // NOTE: We set 'x' to one only when GpW is allowed, otherwise the X is part + // of the opcode and we cannot set it. This is why this works without requiring + // additional logic. + RegType type = op.as().type(); + *x = diff(type, RegType::kARM_GpW) & allowed; + return Support::bitTest(allowed << uint32_t(RegType::kARM_GpW), type); +} + +static inline bool checkGpType(const Operand_& o0, const Operand_& o1, uint32_t allowed, uint32_t* x) noexcept { + return checkGpType(o0, allowed, x) && checkSignature(o0, o1); +} + +static inline bool checkGpType(const Operand_& o0, const Operand_& o1, const Operand_& o2, uint32_t allowed, uint32_t* x) noexcept { + return checkGpType(o0, allowed, x) && checkSignature(o0, o1, o2); +} + +static inline bool checkGpId(const Operand_& op, uint32_t hiId = kZR) noexcept { + uint32_t id = op.as().id(); + return id < 31u || id == hiId; +} + +static inline bool checkGpId(const Operand_& o0, const Operand_& o1, uint32_t hiId = kZR) noexcept { + uint32_t id0 = o0.as().id(); + uint32_t id1 = o1.as().id(); + + return (id0 < 31u || id0 == hiId) && (id1 < 31u || id1 == hiId); +} + +static inline bool checkGpId(const Operand_& o0, const Operand_& o1, const Operand_& o2, uint32_t hiId = kZR) noexcept { + uint32_t id0 = o0.as().id(); + uint32_t id1 = o1.as().id(); + uint32_t id2 = o2.as().id(); + + return (id0 < 31u || id0 == hiId) && (id1 < 31u || id1 == hiId) && (id2 < 31u || id2 == hiId); +} + +static inline bool checkVecId(const Operand_& op) noexcept { + uint32_t id = op.as().id(); + return id <= 31u; +} + +static inline bool checkVecId(const Operand_& o0, const Operand_& o1) noexcept { + uint32_t id0 = o0.as().id(); + uint32_t id1 = o1.as().id(); + + return (id0 | id1) <= 31u; +} + +/* Unused at the moment. +static inline bool checkVecId(const Operand_& o0, const Operand_& o1, const Operand_& o2) noexcept { + uint32_t id0 = o0.as().id(); + uint32_t id1 = o1.as().id(); + uint32_t id2 = o2.as().id(); + + return (id0 | id1 | id2) <= 31u; +} + +static inline bool checkVecId(const Operand_& o0, const Operand_& o1, const Operand_& o2, const Operand_& o3) noexcept { + uint32_t id0 = o0.as().id(); + uint32_t id1 = o1.as().id(); + uint32_t id2 = o2.as().id(); + uint32_t id3 = o3.as().id(); + + return (id0 | id1 | id2 | id3) <= 31u; +} +*/ + +static inline bool checkMemBase(const Mem& mem) noexcept { + return mem.baseType() == RegType::kARM_GpX && mem.baseId() <= 31; +} + +static inline bool checkEven(const Operand_& o0, const Operand_& o1) noexcept { + return ((o0.id() | o1.id()) & 1) == 0; +} + +static inline bool checkConsecutive(const Operand_& o0, const Operand_& o1) noexcept { + return ((o0.id() + 1u) & 0x1Fu) == o1.id(); +} + +static inline bool checkConsecutive(const Operand_& o0, const Operand_& o1, const Operand_& o2) noexcept { + return ((o0.id() + 1u) & 0x1Fu) == o1.id() && + ((o0.id() + 2u) & 0x1Fu) == o2.id(); +} + +static inline bool checkConsecutive(const Operand_& o0, const Operand_& o1, const Operand_& o2, const Operand_& o3) noexcept { + return ((o0.id() + 1u) & 0x1Fu) == o1.id() && + ((o0.id() + 2u) & 0x1Fu) == o2.id() && + ((o0.id() + 3u) & 0x1Fu) == o3.id(); +} + +// a64::Assembler - CheckReg +// ========================= + +#define V(index) (index == uint32_t(RegType::kARM_GpW) ? Gp::kIdZr : \ + index == uint32_t(RegType::kARM_GpX) ? Gp::kIdZr : \ + index == uint32_t(RegType::kARM_VecB) ? 31u : \ + index == uint32_t(RegType::kARM_VecH) ? 31u : \ + index == uint32_t(RegType::kARM_VecS) ? 31u : \ + index == uint32_t(RegType::kARM_VecD) ? 31u : \ + index == uint32_t(RegType::kARM_VecV) ? 31u : 0) +static const Support::Array commonHiRegIdOfType = {{ + ASMJIT_LOOKUP_TABLE_32(V, 0) +}}; +#undef V + +static inline bool checkValidRegs(const Operand_& o0) noexcept { + return ((o0.id() < 31) | (o0.id() == commonHiRegIdOfType[o0.as().type()])); +} + +static inline bool checkValidRegs(const Operand_& o0, const Operand_& o1) noexcept { + return ((o0.id() < 31) | (o0.id() == commonHiRegIdOfType[o0.as().type()])) & + ((o1.id() < 31) | (o1.id() == commonHiRegIdOfType[o1.as().type()])) ; +} + +static inline bool checkValidRegs(const Operand_& o0, const Operand_& o1, const Operand_& o2) noexcept { + return ((o0.id() < 31) | (o0.id() == commonHiRegIdOfType[o0.as().type()])) & + ((o1.id() < 31) | (o1.id() == commonHiRegIdOfType[o1.as().type()])) & + ((o2.id() < 31) | (o2.id() == commonHiRegIdOfType[o2.as().type()])) ; +} + +static inline bool checkValidRegs(const Operand_& o0, const Operand_& o1, const Operand_& o2, const Operand_& o3) noexcept { + return ((o0.id() < 31) | (o0.id() == commonHiRegIdOfType[o0.as().type()])) & + ((o1.id() < 31) | (o1.id() == commonHiRegIdOfType[o1.as().type()])) & + ((o2.id() < 31) | (o2.id() == commonHiRegIdOfType[o2.as().type()])) & + ((o3.id() < 31) | (o3.id() == commonHiRegIdOfType[o3.as().type()])) ; +} + +// a64::Assembler - Construction & Destruction +// =========================================== + +Assembler::Assembler(CodeHolder* code) noexcept : BaseAssembler() { + _archMask = uint64_t(1) << uint32_t(Arch::kAArch64); + assignEmitterFuncs(this); + + if (code) + code->attach(this); +} + +Assembler::~Assembler() noexcept {} + +// a64::Assembler - Emit +// ===================== + +#define ENC_OPS1(OP0) \ + (uint32_t(OperandType::k##OP0)) + +#define ENC_OPS2(OP0, OP1) \ + (uint32_t(OperandType::k##OP0) + \ + (uint32_t(OperandType::k##OP1) << 3)) + +#define ENC_OPS3(OP0, OP1, OP2) \ + (uint32_t(OperandType::k##OP0) + \ + (uint32_t(OperandType::k##OP1) << 3) + \ + (uint32_t(OperandType::k##OP2) << 6)) + +#define ENC_OPS4(OP0, OP1, OP2, OP3) \ + (uint32_t(OperandType::k##OP0) + \ + (uint32_t(OperandType::k##OP1) << 3) + \ + (uint32_t(OperandType::k##OP2) << 6) + \ + (uint32_t(OperandType::k##OP3) << 9)) + +Error Assembler::_emit(InstId instId, const Operand_& o0, const Operand_& o1, const Operand_& o2, const Operand_* opExt) { + // Logging/Validation/Error. + constexpr InstOptions kRequiresSpecialHandling = InstOptions::kReserved; + + Error err; + CodeWriter writer(this); + + // Combine all instruction options and also check whether the instruction + // is valid. All options that require special handling (including invalid + // instruction) are handled by the next branch. + InstOptions options = InstOptions(instId - 1 >= Inst::_kIdCount - 1) | InstOptions((size_t)(_bufferEnd - writer.cursor()) < 4) | instOptions() | forcedInstOptions(); + + CondCode instCC = BaseInst::extractARMCondCode(instId); + instId = instId & uint32_t(InstIdParts::kRealId); + + if (instId >= Inst::_kIdCount) + instId = 0; + + const InstDB::InstInfo* instInfo = &InstDB::_instInfoTable[instId]; + uint32_t encodingIndex = instInfo->_encodingDataIndex; + + Opcode opcode; + uint32_t isign4; + uint32_t instFlags; + + const Operand_& o3 = opExt[EmitterUtils::kOp3]; + const Operand_* rmRel = nullptr; + + uint32_t multipleOpData[4]; + uint32_t multipleOpCount; + + // These are only used when instruction uses a relative displacement. + OffsetFormat offsetFormat; // Offset format. + uint64_t offsetValue; // Offset value (if known). + + if (ASMJIT_UNLIKELY(Support::test(options, kRequiresSpecialHandling))) { + if (ASMJIT_UNLIKELY(!_code)) + return reportError(DebugUtils::errored(kErrorNotInitialized)); + + // Unknown instruction. + if (ASMJIT_UNLIKELY(instId == 0)) + goto InvalidInstruction; + + // Condition code can only be used with 'B' instruction. + if (ASMJIT_UNLIKELY(instCC != CondCode::kAL && instId != Inst::kIdB)) + goto InvalidInstruction; + + // Grow request, happens rarely. + err = writer.ensureSpace(this, 4); + if (ASMJIT_UNLIKELY(err)) + goto Failed; + +#ifndef ASMJIT_NO_VALIDATION + // Strict validation. + if (hasDiagnosticOption(DiagnosticOptions::kValidateAssembler)) { + Operand_ opArray[Globals::kMaxOpCount]; + EmitterUtils::opArrayFromEmitArgs(opArray, o0, o1, o2, opExt); + + err = _funcs.validate(arch(), BaseInst(instId, options, _extraReg), opArray, Globals::kMaxOpCount, ValidationFlags::kNone); + if (ASMJIT_UNLIKELY(err)) + goto Failed; + } +#endif + } + + // Signature of the first 4 operands. + isign4 = (uint32_t(o0.opType()) ) + + (uint32_t(o1.opType()) << 3) + + (uint32_t(o2.opType()) << 6) + + (uint32_t(o3.opType()) << 9); + instFlags = instInfo->flags(); + + switch (instInfo->_encoding) { + // ------------------------------------------------------------------------ + // [Base - Universal] + // ------------------------------------------------------------------------ + + case InstDB::kEncodingBaseOp: { + const InstDB::EncodingData::BaseOp& opData = InstDB::EncodingData::baseOp[encodingIndex]; + + if (isign4 == 0) { + opcode.reset(opData.opcode); + goto EmitOp; + } + + break; + } + + case InstDB::kEncodingBaseOpImm: { + const InstDB::EncodingData::BaseOpImm& opData = InstDB::EncodingData::baseOpImm[encodingIndex]; + + if (isign4 == ENC_OPS1(Imm)) { + uint64_t imm = o0.as().valueAs(); + uint32_t immMax = 1u << opData.immBits; + + if (imm >= immMax) + goto InvalidImmediate; + + opcode.reset(opData.opcode); + opcode.addImm(imm, opData.immOffset); + goto EmitOp; + } + + break; + } + + case InstDB::kEncodingBaseR: { + const InstDB::EncodingData::BaseR& opData = InstDB::EncodingData::baseR[encodingIndex]; + + if (isign4 == ENC_OPS1(Reg)) { + if (!checkGpType(o0, opData.rType)) + goto InvalidInstruction; + + if (!checkGpId(o0, opData.rHiId)) + goto InvalidPhysId; + + opcode.reset(opData.opcode); + opcode.addReg(o0, opData.rShift); + goto EmitOp; + } + + break; + } + + case InstDB::kEncodingBaseRR: { + const InstDB::EncodingData::BaseRR& opData = InstDB::EncodingData::baseRR[encodingIndex]; + + if (isign4 == ENC_OPS2(Reg, Reg)) { + uint32_t x; + if (!checkGpType(o0, opData.aType, &x)) + goto InvalidInstruction; + + if (!checkGpType(o1, opData.bType)) + goto InvalidInstruction; + + if (opData.uniform && !checkSignature(o0, o1)) + goto InvalidInstruction; + + if (!checkGpId(o0, opData.aHiId)) + goto InvalidPhysId; + + if (!checkGpId(o1, opData.bHiId)) + goto InvalidPhysId; + + opcode.reset(opData.opcode); + opcode.addImm(x, 31); + opcode.addReg(o1, opData.bShift); + opcode.addReg(o0, opData.aShift); + goto EmitOp; + } + + break; + } + + case InstDB::kEncodingBaseRRR: { + const InstDB::EncodingData::BaseRRR& opData = InstDB::EncodingData::baseRRR[encodingIndex]; + + if (isign4 == ENC_OPS3(Reg, Reg, Reg)) { + uint32_t x; + if (!checkGpType(o0, opData.aType, &x)) + goto InvalidInstruction; + + if (!checkGpType(o1, opData.bType)) + goto InvalidInstruction; + + if (!checkGpType(o2, opData.cType)) + goto InvalidInstruction; + + if (opData.uniform && !checkSignature(o0, o1, o2)) + goto InvalidInstruction; + + if (!checkGpId(o0, opData.aHiId)) + goto InvalidPhysId; + + if (!checkGpId(o1, opData.bHiId)) + goto InvalidPhysId; + + if (!checkGpId(o2, opData.cHiId)) + goto InvalidPhysId; + + opcode.reset(opData.opcode()); + opcode.addImm(x, 31); + opcode.addReg(o2, 16); + opcode.addReg(o1, 5); + opcode.addReg(o0, 0); + goto EmitOp; + } + + break; + } + + case InstDB::kEncodingBaseRRRR: { + const InstDB::EncodingData::BaseRRRR& opData = InstDB::EncodingData::baseRRRR[encodingIndex]; + + if (isign4 == ENC_OPS4(Reg, Reg, Reg, Reg)) { + uint32_t x; + if (!checkGpType(o0, opData.aType, &x)) + goto InvalidInstruction; + + if (!checkGpType(o1, opData.bType)) + goto InvalidInstruction; + + if (!checkGpType(o2, opData.cType)) + goto InvalidInstruction; + + if (!checkGpType(o3, opData.dType)) + goto InvalidInstruction; + + if (opData.uniform && !checkSignature(o0, o1, o2, o3)) + goto InvalidInstruction; + + if (!checkGpId(o0, opData.aHiId)) + goto InvalidPhysId; + + if (!checkGpId(o1, opData.bHiId)) + goto InvalidPhysId; + + if (!checkGpId(o2, opData.cHiId)) + goto InvalidPhysId; + + if (!checkGpId(o3, opData.dHiId)) + goto InvalidPhysId; + + opcode.reset(opData.opcode()); + opcode.addImm(x, 31); + opcode.addReg(o2, 16); + opcode.addReg(o3, 10); + opcode.addReg(o1, 5); + opcode.addReg(o0, 0); + goto EmitOp; + } + + break; + } + + case InstDB::kEncodingBaseRRII: { + const InstDB::EncodingData::BaseRRII& opData = InstDB::EncodingData::baseRRII[encodingIndex]; + + if (isign4 == ENC_OPS4(Reg, Reg, Imm, Imm)) { + if (!checkGpType(o0, opData.aType)) + goto InvalidInstruction; + + if (!checkGpType(o1, opData.bType)) + goto InvalidInstruction; + + if (!checkGpId(o0, opData.aHiId)) + goto InvalidPhysId; + + if (!checkGpId(o1, opData.bHiId)) + goto InvalidPhysId; + + if (o2.as().valueAs() >= Support::bitMask(opData.aImmSize + opData.aImmDiscardLsb) || + o3.as().valueAs() >= Support::bitMask(opData.bImmSize + opData.bImmDiscardLsb)) + goto InvalidImmediate; + + uint32_t aImm = o2.as().valueAs() >> opData.aImmDiscardLsb; + uint32_t bImm = o3.as().valueAs() >> opData.bImmDiscardLsb; + + if ((aImm << opData.aImmDiscardLsb) != o2.as().valueAs() || + (bImm << opData.bImmDiscardLsb) != o3.as().valueAs()) + goto InvalidImmediate; + + opcode.reset(opData.opcode()); + opcode.addImm(aImm, opData.aImmOffset); + opcode.addImm(bImm, opData.bImmOffset); + opcode.addReg(o1, 5); + opcode.addReg(o0, 0); + goto EmitOp; + } + + break; + } + + // ------------------------------------------------------------------------ + // [Base - Mov] + // ------------------------------------------------------------------------ + + case InstDB::kEncodingBaseMov: { + // MOV is a pseudo instruction that uses various instructions depending on its signature. + uint32_t x = diff(o0.as().type(), RegType::kARM_GpW); + if (x > 1) + goto InvalidInstruction; + + if (isign4 == ENC_OPS2(Reg, Reg)) { + if (!o0.as().isGp()) + goto InvalidInstruction; + + if (!checkSignature(o0, o1)) + goto InvalidInstruction; + + bool hasSP = o0.as().isSP() || o1.as().isSP(); + if (hasSP) { + // Cannot be combined with ZR. + if (!checkGpId(o0, o1, kSP)) + goto InvalidPhysId; + + // MOV Rd, Rm -> ADD Rd, Rn, #0. + opcode.reset(0b00010001000000000000000000000000); + opcode.addImm(x, 31); + opcode.addReg(o1, 5); + opcode.addReg(o0, 0); + goto EmitOp; + } + else { + if (!checkGpId(o0, o1, kZR)) + goto InvalidPhysId; + + // MOV Rd, Rm -> ORR Rd, , Rm. + opcode.reset(0b00101010000000000000001111100000); + opcode.addImm(x, 31); + opcode.addReg(o1, 16); + opcode.addReg(o0, 0); + goto EmitOp; + } + } + + if (isign4 == ENC_OPS2(Reg, Imm)) { + if (!o0.as().isGp()) + goto InvalidInstruction; + + uint64_t immValue = o1.as().valueAs(); + if (!x) + immValue &= 0xFFFFFFFFu; + + // Prefer a single MOVN/MOVZ instruction over a logical instruction. + multipleOpCount = encodeMovSequence64(multipleOpData, immValue, o0.id() & 31, x); + if (multipleOpCount == 1 && !o0.as().isSP()) { + opcode.reset(multipleOpData[0]); + goto EmitOp; + } + + // Logical instructions use 13-bit immediate pattern encoded as N:ImmR:ImmS. + LogicalImm logicalImm; + if (!o0.as().isZR()) { + if (Utils::encodeLogicalImm(immValue, x ? 64 : 32, &logicalImm)) { + if (!checkGpId(o0, kSP)) + goto InvalidPhysId; + + opcode.reset(0b00110010000000000000001111100000); + opcode.addImm(x, 31); + opcode.addLogicalImm(logicalImm); + opcode.addReg(o0, 0); + goto EmitOp; + } + } + + if (!checkGpId(o0, kZR)) + goto InvalidPhysId; + + goto EmitOp_Multiple; + } + + break; + } + + case InstDB::kEncodingBaseMovKNZ: { + const InstDB::EncodingData::BaseMovKNZ& opData = InstDB::EncodingData::baseMovKNZ[encodingIndex]; + + uint32_t x = diff(o0.as().type(), RegType::kARM_GpW); + if (x > 1) + goto InvalidInstruction; + + if (!checkGpId(o0, kZR)) + goto InvalidPhysId; + + opcode.reset(opData.opcode); + opcode.addImm(x, 31); + + if (isign4 == ENC_OPS2(Reg, Imm)) { + uint64_t imm16 = o1.as().valueAs(); + if (imm16 > 0xFFFFu) + goto InvalidImmediate; + + opcode.addImm(imm16, 5); + opcode.addReg(o0, 0); + goto EmitOp; + } + + if (isign4 == ENC_OPS3(Reg, Imm, Imm)) { + uint64_t imm16 = o1.as().valueAs(); + uint32_t shiftType = o2.as().predicate(); + uint64_t shiftValue = o2.as().valueAs(); + + if (imm16 > 0xFFFFu || shiftValue > 48 || shiftType != uint32_t(ShiftOp::kLSL)) + goto InvalidImmediate; + + // Convert shift value to 'hw' field. + uint32_t hw = uint32_t(shiftValue) >> 4; + if ((hw << 4) != uint32_t(shiftValue)) + goto InvalidImmediate; + + opcode.addImm(hw, 21); + opcode.addImm(imm16, 5); + opcode.addReg(o0, 0); + + if (!x && hw > 1u) + goto InvalidImmediate; + + goto EmitOp; + } + + break; + } + + // ------------------------------------------------------------------------ + // [Base - Adr] + // ------------------------------------------------------------------------ + + case InstDB::kEncodingBaseAdr: { + const InstDB::EncodingData::BaseAdr& opData = InstDB::EncodingData::baseAdr[encodingIndex]; + + if (isign4 == ENC_OPS2(Reg, Label) || isign4 == ENC_OPS2(Reg, Imm)) { + if (!o0.as().isGpX()) + goto InvalidInstruction; + + if (!checkGpId(o0, kZR)) + goto InvalidPhysId; + + opcode.reset(opData.opcode()); + opcode.addReg(o0, 0); + offsetFormat.resetToImmValue(opData.offsetType, 4, 5, 21, 0); + + if (instId == Inst::kIdAdrp) + offsetFormat._immDiscardLsb = 12; + + rmRel = &o1; + goto EmitOp_Rel; + } + + break; + } + + // ------------------------------------------------------------------------ + // [Base - Arithmetic and Logical] + // ------------------------------------------------------------------------ + + case InstDB::kEncodingBaseAddSub: { + const InstDB::EncodingData::BaseAddSub& opData = InstDB::EncodingData::baseAddSub[encodingIndex]; + + uint32_t x; + if (!checkGpType(o0, o1, kWX, &x)) + goto InvalidInstruction; + + if (isign4 == ENC_OPS3(Reg, Reg, Imm) || isign4 == ENC_OPS4(Reg, Reg, Imm, Imm)) { + opcode.reset(uint32_t(opData.immediateOp) << 24); + + // ADD | SUB (immediate) - ZR is not allowed. + // ADDS|SUBS (immediate) - ZR allowed in Rd, SP allowed in Rn. + uint32_t aHiId = opcode.get() & B(29) ? kZR : kSP; + uint32_t bHiId = kSP; + + if (!checkGpId(o0, aHiId) || !checkGpId(o1, bHiId)) + goto InvalidPhysId; + + // ADD|SUB (immediate) use 12-bit immediate optionally shifted by 'LSL #12'. + uint64_t imm = o2.as().valueAs(); + uint32_t shift = 0; + + if (isign4 == ENC_OPS4(Reg, Reg, Imm, Imm)) { + if (o3.as().predicate() != uint32_t(ShiftOp::kLSL)) + goto InvalidImmediate; + + if (o3.as().value() != 0 && o3.as().value() != 12) + goto InvalidImmediate; + + shift = uint32_t(o3.as().value() != 0); + } + + // Accept immediate value of '0x00XXX000' by setting 'shift' to 12. + if (imm > 0xFFFu) { + if (shift || (imm & ~uint64_t(0xFFFu << 12)) != 0) + goto InvalidImmediate; + shift = 1; + imm >>= 12; + } + + opcode.addImm(x, 31); + opcode.addImm(shift, 22); + opcode.addImm(imm, 10); + opcode.addReg(o1, 5); + opcode.addReg(o0, 0); + goto EmitOp; + } + + if (isign4 == ENC_OPS3(Reg, Reg, Reg) || isign4 == ENC_OPS4(Reg, Reg, Reg, Imm)) { + if (!checkSignature(o1, o2)) + goto InvalidInstruction; + + uint32_t opSize = x ? 64 : 32; + uint64_t shift = 0; + uint32_t sType = uint32_t(ShiftOp::kLSL); + + if (isign4 == ENC_OPS4(Reg, Reg, Reg, Imm)) { + sType = o3.as().predicate(); + shift = o3.as().valueAs(); + } + + if (!checkGpId(o2, kZR)) + goto InvalidPhysId; + + // Shift operation - LSL, LSR, ASR. + if (sType <= uint32_t(ShiftOp::kASR)) { + bool hasSP = o0.as().isSP() || o1.as().isSP(); + if (!hasSP) { + if (!checkGpId(o0, o1, kZR)) + goto InvalidPhysId; + + if (shift >= opSize) + goto InvalidImmediate; + + opcode.reset(uint32_t(opData.shiftedOp) << 21); + opcode.addImm(x, 31); + opcode.addImm(sType, 22); + opcode.addReg(o2, 16); + opcode.addImm(shift, 10); + opcode.addReg(o1, 5); + opcode.addReg(o0, 0); + goto EmitOp; + } + + // SP register can only be used with LSL or Extend. + if (sType != uint32_t(ShiftOp::kLSL)) + goto InvalidImmediate; + sType = x ? uint32_t(ShiftOp::kUXTX) : uint32_t(ShiftOp::kUXTW); + } + + // Extend operation - UXTB, UXTH, UXTW, UXTX, SXTB, SXTH, SXTW, SXTX. + opcode.reset(uint32_t(opData.extendedOp) << 21); + sType -= uint32_t(ShiftOp::kUXTB); + + if (sType > 7 || shift > 4) + goto InvalidImmediate; + + if (!(opcode.get() & B(29))) { + // ADD|SUB (extend) - ZR is not allowed. + if (!checkGpId(o0, o1, kSP)) + goto InvalidPhysId; + } + else { + // ADDS|SUBS (extend) - ZR allowed in Rd, SP allowed in Rn. + if (!checkGpId(o0, kZR) || !checkGpId(o1, kSP)) + goto InvalidPhysId; + } + + opcode.addImm(x, 31); + opcode.addReg(o2, 16); + opcode.addImm(sType, 13); + opcode.addImm(shift, 10); + opcode.addReg(o1, 5); + opcode.addReg(o0, 0); + goto EmitOp; + } + + break; + } + + case InstDB::kEncodingBaseLogical: { + const InstDB::EncodingData::BaseLogical& opData = InstDB::EncodingData::baseLogical[encodingIndex]; + + uint32_t x; + if (!checkGpType(o0, o1, kWX, &x)) + goto InvalidInstruction; + + if (!checkSignature(o0, o1)) + goto InvalidInstruction; + + uint32_t opSize = x ? 64 : 32; + + if (isign4 == ENC_OPS3(Reg, Reg, Imm) && opData.immediateOp != 0) { + opcode.reset(uint32_t(opData.immediateOp) << 23); + + // AND|ANDS|BIC|BICS|ORR|EOR (immediate) uses a LogicalImm format described by N:R:S values. + uint64_t immMask = Support::lsbMask(opSize); + uint64_t immValue = o2.as().valueAs(); + + if (opData.negateImm) + immValue ^= immMask; + + // Logical instructions use 13-bit immediate pattern encoded as N:ImmS:ImmR. + LogicalImm logicalImm; + if (!Utils::encodeLogicalImm(immValue & immMask, opSize, &logicalImm)) + goto InvalidImmediate; + + // AND|BIC|ORR|EOR (immediate) can have SP on destination, but ANDS|BICS (immediate) cannot. + uint32_t kOpANDS = 0x3 << 29; + bool isANDS = (opcode.get() & kOpANDS) == kOpANDS; + + if (!checkGpId(o0, isANDS ? kZR : kSP) || !checkGpId(o1, kZR)) + goto InvalidPhysId; + + opcode.addImm(x, 31); + opcode.addLogicalImm(logicalImm); + opcode.addReg(o1, 5); + opcode.addReg(o0, 0); + goto EmitOp; + } + + if (!checkSignature(o1, o2)) + goto InvalidInstruction; + + if (isign4 == ENC_OPS3(Reg, Reg, Reg)) { + if (!checkGpId(o0, o1, o2, kZR)) + goto InvalidPhysId; + + opcode.reset(uint32_t(opData.shiftedOp) << 21); + opcode.addImm(x, 31); + opcode.addReg(o2, 16); + opcode.addReg(o1, 5); + opcode.addReg(o0, 0); + goto EmitOp; + } + + if (isign4 == ENC_OPS4(Reg, Reg, Reg, Imm)) { + if (!checkGpId(o0, o1, o2, kZR)) + goto InvalidPhysId; + + uint32_t shiftType = o3.as().predicate(); + uint64_t opShift = o3.as().valueAs(); + + if (shiftType > 0x3 || opShift >= opSize) + goto InvalidImmediate; + + opcode.reset(uint32_t(opData.shiftedOp) << 21); + opcode.addImm(x, 31); + opcode.addImm(shiftType, 22); + opcode.addReg(o2, 16); + opcode.addImm(opShift, 10); + opcode.addReg(o1, 5); + opcode.addReg(o0, 0); + goto EmitOp; + } + + break; + } + + case InstDB::kEncodingBaseCmpCmn: { + const InstDB::EncodingData::BaseCmpCmn& opData = InstDB::EncodingData::baseCmpCmn[encodingIndex]; + + uint32_t x; + if (!checkGpType(o0, kWX, &x)) + goto InvalidInstruction; + + if (isign4 == ENC_OPS2(Reg, Imm)) { + // CMN|CMP (immediate) - ZR is not allowed. + if (!checkGpId(o0, kSP)) + goto InvalidPhysId; + + // CMN|CMP (immediate) use 12-bit immediate optionally shifted by 'LSL #12'. + const Imm& imm12 = o1.as(); + uint32_t immShift = 0; + uint64_t immValue = imm12.valueAs(); + + if (immValue > 0xFFFu) { + if ((immValue & ~uint64_t(0xFFFu << 12)) != 0) + goto InvalidImmediate; + immShift = 1; + immValue >>= 12; + } + + opcode.reset(uint32_t(opData.immediateOp) << 24); + opcode.addImm(x, 31); + opcode.addImm(immShift, 22); + opcode.addImm(immValue, 10); + opcode.addReg(o0, 5); + opcode.addReg(Gp::kIdZr, 0); + goto EmitOp; + } + + if (isign4 == ENC_OPS2(Reg, Reg) || isign4 == ENC_OPS3(Reg, Reg, Imm)) { + if (!checkSignature(o0, o1)) + goto InvalidInstruction; + + uint32_t opSize = x ? 64 : 32; + uint32_t sType = 0; + uint64_t shift = 0; + + if (isign4 == ENC_OPS3(Reg, Reg, Imm)) { + sType = o2.as().predicate(); + shift = o2.as().valueAs(); + } + + bool hasSP = o0.as().isSP() || o1.as().isSP(); + + // Shift operation - LSL, LSR, ASR. + if (sType <= uint32_t(ShiftOp::kASR)) { + if (!hasSP) { + if (shift >= opSize) + goto InvalidImmediate; + + opcode.reset(uint32_t(opData.shiftedOp) << 21); + opcode.addImm(x, 31); + opcode.addImm(sType, 22); + opcode.addReg(o1, 16); + opcode.addImm(shift, 10); + opcode.addReg(o0, 5); + opcode.addReg(Gp::kIdZr, 0); + goto EmitOp; + } + + // SP register can only be used with LSL or Extend. + if (sType != uint32_t(ShiftOp::kLSL)) + goto InvalidImmediate; + + sType = x ? uint32_t(ShiftOp::kUXTX) : uint32_t(ShiftOp::kUXTW); + } + + // Extend operation - UXTB, UXTH, UXTW, UXTX, SXTB, SXTH, SXTW, SXTX. + sType -= uint32_t(ShiftOp::kUXTB); + if (sType > 7 || shift > 4) + goto InvalidImmediate; + + opcode.reset(uint32_t(opData.extendedOp) << 21); + opcode.addImm(x, 31); + opcode.addReg(o1, 16); + opcode.addImm(sType, 13); + opcode.addImm(shift, 10); + opcode.addReg(o0, 5); + opcode.addReg(Gp::kIdZr, 0); + goto EmitOp; + } + + break; + } + + case InstDB::kEncodingBaseMvnNeg: { + const InstDB::EncodingData::BaseMvnNeg& opData = InstDB::EncodingData::baseMvnNeg[encodingIndex]; + + uint32_t x; + if (!checkGpType(o0, o1, kWX, &x)) + goto InvalidInstruction; + + opcode.reset(opData.opcode); + opcode.addImm(x, 31); + opcode.addReg(o1, 16); + opcode.addReg(o0, 0); + + if (isign4 == ENC_OPS2(Reg, Reg)) { + if (!checkGpId(o0, o1, kZR)) + goto InvalidPhysId; + + goto EmitOp; + } + + if (isign4 == ENC_OPS3(Reg, Reg, Imm)) { + if (!checkGpId(o0, o1, kZR)) + goto InvalidPhysId; + + uint32_t opSize = x ? 64 : 32; + uint32_t shiftType = o2.as().predicate(); + uint64_t opShift = o2.as().valueAs(); + + if (shiftType > uint32_t(ShiftOp::kROR) || opShift >= opSize) + goto InvalidImmediate; + + opcode.addImm(shiftType, 22); + opcode.addImm(opShift, 10); + goto EmitOp; + } + + break; + } + + case InstDB::kEncodingBaseTst: { + const InstDB::EncodingData::BaseTst& opData = InstDB::EncodingData::baseTst[encodingIndex]; + + uint32_t x; + if (!checkGpType(o0, kWX, &x)) + goto InvalidInstruction; + + uint32_t opSize = x ? 64 : 32; + + if (isign4 == ENC_OPS2(Reg, Imm) && opData.immediateOp != 0) { + if (!checkGpId(o0, kZR)) + goto InvalidPhysId; + + // TST (immediate) uses a LogicalImm format described by N:R:S values. + uint64_t immMask = Support::lsbMask(opSize); + uint64_t immValue = o1.as().valueAs(); + + // Logical instructions use 13-bit immediate pattern encoded as N:ImmS:ImmR. + LogicalImm logicalImm; + if (!Utils::encodeLogicalImm(immValue & immMask, opSize, &logicalImm)) + goto InvalidImmediate; + + opcode.reset(uint32_t(opData.immediateOp) << 22); + opcode.addLogicalImm(logicalImm); + opcode.addImm(x, 31); + opcode.addReg(o0, 5); + opcode.addReg(Gp::kIdZr, 0); + goto EmitOp; + } + + opcode.reset(uint32_t(opData.shiftedOp) << 21); + opcode.addImm(x, 31); + opcode.addReg(o1, 16); + opcode.addReg(o0, 5); + opcode.addReg(Gp::kIdZr, 0); + + if (isign4 == ENC_OPS2(Reg, Reg)) { + if (!checkGpId(o0, o1, kZR)) + goto InvalidPhysId; + + goto EmitOp; + } + + if (isign4 == ENC_OPS3(Reg, Reg, Imm)) { + if (!checkGpId(o0, o1, kZR)) + goto InvalidPhysId; + + uint32_t shiftType = o2.as().predicate(); + uint64_t opShift = o2.as().valueAs(); + + if (shiftType > 0x3 || opShift >= opSize) + goto InvalidImmediate; + + opcode.addImm(shiftType, 22); + opcode.addImm(opShift, 10); + goto EmitOp; + } + + break; + } + + // ------------------------------------------------------------------------ + // [Base - Bit Manipulation] + // ------------------------------------------------------------------------ + + case InstDB::kEncodingBaseBfc: { + const InstDB::EncodingData::BaseBfc& opData = InstDB::EncodingData::baseBfc[encodingIndex]; + + if (isign4 == ENC_OPS3(Reg, Imm, Imm)) { + uint32_t x; + if (!checkGpType(o0, InstDB::kWX, &x)) + goto InvalidInstruction; + + if (!checkGpId(o0)) + goto InvalidPhysId; + + uint64_t lsb = o1.as().valueAs(); + uint64_t width = o2.as().valueAs(); + uint32_t opSize = x ? 64 : 32; + + if (lsb >= opSize || width == 0 || width > opSize) + goto InvalidImmediate; + + uint32_t lsb32 = Support::neg(uint32_t(lsb)) & (opSize - 1); + uint32_t width32 = uint32_t(width) - 1; + + opcode.reset(opData.opcode); + opcode.addImm(x, 31); + opcode.addImm(x, 22); + opcode.addImm(lsb32, 16); + opcode.addImm(width32, 10); + opcode.addReg(o0, 0); + goto EmitOp; + } + + break; + } + + case InstDB::kEncodingBaseBfi: { + const InstDB::EncodingData::BaseBfi& opData = InstDB::EncodingData::baseBfi[encodingIndex]; + + if (isign4 == ENC_OPS4(Reg, Reg, Imm, Imm)) { + uint32_t x; + if (!checkGpType(o0, InstDB::kWX, &x)) + goto InvalidInstruction; + + if (!checkSignature(o0, o1)) + goto InvalidInstruction; + + if (!checkGpId(o0, o1)) + goto InvalidPhysId; + + uint64_t lsb = o2.as().valueAs(); + uint64_t width = o3.as().valueAs(); + uint32_t opSize = x ? 64 : 32; + + if (lsb >= opSize || width == 0 || width > opSize) + goto InvalidImmediate; + + uint32_t lImm = Support::neg(uint32_t(lsb)) & (opSize - 1); + uint32_t wImm = uint32_t(width) - 1; + + opcode.reset(opData.opcode); + opcode.addImm(x, 31); + opcode.addImm(x, 22); + opcode.addImm(lImm, 16); + opcode.addImm(wImm, 10); + opcode.addReg(o1, 5); + opcode.addReg(o0, 0); + goto EmitOp; + } + + break; + } + + case InstDB::kEncodingBaseBfm: { + const InstDB::EncodingData::BaseBfm& opData = InstDB::EncodingData::baseBfm[encodingIndex]; + + if (isign4 == ENC_OPS4(Reg, Reg, Imm, Imm)) { + uint32_t x; + if (!checkGpType(o0, InstDB::kWX, &x)) + goto InvalidInstruction; + + if (!checkSignature(o0, o1)) + goto InvalidInstruction; + + if (!checkGpId(o0, o1)) + goto InvalidPhysId; + + uint64_t immR = o2.as().valueAs(); + uint64_t immS = o3.as().valueAs(); + uint32_t opSize = x ? 64 : 32; + + if ((immR | immS) >= opSize) + goto InvalidImmediate; + + opcode.reset(opData.opcode); + opcode.addImm(x, 31); + opcode.addImm(x, 22); + opcode.addImm(immR, 16); + opcode.addImm(immS, 10); + opcode.addReg(o1, 5); + opcode.addReg(o0, 0); + goto EmitOp; + } + + break; + } + + case InstDB::kEncodingBaseBfx: { + const InstDB::EncodingData::BaseBfx& opData = InstDB::EncodingData::baseBfx[encodingIndex]; + + if (isign4 == ENC_OPS4(Reg, Reg, Imm, Imm)) { + uint32_t x; + if (!checkGpType(o0, InstDB::kWX, &x)) + goto InvalidInstruction; + + if (!checkSignature(o0, o1)) + goto InvalidInstruction; + + if (!checkGpId(o0, o1)) + goto InvalidPhysId; + + uint64_t lsb = o2.as().valueAs(); + uint64_t width = o3.as().valueAs(); + uint32_t opSize = x ? 64 : 32; + + if (lsb >= opSize || width == 0 || width > opSize) + goto InvalidImmediate; + + uint32_t lsb32 = uint32_t(lsb); + uint32_t width32 = lsb32 + uint32_t(width) - 1u; + + if (width32 >= opSize) + goto InvalidImmediate; + + opcode.reset(opData.opcode); + opcode.addImm(x, 31); + opcode.addImm(x, 22); + opcode.addImm(lsb32, 16); + opcode.addImm(width32, 10); + opcode.addReg(o1, 5); + opcode.addReg(o0, 0); + goto EmitOp; + } + + break; + } + + case InstDB::kEncodingBaseExtend: { + const InstDB::EncodingData::BaseExtend& opData = InstDB::EncodingData::baseExtend[encodingIndex]; + + if (isign4 == ENC_OPS2(Reg, Reg)) { + uint32_t x; + if (!checkGpType(o0, opData.rType, &x)) + goto InvalidInstruction; + + if (!o1.as().isGpW()) + goto InvalidInstruction; + + if (!checkGpId(o0, o1)) + goto InvalidPhysId; + + opcode.reset(opData.opcode()); + opcode.addImm(x, 31); + opcode.addImm(x, 22); + opcode.addReg(o1, 5); + opcode.addReg(o0, 0); + goto EmitOp; + } + + break; + } + + case InstDB::kEncodingBaseExtract: { + const InstDB::EncodingData::BaseExtract& opData = InstDB::EncodingData::baseExtract[encodingIndex]; + + if (isign4 == ENC_OPS4(Reg, Reg, Reg, Imm)) { + uint32_t x; + if (!checkGpType(o0, kWX, &x)) + goto InvalidInstruction; + + if (!checkSignature(o0, o1, o2)) + goto InvalidInstruction; + + if (!checkGpId(o0, o1, o2)) + goto InvalidPhysId; + + uint64_t lsb = o3.as().valueAs(); + uint32_t opSize = x ? 64 : 32; + + if (lsb >= opSize) + goto InvalidImmediate; + + opcode.reset(opData.opcode); + opcode.addImm(x, 31); + opcode.addImm(x, 22); + opcode.addReg(o2, 16); + opcode.addImm(lsb, 10); + opcode.addReg(o1, 5); + opcode.addReg(o0, 0); + goto EmitOp; + } + + break; + } + + case InstDB::kEncodingBaseRev: { + if (isign4 == ENC_OPS2(Reg, Reg)) { + uint32_t x; + if (!checkGpType(o0, InstDB::kWX, &x)) + goto InvalidInstruction; + + if (!checkSignature(o0, o1)) + goto InvalidInstruction; + + if (!checkGpId(o0, o1)) + goto InvalidPhysId; + + opcode.reset(0b01011010110000000000100000000000); + opcode.addImm(x, 31); + opcode.addImm(x, 10); + opcode.addReg(o1, 5); + opcode.addReg(o0, 0); + goto EmitOp; + } + + break; + } + + case InstDB::kEncodingBaseShift: { + const InstDB::EncodingData::BaseShift& opData = InstDB::EncodingData::baseShift[encodingIndex]; + + uint32_t x; + if (!checkGpType(o0, kWX, &x)) + goto InvalidInstruction; + + if (isign4 == ENC_OPS3(Reg, Reg, Reg)) { + if (!checkSignature(o0, o1, o2)) + goto InvalidInstruction; + + if (!checkGpId(o0, o1, o2, kZR)) + goto InvalidPhysId; + + opcode.reset(opData.registerOp()); + opcode.addImm(x, 31); + opcode.addReg(o2, 16); + opcode.addReg(o1, 5); + opcode.addReg(o0, 0); + goto EmitOp; + } + + if (isign4 == ENC_OPS3(Reg, Reg, Imm) && opData.immediateOp()) { + if (!checkSignature(o0, o1)) + goto InvalidInstruction; + + if (!checkGpId(o0, o1, kZR)) + goto InvalidPhysId; + + uint64_t immR = o2.as().valueAs(); + uint32_t opSize = x ? 64 : 32; + + if (immR >= opSize) + goto InvalidImmediate; + + opcode.reset(opData.immediateOp()); + opcode.addImm(x, 31); + opcode.addImm(x, 22); + opcode.addReg(o1, 5); + opcode.addReg(o0, 0); + + if (opcode.get() & B(10)) { + // ASR and LSR (immediate) has the same logic. + opcode.addImm(x, 15); + opcode.addImm(immR, 16); + goto EmitOp; + } + + if (opData.ror == 0) { + // LSL (immediate) is an alias to UBFM + uint32_t ubfmImmR = Support::neg(uint32_t(immR)) & (opSize - 1); + uint32_t ubfmImmS = opSize - 1 - uint32_t(immR); + + opcode.addImm(ubfmImmR, 16); + opcode.addImm(ubfmImmS, 10); + goto EmitOp; + } + else { + // ROR (immediate) is an alias to EXTR. + opcode.addImm(immR, 10); + opcode.addReg(o1, 16); + goto EmitOp; + } + } + + break; + } + + // ------------------------------------------------------------------------ + // [Base - Conditionals] + // ------------------------------------------------------------------------ + + case InstDB::kEncodingBaseCCmp: { + const InstDB::EncodingData::BaseCCmp& opData = InstDB::EncodingData::baseCCmp[encodingIndex]; + + if (isign4 == ENC_OPS4(Reg, Reg, Imm, Imm) || isign4 == ENC_OPS4(Reg, Imm, Imm, Imm)) { + uint32_t x; + if (!checkGpType(o0, InstDB::kWX, &x)) + goto InvalidInstruction; + + if (!checkGpId(o0, kZR)) + goto InvalidPhysId; + + uint64_t nzcv = o2.as().valueAs(); + uint64_t cond = o3.as().valueAs(); + + if ((nzcv | cond) > 0xFu) + goto InvalidImmediate; + + opcode.reset(opData.opcode); + opcode.addImm(x, 31); + opcode.addImm(condCodeToOpcodeCond(uint32_t(cond)), 12); + opcode.addImm(nzcv, 0); + + if (isign4 == ENC_OPS4(Reg, Reg, Imm, Imm)) { + // CCMN|CCMP (register) form. + if (!checkSignature(o0, o1)) + goto InvalidInstruction; + + if (!checkGpId(o1, kZR)) + goto InvalidPhysId; + + opcode.addReg(o1, 16); + opcode.addReg(o0, 5); + goto EmitOp; + } + else { + // CCMN|CCMP (immediate) form. + uint64_t imm5 = o1.as().valueAs(); + if (imm5 > 0x1F) + goto InvalidImmediate; + + opcode.addImm(1, 11); + opcode.addImm(imm5, 16); + opcode.addReg(o0, 5); + goto EmitOp; + } + } + + break; + } + + case InstDB::kEncodingBaseCInc: { + const InstDB::EncodingData::BaseCInc& opData = InstDB::EncodingData::baseCInc[encodingIndex]; + + if (isign4 == ENC_OPS3(Reg, Reg, Imm)) { + uint32_t x; + if (!checkGpType(o0, o1, InstDB::kWX, &x)) + goto InvalidInstruction; + + if (!checkGpId(o0, o1, kZR)) + goto InvalidPhysId; + + uint64_t cond = o2.as().valueAs(); + if (cond - 2u > 0xEu) + goto InvalidImmediate; + + opcode.reset(opData.opcode); + opcode.addImm(x, 31); + opcode.addReg(o1, 16); + opcode.addImm(condCodeToOpcodeCond(uint32_t(cond)) ^ 1u, 12); + opcode.addReg(o1, 5); + opcode.addReg(o0, 0); + goto EmitOp; + } + + break; + } + + case InstDB::kEncodingBaseCSel: { + const InstDB::EncodingData::BaseCSel& opData = InstDB::EncodingData::baseCSel[encodingIndex]; + + if (isign4 == ENC_OPS4(Reg, Reg, Reg, Imm)) { + uint32_t x; + if (!checkGpType(o0, o1, o2, InstDB::kWX, &x)) + goto InvalidInstruction; + + if (!checkGpId(o0, o1, o2, kZR)) + goto InvalidPhysId; + + uint64_t cond = o3.as().valueAs(); + if (cond > 0xFu) + goto InvalidImmediate; + + opcode.reset(opData.opcode); + opcode.addImm(x, 31); + opcode.addReg(o2, 16); + opcode.addImm(condCodeToOpcodeCond(uint32_t(cond)), 12); + opcode.addReg(o1, 5); + opcode.addReg(o0, 0); + goto EmitOp; + } + + break; + } + + case InstDB::kEncodingBaseCSet: { + const InstDB::EncodingData::BaseCSet& opData = InstDB::EncodingData::baseCSet[encodingIndex]; + + if (isign4 == ENC_OPS2(Reg, Imm)) { + uint32_t x; + if (!checkGpType(o0, InstDB::kWX, &x)) + goto InvalidInstruction; + + if (!checkGpId(o0, kZR)) + goto InvalidPhysId; + + uint64_t cond = o1.as().valueAs(); + if (cond - 2u >= 0xEu) + goto InvalidImmediate; + + opcode.reset(opData.opcode); + opcode.addImm(x, 31); + opcode.addImm(condCodeToOpcodeCond(uint32_t(cond)) ^ 1u, 12); + opcode.addReg(o0, 0); + goto EmitOp; + } + + break; + } + + // ------------------------------------------------------------------------ + // [Base - Special] + // ------------------------------------------------------------------------ + + case InstDB::kEncodingBaseAtDcIcTlbi: { + const InstDB::EncodingData::BaseAtDcIcTlbi& opData = InstDB::EncodingData::baseAtDcIcTlbi[encodingIndex]; + + if (isign4 == ENC_OPS1(Imm) || isign4 == ENC_OPS2(Imm, Reg)) { + if (opData.mandatoryReg && isign4 != ENC_OPS2(Imm, Reg)) + goto InvalidInstruction; + + if (o0.as().valueAs() > 0x7FFFu) + goto InvalidImmediate; + + uint32_t imm = o0.as().valueAs(); + if ((imm & opData.immVerifyMask) != opData.immVerifyData) + goto InvalidImmediate; + + uint32_t rt = 31; + if (o1.isReg()) { + if (!o1.as().isGpX()) + goto InvalidInstruction; + + if (!checkGpId(o1, kZR)) + goto InvalidPhysId; + + rt = o1.id() & 31; + } + + opcode.reset(0b11010101000010000000000000000000); + opcode.addImm(imm, 5); + opcode.addReg(rt, 0); + goto EmitOp; + } + break; + } + + case InstDB::kEncodingBaseMrs: { + if (isign4 == ENC_OPS2(Reg, Imm)) { + if (!o0.as().isGpX()) + goto InvalidInstruction; + + if (!checkGpId(o0, kZR)) + goto InvalidPhysId; + + if (o1.as().valueAs() > 0xFFFFu) + goto InvalidImmediate; + + uint32_t imm = o1.as().valueAs(); + if (!(imm & B(15))) + goto InvalidImmediate; + + opcode.reset(0b11010101001100000000000000000000); + opcode.addImm(imm, 5); + opcode.addReg(o0, 0); + goto EmitOp; + } + + break; + } + + case InstDB::kEncodingBaseMsr: { + if (isign4 == ENC_OPS2(Imm, Reg)) { + if (!o1.as().isGpX()) + goto InvalidInstruction; + + if (o0.as().valueAs() > 0xFFFFu) + goto InvalidImmediate; + + uint32_t imm = o0.as().valueAs(); + if (!(imm & B(15))) + goto InvalidImmediate; + + if (!checkGpId(o1, kZR)) + goto InvalidPhysId; + + opcode.reset(0b11010101000100000000000000000000); + opcode.addImm(imm, 5); + opcode.addReg(o1, 0); + goto EmitOp; + } + + if (isign4 == ENC_OPS2(Imm, Imm)) { + if (o0.as().valueAs() > 0x1Fu) + goto InvalidImmediate; + + if (o1.as().valueAs() > 0xFu) + goto InvalidImmediate; + + uint32_t op = o0.as().valueAs(); + uint32_t cRm = o1.as().valueAs(); + + uint32_t op1 = uint32_t(op) >> 3; + uint32_t op2 = uint32_t(op) & 0x7u; + + opcode.reset(0b11010101000000000100000000011111); + opcode.addImm(op1, 16); + opcode.addImm(cRm, 8); + opcode.addImm(op2, 5); + goto EmitOp; + } + + break; + } + + case InstDB::kEncodingBaseSys: { + if (isign4 == ENC_OPS4(Imm, Imm, Imm, Imm)) { + if (o0.as().valueAs() > 0x7u || + o1.as().valueAs() > 0xFu || + o2.as().valueAs() > 0xFu || + o3.as().valueAs() > 0x7u) + goto InvalidImmediate; + + uint32_t op1 = o0.as().valueAs(); + uint32_t cRn = o1.as().valueAs(); + uint32_t cRm = o2.as().valueAs(); + uint32_t op2 = o3.as().valueAs(); + uint32_t rt = 31; + + const Operand_& o4 = opExt[EmitterUtils::kOp4]; + if (o4.isReg()) { + if (!o4.as().isGpX()) + goto InvalidInstruction; + + if (!checkGpId(o4, kZR)) + goto InvalidPhysId; + + rt = o4.id() & 31; + } + else if (!o4.isNone()) { + goto InvalidInstruction; + } + + opcode.reset(0b11010101000010000000000000000000); + opcode.addImm(op1, 16); + opcode.addImm(cRn, 12); + opcode.addImm(cRm, 8); + opcode.addImm(op2, 5); + opcode.addImm(rt, 0); + goto EmitOp; + } + + break; + } + + // ------------------------------------------------------------------------ + // [Base - Branch] + // ------------------------------------------------------------------------ + + case InstDB::kEncodingBaseBranchReg: { + const InstDB::EncodingData::BaseBranchReg& opData = InstDB::EncodingData::baseBranchReg[encodingIndex]; + + if (isign4 == ENC_OPS1(Reg)) { + if (!o0.as().isGpX()) + goto InvalidInstruction; + + if (!checkGpId(o0, kZR)) + goto InvalidPhysId; + + opcode.reset(opData.opcode); + opcode.addReg(o0, 5); + goto EmitOp; + } + + break; + } + + case InstDB::kEncodingBaseBranchRel: { + const InstDB::EncodingData::BaseBranchRel& opData = InstDB::EncodingData::baseBranchRel[encodingIndex]; + + if (isign4 == ENC_OPS1(Label) || isign4 == ENC_OPS1(Imm)) { + opcode.reset(opData.opcode); + rmRel = &o0; + + if (instCC != CondCode::kAL) { + opcode |= B(30); + opcode.addImm(condCodeToOpcodeCond(uint32_t(instCC)), 0); + offsetFormat.resetToImmValue(OffsetType::kSignedOffset, 4, 5, 19, 2); + goto EmitOp_Rel; + } + + offsetFormat.resetToImmValue(OffsetType::kSignedOffset, 4, 0, 26, 2); + goto EmitOp_Rel; + } + + break; + } + + case InstDB::kEncodingBaseBranchCmp: { + const InstDB::EncodingData::BaseBranchCmp& opData = InstDB::EncodingData::baseBranchCmp[encodingIndex]; + + if (isign4 == ENC_OPS2(Reg, Label) || isign4 == ENC_OPS2(Reg, Imm)) { + uint32_t x; + if (!checkGpType(o0, kWX, &x)) + goto InvalidInstruction; + + if (!checkGpId(o0, kZR)) + goto InvalidPhysId; + + opcode.reset(opData.opcode); + opcode.addImm(x, 31); + opcode.addReg(o0, 0); + offsetFormat.resetToImmValue(OffsetType::kSignedOffset, 4, 5, 19, 2); + + rmRel = &o1; + goto EmitOp_Rel; + } + + break; + } + + case InstDB::kEncodingBaseBranchTst: { + const InstDB::EncodingData::BaseBranchTst& opData = InstDB::EncodingData::baseBranchTst[encodingIndex]; + + if (isign4 == ENC_OPS3(Reg, Imm, Label) || isign4 == ENC_OPS3(Reg, Imm, Imm)) { + uint32_t x; + if (!checkGpType(o0, kWX, &x)) + goto InvalidInstruction; + + if (!checkGpId(o0, kZR)) + goto InvalidPhysId; + + uint64_t imm = o1.as().valueAs(); + + opcode.reset(opData.opcode); + if (imm >= 32) { + if (!x) + goto InvalidImmediate; + opcode.addImm(x, 31); + imm &= 0x1F; + } + + opcode.addReg(o0, 0); + opcode.addImm(imm, 19); + offsetFormat.resetToImmValue(OffsetType::kSignedOffset, 4, 5, 14, 2); + + rmRel = &o2; + goto EmitOp_Rel; + } + + break; + } + + // ------------------------------------------------------------------------ + // [Base - Load / Store] + // ------------------------------------------------------------------------ + + case InstDB::kEncodingBaseLdSt: { + const InstDB::EncodingData::BaseLdSt& opData = InstDB::EncodingData::baseLdSt[encodingIndex]; + + if (isign4 == ENC_OPS2(Reg, Mem)) { + const Mem& m = o1.as(); + rmRel = &m; + + uint32_t x; + if (!checkGpType(o0, opData.rType, &x)) + goto InvalidInstruction; + + if (!checkGpId(o0, kZR)) + goto InvalidPhysId; + + // Instructions that work with either word or dword have the unsigned + // offset shift set to 2 (word), so we set it to 3 (dword) if this is + // X version of the instruction. + uint32_t xShiftMask = uint32_t(opData.uOffsetShift == 2); + uint32_t immShift = uint32_t(opData.uOffsetShift) + (x & xShiftMask); + + if (!armCheckMemBaseIndexRel(m)) + goto InvalidAddress; + + int64_t offset = m.offset(); + if (m.hasBaseReg()) { + // [Base {Offset | Index}] + if (m.hasIndex()) { + uint32_t opt = armShiftOpToLdStOptMap[m.predicate()]; + if (opt == 0xFF) + goto InvalidAddress; + + uint32_t shift = m.shift(); + uint32_t s = shift != 0; + + if (s && shift != immShift) + goto InvalidAddressScale; + + opcode.reset(uint32_t(opData.registerOp) << 21); + opcode.xorImm(x, opData.xOffset); + opcode.addImm(opt, 13); + opcode.addImm(s, 12); + opcode |= B(11); + opcode.addReg(o0, 0); + goto EmitOp_MemBaseIndex_Rn5_Rm16; + } + + // Makes it easier to work with the offset especially on 32-bit arch. + if (!Support::isInt32(offset)) + goto InvalidDisplacement; + int32_t offset32 = int32_t(offset); + + if (m.isPreOrPost()) { + if (!Support::isInt9(offset32)) + goto InvalidDisplacement; + + opcode.reset(uint32_t(opData.prePostOp) << 21); + opcode.xorImm(x, opData.xOffset); + opcode.addImm(offset32 & 0x1FF, 12); + opcode.addImm(m.isPreIndex(), 11); + opcode |= B(10); + opcode.addReg(o0, 0); + goto EmitOp_MemBase_Rn5; + } + else { + uint32_t imm12 = uint32_t(offset32) >> immShift; + + // Alternative form of LDUR/STUR and related instructions as described by AArch64 reference manual: + // + // If this instruction is not encodable with scaled unsigned offset, try unscaled signed offset. + if (!Support::isUInt12(imm12) || (imm12 << immShift) != uint32_t(offset32)) { + instId = opData.uAltInstId; + instInfo = &InstDB::_instInfoTable[instId]; + encodingIndex = instInfo->_encodingDataIndex; + goto Case_BaseLdurStur; + } + + opcode.reset(uint32_t(opData.uOffsetOp) << 22); + opcode.xorImm(x, opData.xOffset); + opcode.addImm(imm12, 10); + opcode.addReg(o0, 0); + goto EmitOp_MemBase_Rn5; + } + } + else { + if (!opData.literalOp) + goto InvalidAddress; + + opcode.reset(uint32_t(opData.literalOp) << 24); + opcode.xorImm(x, opData.xOffset); + opcode.addReg(o0, 0); + offsetFormat.resetToImmValue(OffsetType::kSignedOffset, 4, 5, 19, 2); + goto EmitOp_Rel; + } + } + + break; + } + + case InstDB::kEncodingBaseLdpStp: { + const InstDB::EncodingData::BaseLdpStp& opData = InstDB::EncodingData::baseLdpStp[encodingIndex]; + + if (isign4 == ENC_OPS3(Reg, Reg, Mem)) { + const Mem& m = o2.as(); + rmRel = &m; + + uint32_t x; + if (!checkGpType(o0, o1, opData.rType, &x)) + goto InvalidInstruction; + + if (!checkGpId(o0, o1, kZR)) + goto InvalidPhysId; + + if (m.baseType() != RegType::kARM_GpX || m.hasIndex()) + goto InvalidAddress; + + if (m.isOffset64Bit()) + goto InvalidDisplacement; + + uint32_t offsetShift = opData.offsetShift + x; + int32_t offset32 = m.offsetLo32() >> offsetShift; + + // Make sure we didn't lose bits by applying the mandatory offset shift. + if (uint32_t(offset32) << offsetShift != uint32_t(m.offsetLo32())) + goto InvalidDisplacement; + + // Offset is encoded as 7-bit immediate. + if (!Support::isInt7(offset32)) + goto InvalidDisplacement; + + if (m.isPreOrPost() && offset32 != 0) { + if (!opData.prePostOp) + goto InvalidAddress; + + opcode.reset(uint32_t(opData.prePostOp) << 22); + opcode.addImm(m.isPreIndex(), 24); + } + else { + opcode.reset(uint32_t(opData.offsetOp) << 22); + } + + opcode.addImm(x, opData.xOffset); + opcode.addImm(offset32 & 0x7F, 15); + opcode.addReg(o1, 10); + opcode.addReg(o0, 0); + goto EmitOp_MemBase_Rn5; + } + + break; + } + + case InstDB::kEncodingBaseStx: { + const InstDB::EncodingData::BaseStx& opData = InstDB::EncodingData::baseStx[encodingIndex]; + + if (isign4 == ENC_OPS3(Reg, Reg, Mem)) { + const Mem& m = o2.as(); + uint32_t x; + + if (!o0.as().isGpW() || !checkGpType(o1, opData.rType, &x)) + goto InvalidInstruction; + + if (!checkGpId(o0, o1, kZR)) + goto InvalidPhysId; + + opcode.reset(opData.opcode()); + opcode.addImm(x, opData.xOffset); + opcode.addReg(o0, 16); + opcode.addReg(o1, 0); + + rmRel = &m; + goto EmitOp_MemBaseNoImm_Rn5; + } + + break; + } + + case InstDB::kEncodingBaseLdxp: { + const InstDB::EncodingData::BaseLdxp& opData = InstDB::EncodingData::baseLdxp[encodingIndex]; + + if (isign4 == ENC_OPS3(Reg, Reg, Mem)) { + const Mem& m = o2.as(); + uint32_t x; + + if (!checkGpType(o0, opData.rType, &x) || !checkSignature(o0, o1)) + goto InvalidInstruction; + + if (!checkGpId(o0, o1, kZR)) + goto InvalidPhysId; + + opcode.reset(opData.opcode()); + opcode.addImm(x, opData.xOffset); + opcode.addReg(o1, 10); + opcode.addReg(o0, 0); + + rmRel = &m; + goto EmitOp_MemBaseNoImm_Rn5; + } + + break; + } + + case InstDB::kEncodingBaseStxp: { + const InstDB::EncodingData::BaseStxp& opData = InstDB::EncodingData::baseStxp[encodingIndex]; + + if (isign4 == ENC_OPS4(Reg, Reg, Reg, Mem)) { + const Mem& m = o3.as(); + uint32_t x; + + if (!o0.as().isGpW() || !checkGpType(o1, opData.rType, &x) || !checkSignature(o1, o2)) + goto InvalidInstruction; + + if (!checkGpId(o0, o1, o2, kZR)) + goto InvalidPhysId; + + opcode.reset(opData.opcode()); + opcode.addImm(x, opData.xOffset); + opcode.addReg(o0, 16); + opcode.addReg(o2, 10); + opcode.addReg(o1, 0); + + rmRel = &m; + goto EmitOp_MemBaseNoImm_Rn5; + } + + break; + } + + case InstDB::kEncodingBaseRM_NoImm: { + const InstDB::EncodingData::BaseRM_NoImm& opData = InstDB::EncodingData::baseRM_NoImm[encodingIndex]; + + if (isign4 == ENC_OPS2(Reg, Mem)) { + const Mem& m = o1.as(); + rmRel = &m; + + uint32_t x; + if (!checkGpType(o0, opData.rType, &x)) + goto InvalidInstruction; + + if (!checkGpId(o0, opData.rHiId)) + goto InvalidPhysId; + + opcode.reset(opData.opcode()); + opcode.addImm(x, opData.xOffset); + opcode.addReg(o0, 0); + goto EmitOp_MemBaseNoImm_Rn5; + } + + break; + } + + case InstDB::kEncodingBaseRM_SImm9: { +Case_BaseLdurStur: + const InstDB::EncodingData::BaseRM_SImm9& opData = InstDB::EncodingData::baseRM_SImm9[encodingIndex]; + + if (isign4 == ENC_OPS2(Reg, Mem)) { + const Mem& m = o1.as(); + rmRel = &m; + + uint32_t x; + if (!checkGpType(o0, opData.rType, &x)) + goto InvalidInstruction; + + if (!checkGpId(o0, opData.rHiId)) + goto InvalidPhysId; + + if (m.hasBaseReg() && !m.hasIndex()) { + if (m.isOffset64Bit()) + goto InvalidDisplacement; + + int32_t offset32 = m.offsetLo32() >> opData.immShift; + if (Support::shl(offset32, opData.immShift) != m.offsetLo32()) + goto InvalidDisplacement; + + if (!Support::isInt9(offset32)) + goto InvalidDisplacement; + + if (m.isFixedOffset()) { + opcode.reset(opData.offsetOp()); + } + else { + if (!opData.prePostOp()) + goto InvalidInstruction; + + opcode.reset(opData.prePostOp()); + opcode.xorImm(m.isPreIndex(), 11); + } + + opcode.xorImm(x, opData.xOffset); + opcode.addImm(offset32 & 0x1FF, 12); + opcode.addReg(o0, 0); + goto EmitOp_MemBase_Rn5; + } + + goto InvalidAddress; + } + + break; + } + + case InstDB::kEncodingBaseRM_SImm10: { + const InstDB::EncodingData::BaseRM_SImm10& opData = InstDB::EncodingData::baseRM_SImm10[encodingIndex]; + + if (isign4 == ENC_OPS2(Reg, Mem)) { + const Mem& m = o1.as(); + rmRel = &m; + + uint32_t x; + if (!checkGpType(o0, opData.rType, &x)) + goto InvalidInstruction; + + if (!checkGpId(o0, opData.rHiId)) + goto InvalidPhysId; + + if (m.hasBaseReg() && !m.hasIndex()) { + if (m.isOffset64Bit()) + goto InvalidDisplacement; + + int32_t offset32 = m.offsetLo32() >> opData.immShift; + if (Support::shl(offset32, opData.immShift) != m.offsetLo32()) + goto InvalidDisplacement; + + if (!Support::isInt10(offset32)) + goto InvalidDisplacement; + + if (m.isPostIndex()) + goto InvalidAddress; + + // Offset has 10 bits, sign is stored in the 10th bit. + offset32 &= 0x3FF; + + opcode.reset(opData.opcode()); + opcode.xorImm(m.isPreIndex(), 11); + opcode.xorImm(x, opData.xOffset); + opcode.addImm(offset32 >> 9, 22); + opcode.addImm(offset32, 12); + opcode.addReg(o0, 0); + goto EmitOp_MemBase_Rn5; + } + + goto InvalidAddress; + } + + break; + } + + case InstDB::kEncodingBaseAtomicOp: { + const InstDB::EncodingData::BaseAtomicOp& opData = InstDB::EncodingData::baseAtomicOp[encodingIndex]; + + if (isign4 == ENC_OPS3(Reg, Reg, Mem)) { + const Mem& m = o2.as(); + uint32_t x; + + if (!checkGpType(o0, opData.rType, &x) || !checkSignature(o0, o1)) + goto InvalidInstruction; + + if (!checkGpId(o0, o1, kZR)) + goto InvalidInstruction; + + opcode.reset(opData.opcode()); + opcode.addImm(x, opData.xOffset); + opcode.addReg(o0, 16); + opcode.addReg(o1, 0); + + rmRel = &m; + goto EmitOp_MemBaseNoImm_Rn5; + } + + break; + } + + case InstDB::kEncodingBaseAtomicSt: { + const InstDB::EncodingData::BaseAtomicSt& opData = InstDB::EncodingData::baseAtomicSt[encodingIndex]; + + if (isign4 == ENC_OPS2(Reg, Mem)) { + const Mem& m = o1.as(); + uint32_t x; + + if (!checkGpType(o0, opData.rType, &x)) + goto InvalidInstruction; + + if (!checkGpId(o0, kZR)) + goto InvalidPhysId; + + opcode.reset(opData.opcode()); + opcode.addImm(x, opData.xOffset); + opcode.addReg(o0, 16); + opcode.addReg(Gp::kIdZr, 0); + + rmRel = &m; + goto EmitOp_MemBaseNoImm_Rn5; + } + + break; + } + + case InstDB::kEncodingBaseAtomicCasp: { + const InstDB::EncodingData::BaseAtomicCasp& opData = InstDB::EncodingData::baseAtomicCasp[encodingIndex]; + const Operand_& o4 = opExt[EmitterUtils::kOp4]; + + if (isign4 == ENC_OPS4(Reg, Reg, Reg, Reg) && o4.isMem()) { + const Mem& m = o4.as(); + uint32_t x; + + if (!checkGpType(o0, opData.rType, &x)) + goto InvalidInstruction; + + if (!checkSignature(o0, o1, o2, o3)) + goto InvalidInstruction; + + if (!checkEven(o0, o2) || !checkGpId(o0, o2, kZR)) + goto InvalidPhysId; + + if (!checkConsecutive(o0, o1) || !checkConsecutive(o2, o3)) + goto InvalidPhysId; + + opcode.reset(opData.opcode()); + opcode.addImm(x, opData.xOffset); + opcode.addReg(o0, 16); + opcode.addReg(o2, 0); + + rmRel = &m; + goto EmitOp_MemBaseNoImm_Rn5; + } + + break; + } + + // ------------------------------------------------------------------------ + // [FSimd - Instructions] + // ------------------------------------------------------------------------ + + case InstDB::kEncodingFSimdSV: { + const InstDB::EncodingData::FSimdSV& opData = InstDB::EncodingData::fSimdSV[encodingIndex]; + + if (isign4 == ENC_OPS2(Reg, Reg)) { + uint32_t q = diff(o1.as().type(), RegType::kARM_VecD); + if (q > 1) + goto InvalidInstruction; + + if (o0.as().hasElementType()) + goto InvalidInstruction; + + // This operation is only defined for: + // hD, vS.{4|8}h (16-bit) + // sD, vS.4s (32-bit) + uint32_t sz = diff(o0.as().type(), RegType::kARM_VecH); + uint32_t elementSz = o1.as().elementType() - Vec::kElementTypeH; + + // Size greater than 1 means 64-bit elements, not supported. + if ((sz | elementSz) > 1 || sz != elementSz) + goto InvalidInstruction; + + // Size 1 (32-bit float) requires at least 4 elements. + if (sz && !q) + goto InvalidInstruction; + + // Bit flipping according to sz. + static const uint32_t szBits[] = { B(29), 0 }; + + opcode.reset(opData.opcode << 10); + opcode ^= szBits[sz]; + opcode.addImm(q, 30); + goto EmitOp_Rd0_Rn5; + } + + break; + } + + case InstDB::kEncodingFSimdVV: { + const InstDB::EncodingData::FSimdVV& opData = InstDB::EncodingData::fSimdVV[encodingIndex]; + + if (isign4 == ENC_OPS2(Reg, Reg)) { + if (!matchSignature(o0, o1, instFlags)) + goto InvalidInstruction; + + if (!pickFpOpcode(o0.as(), opData.scalarOp(), opData.scalarHf(), opData.vectorOp(), opData.vectorHf(), &opcode)) + goto InvalidInstruction; + + goto EmitOp_Rd0_Rn5; + } + + break; + } + + case InstDB::kEncodingFSimdVVV: { + const InstDB::EncodingData::FSimdVVV& opData = InstDB::EncodingData::fSimdVVV[encodingIndex]; + + if (isign4 == ENC_OPS3(Reg, Reg, Reg)) { + if (!matchSignature(o0, o1, o2, instFlags)) + goto InvalidInstruction; + + if (!pickFpOpcode(o0.as(), opData.scalarOp(), opData.scalarHf(), opData.vectorOp(), opData.vectorHf(), &opcode)) + goto InvalidInstruction; + + goto EmitOp_Rd0_Rn5_Rm16; + } + + break; + } + + case InstDB::kEncodingFSimdVVVe: { + const InstDB::EncodingData::FSimdVVVe& opData = InstDB::EncodingData::fSimdVVVe[encodingIndex]; + + if (isign4 == ENC_OPS3(Reg, Reg, Reg)) { + if (!o2.as().hasElementIndex()) { + if (!matchSignature(o0, o1, o2, instFlags)) + goto InvalidInstruction; + + if (!pickFpOpcode(o0.as(), opData.scalarOp(), opData.scalarHf(), opData.vectorOp(), opData.vectorHf(), &opcode)) + goto InvalidInstruction; + + goto EmitOp_Rd0_Rn5_Rm16; + } + else { + if (!matchSignature(o0, o1, instFlags)) + goto InvalidInstruction; + + uint32_t q = o1.as().isVecQ(); + uint32_t sz; + + if (!pickFpOpcode(o0.as(), opData.elementScalarOp(), InstDB::kHF_D, opData.elementVectorOp(), InstDB::kHF_D, &opcode, &sz)) + goto InvalidInstruction; + + if (sz == 0 && o2.as().id() > 15) + goto InvalidPhysId; + + uint32_t elementIndex = o2.as().elementIndex(); + if (elementIndex > (7u >> sz)) + goto InvalidElementIndex; + + uint32_t hlm = elementIndex << sz; + opcode.addImm(q, 30); + opcode.addImm(hlm & 3u, 20); + opcode.addImm(hlm >> 2, 11); + goto EmitOp_Rd0_Rn5_Rm16; + } + } + + break; + } + + case InstDB::kEncodingFSimdVVVV: { + const InstDB::EncodingData::FSimdVVVV& opData = InstDB::EncodingData::fSimdVVVV[encodingIndex]; + + if (isign4 == ENC_OPS4(Reg, Reg, Reg, Reg)) { + if (!matchSignature(o0, o1, o2, o3, instFlags)) + goto InvalidInstruction; + + if (!pickFpOpcode(o0.as(), opData.scalarOp(), opData.scalarHf(), opData.vectorOp(), opData.vectorHf(), &opcode)) + goto InvalidInstruction; + + goto EmitOp_Rd0_Rn5_Rm16_Ra10; + } + + break; + } + + case InstDB::kEncodingSimdFcadd: { + const InstDB::EncodingData::SimdFcadd& opData = InstDB::EncodingData::simdFcadd[encodingIndex]; + + if (isign4 == ENC_OPS4(Reg, Reg, Reg, Imm)) { + if (!checkSignature(o0, o1, o2) || o0.as().hasElementIndex()) + goto InvalidInstruction; + + uint32_t q = diff(o0.as().type(), RegType::kARM_VecD); + if (q > 1) + goto InvalidInstruction; + + uint32_t sz = o0.as().elementType() - Vec::kElementTypeB; + if (sz == 0 || sz > 3) + goto InvalidInstruction; + + // 0 <- 90deg. + // 1 <- 270deg. + uint32_t rot = 0; + if (o3.as().value() == 270) + rot = 1; + else if (o3.as().value() != 90) + goto InvalidImmediate; + + opcode.reset(opData.opcode()); + opcode.addImm(q, 30); + opcode.addImm(sz, 22); + opcode.addImm(rot, 12); + goto EmitOp_Rd0_Rn5_Rm16; + } + + break; + } + + case InstDB::kEncodingSimdFccmpFccmpe: { + const InstDB::EncodingData::SimdFccmpFccmpe& opData = InstDB::EncodingData::simdFccmpFccmpe[encodingIndex]; + + if (isign4 == ENC_OPS4(Reg, Reg, Imm, Imm)) { + uint32_t sz = diff(o0.as().type(), RegType::kARM_VecH); + if (sz > 2) + goto InvalidInstruction; + + if (!checkSignature(o0, o1) || o0.as().hasElementType()) + goto InvalidInstruction; + + uint64_t nzcv = o2.as().valueAs(); + uint64_t cond = o3.as().valueAs(); + + if ((nzcv | cond) > 0xFu) + goto InvalidImmediate; + + uint32_t type = (sz - 1) & 0x3u; + + opcode.reset(opData.opcode()); + opcode.addImm(type, 22); + opcode.addImm(condCodeToOpcodeCond(uint32_t(cond)), 12); + opcode.addImm(nzcv, 0); + + goto EmitOp_Rn5_Rm16; + } + + break; + } + + case InstDB::kEncodingSimdFcm: { + const InstDB::EncodingData::SimdFcm& opData = InstDB::EncodingData::simdFcm[encodingIndex]; + + if (isign4 == ENC_OPS3(Reg, Reg, Reg) && opData.hasRegisterOp()) { + if (!matchSignature(o0, o1, o2, instFlags)) + goto InvalidInstruction; + + if (!pickFpOpcode(o0.as(), opData.registerScalarOp(), opData.registerScalarHf(), opData.registerVectorOp(), opData.registerVectorHf(), &opcode)) + goto InvalidInstruction; + + goto EmitOp_Rd0_Rn5_Rm16; + } + + if (isign4 == ENC_OPS3(Reg, Reg, Imm) && opData.hasZeroOp()) { + if (!checkSignature(o0, o1)) + goto InvalidInstruction; + + if (o2.as().value() != 0 || o2.as().predicate() != 0) + goto InvalidImmediate; + + if (!pickFpOpcode(o0.as(), opData.zeroScalarOp(), InstDB::kHF_B, opData.zeroVectorOp(), InstDB::kHF_B, &opcode)) + goto InvalidInstruction; + + goto EmitOp_Rd0_Rn5; + } + + break; + } + + case InstDB::kEncodingSimdFcmla: { + const InstDB::EncodingData::SimdFcmla& opData = InstDB::EncodingData::simdFcmla[encodingIndex]; + + if (isign4 == ENC_OPS4(Reg, Reg, Reg, Imm)) { + if (!checkSignature(o0, o1)) + goto InvalidInstruction; + + uint32_t q = diff(o0.as().type(), RegType::kARM_VecD); + if (q > 1) + goto InvalidInstruction; + + uint32_t sz = o0.as().elementType() - Vec::kElementTypeB; + if (sz == 0 || sz > 3) + goto InvalidInstruction; + + uint32_t rot = 0; + switch (o3.as().value()) { + case 0 : rot = 0; break; + case 90 : rot = 1; break; + case 180: rot = 2; break; + case 270: rot = 3; break; + default: + goto InvalidImmediate; + } + + if (!o2.as().hasElementIndex()) { + if (!checkSignature(o1, o2)) + goto InvalidInstruction; + + opcode.reset(opData.regularOp()); + opcode.addImm(q, 30); + opcode.addImm(sz, 22); + opcode.addImm(rot, 11); + goto EmitOp_Rd0_Rn5_Rm16; + } + else { + if (o0.as().elementType() != o2.as().elementType()) + goto InvalidInstruction; + + // Only allowed vectors are: 4H, 8H, and 4S. + if (!(sz == 1 || (q == 1 && sz == 2))) + goto InvalidInstruction; + + // Element index ranges: + // 4H - ElementIndex[0..1] (index 2..3 is UNDEFINED). + // 8H - ElementIndex[0..3]. + // 4S - ElementIndex[0..1]. + uint32_t elementIndex = o2.as().elementIndex(); + uint32_t hlFieldShift = sz == 1 ? 0u : 1u; + uint32_t maxElementIndex = q == 1 && sz == 1 ? 3u : 1u; + + if (elementIndex > maxElementIndex) + goto InvalidElementIndex; + + uint32_t hl = elementIndex << hlFieldShift; + + opcode.reset(opData.elementOp()); + opcode.addImm(q, 30); + opcode.addImm(sz, 22); + opcode.addImm(hl & 1u, 21); // L field. + opcode.addImm(hl >> 1, 11); // H field. + opcode.addImm(rot, 13); + goto EmitOp_Rd0_Rn5_Rm16; + } + } + + break; + } + + case InstDB::kEncodingSimdFcmpFcmpe: { + const InstDB::EncodingData::SimdFcmpFcmpe& opData = InstDB::EncodingData::simdFcmpFcmpe[encodingIndex]; + + uint32_t sz = diff(o0.as().type(), RegType::kARM_VecH); + uint32_t type = (sz - 1) & 0x3u; + + if (sz > 2) + goto InvalidInstruction; + + if (o0.as().hasElementType()) + goto InvalidInstruction; + + opcode.reset(opData.opcode()); + opcode.addImm(type, 22); + + if (isign4 == ENC_OPS2(Reg, Reg)) { + if (!checkSignature(o0, o1)) + goto InvalidInstruction; + + goto EmitOp_Rn5_Rm16; + } + + if (isign4 == ENC_OPS2(Reg, Imm)) { + if (o1.as().value() != 0 || o1.as().predicate() != 0) + goto InvalidInstruction; + + opcode |= B(3); + goto EmitOp_Rn5; + } + + break; + } + + case InstDB::kEncodingSimdFcsel: { + if (isign4 == ENC_OPS4(Reg, Reg, Reg, Imm)) { + if (!checkSignature(o0, o1, o2)) + goto InvalidInstruction; + + uint32_t sz = diff(o0.as().type(), RegType::kARM_VecH); + uint32_t type = (sz - 1) & 0x3u; + + if (sz > 2 || o0.as().hasElementType()) + goto InvalidInstruction; + + uint64_t cond = o3.as().valueAs(); + if (cond > 0xFu) + goto InvalidImmediate; + + opcode.reset(0b00011110001000000000110000000000); + opcode.addImm(type, 22); + opcode.addImm(condCodeToOpcodeCond(uint32_t(cond)), 12); + goto EmitOp_Rd0_Rn5_Rm16; + } + + break; + } + + case InstDB::kEncodingSimdFcvt: { + if (isign4 == ENC_OPS2(Reg, Reg)) { + uint32_t dstSz = diff(o0.as().type(), RegType::kARM_VecH); + uint32_t srcSz = diff(o1.as().type(), RegType::kARM_VecH); + + if ((dstSz | srcSz) > 3) + goto InvalidInstruction; + + if (o0.as().hasElementType() || o1.as().hasElementType()) + goto InvalidInstruction; + + // Table that provides 'type' and 'opc' according to the dst/src combination. + static const uint8_t table[] = { + 0xFFu, // H <- H (Invalid). + 0x03u, // H <- S (type=00 opc=11). + 0x13u, // H <- D (type=01 opc=11). + 0xFFu, // H <- Q (Invalid). + 0x30u, // S <- H (type=11 opc=00). + 0xFFu, // S <- S (Invalid). + 0x10u, // S <- D (type=01 opc=00). + 0xFFu, // S <- Q (Invalid). + 0x31u, // D <- H (type=11 opc=01). + 0x01u, // D <- S (type=00 opc=01). + 0xFFu, // D <- D (Invalid). + 0xFFu, // D <- Q (Invalid). + 0xFFu, // Q <- H (Invalid). + 0xFFu, // Q <- S (Invalid). + 0xFFu, // Q <- D (Invalid). + 0xFFu // Q <- Q (Invalid). + }; + + uint32_t typeOpc = table[(dstSz << 2) | srcSz]; + opcode.reset(0b0001111000100010010000 << 10); + opcode.addImm(typeOpc >> 4, 22); + opcode.addImm(typeOpc & 15, 15); + goto EmitOp_Rd0_Rn5; + } + + break; + } + + case InstDB::kEncodingSimdFcvtLN: { + const InstDB::EncodingData::SimdFcvtLN& opData = InstDB::EncodingData::simdFcvtLN[encodingIndex]; + + if (isign4 == ENC_OPS2(Reg, Reg)) { + // Scalar form - only FCVTXN. + if (o0.as().isVecS() && o1.as().isVecD()) { + if (!opData.hasScalar()) + goto InvalidInstruction; + + if (o0.as().hasElementType() || o1.as().hasElementType()) + goto InvalidInstruction; + + opcode.reset(opData.scalarOp()); + opcode |= B(22); // sz bit must be 1, the only supported combination of FCVTXN. + goto EmitOp_Rd0_Rn5; + } + + opcode.reset(opData.vectorOp()); + + const Vec& rL = (instFlags & InstDB::kInstFlagLong) ? o0.as() : o1.as(); + const Vec& rN = (instFlags & InstDB::kInstFlagLong) ? o1.as() : o0.as(); + + uint32_t q = diff(rN.type(), RegType::kARM_VecD); + if (uint32_t(opcode.hasQ()) != q) + goto InvalidInstruction; + + if (rL.isVecS4() && rN.elementType() == Vec::kElementTypeH && !opData.isCvtxn()) { + goto EmitOp_Rd0_Rn5; + } + + if (rL.isVecD2() && rN.elementType() == Vec::kElementTypeS) { + opcode |= B(22); + goto EmitOp_Rd0_Rn5; + } + } + + break; + } + + case InstDB::kEncodingSimdFcvtSV: { + const InstDB::EncodingData::SimdFcvtSV& opData = InstDB::EncodingData::simdFcvtSV[encodingIndex]; + + // So we can support both IntToFloat and FloatToInt conversions. + const Operand_& oGp = opData.isFloatToInt() ? o0 : o1; + const Operand_& oVec = opData.isFloatToInt() ? o1 : o0; + + if (isign4 == ENC_OPS2(Reg, Reg)) { + if (oGp.as().isGp() && oVec.as().isVec()) { + uint32_t x = oGp.as().isGpX(); + uint32_t type = diff(oVec.as().type(), RegType::kARM_VecH); + + if (type > 2u) + goto InvalidInstruction; + + type = (type - 1u) & 0x3; + opcode.reset(opData.generalOp()); + opcode.addImm(type, 22); + opcode.addImm(x, 31); + goto EmitOp_Rd0_Rn5; + } + + if (o0.as().isVec() && o1.as().isVec()) { + if (!checkSignature(o0, o1)) + goto InvalidInstruction; + + if (!pickFpOpcode(o0.as(), opData.scalarIntOp(), InstDB::kHF_B, opData.vectorIntOp(), InstDB::kHF_B, &opcode)) + goto InvalidInstruction; + + goto EmitOp_Rd0_Rn5; + } + } + + if (isign4 == ENC_OPS3(Reg, Reg, Imm) && opData.isFixedPoint()) { + if (o2.as().valueAs() >= 64) + goto InvalidInstruction; + + uint32_t scale = o2.as().valueAs(); + if (scale == 0) + goto InvalidInstruction; + + if (oGp.as().isGp() && oVec.as().isVec()) { + uint32_t x = oGp.as().isGpX(); + uint32_t type = diff(oVec.as().type(), RegType::kARM_VecH); + + uint32_t scaleLimit = 32u << x; + if (scale > scaleLimit) + goto InvalidInstruction; + + type = (type - 1u) & 0x3; + opcode.reset(opData.generalOp() ^ B(21)); + opcode.addImm(type, 22); + opcode.addImm(x, 31); + opcode.addImm(64u - scale, 10); + goto EmitOp_Rd0_Rn5; + } + + if (o0.as().isVec() && o1.as().isVec()) { + if (!checkSignature(o0, o1)) + goto InvalidInstruction; + + uint32_t sz; + if (!pickFpOpcode(o0.as(), opData.scalarFpOp(), InstDB::kHF_0, opData.vectorFpOp(), InstDB::kHF_0, &opcode, &sz)) + goto InvalidInstruction; + + uint32_t scaleLimit = 16u << sz; + if (scale > scaleLimit) + goto InvalidInstruction; + + uint32_t imm = Support::neg(scale) & Support::lsbMask(sz + 4 + 1); + opcode.addImm(imm, 16); + goto EmitOp_Rd0_Rn5; + } + } + + break; + } + + case InstDB::kEncodingSimdFmlal: { + const InstDB::EncodingData::SimdFmlal& opData = InstDB::EncodingData::simdFmlal[encodingIndex]; + + if (isign4 == ENC_OPS3(Reg, Reg, Reg)) { + uint32_t q = diff(o0.as().type(), RegType::kARM_VecD); + uint32_t qIsOptional = opData.optionalQ(); + + if (qIsOptional) { + // This instruction works with either 64-bit or 128-bit registers, + // encoded by Q bit. + if (q > 1) + goto InvalidInstruction; + } + else { + // This instruction requires 128-bit vector registers. + if (q != 1) + goto InvalidInstruction; + + // The instruction is ehtier B (bottom) or T (top), which is part of + // the opcode, which uses Q bit, so we have to clear it explicitly. + q = 0; + } + + if (uint32_t(o0.as().type()) != uint32_t(o1.as().type()) + qIsOptional || + o0.as().elementType() != opData.tA || + o1.as().elementType() != opData.tB) + goto InvalidInstruction; + + if (!o2.as().hasElementIndex()) { + if (!checkSignature(o1, o2)) + goto InvalidInstruction; + + opcode.reset(opData.vectorOp()); + opcode.addImm(q, 30); + goto EmitOp_Rd0_Rn5_Rm16; + } + else { + if (o2.as().elementType() != opData.tElement) + goto InvalidInstruction; + + if (o2.as().id() > 15) + goto InvalidPhysId; + + uint32_t elementIndex = o2.as().elementIndex(); + if (elementIndex > 7u) + goto InvalidElementIndex; + + opcode.reset(opData.elementOp()); + opcode.addImm(q, 30); + opcode.addImm(elementIndex & 3u, 20); + opcode.addImm(elementIndex >> 2, 11); + goto EmitOp_Rd0_Rn5_Rm16; + } + } + + break; + } + + case InstDB::kEncodingSimdFmov: { + if (isign4 == ENC_OPS2(Reg, Reg)) { + // FMOV Gp <-> Vec opcode: + opcode.reset(0b00011110001001100000000000000000); + + if (o0.as().isGp() && o1.as().isVec()) { + // FMOV Wd, Hn (sf=0 type=11 rmode=00 op=110) + // FMOV Xd, Hn (sf=1 type=11 rmode=00 op=110) + // FMOV Wd, Sn (sf=0 type=00 rmode=00 op=110) + // FMOV Xd, Dn (sf=1 type=11 rmode=00 op=110) + // FMOV Xd, Vn.d[1] (sf=1 type=10 rmode=01 op=110) + uint32_t x = o0.as().isGpX(); + uint32_t sz = diff(o1.as().type(), RegType::kARM_VecH); + + uint32_t type = (sz - 1) & 0x3u; + uint32_t rModeOp = 0b00110; + + if (o1.as().hasElementIndex()) { + // Special case. + if (!x || !o1.as().isVecD2() || o1.as().elementIndex() != 1) + goto InvalidInstruction; + type = 0b10; + rModeOp = 0b01110; + } + else { + // Must be scalar. + if (sz > 2) + goto InvalidInstruction; + + if (o1.as().hasElementType()) + goto InvalidInstruction; + + if (o1.as().isVecS() && x) + goto InvalidInstruction; + + if (o1.as().isVecD() && !x) + goto InvalidInstruction; + } + + opcode.addImm(x, 31); + opcode.addImm(type, 22); + opcode.addImm(rModeOp, 16); + goto EmitOp_Rd0_Rn5; + } + + if (o0.as().isVec() && o1.as().isGp()) { + // FMOV Hd, Wn (sf=0 type=11 rmode=00 op=111) + // FMOV Hd, Xn (sf=1 type=11 rmode=00 op=111) + // FMOV Sd, Wn (sf=0 type=00 rmode=00 op=111) + // FMOV Dd, Xn (sf=1 type=11 rmode=00 op=111) + // FMOV Vd.d[1], Xn (sf=1 type=10 rmode=01 op=111) + uint32_t x = o1.as().isGpX(); + uint32_t sz = diff(o0.as().type(), RegType::kARM_VecH); + + uint32_t type = (sz - 1) & 0x3u; + uint32_t rModeOp = 0b00111; + + if (o0.as().hasElementIndex()) { + // Special case. + if (!x || !o0.as().isVecD2() || o0.as().elementIndex() != 1) + goto InvalidInstruction; + type = 0b10; + rModeOp = 0b01111; + } + else { + // Must be scalar. + if (sz > 2) + goto InvalidInstruction; + + if (o0.as().hasElementType()) + goto InvalidInstruction; + + if (o0.as().isVecS() && x) + goto InvalidInstruction; + + if (o0.as().isVecD() && !x) + goto InvalidInstruction; + } + + opcode.addImm(x, 31); + opcode.addImm(type, 22); + opcode.addImm(rModeOp, 16); + goto EmitOp_Rd0_Rn5; + } + + if (checkSignature(o0, o1)) { + uint32_t sz = diff(o0.as().type(), RegType::kARM_VecH); + if (sz > 2) + goto InvalidInstruction; + + if (o0.as().hasElementType()) + goto InvalidInstruction; + + uint32_t type = (sz - 1) & 0x3; + opcode.reset(0b00011110001000000100000000000000); + opcode.addImm(type, 22); + goto EmitOp_Rd0_Rn5; + } + } + + if (isign4 == ENC_OPS2(Reg, Imm)) { + if (o0.as().isVec()) { + double fpValue; + if (o1.as().isDouble()) + fpValue = o1.as().valueAs(); + else if (o1.as().isInt32()) + fpValue = o1.as().valueAs(); + else + goto InvalidImmediate; + + if (!Utils::isFP64Imm8(fpValue)) + goto InvalidImmediate; + + uint32_t imm8 = Utils::encodeFP64ToImm8(fpValue); + if (!o0.as().hasElementType()) { + // FMOV (scalar, immediate). + uint32_t sz = diff(o0.as().type(), RegType::kARM_VecH); + uint32_t type = (sz - 1u) & 0x3u; + + if (sz > 2) + goto InvalidInstruction; + + opcode.reset(0b00011110001000000001000000000000); + opcode.addImm(type, 22); + opcode.addImm(imm8, 13); + goto EmitOp_Rd0; + } + else { + uint32_t q = diff(o0.as().type(), RegType::kARM_VecD); + uint32_t sz = o0.as().elementType() - Vec::kElementTypeH; + + if (q > 1 || sz > 2) + goto InvalidInstruction; + + static const uint32_t szBits[3] = { B(11), B(0), B(29) }; + opcode.reset(0b00001111000000001111010000000000); + opcode ^= szBits[sz]; + opcode.addImm(q, 30); + opcode.addImm(imm8 >> 5, 16); + opcode.addImm(imm8 & 31, 5); + goto EmitOp_Rd0; + } + } + } + + break; + } + + case InstDB::kEncodingFSimdPair: { + const InstDB::EncodingData::FSimdPair& opData = InstDB::EncodingData::fSimdPair[encodingIndex]; + + if (isign4 == ENC_OPS2(Reg, Reg)) { + // This operation is only defined for: + // hD, vS.2h (16-bit) + // sD, vS.2s (32-bit) + // dD, vS.2d (64-bit) + uint32_t sz = diff(o0.as().type(), RegType::kARM_VecH); + if (sz > 2) + goto InvalidInstruction; + + static const uint32_t szSignatures[3] = { + VecS::kSignature | (Vec::kSignatureElementH), + VecD::kSignature | (Vec::kSignatureElementS), + VecV::kSignature | (Vec::kSignatureElementD) + }; + + if (o1.signature() != szSignatures[sz]) + goto InvalidInstruction; + + static const uint32_t szBits[] = { B(29), 0, B(22) }; + opcode.reset(opData.scalarOp()); + opcode ^= szBits[sz]; + goto EmitOp_Rd0_Rn5; + } + + if (isign4 == ENC_OPS3(Reg, Reg, Reg)) { + if (!checkSignature(o0, o1, o2)) + goto InvalidInstruction; + + uint32_t q = diff(o0.as().type(), RegType::kARM_VecD); + if (q > 1) + goto InvalidInstruction; + + uint32_t sz = o0.as().elementType() - Vec::kElementTypeH; + if (sz > 2) + goto InvalidInstruction; + + static const uint32_t szBits[3] = { B(22) | B(21) | B(15) | B(14), 0, B(22) }; + opcode.reset(opData.vectorOp()); + opcode ^= szBits[sz]; + opcode.addImm(q, 30); + goto EmitOp_Rd0_Rn5_Rm16; + } + + break; + } + + // ------------------------------------------------------------------------ + // [ISimd - Instructions] + // ------------------------------------------------------------------------ + + case InstDB::kEncodingISimdSV: { + const InstDB::EncodingData::ISimdSV& opData = InstDB::EncodingData::iSimdSV[encodingIndex]; + + if (isign4 == ENC_OPS2(Reg, Reg)) { + // The first destination operand is scalar, which matches element-type of source vectors. + uint32_t L = (instFlags & InstDB::kInstFlagLong) != 0; + if (diff(o0.as().type(), RegType::kARM_VecB) != o1.as().elementType() - Vec::kElementTypeB + L) + goto InvalidInstruction; + + SizeOp sizeOp = armElementTypeToSizeOp(opData.vecOpType, o1.as().type(), o1.as().elementType()); + if (!sizeOp.isValid()) + goto InvalidInstruction; + + opcode.reset(opData.opcode()); + opcode.addImm(sizeOp.q(), 30); + opcode.addImm(sizeOp.size(), 22); + goto EmitOp_Rd0_Rn5; + } + + break; + } + + case InstDB::kEncodingISimdVV: { + const InstDB::EncodingData::ISimdVV& opData = InstDB::EncodingData::iSimdVV[encodingIndex]; + + if (isign4 == ENC_OPS2(Reg, Reg)) { + const Operand_& sop = significantSimdOp(o0, o1, instFlags); + if (!matchSignature(o0, o1, instFlags)) + goto InvalidInstruction; + + SizeOp sizeOp = armElementTypeToSizeOp(opData.vecOpType, sop.as().type(), sop.as().elementType()); + if (!sizeOp.isValid()) + goto InvalidInstruction; + + opcode.reset(opData.opcode()); + opcode.addImm(sizeOp.qs(), 30); + opcode.addImm(sizeOp.scalar(), 28); + opcode.addImm(sizeOp.size(), 22); + goto EmitOp_Rd0_Rn5; + } + + break; + } + + case InstDB::kEncodingISimdVVx: { + const InstDB::EncodingData::ISimdVVx& opData = InstDB::EncodingData::iSimdVVx[encodingIndex]; + + if (isign4 == ENC_OPS2(Reg, Reg)) { + if (o0.signature() != opData.op0Signature || + o1.signature() != opData.op1Signature) + goto InvalidInstruction; + + opcode.reset(opData.opcode()); + goto EmitOp_Rd0_Rn5; + } + + break; + } + + case InstDB::kEncodingISimdVVV: { + const InstDB::EncodingData::ISimdVVV& opData = InstDB::EncodingData::iSimdVVV[encodingIndex]; + + if (isign4 == ENC_OPS3(Reg, Reg, Reg)) { + const Operand_& sop = significantSimdOp(o0, o1, instFlags); + if (!matchSignature(o0, o1, o2, instFlags)) + goto InvalidInstruction; + + SizeOp sizeOp = armElementTypeToSizeOp(opData.vecOpType, sop.as().type(), sop.as().elementType()); + if (!sizeOp.isValid()) + goto InvalidInstruction; + + opcode.reset(opData.opcode()); + opcode.addImm(sizeOp.qs(), 30); + opcode.addImm(sizeOp.scalar(), 28); + opcode.addImm(sizeOp.size(), 22); + goto EmitOp_Rd0_Rn5_Rm16; + } + + break; + } + + case InstDB::kEncodingISimdVVVx: { + const InstDB::EncodingData::ISimdVVVx& opData = InstDB::EncodingData::iSimdVVVx[encodingIndex]; + + if (isign4 == ENC_OPS3(Reg, Reg, Reg)) { + if (o0.signature() != opData.op0Signature || + o1.signature() != opData.op1Signature || + o2.signature() != opData.op2Signature) + goto InvalidInstruction; + + opcode.reset(opData.opcode()); + goto EmitOp_Rd0_Rn5_Rm16; + } + + break; + } + + case InstDB::kEncodingISimdWWV: { + // Special case for wide add/sub [s|b][add|sub][w]{2}. + const InstDB::EncodingData::ISimdWWV& opData = InstDB::EncodingData::iSimdWWV[encodingIndex]; + + if (isign4 == ENC_OPS3(Reg, Reg, Reg)) { + SizeOp sizeOp = armElementTypeToSizeOp(opData.vecOpType, o2.as().type(), o2.as().elementType()); + if (!sizeOp.isValid()) + goto InvalidInstruction; + + if (!checkSignature(o0, o1) || !o0.as().isVecV() || o0.as().elementType() != o2.as().elementType() + 1) + goto InvalidInstruction; + + opcode.reset(opData.opcode()); + opcode.addImm(sizeOp.qs(), 30); + opcode.addImm(sizeOp.scalar(), 28); + opcode.addImm(sizeOp.size(), 22); + goto EmitOp_Rd0_Rn5_Rm16; + } + + break; + } + + case InstDB::kEncodingISimdVVVe: { + const InstDB::EncodingData::ISimdVVVe& opData = InstDB::EncodingData::iSimdVVVe[encodingIndex]; + + if (isign4 == ENC_OPS3(Reg, Reg, Reg)) { + const Operand_& sop = significantSimdOp(o0, o1, instFlags); + if (!matchSignature(o0, o1, instFlags)) + goto InvalidInstruction; + + if (!o2.as().hasElementIndex()) { + SizeOp sizeOp = armElementTypeToSizeOp(opData.regularVecType, sop.as().type(), sop.as().elementType()); + if (!sizeOp.isValid()) + goto InvalidInstruction; + + if (!checkSignature(o1, o2)) + goto InvalidInstruction; + + opcode.reset(uint32_t(opData.regularOp) << 10); + opcode.addImm(sizeOp.qs(), 30); + opcode.addImm(sizeOp.scalar(), 28); + opcode.addImm(sizeOp.size(), 22); + goto EmitOp_Rd0_Rn5_Rm16; + } + else { + SizeOp sizeOp = armElementTypeToSizeOp(opData.elementVecType, sop.as().type(), sop.as().elementType()); + if (!sizeOp.isValid()) + goto InvalidInstruction; + + uint32_t elementIndex = o2.as().elementIndex(); + LMHImm lmh; + + if (!encodeLMH(sizeOp.size(), elementIndex, &lmh)) + goto InvalidElementIndex; + + if (o2.as().id() > lmh.maxRmId) + goto InvalidPhysId; + + opcode.reset(uint32_t(opData.elementOp) << 10); + opcode.addImm(sizeOp.q(), 30); + opcode.addImm(sizeOp.size(), 22); + opcode.addImm(lmh.lm, 20); + opcode.addImm(lmh.h, 11); + goto EmitOp_Rd0_Rn5_Rm16; + } + } + + break; + } + + case InstDB::kEncodingISimdVVVI: { + const InstDB::EncodingData::ISimdVVVI& opData = InstDB::EncodingData::iSimdVVVI[encodingIndex]; + + if (isign4 == ENC_OPS4(Reg, Reg, Reg, Imm)) { + const Operand_& sop = significantSimdOp(o0, o1, instFlags); + if (!matchSignature(o0, o1, o2, instFlags)) + goto InvalidInstruction; + + SizeOp sizeOp = armElementTypeToSizeOp(opData.vecOpType, sop.as().type(), sop.as().elementType()); + if (!sizeOp.isValid()) + goto InvalidInstruction; + + uint64_t immValue = o3.as().valueAs(); + uint32_t immSize = opData.immSize; + + if (opData.imm64HasOneBitLess && !sizeOp.q()) + immSize--; + + uint32_t immMax = 1u << immSize; + if (immValue >= immMax) + goto InvalidImmediate; + + opcode.reset(opData.opcode()); + opcode.addImm(sizeOp.qs(), 30); + opcode.addImm(sizeOp.scalar(), 28); + opcode.addImm(sizeOp.size(), 22); + opcode.addImm(immValue, opData.immShift); + goto EmitOp_Rd0_Rn5_Rm16; + } + + break; + } + + case InstDB::kEncodingISimdVVVV: { + const InstDB::EncodingData::ISimdVVVV& opData = InstDB::EncodingData::iSimdVVVV[encodingIndex]; + + if (isign4 == ENC_OPS4(Reg, Reg, Reg, Reg)) { + const Operand_& sop = significantSimdOp(o0, o1, instFlags); + if (!matchSignature(o0, o1, o2, o3, instFlags)) + goto InvalidInstruction; + + SizeOp sizeOp = armElementTypeToSizeOp(opData.vecOpType, sop.as().type(), sop.as().elementType()); + if (!sizeOp.isValid()) + goto InvalidInstruction; + + opcode.reset(uint32_t(opData.opcode) << 10); + opcode.addImm(sizeOp.qs(), 30); + opcode.addImm(sizeOp.scalar(), 28); + opcode.addImm(sizeOp.size(), 22); + goto EmitOp_Rd0_Rn5_Rm16_Ra10; + } + + break; + } + + case InstDB::kEncodingISimdVVVVx: { + const InstDB::EncodingData::ISimdVVVVx& opData = InstDB::EncodingData::iSimdVVVVx[encodingIndex]; + + if (isign4 == ENC_OPS4(Reg, Reg, Reg, Reg)) { + if (o0.signature() != opData.op0Signature || + o1.signature() != opData.op1Signature || + o2.signature() != opData.op2Signature || + o3.signature() != opData.op3Signature) + goto InvalidInstruction; + + opcode.reset(uint32_t(opData.opcode) << 10); + goto EmitOp_Rd0_Rn5_Rm16_Ra10; + } + + break; + } + + + case InstDB::kEncodingISimdPair: { + const InstDB::EncodingData::ISimdPair& opData = InstDB::EncodingData::iSimdPair[encodingIndex]; + + if (isign4 == ENC_OPS2(Reg, Reg) && opData.opcode2) { + if (o0.as().isVecD1() && o1.as().isVecD2()) { + opcode.reset(uint32_t(opData.opcode2) << 10); + opcode.addImm(0x3, 22); // size. + goto EmitOp_Rd0_Rn5; + } + } + + if (isign4 == ENC_OPS3(Reg, Reg, Reg)) { + if (!matchSignature(o0, o1, o2, instFlags)) + goto InvalidInstruction; + + SizeOp sizeOp = armElementTypeToSizeOp(opData.opType3, o0.as().type(), o0.as().elementType()); + if (!sizeOp.isValid()) + goto InvalidInstruction; + + opcode.reset(uint32_t(opData.opcode3) << 10); + opcode.addImm(sizeOp.qs(), 30); + opcode.addImm(sizeOp.scalar(), 28); + opcode.addImm(sizeOp.size(), 22); + goto EmitOp_Rd0_Rn5_Rm16; + } + + break; + } + + case InstDB::kEncodingSimdBicOrr: { + const InstDB::EncodingData::SimdBicOrr& opData = InstDB::EncodingData::simdBicOrr[encodingIndex]; + + if (isign4 == ENC_OPS3(Reg, Reg, Reg)) { + if (!matchSignature(o0, o1, o2, instFlags)) + goto InvalidInstruction; + + SizeOp sizeOp = armElementTypeToSizeOp(InstDB::kVO_V_B, o0.as().type(), o0.as().elementType()); + if (!sizeOp.isValid()) + goto InvalidInstruction; + + opcode.reset(uint32_t(opData.registerOp) << 10); + opcode.addImm(sizeOp.q(), 30); + goto EmitOp_Rd0_Rn5_Rm16; + } + + if (isign4 == ENC_OPS2(Reg, Imm) || isign4 == ENC_OPS3(Reg, Imm, Imm)) { + SizeOp sizeOp = armElementTypeToSizeOp(InstDB::kVO_V_HS, o0.as().type(), o0.as().elementType()); + if (!sizeOp.isValid()) + goto InvalidInstruction; + + if (o1.as().valueAs() > 0xFFFFFFFFu) + goto InvalidImmediate; + + uint32_t imm = o1.as().valueAs(); + uint32_t shift = 0; + uint32_t maxShift = (8u << sizeOp.size()) - 8u; + + if (o2.isImm()) { + if (o2.as().predicate() != uint32_t(ShiftOp::kLSL)) + goto InvalidImmediate; + + if (imm > 0xFFu || o2.as().valueAs() > maxShift) + goto InvalidImmediate; + + shift = o2.as().valueAs(); + if ((shift & 0x7u) != 0u) + goto InvalidImmediate; + } + else if (imm) { + shift = Support::ctz(imm) & 0x7u; + imm >>= shift; + + if (imm > 0xFFu || shift > maxShift) + goto InvalidImmediate; + } + + uint32_t cmode = 0x1u | ((shift / 8u) << 1); + if (sizeOp.size() == 1) + cmode |= B(3); + + // The immediate value is split into ABC and DEFGH parts. + uint32_t abc = (imm >> 5) & 0x7u; + uint32_t defgh = imm & 0x1Fu; + + opcode.reset(uint32_t(opData.immediateOp) << 10); + opcode.addImm(sizeOp.q(), 30); + opcode.addImm(abc, 16); + opcode.addImm(cmode, 12); + opcode.addImm(defgh, 5); + goto EmitOp_Rd0; + } + + break; + } + + case InstDB::kEncodingSimdCmp: { + const InstDB::EncodingData::SimdCmp& opData = InstDB::EncodingData::simdCmp[encodingIndex]; + + if (isign4 == ENC_OPS3(Reg, Reg, Reg) && opData.regOp) { + if (!matchSignature(o0, o1, o2, instFlags)) + goto InvalidInstruction; + + SizeOp sizeOp = armElementTypeToSizeOp(opData.vecOpType, o0.as().type(), o0.as().elementType()); + if (!sizeOp.isValid()) + goto InvalidInstruction; + + opcode.reset(uint32_t(opData.regOp) << 10); + opcode.addImm(sizeOp.qs(), 30); + opcode.addImm(sizeOp.scalar(), 28); + opcode.addImm(sizeOp.size(), 22); + goto EmitOp_Rd0_Rn5_Rm16; + } + + if (isign4 == ENC_OPS3(Reg, Reg, Imm) && opData.zeroOp) { + if (!matchSignature(o0, o1, instFlags)) + goto InvalidInstruction; + + if (o2.as().value() != 0) + goto InvalidImmediate; + + SizeOp sizeOp = armElementTypeToSizeOp(opData.vecOpType, o0.as().type(), o0.as().elementType()); + if (!sizeOp.isValid()) + goto InvalidInstruction; + + opcode.reset(uint32_t(opData.zeroOp) << 10); + opcode.addImm(sizeOp.qs(), 30); + opcode.addImm(sizeOp.scalar(), 28); + opcode.addImm(sizeOp.size(), 22); + goto EmitOp_Rd0_Rn5; + } + + break; + } + + case InstDB::kEncodingSimdDot: { + const InstDB::EncodingData::SimdDot& opData = InstDB::EncodingData::simdDot[encodingIndex]; + + if (isign4 == ENC_OPS3(Reg, Reg, Reg)) { + uint32_t q = diff(o0.as().type(), RegType::kARM_VecD); + uint32_t size = 2; + + if (q > 1u) + goto InvalidInstruction; + + if (!o2.as().hasElementIndex()) { + if (!opData.vectorOp) + goto InvalidInstruction; + + if (o0.as().type() != o1.as().type() || o1.as().type() != o2.as().type()) + goto InvalidInstruction; + + if (o0.as().elementType() != opData.tA || + o1.as().elementType() != opData.tB || + o2.as().elementType() != opData.tB) + goto InvalidInstruction; + + opcode.reset(uint32_t(opData.vectorOp) << 10); + opcode.addImm(q, 30); + goto EmitOp_Rd0_Rn5_Rm16; + } + else { + if (!opData.elementOp) + goto InvalidInstruction; + + if (o0.as().type() != o1.as().type() || !o2.as().isVecV()) + goto InvalidInstruction; + + if (o0.as().elementType() != opData.tA || + o1.as().elementType() != opData.tB || + o2.as().elementType() != opData.tElement) + goto InvalidInstruction; + + uint32_t elementIndex = o2.as().elementIndex(); + LMHImm lmh; + + if (!encodeLMH(size, elementIndex, &lmh)) + goto InvalidElementIndex; + + if (o2.as().id() > lmh.maxRmId) + goto InvalidPhysId; + + opcode.reset(uint32_t(opData.elementOp) << 10); + opcode.addImm(q, 30); + opcode.addImm(lmh.lm, 20); + opcode.addImm(lmh.h, 11); + goto EmitOp_Rd0_Rn5_Rm16; + } + } + + break; + } + + case InstDB::kEncodingSimdDup: SimdDup: { + if (isign4 == ENC_OPS2(Reg, Reg)) { + // Truth table of valid encodings of `Q:1|ElementType:3` + uint32_t kValidEncodings = B(Vec::kElementTypeB + 0) | + B(Vec::kElementTypeH + 0) | + B(Vec::kElementTypeS + 0) | + B(Vec::kElementTypeB + 8) | + B(Vec::kElementTypeH + 8) | + B(Vec::kElementTypeS + 8) | + B(Vec::kElementTypeD + 8) ; + + uint32_t q = diff(o0.as().type(), RegType::kARM_VecD); + + if (o1.as().isGp()) { + // DUP - Vec (scalar|vector) <- GP register. + // + // NOTE: This is only scalar for `dup d, x` case, otherwise the value + // would be duplicated across all vector elements (1, 2, 4, 8, or 16). + uint32_t elementType = o0.as().elementType(); + if (q > 1 || !Support::bitTest(kValidEncodings, (q << 3) | elementType)) + goto InvalidInstruction; + + uint32_t lsbIndex = elementType - 1u; + uint32_t imm5 = 1u << lsbIndex; + + opcode.reset(0b0000111000000000000011 << 10); + opcode.addImm(q, 30); + opcode.addImm(imm5, 16); + goto EmitOp_Rd0_Rn5; + } + + if (!o1.as().isVec() || !o1.as().hasElementIndex()) + goto InvalidInstruction; + + uint32_t dstIndex = o1.as().elementIndex(); + if (!o0.as().hasElementType()) { + // DUP - Vec (scalar) <- Vec[N]. + uint32_t lsbIndex = diff(o0.as().type(), RegType::kARM_VecB); + + if (lsbIndex != o1.as().elementType() - Vec::kElementTypeB || lsbIndex > 3) + goto InvalidInstruction; + + uint32_t imm5 = ((dstIndex << 1) | 1u) << lsbIndex; + if (imm5 > 31) + goto InvalidElementIndex; + + opcode.reset(0b0101111000000000000001 << 10); + opcode.addImm(imm5, 16); + goto EmitOp_Rd0_Rn5; + } + else { + // DUP - Vec (all) <- Vec[N]. + uint32_t elementType = o0.as().elementType(); + if (q > 1 || !Support::bitTest(kValidEncodings, (q << 3) | elementType)) + goto InvalidInstruction; + + uint32_t lsbIndex = elementType - 1u; + uint32_t imm5 = ((dstIndex << 1) | 1u) << lsbIndex; + + if (imm5 > 31) + goto InvalidElementIndex; + + opcode.reset(0b0000111000000000000001 << 10); + opcode.addImm(q, 30); + opcode.addImm(imm5, 16); + goto EmitOp_Rd0_Rn5; + } + } + + break; + } + + case InstDB::kEncodingSimdIns: SimdIns: { + if (isign4 == ENC_OPS2(Reg, Reg) && o0.as().isVecV()) { + if (!o0.as().hasElementIndex()) + goto InvalidInstruction; + + uint32_t elementType = o0.as().elementType(); + uint32_t dstIndex = o0.as().elementIndex(); + uint32_t lsbIndex = elementType - 1u; + + uint32_t imm5 = ((dstIndex << 1) | 1u) << lsbIndex; + if (imm5 > 31) + goto InvalidElementIndex; + + if (o1.as().isGp()) { + // INS - Vec[N] <- GP register. + opcode.reset(0b0100111000000000000111 << 10); + opcode.addImm(imm5, 16); + goto EmitOp_Rd0_Rn5; + } + else if (o1.as().isVecV() && o1.as().hasElementIndex()) { + // INS - Vec[N] <- Vec[M]. + if (o0.as().elementType() != o1.as().elementType()) + goto InvalidInstruction; + + uint32_t srcIndex = o1.as().elementIndex(); + if (o0.as().type() != o1.as().type()) + goto InvalidInstruction; + + uint32_t imm4 = srcIndex << lsbIndex; + if (imm4 > 15) + goto InvalidElementIndex; + + opcode.reset(0b0110111000000000000001 << 10); + opcode.addImm(imm5, 16); + opcode.addImm(imm4, 11); + goto EmitOp_Rd0_Rn5; + } + } + + break; + } + + case InstDB::kEncodingSimdMov: { + if (isign4 == ENC_OPS2(Reg, Reg)) { + if (o0.as().isVec() && o1.as().isVec()) { + // INS v.x[index], v.x[index]. + if (o0.as().hasElementIndex() && o1.as().hasElementIndex()) + goto SimdIns; + + // DUP {b|h|s|d}, v.{b|h|s|d}[index]. + if (o1.as().hasElementIndex()) + goto SimdDup; + + if (!checkSignature(o0, o1)) + goto InvalidInstruction; + + // ORR Vd, Vn, Vm + uint32_t q = diff(o0.as().type(), RegType::kARM_VecD); + if (q > 1) + goto InvalidInstruction; + + opcode.reset(0b0000111010100000000111 << 10); + opcode.addImm(q, 30); + opcode.addReg(o1, 16); // Vn == Vm. + goto EmitOp_Rd0_Rn5; + } + + if (o0.as().isVec() && o1.as().isGp()) { + // INS v.x[index], Rn. + if (o0.as().hasElementIndex()) + goto SimdIns; + + goto InvalidInstruction; + } + + if (o0.as().isGp() && o1.as().isVec()) { + // UMOV Rd, V.{s|d}[index]. + encodingIndex = 1; + goto SimdUmov; + } + } + + break; + } + + case InstDB::kEncodingSimdMoviMvni: { + const InstDB::EncodingData::SimdMoviMvni& opData = InstDB::EncodingData::simdMoviMvni[encodingIndex]; + + if (isign4 == ENC_OPS2(Reg, Imm) || isign4 == ENC_OPS3(Reg, Imm, Imm)) { + SizeOp sizeOp = armElementTypeToSizeOp(InstDB::kVO_V_Any, o0.as().type(), o0.as().elementType()); + if (!sizeOp.isValid()) + goto InvalidInstruction; + + uint64_t imm64 = o1.as().valueAs(); + uint32_t imm8 = 0; + uint32_t cmode = 0; + uint32_t inverted = opData.inverted; + uint32_t op = 0; + uint32_t shift = 0; + uint32_t shiftOp = uint32_t(ShiftOp::kLSL); + + if (sizeOp.size() == 3u) { + // The second immediate should not be present, however, we accept + // an immediate value of zero as some user code may still pass it. + if (o2.isImm() && o0.as().value() != 0) + goto InvalidImmediate; + + if (Utils::isByteMaskImm8(imm64)) { + imm8 = encodeImm64ByteMaskToImm8(imm64); + } + else { + // Change from D to S and from 64-bit imm to 32-bit imm if this + // is not a byte-mask pattern. + if ((imm64 >> 32) == (imm64 & 0xFFFFFFFFu)) { + imm64 &= 0xFFFFFFFFu; + sizeOp.decrementSize(); + } + else { + goto InvalidImmediate; + } + } + } + + if (sizeOp.size() < 3u) { + if (imm64 > 0xFFFFFFFFu) + goto InvalidImmediate; + imm8 = uint32_t(imm64); + + if (sizeOp.size() == 2) { + if ((imm8 >> 16) == (imm8 & 0xFFFFu)) { + imm8 >>= 16; + sizeOp.decrementSize(); + } + } + + if (sizeOp.size() == 1) { + if (imm8 > 0xFFFFu) + goto InvalidImmediate; + + if ((imm8 >> 8) == (imm8 & 0xFFu)) { + imm8 >>= 8; + sizeOp.decrementSize(); + } + } + + uint32_t maxShift = (8u << sizeOp.size()) - 8u; + if (o2.isImm()) { + if (imm8 > 0xFFu || o2.as().valueAs() > maxShift) + goto InvalidImmediate; + + shift = o2.as().valueAs(); + shiftOp = o2.as().predicate(); + } + else if (imm8) { + shift = Support::ctz(imm8) & ~0x7u; + imm8 >>= shift; + + if (imm8 > 0xFFu || shift > maxShift) + goto InvalidImmediate; + } + + if ((shift & 0x7u) != 0u) + goto InvalidImmediate; + } + + shift /= 8u; + + switch (sizeOp.size()) { + case 0: + if (shiftOp != uint32_t(ShiftOp::kLSL)) + goto InvalidImmediate; + + if (inverted) { + imm8 = ~imm8 & 0xFFu; + inverted = 0; + } + + cmode = B(3) | B(2) | B(1); + break; + + case 1: + if (shiftOp != uint32_t(ShiftOp::kLSL)) + goto InvalidImmediate; + + cmode = B(3) | (shift << 1); + op = inverted; + break; + + case 2: + if (shiftOp == uint32_t(ShiftOp::kLSL)) { + cmode = shift << 1; + } + else if (shiftOp == uint32_t(ShiftOp::kMSL)) { + if (shift == 0 || shift > 2) + goto InvalidImmediate; + cmode = B(3) | B(2) | (shift - 1u); + } + else { + goto InvalidImmediate; + } + + op = inverted; + break; + + case 3: + if (inverted) { + imm8 = ~imm8 & 0xFFu; + inverted = 0; + } + + op = 1; + cmode = B(3) | B(2) | B(1); + break; + } + + // The immediate value is split into ABC and DEFGH parts. + uint32_t abc = (imm8 >> 5) & 0x7u; + uint32_t defgh = imm8 & 0x1Fu; + + opcode.reset(uint32_t(opData.opcode) << 10); + opcode.addImm(sizeOp.q(), 30); + opcode.addImm(op, 29); + opcode.addImm(abc, 16); + opcode.addImm(cmode, 12); + opcode.addImm(defgh, 5); + goto EmitOp_Rd0; + } + + break; + } + + case InstDB::kEncodingSimdShift: { + const InstDB::EncodingData::SimdShift& opData = InstDB::EncodingData::simdShift[encodingIndex]; + + const Operand_& sop = significantSimdOp(o0, o1, instFlags); + SizeOp sizeOp = armElementTypeToSizeOp(opData.vecOpType, sop.as().type(), sop.as().elementType()); + + if (!sizeOp.isValid()) + goto InvalidInstruction; + + if (isign4 == ENC_OPS3(Reg, Reg, Imm) && opData.immediateOp) { + if (!matchSignature(o0, o1, instFlags)) + goto InvalidInstruction; + + if (o2.as().valueAs() > 63) + goto InvalidImmediate; + + uint32_t lsbShift = sizeOp.size() + 3u; + uint32_t lsbMask = (1u << lsbShift) - 1u; + uint32_t imm = o2.as().valueAs(); + + // Some instructions use IMM and some X - IMM, so negate if required. + if (opData.invertedImm) { + if (imm == 0 || imm > (1u << lsbShift)) + goto InvalidImmediate; + imm = Support::neg(imm) & lsbMask; + } + + if (imm > lsbMask) + goto InvalidImmediate; + imm |= (1u << lsbShift); + + opcode.reset(uint32_t(opData.immediateOp) << 10); + opcode.addImm(sizeOp.qs(), 30); + opcode.addImm(sizeOp.scalar(), 28); + opcode.addImm(imm, 16); + goto EmitOp_Rd0_Rn5; + } + + if (isign4 == ENC_OPS3(Reg, Reg, Reg) && opData.registerOp) { + if (!matchSignature(o0, o1, o2, instFlags)) + goto InvalidInstruction; + + opcode.reset(uint32_t(opData.registerOp) << 10); + opcode.addImm(sizeOp.qs(), 30); + opcode.addImm(sizeOp.scalar(), 28); + opcode.addImm(sizeOp.size(), 22); + goto EmitOp_Rd0_Rn5_Rm16; + } + + break; + } + + case InstDB::kEncodingSimdShiftES: { + const InstDB::EncodingData::SimdShiftES& opData = InstDB::EncodingData::simdShiftES[encodingIndex]; + + if (isign4 == ENC_OPS3(Reg, Reg, Imm)) { + SizeOp sizeOp = armElementTypeToSizeOp(opData.vecOpType, o1.as().type(), o1.as().elementType()); + if (!sizeOp.isValid()) + goto InvalidInstruction; + + if (!matchSignature(o0, o1, instFlags)) + goto InvalidInstruction; + + // The immediate value must match the element size. + uint64_t shift = o2.as().valueAs(); + uint32_t shiftOp = o2.as().predicate(); + + if (shift != (8u << sizeOp.size()) || shiftOp != uint32_t(ShiftOp::kLSL)) + goto InvalidImmediate; + + opcode.reset(uint32_t(opData.opcode) << 10); + opcode.addImm(sizeOp.q(), 30); + opcode.addImm(sizeOp.size(), 22); + goto EmitOp_Rd0_Rn5; + } + + break; + } + + case InstDB::kEncodingSimdSm3tt: { + const InstDB::EncodingData::SimdSm3tt& opData = InstDB::EncodingData::simdSm3tt[encodingIndex]; + + if (isign4 == ENC_OPS3(Reg, Reg, Reg)) { + if (o0.as().isVecS4() && o1.as().isVecS4() && o2.as().isVecS4() && o2.as().hasElementIndex()) { + uint32_t imm2 = o2.as().elementIndex(); + if (imm2 > 3) + goto InvalidElementIndex; + + opcode.reset(uint32_t(opData.opcode) << 10); + opcode.addImm(imm2, 12); + goto EmitOp_Rd0_Rn5_Rm16; + } + } + + break; + } + + + case InstDB::kEncodingSimdSmovUmov: SimdUmov: { + const InstDB::EncodingData::SimdSmovUmov& opData = InstDB::EncodingData::simdSmovUmov[encodingIndex]; + + if (isign4 == ENC_OPS2(Reg, Reg) && o0.as().isGp() && o1.as().isVec()) { + SizeOp sizeOp = armElementTypeToSizeOp(opData.vecOpType, o1.as().type(), o1.as().elementType()); + if (!sizeOp.isValid()) + goto InvalidInstruction; + + if (!o1.as().hasElementIndex()) + goto InvalidInstruction; + + uint32_t x = o0.as().isGpX(); + uint32_t gpMustBeX = uint32_t(sizeOp.size() >= 3u - opData.isSigned); + + if (opData.isSigned) { + if (gpMustBeX && !x) + goto InvalidInstruction; + } + else { + if (x != gpMustBeX) + goto InvalidInstruction; + } + + uint32_t elementIndex = o1.as().elementIndex(); + uint32_t maxElementIndex = 15u >> sizeOp.size(); + + if (elementIndex > maxElementIndex) + goto InvalidElementIndex; + + uint32_t imm5 = (1u | (elementIndex << 1)) << sizeOp.size(); + + opcode.reset(uint32_t(opData.opcode) << 10); + opcode.addImm(x, 30); + opcode.addImm(imm5, 16); + goto EmitOp_Rd0_Rn5; + } + + break; + } + + case InstDB::kEncodingSimdSxtlUxtl: { + const InstDB::EncodingData::SimdSxtlUxtl& opData = InstDB::EncodingData::simdSxtlUxtl[encodingIndex]; + + if (isign4 == ENC_OPS2(Reg, Reg)) { + SizeOp sizeOp = armElementTypeToSizeOp(opData.vecOpType, o1.as().type(), o1.as().elementType()); + if (!sizeOp.isValid()) + goto InvalidInstruction; + + if (!matchSignature(o0, o1, instFlags)) + goto InvalidInstruction; + + opcode.reset(uint32_t(opData.opcode) << 10); + opcode.addImm(sizeOp.q(), 30); + opcode.addImm(1u, sizeOp.size() + 19); + goto EmitOp_Rd0_Rn5; + } + + break; + } + + case InstDB::kEncodingSimdTblTbx: { + const InstDB::EncodingData::SimdTblTbx& opData = InstDB::EncodingData::simdTblTbx[encodingIndex]; + + if (isign4 == ENC_OPS3(Reg, Reg, Reg) || isign4 == ENC_OPS4(Reg, Reg, Reg, Reg)) { + // TBL/TBX ., { .16B }, . + // TBL/TBX ., { .16B, .16B }, . + // TBL/TBX ., { .16B, .16B, .16B }, . + // TBL/TBX ., { .16B, .16B, .16B, .16B }, . + opcode.reset(uint32_t(opData.opcode) << 10); + + const Operand_& o4 = opExt[EmitterUtils::kOp4]; + const Operand_& o5 = opExt[EmitterUtils::kOp5]; + + uint32_t q = diff(o0.as().type(), RegType::kARM_VecD); + if (q > 1 || o0.as().hasElementIndex()) + goto InvalidInstruction; + + if (!o1.as().isVecB16() || o1.as().hasElementIndex()) + goto InvalidInstruction; + + uint32_t len = uint32_t(!o3.isNone()) + uint32_t(!o4.isNone()) + uint32_t(!o5.isNone()); + opcode.addImm(q, 30); + opcode.addImm(len, 13); + + switch (len) { + case 0: + if (!checkSignature(o0, o2)) + goto InvalidInstruction; + + if (o2.id() > 31) + goto InvalidPhysId; + + opcode.addReg(o2, 16); + goto EmitOp_Rd0_Rn5; + + case 1: + if (!checkSignature(o0, o3)) + goto InvalidInstruction; + + if (o3.id() > 31) + goto InvalidPhysId; + + opcode.addReg(o3, 16); + goto EmitOp_Rd0_Rn5; + + case 2: + if (!checkSignature(o0, o4)) + goto InvalidInstruction; + + if (o4.id() > 31) + goto InvalidPhysId; + + opcode.addReg(o4, 16); + goto EmitOp_Rd0_Rn5; + + case 3: + if (!checkSignature(o0, o5)) + goto InvalidInstruction; + + if (o5.id() > 31) + goto InvalidPhysId; + + opcode.addReg(o5, 16); + goto EmitOp_Rd0_Rn5; + + default: + // Should never happen. + goto InvalidInstruction; + } + } + + break; + } + + // ------------------------------------------------------------------------ + // [Simd - Load / Store] + // ------------------------------------------------------------------------ + + case InstDB::kEncodingSimdLdSt: { + const InstDB::EncodingData::SimdLdSt& opData = InstDB::EncodingData::simdLdSt[encodingIndex]; + + if (isign4 == ENC_OPS2(Reg, Mem)) { + const Mem& m = o1.as(); + rmRel = &m; + + // Width | SZ | XY | XSZ + // -------+----------+-----------+----- + // 8-bit | size==00 | opc == 01 | 000 + // 16-bit | size==01 | opc == 01 | 001 + // 32-bit | size==10 | opc == 01 | 010 + // 64-bit | size==11 | opc == 01 | 011 + // 128-bit| size==00 | opc == 11 | 100 + uint32_t xsz = diff(o0.as().type(), RegType::kARM_VecB); + if (xsz > 4u || o0.as().hasElementIndex()) + goto InvalidRegType; + + if (!checkVecId(o0)) + goto InvalidPhysId; + + if (!armCheckMemBaseIndexRel(m)) + goto InvalidAddress; + + int64_t offset = m.offset(); + if (m.hasBaseReg()) { + // [Base {Offset | Index}] + if (m.hasIndex()) { + uint32_t opt = armShiftOpToLdStOptMap[m.predicate()]; + if (opt == 0xFFu) + goto InvalidAddress; + + uint32_t shift = m.shift(); + uint32_t s = (shift != 0); + + if (s && shift != xsz) + goto InvalidAddressScale; + + opcode.reset(uint32_t(opData.registerOp) << 21); + opcode.addImm(xsz & 3u, 30); + opcode.addImm(xsz >> 2, 23); + opcode.addImm(opt, 13); + opcode.addImm(s, 12); + opcode |= B(11); + opcode.addReg(o0, 0); + goto EmitOp_MemBaseIndex_Rn5_Rm16; + } + + // Makes it easier to work with the offset especially on 32-bit arch. + if (!Support::isInt32(offset)) + goto InvalidDisplacement; + int32_t offset32 = int32_t(offset); + + if (m.isPreOrPost()) { + if (!Support::isInt9(offset32)) + goto InvalidDisplacement; + + opcode.reset(uint32_t(opData.prePostOp) << 21); + opcode.addImm(xsz & 3u, 30); + opcode.addImm(xsz >> 2, 23); + opcode.addImm(offset32 & 0x1FF, 12); + opcode.addImm(m.isPreIndex(), 11); + opcode |= B(10); + opcode.addReg(o0, 0); + goto EmitOp_MemBase_Rn5; + } + else { + uint32_t imm12 = uint32_t(offset32) >> xsz; + + // If this instruction is not encodable with scaled unsigned offset, try unscaled signed offset. + if (!Support::isUInt12(imm12) || (imm12 << xsz) != uint32_t(offset32)) { + instId = opData.uAltInstId; + instInfo = &InstDB::_instInfoTable[instId]; + encodingIndex = instInfo->_encodingDataIndex; + goto Case_SimdLdurStur; + } + + opcode.reset(uint32_t(opData.uOffsetOp) << 22); + opcode.addImm(xsz & 3u, 30); + opcode.addImm(xsz >> 2, 23); + opcode.addImm(imm12, 10); + opcode.addReg(o0, 0); + goto EmitOp_MemBase_Rn5; + } + } + else { + if (!opData.literalOp) + goto InvalidAddress; + + if (xsz < 2u) + goto InvalidRegType; + + uint32_t opc = xsz - 2u; + opcode.reset(uint32_t(opData.literalOp) << 24); + opcode.addImm(opc, 30); + opcode.addReg(o0, 0); + offsetFormat.resetToImmValue(OffsetType::kSignedOffset, 4, 5, 19, 2); + goto EmitOp_Rel; + } + } + + break; + } + + case InstDB::kEncodingSimdLdpStp: { + const InstDB::EncodingData::SimdLdpStp& opData = InstDB::EncodingData::simdLdpStp[encodingIndex]; + + if (isign4 == ENC_OPS3(Reg, Reg, Mem)) { + const Mem& m = o2.as(); + rmRel = &m; + + uint32_t opc = diff(o0.as().type(), RegType::kARM_VecS); + if (opc > 2u || o0.as().hasElementTypeOrIndex()) + goto InvalidInstruction; + + if (!checkSignature(o0, o1)) + goto InvalidInstruction; + + if (!checkVecId(o0, o1)) + goto InvalidPhysId; + + if (m.baseType() != RegType::kARM_GpX || m.hasIndex()) + goto InvalidAddress; + + if (m.isOffset64Bit()) + goto InvalidDisplacement; + + uint32_t offsetShift = 2u + opc; + int32_t offset32 = m.offsetLo32() >> offsetShift; + + // Make sure we didn't lose bits by applying the mandatory offset shift. + if (Support::shl(offset32, offsetShift) != m.offsetLo32()) + goto InvalidDisplacement; + + // Offset is encoded as a 7-bit immediate. + if (!Support::isInt7(offset32)) + goto InvalidDisplacement; + + if (m.isPreOrPost() && offset32 != 0) { + if (!opData.prePostOp) + goto InvalidAddress; + + opcode.reset(uint32_t(opData.prePostOp) << 22); + opcode.addImm(m.isPreIndex(), 24); + } + else { + opcode.reset(uint32_t(opData.offsetOp) << 22); + } + + opcode.addImm(opc, 30); + opcode.addImm(offset32 & 0x7F, 15); + opcode.addReg(o1, 10); + opcode.addReg(o0, 0); + goto EmitOp_MemBase_Rn5; + } + + break; + } + + case InstDB::kEncodingSimdLdurStur: { +Case_SimdLdurStur: + const InstDB::EncodingData::SimdLdurStur& opData = InstDB::EncodingData::simdLdurStur[encodingIndex]; + + if (isign4 == ENC_OPS2(Reg, Mem)) { + const Mem& m = o1.as(); + rmRel = &m; + + uint32_t sz = diff(o0.as().type(), RegType::kARM_VecB); + if (sz > 4 || o0.as().hasElementTypeOrIndex()) + goto InvalidInstruction; + + if (!checkVecId(o0)) + goto InvalidPhysId; + + if (!armCheckMemBaseIndexRel(m)) + goto InvalidAddress; + + if (m.hasBaseReg() && !m.hasIndex() && !m.isPreOrPost()) { + if (m.isOffset64Bit()) + goto InvalidDisplacement; + + int32_t offset32 = m.offsetLo32(); + if (!Support::isInt9(offset32)) + goto InvalidDisplacement; + + opcode.reset(uint32_t(opData.opcode) << 10); + opcode.addImm(sz & 3u, 30); + opcode.addImm(sz >> 2, 23); + opcode.addImm(offset32 & 0x1FF, 12); + opcode.addReg(o0, 0); + goto EmitOp_MemBase_Rn5; + } + + goto InvalidAddress; + } + + break; + } + + case InstDB::kEncodingSimdLdNStN: { + const InstDB::EncodingData::SimdLdNStN& opData = InstDB::EncodingData::simdLdNStN[encodingIndex]; + const Operand_& o4 = opExt[EmitterUtils::kOp4]; + + uint32_t n = 1; + + if (isign4 == ENC_OPS2(Reg, Mem)) { + if (opData.n != 1) + goto InvalidInstruction; + + rmRel = &o1; + } + else if (isign4 == ENC_OPS3(Reg, Reg, Mem)) { + if (opData.n != 1 && opData.n != 2) + goto InvalidInstruction; + + if (!checkSignature(o0, o1) || !checkConsecutive(o0, o1)) + goto InvalidInstruction; + + n = 2; + rmRel = &o2; + } + else if (isign4 == ENC_OPS4(Reg, Reg, Reg, Mem) && o4.isNone()) { + if (opData.n != 1 && opData.n != 3) + goto InvalidInstruction; + + if (!checkSignature(o0, o1, o2) || !checkConsecutive(o0, o1, o2)) + goto InvalidInstruction; + + n = 3; + rmRel = &o3; + } + else if (isign4 == ENC_OPS4(Reg, Reg, Reg, Reg) && o4.isMem()) { + if (opData.n != 1 && opData.n != 4) + goto InvalidInstruction; + + if (!checkSignature(o0, o1, o2, o3) || !checkConsecutive(o0, o1, o2, o3)) + goto InvalidInstruction; + + n = 4; + rmRel = &o4; + } + else { + goto InvalidInstruction; + } + + // We will use `v` and `m` from now as those are relevant for encoding. + const Vec& v = o0.as(); + const Mem& m = rmRel->as(); + + uint32_t q = 0; + uint32_t rm = 0; + uint32_t rn = m.baseId(); + uint32_t sz = v.elementType() - Vec::kElementTypeB; + uint32_t opcSsize = sz; + uint32_t offsetPossibility = 0; + + if (sz > 3) + goto InvalidInstruction; + + if (m.baseType() != RegType::kARM_GpX) + goto InvalidAddress; + + // Rn cannot be ZR, but can be SP. + if (rn > 30 && rn != Gp::kIdSp) + goto InvalidAddress; + + rn &= 31; + + if (opData.replicate) { + if (n != opData.n) + goto InvalidInstruction; + + // Replicates to the whole register, element index cannot be used. + if (v.hasElementIndex()) + goto InvalidInstruction; + + q = diff(v.type(), RegType::kARM_VecD); + if (q > 1) + goto InvalidInstruction; + + opcode.reset(uint32_t(opData.singleOp) << 10); + offsetPossibility = (1u << sz) * n; + } + else if (v.hasElementIndex()) { + if (n != opData.n) + goto InvalidInstruction; + + // LDx/STx (single structure). + static const uint8_t opcSsizeBySzS[] = { 0x0u << 3, 0x2u << 3, 0x4u << 3, (0x4u << 3) | 1u }; + + opcode.reset(uint32_t(opData.singleOp) << 10); + opcSsize = opcSsizeBySzS[sz]; + offsetPossibility = (1u << sz) * opData.n; + + uint32_t elementIndex = v.elementIndex(); + uint32_t maxElementIndex = 15 >> sz; + + if (elementIndex > maxElementIndex) + goto InvalidElementIndex; + + elementIndex <<= sz; + q = elementIndex >> 3; + opcSsize |= elementIndex & 0x7u; + } + else { + // LDx/STx (multiple structures). + static const uint8_t opcSsizeByN[] = { 0u, 0x7u << 2, 0xAu << 2, 0x6u << 2, 0x2u << 2 }; + + q = diff(v.type(), RegType::kARM_VecD); + if (q > 1) + goto InvalidInstruction; + + if (opData.n == 1) + opcSsize |= opcSsizeByN[n]; + + opcode.reset(uint32_t(opData.multipleOp) << 10); + offsetPossibility = (8u << q) * n; + } + + if (m.hasIndex()) { + if (m.hasOffset() || !m.isPostIndex()) + goto InvalidAddress; + + rm = m.indexId(); + if (rm > 30) + goto InvalidAddress; + + // Bit 23 - PostIndex. + opcode |= B(23); + } + else { + if (m.hasOffset()) { + if (m.offset() != int32_t(offsetPossibility) || !m.isPostIndex()) + goto InvalidAddress; + rm = 31; + + // Bit 23 - PostIndex. + opcode |= B(23); + } + } + + opcode.addImm(q, 30); + opcode.addImm(rm, 16); + opcode.addImm(opcSsize, 10); + opcode.addImm(rn, 5); + goto EmitOp_Rd0; + } + + default: + break; + } + + goto InvalidInstruction; + + // -------------------------------------------------------------------------- + // [EmitGp - Single] + // -------------------------------------------------------------------------- + +EmitOp_Rd0: + if (!checkValidRegs(o0)) + goto InvalidPhysId; + + opcode.addReg(o0, 0); + goto EmitOp; + +EmitOp_Rn5: + if (!checkValidRegs(o0)) + goto InvalidPhysId; + + opcode.addReg(o0, 5); + goto EmitOp; + +EmitOp_Rn5_Rm16: + if (!checkValidRegs(o0, o1)) + goto InvalidPhysId; + + opcode.addReg(o0, 5); + opcode.addReg(o1, 16); + goto EmitOp; + +EmitOp_Rd0_Rn5: + if (!checkValidRegs(o0, o1)) + goto InvalidPhysId; + + opcode.addReg(o0, 0); + opcode.addReg(o1, 5); + goto EmitOp; + +EmitOp_Rd0_Rn5_Rm16_Ra10: + if (!checkValidRegs(o0, o1, o2, o3)) + goto InvalidPhysId; + + opcode.addReg(o0, 0); + opcode.addReg(o1, 5); + opcode.addReg(o2, 16); + opcode.addReg(o3, 10); + goto EmitOp; + +EmitOp_Rd0_Rn5_Rm16: + if (!checkValidRegs(o0, o1, o3)) + goto InvalidPhysId; + + opcode.addReg(o0, 0); + opcode.addReg(o1, 5); + opcode.addReg(o2, 16); + goto EmitOp; + + // -------------------------------------------------------------------------- + // [EmitGp - Multiple] + // -------------------------------------------------------------------------- + +EmitOp_Multiple: + { + ASMJIT_ASSERT(multipleOpCount > 0); + err = writer.ensureSpace(this, multipleOpCount * 4u); + if (ASMJIT_UNLIKELY(err)) + goto Failed; + + for (uint32_t i = 0; i < multipleOpCount; i++) + writer.emit32uLE(multipleOpData[i]); + + goto EmitDone; + } + + // -------------------------------------------------------------------------- + // [EmitGp - Memory] + // -------------------------------------------------------------------------- + +EmitOp_MemBase_Rn5: + if (!checkMemBase(rmRel->as())) + goto InvalidAddress; + + opcode.addReg(rmRel->as().baseId(), 5); + goto EmitOp; + +EmitOp_MemBaseNoImm_Rn5: + if (!checkMemBase(rmRel->as()) || rmRel->as().hasIndex()) + goto InvalidAddress; + + if (rmRel->as().hasOffset()) + goto InvalidDisplacement; + + opcode.addReg(rmRel->as().baseId(), 5); + goto EmitOp; + +EmitOp_MemBaseIndex_Rn5_Rm16: + if (!rmRel->as().hasBaseReg()) + goto InvalidAddress; + + if (rmRel->as().indexId() > 30 && rmRel->as().indexId() != Gp::kIdZr) + goto InvalidPhysId; + + opcode.addReg(rmRel->as().indexId(), 16); + opcode.addReg(rmRel->as().baseId(), 5); + goto EmitOp; + + // -------------------------------------------------------------------------- + // [EmitOp - PC Relative] + // -------------------------------------------------------------------------- + +EmitOp_Rel: + { + if (rmRel->isLabel() || rmRel->isMem()) { + uint32_t labelId; + int64_t labelOffset = 0; + + if (rmRel->isLabel()) { + labelId = rmRel->as