My solution will work for the very common use case of Hibernate+Spring+MySQL
Similar to the above answer, I based my solution upon Dr Richard Kennar's. However, since Hibernate is often used with Spring, I wanted my solution to work very well with Spring and the standard method for using Hibernate. Therefore my solution uses a combination of thread locals and singleton beans to achieve the result. Technically the interceptor is invoked on every prepared SQL statement for the SessionFactory, but it skips all logic and does not initialize any ThreadLocal(s) unless it is a query specifically set to count the total rows.
Using the below class, your Spring configuration looks like:
<bean id="foundRowCalculator" class="my.hibernate.classes.MySQLCalcFoundRowsInterceptor" />
<!-- p:sessionFactoryBeanName="mySessionFactory"/ -->
<bean id="mySessionFactory"
class="org.springframework.orm.hibernate3.annotation.AnnotationSessionFactoryBean"
p:dataSource-ref="dataSource"
p:packagesToScan="my.hibernate.classes"
p:entityInterceptor-ref="foundRowCalculator"/>
Basically you must declare the interceptor bean and then reference it in the "entityInterceptor" property of the SessionFactoryBean. You must only set "sessionFactoryBeanName" if there is more than one SessionFactory in your Spring context and the session factory you want to reference is not called "sessionFactory". The reason you cannot set a reference is that this would cause an interdependency between the beans that cannot be resolved.
Using a wrapper bean for the result:
package my.hibernate.classes;
public class PagedResponse<T> {
public final List<T> items;
public final int total;
public PagedResponse(List<T> items, int total) {
this.items = items;
this.total = total;
}
}
Then using an abstract base DAO class you must call "setCalcFoundRows(true)" before making the query and "reset()" after [in a finally block to ensure it's called]:
package my.hibernate.classes;
import org.hibernate.Criteria;
import org.hibernate.Query;
import org.springframework.beans.factory.annotation.Autowired;
public abstract class BaseDAO {
@Autowired
private MySQLCalcFoundRowsInterceptor rowCounter;
public <T> PagedResponse<T> getPagedResponse(Criteria crit, int firstResult, int maxResults) {
rowCounter.setCalcFoundRows(true);
try {
@SuppressWarnings("unchecked")
return new PagedResponse<T>(
crit.
setFirstResult(firstResult).
setMaxResults(maxResults).
list(),
rowCounter.getFoundRows());
} finally {
rowCounter.reset();
}
}
public <T> PagedResponse<T> getPagedResponse(Query query, int firstResult, int maxResults) {
rowCounter.setCalcFoundRows(true);
try {
@SuppressWarnings("unchecked")
return new PagedResponse<T>(
query.
setFirstResult(firstResult).
setMaxResults(maxResults).
list(),
rowCounter.getFoundRows());
} finally {
rowCounter.reset();
}
}
}
Then a concrete DAO class example for an @Entity named MyEntity with a String property "prop":
package my.hibernate.classes;
import org.hibernate.SessionFactory;
import org.hibernate.criterion.Restrictions
import org.springframework.beans.factory.annotation.Autowired;
public class MyEntityDAO extends BaseDAO {
@Autowired
private SessionFactory sessionFactory;
public PagedResponse<MyEntity> getPagedEntitiesWithPropertyValue(String propVal, int firstResult, int maxResults) {
return getPagedResponse(
sessionFactory.
getCurrentSession().
createCriteria(MyEntity.class).
add(Restrictions.eq("prop", propVal)),
firstResult,
maxResults);
}
}
Finally the interceptor class that does all the work:
package my.hibernate.classes;
import java.io.IOException;
import java.sql.Connection;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import org.hibernate.EmptyInterceptor;
import org.hibernate.HibernateException;
import org.hibernate.SessionFactory;
import org.hibernate.Transaction;
import org.hibernate.jdbc.Work;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.BeanFactoryAware;
public class MySQLCalcFoundRowsInterceptor extends EmptyInterceptor implements BeanFactoryAware {
/**
*
*/
private static final long serialVersionUID = 2745492452467374139L;
//
// Private statics
//
private final static String SELECT_PREFIX = "select ";
private final static String CALC_FOUND_ROWS_HINT = "SQL_CALC_FOUND_ROWS ";
private final static String SELECT_FOUND_ROWS = "select FOUND_ROWS()";
//
// Private members
//
private SessionFactory sessionFactory;
private BeanFactory beanFactory;
private String sessionFactoryBeanName;
private ThreadLocal<Boolean> mCalcFoundRows = new ThreadLocal<Boolean>();
private ThreadLocal<Integer> mSQLStatementsPrepared = new ThreadLocal<Integer>() {
@Override
protected Integer initialValue() {
return Integer.valueOf(0);
}
};
private ThreadLocal<Integer> mFoundRows = new ThreadLocal<Integer>();
private void init() {
if (sessionFactory == null) {
if (sessionFactoryBeanName != null) {
sessionFactory = beanFactory.getBean(sessionFactoryBeanName, SessionFactory.class);
} else {
try {
sessionFactory = beanFactory.getBean("sessionFactory", SessionFactory.class);
} catch (RuntimeException exp) {
}
if (sessionFactory == null) {
sessionFactory = beanFactory.getBean(SessionFactory.class);
}
}
}
}
@Override
public String onPrepareStatement(String sql) {
if (mCalcFoundRows.get() == null || !mCalcFoundRows.get().booleanValue()) {
return sql;
}
switch (mSQLStatementsPrepared.get()) {
case 0: {
mSQLStatementsPrepared.set(mSQLStatementsPrepared.get() + 1);
// First time, prefix CALC_FOUND_ROWS_HINT
StringBuilder builder = new StringBuilder(sql);
int indexOf = builder.indexOf(SELECT_PREFIX);
if (indexOf == -1) {
throw new HibernateException("First SQL statement did not contain '" + SELECT_PREFIX + "'");
}
builder.insert(indexOf + SELECT_PREFIX.length(), CALC_FOUND_ROWS_HINT);
return builder.toString();
}
case 1: {
mSQLStatementsPrepared.set(mSQLStatementsPrepared.get() + 1);
// Before any secondary selects, capture FOUND_ROWS. If no secondary
// selects are
// ever executed, getFoundRows() will capture FOUND_ROWS
// just-in-time when called
// directly
captureFoundRows();
return sql;
}
default:
// Pass-through untouched
return sql;
}
}
public void reset() {
if (mCalcFoundRows.get() != null && mCalcFoundRows.get().booleanValue()) {
mSQLStatementsPrepared.remove();
mFoundRows.remove();
mCalcFoundRows.remove();
}
}
@Override
public void afterTransactionCompletion(Transaction tx) {
reset();
}
public void setCalcFoundRows(boolean calc) {
if (calc) {
mCalcFoundRows.set(Boolean.TRUE);
} else {
reset();
}
}
public int getFoundRows() {
if (mCalcFoundRows.get() == null || !mCalcFoundRows.get().booleanValue()) {
throw new IllegalStateException("Attempted to getFoundRows without first calling 'setCalcFoundRows'");
}
if (mFoundRows.get() == null) {
captureFoundRows();
}
return mFoundRows.get();
}
//
// Private methods
//
private void captureFoundRows() {
init();
// Sanity checks
if (mFoundRows.get() != null) {
throw new HibernateException("'" + SELECT_FOUND_ROWS + "' called more than once");
}
if (mSQLStatementsPrepared.get() < 1) {
throw new HibernateException("'" + SELECT_FOUND_ROWS + "' called before '" + SELECT_PREFIX + CALC_FOUND_ROWS_HINT + "'");
}
// Fetch the total number of rows
sessionFactory.getCurrentSession().doWork(new Work() {
@Override
public void execute(Connection connection) throws SQLException {
final Statement stmt = connection.createStatement();
ResultSet rs = null;
try {
rs = stmt.executeQuery(SELECT_FOUND_ROWS);
if (rs.next()) {
mFoundRows.set(rs.getInt(1));
} else {
mFoundRows.set(0);
}
} finally {
if (rs != null) {
rs.close();
}
try {
stmt.close();
} catch (RuntimeException exp) {
}
}
}
});
}
public void setSessionFactoryBeanName(String sessionFactoryBeanName) {
this.sessionFactoryBeanName = sessionFactoryBeanName;
}
@Override
public void setBeanFactory(BeanFactory arg0) throws BeansException {
this.beanFactory = arg0;
}
}