3

I'm trying to implement a 1D convolutional neural network in Google Jax with stax.GeneralConv() (https://jax.readthedocs.io/en/latest/_modules/jax/experimental/stax.html#GeneralConv). I have a 1 dimensional input array with 18 and output array with 6 entries. I want to implement a CNN with kernel width 3 as follows:

init_random_params, conv_net = stax.serial(
    GeneralConv(('NC','IO','NC'),1,(3,),padding='SAME'), # dimension_numbers = ('NC','IO','NC')
    LogSoftmax,
    Dense(6),
)

with the initial network parameters:

rng = jax.random.PRNGKey(0)
_, init_params = init_random_params(rng, (18,))

But I get the following error:

stax.py", line 75, in <listcomp>
    next(filter_shape_iter) for c in rhs_spec]

IndexError: tuple index out of range

stax requires the dimension number rhs_spec to be at least 2 characters long, but I use a 1 dimensional filter. Does anybody have an idea how to solve this problem?

salutnomo
  • 31
  • 1

1 Answers1

1

I haven't tried this myself, but I expect that a 1d convolution still requires one direction over which to convolve, e.g.

Conv2d = functools.partial(GeneralConv, ('NHWC', 'HWIO', 'NHWC'))
Conv1d = functools.partial(GeneralConv, ('NHC', 'HIO', 'NHC'))

In other words, dropping the W axis to go from 2d to 1d convolutions.

The input shape corresponding to NHC is (batch_size, sequence_length, num_channels).

Note that even though the number of channels might be 1, you still need include that axis because GeneralConv does an index lookup along the lines of num_channels = input_shape['NHC'.index('C')].

Kris
  • 22,079
  • 3
  • 30
  • 35