m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

661 lines
20 KiB

6 months ago
  1. """Implementation of DPLL algorithm
  2. Features:
  3. - Clause learning
  4. - Watch literal scheme
  5. - VSIDS heuristic
  6. References:
  7. - https://en.wikipedia.org/wiki/DPLL_algorithm
  8. """
  9. from collections import defaultdict
  10. from heapq import heappush, heappop
  11. from sympy.core.sorting import ordered
  12. from sympy.assumptions.cnf import EncodedCNF
  13. def dpll_satisfiable(expr, all_models=False):
  14. """
  15. Check satisfiability of a propositional sentence.
  16. It returns a model rather than True when it succeeds.
  17. Returns a generator of all models if all_models is True.
  18. Examples
  19. ========
  20. >>> from sympy.abc import A, B
  21. >>> from sympy.logic.algorithms.dpll2 import dpll_satisfiable
  22. >>> dpll_satisfiable(A & ~B)
  23. {A: True, B: False}
  24. >>> dpll_satisfiable(A & ~A)
  25. False
  26. """
  27. if not isinstance(expr, EncodedCNF):
  28. exprs = EncodedCNF()
  29. exprs.add_prop(expr)
  30. expr = exprs
  31. # Return UNSAT when False (encoded as 0) is present in the CNF
  32. if {0} in expr.data:
  33. if all_models:
  34. return (f for f in [False])
  35. return False
  36. solver = SATSolver(expr.data, expr.variables, set(), expr.symbols)
  37. models = solver._find_model()
  38. if all_models:
  39. return _all_models(models)
  40. try:
  41. return next(models)
  42. except StopIteration:
  43. return False
  44. # Uncomment to confirm the solution is valid (hitting set for the clauses)
  45. #else:
  46. #for cls in clauses_int_repr:
  47. #assert solver.var_settings.intersection(cls)
  48. def _all_models(models):
  49. satisfiable = False
  50. try:
  51. while True:
  52. yield next(models)
  53. satisfiable = True
  54. except StopIteration:
  55. if not satisfiable:
  56. yield False
  57. class SATSolver:
  58. """
  59. Class for representing a SAT solver capable of
  60. finding a model to a boolean theory in conjunctive
  61. normal form.
  62. """
  63. def __init__(self, clauses, variables, var_settings, symbols=None,
  64. heuristic='vsids', clause_learning='none', INTERVAL=500):
  65. self.var_settings = var_settings
  66. self.heuristic = heuristic
  67. self.is_unsatisfied = False
  68. self._unit_prop_queue = []
  69. self.update_functions = []
  70. self.INTERVAL = INTERVAL
  71. if symbols is None:
  72. self.symbols = list(ordered(variables))
  73. else:
  74. self.symbols = symbols
  75. self._initialize_variables(variables)
  76. self._initialize_clauses(clauses)
  77. if 'vsids' == heuristic:
  78. self._vsids_init()
  79. self.heur_calculate = self._vsids_calculate
  80. self.heur_lit_assigned = self._vsids_lit_assigned
  81. self.heur_lit_unset = self._vsids_lit_unset
  82. self.heur_clause_added = self._vsids_clause_added
  83. # Note: Uncomment this if/when clause learning is enabled
  84. #self.update_functions.append(self._vsids_decay)
  85. else:
  86. raise NotImplementedError
  87. if 'simple' == clause_learning:
  88. self.add_learned_clause = self._simple_add_learned_clause
  89. self.compute_conflict = self.simple_compute_conflict
  90. self.update_functions.append(self.simple_clean_clauses)
  91. elif 'none' == clause_learning:
  92. self.add_learned_clause = lambda x: None
  93. self.compute_conflict = lambda: None
  94. else:
  95. raise NotImplementedError
  96. # Create the base level
  97. self.levels = [Level(0)]
  98. self._current_level.varsettings = var_settings
  99. # Keep stats
  100. self.num_decisions = 0
  101. self.num_learned_clauses = 0
  102. self.original_num_clauses = len(self.clauses)
  103. def _initialize_variables(self, variables):
  104. """Set up the variable data structures needed."""
  105. self.sentinels = defaultdict(set)
  106. self.occurrence_count = defaultdict(int)
  107. self.variable_set = [False] * (len(variables) + 1)
  108. def _initialize_clauses(self, clauses):
  109. """Set up the clause data structures needed.
  110. For each clause, the following changes are made:
  111. - Unit clauses are queued for propagation right away.
  112. - Non-unit clauses have their first and last literals set as sentinels.
  113. - The number of clauses a literal appears in is computed.
  114. """
  115. self.clauses = []
  116. for cls in clauses:
  117. self.clauses.append(list(cls))
  118. for i in range(len(self.clauses)):
  119. # Handle the unit clauses
  120. if 1 == len(self.clauses[i]):
  121. self._unit_prop_queue.append(self.clauses[i][0])
  122. continue
  123. self.sentinels[self.clauses[i][0]].add(i)
  124. self.sentinels[self.clauses[i][-1]].add(i)
  125. for lit in self.clauses[i]:
  126. self.occurrence_count[lit] += 1
  127. def _find_model(self):
  128. """
  129. Main DPLL loop. Returns a generator of models.
  130. Variables are chosen successively, and assigned to be either
  131. True or False. If a solution is not found with this setting,
  132. the opposite is chosen and the search continues. The solver
  133. halts when every variable has a setting.
  134. Examples
  135. ========
  136. >>> from sympy.logic.algorithms.dpll2 import SATSolver
  137. >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2},
  138. ... {3, -2}], {1, 2, 3}, set())
  139. >>> list(l._find_model())
  140. [{1: True, 2: False, 3: False}, {1: True, 2: True, 3: True}]
  141. >>> from sympy.abc import A, B, C
  142. >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2},
  143. ... {3, -2}], {1, 2, 3}, set(), [A, B, C])
  144. >>> list(l._find_model())
  145. [{A: True, B: False, C: False}, {A: True, B: True, C: True}]
  146. """
  147. # We use this variable to keep track of if we should flip a
  148. # variable setting in successive rounds
  149. flip_var = False
  150. # Check if unit prop says the theory is unsat right off the bat
  151. self._simplify()
  152. if self.is_unsatisfied:
  153. return
  154. # While the theory still has clauses remaining
  155. while True:
  156. # Perform cleanup / fixup at regular intervals
  157. if self.num_decisions % self.INTERVAL == 0:
  158. for func in self.update_functions:
  159. func()
  160. if flip_var:
  161. # We have just backtracked and we are trying to opposite literal
  162. flip_var = False
  163. lit = self._current_level.decision
  164. else:
  165. # Pick a literal to set
  166. lit = self.heur_calculate()
  167. self.num_decisions += 1
  168. # Stopping condition for a satisfying theory
  169. if 0 == lit:
  170. yield {self.symbols[abs(lit) - 1]:
  171. lit > 0 for lit in self.var_settings}
  172. while self._current_level.flipped:
  173. self._undo()
  174. if len(self.levels) == 1:
  175. return
  176. flip_lit = -self._current_level.decision
  177. self._undo()
  178. self.levels.append(Level(flip_lit, flipped=True))
  179. flip_var = True
  180. continue
  181. # Start the new decision level
  182. self.levels.append(Level(lit))
  183. # Assign the literal, updating the clauses it satisfies
  184. self._assign_literal(lit)
  185. # _simplify the theory
  186. self._simplify()
  187. # Check if we've made the theory unsat
  188. if self.is_unsatisfied:
  189. self.is_unsatisfied = False
  190. # We unroll all of the decisions until we can flip a literal
  191. while self._current_level.flipped:
  192. self._undo()
  193. # If we've unrolled all the way, the theory is unsat
  194. if 1 == len(self.levels):
  195. return
  196. # Detect and add a learned clause
  197. self.add_learned_clause(self.compute_conflict())
  198. # Try the opposite setting of the most recent decision
  199. flip_lit = -self._current_level.decision
  200. self._undo()
  201. self.levels.append(Level(flip_lit, flipped=True))
  202. flip_var = True
  203. ########################
  204. # Helper Methods #
  205. ########################
  206. @property
  207. def _current_level(self):
  208. """The current decision level data structure
  209. Examples
  210. ========
  211. >>> from sympy.logic.algorithms.dpll2 import SATSolver
  212. >>> l = SATSolver([{1}, {2}], {1, 2}, set())
  213. >>> next(l._find_model())
  214. {1: True, 2: True}
  215. >>> l._current_level.decision
  216. 0
  217. >>> l._current_level.flipped
  218. False
  219. >>> l._current_level.var_settings
  220. {1, 2}
  221. """
  222. return self.levels[-1]
  223. def _clause_sat(self, cls):
  224. """Check if a clause is satisfied by the current variable setting.
  225. Examples
  226. ========
  227. >>> from sympy.logic.algorithms.dpll2 import SATSolver
  228. >>> l = SATSolver([{1}, {-1}], {1}, set())
  229. >>> try:
  230. ... next(l._find_model())
  231. ... except StopIteration:
  232. ... pass
  233. >>> l._clause_sat(0)
  234. False
  235. >>> l._clause_sat(1)
  236. True
  237. """
  238. for lit in self.clauses[cls]:
  239. if lit in self.var_settings:
  240. return True
  241. return False
  242. def _is_sentinel(self, lit, cls):
  243. """Check if a literal is a sentinel of a given clause.
  244. Examples
  245. ========
  246. >>> from sympy.logic.algorithms.dpll2 import SATSolver
  247. >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2},
  248. ... {3, -2}], {1, 2, 3}, set())
  249. >>> next(l._find_model())
  250. {1: True, 2: False, 3: False}
  251. >>> l._is_sentinel(2, 3)
  252. True
  253. >>> l._is_sentinel(-3, 1)
  254. False
  255. """
  256. return cls in self.sentinels[lit]
  257. def _assign_literal(self, lit):
  258. """Make a literal assignment.
  259. The literal assignment must be recorded as part of the current
  260. decision level. Additionally, if the literal is marked as a
  261. sentinel of any clause, then a new sentinel must be chosen. If
  262. this is not possible, then unit propagation is triggered and
  263. another literal is added to the queue to be set in the future.
  264. Examples
  265. ========
  266. >>> from sympy.logic.algorithms.dpll2 import SATSolver
  267. >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2},
  268. ... {3, -2}], {1, 2, 3}, set())
  269. >>> next(l._find_model())
  270. {1: True, 2: False, 3: False}
  271. >>> l.var_settings
  272. {-3, -2, 1}
  273. >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2},
  274. ... {3, -2}], {1, 2, 3}, set())
  275. >>> l._assign_literal(-1)
  276. >>> try:
  277. ... next(l._find_model())
  278. ... except StopIteration:
  279. ... pass
  280. >>> l.var_settings
  281. {-1}
  282. """
  283. self.var_settings.add(lit)
  284. self._current_level.var_settings.add(lit)
  285. self.variable_set[abs(lit)] = True
  286. self.heur_lit_assigned(lit)
  287. sentinel_list = list(self.sentinels[-lit])
  288. for cls in sentinel_list:
  289. if not self._clause_sat(cls):
  290. other_sentinel = None
  291. for newlit in self.clauses[cls]:
  292. if newlit != -lit:
  293. if self._is_sentinel(newlit, cls):
  294. other_sentinel = newlit
  295. elif not self.variable_set[abs(newlit)]:
  296. self.sentinels[-lit].remove(cls)
  297. self.sentinels[newlit].add(cls)
  298. other_sentinel = None
  299. break
  300. # Check if no sentinel update exists
  301. if other_sentinel:
  302. self._unit_prop_queue.append(other_sentinel)
  303. def _undo(self):
  304. """
  305. _undo the changes of the most recent decision level.
  306. Examples
  307. ========
  308. >>> from sympy.logic.algorithms.dpll2 import SATSolver
  309. >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2},
  310. ... {3, -2}], {1, 2, 3}, set())
  311. >>> next(l._find_model())
  312. {1: True, 2: False, 3: False}
  313. >>> level = l._current_level
  314. >>> level.decision, level.var_settings, level.flipped
  315. (-3, {-3, -2}, False)
  316. >>> l._undo()
  317. >>> level = l._current_level
  318. >>> level.decision, level.var_settings, level.flipped
  319. (0, {1}, False)
  320. """
  321. # Undo the variable settings
  322. for lit in self._current_level.var_settings:
  323. self.var_settings.remove(lit)
  324. self.heur_lit_unset(lit)
  325. self.variable_set[abs(lit)] = False
  326. # Pop the level off the stack
  327. self.levels.pop()
  328. #########################
  329. # Propagation #
  330. #########################
  331. """
  332. Propagation methods should attempt to soundly simplify the boolean
  333. theory, and return True if any simplification occurred and False
  334. otherwise.
  335. """
  336. def _simplify(self):
  337. """Iterate over the various forms of propagation to simplify the theory.
  338. Examples
  339. ========
  340. >>> from sympy.logic.algorithms.dpll2 import SATSolver
  341. >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2},
  342. ... {3, -2}], {1, 2, 3}, set())
  343. >>> l.variable_set
  344. [False, False, False, False]
  345. >>> l.sentinels
  346. {-3: {0, 2}, -2: {3, 4}, 2: {0, 3}, 3: {2, 4}}
  347. >>> l._simplify()
  348. >>> l.variable_set
  349. [False, True, False, False]
  350. >>> l.sentinels
  351. {-3: {0, 2}, -2: {3, 4}, -1: set(), 2: {0, 3},
  352. ...3: {2, 4}}
  353. """
  354. changed = True
  355. while changed:
  356. changed = False
  357. changed |= self._unit_prop()
  358. changed |= self._pure_literal()
  359. def _unit_prop(self):
  360. """Perform unit propagation on the current theory."""
  361. result = len(self._unit_prop_queue) > 0
  362. while self._unit_prop_queue:
  363. next_lit = self._unit_prop_queue.pop()
  364. if -next_lit in self.var_settings:
  365. self.is_unsatisfied = True
  366. self._unit_prop_queue = []
  367. return False
  368. else:
  369. self._assign_literal(next_lit)
  370. return result
  371. def _pure_literal(self):
  372. """Look for pure literals and assign them when found."""
  373. return False
  374. #########################
  375. # Heuristics #
  376. #########################
  377. def _vsids_init(self):
  378. """Initialize the data structures needed for the VSIDS heuristic."""
  379. self.lit_heap = []
  380. self.lit_scores = {}
  381. for var in range(1, len(self.variable_set)):
  382. self.lit_scores[var] = float(-self.occurrence_count[var])
  383. self.lit_scores[-var] = float(-self.occurrence_count[-var])
  384. heappush(self.lit_heap, (self.lit_scores[var], var))
  385. heappush(self.lit_heap, (self.lit_scores[-var], -var))
  386. def _vsids_decay(self):
  387. """Decay the VSIDS scores for every literal.
  388. Examples
  389. ========
  390. >>> from sympy.logic.algorithms.dpll2 import SATSolver
  391. >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2},
  392. ... {3, -2}], {1, 2, 3}, set())
  393. >>> l.lit_scores
  394. {-3: -2.0, -2: -2.0, -1: 0.0, 1: 0.0, 2: -2.0, 3: -2.0}
  395. >>> l._vsids_decay()
  396. >>> l.lit_scores
  397. {-3: -1.0, -2: -1.0, -1: 0.0, 1: 0.0, 2: -1.0, 3: -1.0}
  398. """
  399. # We divide every literal score by 2 for a decay factor
  400. # Note: This doesn't change the heap property
  401. for lit in self.lit_scores.keys():
  402. self.lit_scores[lit] /= 2.0
  403. def _vsids_calculate(self):
  404. """
  405. VSIDS Heuristic Calculation
  406. Examples
  407. ========
  408. >>> from sympy.logic.algorithms.dpll2 import SATSolver
  409. >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2},
  410. ... {3, -2}], {1, 2, 3}, set())
  411. >>> l.lit_heap
  412. [(-2.0, -3), (-2.0, 2), (-2.0, -2), (0.0, 1), (-2.0, 3), (0.0, -1)]
  413. >>> l._vsids_calculate()
  414. -3
  415. >>> l.lit_heap
  416. [(-2.0, -2), (-2.0, 2), (0.0, -1), (0.0, 1), (-2.0, 3)]
  417. """
  418. if len(self.lit_heap) == 0:
  419. return 0
  420. # Clean out the front of the heap as long the variables are set
  421. while self.variable_set[abs(self.lit_heap[0][1])]:
  422. heappop(self.lit_heap)
  423. if len(self.lit_heap) == 0:
  424. return 0
  425. return heappop(self.lit_heap)[1]
  426. def _vsids_lit_assigned(self, lit):
  427. """Handle the assignment of a literal for the VSIDS heuristic."""
  428. pass
  429. def _vsids_lit_unset(self, lit):
  430. """Handle the unsetting of a literal for the VSIDS heuristic.
  431. Examples
  432. ========
  433. >>> from sympy.logic.algorithms.dpll2 import SATSolver
  434. >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2},
  435. ... {3, -2}], {1, 2, 3}, set())
  436. >>> l.lit_heap
  437. [(-2.0, -3), (-2.0, 2), (-2.0, -2), (0.0, 1), (-2.0, 3), (0.0, -1)]
  438. >>> l._vsids_lit_unset(2)
  439. >>> l.lit_heap
  440. [(-2.0, -3), (-2.0, -2), (-2.0, -2), (-2.0, 2), (-2.0, 3), (0.0, -1),
  441. ...(-2.0, 2), (0.0, 1)]
  442. """
  443. var = abs(lit)
  444. heappush(self.lit_heap, (self.lit_scores[var], var))
  445. heappush(self.lit_heap, (self.lit_scores[-var], -var))
  446. def _vsids_clause_added(self, cls):
  447. """Handle the addition of a new clause for the VSIDS heuristic.
  448. Examples
  449. ========
  450. >>> from sympy.logic.algorithms.dpll2 import SATSolver
  451. >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2},
  452. ... {3, -2}], {1, 2, 3}, set())
  453. >>> l.num_learned_clauses
  454. 0
  455. >>> l.lit_scores
  456. {-3: -2.0, -2: -2.0, -1: 0.0, 1: 0.0, 2: -2.0, 3: -2.0}
  457. >>> l._vsids_clause_added({2, -3})
  458. >>> l.num_learned_clauses
  459. 1
  460. >>> l.lit_scores
  461. {-3: -1.0, -2: -2.0, -1: 0.0, 1: 0.0, 2: -1.0, 3: -2.0}
  462. """
  463. self.num_learned_clauses += 1
  464. for lit in cls:
  465. self.lit_scores[lit] += 1
  466. ########################
  467. # Clause Learning #
  468. ########################
  469. def _simple_add_learned_clause(self, cls):
  470. """Add a new clause to the theory.
  471. Examples
  472. ========
  473. >>> from sympy.logic.algorithms.dpll2 import SATSolver
  474. >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2},
  475. ... {3, -2}], {1, 2, 3}, set())
  476. >>> l.num_learned_clauses
  477. 0
  478. >>> l.clauses
  479. [[2, -3], [1], [3, -3], [2, -2], [3, -2]]
  480. >>> l.sentinels
  481. {-3: {0, 2}, -2: {3, 4}, 2: {0, 3}, 3: {2, 4}}
  482. >>> l._simple_add_learned_clause([3])
  483. >>> l.clauses
  484. [[2, -3], [1], [3, -3], [2, -2], [3, -2], [3]]
  485. >>> l.sentinels
  486. {-3: {0, 2}, -2: {3, 4}, 2: {0, 3}, 3: {2, 4, 5}}
  487. """
  488. cls_num = len(self.clauses)
  489. self.clauses.append(cls)
  490. for lit in cls:
  491. self.occurrence_count[lit] += 1
  492. self.sentinels[cls[0]].add(cls_num)
  493. self.sentinels[cls[-1]].add(cls_num)
  494. self.heur_clause_added(cls)
  495. def _simple_compute_conflict(self):
  496. """ Build a clause representing the fact that at least one decision made
  497. so far is wrong.
  498. Examples
  499. ========
  500. >>> from sympy.logic.algorithms.dpll2 import SATSolver
  501. >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2},
  502. ... {3, -2}], {1, 2, 3}, set())
  503. >>> next(l._find_model())
  504. {1: True, 2: False, 3: False}
  505. >>> l._simple_compute_conflict()
  506. [3]
  507. """
  508. return [-(level.decision) for level in self.levels[1:]]
  509. def _simple_clean_clauses(self):
  510. """Clean up learned clauses."""
  511. pass
  512. class Level:
  513. """
  514. Represents a single level in the DPLL algorithm, and contains
  515. enough information for a sound backtracking procedure.
  516. """
  517. def __init__(self, decision, flipped=False):
  518. self.decision = decision
  519. self.var_settings = set()
  520. self.flipped = flipped