spring data jpa 自定义全局DAO方法实例教程。最近项目中采用了spring mvc+jpa的技术架构.其中用到了spring data jpa对jpa的进一步封装.于是顺便研究了一下。
public interface GenericRepository<T, ID extends Serializable> extends JpaRepository<T, ID> {
public QueryResult<T> getScrollDataByJpql(Class<T> entityClass, String whereJpql, Object[] queryParams,
public QueryResult<T> getScrollDataBySql(String sql, Object[] queryParams, Pageable pageable);
package com.trmp.commons.dao;
import java.beans.Introspector;
import java.beans.PropertyDescriptor;
import java.io.Serializable;
import java.lang.reflect.Method;
import java.util.LinkedHashMap;
import javax.persistence.EmbeddedId;
import javax.persistence.Entity;
import javax.persistence.EntityManager;
import javax.persistence.Query;
import org.apache.log4j.Logger;
import org.springframework.data.domain.Pageable;
import org.springframework.data.jpa.repository.support.JpaEntityInformation;
import org.springframework.data.jpa.repository.support.JpaEntityInformationSupport;
import org.springframework.data.jpa.repository.support.SimpleJpaRepository;
import org.springframework.data.repository.NoRepositoryBean;
import com.trmp.commons.util.QueryResult;
/**
* {@link BasicJpaRepository}接口实现类,并在{@link SimpleJpaRepository}基础上扩展。
* @author divine
*
* @param <T> ORM对象
* @param <ID> 主键ID
*/
@NoRepositoryBean // 必须的
public class GenericRepositoryImpl<T, ID extends Serializable> extends
SimpleJpaRepository<T, ID> implements GenericRepository<T, ID> {
static Logger logger = Logger.getLogger(GenericRepositoryImpl.class);
private final EntityManager em;
/**
* 构造函数
* @param domainClass
* @param em
*/
public GenericRepositoryImpl(final JpaEntityInformation<T, ?> entityInformation, EntityManager entityManager) {
super(entityInformation, entityManager);
this.em = entityManager;
}
/**
* 构造函数
* @param domainClass
* @param em
*/
public GenericRepositoryImpl(Class<T> domainClass, EntityManager em) {
this(JpaEntityInformationSupport.getMetadata(domainClass, em), em);
}
@Override
public void setQueryParams(Query query, Object[] queryParams){
if(null != queryParams && queryParams.length != 0){
for(int i=0;i<queryParams.length;i++){
query.setParameter(i+1, queryParams[i]);
}
}
}
@Override
public String buildOrderby(LinkedHashMap<String, String> orderby) {
// TODO Auto-generated method stub
StringBuffer orderbyql = new StringBuffer(“”);
if(orderby!=null && orderby.size()>0){
orderbyql.append(” order by “);
for(String key : orderby.keySet()){
orderbyql.append(“o.”).append(key).append(” “).append(orderby.get(key)).append(“,”);
}
orderbyql.deleteCharAt(orderbyql.length()-1);
}
return orderbyql.toString();
}
@Override
public String getEntityName(Class<T> entityClass) {
// TODO Auto-generated method stub
String entityname = entityClass.getSimpleName();
Entity entity = entityClass.getAnnotation(Entity.class);
if(entity.name()!=null && !”".equals(entity.name())){
entityname = entity.name();
}
return entityname;
}
@Override
public QueryResult<T> getScrollDataByJpql(Class<T> entityClass, String whereJpql, Object[] queryParams,
LinkedHashMap<String, String> orderby, Pageable pageable) {
QueryResult<T> qr = new QueryResult<T>();
String entityname = getEntityName(entityClass);
String sql = “select o from “+ entityname+ ” o “;
String sqlWhere = whereJpql==null? “”: “where “+ whereJpql;
Query query = em.createQuery(sql+sqlWhere+ buildOrderby(orderby));
setQueryParams(query, queryParams);
if(pageable.getPageNumber()!=-1 && pageable.getPageSize()!=-1)
query.setFirstResult(pageable.getPageNumber()*pageable.getPageSize()).setMaxResults(pageable.getPageSize());
qr.setResultList(query.getResultList());
query = em.createQuery(“select count(“+ getCountField(entityClass)+ “) from “+ entityname+ ” o “+ sqlWhere);
setQueryParams(query, queryParams);
qr.setTotalRecord((Long)query.getSingleResult());
return qr;
}
@Override
public QueryResult<T> getScrollDataBySql(String sql, Object[] queryParams, Pageable pageable) {
// TODO Auto-generated method stub
//查询记录数
QueryResult<T> qr = new QueryResult<T>();
Query query = em.createNativeQuery(sql);
setQueryParams(query, queryParams);
if(pageable.getPageNumber()!=-1 && pageable.getPageSize()!=-1)
query.setFirstResult(pageable.getPageNumber()*pageable.getPageSize()).setMaxResults(pageable.getPageSize());
qr.setResultList(query.getResultList());
//
String from = getFromClause(sql);
//查询总记录数
query = em.createQuery(“select count(*) ” + from);
setQueryParams(query, queryParams);
qr.setTotalRecord((Long)query.getSingleResult());
return qr;
}
private String getCountField(Class<T> clazz) {
String out = “o”;
try {
PropertyDescriptor[] propertyDescriptors = Introspector.getBeanInfo(clazz).getPropertyDescriptors();
for(PropertyDescriptor propertydesc : propertyDescriptors) {
Method method = propertydesc.getReadMethod();
if(method!=null && method.isAnnotationPresent(EmbeddedId.class)){
PropertyDescriptor[] ps = Introspector.getBeanInfo(propertydesc.getPropertyType()).getPropertyDescriptors();
out = “o.”+ propertydesc.getName()+ “.” + (!ps[1].getName().equals(“class”)? ps[1].getName(): ps[0].getName());
break;
}
}
} catch (Exception e) {
e.printStackTrace();
}
return out;
}
/**
* 从sql中找出from子句
* @param sql
* @return
*/
private String getFromClause(String sql) {
String sql2 = sql.toLowerCase();
int index = sql2.indexOf(” from “);
if (index < 0) {
return null;
} else {
int i1 = sql2.lastIndexOf(” order by “);
int i2 = sql2.lastIndexOf(” group by “);
if (i1 >= 0 && i2 >= 0) {
return sql.substring(index, i1 > i2 ? i2 : i1);
} else if (i1 >= 0) {
return sql.substring(index, i1);
} else if (i2 >= 0) {
return sql.substring(index, i2);
} else {
return sql.substring(index);
}
}
}
}
3、创建自定义Repository的构建工厂
package com.trmp.commons.dao;
import java.io.Serializable;
import javax.persistence.EntityManager;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.data.jpa.repository.support.JpaEntityInformation;
import org.springframework.data.jpa.repository.support.JpaRepositoryFactory;
import org.springframework.data.querydsl.QueryDslPredicateExecutor;
import org.springframework.data.querydsl.QueryDslUtils;
import org.springframework.data.repository.core.RepositoryMetadata;
import org.springframework.util.Assert;
/**
* @author divine
*
*/
public class DefaultRepositoryFactory extends JpaRepositoryFactory {
private final EntityManager entityManager;
public DefaultRepositoryFactory(EntityManager entityManager) {
super(entityManager);
Assert.notNull(entityManager);
this.entityManager = entityManager;
}
@Override
protected <T, ID extends Serializable> JpaRepository<?, ?> getTargetRepository(RepositoryMetadata metadata, EntityManager entityManager) {
JpaEntityInformation<?, Serializable> entityInformation = getEntityInformation(metadata.getDomainType());
return new GenericRepositoryImpl(entityInformation, entityManager); // custom implementation
}
@Override
protected Class<?> getRepositoryBaseClass(RepositoryMetadata metadata) {
return GenericRepositoryImpl.class;
}
}
4、创建自定义Repository工厂构建器,该类将被注入到spring容器中,是spring管理自定义DAO的入口
/**
*
*/
package com.trmp.commons.dao;
import java.io.Serializable;
import javax.persistence.EntityManager;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.data.jpa.repository.support.JpaRepositoryFactoryBean;
import org.springframework.data.repository.core.support.RepositoryFactorySupport;
/**
* @author divine
*
*/
public class DefaultRepositoryFactoryBean<T extends JpaRepository<S, ID>, S, ID extends Serializable>
extends JpaRepositoryFactoryBean<T, S, ID> {
/**
* Returns a {@link RepositoryFactorySupport}.
*
* @param entityManager
* @return
*/
protected RepositoryFactorySupport createRepositoryFactory(EntityManager entityManager) {
return new DefaultRepositoryFactory(entityManager);
}
}
5、applicationContext.xml中注入DefaultRepositoryFactoryBean
<!–
spring data jpa repository
base-package:扫描的包,所有继承GenericRepository的接口所在的包需要在这里定义
repository-impl-postfix:仓库自定义实现类型的后缀 自动扫描并添加到接口的实现
factory-class 仓库接口的实现工厂
–>
<jpa:repositories base-package=”com.trmp,com.platform”
entity-manager-factory-ref=”entityManagerFactory” transaction-manager-ref=”transactionManager” repository-impl-postfix=”Impl”
factory-class=”com.trmp.commons.dao.DefaultRepositoryFactoryBean”>
</jpa:repositories>
6、如何使用
与使用spring data jpa的方式一样,接口继承GenericRepository,如我定义了一个叫MyRepository的接口。
/**
*
*/
package com.trmp.base.sceneryManager.dao;
import com.trmp.base.sceneryManager.entity.BaseSceneryInfo;
import com.trmp.commons.dao.GenericRepository;
/**
* @author divine
*
*/
public interface MyRepository extends GenericRepository<BaseSceneryInfo, Long> {
}
7、junit测试
package com.trmp.base.sceneryManager;
import javax.persistence.EntityManager;
import javax.persistence.PersistenceContext;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.domain.PageRequest;
import org.springframework.data.domain.Pageable;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
import org.springframework.test.context.transaction.TransactionConfiguration;
import org.springframework.transaction.annotation.Transactional;
import com.trmp.base.sceneryManager.dao.MyRepository;
import com.trmp.base.sceneryManager.entity.BaseSceneryInfo;
import com.trmp.commons.util.QueryResult;
@RunWith(SpringJUnit4ClassRunner.class)
@ContextConfiguration(locations={“classpath:applicationContext.xml”})
@TransactionConfiguration(transactionManager=”transactionManager”)
@Transactional
public class SceneryTest {
@PersistenceContext
EntityManager em;
@Autowired
private MyRepository customDao;
@Test
public void testGetScrollData(){
Pageable pageable = new PageRequest(0, 10);
Object[] params = new Object[]{60L,”%沈%”};
QueryResult qr = customDao.getScrollDataByJpql(BaseSceneryInfo.class, “orgID = ? and cmpName like ? “, params, null, pageable);
System.out.println(qr.getTotalRecord());
}
}
8、用到的工具类
/**
*
*/
package com.trmp.commons.util;
import java.util.List;
/**
* @author divine
* 封装查询结构的对象
* resultList 结构集对象
* totalRecord 总记录数
*/
public class QueryResult<T> {
/**
* 查询结果集
*/
private List resultList;
/**
* 总记录数
*/
private Long totalRecord;
public List getResultList() {
return resultList;
}
public void setResultList(List resultList) {
this.resultList = resultList;
}
public Long getTotalRecord() {
return totalRecord;
}
public void setTotalRecord(Long totalRecord) {
this.totalRecord = totalRecord;
}
}