I've implemented the following solution which uses the producer/consumer pattern using BlockingQueue and ExecutorService.
The main thread (producer) instantiates a BlockingQueue for each of the worker threads (consumers) and a boolean volatile variable "terminated" to signal to the working threads when all data has been generated and they should terminate execution (escaping from the while loop, empty the queue and write the remaining data on jdbc connection). The producer produces different data for each thread using the two BlockingQueue blockingQueue1 and blockingQueue2.
Here is the simplified MainThreadProducer, which simply generates integer data for two worker threads:
// MainThreadProducer.java
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.*;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
public class MainThreadProducer {
public static Logger logger = LogManager.getLogger(MainThreadProducer.class);
public final static BlockingQueue<Integer> blockingQueue1 = new LinkedBlockingDeque<>(100);
public final static BlockingQueue<Integer> blockingQueue2 = new LinkedBlockingDeque<>(100);
/* signal to the worker threads that all data has been generated */
public static volatile boolean terminated = false;
private void run () {
try {
ExecutorService executor = Executors.newFixedThreadPool(2);
Future<Integer> future1 = executor.submit(new WorkerThreadConsumer("1"));
Future<Integer> future2 = executor.submit(new WorkerThreadConsumer("2"));
for (int i = 0; i < 10023; ++i) {
blockingQueue1.put(i);
blockingQueue2.put(i*2);
}
executor.shutdown();
terminated = true;
int res1 = future1.get();
int res2 = future1.get();
logger.info("Total rows written (thread 1): " + res1);
logger.info("Total rows written (thread 2): " + res2);
}
catch (Exception e) {
e.printStackTrace();
System.exit(1);
}
}
public static void main(String[] args) {
MainThreadProducer instance = new MainThreadProducer();
instance.run();
}
}
Here is the WorkerThreadConsumer.java class. For this test I am creating two threads which will write to a database DBTEST on table TARGET_1 and TARGET_2 respectively.
Each thread is instantiated with a specific String type (1 and 2), thus it can know from which BlockingQueue it needs to read data.
// WorkerThreadConsumer.java
import java.sql.PreparedStatement;
import com.microsoft.sqlserver.jdbc.SQLServerResultSet;
import java.sql.*;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.Callable;
import Configuration;
public class WorkerThreadConsumer implements Callable<Integer> {
private String type;
public WorkerThreadConsumer (String type) {
this.type = type;
}
@Override
public Integer call() {
String TAG = "[THREAD_" + Thread.currentThread().getId() + "]";
int processed = 0; // number of rows currently processed
int batchSize = 100; // size of the batch we write to the server with the PreparedStatement
try {
// load jdbc driver
Class.forName("com.microsoft.sqlserver.jdbc.SQLServerDriver");
MainThreadProducer.logger.info(TAG + "\tLoaded com.microsoft.sqlserver.jdbc.SQLServerDriver");
String stub = String.format("INSERT INTO DBTEST.dbo.TARGET_%s (id) VALUES (?);", this.type);
BlockingQueue<Integer> queue;
switch (this.type) {
case "1":
queue = MainThreadProducer.blockingQueue1;
break;
case "2":
queue = MainThreadProducer.blockingQueue2;
break;
default:
queue = MainThreadProducer.blockingQueue1;
}
try (Connection connection = DriverManager.getConnection(Configuration.DWH_DB_CONNECTION_URL);
PreparedStatement stmt = connection.prepareStatement(stub);) {
connection.setAutoCommit(false);
while (!MainThreadProducer.terminated) {
int data = queue.take();
stmt.setInt(1, data);
stmt.addBatch();
processed += 1;
if (processed % batchSize == 0) {
int[] result = stmt.executeBatch();
connection.commit();
MainThreadProducer.logger.info(TAG + "\tWritten rows count: " + result.length);
}
}
// empty queue and write
while (!queue.isEmpty()) {
int data = queue.take();
stmt.setInt(1, data);
stmt.addBatch();
processed += 1;
if (processed % batchSize == 0) {
int[] result = stmt.executeBatch();
connection.commit();
MainThreadProducer.logger.info(TAG + "\tWritten rows count: " + result.length);
}
}
// last write in case queue size > batch size
int[] result = stmt.executeBatch();
connection.commit();
MainThreadProducer.logger.info(TAG + "\tWritten rows count: " + result.length);
}
}
catch (Exception e) {
e.printStackTrace();
System.exit(1);
}
return processed;
}
}
The solution seems to work. Please let me know if you see potential issues.