m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

95 lines
3.1 KiB

11 months ago
  1. from sympy.core.basic import Basic
  2. from sympy.printing import pprint
  3. import random
  4. def interactive_traversal(expr):
  5. """Traverse a tree asking a user which branch to choose. """
  6. RED, BRED = '\033[0;31m', '\033[1;31m'
  7. GREEN, BGREEN = '\033[0;32m', '\033[1;32m'
  8. YELLOW, BYELLOW = '\033[0;33m', '\033[1;33m' # noqa
  9. BLUE, BBLUE = '\033[0;34m', '\033[1;34m' # noqa
  10. MAGENTA, BMAGENTA = '\033[0;35m', '\033[1;35m'# noqa
  11. CYAN, BCYAN = '\033[0;36m', '\033[1;36m' # noqa
  12. END = '\033[0m'
  13. def cprint(*args):
  14. print("".join(map(str, args)) + END)
  15. def _interactive_traversal(expr, stage):
  16. if stage > 0:
  17. print()
  18. cprint("Current expression (stage ", BYELLOW, stage, END, "):")
  19. print(BCYAN)
  20. pprint(expr)
  21. print(END)
  22. if isinstance(expr, Basic):
  23. if expr.is_Add:
  24. args = expr.as_ordered_terms()
  25. elif expr.is_Mul:
  26. args = expr.as_ordered_factors()
  27. else:
  28. args = expr.args
  29. elif hasattr(expr, "__iter__"):
  30. args = list(expr)
  31. else:
  32. return expr
  33. n_args = len(args)
  34. if not n_args:
  35. return expr
  36. for i, arg in enumerate(args):
  37. cprint(GREEN, "[", BGREEN, i, GREEN, "] ", BLUE, type(arg), END)
  38. pprint(arg)
  39. print()
  40. if n_args == 1:
  41. choices = '0'
  42. else:
  43. choices = '0-%d' % (n_args - 1)
  44. try:
  45. choice = input("Your choice [%s,f,l,r,d,?]: " % choices)
  46. except EOFError:
  47. result = expr
  48. print()
  49. else:
  50. if choice == '?':
  51. cprint(RED, "%s - select subexpression with the given index" %
  52. choices)
  53. cprint(RED, "f - select the first subexpression")
  54. cprint(RED, "l - select the last subexpression")
  55. cprint(RED, "r - select a random subexpression")
  56. cprint(RED, "d - done\n")
  57. result = _interactive_traversal(expr, stage)
  58. elif choice in ('d', ''):
  59. result = expr
  60. elif choice == 'f':
  61. result = _interactive_traversal(args[0], stage + 1)
  62. elif choice == 'l':
  63. result = _interactive_traversal(args[-1], stage + 1)
  64. elif choice == 'r':
  65. result = _interactive_traversal(random.choice(args), stage + 1)
  66. else:
  67. try:
  68. choice = int(choice)
  69. except ValueError:
  70. cprint(BRED,
  71. "Choice must be a number in %s range\n" % choices)
  72. result = _interactive_traversal(expr, stage)
  73. else:
  74. if choice < 0 or choice >= n_args:
  75. cprint(BRED, "Choice must be in %s range\n" % choices)
  76. result = _interactive_traversal(expr, stage)
  77. else:
  78. result = _interactive_traversal(args[choice], stage + 1)
  79. return result
  80. return _interactive_traversal(expr, 0)