0

Today I'll try to use CUDA in my ND4J and Deeplearnint4j project. After that, Neural Net (imported from Keras) began work faster. But the next code began work slowly

I have already tried to change ND4J backend to native (CPU) and I got fast result.

promlem part is highlighted with a comment (in 2 lines)

import com.rabbitmq.client.Channel;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;

import java.io.IOException;
import java.sql.*;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;

public class GraphUpdater implements Runnable {
    private Pair pubPair;
    private ConcurrentHashMap<Integer, INDArray> pubsList;
    private Connection connectionMain;
    private Connection connectionSite;
    private Channel channel;

    GraphUpdater(Pair pubPair, ConcurrentHashMap<Integer, INDArray> pubsList, Channel channel) throws SQLException {
    this.pubPair = pubPair;
    this.channel = channel;
    this.pubsList = pubsList;
    connectionMain = DataBaseConnectionsPool.getConnection();
    connectionSite = DataBaseConnectionsPool.getConnectionSite();
}

@Override
public void run(){
    try {
        channel.basicAck(pubPair.deliveryTag, false);
    } catch (IOException e) {
        System.out.println("Error, pub="+pubPair.pub);
        e.printStackTrace();
    }
    PreparedStatement st;
    PreparedStatement stNew;
    try {
        st = connectionMain.prepareStatement("update vec_graph set closed_pubs=closed_pubs || ? where pub=?");
        stNew = connectionMain.prepareStatement("insert into vec_graph values (?, ?)");

        Statement psNew = connectionMain.createStatement();
        ResultSet rs = psNew.executeQuery("select * from new_public_vectors where pub="+pubPair.pub);
        float[] _floatArr = new float[64];
        while (rs.next()){
            Array arr = rs.getArray("vector");
            Object[] obj = (Object[]) arr.getArray();
            for (int vIndex=0; vIndex < 64; vIndex++){
                _floatArr[vIndex] = (float)(double)obj[vIndex];
            }
            pubsList.put(rs.getInt(1), Nd4j.create(_floatArr));
        }

        //pub from task X all pubs from db
        int pub = pubPair.pub;
        List<Integer> closed = new ArrayList<>();
        double mean = 0.96D;
        INDArray currentVector = pubsList.get(pub);
        //!%!%!%!%slowly part of code
        for (int pubId : pubsList.keySet()) {
            INDArray publicVector = pubsList.get(pubId);
            if (currentVector == null || pub == pubId || publicVector == null){
                continue;
            }
            //!%!%!%!%mega slowly part of code, ~99% of CPU time in VisualVM
            double dist = -Transforms.cosineDistance(currentVector, publicVector) + 1; // Transfer from cosine sim to cosine dist
            if ((dist - mean) < 0.01 && (dist - mean) > 0){
                mean = (mean+dist)/2;
            }else if (dist > mean){
                mean = dist;
                closed.clear();
                st.clearBatch();
            }else{
                continue;
            }
            Array a = connectionMain.createArrayOf("int", new Object[]{pub});
            st.setArray(1, a);
            st.setInt(2, pubId);
            st.addBatch();
            closed.add(pubId);
        }
        Object[] obj_vector = new Object[closed.size()];
        for (int i = 0; i < closed.size(); i++){
            obj_vector[i] = closed.get(i);
        }
        Array closedArray = connectionMain.createArrayOf("int", obj_vector);
        stNew.setInt(1, pub);
        stNew.setArray(2, closedArray);
        stNew.addBatch();

        if (pubPair.byUser != 0){
            showToUser(closed, pub, pubPair.byUser);
        }
        try {
            st.executeBatch();
            stNew.executeBatch();
        }catch (BatchUpdateException e){
            e.printStackTrace();
            e.getNextException().printStackTrace();
        }
    } catch (BatchUpdateException e){
        e.printStackTrace();
        e.getNextException().printStackTrace();
    } catch (SQLException e) {
        e.printStackTrace();
    }finally {
        try {
            connectionMain.close();
            connectionSite.close();
        } catch (SQLException e) {
            e.printStackTrace();
        }
    }
}

I would like somthing from this list:

  1. Get faster result and use GPU

  2. Turn off the GPU for this part of code and leave it on for NN

vladF
  • 101
  • 2
  • Using a GPU for something doesn't just make it faster. The way you set up your problem, and architecture are completely different. They each have different methodologies. Most likely, your code is not set up to take advantage of running parallel on 500+ cores like in a GPU. – Dylan Mar 27 '19 at 18:05

1 Answers1

0

Okay, I'l rewrite part of code with cosineDistance to my own implementation

vladF
  • 101
  • 2