package com.sfa.job.config;

import lombok.extern.slf4j.Slf4j;
import org.apache.shardingsphere.sharding.api.sharding.standard.PreciseShardingValue;
import org.apache.shardingsphere.sharding.api.sharding.standard.RangeShardingValue;
import org.apache.shardingsphere.sharding.api.sharding.standard.StandardShardingAlgorithm;

import java.time.LocalDateTime;
import java.util.*;

/**
 * ShardingSphere 5.4.1 按 pay_time 年份分表算法
 */
@Slf4j
public class PayTimeYearShardingAlgorithm implements StandardShardingAlgorithm<LocalDateTime> {

    // 初始化方法（可留空）
    @Override
    public void init(Properties props) {
        StandardShardingAlgorithm.super.init(props);
    }

    // 获取分片算法类型（自定义名称）
    @Override
    public String getType() {
        return "PAY_TIME_YEAR";
    }

    @Override
    public Collection<Object> getTypeAliases() {
        return StandardShardingAlgorithm.super.getTypeAliases();
    }

    @Override
    public boolean isDefault() {
        return StandardShardingAlgorithm.super.isDefault();
    }

    /**
     * 精确分片（处理 = 条件）
     * 例：pay_time = '2023-05-01 10:00:00' → 路由到 t_order_2023
     */
    @Override
    public String doSharding(Collection<String> collection, PreciseShardingValue<LocalDateTime> preciseShardingValue) {
        // 获取精确查询的时间值
        LocalDateTime payTime = preciseShardingValue.getValue();
        if(payTime ==  null){
            return preciseShardingValue.getLogicTableName()+ "_2021" ;
        }
        int year = payTime.getYear();
        // 生成目标表名（逻辑表名_年份）
        String targetTable = preciseShardingValue.getLogicTableName() + "_" + year;
        // 校验表是否存在
        if (collection.contains(targetTable)) {
            return targetTable;
        }
        throw new IllegalArgumentException("未找到匹配的表：" + targetTable);
    }

    /**
     * 范围分片（处理 >、<、BETWEEN 等条件）
     * 例：pay_time BETWEEN '2023-01-01' AND '2024-12-31' → 路由到 t_order_2023、t_order_2024
     */
    @Override
    public Collection<String> doSharding(Collection<String> collection, RangeShardingValue<LocalDateTime> shardingValue) {
        Set<String> result = new HashSet<>();
        // 遍历所有分片条件（通常只有一个）
//        for (RangeShardingValue<LocalDateTime> shardingValue : 、) {
        String logicTableName = shardingValue.getLogicTableName();
        // 获取范围查询的上下限
        LocalDateTime lower = shardingValue.getValueRange().lowerEndpoint();
        LocalDateTime upper = shardingValue.getValueRange().upperEndpoint();

        int startYear = lower.getYear();
        int endYear = upper.getYear();

        // 生成所有符合条件的表名
        for (int year = startYear; year <= endYear; year++) {
            String targetTable = logicTableName + "_" + year;
            if (collection.contains(targetTable)) {
                result.add(targetTable);
            }
        }
//        }
        return result;
    }

    @Override
    public Optional<String> getAlgorithmStructure(String dataNodePrefix, String shardingColumn) {
        return StandardShardingAlgorithm.super.getAlgorithmStructure(dataNodePrefix, shardingColumn);
    }
}
