Merge pull request #3577 from bathmatt/kokkos-mliap-pytorch

Have PyTorch interface for MLIAP working in Kokkos.  This uses cuPy a…
This commit is contained in:
Axel Kohlmeyer
2023-01-06 14:24:10 -05:00
committed by GitHub
16 changed files with 681 additions and 36 deletions

View File

@ -3,6 +3,7 @@
if(CMAKE_CXX_STANDARD LESS 14)
message(FATAL_ERROR "The KOKKOS package requires the C++ standard to be set to at least C++14")
endif()
########################################################################
# consistency checks and Kokkos options/settings required by LAMMPS
if(Kokkos_ENABLE_CUDA)
@ -89,7 +90,6 @@ else()
endif()
add_subdirectory(${LAMMPS_LIB_KOKKOS_SRC_DIR} ${LAMMPS_LIB_KOKKOS_BIN_DIR})
set(Kokkos_INCLUDE_DIRS ${LAMMPS_LIB_KOKKOS_SRC_DIR}/core/src
${LAMMPS_LIB_KOKKOS_SRC_DIR}/containers/src
${LAMMPS_LIB_KOKKOS_SRC_DIR}/algorithms/src
@ -143,7 +143,23 @@ if(PKG_ML-IAP)
list(APPEND KOKKOS_PKG_SOURCES ${KOKKOS_PKG_SOURCES_DIR}/mliap_data_kokkos.cpp
${KOKKOS_PKG_SOURCES_DIR}/mliap_descriptor_so3_kokkos.cpp
${KOKKOS_PKG_SOURCES_DIR}/mliap_model_linear_kokkos.cpp
${KOKKOS_PKG_SOURCES_DIR}/mliap_model_python_kokkos.cpp
${KOKKOS_PKG_SOURCES_DIR}/mliap_so3_kokkos.cpp)
# Add KOKKOS version of ML-IAP Python coupling if non-KOKKOS version is included
if(MLIAP_ENABLE_PYTHON AND Cythonize_EXECUTABLE)
file(GLOB MLIAP_KOKKOS_CYTHON_SRC ${LAMMPS_SOURCE_DIR}/KOKKOS/*.pyx)
foreach(MLIAP_CYTHON_FILE ${MLIAP_KOKKOS_CYTHON_SRC})
get_filename_component(MLIAP_CYTHON_BASE ${MLIAP_CYTHON_FILE} NAME_WE)
add_custom_command(OUTPUT ${MLIAP_BINARY_DIR}/${MLIAP_CYTHON_BASE}.cpp ${MLIAP_BINARY_DIR}/${MLIAP_CYTHON_BASE}.h
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${MLIAP_CYTHON_FILE} ${MLIAP_BINARY_DIR}/${MLIAP_CYTHON_BASE}.pyx
COMMAND ${Cythonize_EXECUTABLE} -3 ${MLIAP_BINARY_DIR}/${MLIAP_CYTHON_BASE}.pyx
WORKING_DIRECTORY ${MLIAP_BINARY_DIR}
MAIN_DEPENDENCY ${MLIAP_CYTHON_FILE}
COMMENT "Generating C++ sources with cythonize...")
list(APPEND KOKKOS_PKG_SOURCES ${MLIAP_BINARY_DIR}/${MLIAP_CYTHON_BASE}.cpp)
endforeach()
endif()
endif()
if(PKG_PHONON)