0

I am playing around with Resnet architectures to compare performance with a CNN. I used this Resnet for my first test. I am reusing my code to load and prepare data for the network and it works just fine. The rest of the script seems to also work fine except when it gets to fit_generator. At fit_generator it pauses for a time then seems to exit where I have a print statement saying "what happened?" I am confused since I would expect an error message or the program to crash or something. I am using windows 10 running the latest version of anaconda. In my condo environment, I am using python 3.6, the latest version of Keras 2.3, the latest version of TensorFlow. I would appreciate any insights.

def batch_generator(X_train, Y_train):  
    while True:
        for fl, lb in zip(X_train, Y_train):
            sam, lam = get_IQsamples(fl, lb)
            max_iter = sam.shape[0]
            sample = []     # store all the generated data batches
            label = []   # store all the generated label batches

            i = 0
            for d, l in zip(sam, lam):
                sample.append(d)
                label.append(l)
                i += 1
                if i == max_iter:
                    break
            sample = np.asarray(sample)        
            label = np.asarray(label)
            yield sample, label


def residual_stack(x, f):
    
    # 1x1 conv linear
    x = Conv2D(f, (1, 1), strides=1, padding='same', data_format='channels_last')(x)
    x = Activation('linear')(x)


    # residual unit 1    
    x_shortcut = x
    x = Conv2D(f, (3, 2), strides=1, padding="same", data_format='channels_last')(x)
    x = Activation('relu')(x)
    x = Conv2D(f, 3, strides=1, padding="same", data_format='channels_last')(x)
    x = Activation('linear')(x)

    # add skip connection
    if x.shape[1:] == x_shortcut.shape[1:]:
      x = Add()([x, x_shortcut])

    else:
      raise Exception('Skip Connection Failure!')


    # residual unit 2    
    x_shortcut = x
    x = Conv2D(f, 3, strides=1, padding="same", data_format='channels_last')(x)
    x = Activation('relu')(x)
    x = Conv2D(f, 3, strides = 1, padding = "same", data_format='channels_last')(x)
    x = Activation('linear')(x)

    # add skip connection
    if x.shape[1:] == x_shortcut.shape[1:]:
      x = Add()([x, x_shortcut])

    else:
      raise Exception('Skip Connection Failure!')


    # max pooling layer
    x = MaxPooling2D(pool_size=2, strides=None, padding='valid', data_format='channels_last')(x)

    return x

.

Define ResNet Model

# define resnet model

def ResNet(input_shape, classes):   

    # create input tensor
    x_input = Input(input_shape)
    x = x_input

    # residual stack
    num_filters = 40
    x = residual_stack(x, num_filters)
    x = residual_stack(x, num_filters)
    x = residual_stack(x, num_filters)
    x = residual_stack(x, num_filters)
    x = residual_stack(x, num_filters)


    # output layer
    x = Flatten()(x)
    x = Dense(128, activation="selu", kernel_initializer="he_normal")(x)
    x = Dropout(.5)(x)
    x = Dense(128, activation="selu", kernel_initializer="he_normal")(x)
    x = Dropout(.5)(x)
    x = Dense(classes , activation='softmax', kernel_initializer = glorot_uniform(seed=0))(x)


    # Create model
    model = Model(inputs = x_input, outputs = x)
    model.summary()

    return model


model = ResNet((32,32,2),8)

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])


print('Load complete!')
print('\n')


steps = val_length_train // batchsize
valid_steps = val_length // batchsize

history = model.fit_generator(
            generator=train_gen,
            epochs=3,
            verbose=0,
            steps_per_epoch=steps,
            validation_data=valid_gen,
            validation_steps=valid_steps,
            callbacks=[tensorboard])

print("what happened?")
Robi Sen
  • 162
  • 1
  • 8

1 Answers1

0

Sort of. If there is a error it will still be thrown and printed of verbose is 0. That being said verbose 0 seems to cause issues for some people. This post is from 2017 but I've seen the same issue as recent as Nov 2019 https://github.com/keras-team/keras/issues/5818. If I use 0 or 2 things work fine but all of that is irrelevant since the script never seems to start grabbing data or training. I appreciate the feedback.

Robi Sen
  • 162
  • 1
  • 8