0

I am implementing a Radial Basis Function in Halide, and while I have it running successfully it is quite slow. For each pixel I compute the distance, then take a weighted sum of this distance to produce the output. To loop over the weights I use an RDom (as seen below). In this implementation, every pixel computation requires reloading all of the many (3000+) weights, hence the slow speed.

My question is how to take advantage of Halide's scheduling functionality in this instance. My desire is to load some of the weights, compute partial weighted sums for a subset of the pixels, load the next set of weights, and continue to completion. This keeps locality for each smaller group of weights, and that kind of thing is exactly what Halide is built for. Unfortunately I haven't found anything for this specific problem. The RDom seems to be at a lower level of abstraction than the scheduling primitives, so its unclear how to schedule this.

Any alternative suggestions for weighted sum implementation in Halide are welcome. No need to do this with an RDom, I'm just not aware of any other way.

  Func rbf_ctrl_pts("rbf_ctrl_pts");
    // Initialization with all zero
    rbf_ctrl_pts(x,y,c) = cast<float>(0);
    // Index to iterate with
    RDom idx(0,num_ctrl_pts);
    // Loop code
    // Subtract the vectors 
    Expr red_sub   = (*in_func)(x,y,0) - (*ctrl_pts_h)(0,idx);
    Expr green_sub = (*in_func)(x,y,1) - (*ctrl_pts_h)(1,idx);
    Expr blue_sub  = (*in_func)(x,y,2) - (*ctrl_pts_h)(2,idx);
    // Take the L2 norm to get the distance
    Expr dist      = sqrt( red_sub*red_sub +
                              green_sub*green_sub +
                              blue_sub*blue_sub );
    // Update persistant loop variables
    rbf_ctrl_pts(x,y,c) = select( c == 0, rbf_ctrl_pts(x,y,c) +
                                    ( (*weights_h)(0,idx) * dist),
                                  c == 1, rbf_ctrl_pts(x,y,c) +
                                    ( (*weights_h)(1,idx) * dist),
                                          rbf_ctrl_pts(x,y,c) +
                                    ( (*weights_h)(2,idx) * dist));
Kantthpel
  • 149
  • 10

1 Answers1

1

You can use split or tile and rfactor in the idx dimension of rbf_ctrl_pts to factor and schedule the reduction operation. Getting locality on the weights should be doable via these mechanisms. I'm not 100% sure the associative prover will handle the select so it may be required to unroll by channels or move to using a Tuple across the channels, although in the code above, I'm not sure the select is doing anything compared to passing c through.

Zalman Stern
  • 3,161
  • 12
  • 18
  • 1
    Yeah, I think the last stage should just be: rbf_ctrl_pts(x,y,c) += ((*weights_h)(c,idx) * dist); – Andrew Adams Oct 25 '16 at 19:29
  • Thank you, although rbf_ctrl_pts only has dimensions of x, y, and c from my understanding of Func definitions. Can I use split/tile for idx even though it isn't in the function definition? – Kantthpel Oct 25 '16 at 21:21