diff --git a/python/lammps/mliap/pytorch.py b/python/lammps/mliap/pytorch.py index 9aa2da80f4..d699c239b0 100644 --- a/python/lammps/mliap/pytorch.py +++ b/python/lammps/mliap/pytorch.py @@ -300,7 +300,7 @@ class ElemwiseModels(torch.nn.Module): self.subnets = subnets self.n_types = n_types - def forward(self, descriptors, elems): + def forward(self, descriptors, elems, dtype=torch.float64): """ Feeds descriptors to network model after adding zeros into descriptor columns relating to different atom types @@ -319,8 +319,12 @@ class ElemwiseModels(torch.nn.Module): Per atom attribute computed by the network model """ - per_atom_attributes = torch.zeros(elems.size[0]) + self.dtype=dtype + self.to(self.dtype) + + per_atom_attributes = torch.zeros(elems.size(dim=0), dtype=self.dtype) given_elems, elem_indices = torch.unique(elems, return_inverse=True) for i, elem in enumerate(given_elems): - per_atom_attribute[elem_indices == i] = self.subnets[elem](descriptors[elem_indices == i]) + self.subnets[elem].to(self.dtype) + per_atom_attributes[elem_indices == i] = self.subnets[elem](descriptors[elem_indices == i]).flatten() return per_atom_attributes diff --git a/src/ML-IAP/mliap_model_python.cpp b/src/ML-IAP/mliap_model_python.cpp index 8f88fb319d..acbf2ed92d 100644 --- a/src/ML-IAP/mliap_model_python.cpp +++ b/src/ML-IAP/mliap_model_python.cpp @@ -26,6 +26,7 @@ #include "pair_mliap.h" #include "python_compat.h" #include "utils.h" +#include "comm.h" #include @@ -104,7 +105,7 @@ void MLIAPModelPython::read_coeffs(char *fname) if (loaded) { this->connect_param_counts(); } else { - utils::logmesg(lmp, "Loading python model deferred.\n"); + if (comm->me == 0) utils::logmesg(lmp, "Loading python model deferred.\n"); } }