2

I just generate an .mlmodel from create ML application from these feature set. Image below show the selected input features for training the model.

feature set for model training

After my training got finished model preview is showing an extra input feature vector with the name of stateIn (LSTM state input). Image below shows the preview of all input feature vectors including this LSTM input.

Extra input vector

So, I have no idea why this is added into the model prediction and what should be passed to this parameter. Below is the class that is automatically generated while I dragged the .mlmodel into the Xcode. I need explanation about stateIn input vector.

import CoreML


/// Model Prediction Input Type
@available(macOS 10.13, iOS 11.0, tvOS 11.0, watchOS 4.0, *)
class ReloadInput : MLFeatureProvider {

    /// accX window input as 100 element vector of doubles
    var accX: MLMultiArray

    /// accY window input as 100 element vector of doubles
    var accY: MLMultiArray

    /// accZ window input as 100 element vector of doubles
    var accZ: MLMultiArray

    /// gyroX window input as 100 element vector of doubles
    var gyroX: MLMultiArray

    /// gyroY window input as 100 element vector of doubles
    var gyroY: MLMultiArray

    /// gyroZ window input as 100 element vector of doubles
    var gyroZ: MLMultiArray

    /// LSTM state input as 400 element vector of doubles
    var stateIn: MLMultiArray

    var featureNames: Set<String> {
        get {
            return ["accX", "accY", "accZ", "gyroX", "gyroY", "gyroZ", "stateIn"]
        }
    }
    
    func featureValue(for featureName: String) -> MLFeatureValue? {
        if (featureName == "accX") {
            return MLFeatureValue(multiArray: accX)
        }
        if (featureName == "accY") {
            return MLFeatureValue(multiArray: accY)
        }
        if (featureName == "accZ") {
            return MLFeatureValue(multiArray: accZ)
        }
        if (featureName == "gyroX") {
            return MLFeatureValue(multiArray: gyroX)
        }
        if (featureName == "gyroY") {
            return MLFeatureValue(multiArray: gyroY)
        }
        if (featureName == "gyroZ") {
            return MLFeatureValue(multiArray: gyroZ)
        }
        if (featureName == "stateIn") {
            return MLFeatureValue(multiArray: stateIn)
        }
        return nil
    }
    
    init(accX: MLMultiArray, accY: MLMultiArray, accZ: MLMultiArray, gyroX: MLMultiArray, gyroY: MLMultiArray, gyroZ: MLMultiArray, stateIn: MLMultiArray) {
        self.accX = accX
        self.accY = accY
        self.accZ = accZ
        self.gyroX = gyroX
        self.gyroY = gyroY
        self.gyroZ = gyroZ
        self.stateIn = stateIn
    }

    @available(macOS 12.0, iOS 15.0, tvOS 15.0, watchOS 8.0, *)
    convenience init(accX: MLShapedArray<Double>, accY: MLShapedArray<Double>, accZ: MLShapedArray<Double>, gyroX: MLShapedArray<Double>, gyroY: MLShapedArray<Double>, gyroZ: MLShapedArray<Double>, stateIn: MLShapedArray<Double>) {
        self.init(accX: MLMultiArray(accX), accY: MLMultiArray(accY), accZ: MLMultiArray(accZ), gyroX: MLMultiArray(gyroX), gyroY: MLMultiArray(gyroY), gyroZ: MLMultiArray(gyroZ), stateIn: MLMultiArray(stateIn))
    }

}
Qazi Ammar
  • 953
  • 1
  • 8
  • 23

0 Answers0