m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

234 lines
6.9 KiB

7 months ago
  1. """ Generic Unification algorithm for expression trees with lists of children
  2. This implementation is a direct translation of
  3. Artificial Intelligence: A Modern Approach by Stuart Russel and Peter Norvig
  4. Second edition, section 9.2, page 276
  5. It is modified in the following ways:
  6. 1. We allow associative and commutative Compound expressions. This results in
  7. combinatorial blowup.
  8. 2. We explore the tree lazily.
  9. 3. We provide generic interfaces to symbolic algebra libraries in Python.
  10. A more traditional version can be found here
  11. http://aima.cs.berkeley.edu/python/logic.html
  12. """
  13. from sympy.utilities.iterables import kbins
  14. class Compound:
  15. """ A little class to represent an interior node in the tree
  16. This is analogous to SymPy.Basic for non-Atoms
  17. """
  18. def __init__(self, op, args):
  19. self.op = op
  20. self.args = args
  21. def __eq__(self, other):
  22. return (type(self) is type(other) and self.op == other.op and
  23. self.args == other.args)
  24. def __hash__(self):
  25. return hash((type(self), self.op, self.args))
  26. def __str__(self):
  27. return "%s[%s]" % (str(self.op), ', '.join(map(str, self.args)))
  28. class Variable:
  29. """ A Wild token """
  30. def __init__(self, arg):
  31. self.arg = arg
  32. def __eq__(self, other):
  33. return type(self) is type(other) and self.arg == other.arg
  34. def __hash__(self):
  35. return hash((type(self), self.arg))
  36. def __str__(self):
  37. return "Variable(%s)" % str(self.arg)
  38. class CondVariable:
  39. """ A wild token that matches conditionally.
  40. arg - a wild token.
  41. valid - an additional constraining function on a match.
  42. """
  43. def __init__(self, arg, valid):
  44. self.arg = arg
  45. self.valid = valid
  46. def __eq__(self, other):
  47. return (type(self) is type(other) and
  48. self.arg == other.arg and
  49. self.valid == other.valid)
  50. def __hash__(self):
  51. return hash((type(self), self.arg, self.valid))
  52. def __str__(self):
  53. return "CondVariable(%s)" % str(self.arg)
  54. def unify(x, y, s=None, **fns):
  55. """ Unify two expressions.
  56. Parameters
  57. ==========
  58. x, y - expression trees containing leaves, Compounds and Variables.
  59. s - a mapping of variables to subtrees.
  60. Returns
  61. =======
  62. lazy sequence of mappings {Variable: subtree}
  63. Examples
  64. ========
  65. >>> from sympy.unify.core import unify, Compound, Variable
  66. >>> expr = Compound("Add", ("x", "y"))
  67. >>> pattern = Compound("Add", ("x", Variable("a")))
  68. >>> next(unify(expr, pattern, {}))
  69. {Variable(a): 'y'}
  70. """
  71. s = s or {}
  72. if x == y:
  73. yield s
  74. elif isinstance(x, (Variable, CondVariable)):
  75. yield from unify_var(x, y, s, **fns)
  76. elif isinstance(y, (Variable, CondVariable)):
  77. yield from unify_var(y, x, s, **fns)
  78. elif isinstance(x, Compound) and isinstance(y, Compound):
  79. is_commutative = fns.get('is_commutative', lambda x: False)
  80. is_associative = fns.get('is_associative', lambda x: False)
  81. for sop in unify(x.op, y.op, s, **fns):
  82. if is_associative(x) and is_associative(y):
  83. a, b = (x, y) if len(x.args) < len(y.args) else (y, x)
  84. if is_commutative(x) and is_commutative(y):
  85. combs = allcombinations(a.args, b.args, 'commutative')
  86. else:
  87. combs = allcombinations(a.args, b.args, 'associative')
  88. for aaargs, bbargs in combs:
  89. aa = [unpack(Compound(a.op, arg)) for arg in aaargs]
  90. bb = [unpack(Compound(b.op, arg)) for arg in bbargs]
  91. yield from unify(aa, bb, sop, **fns)
  92. elif len(x.args) == len(y.args):
  93. yield from unify(x.args, y.args, sop, **fns)
  94. elif is_args(x) and is_args(y) and len(x) == len(y):
  95. if len(x) == 0:
  96. yield s
  97. else:
  98. for shead in unify(x[0], y[0], s, **fns):
  99. yield from unify(x[1:], y[1:], shead, **fns)
  100. def unify_var(var, x, s, **fns):
  101. if var in s:
  102. yield from unify(s[var], x, s, **fns)
  103. elif occur_check(var, x):
  104. pass
  105. elif isinstance(var, CondVariable) and var.valid(x):
  106. yield assoc(s, var, x)
  107. elif isinstance(var, Variable):
  108. yield assoc(s, var, x)
  109. def occur_check(var, x):
  110. """ var occurs in subtree owned by x? """
  111. if var == x:
  112. return True
  113. elif isinstance(x, Compound):
  114. return occur_check(var, x.args)
  115. elif is_args(x):
  116. if any(occur_check(var, xi) for xi in x): return True
  117. return False
  118. def assoc(d, key, val):
  119. """ Return copy of d with key associated to val """
  120. d = d.copy()
  121. d[key] = val
  122. return d
  123. def is_args(x):
  124. """ Is x a traditional iterable? """
  125. return type(x) in (tuple, list, set)
  126. def unpack(x):
  127. if isinstance(x, Compound) and len(x.args) == 1:
  128. return x.args[0]
  129. else:
  130. return x
  131. def allcombinations(A, B, ordered):
  132. """
  133. Restructure A and B to have the same number of elements.
  134. Parameters
  135. ==========
  136. ordered must be either 'commutative' or 'associative'.
  137. A and B can be rearranged so that the larger of the two lists is
  138. reorganized into smaller sublists.
  139. Examples
  140. ========
  141. >>> from sympy.unify.core import allcombinations
  142. >>> for x in allcombinations((1, 2, 3), (5, 6), 'associative'): print(x)
  143. (((1,), (2, 3)), ((5,), (6,)))
  144. (((1, 2), (3,)), ((5,), (6,)))
  145. >>> for x in allcombinations((1, 2, 3), (5, 6), 'commutative'): print(x)
  146. (((1,), (2, 3)), ((5,), (6,)))
  147. (((1, 2), (3,)), ((5,), (6,)))
  148. (((1,), (3, 2)), ((5,), (6,)))
  149. (((1, 3), (2,)), ((5,), (6,)))
  150. (((2,), (1, 3)), ((5,), (6,)))
  151. (((2, 1), (3,)), ((5,), (6,)))
  152. (((2,), (3, 1)), ((5,), (6,)))
  153. (((2, 3), (1,)), ((5,), (6,)))
  154. (((3,), (1, 2)), ((5,), (6,)))
  155. (((3, 1), (2,)), ((5,), (6,)))
  156. (((3,), (2, 1)), ((5,), (6,)))
  157. (((3, 2), (1,)), ((5,), (6,)))
  158. """
  159. if ordered == "commutative":
  160. ordered = 11
  161. if ordered == "associative":
  162. ordered = None
  163. sm, bg = (A, B) if len(A) < len(B) else (B, A)
  164. for part in kbins(list(range(len(bg))), len(sm), ordered=ordered):
  165. if bg == B:
  166. yield tuple((a,) for a in A), partition(B, part)
  167. else:
  168. yield partition(A, part), tuple((b,) for b in B)
  169. def partition(it, part):
  170. """ Partition a tuple/list into pieces defined by indices.
  171. Examples
  172. ========
  173. >>> from sympy.unify.core import partition
  174. >>> partition((10, 20, 30, 40), [[0, 1, 2], [3]])
  175. ((10, 20, 30), (40,))
  176. """
  177. return type(it)([index(it, ind) for ind in part])
  178. def index(it, ind):
  179. """ Fancy indexing into an indexable iterable (tuple, list).
  180. Examples
  181. ========
  182. >>> from sympy.unify.core import index
  183. >>> index([10, 20, 30], (1, 2, 0))
  184. [20, 30, 10]
  185. """
  186. return type(it)([it[i] for i in ind])