spring data jpa 自定义全局DAO方法实例教程



spring data jpa 自定义全局DAO方法实例教程。最近项目中采用了spring mvc+jpa的技术架构.其中用到了spring data jpa对jpa的进一步封装.于是顺便研究了一下。

spring data jpa入门可以参考IBM的一篇文章

使用 Spring Data JPA 简化 JPA 开发

Spring Data JPA 开发指南

 

开发了一段时间发现有些需求仍然不能满足,于是想到了自己再封装一层。以下是开发步骤:

1、扩展JpaRepository接口,将来所有dao将继承该接口以实现DAO的封装

 

/**

 * 

 */

package com.trmp.commons.dao;

 

import java.io.Serializable;

import java.util.LinkedHashMap;

import javax.persistence.Query;

import org.springframework.data.domain.Pageable;

import org.springframework.data.jpa.repository.JpaRepository;

import org.springframework.data.repository.NoRepositoryBean;

import com.trmp.commons.util.QueryResult;

/**

 * @author divine

 * 针对spring data jpa所提供的接口{@link JpaRepository}再次扩展

 * @NoRepositoryBean是必须的

 */

@NoRepositoryBean

public interface GenericRepository<T, ID extends Serializable> extends JpaRepository<T, ID> {

//JpaRepository本身是一个空接口,下面所有的方法声明都是自定义的

/**

* 设置query的参数

* @param query 查询对象

* @param queryParams 参数

*/

public void setQueryParams(Query query, Object[] queryParams);

 

/**


* 组装ORDER BY 语句

* @param orderby

* @return

*/

public String buildOrderby(LinkedHashMap<String, String> orderby);

 

/**

* 获取实体名

* @param entityClass

* @return

*/

public String getEntityName(Class<T> entityClass);

 

/**

* jpql语句查询

* @param entityClass

* @param whereSql

* @param queryParams

* @param orderby

* @param pageable

* @return

*/

public QueryResult<T> getScrollDataByJpql(Class<T> entityClass, String whereJpql, Object[] queryParams,

LinkedHashMap<String, String> orderby, Pageable pageable);

 

/**

* sql语句查询

* @param sql

* @param queryParams

* @param pageable

* @return

*/

public QueryResult<T> getScrollDataBySql(String sql, Object[] queryParams, Pageable pageable);

}

2、创建自定义JpaRepository的实现
package com.trmp.commons.dao;
 
import java.beans.Introspector;
import java.beans.PropertyDescriptor;
import java.io.Serializable;
import java.lang.reflect.Method;
import java.util.LinkedHashMap;
 
import javax.persistence.EmbeddedId;
import javax.persistence.Entity;
import javax.persistence.EntityManager;
import javax.persistence.Query;
 
import org.apache.log4j.Logger;
import org.springframework.data.domain.Pageable;
import org.springframework.data.jpa.repository.support.JpaEntityInformation;
import org.springframework.data.jpa.repository.support.JpaEntityInformationSupport;
import org.springframework.data.jpa.repository.support.SimpleJpaRepository;
import org.springframework.data.repository.NoRepositoryBean;
 
import com.trmp.commons.util.QueryResult;
 
/**
 * {@link BasicJpaRepository}接口实现类,并在{@link SimpleJpaRepository}基础上扩展。
 * @author divine
 *
 * @param <T> ORM对象
 * @param <ID> 主键ID
 */
@NoRepositoryBean   // 必须的
public class GenericRepositoryImpl<T, ID extends Serializable> extends
SimpleJpaRepository<T, ID> implements GenericRepository<T, ID> {
 
static Logger logger = Logger.getLogger(GenericRepositoryImpl.class);
    private final EntityManager em;
 
/**
* 构造函数
     * @param domainClass
     * @param em
     */
public GenericRepositoryImpl(final JpaEntityInformation<T, ?> entityInformation, EntityManager entityManager) {
 
super(entityInformation, entityManager);
this.em = entityManager;
}
 
/**
* 构造函数
* @param domainClass
* @param em
*/
public GenericRepositoryImpl(Class<T> domainClass, EntityManager em) {
        this(JpaEntityInformationSupport.getMetadata(domainClass, em), em); 
    }
 
@Override
public void setQueryParams(Query query, Object[] queryParams){
 
if(null != queryParams && queryParams.length != 0){
for(int i=0;i<queryParams.length;i++){
query.setParameter(i+1, queryParams[i]);
}
}
}
 
@Override
public String buildOrderby(LinkedHashMap<String, String> orderby) {
// TODO Auto-generated method stub
StringBuffer orderbyql = new StringBuffer(“”);
if(orderby!=null && orderby.size()>0){
orderbyql.append(” order by “);
for(String key : orderby.keySet()){
orderbyql.append(“o.”).append(key).append(” “).append(orderby.get(key)).append(“,”);
}
orderbyql.deleteCharAt(orderbyql.length()-1);
}
return orderbyql.toString();
}
 
@Override
public String getEntityName(Class<T> entityClass) {
// TODO Auto-generated method stub
String entityname = entityClass.getSimpleName();
Entity entity = entityClass.getAnnotation(Entity.class);
if(entity.name()!=null && !”".equals(entity.name())){
entityname = entity.name();
}
return entityname;
}
 
@Override
public QueryResult<T> getScrollDataByJpql(Class<T> entityClass, String whereJpql, Object[] queryParams,
LinkedHashMap<String, String> orderby, Pageable pageable) {
 
QueryResult<T> qr = new QueryResult<T>();
        String entityname = getEntityName(entityClass);
        String sql = “select o from “+ entityname+ ” o “;
        String sqlWhere = whereJpql==null? “”: “where “+ whereJpql;
        Query query = em.createQuery(sql+sqlWhere+ buildOrderby(orderby));
        
        setQueryParams(query, queryParams);
        if(pageable.getPageNumber()!=-1 && pageable.getPageSize()!=-1)
        query.setFirstResult(pageable.getPageNumber()*pageable.getPageSize()).setMaxResults(pageable.getPageSize());  
        qr.setResultList(query.getResultList());
        
        query = em.createQuery(“select count(“+ getCountField(entityClass)+ “) from “+ entityname+ ” o “+ sqlWhere);  
        setQueryParams(query, queryParams);
        qr.setTotalRecord((Long)query.getSingleResult());
        
        return qr;
}
 
@Override
public QueryResult<T> getScrollDataBySql(String sql, Object[] queryParams, Pageable pageable) {
// TODO Auto-generated method stub
//查询记录数
QueryResult<T> qr = new QueryResult<T>();
Query query = em.createNativeQuery(sql);
setQueryParams(query, queryParams);
if(pageable.getPageNumber()!=-1 && pageable.getPageSize()!=-1)
        query.setFirstResult(pageable.getPageNumber()*pageable.getPageSize()).setMaxResults(pageable.getPageSize());
qr.setResultList(query.getResultList());
 
//
String from = getFromClause(sql);
//查询总记录数
query = em.createQuery(“select count(*) ” + from);  
        setQueryParams(query, queryParams);
        qr.setTotalRecord((Long)query.getSingleResult());
return qr;
}
 
private String getCountField(Class<T> clazz) {
 
        String out = “o”;
        try {  
            PropertyDescriptor[] propertyDescriptors = Introspector.getBeanInfo(clazz).getPropertyDescriptors();  
            for(PropertyDescriptor propertydesc : propertyDescriptors) {
                Method method = propertydesc.getReadMethod();  
                if(method!=null && method.isAnnotationPresent(EmbeddedId.class)){                     
                    PropertyDescriptor[] ps = Introspector.getBeanInfo(propertydesc.getPropertyType()).getPropertyDescriptors();  
                    out = “o.”+ propertydesc.getName()+ “.” + (!ps[1].getName().equals(“class”)? ps[1].getName(): ps[0].getName());  
                    break;  
                }  
            }  
        } catch (Exception e) {
            e.printStackTrace();  
        }  
        return out;  
    }
 
/**
* 从sql中找出from子句
* @param sql
* @return
*/
private String getFromClause(String sql) {
String sql2 = sql.toLowerCase();
int index = sql2.indexOf(” from “);
if (index < 0) {
return null;
} else {
int i1 = sql2.lastIndexOf(” order by “);
int i2 = sql2.lastIndexOf(” group by “);
 
if (i1 >= 0 && i2 >= 0) {
return sql.substring(index, i1 > i2 ? i2 : i1);
} else if (i1 >= 0) {
return sql.substring(index, i1);
} else if (i2 >= 0) {
return sql.substring(index, i2);
} else {
return sql.substring(index);
}
}
}
 
}
 
3、创建自定义Repository的构建工厂
package com.trmp.commons.dao;
 
import java.io.Serializable;
 
import javax.persistence.EntityManager;
 
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.data.jpa.repository.support.JpaEntityInformation;
import org.springframework.data.jpa.repository.support.JpaRepositoryFactory;
import org.springframework.data.querydsl.QueryDslPredicateExecutor;
import org.springframework.data.querydsl.QueryDslUtils;
import org.springframework.data.repository.core.RepositoryMetadata;
import org.springframework.util.Assert;
 
/**
 * @author divine
 *
 */
public class DefaultRepositoryFactory extends JpaRepositoryFactory {
 
private final EntityManager entityManager;
    
public DefaultRepositoryFactory(EntityManager entityManager) {
super(entityManager);
Assert.notNull(entityManager);
this.entityManager = entityManager;
 
}
 
@Override
protected <T, ID extends Serializable> JpaRepository<?, ?> getTargetRepository(RepositoryMetadata metadata, EntityManager entityManager) {
 
JpaEntityInformation<?, Serializable> entityInformation = getEntityInformation(metadata.getDomainType());
return new GenericRepositoryImpl(entityInformation, entityManager); // custom implementation
}
  
    @Override
    protected Class<?> getRepositoryBaseClass(RepositoryMetadata metadata) {
 
    return GenericRepositoryImpl.class;
    }
}
4、创建自定义Repository工厂构建器,该类将被注入到spring容器中,是spring管理自定义DAO的入口
/**
 * 
 */
package com.trmp.commons.dao;
import java.io.Serializable;
import javax.persistence.EntityManager;
 
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.data.jpa.repository.support.JpaRepositoryFactoryBean;
import org.springframework.data.repository.core.support.RepositoryFactorySupport;
 
/**
 * @author divine
 *
 */
public class DefaultRepositoryFactoryBean<T extends JpaRepository<S, ID>, S, ID extends Serializable>
extends JpaRepositoryFactoryBean<T, S, ID> {
/**
     * Returns a {@link RepositoryFactorySupport}.
     *
     * @param entityManager
     * @return
     */
protected RepositoryFactorySupport createRepositoryFactory(EntityManager entityManager) {
 
    return new DefaultRepositoryFactory(entityManager);
    }
}
5、applicationContext.xml中注入DefaultRepositoryFactoryBean
<!–
        spring data jpa repository
        base-package:扫描的包,所有继承GenericRepository的接口所在的包需要在这里定义
        repository-impl-postfix:仓库自定义实现类型的后缀  自动扫描并添加到接口的实现
        factory-class 仓库接口的实现工厂
    –>
<jpa:repositories base-package=”com.trmp,com.platform” 
entity-manager-factory-ref=”entityManagerFactory” transaction-manager-ref=”transactionManager” repository-impl-postfix=”Impl”
factory-class=”com.trmp.commons.dao.DefaultRepositoryFactoryBean”>
</jpa:repositories>
6、如何使用
与使用spring data jpa的方式一样,接口继承GenericRepository,如我定义了一个叫MyRepository的接口。
/**
 * 
 */
package com.trmp.base.sceneryManager.dao;
 
import com.trmp.base.sceneryManager.entity.BaseSceneryInfo;
import com.trmp.commons.dao.GenericRepository;
 
/**
 * @author divine
 *
 */
public interface MyRepository extends GenericRepository<BaseSceneryInfo, Long> {
 
 
}
7、junit测试
package com.trmp.base.sceneryManager;
 
import javax.persistence.EntityManager;
import javax.persistence.PersistenceContext;
 
import org.junit.Test;
import org.junit.runner.RunWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.domain.PageRequest;
import org.springframework.data.domain.Pageable;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
import org.springframework.test.context.transaction.TransactionConfiguration;
import org.springframework.transaction.annotation.Transactional;
 
import com.trmp.base.sceneryManager.dao.MyRepository;
import com.trmp.base.sceneryManager.entity.BaseSceneryInfo;
import com.trmp.commons.util.QueryResult;
 
@RunWith(SpringJUnit4ClassRunner.class)
@ContextConfiguration(locations={“classpath:applicationContext.xml”})
@TransactionConfiguration(transactionManager=”transactionManager”)
@Transactional
public class SceneryTest {
 
@PersistenceContext
EntityManager em;
 
@Autowired
private MyRepository customDao;
 
@Test
public void testGetScrollData(){
 
Pageable pageable = new PageRequest(0, 10);
Object[] params = new Object[]{60L,”%沈%”};
 
 
QueryResult qr = customDao.getScrollDataByJpql(BaseSceneryInfo.class, “orgID = ? and cmpName like ? “, params, null, pageable);
System.out.println(qr.getTotalRecord());
}
}
8、用到的工具类
/**
 * 
 */
package com.trmp.commons.util;
 
import java.util.List;
 
/**
 * @author divine
 * 封装查询结构的对象
 * resultList 结构集对象
 * totalRecord 总记录数
 */
public class QueryResult<T> {
 
/**
* 查询结果集
*/
private List resultList;
 
/**
* 总记录数
*/
private Long totalRecord;
 
 
public List getResultList() {
return resultList;
}
public void setResultList(List resultList) {
this.resultList = resultList;
}
public Long getTotalRecord() {
return totalRecord;
}
public void setTotalRecord(Long totalRecord) {
this.totalRecord = totalRecord;
}
}