0

Hey I have pyspark df like this:

id    emb          pt

1    [5,0.8..]     'h'
2    [0.7, 0.8..]  'd'
3    [1, 3, ..]    'h'
4    [0.3, 0.8..]  'f'

Now I have to compute Faiss index on emb column within pt

So first I'm collecting all unique pt:

pt = [val.pt for val in df.select('pt').distinct().collect()]

Now I'm making func to compute faiss index which takes pt as an input and save the output file.

def multiprocess(pt):    
    item_emb = df.filter(f.col('pt') == pt).select('emb').rdd.flatMap(lambda x: np.float32(x)).collect()
    item_id = df.filter(f.col('pt') == pt).select('id').rdd.flatMap(lambda x: x).collect()
    item_ids_np = np.array(item_id, dtype=np.int64)
    item_embs_np = np.array(item_emb)
    pool_size = item_embs_np.shape[0]
    emb_size = item_embs_np.shape[1]
    
    nlist = int(math.sqrt(pool_size))
    quantizer = faiss.IndexFlatL2(emb_size)
    
    index = faiss.IndexIVFFlat(quantizer, emb_size, nlist)
    index.train(item_embs_np)
    
    index.add_with_ids(item_embs_np, item_ids_np)
    faiss.write_index(index, 'Faiss_Index/train_'+pt+'.index')

Now I'm using multiprocess library to run this parallelly.

pool = mp.Pool(processes = (mp.cpu_count() - 1))
pool.map(multiprocess, pt)

Now my code keeps on running but it's not outputting any file, so I'm not sure if I have executed mp properly or my func is saving file somewhere else. Can someone please help.

Chris_007
  • 829
  • 11
  • 29

0 Answers0