CSP算法

• 16 min read • 3054 words
Tags: CSP
Categories: Introduction to Artificial Intelligence

下面是对CSP的相关概念的接口预实现,之后会用到

class CSP:
    """Constraint Satisfaction Problem class"""
    
    def __init__(self, variables: List[str], domains: Dict[str, Set[Any]], constraints: List[Constraint]):
        self.variables = variables
        self.domains = domains  # Using sets for efficient domain operations
        self.constraints = constraints
    
    def is_consistent(self, var: str, value: Any, assignment: Dict[str, Any]) -> bool:
        """Check if assigning value to var is consistent with all constraints"""
        for constraint in self.constraints:
            if var in constraint.variables and not constraint.is_consistent(var, value, assignment):
                return False
        return True
    
    def is_complete(self, assignment: Dict[str, Any]) -> bool:
        """Check if assignment is complete (all variables assigned)"""
        return len(assignment) == len(self.variables)
    
    def get_unassigned_variables(self, assignment: Dict[str, Any]) -> List[str]:
        """Get list of unassigned variables"""
        return [var for var in self.variables if var not in assignment]

class Constraint(ABC):
    def __init__(self, variables: List[str]):
        self.variables = variables
    
    @abstractmethod
    def is_satisfied(self, assignment: Dict[str, Any]) -> bool:
        """Check if the constraint is satisfied"""
        pass
    
    @abstractmethod
    def is_consistent(self, var: str, value: Any, assignment: Dict[str, Any]) -> bool:
        """Check whether a single variable assignment is consistent with the constraint"""
            pass
 
class BinaryConstraint(Constraint):
    
    def __init__(self, var1: str, var2: str):
        super().__init__([var1, var2])
        self.var1 = var1
        self.var2 = var2
 
class NotEqualConstraint(BinaryConstraint):
    
    def is_satisfied(self, assignment: Dict[str, Any]) -> bool:
        if self.var1 not in assignment or self.var2 not in assignment:
            return True
        return assignment[self.var1] != assignment[self.var2]
    
    def is_consistent(self, var: str, value: Any, assignment: Dict[str, Any]) -> bool:
        if var == self.var1 and self.var2 in assignment:
            return value != assignment[self.var2]
        elif var == self.var2 and self.var1 in assignment:
            return value != assignment[self.var1]
        return True

class CSPSolver:

    @abstractmethod
    def solve(self, csp: 'CSP') -> Optional[Dict[str, Any]]:
        """Solving CSP problems"""
        pass

1. 回溯搜索

a.a.相关概念

回溯搜索是解决CSP最基本的算法,它是深度优先搜索在CSP上的应用。可以把它理解为带有剪枝的dfs搜索:在每一步搜索时,回溯搜索都会通过前向检验判断当前状态是否已经和约束条件矛盾了,如果矛盾了,回溯搜索就会停止无意义的搜索、转而尝试其他选择

b.b.简单实现

def backtrace(self, assignment: Dict[str, Any], csp: 'CSP') -> Optional[Dict[str, Any]]:
    if len(assignment) == len(csp.variables):
        return assignment

    # Start assigning value 
    var = self.select_unassigned_variable(assignment, csp)
    
    # The order_domain_values method is used to get the values in a specific order
    # We will discuss it later
    for value in self.order_domain_values(var, assignment, csp):
        if csp.is_consistent(val, value, assignment):
            assignment[var] = value

            # Do forward check
            inference = self.forward_check(var, value, assignment, csp)
            if inference is not None:
                result = self.backtrace(assignment, csp)
            self.restore_domains(inference, csp) 
        else 
            result = self.backtrace(assignment, csp)
            if result is not None:
                return result
        del assignment[var]
    
    return None

c.c. 改进策略

我们可以通过让回溯搜索选择最可能不与限制条件冲突的值,然后继续搜索,常见策略如下:

变量选择启发式

  • 最少剩余值(MRV):选择合法值最少的变量
  • 度启发式:选择约束最多未赋值变量的变量

值选择启发式

  • 最小约束值(LCV):选择对其他变量约束最少的值

推理

  • 前向检查:赋值后立即检查相关变量域的一致性(见CSP过滤部分笔记)
  • 弧一致性:确保每个变量的值都有对应的一致邻居值

在回溯算法中对应的部分传入优化后的函数方法即可。

2. 约束传播

约束传播是通过局部一致性检查来缩小搜索空间的技术。

a.a. 弧一致性(Arc Consistency)

对于约束XiXjX_i \rightarrow X_j,如果对于XiX_i域中的每个值,XjX_j域中都存在至少一个值使得约束得到满足,则称该弧是一致的。如果XiX_i域中的某个值XjX_j的域中找不到能够满足约束的值,那么这个值就是错误的、应该从XiX_i的域中剔除掉。

在执行具体的搜索之前,我们希望做一些预处理、把绝对不可能的情况排除掉,来减少搜索可能性。AC3AC3算法(Arc Consistency Algorithm #3) 就是这个预处理过程的自动化和通用化版本。该算法的流程如下:

  1. 我们将所有需要检查的弧放在一个“工作清单”中,这个可以用队列实现。
  2. 从队列中弹出一个弧(Xi,Xj)(X_i, X_j),开始检查它的一致性。
    1. 我们先确认这个弧的一致性:弹出XiX_i的域中所有不满足一致性的值。
    2. 如果XiX_i的域缩小了,我们也需要调整所有指向XiX_i的弧,因为XiX_i的域的缩小可能导致它们的域也出现不一致的值、也需要缩小。

简单的代码实现如下:

def ac3(self, csp: 'CSP') -> bool:
        # Initialize queue with all arcs
        queue = deque()
        
        for constraint in csp.constraints:
            if isinstance(constraint, BinaryConstraint):
                queue.append((constraint.var1, constraint.var2, constraint))
                queue.append((constraint.var2, constraint.var1, constraint))
        
        while queue:
            xi, xj, constraint = queue.popleft()
            
            if self.revise(csp, xi, xj, constraint):
                # If xi's domain become empty, there's no solution
                if not csp.domains[xi]:
                    return False
                
                # Add all relative arcs back to queue
                for other_constraint in csp.constraints:
                    if isinstance(other_constraint, BinaryConstraint):
                        if other_constraint.var1 == xi and other_constraint.var2 != xj:
                            queue.append((other_constraint.var2, xi, other_constraint))
                        elif other_constraint.var2 == xi and other_constraint.var1 != xj:
                            queue.append((other_constraint.var1, xi, other_constraint))
        
        return True
    
    def revise(self, csp: 'CSP', xi: str, xj: str, constraint: BinaryConstraint) -> bool:
        """Revise function for AC-3 algorithm"""
        revised = False
        to_remove = set()
        
        for x in csp.domains[xi]:
            satisfied = False
            for y in csp.domains[xj]:
                temp_assignment = {xi: x, xj: y}
                if constraint.is_satisfied(temp_assignment):
                    satisfied = True
                    break
            
            if not satisfied:
                to_remove.add(x)
                revised = True
        
        csp.domains[xi] -= to_remove
        return revised

b.b. 路径一致性(Path Consistency)

确保对于任意两个变量的值组合,都存在第三个变量的值使得所有相关约束得到满足。

6. CSP的结构利用

a.a. 树结构CSP

如果约束图是树结构(无环图),可以在O(nd2)O(nd^2)时间内解决CSP:

  1. 选择任意变量作为根
  2. 从叶子到根进行弧一致性检查
  3. 从根到叶子进行赋值

b.b. 通用图的树分解

对于一般约束图,可以通过以下方法利用树结构:

  • 变量消除:逐步消除变量,构建新约束
  • 条件独立性:利用变量间的独立性分解问题
  • 树宽度:衡量图接近树结构的程度

7. 局部搜索方法

a.a.相关概念

由于回溯搜索是部分赋值的,需要系统性探索整个空间来找到一个满足约束条件的解。它是完备的,但是这个过程可能占用大量内存(递归栈)。对此,我们可以采用局部搜索的方法。

在CSP问题中,局部搜索会进行完整赋值,然后这个最初的赋值会在解空间中行走。然后这个值在解空间中“行走”,从一个完整解移动到它的“邻居”完整解,试图找到一个没有冲突的解。这在景观图中相当于从一个有冲突的“山丘”状态,走到一个没有冲突的“山谷”状态。

这种方法在许多实际问题中表现良好,特别是当解的密度较高时

b.b.简单实现

CSP的局部搜索中,“状态”转化为了约束冲突的数量。CSP会往冲突数量最小的方向靠近。因此,我们使用min_conflicts_value作为启发式函数,然后随即寻找一个位置进行搜索:

class LocalSearchSolver(CSPSolver):
    """Local search solver using min-conflicts heuristic"""
    
    def __init__(self, max_steps: int = 10000):
        self.max_steps = max_steps
    
    def solve(self, csp: CSP) -> Optional[Dict[str, Any]]:
        """Solve using min-conflicts local search"""
        # Start with random complete assignment
        current = self.random_assignment(csp)
        
        for step in range(self.max_steps):
            conflicts = self.get_conflicts(current, csp)
            if not conflicts:
                return current  # Found solution
            
            # Select random conflicted variable
            var = random.choice(list(conflicts))
            
            # Choose value that minimizes conflicts
            best_value = self.min_conflicts_value(var, current, csp)
            current[var] = best_value
        
        return None  # Failed to find solution
    
    def random_assignment(self, csp: CSP) -> Dict[str, Any]:
        """Generate random complete assignment"""
        assignment = {}
        for var in csp.variables:
            assignment[var] = random.choice(list(csp.domains[var]))
        return assignment
    
    def get_conflicts(self, assignment: Dict[str, Any], csp: CSP) -> Set[str]:
        """Get set of variables involved in constraint violations"""
        conflicts = set()
        for constraint in csp.constraints:
            if not constraint.is_satisfied(assignment):
                conflicts.update(constraint.variables)
        return conflicts
    
    def min_conflicts_value(self, var: str, assignment: Dict[str, Any], csp: CSP) -> Any:
        """Choose value for var that minimizes conflicts"""
        min_conflicts = float('inf')
        best_value = None
        
        for value in csp.domains[var]:
            test_assignment = assignment.copy()
            test_assignment[var] = value
            
            conflicts = len(self.get_conflicts(test_assignment, csp))
            if conflicts < min_conflicts:
                min_conflicts = conflicts
                best_value = value
        
        return best_value