I am trying to load about 1M rows from a PostgreSQL database into Spark. When using Spark it takes about 10s. However, loading the same query using psycopg2 driver takes 2s. I am using postgresql jdbc driver version 42.0.0
def _loadFromPostGres(name):
url_connect = "jdbc:postgresql:"+dbname
properties = {"user": "postgres", "password": "postgres"}
df = SparkSession.builder.getOrCreate().read.jdbc(url=url_connect, table=name, properties=properties)
return df
df = _loadFromPostGres("""
(SELECT "seriesId", "companyId", "userId", "score"
FROM user_series_game
WHERE "companyId"=655124304077004298) as
user_series_game""")
print measure(lambda : len(df.collect()))
The output is -
--- 10.7214591503 seconds ---
1076131
Using psycopg2 -
import psycopg2
conn = psycopg2.connect(conn_string)
cur = conn.cursor()
def _exec():
cur.execute("""(SELECT "seriesId", "companyId", "userId", "score"
FROM user_series_game
WHERE "companyId"=655124304077004298)""")
return cur.fetchall()
print measure(lambda : len(_exec()))
cur.close()
conn.close()
The output is -
--- 2.27961301804 seconds ---
1076131
The measure function -
def measure(func) :
start_time = time.time()
x = func()
print("--- %s seconds ---" % (time.time() - start_time))
return x
Kindly help me find the cause of this problem.
Edit 1
I did a few more benchmarks. Using Scala and JDBC -
import java.sql._;
import scala.collection.mutable.ArrayBuffer;
def exec() {
val url = ("jdbc:postgresql://prod.caumccqvmegm.ap-southeast-1.rds.amazonaws.com/prod"+
"?tcpKeepAlive=true&prepareThreshold=-1&binaryTransfer=true&defaultRowFetchSize=10000")
val conn = DriverManager.getConnection(url,"postgres","postgres");
val sqlText = """SELECT "seriesId", "companyId", "userId", "score"
FROM user_series_game
WHERE "companyId"=655124304077004298"""
val t0 = System.nanoTime()
val stmt = conn.prepareStatement(sqlText, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
val rs = stmt.executeQuery()
val list = new ArrayBuffer[(Long, Long, Long, Double)]()
while (rs.next()) {
val seriesId = rs.getLong("seriesId")
val companyId = rs.getLong("companyId")
val userId = rs.getLong("userId")
val score = rs.getDouble("score")
list.append((seriesId, companyId, userId, score))
}
val t1 = System.nanoTime()
println("Elapsed time: " + (t1 - t0) * 1e-9 + "s")
println(list.size)
rs.close()
stmt.close()
conn.close()
}
exec()
The output was -
Elapsed time: 1.922102285s
1143402
When I did collect() in Spark + Scala -
import org.apache.spark.sql.SparkSession
def exec2() {
val spark = SparkSession.builder().getOrCreate()
val url = ("jdbc:postgresql://prod.caumccqvmegm.ap-southeast-1.rds.amazonaws.com/prod"+
"?tcpKeepAlive=true&prepareThreshold=-1&binaryTransfer=true&defaultRowFetchSize=10000")
val sqlText = """(SELECT "seriesId", "companyId", "userId", "score"
FROM user_series_game
WHERE "companyId"=655124304077004298) as user_series_game"""
val t0 = System.nanoTime()
val df = spark.read
.format("jdbc")
.option("url", url)
.option("dbtable", sqlText)
.option("user", "postgres")
.option("password", "postgres")
.load()
val list = df.collect()
val t1 = System.nanoTime()
println("Elapsed time: " + (t1 - t0) * 1e-9 + "s")
print (list.size)
}
exec2()
The output was
Elapsed time: 1.486141076s
1143445
So 4x amount of extra time is spent within Python serialisation. I understand there will be some penalty, but this seems too much.