5

I'm trying to use the scipy.hierarchy.cluster module to hierarchically cluster some text. I've done the following:

l = linkage(model.wv.syn0, method='complete', metric='cosine')

den = dendrogram(
l,
leaf_rotation=0.,  
leaf_font_size=16.,  
orientation='left',
leaf_label_func=lambda v: str(model.wv.index2word[v])

The dendrogram function returns a dict containing a representation of the tree where:

den['ivl'] is a list of labels corresponding to the leaves:

['politics', 'protest', 'characterfirstvo', 'machine', 'writing', 'learning', 'healthcare', 'climate', 'of', 'rights', 'activism', 'resistance', 'apk', 'week', 'challenge', 'water', 'obamacare', 'colorado', 'change', 'voiceovers', '52', 'acting', 'android']

den['leaves'] is a list of the position of each leaf in the left-to-right traversal of the leaves:
[0, 18, 5, 6, 2, 7, 12, 16, 21, 20, 22, 3, 10, 14, 15, 19, 11, 1, 17, 4, 13, 8, 9]

I know that scipy's to_tree() method converts a hierarchical clustering represented by a linkage matrix into a tree object by returning a reference to the root node (a ClusterNode object) - but I'm not sure how this root node corresponds to my leaves/labels. For example, the ids returned by the get_id() method in this case are root = 44, left = 41, right = 43:

rootnode, nodelist = to_tree(l, rd=True)
rootID = rootnode.get_id()
leftID = rootnode.get_left().get_id()
rightID = rootnode.get_right().get_id()

My question essentially is, how can I traverse this tree and get the corresponding position in den['leaves'] and label in den['ivl'] for each ClusterNode?

Thank you in advance for any help!

For reference, this is the linkage matrix l:

[[20.         22.          0.72081252  2.        ]
[12.         16.          0.78620636  2.        ]
[ 3.         10.          0.79635815  2.        ]
[ 0.         18.          0.80193474  2.        ]
[15.         19.          0.82297097  2.        ]
[ 2.          7.          0.84152483  2.        ]
[ 1.         17.          0.84453892  2.        ]
[ 4.         13.          0.86098654  2.        ]
[ 8.          9.          0.88163748  2.        ]
[14.         27.          0.91252009  3.        ]
[11.         29.          0.92034739  3.        ]
[21.         23.          0.92406542  3.        ]
[ 5.          6.          0.93213108  2.        ]
[25.         32.          0.98555722  5.        ]
[26.         35.          0.99214198  4.        ]
[30.         31.          1.05624908  4.        ]
[24.         34.          1.0606247   5.        ]
[28.         39.          1.06322889  7.        ]
[37.         40.          1.1455562  11.        ]
[33.         38.          1.15171714  7.        ]
[36.         42.          1.17330334 12.        ]
[41.         43.          1.25056073 23.        ]]

1 Answers1

0

You don't need dendrogram to traverse the cluster tree. Given you have linkage_matrix and array of cluster_ids (output of scipy.cluster.hierarchy.fcluster method), you can use get_node function to get the node of cluster tree corresponding given cluster_id:

import numpy as np
from scipy.cluster.hierarchy import leaders, ClusterNode, to_tree
from typing import Optional, List


def get_node(
    linkage_matrix: np.ndarray,
    clusters_array: np.ndarray,
    cluster_num: int
) -> ClusterNode:
    """
    Returns ClusterNode (the node of the cluster tree) corresponding to the given cluster number.
    :param linkage_matrix: linkage matrix
    :param clusters_array: array of cluster numbers for each point
    :param cluster_num: id of cluster for which we want to get ClusterNode
    :return: ClusterNode corresponding to the given cluster number
    """
    L, M = leaders(linkage_matrix, clusters_array)
    idx = L[M == cluster_num]
    tree = to_tree(linkage_matrix)
    result = search_for_node(tree, idx)
    assert result
    return result


def search_for_node(
    cur_node: Optional[ClusterNode],
    target: int
) -> Optional[ClusterNode]:
    """
    Searches for the node with the given id of the cluster in the given subtree.
    :param cur_node: root of the cluster subtree to search for target node
    :param target: id of the target node (cluster)
    :return: ClusterNode with the given id if it exists in the subtree, None otherwise
    """
    if cur_node is None:
        return False
    if cur_node.get_id() == target:
        return cur_node
    left = search_for_node(cur_node.get_left(), target)
    if left:
        return left
    return search_for_node(cur_node.get_right(), target)

To get all samples that belong to the current cluster node you should just get all descendant leaf nodes:

def get_leaves_ids(node: ClusterNode) -> List[int]:
    """
    Returns ids of all samples (leaf nodes) that belong to the given ClusterNode (belong to the node's subtree).
    :param node: ClusterNode for which we want to get ids of samples
    :return: list of ids of samples that belong to the given ClusterNode
    """
    res = []

    def dfs(cur: Optional[ClusterNode]):
        if cur is None:
            return
        if cur.is_leaf():
            res.append(cur.get_id())
            return
        dfs(cur.get_left())
        dfs(cur.get_right())
    dfs(node)
    return res

Traversing the tree, finding siblings, ancestors, descendants is pretty same as in case of regular trees. IDs of the leaves are actually the IDs of the samples from you dataset, IDs of non-terminal nodes could be mapped to IDs of the clusters using leaders function (see get_node implementation).

Daniel Savenkov
  • 343
  • 2
  • 11