0

I want to use the graph_cnn (Defferrard et al. 2016) for inputs with variation of number of nodes. The author provided the example code (see graph_cnn). Below is the what I think the critical part of the code

def chebyshev5(self, x, L, Fout, K):
    N, M, Fin = x.get_shape()
    N, M, Fin = int(N), int(M), int(Fin)
    # Rescale Laplacian and store as a TF sparse tensor. Copy to not modify the shared L.
    L = scipy.sparse.csr_matrix(L)
    L = graph.rescale_L(L, lmax=2)
    L = L.tocoo()
    indices = np.column_stack((L.row, L.col))
    L = tf.SparseTensor(indices, L.data, L.shape)
    L = tf.sparse_reorder(L)
    # Transform to Chebyshev basis
    x0 = tf.transpose(x, perm=[1, 2, 0])  # M x Fin x N
    x0 = tf.reshape(x0, [M, Fin*N])  # M x Fin*N
    x = tf.expand_dims(x0, 0)  # 1 x M x Fin*N
    def concat(x, x_):
        x_ = tf.expand_dims(x_, 0)  # 1 x M x Fin*N
        return tf.concat([x, x_], axis=0)  # K x M x Fin*N
    if K > 1:
        x1 = tf.sparse_tensor_dense_matmul(L, x0)
        x = concat(x, x1)
    for k in range(2, K):
        x2 = 2 * tf.sparse_tensor_dense_matmul(L, x1) - x0  # M x Fin*N
        x = concat(x, x2)
        x0, x1 = x1, x2
    x = tf.reshape(x, [K, M, Fin, N])  # K x M x Fin x N
    x = tf.transpose(x, perm=[3,1,2,0])  # N x M x Fin x K
    x = tf.reshape(x, [N*M, Fin*K])  # N*M x Fin*K
    # Filter: Fin*Fout filters of order K, i.e. one filterbank per feature pair.
    W = self._weight_variable([Fin*K, Fout], regularization=False)
    x = tf.matmul(x, W)  # N*M x Fout
    return tf.reshape(x, [N, M, Fout])  # N x M x Fout

Essentially, I think what this does can be simplified as something like

return = concat{(L*x)^k for (k=0 to K-1)} * W

x is the input of N x M x Fin (size variable in any batch):

L is an array of operators on x each with the size of M x M matching the corresponding sample (size variable in any batch).

W is the neural network parameters to be optimized, its size is Fin x K x Fout

N: number of samples in a batch (size fixed for any batch);

M: the number of nodes in the graph (size variable in any batch);

Fin: the number of input features (size fixed for any batch)].

Fout is the number of output features (size fixed for any batch).

K is a constant representing the number of steps (hops) in the graph

For single example, the above code works. But since both x and L have variable length for each sample in a batch, I don't know how to make it work for a batch of samples.

Maosi Chen
  • 1,492
  • 2
  • 14
  • 33

1 Answers1

0

The tf.matmul currently (v1.4) only supports batch matrix multiplication on the lowest 2 dims for dense tensors. If either of the input tensor is sparse, it will prompt dimension mismatch error. tf.sparse_tensor_dense_matmul cannot be applied to batch inputs either.

Therefore, my current solution is to move all L preparation steps before calling the function, pass the L as a dense tensor (shape: [N, M, M]), and use the tf.matmul to perform the batch matrix multiplication.

Here is my revised code:

'''
chebyshev5_batch
Purpose:
    perform the graph filtering on the given layer
Args:
    x: the batch of inputs for the given layer, 
       dense tensor, size: [N, M, Fin], 
    L: the batch of sorted Laplacian of the given layer (tf.Tensor) 
       if in dense format, size of [N, M, M]
    Fout: the number of output features on the given layer
    K: the filter size or number of hopes on the given layer.
    lyr_num: the idx of the original Laplacian lyr (start form 0)
Output:
    y: the filtered output from the given layer

'''
def chebyshev5_batch(x, L, Fout, K, lyr_num):
    N, M, Fin = x.get_shape()
    #N, M, Fin = int(N), int(M), int(Fin)
#    # Rescale Laplacian and store as a TF sparse tensor. Copy to not modify the shared L.
#    L = scipy.sparse.csr_matrix(L)
#    L = graph.rescale_L(L, lmax=2)
#    L = L.tocoo()
#    indices = np.column_stack((L.row, L.col))
#    L = tf.SparseTensor(indices, L.data, L.shape)
#    L = tf.sparse_reorder(L)
#    # Transform to Chebyshev basis
#    x0 = tf.transpose(x, perm=[1, 2, 0])  # M x Fin x N
#    x0 = tf.reshape(x0, [M, Fin*N])  # M x Fin*N

    def expand_concat(orig, new):
        new = tf.expand_dims(new, 0)  # 1 x N x M x Fin
        return tf.concat([orig, new], axis=0)  # (shape(x)[0] + 1) x N x M x Fin

    # L:  # N x M x M
    # x0: # N x M x Fin
    # L*x0: # N x M x Fin

    x0 = x  # N x M x Fin
    stk_x = tf.expand_dims(x0, axis=0)  # 1 x N x M x Fin (eventually K x N x M x Fin, if K>1)

    if K > 1:
        x1 = tf.matmul(L, x0) # N x M x Fin
        stk_x = expand_concat(stk_x, x1) 
    for kk in range(2, K):
        x2 = tf.matmul(L, x1) - x0 # N x M x Fin
        stk_x = expand_concat(stk_x, x2)
        x0 = x1
        x1 = x2

    # now stk_x has the shape of K x N x M x Fin
    # transpose to the shape of  N x M x Fin x K
    ##  source positions         1   2   3     0   
    stk_x_transp = tf.transpose(stk_x, perm=[1,2,3,0])
    stk_x_forMul = tf.reshape(stk_x_transp, [N*M, Fin*K])


    #W = self._weight_variable([Fin*K, Fout], regularization=False)  
    W_initial = tf.truncated_normal_initializer(0, 0.1)
    W = tf.get_variable('weights_L_'+str(lyr_num), [Fin*K, Fout], tf.float32, initializer=W_initial)
    tf.summary.histogram(W.op.name, W)

    y = tf.matmul(stk_x_forMul, W)
    y = tf.reshape(y, [N, M, Fout])
    return y
Maosi Chen
  • 1,492
  • 2
  • 14
  • 33