I have this function which works for single vector:
def vec_to_board(vector, player, dim, reverse=False):
player_board = np.zeros(dim * dim)
player_pos = np.argwhere(vector == player)
if not reverse:
player_board[mapping[player_pos.T]] = 1
else:
player_board[reverse_mapping[player_pos.T]] = 1
return np.reshape(player_board, [dim, dim])
However, I want it to work for a batch of vectors.
What I have tried so far:
states = jnp.array([[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2], [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2]])
batch_size = 1
b_states = vmap(vec_to_board)((states, 1, 4), batch_size)
This doesn't work. However, if I understand correctly vmap should be able to handle this transformation for batches?