I following example/warpctc/lstm_ocr.py to training a model. Now I had saved a checkpoint mymodel-0100.params and mymodel-symbol.json.
So, How can I make a predict with this checkpoint use only one image?
I had tired to use Predictor interface, code below:
# Load the pre-trained model
symbol_file = "mymodel-symbol.json"
param_file = "mymodel-0100.params"
predictor = Predictor(open(symbol_file).read(),
open(param_file).read(),
{'data':(80, 30)})
But data shape always raise error and I don't know how to set this value. Anybody help me thank you.
However, I also tried another way: add one line code in the end of mxnet/example/warpctc/lstm_ocr.py:
model = mx.model.FeedForward(ctx=contexts,
symbol=symbol,
num_epoch=num_epoch,
learning_rate=learning_rate,
momentum=momentum,
wd=0.00001,
initializer=mx.init.Xavier(factor_type="in", magnitude=2.34))
model.fit(X=data_train, eval_data=data_val,
eval_metric = mx.metric.np(Accuracy),
batch_end_callback=mx.callback.Speedometer(BATCH_SIZE, 50),)
model.save("ocr")
# add new line for predict
model.predict(data_val)
But it is always error output:
Traceback (most recent call last):
File "lstm_ctc_ocr.py", line 211, in <module>
training_all()
File "lstm_ctc_ocr.py", line 188, in training_all
model.predict(data_val)
File "/home/bobliu/Work/code/DL/mxnet/python/mxnet/model.py", line 618, in predict
self._init_predictor(data_shapes, type_dict)
File "/home/bobliu/Work/code/DL/mxnet/python/mxnet/model.py", line 541, in _init_predictor
self.ctx[0], grad_req='null', type_dict=type_dict, **dict(input_shapes))
File "/home/bobliu/Work/code/DL/mxnet/python/mxnet/symbol.py", line 685, in simple_bind
arg_types, _, aux_types = self.infer_type(**type_dict)
File "/home/bobliu/Work/code/DL/mxnet/python/mxnet/symbol.py", line 417, in infer_type
ctypes.byref(complete)))
File "/home/bobliu/Work/code/DL/mxnet/python/mxnet/base.py", line 77, in check_call
raise MXNetError(py_str(_LIB.MXGetLastError()))
mxnet.base.MXNetError: InferType Error in reshape0: [21:39:03] src/operator/./reshape-inl.h:345: Check failed: (dtype) != (-1) First input must have specified type