I'm using JDK21 EA to test the Vector API performance.
My original (non-vector) code looks like this:
double[] src;
double divisor;
float[] dst;
for (int i=0; i<src.length; ++i) {
if (src[i]<=0.0d) {
dst[i] = -1.0f;
}
else {
double v = src[i] / divisor;
float f = 10.0f * log10((float)v);
dst[i] = f;
}
}
I've rewritten it to use the vector API (I'm sure there's bugs in this as I haven't checked the results):
double[] src;
double divisor;
float[] dst;
for (int j=0; j<src.length; j+=DoubleVector.SPECIES_512.length()) {
DoubleVector dvsrc = DoubleVector.fromArray(DoubleVector.SPECIES_512, src, j);
FloatVector fvdst = FloatVector.fromArray(FloatVector.SPECIES_256, dst, j);
// Identify values <= 0.0d
VectorMask<Double> mask = dvsrc.compare(VectorOperators.LE, 0.0d);
// Set those values to -1.0f
fvdst.blend(-1.0f, mask.cast(FloatVector.SPECIES_256));
// Invert the mask, so it now points to values >0.0d
mask = mask.not();
// Divide and cast to float
FloatVector x = dvsrc.div(divisor, mask).reinterpretAsFloats();
VectorMask<Float> floatMask = mask.cast(FloatVector.SPECIES_256);
// Log10 and multiply
FloatVector processed = x.lanewise(VectorOperators.LOG10).mul(10.0f);
// Store
fvdst.blend(processed, floatMask);
}
So, the code above compiles, but obviously it's a long way from actually working correctly.
First off, the conversion of the DoubleVector (dvsrc) to a FloatVector (x) is wrong as dvsrc.length=8, but x.length=16 (I assume this is because each double gets converted to 2 floats?)
This presumably also means that the mask is invalid.
Then the final blend into fvdst fails because fvdst is 256bits, whereas processed is 512bits long.
java.lang.ClassCastException: class jdk.incubator.vector.Float512Vector cannot be cast to class jdk.incubator.vector.Float256Vector (jdk.incubator.vector.Float512Vector and jdk.incubator.vector.Float256Vector are in module jdk.incubator.vector of loader 'bootstrap')
I've tried the JDK docs, Google searches, different ways to cast the DoubleVector to FloatVector including convertShape() - but so far, I've been unable to get this code to run.
I'm having difficulty getting any real information from the JDK docs.
Can anyone suggest how to fix this code properly or where I can go to get more detailed information about what's going on etc?