m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

77 lines
2.5 KiB

6 months ago
  1. from sympy.strategies.tree import treeapply, greedy, allresults, brute
  2. from functools import partial, reduce
  3. def test_treeapply():
  4. tree = ([3, 3], [4, 1], 2)
  5. assert treeapply(tree, {list: min, tuple: max}) == 3
  6. add = lambda *args: sum(args)
  7. mul = lambda *args: reduce(lambda a, b: a*b, args, 1)
  8. assert treeapply(tree, {list: add, tuple: mul}) == 60
  9. def test_treeapply_leaf():
  10. assert treeapply(3, {}, leaf=lambda x: x**2) == 9
  11. tree = ([3, 3], [4, 1], 2)
  12. treep1 = ([4, 4], [5, 2], 3)
  13. assert treeapply(tree, {list: min, tuple: max}, leaf=lambda x: x+1) == \
  14. treeapply(treep1, {list: min, tuple: max})
  15. def test_treeapply_strategies():
  16. from sympy.strategies import chain, minimize
  17. join = {list: chain, tuple: minimize}
  18. inc = lambda x: x + 1
  19. dec = lambda x: x - 1
  20. double = lambda x: 2*x
  21. assert treeapply(inc, join) == inc
  22. assert treeapply((inc, dec), join)(5) == minimize(inc, dec)(5)
  23. assert treeapply([inc, dec], join)(5) == chain(inc, dec)(5)
  24. tree = (inc, [dec, double]) # either inc or dec-then-double
  25. assert treeapply(tree, join)(5) == 6
  26. assert treeapply(tree, join)(1) == 0
  27. maximize = partial(minimize, objective=lambda x: -x)
  28. join = {list: chain, tuple: maximize}
  29. fn = treeapply(tree, join)
  30. assert fn(4) == 6 # highest value comes from the dec then double
  31. assert fn(1) == 2 # highest value comes from the inc
  32. def test_greedy():
  33. inc = lambda x: x + 1
  34. dec = lambda x: x - 1
  35. double = lambda x: 2*x
  36. tree = [inc, (dec, double)] # either inc or dec-then-double
  37. fn = greedy(tree, objective=lambda x: -x)
  38. assert fn(4) == 6 # highest value comes from the dec then double
  39. assert fn(1) == 2 # highest value comes from the inc
  40. tree = [inc, dec, [inc, dec, [(inc, inc), (dec, dec)]]]
  41. lowest = greedy(tree)
  42. assert lowest(10) == 8
  43. highest = greedy(tree, objective=lambda x: -x)
  44. assert highest(10) == 12
  45. def test_allresults():
  46. inc = lambda x: x+1
  47. dec = lambda x: x-1
  48. double = lambda x: x*2
  49. # square = lambda x: x**2
  50. assert set(allresults(inc)(3)) == {inc(3)}
  51. assert set(allresults([inc, dec])(3)) == {2, 4}
  52. assert set(allresults((inc, dec))(3)) == {3}
  53. assert set(allresults([inc, (dec, double)])(4)) == {5, 6}
  54. def test_brute():
  55. inc = lambda x: x+1
  56. dec = lambda x: x-1
  57. square = lambda x: x**2
  58. tree = ([inc, dec], square)
  59. fn = brute(tree, lambda x: -x)
  60. assert fn(2) == (2 + 1)**2
  61. assert fn(-2) == (-2 - 1)**2
  62. assert brute(inc)(1) == 2