0

I am implementing a derive macro to reduce the amount of boilerplate I have to write for similar types.

I want the macro to operate on structs which have the following format:

#[derive(MyTrait)]
struct SomeStruct {
    records: HashMap<Id, Record>
}

Calling the macro should generate an implementation like so:

impl MyTrait for SomeStruct {
    fn foo(&self, id: Id) -> Record { ... }
}

So I understand how to generate the code using quote:

#[proc_macro_derive(MyTrait)]
pub fn derive_answer_fn(item: TokenStream) -> TokenStream {
    ...
    let generated = quote!{

        impl MyTrait for #struct_name {
            fn foo(&self, id: #id_type) -> #record_type { ... }
        }

    }
    ...
}

But what is the best way to get #struct_name, #id_type and #record_type from the input token stream?

sak
  • 2,612
  • 24
  • 55
  • 1
    Parsing it with `syn`. – Chayim Friedman Sep 11 '22 at 10:49
  • Parsing it with `syn` is the definite, standard, by the book answer to this. But personally, I find `syn` to be fairly complex and really only started getting a hang of proc macros when I tried [`venial`](https://docs.rs/venial/), which is much smaller, but also far less complete. – Caesar Sep 11 '22 at 11:44

1 Answers1

2

One way is to use the venial crate to parse the TokenStream.

use proc_macro2;
use quote::quote;
use venial;

#[proc_macro_derive(MyTrait)]
pub fn derive_answer_fn(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
    // Ensure it's deriving for a struct.
    let s = match venial::parse_declaration(proc_macro2::TokenStream::from(item)) {
        Ok(venial::Declaration::Struct(s)) => s,
        Ok(_) => panic!("Can only derive this trait on a struct"),
        Err(_) => panic!("Error parsing into valid Rust"),
    };

    let struct_name = s.name;

    // Get the struct's first field.
    let fields = s.fields;
    let named_fields = match fields {
        venial::StructFields::Named(named_fields) => named_fields,
        _ => panic!("Expected a named field"),
    };

    let inners: Vec<(venial::NamedField, proc_macro2::Punct)> = named_fields.fields.inner;
    if inners.len() != 1 {
        panic!("Expected exactly one named field");
    }

    // Get the name and type of the first field.
    let first_field_name = &inners[0].0.name;
    let first_field_type = &inners[0].0.ty;

    // Extract Id and Record from the type HashMap<Id, Record>
    if first_field_type.tokens.len() != 6 {
        panic!("Expected type T<R, S> for first named field");
    }

    let id = first_field_type.tokens[2].clone();
    let record = first_field_type.tokens[4].clone();

    // Implement MyTrait.
    let generated = quote! {
        impl MyTrait for #struct_name {
            fn foo(&self, id: #id) -> #record { *self.#first_field_name.get(&id).unwrap() }
        }
    };

    proc_macro::TokenStream::from(generated)
}
Andrew
  • 904
  • 5
  • 17