m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

124 lines
3.7 KiB

6 months ago
  1. from sympy.core.symbol import symbols, Dummy
  2. from sympy.matrices.expressions.applyfunc import ElementwiseApplyFunction
  3. from sympy.core.function import Lambda
  4. from sympy.functions.elementary.exponential import exp
  5. from sympy.functions.elementary.trigonometric import sin
  6. from sympy.matrices.dense import Matrix
  7. from sympy.matrices.expressions.matexpr import MatrixSymbol
  8. from sympy.matrices.expressions.matmul import MatMul
  9. from sympy.simplify.simplify import simplify
  10. from sympy.testing.pytest import raises
  11. from sympy.matrices.common import ShapeError
  12. X = MatrixSymbol("X", 3, 3)
  13. Y = MatrixSymbol("Y", 3, 3)
  14. k = symbols("k")
  15. Xk = MatrixSymbol("X", k, k)
  16. Xd = X.as_explicit()
  17. x, y, z, t = symbols("x y z t")
  18. def test_applyfunc_matrix():
  19. x = Dummy('x')
  20. double = Lambda(x, x**2)
  21. expr = ElementwiseApplyFunction(double, Xd)
  22. assert isinstance(expr, ElementwiseApplyFunction)
  23. assert expr.doit() == Xd.applyfunc(lambda x: x**2)
  24. assert expr.shape == (3, 3)
  25. assert expr.func(*expr.args) == expr
  26. assert simplify(expr) == expr
  27. assert expr[0, 0] == double(Xd[0, 0])
  28. expr = ElementwiseApplyFunction(double, X)
  29. assert isinstance(expr, ElementwiseApplyFunction)
  30. assert isinstance(expr.doit(), ElementwiseApplyFunction)
  31. assert expr == X.applyfunc(double)
  32. assert expr.func(*expr.args) == expr
  33. expr = ElementwiseApplyFunction(exp, X*Y)
  34. assert expr.expr == X*Y
  35. assert expr.function.dummy_eq(Lambda(x, exp(x)))
  36. assert expr.dummy_eq((X*Y).applyfunc(exp))
  37. assert expr.func(*expr.args) == expr
  38. assert isinstance(X*expr, MatMul)
  39. assert (X*expr).shape == (3, 3)
  40. Z = MatrixSymbol("Z", 2, 3)
  41. assert (Z*expr).shape == (2, 3)
  42. expr = ElementwiseApplyFunction(exp, Z.T)*ElementwiseApplyFunction(exp, Z)
  43. assert expr.shape == (3, 3)
  44. expr = ElementwiseApplyFunction(exp, Z)*ElementwiseApplyFunction(exp, Z.T)
  45. assert expr.shape == (2, 2)
  46. raises(ShapeError, lambda: ElementwiseApplyFunction(exp, Z)*ElementwiseApplyFunction(exp, Z))
  47. M = Matrix([[x, y], [z, t]])
  48. expr = ElementwiseApplyFunction(sin, M)
  49. assert isinstance(expr, ElementwiseApplyFunction)
  50. assert expr.function.dummy_eq(Lambda(x, sin(x)))
  51. assert expr.expr == M
  52. assert expr.doit() == M.applyfunc(sin)
  53. assert expr.doit() == Matrix([[sin(x), sin(y)], [sin(z), sin(t)]])
  54. assert expr.func(*expr.args) == expr
  55. expr = ElementwiseApplyFunction(double, Xk)
  56. assert expr.doit() == expr
  57. assert expr.subs(k, 2).shape == (2, 2)
  58. assert (expr*expr).shape == (k, k)
  59. M = MatrixSymbol("M", k, t)
  60. expr2 = M.T*expr*M
  61. assert isinstance(expr2, MatMul)
  62. assert expr2.args[1] == expr
  63. assert expr2.shape == (t, t)
  64. expr3 = expr*M
  65. assert expr3.shape == (k, t)
  66. raises(ShapeError, lambda: M*expr)
  67. expr1 = ElementwiseApplyFunction(lambda x: x+1, Xk)
  68. expr2 = ElementwiseApplyFunction(lambda x: x, Xk)
  69. assert expr1 != expr2
  70. def test_applyfunc_entry():
  71. af = X.applyfunc(sin)
  72. assert af[0, 0] == sin(X[0, 0])
  73. af = Xd.applyfunc(sin)
  74. assert af[0, 0] == sin(X[0, 0])
  75. def test_applyfunc_as_explicit():
  76. af = X.applyfunc(sin)
  77. assert af.as_explicit() == Matrix([
  78. [sin(X[0, 0]), sin(X[0, 1]), sin(X[0, 2])],
  79. [sin(X[1, 0]), sin(X[1, 1]), sin(X[1, 2])],
  80. [sin(X[2, 0]), sin(X[2, 1]), sin(X[2, 2])],
  81. ])
  82. def test_applyfunc_transpose():
  83. af = Xk.applyfunc(sin)
  84. assert af.T.dummy_eq(Xk.T.applyfunc(sin))
  85. def test_applyfunc_shape_11_matrices():
  86. M = MatrixSymbol("M", 1, 1)
  87. double = Lambda(x, x*2)
  88. expr = M.applyfunc(sin)
  89. assert isinstance(expr, ElementwiseApplyFunction)
  90. expr = M.applyfunc(double)
  91. assert isinstance(expr, MatMul)
  92. assert expr == 2*M