m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

56 lines
2.3 KiB

6 months ago
  1. from sympy.core import symbols, Lambda
  2. from sympy.functions import KroneckerDelta
  3. from sympy.matrices import Matrix
  4. from sympy.matrices.expressions import FunctionMatrix, MatrixExpr, Identity
  5. from sympy.testing.pytest import raises, warns
  6. from sympy.utilities.exceptions import SymPyDeprecationWarning
  7. def test_funcmatrix_creation():
  8. i, j, k = symbols('i j k')
  9. assert FunctionMatrix(2, 2, Lambda((i, j), 0))
  10. assert FunctionMatrix(0, 0, Lambda((i, j), 0))
  11. raises(ValueError, lambda: FunctionMatrix(-1, 0, Lambda((i, j), 0)))
  12. raises(ValueError, lambda: FunctionMatrix(2.0, 0, Lambda((i, j), 0)))
  13. raises(ValueError, lambda: FunctionMatrix(2j, 0, Lambda((i, j), 0)))
  14. raises(ValueError, lambda: FunctionMatrix(0, -1, Lambda((i, j), 0)))
  15. raises(ValueError, lambda: FunctionMatrix(0, 2.0, Lambda((i, j), 0)))
  16. raises(ValueError, lambda: FunctionMatrix(0, 2j, Lambda((i, j), 0)))
  17. raises(ValueError, lambda: FunctionMatrix(2, 2, Lambda(i, 0)))
  18. with warns(SymPyDeprecationWarning, test_stacklevel=False):
  19. # This raises a deprecation warning from sympify()
  20. raises(ValueError, lambda: FunctionMatrix(2, 2, lambda i, j: 0))
  21. raises(ValueError, lambda: FunctionMatrix(2, 2, Lambda((i,), 0)))
  22. raises(ValueError, lambda: FunctionMatrix(2, 2, Lambda((i, j, k), 0)))
  23. raises(ValueError, lambda: FunctionMatrix(2, 2, i+j))
  24. assert FunctionMatrix(2, 2, "lambda i, j: 0") == \
  25. FunctionMatrix(2, 2, Lambda((i, j), 0))
  26. m = FunctionMatrix(2, 2, KroneckerDelta)
  27. assert m.as_explicit() == Identity(2).as_explicit()
  28. assert m.args[2].dummy_eq(Lambda((i, j), KroneckerDelta(i, j)))
  29. n = symbols('n')
  30. assert FunctionMatrix(n, n, Lambda((i, j), 0))
  31. n = symbols('n', integer=False)
  32. raises(ValueError, lambda: FunctionMatrix(n, n, Lambda((i, j), 0)))
  33. n = symbols('n', negative=True)
  34. raises(ValueError, lambda: FunctionMatrix(n, n, Lambda((i, j), 0)))
  35. def test_funcmatrix():
  36. i, j = symbols('i,j')
  37. X = FunctionMatrix(3, 3, Lambda((i, j), i - j))
  38. assert X[1, 1] == 0
  39. assert X[1, 2] == -1
  40. assert X.shape == (3, 3)
  41. assert X.rows == X.cols == 3
  42. assert Matrix(X) == Matrix(3, 3, lambda i, j: i - j)
  43. assert isinstance(X*X + X, MatrixExpr)
  44. def test_replace_issue():
  45. X = FunctionMatrix(3, 3, KroneckerDelta)
  46. assert X.replace(lambda x: True, lambda x: x) == X