UpdateEntityDao.class:
import org.apache.ibatis.annotations.Mapper;
import org.apache.ibatis.annotations.UpdateProvider;
@Mapper
public interface UpdateEntityDao {
@UpdateProvider(type = UpdateEntityProvider.class, method = "updateEntity")
int updateEntity(UpdateEntityReq<? extends IEntity> updateEntityReq);
}
UpdateEntityProvider.class:
import cn.hutool.core.util.ReflectUtil;
import org.apache.ibatis.jdbc.SQL;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
public class UpdateEntityProvider {
public <T extends IEntity> String updateEntity(UpdateEntityReq<T> updateEntityReq) {
if (ObjectUtils.isEmpty(updateEntityReq.getEntity())) {
throw new IllegalArgumentException("entity不能为空");
}
if(ObjectUtils.isEmpty(updateEntityReq.getUpdateFields())) {
throw new IllegalArgumentException("updateFields不能为空");
}
if(StringUtils.isEmpty(updateEntityReq.getIdName())) {
throw new IllegalArgumentException("idName不能为空");
}
IEntity entity = updateEntityReq.getEntity();
Class<?> clazz = updateEntityReq.getEntity().getClass();
if(clazz.isAnonymousClass() || clazz.isSynthetic()) {
clazz = clazz.getSuperclass();
}
String tableName = clazz.getSimpleName();
Field[] fields = ReflectUtil.getFields(clazz);
Set<String> fieldNames = Arrays.stream(fields).map(Field::getName).collect(Collectors.toSet());
for (String field : updateEntityReq.getUpdateFields()) {
if (!fieldNames.contains(field)) {
throw new IllegalArgumentException(String.format("updateFields中的字段无效:%s", field));
}
}
if (!fieldNames.contains(updateEntityReq.getIdName())) {
throw new IllegalArgumentException("idName字段无效");
}
return new SQL() {{
String where = null;
List<String> setList = new ArrayList<>();
for (Field field : fields) {
String fieldName = field.getName();
Object value = ReflectUtil.getFieldValue(entity, field);
// idName默认为id
if(fieldName.equals(updateEntityReq.getIdName())) {
where = String.format("%s = #{entity.%s}", fieldName, fieldName);
}
if (!updateEntityReq.getUpdateFields().contains(fieldName)) {
continue;
}
// isFilterNull默认为ture
if (updateEntityReq.isFilterNull() && value == null) {
continue;
}
// isFilterEmptyStrings默认为false
if (updateEntityReq.isFilterEmptyStrings() && value instanceof String && StringUtils.isEmpty(value)) {
continue;
}
setList.add(String.format("%s = #{entity.%s}", fieldName, fieldName));
}
if(StringUtils.isEmpty(tableName)) {
throw new IllegalArgumentException("表名不能为空");
}
if(ObjectUtils.isEmpty(setList)) {
throw new IllegalArgumentException("未生成更新字段");
}
if(StringUtils.isEmpty(where)) {
throw new IllegalArgumentException("未生成查询条件");
}
UPDATE(tableName);
setList.forEach(this::SET);
WHERE(where);
}}.toString();
}
}
UpdateEntityReq.class:
import lombok.Builder;
import lombok.Data;
import lombok.NonNull;
import java.util.List;
@Data
@Builder
public class UpdateEntityReq<T extends IEntity> {
@NonNull
private T entity;
@NonNull
private List<String> updateFields;
@Builder.Default
private String idName = "id";
@Builder.Default
private boolean filterNull = true;
@Builder.Default
private boolean filterEmptyStrings = false;
}
UpdateUtil.class:
import cn.hutool.core.lang.func.Func1;
import cn.hutool.core.lang.func.LambdaUtil;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
public final class UpdateUtil {
private UpdateUtil() {}
public static <T extends IEntity> UpdateStep<T> update(T entity) {
return new UpdateStep<>(entity);
}
public static class UpdateStep<T extends IEntity> {
private final T entity;
public UpdateStep(T entity) {
this.entity = entity;
}
@SafeVarargs
public final SetStep<T> set(Func1<T, ?>...updateFields) {
return new SetStep<>(entity, toList(updateFields), false, false);
}
@SafeVarargs
public final SetStep<T> setNotNull(Func1<T, ?>...updateFields) {
return new SetStep<>(entity, toList(updateFields), true, false);
}
@SafeVarargs
public final SetStep<T> setNotBlank(Func1<T, ?>...updateFields) {
return new SetStep<>(entity, toList(updateFields), true, true);
}
@SafeVarargs
private static <T extends IEntity> List<String> toList(Func1<T, ?>...updateFields) {
return Arrays.stream(updateFields).map(LambdaUtil::getFieldName).collect(Collectors.toList());
}
}
public static class SetStep<T extends IEntity> {
private final T entity;
private final List<String> updateFields;
private final boolean filterNull;
private final boolean filterEmptyStrings;
public SetStep(T entity, List<String> updateFields, boolean filterNull, boolean filterEmptyStrings) {
this.entity = entity;
this.updateFields = updateFields;
this.filterNull = filterNull;
this.filterEmptyStrings = filterEmptyStrings;
}
public ExecuteStep<T> whereBy(Func1<T, ?> idName) {
UpdateEntityReq<T> updateEntityReq = UpdateEntityReq.<T>builder()
.entity(entity)
.updateFields(updateFields)
.idName(LambdaUtil.getFieldName(idName))
.filterNull(filterNull)
.filterEmptyStrings(filterEmptyStrings)
.build();
return new ExecuteStep<>(updateEntityReq);
}
public ExecuteStep<T> whereById() {
UpdateEntityReq<T> updateEntityReq = UpdateEntityReq.<T>builder()
.entity(entity)
.updateFields(updateFields)
.idName("id")
.filterNull(filterNull)
.filterEmptyStrings(filterEmptyStrings)
.build();
return new ExecuteStep<>(updateEntityReq);
}
}
public static class ExecuteStep<T extends IEntity> {
private final UpdateEntityReq<T> updateEntityReq;
public ExecuteStep(UpdateEntityReq<T> updateEntityReq) {
this.updateEntityReq = updateEntityReq;
}
public int execute(UpdateEntityDao updateEntityDao) {
if(updateEntityDao == null) {
throw new IllegalArgumentException("updateEntityDao不能为空");
}
return updateEntityDao.updateEntity(updateEntityReq);
}
}
}