I am using JAX MD to simulate the bonded energies/forces between atoms. I have an energy function which computes the bond energy between two atoms. Using smap.bond
I can promote this function to act on a system of atoms, where one can also pass a list of the bonded pairs within the system. I would like to be able to transform this function into a function/ or write a new function such that only the bonded pairs in this original bond list below a certain cut-off contribute to the energy. It would be nice to achieve this through neighbour lists (I can already apply a cutoff, but this is not as efficient as a neighbour list). This function needs to be jittable since I will ultimately use it in simulations using lax.scan
.
I am also aware of smap.pair_neighbor_list
, however I am not sure how to use this when also wanting to pass a bond list.
I would greatly appreciate any help.