Add os env vars to allow more MPI procs
This commit is contained in:
@ -4,11 +4,17 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import jit
|
||||
from functools import partial
|
||||
import os
|
||||
|
||||
# Required else get `jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory`
|
||||
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
|
||||
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".XX"
|
||||
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]="platform"
|
||||
|
||||
@jax.jit
|
||||
def lj_potential(epsilon, sigma, rij):
|
||||
# A pure function we can differentiate:
|
||||
def _tot_e(rij):
|
||||
"""A differentiable fn for total energy."""
|
||||
r2inv = 1.0 / jnp.sum(rij ** 2, axis=1)
|
||||
r6inv = r2inv * r2inv * r2inv
|
||||
|
||||
@ -17,7 +23,7 @@ def lj_potential(epsilon, sigma, rij):
|
||||
|
||||
eij = r6inv * (lj1 * r6inv - lj2)
|
||||
return 0.5 * jnp.sum(eij), eij
|
||||
# Construct a function computing _tot_e and its derivative
|
||||
# Compute _tot_e and its derivative.
|
||||
(_, eij), fij = jax.value_and_grad(_tot_e, has_aux=True)(rij)
|
||||
return eij, fij
|
||||
|
||||
|
||||
@ -6,6 +6,13 @@ import jax.numpy as jnp
|
||||
from jax import jit
|
||||
from functools import partial
|
||||
import cupy
|
||||
import os
|
||||
|
||||
# Required else get `jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory`
|
||||
# Does not fix GPU problem with larger num. atoms.
|
||||
#os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
|
||||
#os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".XX"
|
||||
#os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]="platform"
|
||||
|
||||
@jax.jit
|
||||
def lj_potential(epsilon, sigma, rij):
|
||||
|
||||
Reference in New Issue
Block a user