0

I am building a simple feed forward neural network at compile time using const generics and macros. These are a bunch of matrices one after the other.

I have created the network! macro, which works like this:

network!(2, 4, 1)

The first item is the number of inputs, and the rest are the number of neurons per layer. The macro looks as follows:

#[macro_export]
macro_rules! network {
    ( $inputs:expr, $($outputs:expr),* ) => {
        {
            Network {
                layers: [
                    $(
                        &Layer::<$inputs, $outputs>::new(),
                    )*
                ]
            }
        }
    };
}

It declares an array of layer elements, which use const generics to to have a fixed size array of weights on each layer, the first type parameter is the number of inputs it expects and the second type parameter is the number of outputs.

This macro produces the following code:

Network {
    layers: [
         &Layer::<2, 4>::new(),
         &Layer::<2, 1>::new(),
    ]
}

This is completely wrong, because for each layer the number of inputs should be the number of outputs of the previous one, like so (notice 2 -> 4):

Network {
    layers: [
         &Layer::<2, 4>::new(),
         &Layer::<4, 1>::new(),
    ]
}

To do this, I need to replace the $inputs value by the value of $outputs on each iteration, but I have no clue how to do.

Shepmaster
  • 388,571
  • 95
  • 1,107
  • 1,366
codearm
  • 185
  • 1
  • 9
  • 1
    It's hard to answer your question because it doesn't include a [MRE]. We can't tell what crates (and their versions), types, traits, fields, etc. are present in the code. It would make it easier for us to help you if you try to reproduce your error on the [Rust Playground](https://play.rust-lang.org) if possible, otherwise in a brand new Cargo project, then [edit] your question to include the additional info. There are [Rust-specific MRE tips](//stackoverflow.com/tags/rust/info) you can use to reduce your original code for posting here. Thanks! – Shepmaster Mar 31 '21 at 13:48
  • Doing `&Layer::new()` seems likely to not actually work because you are taking a reference to a temporary. Have you tried to write this same code _without_ a macro first? – Shepmaster Mar 31 '21 at 13:53

2 Answers2

3

You can match on the two leading values and then all the rest. Do something specific for the two values and call the macro recursively, reusing the second value:

struct Layer<const I: usize, const O: usize>;

macro_rules! example {
    // Do something interesting for a given pair of arguments
    ($a:literal, $b:literal) => {
        Layer::<$a, $b>;
    };

    // Recursively traverse the arguments
    ($a:literal, $b:literal, $($rest:literal),+) => {
        example!($a, $b);
        example!($b, $($rest),*);
    };
}

fn main() {
    example!(1, 2, 3);
}

Expanding the macro leads to:

fn main() {
    Layer::<1, 2>;
    Layer::<2, 3>;
}
Shepmaster
  • 388,571
  • 95
  • 1,107
  • 1,366
  • This is an excellent answer. Thank you very much! Integer32 must be an awesome company if you are in it :) – codearm Mar 31 '21 at 14:24
  • I've added the final solution if you are interested, it uses TT munchers mixed with your smart trick to populate the array, pretty neat: https://stackoverflow.com/a/66903701/11386119 – codearm Apr 01 '21 at 11:34
0

For those interested I finally was able to populate my network like this, based on @Shepmaster's answer:

struct Network<'a, const L: usize> {
    layers: [&'a dyn Forward; L],
}

macro_rules! network {
    // Recursively accumulate token tree
    (@accum ($a:literal, $b:literal, $($others:literal),+) $($e:tt)*) => {
        network!(@accum ($b, $($others),*) $($e)*, &Layer::<$a, $b>::new())
    };

    // Latest iteration, convert to expression
    (@accum ($a:literal, $b:literal) $($e:tt)*) => {[$($e)*, &Layer::<$a, $b>::new()]};

    // Entrance
    ($a:literal, $b:literal, $($others:literal),+) => {
        Network {
            layers: network!(@accum ($b, $($others),*) &Layer::<$a, $b>::new())
        }
    };
}

For network!(2, 3, 4, 5, 1) it translates to:

Network {
     layers:
          [&Layer::<2, 3>::new(),
           &Layer::<3, 4>::new(),
           &Layer::<4, 5>::new(),
           &Layer::<5, 1>::new()]
};
codearm
  • 185
  • 1
  • 9