public static Set<String> getRelations(Dataset<org.apache.spark.sql.Row> dataframe){
LogicalPlan plan = dataframe.queryExecution().analyzed();
return scala.collection.JavaConverters.seqAsJavaListConverter(plan.collectLeaves()).asJava()
.stream()
.map(logicalPlan -> {
if(logicalPlan instanceof CatalogRelation){
CatalogRelation catalogRelation = (CatalogRelation) logicalPlan;
return catalogRelation.tableMeta().database() + "." + catalogRelation.tableMeta().identifier().table();
}
if(logicalPlan instanceof LogicalRelation){
LogicalRelation logicalRelation = (LogicalRelation) logicalPlan;
return logicalRelation.catalogTable().get().database() + "." + logicalRelation.catalogTable().get().identifier().table();
}
return "";
}).collect(Collectors.toSet());
}
This is the getRelations() function to retrieve the relations from your Dataframe.
Dataset<org.apache.spark.sql.Row> df = spark.sql("select * from table_a as A join table_b as b where a.pk=b.fk");
This would give you a dataset which you can pass to getRelations() to get your relations in the query.