图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

112 lines
3.0 KiB

  1. from sympy.core.symbol import symbols
  2. from sympy.printing.codeprinter import ccode
  3. from sympy.codegen.ast import Declaration, Variable, float64, int64, String, CodeBlock
  4. from sympy.codegen.cnodes import (
  5. alignof, CommaOperator, goto, Label, PreDecrement, PostDecrement, PreIncrement, PostIncrement,
  6. sizeof, union, struct
  7. )
  8. x, y = symbols('x y')
  9. def test_alignof():
  10. ax = alignof(x)
  11. assert ccode(ax) == 'alignof(x)'
  12. assert ax.func(*ax.args) == ax
  13. def test_CommaOperator():
  14. expr = CommaOperator(PreIncrement(x), 2*x)
  15. assert ccode(expr) == '(++(x), 2*x)'
  16. assert expr.func(*expr.args) == expr
  17. def test_goto_Label():
  18. s = 'early_exit'
  19. g = goto(s)
  20. assert g.func(*g.args) == g
  21. assert g != goto('foobar')
  22. assert ccode(g) == 'goto early_exit'
  23. l1 = Label(s)
  24. assert ccode(l1) == 'early_exit:'
  25. assert l1 == Label('early_exit')
  26. assert l1 != Label('foobar')
  27. body = [PreIncrement(x)]
  28. l2 = Label(s, body)
  29. assert l2.name == String("early_exit")
  30. assert l2.body == CodeBlock(PreIncrement(x))
  31. assert ccode(l2) == ("early_exit:\n"
  32. "++(x);")
  33. body = [PreIncrement(x), PreDecrement(y)]
  34. l2 = Label(s, body)
  35. assert l2.name == String("early_exit")
  36. assert l2.body == CodeBlock(PreIncrement(x), PreDecrement(y))
  37. assert ccode(l2) == ("early_exit:\n"
  38. "{\n ++(x);\n --(y);\n}")
  39. def test_PreDecrement():
  40. p = PreDecrement(x)
  41. assert p.func(*p.args) == p
  42. assert ccode(p) == '--(x)'
  43. def test_PostDecrement():
  44. p = PostDecrement(x)
  45. assert p.func(*p.args) == p
  46. assert ccode(p) == '(x)--'
  47. def test_PreIncrement():
  48. p = PreIncrement(x)
  49. assert p.func(*p.args) == p
  50. assert ccode(p) == '++(x)'
  51. def test_PostIncrement():
  52. p = PostIncrement(x)
  53. assert p.func(*p.args) == p
  54. assert ccode(p) == '(x)++'
  55. def test_sizeof():
  56. typename = 'unsigned int'
  57. sz = sizeof(typename)
  58. assert ccode(sz) == 'sizeof(%s)' % typename
  59. assert sz.func(*sz.args) == sz
  60. assert not sz.is_Atom
  61. assert sz.atoms() == {String('unsigned int'), String('sizeof')}
  62. def test_struct():
  63. vx, vy = Variable(x, type=float64), Variable(y, type=float64)
  64. s = struct('vec2', [vx, vy])
  65. assert s.func(*s.args) == s
  66. assert s == struct('vec2', (vx, vy))
  67. assert s != struct('vec2', (vy, vx))
  68. assert str(s.name) == 'vec2'
  69. assert len(s.declarations) == 2
  70. assert all(isinstance(arg, Declaration) for arg in s.declarations)
  71. assert ccode(s) == (
  72. "struct vec2 {\n"
  73. " double x;\n"
  74. " double y;\n"
  75. "}")
  76. def test_union():
  77. vx, vy = Variable(x, type=float64), Variable(y, type=int64)
  78. u = union('dualuse', [vx, vy])
  79. assert u.func(*u.args) == u
  80. assert u == union('dualuse', (vx, vy))
  81. assert str(u.name) == 'dualuse'
  82. assert len(u.declarations) == 2
  83. assert all(isinstance(arg, Declaration) for arg in u.declarations)
  84. assert ccode(u) == (
  85. "union dualuse {\n"
  86. " double x;\n"
  87. " int64_t y;\n"
  88. "}")