this is not really a question, but rather I was wondering if anyone has a better way of doing an occupancy grid in Jax (or in another language) for a 3D grid. Here is some working code, does anyone has a better solution (or any problems with my code?)
import jax.numpy as jnp
from jax import lax
import numpy as np
def mipmap(mat):
assert mat.ndim == 3
xdim, ydim, zdim = mat.shape
assert xdim == ydim
assert ydim == zdim
levels = jnp.log2(xdim)
mipmap = []
data = jnp.array(mat.astype(jnp.float32))
occupancy = data > 0
occupancy = jnp.array(occupancy.astype(jnp.float32))
mipmap.append(occupancy.astype(int))
data = data[None, :, :, :, None]
kernel = jnp.ones([2, 2, 2])[:, :, :, jnp.newaxis, jnp.newaxis]
dn = lax.conv_dimension_numbers(data.shape, kernel.shape, ('NHWDC', 'HWDIO', 'NHWDC'))
for i in range(int(levels)):
out = lax.conv_general_dilated(data, # lhs = image tensor
kernel, # rhs = conv kernel tensor
(2, 2, 2), # window strides
'SAME', # padding mode
(1, 1, 1), # lhs/image dilation
(1, 1, 1), # rhs/kernel dilation
dn) # dimension_numbers
occupancy = out > 0
occupancy = jnp.array(occupancy.astype(jnp.float32))
data = occupancy
mipmap.append(occupancy[0, :, :, :, 0].astype(int))
return mipmap
# example
entry = np.zeros([4, 4, 4])
entry[0, 0, 0] = 1
entry[2, 0, 0] = 1
entry[3, 3, 3] = 1
entry[0, 3, 3] = 1
occupancy = mipmap(entry)
Thanks for reading and letting me know :)