iohannes
Published on 2025-04-11 / 1 Visits

二维二分法获取函数最大时的,两个参数值

def get_best_param_2d(x_min, x_max, y_min, y_max, fun, epsilon=1e-6):
    """
    使用二维二分法在区间 [x_min, x_max] 和 [y_min, y_max] 内找到使目标函数 fun 最大的参数值 (x, y)。
    
    参数:
        x_min: x 的搜索区间的下限
        x_max: x 的搜索区间的上限
        y_min: y 的搜索区间的下限
        y_max: y 的搜索区间的上限
        fun: 目标函数,输入两个参数值,返回一个数值
        epsilon: 精度阈值,用于控制搜索的终止条件
    
    返回:
        best_x, best_y: 使目标函数 fun 最大的参数值 (x, y)
    """
    while (x_max - x_min > epsilon) or (y_max - y_min > epsilon):
        # 计算 x 和 y 的中点
        x_mid1 = x_min + (x_max - x_min) / 3
        x_mid2 = x_min + 2 * (x_max - x_min) / 3
        y_mid1 = y_min + (y_max - y_min) / 3
        y_mid2 = y_min + 2 * (y_max - y_min) / 3
        
        # 计算中点处的目标函数值
        rtn11 = fun(x_mid1, y_mid1)
        rtn12 = fun(x_mid1, y_mid2)
        rtn21 = fun(x_mid2, y_mid1)
        rtn22 = fun(x_mid2, y_mid2)
        
        # 根据目标函数值调整搜索区间
        if rtn11 < rtn12:
            y_min = y_mid1
        else:
            y_max = y_mid2
        if rtn21 < rtn22:
            y_min = y_mid1
        else:
            y_max = y_mid2
        if rtn11 < rtn21:
            x_min = x_mid1
        else:
            x_max = x_mid2
        if rtn12 < rtn22:
            x_min = x_mid1
        else:
            x_max = x_mid2
    
    # 返回最优参数值
    best_x = (x_min + x_max) / 2
    best_y = (y_min + y_max) / 2
    return best_x, best_y