0

I'm trying to train a list of text datasets at the character level (for example, a cat => "a", " ", "c", "a", "t") so that I can classify them with great accuracy. I'm using mxnet package (CNN Network) in R and using crepe model. So to prepare for training, I need to do iterations for both training and test datasets. So the code is as follow:

train.iter <- CustomCSVIter$new(iter=NULL, data.csv=train.file.output, 
                            batch.size=args$batch_size, alphabet=alphabet,
                            feature.len=feature.len)  
test.iter <- CustomCSVIter$new(iter=NULL, data.csv=test.file.output, 
                           batch.size=args$batch_size, alphabet=alphabet, 
                           feature.len=feature.len)

data.csv where I have these datasets, batch.size is just an integer, feature.len is also just an integer, and alphabet is a vector of alphanumeric quotations (abcd...?!""). When I run the above code, I get a message saying I have a fatal error and Rstudio crashes and reloads. I don't know what I'm doing wrong. To run the above code, you need the following function:

CustomCSVIter <- setRefClass("CustomCSVIter",
                         fields=c("iter", "data.csv", "batch.size",
                                  "alphabet","feature.len"),
                         contains = "Rcpp_MXArrayDataIter",
                         methods=list(
                           initialize=function(iter, data.csv, batch.size,
                                               alphabet, feature.len){
                             csv_iter <- mx.io.CSVIter(data.csv=data.csv, 
                                                       data.shape=feature.len+1, #=features + label
                                                       batch.size=batch.size)
                             .self$iter <- csv_iter 
                             .self$data.csv <- data.csv
                             .self$batch.size <- batch.size
                             .self$alphabet <- alphabet
                             .self$feature.len <- feature.len
                             .self
                           },
                           value=function(){
                             val <- as.array(.self$iter$value()$data)
                             val.y <- val[1,]
                             val.x <- val[-1,]
                             val.x <- dict.decoder(data=val.x, 
                                                   alphabet=.self$alphabet,
                                                   feature.len=.self$feature.len,
                                                   batch.size=.self$batch.size)
                             val.x <- mx.nd.array(val.x)
                             val.y <- mx.nd.array(val.y)
                             list(data=val.x, label=val.y)
                           },
                           iter.next=function(){
                             .self$iter$iter.next()
                           },
                           reset=function(){
                             .self$iter$reset()
                           },
                           num.pad=function(){
                             .self$iter$num.pad()
                           },
                           finalize=function(){
                             .self$iter$finalize()
                           }
                         )

)

Phillip1982
  • 189
  • 1
  • 2
  • 10

1 Answers1

0

Usually a problem like this arises when there is a mismatch between shapes of the input file and the data.shape parameter of the iterator.

You can easily check if this is the problem by running your code outside from RStudio. Run R from terminal/command line and paste your code there. When an exception happens, it will terminate the R session, and you will be able to read the exception message. In my case it was:

Check failed: row.length == shape.Size() (2 vs. 1) The data size in CSV do not match size of shape: specified shape=(1,), the csv row-length=2

In your case it is probably something similar. Btw, there is an implementation of a custom iterator for MNIST dataset, which you may find useful: https://github.com/apache/incubator-mxnet/issues/4105#issuecomment-266190690

Sergei
  • 1,617
  • 15
  • 31