图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

154 lines
4.1 KiB

  1. from collections.abc import Iterable
  2. from functools import singledispatch
  3. from sympy.core.expr import Expr
  4. from sympy.core.mul import Mul
  5. from sympy.core.singleton import S
  6. from sympy.core.sympify import sympify
  7. from sympy.core.parameters import global_parameters
  8. class TensorProduct(Expr):
  9. """
  10. Generic class for tensor products.
  11. """
  12. is_number = False
  13. def __new__(cls, *args, **kwargs):
  14. from sympy.tensor.array import NDimArray, tensorproduct, Array
  15. from sympy.matrices.expressions.matexpr import MatrixExpr
  16. from sympy.matrices.matrices import MatrixBase
  17. from sympy.strategies import flatten
  18. args = [sympify(arg) for arg in args]
  19. evaluate = kwargs.get("evaluate", global_parameters.evaluate)
  20. if not evaluate:
  21. obj = Expr.__new__(cls, *args)
  22. return obj
  23. arrays = []
  24. other = []
  25. scalar = S.One
  26. for arg in args:
  27. if isinstance(arg, (Iterable, MatrixBase, NDimArray)):
  28. arrays.append(Array(arg))
  29. elif isinstance(arg, (MatrixExpr,)):
  30. other.append(arg)
  31. else:
  32. scalar *= arg
  33. coeff = scalar*tensorproduct(*arrays)
  34. if len(other) == 0:
  35. return coeff
  36. if coeff != 1:
  37. newargs = [coeff] + other
  38. else:
  39. newargs = other
  40. obj = Expr.__new__(cls, *newargs, **kwargs)
  41. return flatten(obj)
  42. def rank(self):
  43. return len(self.shape)
  44. def _get_args_shapes(self):
  45. from sympy.tensor.array import Array
  46. return [i.shape if hasattr(i, "shape") else Array(i).shape for i in self.args]
  47. @property
  48. def shape(self):
  49. shape_list = self._get_args_shapes()
  50. return sum(shape_list, ())
  51. def __getitem__(self, index):
  52. index = iter(index)
  53. return Mul.fromiter(
  54. arg.__getitem__(tuple(next(index) for i in shp))
  55. for arg, shp in zip(self.args, self._get_args_shapes())
  56. )
  57. @singledispatch
  58. def shape(expr):
  59. """
  60. Return the shape of the *expr* as a tuple. *expr* should represent
  61. suitable object such as matrix or array.
  62. Parameters
  63. ==========
  64. expr : SymPy object having ``MatrixKind`` or ``ArrayKind``.
  65. Raises
  66. ======
  67. NoShapeError : Raised when object with wrong kind is passed.
  68. Examples
  69. ========
  70. This function returns the shape of any object representing matrix or array.
  71. >>> from sympy import shape, Array, ImmutableDenseMatrix, Integral
  72. >>> from sympy.abc import x
  73. >>> A = Array([1, 2])
  74. >>> shape(A)
  75. (2,)
  76. >>> shape(Integral(A, x))
  77. (2,)
  78. >>> M = ImmutableDenseMatrix([1, 2])
  79. >>> shape(M)
  80. (2, 1)
  81. >>> shape(Integral(M, x))
  82. (2, 1)
  83. You can support new type by dispatching.
  84. >>> from sympy import Expr
  85. >>> class NewExpr(Expr):
  86. ... pass
  87. >>> @shape.register(NewExpr)
  88. ... def _(expr):
  89. ... return shape(expr.args[0])
  90. >>> shape(NewExpr(M))
  91. (2, 1)
  92. If unsuitable expression is passed, ``NoShapeError()`` will be raised.
  93. >>> shape(Integral(x, x))
  94. Traceback (most recent call last):
  95. ...
  96. sympy.tensor.functions.NoShapeError: shape() called on non-array object: Integral(x, x)
  97. Notes
  98. =====
  99. Array-like classes (such as ``Matrix`` or ``NDimArray``) has ``shape``
  100. property which returns its shape, but it cannot be used for non-array
  101. classes containing array. This function returns the shape of any
  102. registered object representing array.
  103. """
  104. if hasattr(expr, "shape"):
  105. return expr.shape
  106. raise NoShapeError(
  107. "%s does not have shape, or its type is not registered to shape()." % expr)
  108. class NoShapeError(Exception):
  109. """
  110. Raised when ``shape()`` is called on non-array object.
  111. This error can be imported from ``sympy.tensor.functions``.
  112. Examples
  113. ========
  114. >>> from sympy import shape
  115. >>> from sympy.abc import x
  116. >>> shape(x)
  117. Traceback (most recent call last):
  118. ...
  119. sympy.tensor.functions.NoShapeError: shape() called on non-array object: x
  120. """
  121. pass