I'm trying to implement a 1D convolutional neural network in Google Jax with stax.GeneralConv() (https://jax.readthedocs.io/en/latest/_modules/jax/experimental/stax.html#GeneralConv). I have a 1 dimensional input array with 18 and output array with 6 entries. I want to implement a CNN with kernel width 3 as follows:
init_random_params, conv_net = stax.serial(
GeneralConv(('NC','IO','NC'),1,(3,),padding='SAME'), # dimension_numbers = ('NC','IO','NC')
LogSoftmax,
Dense(6),
)
with the initial network parameters:
rng = jax.random.PRNGKey(0)
_, init_params = init_random_params(rng, (18,))
But I get the following error:
stax.py", line 75, in <listcomp>
next(filter_shape_iter) for c in rhs_spec]
IndexError: tuple index out of range
stax requires the dimension number rhs_spec to be at least 2 characters long, but I use a 1 dimensional filter. Does anybody have an idea how to solve this problem?