4

I was looking for a way to read a csv file with an unknown number of columns into a nested dictionary. i.e. for input of the form

file.csv:
1,  2,  3,  4
1,  6,  7,  8
9, 10, 11, 12

I want a dictionary of the form:

{1:{2:{3:4}, 6:{7:8}}, 9:{10:{11:12}}}

This is in order to allow O(1) search of a value in the csv file. Creating the dictionary can take a relatively long time, as in my application I only create it once, but search it millions of times.

I also wanted an option to name the relevant columns, so that I can ignore unnecessary once

shayelk
  • 1,606
  • 1
  • 14
  • 32

2 Answers2

1

Here's a simple, albeit brittle approach:

>>> d = {}
>>> with io.StringIO(s) as f: # fake a file
...     reader = csv.reader(f)
...     for row in reader:
...         nested = d
...         for val in map(int, row[:-2]):
...             nested = nested.setdefault(val, {})
...         k, v = map(int, row[-2:]) # this will fail if you don't have enough columns
...         nested[k] = v
...
>>> d
{1: {2: {3: 4}, 6: {7: 8}}, 9: {10: {11: 12}}}

However, this assumes the number of columns is at least 2.

juanpa.arrivillaga
  • 88,713
  • 10
  • 131
  • 172
  • Any chance of getting more explanation how this works? It looks like There is something interesting going on with the "nested = d" line, which I'm guessing makes use of an interesting feature of [pointers?] that I don't fully understand. When I step through this one line at a time the values for d and nested are different, I don't see d being explicitly updated (but its still getting updates)? – Richard W Jun 21 '18 at 12:12
0

Here is what I came up with. Feel free to comment and suggest improvements.

import csv
import itertools

def list_to_dict(lst):
    # Takes a list, and recursively turns it into a nested dictionary, where
    # the first element is a key, whose value is the dictionary created from the 
    # rest of the list. the last element in the list will be the value of the
    # innermost dictionary
    # INPUTS:
    #   lst - a list (e.g. of strings or floats)
    # OUTPUT:
    #   A nested dictionary
    # EXAMPLE RUN:
    #   >>> lst = [1, 2, 3, 4]
    #   >>> list_to_dict(lst)
    #   {1:{2:{3:4}}}
    if len(lst) == 1:
        return lst[0]
    else:
        data_dict = {lst[-2]: lst[-1]}
        lst.pop()
        lst[-1] = data_dict
        return list_to_dict(lst)


def dict_combine(d1, d2):
    # Combines two nested dictionaries into one.
    # INPUTS:
    #   d1, d2: Two nested dictionaries. The function might change d1 and d2, 
    #           therefore if the input dictionaries are not to be mutated, 
    #           you should pass copies of d1 and d2.
    #           Note that the function works more efficiently if d1 is the 
    #           bigger dictionary.
    # OUTPUT:
    #   The combined dictionary
    # EXAMPLE RUN:
    #   >>> d1 = {1: {2: {3: 4, 5: 6}}}
    #   >>> d2 = {1: {2: {7: 8}, 9: {10, 11}}}
    #   >>> dict_combine(d1, d2)
    #   {1: {2: {3: 4, 5: 6, 7: 8}, 9: {10, 11}}}

    for key in d2:
        if key in d1:
            d1[key] = dict_combine(d1[key], d2[key])
        else:
            d1[key] = d2[key]
    return d1


def csv_to_dict(csv_file_path, params=None, n_row_max=None):
    # NAME: csv_to_dict
    #
    # DESCRIPTION: Reads a csv file and turns relevant columns into a nested 
    #              dictionary.
    #
    # INPUTS:
    #   csv_file_path: The full path to the data file
    #   params:        A list of relevant column names. The resulting dictionary
    #                  will be nested in the same order as parameters in 'params'.
    #                  Default is None (read all columns)
    #   n_row_max:     The maximum number of rows to read. Default is None
    #                  (read all rows)
    #
    # OUTPUT:
    #   A nested dictionary containing all the relevant csv data

    csv_dictionary = {}

    with open(csv_file_path, 'r') as csv_file:
        csv_data = csv.reader(csv_file, delimiter=',')
        names  = next(csv_data)          # Read title line
        if not params:
            # A list of column indices to read from csv
            relevant_param_indices = list(range(0, len(names) - 1))  
        else:
            # A list of column indices to read from csv
            relevant_param_indices = []  
            for name in params:
                if name not in names:    
                # Parameter name is not found in title line
                    raise ValueError('Could not find {} in csv file'.format(name))
                else:
                # Get indices of the relevant columns
                    relevant_param_indices.append(names.index(name))   
        for row in itertools.islice(csv_data, 1, n_row_max):
            # Get a list containing relevant columns only
            relevant_cols = [row[i] for i in relevant_param_indices] 
            # Turn the string to numbers. Not necessary  
            float_row = [float(element) for element in relevant_cols]  
            # Build nested dictionary
            csv_dictionary = dict_combine(csv_dictionary, list_to_dict(float_row))  

        return csv_dictionary
shayelk
  • 1,606
  • 1
  • 14
  • 32