0

I am working on a image classification problem with CNN. I was trying to build my own customized convolution layer for a deep learning architecture in matlab. The following class was written for the same.

classdef CustomConvolutionLayer < nnet.layer.Layer ... % & nnet.layer.Acceleratable properties FilterSize NumFilters
end properties (Learnable) % Layer learnable parameters Weights
end

methods
    function layer = CustomConvolutionLayer(filterSize)           
        % Set layer name.
        layer.Name = "SC";
        % Set layer description.
        layer.Description = "Sparse Convolution";        
        % Initialize layer weights.
        layer.FilterSize = filterSize;
        layer.NumFilters = 1;
        layer.Weights = randn(filterSize);    
    end
    
    function Z = predict(layer, X)            
        Z = conv2(X,layer.Weights,'same');            
    end
    
    function [dLdX,dLdW] = backward(layer,X,~,dLdZ,~)       
        dLdW = zeros(size(layer.Weights),'like',layer.Weights);       
        tW=conv2(X,dLdZ,'same');
        fvalue = mean2(tW(:));
        dLdW(:,:)=fvalue;            
        dLdX = conv2(dLdZ, rot90(layer.Weights, 2), 'same');
    end
end

end

when I am trying to make use of the custom layer with the following deep learning architecture

inputSize = [256 256 1]; numClasses = 25;

layers = [ imageInputLayer(inputSize)
sparseConvolutionLayer03([5 5]) batchNormalizationLayer reluLayer fullyConnectedLayer(numClasses) softmaxLayer classificationLayer];

options = trainingOptions('sgdm', ... 'MaxEpochs',1, ... 'ValidationData',imdsValidation, ... 'ValidationFrequency',7, ... 'Verbose',true, ... 'Plots','training-progress');

net = trainNetwork(imdsTrain,layers,options); disp('Training Over...')

I’m getting the following error

Error in funCNNclassifier01 (line 34) net = trainNetwork(imdsTrain,layers,options); Caused by: Error using conv2 N-D arrays are not supported.

My input size is 256×256 only, but if I check its size, it was shown that 256×256×1×128, which is why this error is thrown. I am very unsure why / how the dimension is changing from 2D to 4D. Your assistance would be more appreciated.

0 Answers0