2

I'm trying to make a Sankey-plot using Plotly, which follows the filtering of certain documents into either being in scope or out of scope, i.e. 1 source, 2 targets, however some documents are filtered during step 1, some during step 2 etc. This leads to the following Sankey-plot:

Current output

Now what I would ideally like is for it to look something like this:

Ideal output

I've already tried to look through the documentation on : https://plot.ly/python/reference/#sankey but I fail to find what I'm looking for, ideally I would like to implement a feature to prevent the plot from overlapping nodes and links.

This is the code I'm using the generate the plot object:

def genSankeyPlotObject(df, cat_cols=[], value_cols='', visible = False):

    ### COLORPLATTE TO USE
    colorPalette = ['472d3c', '5e3643', '7a444a', 'a05b53', 'bf7958', 'eea160', 'f4cca1', 'b6d53c', '71aa34', '397b44',
                    '3c5956', '302c2e', '5a5353', '7d7071', 'a0938e', 'cfc6b8', 'dff6f5', '8aebf1', '28ccdf', '3978a8',
                    '394778', '39314b', '564064', '8e478c', 'cd6093', 'ffaeb6', 'f4b41b', 'f47e1b', 'e6482e', 'a93b3b',
                    '827094', '4f546b']

    ### CREATES LABELLIST FROM DEFINED COLUMNS
    labelList = []
    for catCol in cat_cols:
        labelListTemp = list(set(df[catCol].values))
        labelList = labelList + labelListTemp
    labelList = list(dict.fromkeys(labelList))

    ### DEFINES THE NUMBER OF COLORS IN THE COLORPALLET
    colorNum = len(df[cat_cols[0]].unique()) + len(df[cat_cols[1]].unique()) + len(df[cat_cols[2]].unique())
    TempcolorPallet = colorPalette * math.ceil(len(colorPalette)/colorNum)
    shuffle(TempcolorPallet)
    colorList = TempcolorPallet[0:colorNum]

    ### TRANSFORMS DF INTO SOURCE -> TARGET PAIRS
    for i in range(len(cat_cols)-1):
        if i==0:
            sourceTargetDf = df[[cat_cols[i],cat_cols[i+1],value_cols]]
            sourceTargetDf.columns = ['source','target','count']
        else:
            tempDf = df[[cat_cols[i],cat_cols[i+1],value_cols]]
            tempDf.columns = ['source','target','count']
            sourceTargetDf = pd.concat([sourceTargetDf,tempDf])
        sourceTargetDf = sourceTargetDf.groupby(['source','target']).agg({'count':'sum'}).reset_index()

    ### ADDING INDEX TO SOURCE -> TARGET PAIRS
    sourceTargetDf['sourceID'] = sourceTargetDf['source'].apply(lambda x: labelList.index(x))
    sourceTargetDf['targetID'] = sourceTargetDf['target'].apply(lambda x: labelList.index(x))

    ### CREATES THE SANKEY PLOT OBJECT
    data = go.Sankey(node = dict(pad = 15,
                                 thickness = 20,
                                 line = dict(color = "black",
                                             width = 0.5),
                                 label = labelList,
                                 color = colorList),
                     link = dict(source = sourceTargetDf['sourceID'],
                                 target = sourceTargetDf['targetID'],
                                 value = sourceTargetDf['count']),
                     valuesuffix = ' ' + value_cols,
                     visible = visible)

    return data
DasBoot
  • 411
  • 2
  • 12

0 Answers0