I wanted to implement some computational graph data structure in Rust (that can hopefully support autodiff), and I decided to use the ndarray crate, since that is most similar to NumPy. Making a graph data structure isn't really all that difficult, and involves either using a Vec<Node>
where each Node contains references to other vector elements, or involves a node that has reference counted pointers to other nodes.
However, this entirely fails when using ndarray. Suppose we have a fairly common type, which is Array<T, D>
. This is an n-dim array that owns its data, and has two generic parameters: T for the type (e.g. f32, i32, Complex<f32>
, etc.) and D for the number of dimensions (could be dynamic using IxDyn
). How can I create a useful graph data structure where each single node has different generics? For instance, how can I make a graph where the one node has some D and T, and the node that it points to has a D and T that are of different types?
You might be asking, why would I need that? Why not fix T to some value that makes sense, like f32
, and use dynamic sized arrays? The problem is that I would like to support transformations from real numbers to complex numbers, such as FFT.
So given that you want to implement some variant of reverse mode automatic differentiation on tensors, how would you implement the computational graph?
This is one of my attempts:
use std::{cell::RefCell, ops::Add, rc::Rc};
use ndarray::prelude::*;
use num_traits::One;
pub enum OpType {
Add,
MatVecMul,
Leaf
}
pub struct Expr<'t, T, D> {
index: usize,
tape: &'t Tape,
output: Rc<Array<T, D>>
}
// Note how I have to give Node generic types due to the Array
pub struct Node<T, D> {
value: Rc<Array<T, D>>,
op: OpType,
deps: [usize; 2],
}
// This is causing all the problems. We cannot return the array from this trait.
pub trait TapeNode {
}
impl <T, D> TapeNode for Node<T, D> {
}
pub struct Tape {
nodes: RefCell<Vec<Box<dyn TapeNode>>>
}
impl Tape {
pub fn new() -> Self {
Tape {
nodes: RefCell::new(Vec::new())
}
}
pub fn push_leaf<'t, T, D>(&'t self, value: Array<T, D>) -> Expr<'t, T, D>
where
T: 'static,
D: 'static
{
let mut nodes = self.nodes.borrow_mut();
let value_ref = Rc::from(value);
let new_node = Node {
value: value_ref.clone(),
op: OpType::Leaf,
deps: [0, 0],
};
let len = nodes.len();
nodes.push(Box::from(new_node));
Expr { index: len, tape: &self, output: value_ref.clone() }
}
pub fn push_1<'t, T, D>(&'t self, value: Array<T, D>, id_0: usize, op: OpType) -> Expr<'t, T, D>
where
T: 'static,
D: 'static
{
let mut nodes = self.nodes.borrow_mut();
let value_ref = Rc::from(value);
let new_node = Node {
value: value_ref.clone(),
op,
deps: [id_0, 0],
};
let len = nodes.len();
nodes.push(Box::from(new_node));
Expr { index: len, tape: &self, output: value_ref.clone() }
}
pub fn push_2<'t, T, D>(&'t self, value: Array<T, D>, id_0: usize, id_1: usize, op: OpType) -> Expr<'t, T, D>
where
T: 'static,
D: 'static
{
let mut nodes = self.nodes.borrow_mut();
let value_ref = Rc::from(value);
let new_node = Node {
value: value_ref.clone(),
op,
deps: [id_0, id_1],
};
let len = nodes.len();
nodes.push(Box::from(new_node));
Expr { index: len, tape: &self, output: value_ref.clone() }
}
}
impl <'t, T, D> Add for Expr<'t, T, D>
where
T: Add<T, Output = T> + Clone + One + 'static,
D: Dimension + 'static
{
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
let added = self.output.as_ref() + rhs.output.as_ref();
self.tape.push_2(added, self.index, rhs.index, OpType::Add)
}
}
/*
Problem happens with backpropagation. For that, we must iterate through
the tape backwards and apply the chain rule repeatedly. This requires
being able to get the array at each node, and that fails.
*/
This code is inspired by this blog post:
I actually did get such a graph working, where there is a central list of pointers to a bunch of dynamically dispatched nodes (this is what the Tape struct does). What's the problem, you say? Well, given an index, suppose I want to retrieve some array from a Tape. How would I do that? Because the tape stores pointers to trait objects, should we just add a getter method to the trait definition? No! If we try to return an ndarray Array with generic parameters and add that function to the trait definition, we will create a trait that is not object safe. So, this whole system falls apart. To be honest, it's a bit frustrating.
Is there any way to make this system work, or does the entire approach need to change? If so, what approach would you use? I know there is some approach that works, because the neuronika crate is able to generate a computational graph with nodes of differing generic types.
EDIT: After reviewing how Neuronika's implementation is able to function with ndarray generics (answer posted below), how would you approach this problem? Would you try a different approach, or would you just stick to some variant of this one?