0

I am trying to perform a matrix-matrix multiplication from within a Generator. I know I should be using define_extern as I have before with other functions, but for some reason with GEMM (i.e. BLAS' mat-mat multiply), I get a segfault.

Here's my code:

class TestBLAS : public Halide::Generator<TestBLAS> {                                           
public:                                                                                                                 
    Input<Buffer<float>> A{"A", 2};                                                             
    Input<Buffer<float>> B{"B", 2};                                                             
    Output<Buffer<float>> C{"C", 2};                                                            

    Var x,y;                                                                                    

    void generate() {                                                                           
        Func g;                                                                                 
        g.define_extern("hblas_sgemm", {false, false, 1.f, A, B, 0.f}, type_of<float>(), 2);

        C(x,y) = g(x,y);                                                                        
    }                                                                                           
};                                                                                              

HALIDE_REGISTER_GENERATOR(TestBLAS, testblas)                                                   

In some apps/linear_algebra scripts, I found Halide::Runtime::Buffer::raw_buffer() being passed. How can I access this pointer from a Halide::GeneratorInput<Halide::Buffer<float> > or even a Halide::Func?

I'd understand the unknown bounds of the Func would make it difficult, but possibly there's a way to supply that information separately?

From the docs I don't seem to find a way...

UPDATE:

Following @Fabian's response, I haven't managed to initialize context from within the generator, but the closest I got to making the code work was this:

...
GEMMGenerator<float> gemm;
void generate() { 
    gemm.set_inputs(1., A, B, 0., B);    
    C = gemm.result_;
}

This compiles, but when I run it, I get:

Condition failed: funcs_.size() == array_size() && exprs_.empty()
Aborted (core dumped)

Is there any way around this?

zanbri
  • 5,958
  • 2
  • 31
  • 41

1 Answers1

0

Looking at your code I assume you want to use another Generator in a generate(). Then this should work:

auto gen = context.create<GEMMGenerator<float>>();
gen->transpose_A_.set(false); // GeneratorParam
gen->apply(a, A, B, ...); // Inputs
Func res = gen->result_;
  • How am I meant to initialize `context`? Also will this work only if I pass Input or also Funcs? – zanbri Nov 20 '17 at 19:16