5

I need to create a 3D tensor like this (5,3,2) for example

array([[[0, 0],
        [0, 1],
        [0, 0]],

       [[1, 0],
        [0, 0],
        [0, 0]],

       [[0, 0],
        [1, 0],
        [0, 0]],

       [[0, 0],
        [0, 0],
        [1, 0]],

       [[0, 0],
        [0, 1],
        [0, 0]]])

There should be exactly one 'one' placed randomly in every slice (if you consider the tensor to be a loaf of bread). This could be done using loops, but I want to vectorize this part.

Nicolas Gervais
  • 33,817
  • 13
  • 115
  • 143
Atul Vinayak
  • 466
  • 3
  • 15

3 Answers3

4

Try generate a random array, then find the max:

a = np.random.rand(5,3,2)
out = (a == a.max(axis=(1,2))[:,None,None]).astype(int)
Quang Hoang
  • 146,074
  • 10
  • 56
  • 74
  • 1
    This is not a great approach, because there is a small but nonzero chance that the maximum will appear more than once, leading to multiple 1s in a slice. – jakevdp Feb 16 '21 at 05:08
2

The most straightforward way to do this is probably to create an array of zeros, and set a random index to 1. In NumPy, it might look like this:

import numpy as np

K, M, N = 5, 3, 2
i = np.random.randint(0, M, K)
j = np.random.randint(0, N, K)
x = np.zeros((K, M, N))
x[np.arange(K), i, j] = 1

In JAX, it might look something like this:

import jax.numpy as jnp
from jax import random

K, M, N = 5, 3, 2
key1, key2 = random.split(random.PRNGKey(0))
i = random.randint(key1, (K,), 0, M)
j = random.randint(key2, (K,), 0, N)
x = jnp.zeros((K, M, N)).at[jnp.arange(K), i, j].set(1)

A more concise option that also guarantees a single 1 per slice would be to use broadcasted equality of a random integer with an appropriately constructed range:

r = random.randint(random.PRNGKey(0), (K, 1, 1), 0, M * N)
x = (r == jnp.arange(M * N).reshape(M, N)).astype(int)
jakevdp
  • 77,104
  • 11
  • 125
  • 160
0

You can create a zero array where the first element of each sub-array is 1, and then permute it across the final two axes:

x = np.zeros((5,3,2)); x[:,0,0] = 1

rng = np.random.default_rng()
x = rng.permuted(rng.permuted(x, axis=-1), axis=-2)
iacob
  • 20,084
  • 6
  • 92
  • 119