Hey I have pyspark df like this:
id emb pt
1 [5,0.8..] 'h'
2 [0.7, 0.8..] 'd'
3 [1, 3, ..] 'h'
4 [0.3, 0.8..] 'f'
Now I have to compute Faiss index on emb
column within pt
So first I'm collecting all unique pt:
pt = [val.pt for val in df.select('pt').distinct().collect()]
Now I'm making func to compute faiss index which takes pt
as an input and save the output file.
def multiprocess(pt):
item_emb = df.filter(f.col('pt') == pt).select('emb').rdd.flatMap(lambda x: np.float32(x)).collect()
item_id = df.filter(f.col('pt') == pt).select('id').rdd.flatMap(lambda x: x).collect()
item_ids_np = np.array(item_id, dtype=np.int64)
item_embs_np = np.array(item_emb)
pool_size = item_embs_np.shape[0]
emb_size = item_embs_np.shape[1]
nlist = int(math.sqrt(pool_size))
quantizer = faiss.IndexFlatL2(emb_size)
index = faiss.IndexIVFFlat(quantizer, emb_size, nlist)
index.train(item_embs_np)
index.add_with_ids(item_embs_np, item_ids_np)
faiss.write_index(index, 'Faiss_Index/train_'+pt+'.index')
Now I'm using multiprocess
library to run this parallelly.
pool = mp.Pool(processes = (mp.cpu_count() - 1))
pool.map(multiprocess, pt)
Now my code keeps on running but it's not outputting any file, so I'm not sure if I have executed mp properly or my func is saving file somewhere else. Can someone please help.