I am implementing a Radial Basis Function in Halide, and while I have it running successfully it is quite slow. For each pixel I compute the distance, then take a weighted sum of this distance to produce the output. To loop over the weights I use an RDom (as seen below). In this implementation, every pixel computation requires reloading all of the many (3000+) weights, hence the slow speed.
My question is how to take advantage of Halide's scheduling functionality in this instance. My desire is to load some of the weights, compute partial weighted sums for a subset of the pixels, load the next set of weights, and continue to completion. This keeps locality for each smaller group of weights, and that kind of thing is exactly what Halide is built for. Unfortunately I haven't found anything for this specific problem. The RDom seems to be at a lower level of abstraction than the scheduling primitives, so its unclear how to schedule this.
Any alternative suggestions for weighted sum implementation in Halide are welcome. No need to do this with an RDom, I'm just not aware of any other way.
Func rbf_ctrl_pts("rbf_ctrl_pts");
// Initialization with all zero
rbf_ctrl_pts(x,y,c) = cast<float>(0);
// Index to iterate with
RDom idx(0,num_ctrl_pts);
// Loop code
// Subtract the vectors
Expr red_sub = (*in_func)(x,y,0) - (*ctrl_pts_h)(0,idx);
Expr green_sub = (*in_func)(x,y,1) - (*ctrl_pts_h)(1,idx);
Expr blue_sub = (*in_func)(x,y,2) - (*ctrl_pts_h)(2,idx);
// Take the L2 norm to get the distance
Expr dist = sqrt( red_sub*red_sub +
green_sub*green_sub +
blue_sub*blue_sub );
// Update persistant loop variables
rbf_ctrl_pts(x,y,c) = select( c == 0, rbf_ctrl_pts(x,y,c) +
( (*weights_h)(0,idx) * dist),
c == 1, rbf_ctrl_pts(x,y,c) +
( (*weights_h)(1,idx) * dist),
rbf_ctrl_pts(x,y,c) +
( (*weights_h)(2,idx) * dist));