I am unable to plot graph-neural-networking. I have seen few related questions(1, 2, 3) to this topic but their answers do not apply to graph-neural-networks.
What makes it different is that the input vector include objects of different dimensions e.g. properties matrix dimension is [n_nodes, n_node_features]
, adjacency matrix dimension is [n_nodes, n_nodes]
etc. Here is the example of my Model:
class GIN0(Model):
def __init__(self, channels, n_layers):
super().__init__()
self.conv1 = GINConv(channels, epsilon=0, mlp_hidden=[channels, channels])
self.convs = []
for _ in range(1, n_layers):
self.convs.append(
GINConv(channels, epsilon=0, mlp_hidden=[channels, channels])
)
self.pool = GlobalAvgPool()
self.dense1 = Dense(channels, activation="relu")
self.dropout = Dropout(0.5)
self.dense2 = Dense(channels, activation="relu")
def call(self, inputs):
x, a, i = inputs
x = self.conv1([x, a])
for conv in self.convs:
x = conv([x, a])
x = self.pool([x, i])
x = self.dense1(x)
x = self.dropout(x)
return self.dense2(x)
One of the answers in 2 suggested to add build_graph
function as follows:
class my_model(Model):
def __init__(self, dim):
super(my_model, self).__init__()
self.Base = VGG16(input_shape=(dim), include_top = False, weights = 'imagenet')
self.GAP = L.GlobalAveragePooling2D()
self.BAT = L.BatchNormalization()
self.DROP = L.Dropout(rate=0.1)
self.DENS = L.Dense(256, activation='relu', name = 'dense_A')
self.OUT = L.Dense(1, activation='sigmoid')
def call(self, inputs):
x = self.Base(inputs)
g = self.GAP(x)
b = self.BAT(g)
d = self.DROP(b)
d = self.DENS(d)
return self.OUT(d)
# AFAIK: The most convenient method to print model.summary()
# similar to the sequential or functional API like.
def build_graph(self):
x = Input(shape=(dim))
return Model(inputs=[x], outputs=self.call(x))
dim = (124,124,3)
model = my_model((dim))
model.build((None, *dim))
model.build_graph().summary()
However, I am not sure how to define dim or Input Layer using tf.keras.layers.Input
for such a hybrid data-structure as described above.
Any suggestions?