m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

2561 lines
111 KiB

6 months ago
  1. # Copyright (c) Microsoft Corporation. All rights reserved.
  2. # Licensed under the MIT License.
  3. # -*- coding: UTF-8 -*-
  4. import argparse
  5. import logging
  6. import numpy as np
  7. import onnx
  8. import sympy
  9. from onnx import helper, numpy_helper, shape_inference
  10. from packaging import version
  11. assert version.parse(onnx.__version__) >= version.parse("1.8.0")
  12. logger = logging.getLogger(__name__)
  13. def get_attribute(node, attr_name, default_value=None):
  14. found = [attr for attr in node.attribute if attr.name == attr_name]
  15. if found:
  16. return helper.get_attribute_value(found[0])
  17. return default_value
  18. def get_dim_from_proto(dim):
  19. return getattr(dim, dim.WhichOneof("value")) if type(dim.WhichOneof("value")) == str else None
  20. def is_sequence(type_proto):
  21. cls_type = type_proto.WhichOneof("value")
  22. assert cls_type in ["tensor_type", "sequence_type"]
  23. return cls_type == "sequence_type"
  24. def get_shape_from_type_proto(type_proto):
  25. assert not is_sequence(type_proto)
  26. if type_proto.tensor_type.HasField("shape"):
  27. return [get_dim_from_proto(d) for d in type_proto.tensor_type.shape.dim]
  28. else:
  29. return None # note no shape is different from shape without dim (scalar)
  30. def get_elem_type_from_type_proto(type_proto):
  31. if is_sequence(type_proto):
  32. return type_proto.sequence_type.elem_type.tensor_type.elem_type
  33. else:
  34. return type_proto.tensor_type.elem_type
  35. def get_shape_from_value_info(vi):
  36. cls_type = vi.type.WhichOneof("value")
  37. if cls_type is None:
  38. return None
  39. if is_sequence(vi.type):
  40. if "tensor_type" == vi.type.sequence_type.elem_type.WhichOneof("value"):
  41. return get_shape_from_type_proto(vi.type.sequence_type.elem_type)
  42. else:
  43. return None
  44. else:
  45. return get_shape_from_type_proto(vi.type)
  46. def make_named_value_info(name):
  47. vi = onnx.ValueInfoProto()
  48. vi.name = name
  49. return vi
  50. def get_shape_from_sympy_shape(sympy_shape):
  51. return [None if i is None else (int(i) if is_literal(i) else str(i)) for i in sympy_shape]
  52. def is_literal(dim):
  53. return type(dim) in [int, np.int64, np.int32, sympy.Integer] or (hasattr(dim, "is_number") and dim.is_number)
  54. def handle_negative_axis(axis, rank):
  55. assert axis < rank and axis >= -rank
  56. return axis if axis >= 0 else rank + axis
  57. def get_opset(mp, domain=None):
  58. domain = domain or ["", "onnx", "ai.onnx"]
  59. if type(domain) != list:
  60. domain = [domain]
  61. for opset in mp.opset_import:
  62. if opset.domain in domain:
  63. return opset.version
  64. return None
  65. def as_scalar(x):
  66. if type(x) == list:
  67. assert len(x) == 1
  68. return x[0]
  69. elif type(x) == np.ndarray:
  70. return x.item()
  71. else:
  72. return x
  73. def as_list(x, keep_none):
  74. if type(x) == list:
  75. return x
  76. elif type(x) == np.ndarray:
  77. return list(x)
  78. elif keep_none and x is None:
  79. return None
  80. else:
  81. return [x]
  82. def sympy_reduce_product(x):
  83. if type(x) == list:
  84. value = sympy.Integer(1)
  85. for v in x:
  86. value = value * v
  87. else:
  88. value = x
  89. return value
  90. class SymbolicShapeInference:
  91. def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""):
  92. self.dispatcher_ = {
  93. "Add": self._infer_symbolic_compute_ops,
  94. "ArrayFeatureExtractor": self._infer_ArrayFeatureExtractor,
  95. "AveragePool": self._infer_Pool,
  96. "BatchNormalization": self._infer_BatchNormalization,
  97. "Cast": self._infer_Cast,
  98. "CategoryMapper": self._infer_CategoryMapper,
  99. "Compress": self._infer_Compress,
  100. "Concat": self._infer_Concat,
  101. "ConcatFromSequence": self._infer_ConcatFromSequence,
  102. "Constant": self._infer_Constant,
  103. "ConstantOfShape": self._infer_ConstantOfShape,
  104. "Conv": self._infer_Conv,
  105. "CumSum": self._pass_on_shape_and_type,
  106. "Div": self._infer_symbolic_compute_ops,
  107. "Einsum": self._infer_Einsum,
  108. "Expand": self._infer_Expand,
  109. "Equal": self._infer_symbolic_compute_ops,
  110. "Floor": self._infer_symbolic_compute_ops,
  111. "Gather": self._infer_Gather,
  112. "GatherElements": self._infer_GatherElements,
  113. "GatherND": self._infer_GatherND,
  114. "Identity": self._pass_on_shape_and_type,
  115. "If": self._infer_If,
  116. "Loop": self._infer_Loop,
  117. "MatMul": self._infer_MatMul,
  118. "MatMulInteger16": self._infer_MatMulInteger,
  119. "MaxPool": self._infer_Pool,
  120. "Max": self._infer_symbolic_compute_ops,
  121. "Min": self._infer_symbolic_compute_ops,
  122. "Mul": self._infer_symbolic_compute_ops,
  123. "NonMaxSuppression": self._infer_NonMaxSuppression,
  124. "NonZero": self._infer_NonZero,
  125. "OneHot": self._infer_OneHot,
  126. "Pad": self._infer_Pad,
  127. "Range": self._infer_Range,
  128. "Reciprocal": self._pass_on_shape_and_type,
  129. "ReduceSum": self._infer_ReduceSum,
  130. "ReduceProd": self._infer_ReduceProd,
  131. "Reshape": self._infer_Reshape,
  132. "Resize": self._infer_Resize,
  133. "Round": self._pass_on_shape_and_type,
  134. "Scan": self._infer_Scan,
  135. "ScatterElements": self._infer_ScatterElements,
  136. "SequenceAt": self._infer_SequenceAt,
  137. "SequenceInsert": self._infer_SequenceInsert,
  138. "Shape": self._infer_Shape,
  139. "Size": self._infer_Size,
  140. "Slice": self._infer_Slice,
  141. "SoftmaxCrossEntropyLoss": self._infer_SoftmaxCrossEntropyLoss,
  142. "SoftmaxCrossEntropyLossInternal": self._infer_SoftmaxCrossEntropyLoss,
  143. "NegativeLogLikelihoodLossInternal": self._infer_SoftmaxCrossEntropyLoss,
  144. "Split": self._infer_Split,
  145. "SplitToSequence": self._infer_SplitToSequence,
  146. "Squeeze": self._infer_Squeeze,
  147. "Sub": self._infer_symbolic_compute_ops,
  148. "Tile": self._infer_Tile,
  149. "TopK": self._infer_TopK,
  150. "Transpose": self._infer_Transpose,
  151. "Unsqueeze": self._infer_Unsqueeze,
  152. "Where": self._infer_symbolic_compute_ops,
  153. "ZipMap": self._infer_ZipMap,
  154. "Neg": self._infer_symbolic_compute_ops,
  155. # contrib ops:
  156. "Attention": self._infer_Attention,
  157. "BiasGelu": self._infer_BiasGelu,
  158. "MultiHeadAttention": self._infer_MultiHeadAttention,
  159. "EmbedLayerNormalization": self._infer_EmbedLayerNormalization,
  160. "FastGelu": self._infer_FastGelu,
  161. "Gelu": self._infer_Gelu,
  162. "GemmFastGelu": self._infer_GemmFastGelu,
  163. "LayerNormalization": self._infer_LayerNormalization,
  164. "LongformerAttention": self._infer_LongformerAttention,
  165. "PythonOp": self._infer_PythonOp,
  166. "SkipLayerNormalization": self._infer_SkipLayerNormalization,
  167. "SkipSimplifiedLayerNormalization": self._infer_SkipLayerNormalization,
  168. "GroupNorm": self._infer_GroupNorm,
  169. "BiasSplitGelu": self._infer_BiasSplitGelu,
  170. "NhwcConv": self._infer_NhwcConv,
  171. }
  172. self.aten_op_dispatcher_ = {
  173. "embedding": self._infer_Gather,
  174. "bitwise_or": self._infer_aten_bitwise_or,
  175. "diagonal": self._infer_aten_diagonal,
  176. "max_pool2d_with_indices": self._infer_aten_pool2d,
  177. "max": self._infer_aten_minmax,
  178. "min": self._infer_aten_minmax,
  179. "multinomial": self._infer_aten_multinomial,
  180. "unfold": self._infer_aten_unfold,
  181. "argmax": self._infer_aten_argmax,
  182. "avg_pool2d": self._infer_aten_pool2d,
  183. "_adaptive_avg_pool2d": self._infer_aten_pool2d,
  184. "numpy_T": self._infer_Transpose,
  185. "native_group_norm": self._infer_aten_group_norm,
  186. "upsample_nearest1d": self._infer_aten_upsample_nearest,
  187. "upsample_nearest2d": self._infer_aten_upsample_nearest,
  188. "upsample_nearest3d": self._infer_aten_upsample_nearest,
  189. }
  190. self.run_ = True
  191. self.suggested_merge_ = {}
  192. self.symbolic_dims_ = {}
  193. self.input_symbols_ = {}
  194. self.auto_merge_ = auto_merge
  195. self.guess_output_rank_ = guess_output_rank
  196. self.verbose_ = verbose
  197. self.int_max_ = int_max
  198. self.subgraph_id_ = 0
  199. self.prefix_ = prefix
  200. def _add_suggested_merge(self, symbols, apply=False):
  201. assert all([(type(s) == str and s in self.symbolic_dims_) or is_literal(s) for s in symbols])
  202. symbols = set(symbols)
  203. for k, v in self.suggested_merge_.items():
  204. if k in symbols:
  205. symbols.remove(k)
  206. symbols.add(v)
  207. map_to = None
  208. # if there is literal, map to it first
  209. for s in symbols:
  210. if is_literal(s):
  211. map_to = s
  212. break
  213. # when no literals, map to input symbolic dims, then existing symbolic dims
  214. if map_to is None:
  215. for s in symbols:
  216. if s in self.input_symbols_:
  217. map_to = s
  218. break
  219. if map_to is None:
  220. for s in symbols:
  221. if type(self.symbolic_dims_[s]) == sympy.Symbol:
  222. map_to = s
  223. break
  224. # when nothing to map to, use the shorter one
  225. if map_to is None:
  226. if self.verbose_ > 0:
  227. logger.warning("Potential unsafe merge between symbolic expressions: ({})".format(",".join(symbols)))
  228. symbols_list = list(symbols)
  229. lens = [len(s) for s in symbols_list]
  230. map_to = symbols_list[lens.index(min(lens))]
  231. symbols.remove(map_to)
  232. for s in symbols:
  233. if s == map_to:
  234. continue
  235. if is_literal(map_to) and is_literal(s):
  236. assert int(map_to) == int(s)
  237. self.suggested_merge_[s] = int(map_to) if is_literal(map_to) else map_to
  238. for k, v in self.suggested_merge_.items():
  239. if v == s:
  240. self.suggested_merge_[k] = map_to
  241. if apply and self.auto_merge_:
  242. self._apply_suggested_merge()
  243. def _apply_suggested_merge(self, graph_input_only=False):
  244. if not self.suggested_merge_:
  245. return
  246. for i in list(self.out_mp_.graph.input) + ([] if graph_input_only else list(self.out_mp_.graph.value_info)):
  247. for d in i.type.tensor_type.shape.dim:
  248. if d.dim_param in self.suggested_merge_:
  249. v = self.suggested_merge_[d.dim_param]
  250. if is_literal(v):
  251. d.dim_value = int(v)
  252. else:
  253. d.dim_param = v
  254. def _preprocess(self, in_mp):
  255. self.out_mp_ = onnx.ModelProto()
  256. self.out_mp_.CopyFrom(in_mp)
  257. self.graph_inputs_ = dict([(i.name, i) for i in list(self.out_mp_.graph.input)])
  258. self.initializers_ = dict([(i.name, i) for i in self.out_mp_.graph.initializer])
  259. self.known_vi_ = dict([(i.name, i) for i in list(self.out_mp_.graph.input)])
  260. self.known_vi_.update(
  261. dict(
  262. [
  263. (
  264. i.name,
  265. helper.make_tensor_value_info(i.name, i.data_type, list(i.dims)),
  266. )
  267. for i in self.out_mp_.graph.initializer
  268. ]
  269. )
  270. )
  271. def _merge_symbols(self, dims):
  272. if not all([type(d) == str for d in dims]):
  273. if self.auto_merge_:
  274. unique_dims = list(set(dims))
  275. is_int = [is_literal(d) for d in unique_dims]
  276. assert sum(is_int) <= 1 # if there are more than 1 unique ints, something is wrong
  277. if sum(is_int) == 1:
  278. int_dim = is_int.index(1)
  279. if self.verbose_ > 0:
  280. logger.debug(
  281. "dim {} has been merged with value {}".format(
  282. unique_dims[:int_dim] + unique_dims[int_dim + 1 :],
  283. unique_dims[int_dim],
  284. )
  285. )
  286. self._check_merged_dims(unique_dims, allow_broadcast=False)
  287. return unique_dims[int_dim]
  288. else:
  289. if self.verbose_ > 0:
  290. logger.debug("dim {} has been mergd with dim {}".format(unique_dims[1:], unique_dims[0]))
  291. return dims[0]
  292. else:
  293. return None
  294. if all([d == dims[0] for d in dims]):
  295. return dims[0]
  296. merged = [self.suggested_merge_[d] if d in self.suggested_merge_ else d for d in dims]
  297. if all([d == merged[0] for d in merged]):
  298. assert merged[0] in self.symbolic_dims_
  299. return merged[0]
  300. else:
  301. return None
  302. # broadcast from right to left, and merge symbolic dims if needed
  303. def _broadcast_shapes(self, shape1, shape2):
  304. new_shape = []
  305. rank1 = len(shape1)
  306. rank2 = len(shape2)
  307. new_rank = max(rank1, rank2)
  308. for i in range(new_rank):
  309. dim1 = shape1[rank1 - 1 - i] if i < rank1 else 1
  310. dim2 = shape2[rank2 - 1 - i] if i < rank2 else 1
  311. if dim1 == 1 or dim1 == dim2:
  312. new_dim = dim2
  313. elif dim2 == 1:
  314. new_dim = dim1
  315. else:
  316. new_dim = self._merge_symbols([dim1, dim2])
  317. if not new_dim:
  318. # warning about unsupported broadcast when not auto merge
  319. # note that auto merge has the risk of incorrectly merge symbols while one of them being 1
  320. # for example, 'a' = 1, 'b' = 5 at runtime is valid broadcasting, but with auto merge 'a' == 'b'
  321. if self.auto_merge_:
  322. self._add_suggested_merge([dim1, dim2], apply=True)
  323. else:
  324. logger.warning("unsupported broadcast between " + str(dim1) + " " + str(dim2))
  325. new_shape = [new_dim] + new_shape
  326. return new_shape
  327. def _get_shape(self, node, idx):
  328. name = node.input[idx]
  329. if name in self.known_vi_:
  330. vi = self.known_vi_[name]
  331. return get_shape_from_value_info(vi)
  332. else:
  333. assert name in self.initializers_
  334. return list(self.initializers_[name].dims)
  335. def _get_shape_rank(self, node, idx):
  336. return len(self._get_shape(node, idx))
  337. def _get_sympy_shape(self, node, idx):
  338. sympy_shape = []
  339. for d in self._get_shape(node, idx):
  340. if type(d) == str:
  341. sympy_shape.append(
  342. self.symbolic_dims_[d]
  343. if d in self.symbolic_dims_
  344. else sympy.Symbol(d, integer=True, nonnegative=True)
  345. )
  346. else:
  347. assert None != d
  348. sympy_shape.append(d)
  349. return sympy_shape
  350. def _get_value(self, node, idx):
  351. name = node.input[idx]
  352. assert name in self.sympy_data_ or name in self.initializers_
  353. return self.sympy_data_[name] if name in self.sympy_data_ else numpy_helper.to_array(self.initializers_[name])
  354. def _try_get_value(self, node, idx):
  355. if idx >= len(node.input):
  356. return None
  357. name = node.input[idx]
  358. if name in self.sympy_data_ or name in self.initializers_:
  359. return self._get_value(node, idx)
  360. return None
  361. def _update_computed_dims(self, new_sympy_shape):
  362. for i, new_dim in enumerate(new_sympy_shape):
  363. if not is_literal(new_dim) and not type(new_dim) == str:
  364. str_dim = str(new_dim)
  365. if str_dim in self.suggested_merge_:
  366. if is_literal(self.suggested_merge_[str_dim]):
  367. continue # no need to create dim for literals
  368. new_sympy_shape[i] = self.symbolic_dims_[self.suggested_merge_[str_dim]]
  369. else:
  370. # add new_dim if it's a computational expression
  371. if not str(new_dim) in self.symbolic_dims_:
  372. self.symbolic_dims_[str(new_dim)] = new_dim
  373. def _onnx_infer_single_node(self, node):
  374. # skip onnx shape inference for some ops, as they are handled in _infer_*
  375. skip_infer = node.op_type in [
  376. "If",
  377. "Loop",
  378. "Scan",
  379. "SplitToSequence",
  380. "ZipMap", # contrib ops
  381. "Attention",
  382. "BiasGelu",
  383. "EmbedLayerNormalization",
  384. "FastGelu",
  385. "Gelu",
  386. "GemmFastGelu",
  387. "LayerNormalization",
  388. "LongformerAttention",
  389. "SkipLayerNormalization",
  390. "PythonOp",
  391. "MultiHeadAttention",
  392. "GroupNorm",
  393. "BiasSplitGelu",
  394. "NhwcConv",
  395. ]
  396. if not skip_infer:
  397. # Only pass initializers that satisfy the following condition:
  398. # (1) Operator need value of some input for shape inference.
  399. # For example, Unsqueeze in opset 13 uses the axes input to calculate shape of output.
  400. # (2) opset version >= 9. In older version, initializer is required in graph input by onnx spec.
  401. # (3) The initializer is not in graph input. The means the node input is "constant" in inference.
  402. initializers = []
  403. if (get_opset(self.out_mp_) >= 9) and node.op_type in ["Unsqueeze"]:
  404. initializers = [
  405. self.initializers_[name]
  406. for name in node.input
  407. if (name in self.initializers_ and name not in self.graph_inputs_)
  408. ]
  409. # run single node inference with self.known_vi_ shapes
  410. tmp_graph = helper.make_graph(
  411. [node],
  412. "tmp",
  413. [self.known_vi_[i] for i in node.input if i],
  414. [make_named_value_info(i) for i in node.output],
  415. initializers,
  416. )
  417. self.tmp_mp_.graph.CopyFrom(tmp_graph)
  418. self.tmp_mp_ = shape_inference.infer_shapes(self.tmp_mp_)
  419. for i_o in range(len(node.output)):
  420. o = node.output[i_o]
  421. vi = self.out_mp_.graph.value_info.add()
  422. if not skip_infer:
  423. vi.CopyFrom(self.tmp_mp_.graph.output[i_o])
  424. else:
  425. vi.name = o
  426. self.known_vi_[o] = vi
  427. def _onnx_infer_subgraph(self, node, subgraph, use_node_input=True, inc_subgraph_id=True):
  428. if self.verbose_ > 2:
  429. logger.debug(
  430. "Inferencing subgraph of node {} with output({}...): {}".format(node.name, node.output[0], node.op_type)
  431. )
  432. # node inputs are not passed directly to the subgraph
  433. # it's up to the node dispatcher to prepare subgraph input
  434. # for example, with Scan/Loop, subgraph input shape would be trimmed from node input shape
  435. # besides, inputs in subgraph could shadow implicit inputs
  436. subgraph_inputs = set([i.name for i in list(subgraph.initializer) + list(subgraph.input)])
  437. subgraph_implicit_input = set([name for name in self.known_vi_.keys() if not name in subgraph_inputs])
  438. tmp_graph = helper.make_graph(
  439. list(subgraph.node),
  440. "tmp",
  441. list(subgraph.input) + [self.known_vi_[i] for i in subgraph_implicit_input],
  442. [make_named_value_info(i.name) for i in subgraph.output],
  443. )
  444. tmp_graph.initializer.extend([i for i in self.out_mp_.graph.initializer if i.name in subgraph_implicit_input])
  445. tmp_graph.initializer.extend(subgraph.initializer)
  446. self.tmp_mp_.graph.CopyFrom(tmp_graph)
  447. symbolic_shape_inference = SymbolicShapeInference(
  448. self.int_max_,
  449. self.auto_merge_,
  450. self.guess_output_rank_,
  451. self.verbose_,
  452. prefix=self.prefix_ + "_" + str(self.subgraph_id_),
  453. )
  454. if inc_subgraph_id:
  455. self.subgraph_id_ += 1
  456. all_shapes_inferred = False
  457. symbolic_shape_inference._preprocess(self.tmp_mp_)
  458. symbolic_shape_inference.suggested_merge_ = self.suggested_merge_.copy()
  459. while symbolic_shape_inference.run_:
  460. all_shapes_inferred = symbolic_shape_inference._infer_impl(self.sympy_data_.copy())
  461. symbolic_shape_inference._update_output_from_vi()
  462. if use_node_input:
  463. # if subgraph uses node input, it needs to update to merged dims
  464. subgraph.ClearField("input")
  465. subgraph.input.extend(symbolic_shape_inference.out_mp_.graph.input[: len(node.input)])
  466. subgraph.ClearField("output")
  467. subgraph.output.extend(symbolic_shape_inference.out_mp_.graph.output)
  468. subgraph.ClearField("value_info")
  469. subgraph.value_info.extend(symbolic_shape_inference.out_mp_.graph.value_info)
  470. subgraph.ClearField("node")
  471. subgraph.node.extend(symbolic_shape_inference.out_mp_.graph.node)
  472. # for new symbolic dims from subgraph output, add to main graph symbolic dims
  473. subgraph_shapes = [get_shape_from_value_info(o) for o in symbolic_shape_inference.out_mp_.graph.output]
  474. subgraph_new_symbolic_dims = set(
  475. [d for s in subgraph_shapes if s for d in s if type(d) == str and not d in self.symbolic_dims_]
  476. )
  477. new_dims = {}
  478. for d in subgraph_new_symbolic_dims:
  479. assert d in symbolic_shape_inference.symbolic_dims_
  480. new_dims[d] = symbolic_shape_inference.symbolic_dims_[d]
  481. self.symbolic_dims_.update(new_dims)
  482. return symbolic_shape_inference
  483. def _get_int_values(self, node, broadcast=False):
  484. values = [self._try_get_value(node, i) for i in range(len(node.input))]
  485. if all([v is not None for v in values]):
  486. # some shape compute is in floating point, cast to int for sympy
  487. for i, v in enumerate(values):
  488. if type(v) != np.ndarray:
  489. continue
  490. if len(v.shape) > 1:
  491. new_v = None # ignore value for rank > 1
  492. elif len(v.shape) == 0:
  493. new_v = int(v.item())
  494. else:
  495. assert len(v.shape) == 1
  496. new_v = [int(vv) for vv in v]
  497. values[i] = new_v
  498. values_len = [len(v) if type(v) == list else 0 for v in values]
  499. max_len = max(values_len)
  500. if max_len >= 1 and broadcast:
  501. # broadcast
  502. for i, v in enumerate(values):
  503. if v is None:
  504. continue # don't broadcast if value is unknown
  505. if type(v) == list:
  506. if len(v) < max_len:
  507. values[i] = v * max_len
  508. else:
  509. assert len(v) == max_len
  510. else:
  511. values[i] = [v] * max_len
  512. return values
  513. def _compute_on_sympy_data(self, node, op_func):
  514. assert len(node.output) == 1
  515. values = self._get_int_values(node, broadcast=True)
  516. if all([v is not None for v in values]):
  517. is_list = [type(v) == list for v in values]
  518. as_list = any(is_list)
  519. if as_list:
  520. self.sympy_data_[node.output[0]] = [op_func(vs) for vs in zip(*values)]
  521. else:
  522. self.sympy_data_[node.output[0]] = op_func(values)
  523. def _pass_on_sympy_data(self, node):
  524. assert len(node.input) == 1 or node.op_type in [
  525. "Reshape",
  526. "Unsqueeze",
  527. "Squeeze",
  528. ]
  529. self._compute_on_sympy_data(node, lambda x: x[0])
  530. def _pass_on_shape_and_type(self, node):
  531. vi = self.known_vi_[node.output[0]]
  532. vi.CopyFrom(
  533. helper.make_tensor_value_info(
  534. node.output[0],
  535. get_elem_type_from_type_proto(self.known_vi_[node.input[0]].type),
  536. self._get_shape(node, 0),
  537. )
  538. )
  539. def _new_symbolic_dim(self, prefix, dim):
  540. new_dim = "{}_d{}".format(prefix, dim)
  541. if new_dim in self.suggested_merge_:
  542. v = self.suggested_merge_[new_dim]
  543. new_symbolic_dim = sympy.Integer(int(v)) if is_literal(v) else v
  544. else:
  545. new_symbolic_dim = sympy.Symbol(new_dim, integer=True, nonnegative=True)
  546. self.symbolic_dims_[new_dim] = new_symbolic_dim
  547. return new_symbolic_dim
  548. def _new_symbolic_dim_from_output(self, node, out_idx=0, dim=0):
  549. return self._new_symbolic_dim(
  550. "{}{}_{}_o{}_".format(
  551. node.op_type,
  552. self.prefix_,
  553. list(self.out_mp_.graph.node).index(node),
  554. out_idx,
  555. ),
  556. dim,
  557. )
  558. def _new_symbolic_shape(self, rank, node, out_idx=0):
  559. return [self._new_symbolic_dim_from_output(node, out_idx, i) for i in range(rank)]
  560. def _compute_conv_pool_shape(self, node, channels_last=False):
  561. sympy_shape = self._get_sympy_shape(node, 0)
  562. if len(node.input) > 1:
  563. W_shape = self._get_sympy_shape(node, 1)
  564. rank = len(W_shape) - 2 # number of spatial axes
  565. kernel_shape = W_shape[-rank - 1 : -1] if channels_last else W_shape[-rank:]
  566. sympy_shape[3 if channels_last else 1] = W_shape[0]
  567. else:
  568. W_shape = None
  569. kernel_shape = get_attribute(node, "kernel_shape")
  570. rank = len(kernel_shape)
  571. assert len(sympy_shape) == rank + 2
  572. # only need to symbolic shape inference if input has symbolic dims in spatial axes
  573. spatial_shape = sympy_shape[-rank - 1 : -1] if channels_last else sympy_shape[-rank:]
  574. is_symbolic_dims = [not is_literal(i) for i in spatial_shape]
  575. if not any(is_symbolic_dims):
  576. shape = get_shape_from_value_info(self.known_vi_[node.output[0]])
  577. if len(shape) > 0:
  578. assert len(sympy_shape) == len(shape)
  579. if channels_last:
  580. sympy_shape[-rank - 1 : -1] = [sympy.Integer(d) for d in shape[-rank - 1 : -1]]
  581. else:
  582. sympy_shape[-rank:] = [sympy.Integer(d) for d in shape[-rank:]]
  583. return sympy_shape
  584. dilations = get_attribute(node, "dilations", [1] * rank)
  585. strides = get_attribute(node, "strides", [1] * rank)
  586. effective_kernel_shape = [(k - 1) * d + 1 for k, d in zip(kernel_shape, dilations)]
  587. pads = get_attribute(node, "pads")
  588. if pads is None:
  589. pads = [0] * (2 * rank)
  590. auto_pad = get_attribute(node, "auto_pad", b"NOTSET").decode("utf-8")
  591. if auto_pad != "VALID" and auto_pad != "NOTSET":
  592. try:
  593. residual = [sympy.Mod(d, s) for d, s in zip(sympy_shape[-rank:], strides)]
  594. total_pads = [
  595. max(0, (k - s) if r == 0 else (k - r))
  596. for k, s, r in zip(effective_kernel_shape, strides, residual)
  597. ]
  598. except TypeError: # sympy may throw TypeError: cannot determine truth value of Relational
  599. total_pads = [
  600. max(0, (k - s)) for k, s in zip(effective_kernel_shape, strides)
  601. ] # assuming no residual if sympy throws error
  602. elif auto_pad == "VALID":
  603. total_pads = []
  604. else:
  605. total_pads = [0] * rank
  606. else:
  607. assert len(pads) == 2 * rank
  608. total_pads = [p1 + p2 for p1, p2 in zip(pads[:rank], pads[rank:])]
  609. ceil_mode = get_attribute(node, "ceil_mode", 0)
  610. for i in range(rank):
  611. effective_input_size = sympy_shape[-rank + i + (-1 if channels_last else 0)]
  612. if len(total_pads) > 0:
  613. effective_input_size = effective_input_size + total_pads[i]
  614. if ceil_mode:
  615. strided_kernel_positions = sympy.ceiling(
  616. (effective_input_size - effective_kernel_shape[i]) / strides[i]
  617. )
  618. else:
  619. strided_kernel_positions = (effective_input_size - effective_kernel_shape[i]) // strides[i]
  620. sympy_shape[-rank + i + (-1 if channels_last else 0)] = strided_kernel_positions + 1
  621. return sympy_shape
  622. def _check_merged_dims(self, dims, allow_broadcast=True):
  623. if allow_broadcast:
  624. dims = [d for d in dims if not (is_literal(d) and int(d) <= 1)]
  625. if not all([d == dims[0] for d in dims]):
  626. self._add_suggested_merge(dims, apply=True)
  627. def _compute_matmul_shape(self, node, output_dtype=None):
  628. lhs_shape = self._get_shape(node, 0)
  629. rhs_shape = self._get_shape(node, 1)
  630. lhs_rank = len(lhs_shape)
  631. rhs_rank = len(rhs_shape)
  632. lhs_reduce_dim = 0
  633. rhs_reduce_dim = 0
  634. assert lhs_rank > 0 and rhs_rank > 0
  635. if lhs_rank == 1 and rhs_rank == 1:
  636. new_shape = []
  637. elif lhs_rank == 1:
  638. rhs_reduce_dim = -2
  639. new_shape = rhs_shape[:rhs_reduce_dim] + [rhs_shape[-1]]
  640. elif rhs_rank == 1:
  641. lhs_reduce_dim = -1
  642. new_shape = lhs_shape[:lhs_reduce_dim]
  643. else:
  644. lhs_reduce_dim = -1
  645. rhs_reduce_dim = -2
  646. new_shape = self._broadcast_shapes(lhs_shape[:-2], rhs_shape[:-2]) + [lhs_shape[-2]] + [rhs_shape[-1]]
  647. # merge reduce dim
  648. self._check_merged_dims(
  649. [lhs_shape[lhs_reduce_dim], rhs_shape[rhs_reduce_dim]],
  650. allow_broadcast=False,
  651. )
  652. if output_dtype is None:
  653. # infer output_dtype from input type when not specified
  654. output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
  655. vi = self.known_vi_[node.output[0]]
  656. vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape))
  657. def _fuse_tensor_type(self, node, out_idx, dst_type, src_type):
  658. """
  659. update dst_tensor_type to be compatible with src_tensor_type when dimension mismatches
  660. """
  661. dst_tensor_type = (
  662. dst_type.sequence_type.elem_type.tensor_type if is_sequence(dst_type) else dst_type.tensor_type
  663. )
  664. src_tensor_type = (
  665. src_type.sequence_type.elem_type.tensor_type if is_sequence(src_type) else src_type.tensor_type
  666. )
  667. if dst_tensor_type.elem_type != src_tensor_type.elem_type:
  668. node_id = node.name if node.name else node.op_type
  669. raise ValueError(
  670. f"For node {node_id}, dst_tensor_type.elem_type != src_tensor_type.elem_type: "
  671. f"{onnx.onnx_pb.TensorProto.DataType.Name(dst_tensor_type.elem_type)} vs "
  672. f"{onnx.onnx_pb.TensorProto.DataType.Name(src_tensor_type.elem_type)}"
  673. )
  674. if dst_tensor_type.HasField("shape"):
  675. for di, ds in enumerate(zip(dst_tensor_type.shape.dim, src_tensor_type.shape.dim)):
  676. if ds[0] != ds[1]:
  677. # create a new symbolic dimension for node/out_idx/mismatch dim id in dst_tensor_type for tensor_type
  678. # for sequence_type, clear the dimension
  679. new_dim = onnx.TensorShapeProto.Dimension()
  680. if not is_sequence(dst_type):
  681. new_dim.dim_param = str(self._new_symbolic_dim_from_output(node, out_idx, di))
  682. dst_tensor_type.shape.dim[di].CopyFrom(new_dim)
  683. else:
  684. dst_tensor_type.CopyFrom(src_tensor_type)
  685. def _infer_ArrayFeatureExtractor(self, node):
  686. data_shape = self._get_shape(node, 0)
  687. indices_shape = self._get_shape(node, 1)
  688. vi = self.known_vi_[node.output[0]]
  689. vi.CopyFrom(
  690. helper.make_tensor_value_info(
  691. node.output[0],
  692. self.known_vi_[node.input[0]].type.tensor_type.elem_type,
  693. data_shape[:-1] + indices_shape,
  694. )
  695. )
  696. def _infer_symbolic_compute_ops(self, node):
  697. funcs = {
  698. "Add": lambda l: l[0] + l[1],
  699. "Div": lambda l: l[0] // l[1], # integer div in sympy
  700. "Equal": lambda l: l[0] == l[1],
  701. "Floor": lambda l: sympy.floor(l[0]),
  702. "Max": lambda l: l[1]
  703. if is_literal(l[0]) and int(l[0]) < -self.int_max_
  704. else (l[0] if is_literal(l[1]) and int(l[1]) < -self.int_max_ else sympy.Max(l[0], l[1])),
  705. "Min": lambda l: l[1]
  706. if is_literal(l[0]) and int(l[0]) > self.int_max_
  707. else (l[0] if is_literal(l[1]) and int(l[1]) > self.int_max_ else sympy.Min(l[0], l[1])),
  708. "Mul": lambda l: l[0] * l[1],
  709. "Sub": lambda l: l[0] - l[1],
  710. "Where": lambda l: l[1] if l[0] else l[2],
  711. "Neg": lambda l: -l[0],
  712. }
  713. assert node.op_type in funcs
  714. self._compute_on_sympy_data(node, funcs[node.op_type])
  715. def _infer_Cast(self, node):
  716. self._pass_on_sympy_data(node)
  717. def _infer_CategoryMapper(self, node):
  718. input_type = self.known_vi_[node.input[0]].type.tensor_type.elem_type
  719. if input_type == onnx.TensorProto.STRING:
  720. output_type = onnx.TensorProto.INT64
  721. else:
  722. output_type = onnx.TensorProto.STRING
  723. vi = self.known_vi_[node.output[0]]
  724. vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_type, self._get_shape(node, 0)))
  725. def _infer_Compress(self, node):
  726. input_shape = self._get_shape(node, 0)
  727. # create a new symbolic dimension for Compress output
  728. compress_len = str(self._new_symbolic_dim_from_output(node))
  729. axis = get_attribute(node, "axis")
  730. if axis == None:
  731. # when axis is not specified, input is flattened before compress so output is 1D
  732. output_shape = [compress_len]
  733. else:
  734. output_shape = input_shape
  735. output_shape[handle_negative_axis(axis, len(input_shape))] = compress_len
  736. vi = self.known_vi_[node.output[0]]
  737. vi.CopyFrom(
  738. helper.make_tensor_value_info(
  739. node.output[0],
  740. self.known_vi_[node.input[0]].type.tensor_type.elem_type,
  741. output_shape,
  742. )
  743. )
  744. def _infer_Concat(self, node):
  745. if any([i in self.sympy_data_ or i in self.initializers_ for i in node.input]):
  746. values = self._get_int_values(node)
  747. if all([v is not None for v in values]):
  748. assert 0 == get_attribute(node, "axis")
  749. self.sympy_data_[node.output[0]] = []
  750. for i in range(len(node.input)):
  751. value = values[i]
  752. if type(value) == list:
  753. self.sympy_data_[node.output[0]].extend(value)
  754. else:
  755. self.sympy_data_[node.output[0]].append(value)
  756. sympy_shape = self._get_sympy_shape(node, 0)
  757. axis = handle_negative_axis(get_attribute(node, "axis"), len(sympy_shape))
  758. for i_idx in range(1, len(node.input)):
  759. input_shape = self._get_sympy_shape(node, i_idx)
  760. if input_shape:
  761. sympy_shape[axis] = sympy_shape[axis] + input_shape[axis]
  762. self._update_computed_dims(sympy_shape)
  763. # merge symbolic dims for non-concat axes
  764. for d in range(len(sympy_shape)):
  765. if d == axis:
  766. continue
  767. dims = [self._get_shape(node, i_idx)[d] for i_idx in range(len(node.input)) if self._get_shape(node, i_idx)]
  768. if all([d == dims[0] for d in dims]):
  769. continue
  770. merged = self._merge_symbols(dims)
  771. if type(merged) == str:
  772. sympy_shape[d] = self.symbolic_dims_[merged] if merged else None
  773. else:
  774. sympy_shape[d] = merged
  775. vi = self.known_vi_[node.output[0]]
  776. vi.CopyFrom(
  777. helper.make_tensor_value_info(
  778. node.output[0],
  779. self.known_vi_[node.input[0]].type.tensor_type.elem_type,
  780. get_shape_from_sympy_shape(sympy_shape),
  781. )
  782. )
  783. def _infer_ConcatFromSequence(self, node):
  784. seq_shape = self._get_shape(node, 0)
  785. new_axis = 1 if get_attribute(node, "new_axis") else 0
  786. axis = handle_negative_axis(get_attribute(node, "axis"), len(seq_shape) + new_axis)
  787. concat_dim = str(self._new_symbolic_dim_from_output(node, 0, axis))
  788. new_shape = seq_shape
  789. if new_axis:
  790. new_shape = seq_shape[:axis] + [concat_dim] + seq_shape[axis:]
  791. else:
  792. new_shape[axis] = concat_dim
  793. vi = self.known_vi_[node.output[0]]
  794. vi.CopyFrom(
  795. helper.make_tensor_value_info(
  796. node.output[0],
  797. self.known_vi_[node.input[0]].type.sequence_type.elem_type.tensor_type.elem_type,
  798. new_shape,
  799. )
  800. )
  801. def _infer_Constant(self, node):
  802. t = get_attribute(node, "value")
  803. self.sympy_data_[node.output[0]] = numpy_helper.to_array(t)
  804. def _infer_ConstantOfShape(self, node):
  805. sympy_shape = self._get_int_values(node)[0]
  806. vi = self.known_vi_[node.output[0]]
  807. if sympy_shape is not None:
  808. if type(sympy_shape) != list:
  809. sympy_shape = [sympy_shape]
  810. self._update_computed_dims(sympy_shape)
  811. # update sympy data if output type is int, and shape is known
  812. if vi.type.tensor_type.elem_type == onnx.TensorProto.INT64 and all([is_literal(x) for x in sympy_shape]):
  813. self.sympy_data_[node.output[0]] = np.ones(
  814. [int(x) for x in sympy_shape], dtype=np.int64
  815. ) * numpy_helper.to_array(get_attribute(node, "value", 0))
  816. else:
  817. # create new dynamic shape
  818. # note input0 is a 1D vector of shape, the new symbolic shape has the rank of the shape vector length
  819. sympy_shape = self._new_symbolic_shape(self._get_shape(node, 0)[0], node)
  820. vi.CopyFrom(
  821. helper.make_tensor_value_info(
  822. node.output[0],
  823. vi.type.tensor_type.elem_type,
  824. get_shape_from_sympy_shape(sympy_shape),
  825. )
  826. )
  827. def _infer_Conv(self, node):
  828. sympy_shape = self._compute_conv_pool_shape(node)
  829. self._update_computed_dims(sympy_shape)
  830. vi = self.known_vi_[node.output[0]]
  831. vi.CopyFrom(
  832. helper.make_tensor_value_info(
  833. node.output[0],
  834. vi.type.tensor_type.elem_type,
  835. get_shape_from_sympy_shape(sympy_shape),
  836. )
  837. )
  838. def _infer_NhwcConv(self, node):
  839. sympy_shape = self._compute_conv_pool_shape(node, channels_last=True)
  840. self._update_computed_dims(sympy_shape)
  841. vi = self.known_vi_[node.output[0]]
  842. vi.CopyFrom(
  843. helper.make_tensor_value_info(
  844. node.output[0],
  845. self.known_vi_[node.input[0]].type.tensor_type.elem_type,
  846. get_shape_from_sympy_shape(sympy_shape),
  847. )
  848. )
  849. def _infer_Einsum(self, node):
  850. # ref:https://github.com/onnx/onnx/blob/623dfaa0151b2e4ce49779c3ec31cbd78c592b80/onnx/defs/math/defs.cc#L3275
  851. equation = get_attribute(node, "equation")
  852. equation = equation.replace(b" ", b"")
  853. mid_index = equation.find(b"->")
  854. left_equation = equation[:mid_index] if mid_index != -1 else equation
  855. num_operands = 0
  856. num_ellipsis = 0
  857. num_ellipsis_indices = 0
  858. letter_to_dim = {}
  859. terms = left_equation.split(b",")
  860. for term in terms:
  861. ellipsis_index = term.find(b"...")
  862. shape = self._get_shape(node, num_operands)
  863. rank = len(shape)
  864. if ellipsis_index != -1:
  865. if num_ellipsis == 0:
  866. num_ellipsis_indices = rank - len(term) + 3
  867. num_ellipsis = num_ellipsis + 1
  868. for i in range(1, rank + 1):
  869. letter = term[-i]
  870. if letter != 46: # letter != b'.'
  871. dim = shape[-i]
  872. if letter not in letter_to_dim.keys():
  873. letter_to_dim[letter] = dim
  874. elif type(dim) != sympy.Symbol:
  875. letter_to_dim[letter] = dim
  876. num_operands = num_operands + 1
  877. new_sympy_shape = []
  878. from collections import OrderedDict
  879. num_letter_occurrences = OrderedDict()
  880. if mid_index != -1:
  881. right_equation = equation[mid_index + 2 :]
  882. right_ellipsis_index = right_equation.find(b"...")
  883. if right_ellipsis_index != -1:
  884. for i in range(num_ellipsis_indices):
  885. new_sympy_shape.append(shape[i])
  886. for c in right_equation:
  887. if c != 46: # c != b'.'
  888. new_sympy_shape.append(letter_to_dim[c])
  889. else:
  890. for i in range(num_ellipsis_indices):
  891. new_sympy_shape.append(shape[i])
  892. for c in left_equation:
  893. if c != 44 and c != 46: # c != b',' and c != b'.':
  894. if c in num_letter_occurrences:
  895. num_letter_occurrences[c] = num_letter_occurrences[c] + 1
  896. else:
  897. num_letter_occurrences[c] = 1
  898. for key, value in num_letter_occurrences.items():
  899. if value == 1:
  900. new_sympy_shape.append(letter_to_dim[key])
  901. output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
  902. vi = self.known_vi_[node.output[0]]
  903. vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_sympy_shape))
  904. def _infer_Expand(self, node):
  905. expand_to_shape = as_list(self._try_get_value(node, 1), keep_none=True)
  906. if expand_to_shape is not None:
  907. # new_shape's dim can come from shape value
  908. self._update_computed_dims(expand_to_shape)
  909. shape = self._get_shape(node, 0)
  910. new_shape = self._broadcast_shapes(shape, get_shape_from_sympy_shape(expand_to_shape))
  911. vi = self.known_vi_[node.output[0]]
  912. vi.CopyFrom(
  913. helper.make_tensor_value_info(
  914. node.output[0],
  915. self.known_vi_[node.input[0]].type.tensor_type.elem_type,
  916. new_shape,
  917. )
  918. )
  919. def _infer_Gather(self, node):
  920. data_shape = self._get_shape(node, 0)
  921. axis = handle_negative_axis(get_attribute(node, "axis", 0), len(data_shape))
  922. indices_shape = self._get_shape(node, 1)
  923. vi = self.known_vi_[node.output[0]]
  924. vi.CopyFrom(
  925. helper.make_tensor_value_info(
  926. node.output[0],
  927. self.known_vi_[node.input[0]].type.tensor_type.elem_type,
  928. data_shape[:axis] + indices_shape + data_shape[axis + 1 :],
  929. )
  930. )
  931. # for 1D input, do some sympy compute
  932. if node.input[0] in self.sympy_data_ and len(data_shape) == 1 and 0 == get_attribute(node, "axis", 0):
  933. idx = self._try_get_value(node, 1)
  934. if idx is not None:
  935. data = self.sympy_data_[node.input[0]]
  936. if type(data) == list:
  937. if type(idx) == np.ndarray and len(idx.shape) == 1:
  938. self.sympy_data_[node.output[0]] = [data[int(i)] for i in idx]
  939. else:
  940. self.sympy_data_[node.output[0]] = data[int(idx)]
  941. else:
  942. assert idx == 0 or idx == -1
  943. self.sympy_data_[node.output[0]] = data
  944. def _infer_GatherElements(self, node):
  945. indices_shape = self._get_shape(node, 1)
  946. vi = self.known_vi_[node.output[0]]
  947. vi.CopyFrom(
  948. helper.make_tensor_value_info(
  949. node.output[0],
  950. self.known_vi_[node.input[0]].type.tensor_type.elem_type,
  951. indices_shape,
  952. )
  953. )
  954. def _infer_GatherND(self, node):
  955. data_shape = self._get_shape(node, 0)
  956. data_rank = len(data_shape)
  957. indices_shape = self._get_shape(node, 1)
  958. indices_rank = len(indices_shape)
  959. last_index_dimension = indices_shape[-1]
  960. assert is_literal(last_index_dimension) and last_index_dimension <= data_rank
  961. new_shape = indices_shape[:-1] + data_shape[last_index_dimension:]
  962. vi = self.known_vi_[node.output[0]]
  963. vi.CopyFrom(
  964. helper.make_tensor_value_info(
  965. node.output[0],
  966. self.known_vi_[node.input[0]].type.tensor_type.elem_type,
  967. new_shape,
  968. )
  969. )
  970. def _infer_If(self, node):
  971. # special case for constant condition, in case there are mismatching shape from the non-executed branch
  972. subgraphs = [
  973. get_attribute(node, "then_branch"),
  974. get_attribute(node, "else_branch"),
  975. ]
  976. cond = self._try_get_value(node, 0)
  977. if cond is not None:
  978. if as_scalar(cond) > 0:
  979. subgraphs[1].CopyFrom(subgraphs[0])
  980. else:
  981. subgraphs[0].CopyFrom(subgraphs[1])
  982. for i_sub, subgraph in enumerate(subgraphs):
  983. subgraph_infer = self._onnx_infer_subgraph(node, subgraph, use_node_input=False)
  984. for i_out in range(len(node.output)):
  985. vi = self.known_vi_[node.output[i_out]]
  986. if i_sub == 0:
  987. vi.CopyFrom(subgraph.output[i_out])
  988. vi.name = node.output[i_out]
  989. else:
  990. self._fuse_tensor_type(node, i_out, vi.type, subgraph.output[i_out].type)
  991. # pass on sympy data from subgraph, if cond is constant
  992. if cond is not None and i_sub == (0 if as_scalar(cond) > 0 else 1):
  993. if subgraph.output[i_out].name in subgraph_infer.sympy_data_:
  994. self.sympy_data_[vi.name] = subgraph_infer.sympy_data_[subgraph.output[i_out].name]
  995. def _infer_Loop(self, node):
  996. subgraph = get_attribute(node, "body")
  997. assert len(subgraph.input) == len(node.input)
  998. num_loop_carried = len(node.input) - 2 # minus the length and initial loop condition
  999. # when sequence_type is used as loop carried input
  1000. # needs to run subgraph infer twice if the tensor shape in sequence contains None
  1001. for i, si in enumerate(subgraph.input):
  1002. si_name = si.name
  1003. si.CopyFrom(self.known_vi_[node.input[i]])
  1004. si.name = si_name
  1005. self._onnx_infer_subgraph(node, subgraph)
  1006. # check subgraph input/output for shape changes in loop carried variables
  1007. # for tensor_type, create new symbolic dim when changing, i.e., output = Concat(input, a)
  1008. # for sequence_type, propagate from output to input
  1009. need_second_infer = False
  1010. for i_out in range(1, num_loop_carried + 1):
  1011. so = subgraph.output[i_out]
  1012. so_shape = get_shape_from_value_info(so)
  1013. if is_sequence(so.type):
  1014. if so_shape and None in so_shape:
  1015. # copy shape from output to input
  1016. # note that loop input is [loop_len, cond, input_0, input_1, ...]
  1017. # while loop output is [cond, output_0, output_1, ...]
  1018. subgraph.input[i_out + 1].type.sequence_type.elem_type.CopyFrom(so.type.sequence_type.elem_type)
  1019. need_second_infer = True
  1020. else:
  1021. si = subgraph.input[i_out + 1]
  1022. si_shape = get_shape_from_value_info(si)
  1023. for di, dims in enumerate(zip(si_shape, so_shape)):
  1024. if dims[0] != dims[1]:
  1025. new_dim = onnx.TensorShapeProto.Dimension()
  1026. new_dim.dim_param = str(self._new_symbolic_dim_from_output(node, i_out, di))
  1027. si.type.tensor_type.shape.dim[di].CopyFrom(new_dim)
  1028. so.type.tensor_type.shape.dim[di].CopyFrom(new_dim)
  1029. need_second_infer = True
  1030. if need_second_infer:
  1031. if self.verbose_ > 2:
  1032. logger.debug(
  1033. "Rerun Loop: {}({}...), because of sequence in loop carried variables".format(
  1034. node.name, node.output[0]
  1035. )
  1036. )
  1037. self._onnx_infer_subgraph(node, subgraph, inc_subgraph_id=False)
  1038. # create a new symbolic dimension for iteration dependent dimension
  1039. loop_iter_dim = str(self._new_symbolic_dim_from_output(node))
  1040. for i in range(len(node.output)):
  1041. vi = self.known_vi_[node.output[i]]
  1042. vi.CopyFrom(subgraph.output[i + 1]) # first subgraph output is condition, not in node output
  1043. if i >= num_loop_carried:
  1044. assert not is_sequence(vi.type) # TODO: handle loop accumulation in sequence_type
  1045. subgraph_vi_dim = subgraph.output[i + 1].type.tensor_type.shape.dim
  1046. vi.type.tensor_type.shape.ClearField("dim")
  1047. vi_dim = vi.type.tensor_type.shape.dim
  1048. vi_dim.add().dim_param = loop_iter_dim
  1049. vi_dim.extend(list(subgraph_vi_dim))
  1050. vi.name = node.output[i]
  1051. def _infer_MatMul(self, node):
  1052. self._compute_matmul_shape(node)
  1053. def _infer_MatMulInteger(self, node):
  1054. self._compute_matmul_shape(node, onnx.TensorProto.INT32)
  1055. def _infer_NonMaxSuppression(self, node):
  1056. selected = str(self._new_symbolic_dim_from_output(node))
  1057. vi = self.known_vi_[node.output[0]]
  1058. vi.CopyFrom(helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, [selected, 3]))
  1059. def _infer_NonZero(self, node):
  1060. input_rank = self._get_shape_rank(node, 0)
  1061. # create a new symbolic dimension for NonZero output
  1062. nz_len = str(self._new_symbolic_dim_from_output(node, 0, 1))
  1063. vi = self.known_vi_[node.output[0]]
  1064. vi.CopyFrom(helper.make_tensor_value_info(node.output[0], vi.type.tensor_type.elem_type, [input_rank, nz_len]))
  1065. def _infer_OneHot(self, node):
  1066. sympy_shape = self._get_sympy_shape(node, 0)
  1067. depth = self._try_get_value(node, 1)
  1068. axis = get_attribute(node, "axis", -1)
  1069. axis = handle_negative_axis(axis, len(sympy_shape) + 1)
  1070. new_shape = get_shape_from_sympy_shape(
  1071. sympy_shape[:axis]
  1072. + [self._new_symbolic_dim_from_output(node) if not is_literal(depth) else depth]
  1073. + sympy_shape[axis:]
  1074. )
  1075. vi = self.known_vi_[node.output[0]]
  1076. vi.CopyFrom(
  1077. helper.make_tensor_value_info(
  1078. node.output[0],
  1079. self.known_vi_[node.input[2]].type.tensor_type.elem_type,
  1080. new_shape,
  1081. )
  1082. )
  1083. def _infer_Pad(self, node):
  1084. if get_opset(self.out_mp_) <= 10:
  1085. pads = get_attribute(node, "pads")
  1086. else:
  1087. pads = self._try_get_value(node, 1)
  1088. sympy_shape = self._get_sympy_shape(node, 0)
  1089. rank = len(sympy_shape)
  1090. if pads is not None:
  1091. assert len(pads) == 2 * rank
  1092. new_sympy_shape = [
  1093. d + pad_up + pad_down for d, pad_up, pad_down in zip(sympy_shape, pads[:rank], pads[rank:])
  1094. ]
  1095. self._update_computed_dims(new_sympy_shape)
  1096. else:
  1097. # dynamic pads, create new symbolic dimensions
  1098. new_sympy_shape = self._new_symbolic_shape(rank, node)
  1099. output_tp = self.known_vi_[node.input[0]].type.tensor_type.elem_type
  1100. vi = self.known_vi_[node.output[0]]
  1101. vi.CopyFrom(
  1102. helper.make_tensor_value_info(node.output[0], output_tp, get_shape_from_sympy_shape(new_sympy_shape))
  1103. )
  1104. def _infer_Pool(self, node):
  1105. sympy_shape = self._compute_conv_pool_shape(node)
  1106. self._update_computed_dims(sympy_shape)
  1107. for o in node.output:
  1108. if not o:
  1109. continue
  1110. vi = self.known_vi_[o]
  1111. vi.CopyFrom(
  1112. helper.make_tensor_value_info(
  1113. o,
  1114. vi.type.tensor_type.elem_type,
  1115. get_shape_from_sympy_shape(sympy_shape),
  1116. )
  1117. )
  1118. def _infer_aten_bitwise_or(self, node):
  1119. shape0 = self._get_shape(node, 0)
  1120. shape1 = self._get_shape(node, 1)
  1121. new_shape = self._broadcast_shapes(shape0, shape1)
  1122. t0 = self.known_vi_[node.input[0]]
  1123. vi = self.known_vi_[node.output[0]]
  1124. vi.CopyFrom(helper.make_tensor_value_info(node.output[0], t0.type.tensor_type.elem_type, new_shape))
  1125. def _infer_aten_diagonal(self, node):
  1126. sympy_shape = self._get_sympy_shape(node, 0)
  1127. rank = len(sympy_shape)
  1128. offset = self._try_get_value(node, 1)
  1129. dim1 = self._try_get_value(node, 2)
  1130. dim2 = self._try_get_value(node, 3)
  1131. assert offset is not None and dim1 is not None and dim2 is not None
  1132. dim1 = handle_negative_axis(dim1, rank)
  1133. dim2 = handle_negative_axis(dim2, rank)
  1134. new_shape = []
  1135. for dim, val in enumerate(sympy_shape):
  1136. if dim not in [dim1, dim2]:
  1137. new_shape.append(val)
  1138. shape1 = sympy_shape[dim1]
  1139. shape2 = sympy_shape[dim2]
  1140. if offset >= 0:
  1141. diag_shape = sympy.Max(0, sympy.Min(shape1, shape2 - offset))
  1142. else:
  1143. diag_shape = sympy.Max(0, sympy.Min(shape1 + offset, shape2))
  1144. new_shape.append(diag_shape)
  1145. if node.output[0]:
  1146. vi = self.known_vi_[node.output[0]]
  1147. vi.CopyFrom(
  1148. helper.make_tensor_value_info(
  1149. node.output[0],
  1150. self.known_vi_[node.input[0]].type.tensor_type.elem_type,
  1151. get_shape_from_sympy_shape(new_shape),
  1152. )
  1153. )
  1154. def _infer_aten_multinomial(self, node):
  1155. sympy_shape = self._get_sympy_shape(node, 0)
  1156. rank = len(sympy_shape)
  1157. assert rank in [1, 2]
  1158. num_samples = self._try_get_value(node, 1)
  1159. di = rank - 1
  1160. last_dim = num_samples if num_samples else str(self._new_symbolic_dim_from_output(node, 0, di))
  1161. output_shape = sympy_shape[:-1] + [last_dim]
  1162. vi = self.known_vi_[node.output[0]]
  1163. vi.CopyFrom(
  1164. helper.make_tensor_value_info(
  1165. node.output[0],
  1166. onnx.TensorProto.INT64,
  1167. get_shape_from_sympy_shape(output_shape),
  1168. )
  1169. )
  1170. def _infer_aten_pool2d(self, node):
  1171. sympy_shape = self._get_sympy_shape(node, 0)
  1172. assert len(sympy_shape) == 4
  1173. sympy_shape[-2:] = [self._new_symbolic_dim_from_output(node, 0, i) for i in [2, 3]]
  1174. self._update_computed_dims(sympy_shape)
  1175. for i, o in enumerate(node.output):
  1176. if not o:
  1177. continue
  1178. vi = self.known_vi_[o]
  1179. elem_type = onnx.TensorProto.INT64 if i == 1 else self.known_vi_[node.input[0]].type.tensor_type.elem_type
  1180. vi.CopyFrom(helper.make_tensor_value_info(o, elem_type, get_shape_from_sympy_shape(sympy_shape)))
  1181. def _infer_aten_minmax(self, node):
  1182. vi = self.known_vi_[node.output[0]]
  1183. if len(node.input) == 1:
  1184. vi.CopyFrom(
  1185. helper.make_tensor_value_info(
  1186. node.output[0], self.known_vi_[node.input[0]].type.tensor_type.elem_type, []
  1187. )
  1188. )
  1189. else:
  1190. assert len(node.input) == 3
  1191. keepdim = self._try_get_value(node, 2)
  1192. assert keepdim is not None # can only handle known keepdim case.
  1193. dim = self._try_get_value(node, 1)
  1194. if dim is None:
  1195. rank = self._get_shape_rank(node, 0)
  1196. output_shape = self._new_symbolic_shape(rank if keepdim else rank - 1, node)
  1197. else:
  1198. shape = self._get_sympy_shape(node, 0)
  1199. dim = handle_negative_axis(dim, len(shape))
  1200. output_shape = shape[:dim]
  1201. if keepdim:
  1202. output_shape += [1]
  1203. output_shape += shape[dim + 1 :]
  1204. output_shape = get_shape_from_sympy_shape(output_shape)
  1205. vi.CopyFrom(
  1206. helper.make_tensor_value_info(
  1207. node.output[0], self.known_vi_[node.input[0]].type.tensor_type.elem_type, output_shape
  1208. )
  1209. )
  1210. vi1 = self.known_vi_[node.output[1]]
  1211. vi1.CopyFrom(helper.make_tensor_value_info(node.output[1], onnx.TensorProto.INT64, output_shape))
  1212. def _infer_aten_unfold(self, node):
  1213. sympy_shape = self._get_sympy_shape(node, 0)
  1214. dimension = self._try_get_value(node, 1)
  1215. size = self._try_get_value(node, 2)
  1216. step = self._try_get_value(node, 3)
  1217. if dimension is not None and size is not None and step is not None:
  1218. assert dimension < len(sympy_shape)
  1219. sympy_shape[dimension] = (sympy_shape[dimension] - size) // step + 1
  1220. sympy_shape.append(size)
  1221. else:
  1222. rank = len(sympy_shape)
  1223. sympy_shape = self._new_symbolic_shape(rank + 1, node)
  1224. self._update_computed_dims(sympy_shape)
  1225. if node.output[0]:
  1226. vi = self.known_vi_[node.output[0]]
  1227. vi.CopyFrom(
  1228. helper.make_tensor_value_info(
  1229. node.output[0],
  1230. self.known_vi_[node.input[0]].type.tensor_type.elem_type,
  1231. get_shape_from_sympy_shape(sympy_shape),
  1232. )
  1233. )
  1234. def _infer_aten_argmax(self, node):
  1235. new_shape = None
  1236. if node.input[1] == "":
  1237. # The argmax of the flattened input is returned.
  1238. new_shape = []
  1239. else:
  1240. dim = self._try_get_value(node, 1)
  1241. keepdim = self._try_get_value(node, 2)
  1242. if keepdim is not None:
  1243. sympy_shape = self._get_sympy_shape(node, 0)
  1244. if dim is not None:
  1245. dim = handle_negative_axis(dim, len(sympy_shape))
  1246. if keepdim:
  1247. sympy_shape[dim] = 1
  1248. else:
  1249. del sympy_shape[dim]
  1250. else:
  1251. rank = len(sympy_shape)
  1252. sympy_shape = self._new_symbolic_shape(rank if keepdim else rank - 1, node)
  1253. self._update_computed_dims(sympy_shape)
  1254. new_shape = get_shape_from_sympy_shape(sympy_shape)
  1255. if node.output[0] and new_shape is not None:
  1256. vi = self.known_vi_[node.output[0]]
  1257. vi.CopyFrom(helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, new_shape))
  1258. def _infer_aten_group_norm(self, node):
  1259. self._propagate_shape_and_type(node)
  1260. input_shape = self._get_shape(node, 0)
  1261. N = input_shape[0] if input_shape is not None and len(input_shape) != 0 else None
  1262. group = self._try_get_value(node, 6)
  1263. output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
  1264. for i in [1, 2]:
  1265. if node.output[i]:
  1266. vi = self.known_vi_[node.output[i]]
  1267. vi.CopyFrom(
  1268. helper.make_tensor_value_info(
  1269. node.output[i],
  1270. output_dtype,
  1271. [
  1272. N if N is not None else str(self._new_symbolic_dim_from_output(node, i, 0)),
  1273. as_scalar(group)
  1274. if group is not None
  1275. else str(self._new_symbolic_dim_from_output(node, i, 1)),
  1276. ],
  1277. )
  1278. )
  1279. def _infer_aten_upsample_nearest(self, node):
  1280. new_shape = None
  1281. input_shape = self._get_shape(node, 0)
  1282. if input_shape is not None:
  1283. new_shape = input_shape[:2]
  1284. output_size = self._try_get_value(node, 1)
  1285. if output_size is not None:
  1286. new_shape += [dim_size.item() for dim_size in output_size]
  1287. else:
  1288. rank = len(input_shape)
  1289. new_shape += [str(self._new_symbolic_dim_from_output(node, 0, i)) for i in range(2, rank)]
  1290. if node.output[0] and new_shape is not None:
  1291. output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
  1292. vi = self.known_vi_[node.output[0]]
  1293. vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape))
  1294. def _infer_BatchNormalization(self, node):
  1295. self._propagate_shape_and_type(node)
  1296. # this works for opsets < 14 and 14 since we check i < len(node.output) in the loop
  1297. for i in [1, 2, 3, 4]:
  1298. if i < len(node.output) and node.output[i] != "":
  1299. # all of these parameters have the same shape as the 1st input
  1300. self._propagate_shape_and_type(node, input_index=1, output_index=i)
  1301. def _infer_Range(self, node):
  1302. vi = self.known_vi_[node.output[0]]
  1303. input_data = self._get_int_values(node)
  1304. if all([i is not None for i in input_data]):
  1305. start = as_scalar(input_data[0])
  1306. limit = as_scalar(input_data[1])
  1307. delta = as_scalar(input_data[2])
  1308. new_sympy_shape = [sympy.Max(sympy.ceiling((limit - start) / delta), 0)]
  1309. else:
  1310. new_sympy_shape = [self._new_symbolic_dim_from_output(node)]
  1311. self._update_computed_dims(new_sympy_shape)
  1312. vi.CopyFrom(
  1313. helper.make_tensor_value_info(
  1314. node.output[0],
  1315. self.known_vi_[node.input[0]].type.tensor_type.elem_type,
  1316. get_shape_from_sympy_shape(new_sympy_shape),
  1317. )
  1318. )
  1319. def _infer_ReduceSum(self, node):
  1320. keep_dims = get_attribute(node, "keepdims", 1)
  1321. if get_opset(self.out_mp_) >= 13 and len(node.input) > 1:
  1322. # ReduceSum changes axes to input[1] in opset 13
  1323. axes = self._try_get_value(node, 1)
  1324. vi = self.known_vi_[node.output[0]]
  1325. if axes is None:
  1326. assert keep_dims # can only handle keep_dims==True when axes is unknown, by generating new ranks
  1327. vi.CopyFrom(
  1328. helper.make_tensor_value_info(
  1329. node.output[0],
  1330. self.known_vi_[node.input[0]].type.tensor_type.elem_type,
  1331. get_shape_from_sympy_shape(self._new_symbolic_shape(self._get_shape_rank(node, 0), node)),
  1332. )
  1333. )
  1334. else:
  1335. shape = self._get_shape(node, 0)
  1336. output_shape = []
  1337. axes = [handle_negative_axis(a, len(shape)) for a in axes]
  1338. for i, d in enumerate(shape):
  1339. if i in axes:
  1340. if keep_dims:
  1341. output_shape.append(1)
  1342. else:
  1343. output_shape.append(d)
  1344. vi.CopyFrom(
  1345. helper.make_tensor_value_info(
  1346. node.output[0],
  1347. self.known_vi_[node.input[0]].type.tensor_type.elem_type,
  1348. output_shape,
  1349. )
  1350. )
  1351. def _infer_ReduceProd(self, node):
  1352. axes = get_attribute(node, "axes")
  1353. keep_dims = get_attribute(node, "keepdims", 1)
  1354. if keep_dims == 0 and axes == [0]:
  1355. data = self._get_int_values(node)[0]
  1356. if data is not None:
  1357. self.sympy_data_[node.output[0]] = sympy_reduce_product(data)
  1358. def _infer_Reshape(self, node):
  1359. shape_value = self._try_get_value(node, 1)
  1360. vi = self.known_vi_[node.output[0]]
  1361. if shape_value is None:
  1362. shape_shape = self._get_shape(node, 1)
  1363. assert len(shape_shape) == 1
  1364. shape_rank = shape_shape[0]
  1365. assert is_literal(shape_rank)
  1366. vi.CopyFrom(
  1367. helper.make_tensor_value_info(
  1368. node.output[0],
  1369. vi.type.tensor_type.elem_type,
  1370. get_shape_from_sympy_shape(self._new_symbolic_shape(shape_rank, node)),
  1371. )
  1372. )
  1373. else:
  1374. input_sympy_shape = self._get_sympy_shape(node, 0)
  1375. total = int(1)
  1376. for d in input_sympy_shape:
  1377. total = total * d
  1378. new_sympy_shape = []
  1379. deferred_dim_idx = -1
  1380. non_deferred_size = int(1)
  1381. for i, d in enumerate(shape_value):
  1382. if type(d) == sympy.Symbol:
  1383. new_sympy_shape.append(d)
  1384. elif d == 0:
  1385. new_sympy_shape.append(input_sympy_shape[i])
  1386. non_deferred_size = non_deferred_size * input_sympy_shape[i]
  1387. else:
  1388. new_sympy_shape.append(d)
  1389. if d == -1:
  1390. deferred_dim_idx = i
  1391. elif d != 0:
  1392. non_deferred_size = non_deferred_size * d
  1393. assert new_sympy_shape.count(-1) < 2
  1394. if -1 in new_sympy_shape:
  1395. new_dim = total // non_deferred_size
  1396. new_sympy_shape[deferred_dim_idx] = new_dim
  1397. self._update_computed_dims(new_sympy_shape)
  1398. vi.CopyFrom(
  1399. helper.make_tensor_value_info(
  1400. node.output[0],
  1401. vi.type.tensor_type.elem_type,
  1402. get_shape_from_sympy_shape(new_sympy_shape),
  1403. )
  1404. )
  1405. self._pass_on_sympy_data(node)
  1406. def _infer_Resize(self, node):
  1407. vi = self.known_vi_[node.output[0]]
  1408. input_sympy_shape = self._get_sympy_shape(node, 0)
  1409. if get_opset(self.out_mp_) <= 10:
  1410. scales = self._try_get_value(node, 1)
  1411. if scales is not None:
  1412. new_sympy_shape = [sympy.simplify(sympy.floor(d * s)) for d, s in zip(input_sympy_shape, scales)]
  1413. self._update_computed_dims(new_sympy_shape)
  1414. vi.CopyFrom(
  1415. helper.make_tensor_value_info(
  1416. node.output[0],
  1417. self.known_vi_[node.input[0]].type.tensor_type.elem_type,
  1418. get_shape_from_sympy_shape(new_sympy_shape),
  1419. )
  1420. )
  1421. else:
  1422. roi = self._try_get_value(node, 1)
  1423. scales = self._try_get_value(node, 2)
  1424. sizes = self._try_get_value(node, 3)
  1425. if sizes is not None:
  1426. new_sympy_shape = [sympy.simplify(sympy.floor(s)) for s in sizes]
  1427. self._update_computed_dims(new_sympy_shape)
  1428. elif scales is not None:
  1429. rank = len(scales)
  1430. if get_attribute(node, "coordinate_transformation_mode") == "tf_crop_and_resize":
  1431. assert len(roi) == 2 * rank
  1432. roi_start = list(roi)[:rank]
  1433. roi_end = list(roi)[rank:]
  1434. else:
  1435. roi_start = [0] * rank
  1436. roi_end = [1] * rank
  1437. scales = list(scales)
  1438. new_sympy_shape = [
  1439. sympy.simplify(sympy.floor(d * (end - start) * scale))
  1440. for d, start, end, scale in zip(input_sympy_shape, roi_start, roi_end, scales)
  1441. ]
  1442. self._update_computed_dims(new_sympy_shape)
  1443. else:
  1444. new_sympy_shape = self._new_symbolic_shape(self._get_shape_rank(node, 0), node)
  1445. vi.CopyFrom(
  1446. helper.make_tensor_value_info(
  1447. node.output[0],
  1448. self.known_vi_[node.input[0]].type.tensor_type.elem_type,
  1449. get_shape_from_sympy_shape(new_sympy_shape),
  1450. )
  1451. )
  1452. def _infer_Scan(self, node):
  1453. subgraph = get_attribute(node, "body")
  1454. num_scan_inputs = get_attribute(node, "num_scan_inputs")
  1455. scan_input_axes = get_attribute(node, "scan_input_axes", [0] * num_scan_inputs)
  1456. num_scan_states = len(node.input) - num_scan_inputs
  1457. scan_input_axes = [
  1458. handle_negative_axis(ax, self._get_shape_rank(node, i + num_scan_states))
  1459. for i, ax in enumerate(scan_input_axes)
  1460. ]
  1461. # We may have cases where the subgraph has optional inputs that appear in both subgraph's input and initializer,
  1462. # but not in the node's input. In such cases, the input model might be invalid, but let's skip those optional inputs.
  1463. assert len(subgraph.input) >= len(node.input)
  1464. subgraph_inputs = subgraph.input[: len(node.input)]
  1465. for i, si in enumerate(subgraph_inputs):
  1466. subgraph_name = si.name
  1467. si.CopyFrom(self.known_vi_[node.input[i]])
  1468. if i >= num_scan_states:
  1469. scan_input_dim = si.type.tensor_type.shape.dim
  1470. scan_input_dim.remove(scan_input_dim[scan_input_axes[i - num_scan_states]])
  1471. si.name = subgraph_name
  1472. self._onnx_infer_subgraph(node, subgraph)
  1473. num_scan_outputs = len(node.output) - num_scan_states
  1474. scan_output_axes = get_attribute(node, "scan_output_axes", [0] * num_scan_outputs)
  1475. scan_input_dim = get_shape_from_type_proto(self.known_vi_[node.input[-1]].type)[scan_input_axes[-1]]
  1476. for i, o in enumerate(node.output):
  1477. vi = self.known_vi_[o]
  1478. if i >= num_scan_states:
  1479. shape = get_shape_from_type_proto(subgraph.output[i].type)
  1480. new_dim = handle_negative_axis(scan_output_axes[i - num_scan_states], len(shape) + 1)
  1481. shape = shape[:new_dim] + [scan_input_dim] + shape[new_dim:]
  1482. vi.CopyFrom(helper.make_tensor_value_info(o, subgraph.output[i].type.tensor_type.elem_type, shape))
  1483. else:
  1484. vi.CopyFrom(subgraph.output[i])
  1485. vi.name = o
  1486. def _infer_ScatterElements(self, node):
  1487. data_shape = self._get_shape(node, 0)
  1488. vi = self.known_vi_[node.output[0]]
  1489. vi.CopyFrom(
  1490. helper.make_tensor_value_info(
  1491. node.output[0],
  1492. self.known_vi_[node.input[0]].type.tensor_type.elem_type,
  1493. data_shape,
  1494. )
  1495. )
  1496. def _infer_SequenceAt(self, node):
  1497. # need to create new symbolic dimension if sequence shape has None:
  1498. seq_shape = self._get_shape(node, 0)
  1499. vi = self.known_vi_[node.output[0]]
  1500. if seq_shape is not None:
  1501. for di, d in enumerate(seq_shape):
  1502. if d is not None:
  1503. continue
  1504. new_dim = onnx.TensorShapeProto.Dimension()
  1505. new_dim.dim_param = str(self._new_symbolic_dim_from_output(node, 0, di))
  1506. vi.type.tensor_type.shape.dim[di].CopyFrom(new_dim)
  1507. def _infer_SequenceInsert(self, node):
  1508. # workaround bug in onnx's shape inference
  1509. vi_seq = self.known_vi_[node.input[0]]
  1510. vi_tensor = self.known_vi_[node.input[1]]
  1511. vi_out_seq = self.known_vi_[node.output[0]]
  1512. vi_out_seq.CopyFrom(vi_seq)
  1513. vi_out_seq.name = node.output[0]
  1514. self._fuse_tensor_type(node, 0, vi_out_seq.type, vi_tensor.type)
  1515. def _infer_Shape(self, node):
  1516. self.sympy_data_[node.output[0]] = self._get_sympy_shape(node, 0)
  1517. def _infer_Size(self, node):
  1518. sympy_shape = self._get_sympy_shape(node, 0)
  1519. self.sympy_data_[node.output[0]] = sympy_reduce_product(sympy_shape)
  1520. self.known_vi_[node.output[0]].CopyFrom(
  1521. helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, [])
  1522. )
  1523. def _infer_Slice(self, node):
  1524. def less_equal(x, y):
  1525. try:
  1526. return bool(x <= y)
  1527. except TypeError:
  1528. pass
  1529. try:
  1530. return bool(y >= x)
  1531. except TypeError:
  1532. pass
  1533. try:
  1534. return bool(-x >= -y)
  1535. except TypeError:
  1536. pass
  1537. try:
  1538. return bool(-y <= -x)
  1539. except TypeError:
  1540. # the last attempt; this may raise TypeError
  1541. return bool(y - x >= 0)
  1542. def handle_negative_index(index, bound):
  1543. """normalizes a negative index to be in [0, bound)"""
  1544. try:
  1545. if not less_equal(0, index):
  1546. if is_literal(index) and index <= -self.int_max_:
  1547. # this case is handled separately
  1548. return index
  1549. return bound + index
  1550. except TypeError:
  1551. logger.warning("Cannot determine if {} < 0".format(index))
  1552. return index
  1553. if get_opset(self.out_mp_) <= 9:
  1554. axes = get_attribute(node, "axes")
  1555. starts = get_attribute(node, "starts")
  1556. ends = get_attribute(node, "ends")
  1557. if not axes:
  1558. axes = list(range(len(starts)))
  1559. steps = [1] * len(axes)
  1560. else:
  1561. starts = as_list(self._try_get_value(node, 1), keep_none=True)
  1562. ends = as_list(self._try_get_value(node, 2), keep_none=True)
  1563. axes = self._try_get_value(node, 3)
  1564. steps = self._try_get_value(node, 4)
  1565. if axes is None and not (starts is None and ends is None):
  1566. axes = list(range(0, len(starts if starts is not None else ends)))
  1567. if steps is None and not (starts is None and ends is None):
  1568. steps = [1] * len(starts if starts is not None else ends)
  1569. axes = as_list(axes, keep_none=True)
  1570. steps = as_list(steps, keep_none=True)
  1571. new_sympy_shape = self._get_sympy_shape(node, 0)
  1572. if starts is None or ends is None:
  1573. if axes is None:
  1574. for i in range(len(new_sympy_shape)):
  1575. new_sympy_shape[i] = self._new_symbolic_dim_from_output(node, 0, i)
  1576. else:
  1577. new_sympy_shape = get_shape_from_sympy_shape(new_sympy_shape)
  1578. for i in axes:
  1579. new_sympy_shape[i] = self._new_symbolic_dim_from_output(node, 0, i)
  1580. else:
  1581. for i, s, e, t in zip(axes, starts, ends, steps):
  1582. e = handle_negative_index(e, new_sympy_shape[i])
  1583. if is_literal(e):
  1584. if e >= self.int_max_:
  1585. e = new_sympy_shape[i]
  1586. elif e <= -self.int_max_:
  1587. e = 0 if s > 0 else -1
  1588. elif is_literal(new_sympy_shape[i]):
  1589. if e < 0:
  1590. e = max(0, e + new_sympy_shape[i])
  1591. e = min(e, new_sympy_shape[i])
  1592. else:
  1593. if e > 0:
  1594. e = (
  1595. sympy.Min(e, new_sympy_shape[i]) if e > 1 else e
  1596. ) # special case for slicing first to make computation easier
  1597. else:
  1598. if is_literal(new_sympy_shape[i]):
  1599. e = sympy.Min(e, new_sympy_shape[i])
  1600. else:
  1601. try:
  1602. if not less_equal(e, new_sympy_shape[i]):
  1603. e = new_sympy_shape[i]
  1604. except Exception:
  1605. logger.warning(
  1606. "Unable to determine if {} <= {}, treat as equal".format(e, new_sympy_shape[i])
  1607. )
  1608. e = new_sympy_shape[i]
  1609. s = handle_negative_index(s, new_sympy_shape[i])
  1610. if is_literal(new_sympy_shape[i]) and is_literal(s):
  1611. s = max(0, min(s, new_sympy_shape[i]))
  1612. new_sympy_shape[i] = sympy.simplify((e - s + t + (-1 if t > 0 else 1)) // t)
  1613. self._update_computed_dims(new_sympy_shape)
  1614. vi = self.known_vi_[node.output[0]]
  1615. vi.CopyFrom(
  1616. helper.make_tensor_value_info(
  1617. node.output[0],
  1618. vi.type.tensor_type.elem_type,
  1619. get_shape_from_sympy_shape(new_sympy_shape),
  1620. )
  1621. )
  1622. # handle sympy_data if needed, for slice in shape computation
  1623. if (
  1624. node.input[0] in self.sympy_data_
  1625. and [0] == axes
  1626. and len(starts) == 1
  1627. and len(ends) == 1
  1628. and len(steps) == 1
  1629. ):
  1630. input_sympy_data = self.sympy_data_[node.input[0]]
  1631. if type(input_sympy_data) == list or (
  1632. type(input_sympy_data) == np.array and len(input_sympy_data.shape) == 1
  1633. ):
  1634. self.sympy_data_[node.output[0]] = input_sympy_data[starts[0] : ends[0] : steps[0]]
  1635. def _infer_SoftmaxCrossEntropyLoss(self, node):
  1636. vi = self.known_vi_[node.output[0]]
  1637. elem_type = self.known_vi_[node.input[0]].type.tensor_type.elem_type
  1638. vi.type.tensor_type.elem_type = elem_type
  1639. vi.type.tensor_type.shape.CopyFrom(onnx.TensorShapeProto())
  1640. if len(node.output) > 1:
  1641. data_shape = self._get_shape(node, 0)
  1642. vi = self.known_vi_[node.output[1]]
  1643. vi.CopyFrom(helper.make_tensor_value_info(vi.name, elem_type, data_shape))
  1644. def _infer_Split_Common(self, node, make_value_info_func):
  1645. input_sympy_shape = self._get_sympy_shape(node, 0)
  1646. axis = handle_negative_axis(get_attribute(node, "axis", 0), len(input_sympy_shape))
  1647. split = get_attribute(node, "split")
  1648. if not split:
  1649. num_outputs = len(node.output)
  1650. split = [input_sympy_shape[axis] / sympy.Integer(num_outputs)] * num_outputs
  1651. self._update_computed_dims(split)
  1652. else:
  1653. split = [sympy.Integer(s) for s in split]
  1654. for i_o in range(len(split)):
  1655. vi = self.known_vi_[node.output[i_o]]
  1656. vi.CopyFrom(
  1657. make_value_info_func(
  1658. node.output[i_o],
  1659. self.known_vi_[node.input[0]].type.tensor_type.elem_type,
  1660. get_shape_from_sympy_shape(input_sympy_shape[:axis] + [split[i_o]] + input_sympy_shape[axis + 1 :]),
  1661. )
  1662. )
  1663. self.known_vi_[vi.name] = vi
  1664. def _infer_Split(self, node):
  1665. self._infer_Split_Common(node, helper.make_tensor_value_info)
  1666. def _infer_SplitToSequence(self, node):
  1667. self._infer_Split_Common(node, helper.make_sequence_value_info)
  1668. def _infer_Squeeze(self, node):
  1669. input_shape = self._get_shape(node, 0)
  1670. op_set = get_opset(self.out_mp_)
  1671. # Depending on op-version 'axes' are provided as attribute or via 2nd input
  1672. if op_set < 13:
  1673. axes = get_attribute(node, "axes")
  1674. assert self._try_get_value(node, 1) is None
  1675. else:
  1676. axes = self._try_get_value(node, 1)
  1677. assert get_attribute(node, "axes") is None
  1678. if axes is None:
  1679. # No axes have been provided (neither via attribute nor via input).
  1680. # In this case the 'Shape' op should remove all axis with dimension 1.
  1681. # For symbolic dimensions we guess they are !=1.
  1682. output_shape = [s for s in input_shape if s != 1]
  1683. if self.verbose_ > 0:
  1684. symbolic_dimensions = [s for s in input_shape if type(s) != int]
  1685. if len(symbolic_dimensions) > 0:
  1686. logger.debug(
  1687. f"Symbolic dimensions in input shape of op: '{node.op_type}' node: '{node.name}'. "
  1688. + f"Assuming the following dimensions are never equal to 1: {symbolic_dimensions}"
  1689. )
  1690. else:
  1691. axes = [handle_negative_axis(a, len(input_shape)) for a in axes]
  1692. output_shape = []
  1693. for i in range(len(input_shape)):
  1694. if i not in axes:
  1695. output_shape.append(input_shape[i])
  1696. else:
  1697. assert input_shape[i] == 1 or type(input_shape[i]) != int
  1698. if self.verbose_ > 0 and type(input_shape[i]) != int:
  1699. logger.debug(
  1700. f"Symbolic dimensions in input shape of op: '{node.op_type}' node: '{node.name}'. "
  1701. + f"Assuming the dimension '{input_shape[i]}' at index {i} of the input to be equal to 1."
  1702. )
  1703. vi = self.known_vi_[node.output[0]]
  1704. vi.CopyFrom(
  1705. helper.make_tensor_value_info(
  1706. node.output[0],
  1707. self.known_vi_[node.input[0]].type.tensor_type.elem_type,
  1708. output_shape,
  1709. )
  1710. )
  1711. self._pass_on_sympy_data(node)
  1712. def _infer_Tile(self, node):
  1713. repeats_value = self._try_get_value(node, 1)
  1714. new_sympy_shape = []
  1715. if repeats_value is not None:
  1716. input_sympy_shape = self._get_sympy_shape(node, 0)
  1717. for i, d in enumerate(input_sympy_shape):
  1718. new_dim = d * repeats_value[i]
  1719. new_sympy_shape.append(new_dim)
  1720. self._update_computed_dims(new_sympy_shape)
  1721. else:
  1722. new_sympy_shape = self._new_symbolic_shape(self._get_shape_rank(node, 0), node)
  1723. vi = self.known_vi_[node.output[0]]
  1724. vi.CopyFrom(
  1725. helper.make_tensor_value_info(
  1726. node.output[0],
  1727. vi.type.tensor_type.elem_type,
  1728. get_shape_from_sympy_shape(new_sympy_shape),
  1729. )
  1730. )
  1731. def _infer_TopK(self, node):
  1732. rank = self._get_shape_rank(node, 0)
  1733. axis = handle_negative_axis(get_attribute(node, "axis", -1), rank)
  1734. new_shape = self._get_shape(node, 0)
  1735. if get_opset(self.out_mp_) <= 9:
  1736. k = get_attribute(node, "k")
  1737. else:
  1738. k = self._get_int_values(node)[1]
  1739. if k == None:
  1740. k = self._new_symbolic_dim_from_output(node)
  1741. else:
  1742. k = as_scalar(k)
  1743. if type(k) in [int, str]:
  1744. new_shape[axis] = k
  1745. else:
  1746. new_sympy_shape = self._get_sympy_shape(node, 0)
  1747. new_sympy_shape[axis] = k
  1748. self._update_computed_dims(
  1749. new_sympy_shape
  1750. ) # note that TopK dim could be computed in sympy_data, so need to update computed_dims when it enters shape
  1751. new_shape = get_shape_from_sympy_shape(new_sympy_shape)
  1752. for i_o in range(len(node.output)):
  1753. vi = self.known_vi_[node.output[i_o]]
  1754. vi.CopyFrom(helper.make_tensor_value_info(node.output[i_o], vi.type.tensor_type.elem_type, new_shape))
  1755. def _infer_Transpose(self, node):
  1756. if node.input[0] in self.sympy_data_:
  1757. data_shape = self._get_shape(node, 0)
  1758. perm = get_attribute(node, "perm", reversed(list(range(len(data_shape)))))
  1759. input_data = self.sympy_data_[node.input[0]]
  1760. self.sympy_data_[node.output[0]] = (
  1761. np.transpose(np.array(input_data).reshape(*data_shape), axes=tuple(perm)).flatten().tolist()
  1762. )
  1763. def _infer_Unsqueeze(self, node):
  1764. input_shape = self._get_shape(node, 0)
  1765. op_set = get_opset(self.out_mp_)
  1766. # Depending on op-version 'axes' are provided as attribute or via 2nd input
  1767. if op_set < 13:
  1768. axes = get_attribute(node, "axes")
  1769. assert self._try_get_value(node, 1) is None
  1770. else:
  1771. axes = self._try_get_value(node, 1)
  1772. assert get_attribute(node, "axes") is None
  1773. output_rank = len(input_shape) + len(axes)
  1774. axes = [handle_negative_axis(a, output_rank) for a in axes]
  1775. input_axis = 0
  1776. output_shape = []
  1777. for i in range(output_rank):
  1778. if i in axes:
  1779. output_shape.append(1)
  1780. else:
  1781. output_shape.append(input_shape[input_axis])
  1782. input_axis += 1
  1783. vi = self.known_vi_[node.output[0]]
  1784. vi.CopyFrom(
  1785. helper.make_tensor_value_info(
  1786. node.output[0],
  1787. self.known_vi_[node.input[0]].type.tensor_type.elem_type,
  1788. output_shape,
  1789. )
  1790. )
  1791. self._pass_on_sympy_data(node)
  1792. def _infer_ZipMap(self, node):
  1793. map_key_type = None
  1794. if get_attribute(node, "classlabels_int64s") is not None:
  1795. map_key_type = onnx.TensorProto.INT64
  1796. elif get_attribute(node, "classlabels_strings") is not None:
  1797. map_key_type = onnx.TensorProto.STRING
  1798. assert map_key_type is not None
  1799. new_vi = onnx.ValueInfoProto()
  1800. new_vi.name = node.output[0]
  1801. new_vi.type.sequence_type.elem_type.map_type.value_type.tensor_type.elem_type = onnx.TensorProto.FLOAT
  1802. new_vi.type.sequence_type.elem_type.map_type.key_type = map_key_type
  1803. vi = self.known_vi_[node.output[0]]
  1804. vi.CopyFrom(new_vi)
  1805. def _infer_Attention(self, node):
  1806. shape = self._get_shape(node, 0)
  1807. shape_bias = self._get_shape(node, 2)
  1808. if shape and len(shape) == 3 and shape_bias and len(shape_bias) == 1:
  1809. qkv_hidden_sizes_attr = get_attribute(node, "qkv_hidden_sizes")
  1810. if qkv_hidden_sizes_attr is not None:
  1811. assert len(qkv_hidden_sizes_attr) == 3
  1812. shape[2] = int(qkv_hidden_sizes_attr[2])
  1813. elif isinstance(shape_bias[0], int):
  1814. shape[2] = int(shape_bias[0] / 3)
  1815. output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
  1816. vi = self.known_vi_[node.output[0]]
  1817. vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, shape))
  1818. if len(node.output) > 1:
  1819. # input shape: (batch_size, sequence_length, hidden_size)
  1820. # past shape: (2, batch_size, num_heads, past_sequence_length, head_size)
  1821. # mask shape: (batch_size, total_sequence_length) or (batch_size, sequence_length, total_sequence_length) or (batch_size, 1, max_seq_len, max_seq_len)
  1822. # present shape: (2, batch_size, num_heads, total_sequence_length, head_size), where total_sequence_length=sequence_length+past_sequence_length
  1823. input_shape = self._get_shape(node, 0)
  1824. past_shape = self._get_shape(node, 4)
  1825. mask_shape = self._get_shape(node, 3)
  1826. if past_shape and len(past_shape) == 5:
  1827. if mask_shape and len(mask_shape) in [2, 3]:
  1828. past_shape[3] = mask_shape[-1]
  1829. elif input_shape and len(input_shape) == 3:
  1830. if isinstance(input_shape[1], int) and isinstance(past_shape[3], int):
  1831. past_shape[3] = input_shape[1] + past_shape[3]
  1832. else:
  1833. past_shape[3] = f"{past_shape[3]}+{input_shape[1]}"
  1834. vi = self.known_vi_[node.output[1]]
  1835. vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape))
  1836. def _infer_BiasGelu(self, node):
  1837. self._propagate_shape_and_type(node)
  1838. def _infer_MultiHeadAttention(self, node):
  1839. # Input 0 (query) has shape (batch_size, sequence_length, hidden_size)
  1840. # Without packed KV:
  1841. # Input 1 (key) has shape (batch_size, kv_sequence_length, hidden_size)
  1842. # Input 2 (value) has shape (batch_size, kv_sequence_length, v_hidden_size)
  1843. # With packed KV:
  1844. # Input 1 (key) has shape (batch_size, kv_sequence_length, num_heads, 2, head_size)
  1845. # Input 2 (value) is nullptr
  1846. # Output 0 has shape (batch_size, sequence_length, v_hidden_size)
  1847. query_shape = self._get_shape(node, 0)
  1848. key_shape = self._get_shape(node, 1)
  1849. if query_shape is not None and len(query_shape) == 3:
  1850. # By default, hidden size is same for Q/K/V. Only need check v_hidden_size when value is provided.
  1851. output_shape = query_shape
  1852. if key_shape and len(key_shape) == 3:
  1853. value_shape = self._get_shape(node, 2)
  1854. if value_shape and len(value_shape) == 3:
  1855. output_shape[2] = value_shape[2]
  1856. output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
  1857. vi = self.known_vi_[node.output[0]]
  1858. vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
  1859. def _infer_FastGelu(self, node):
  1860. self._propagate_shape_and_type(node)
  1861. def _infer_Gelu(self, node):
  1862. self._propagate_shape_and_type(node)
  1863. def _infer_GemmFastGelu(self, node):
  1864. self._compute_matmul_shape(node)
  1865. def _infer_LayerNormalization(self, node):
  1866. self._propagate_shape_and_type(node)
  1867. def _infer_LongformerAttention(self, node):
  1868. self._propagate_shape_and_type(node)
  1869. def _infer_EmbedLayerNormalization(self, node):
  1870. input_ids_shape = self._get_shape(node, 0)
  1871. word_embedding_shape = self._get_shape(node, 2)
  1872. assert len(input_ids_shape) == 2 and len(word_embedding_shape) == 2
  1873. output_shape = input_ids_shape + [word_embedding_shape[1]]
  1874. word_embedding_dtype = self.known_vi_[node.input[2]].type.tensor_type.elem_type
  1875. vi = self.known_vi_[node.output[0]]
  1876. vi.CopyFrom(helper.make_tensor_value_info(node.output[0], word_embedding_dtype, output_shape))
  1877. mask_index_shape = [input_ids_shape[0]]
  1878. vi = self.known_vi_[node.output[1]]
  1879. vi.CopyFrom(helper.make_tensor_value_info(node.output[1], onnx.TensorProto.INT32, mask_index_shape))
  1880. if len(node.output) > 2:
  1881. # Optional output of add before layer nomalization is done
  1882. # shape is same as the output
  1883. vi = self.known_vi_[node.output[2]]
  1884. vi.CopyFrom(helper.make_tensor_value_info(node.output[2], word_embedding_dtype, output_shape))
  1885. def _infer_SkipLayerNormalization(self, node):
  1886. self._propagate_shape_and_type(node)
  1887. if len(node.output) > 3:
  1888. self._propagate_shape_and_type(node, 0, 3)
  1889. # If the SkipLayerNormalization node contains the optional
  1890. # output for inference, infer the shape and type for it too
  1891. if len(node.output) > 3:
  1892. self._propagate_shape_and_type(node, 0, 3)
  1893. def _infer_GroupNorm(self, node):
  1894. self._propagate_shape_and_type(node)
  1895. def _infer_BiasSplitGelu(self, node):
  1896. input_shape = self._get_shape(node, 0)
  1897. bias_shape = self._get_shape(node, 1)
  1898. if input_shape and bias_shape and isinstance(bias_shape[0], int):
  1899. output_shape = input_shape
  1900. output_shape[2] = int(bias_shape[0] / 2)
  1901. vi = self.known_vi_[node.output[0]]
  1902. output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
  1903. vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, output_shape))
  1904. def _infer_PythonOp(self, node):
  1905. output_tensor_types = get_attribute(node, "output_tensor_types")
  1906. assert output_tensor_types
  1907. output_tensor_ranks = get_attribute(node, "output_tensor_ranks")
  1908. assert output_tensor_ranks
  1909. # set the context output seperately.
  1910. # The first output is autograd's context.
  1911. vi = self.known_vi_[node.output[0]]
  1912. vi.CopyFrom(helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, []))
  1913. # Outputs after autograd's context are tensors.
  1914. # We assume their ranks are fixed for different model inputs.
  1915. for i in range(len(node.output) - 1):
  1916. # Process the i-th tensor outputs.
  1917. vi = self.known_vi_[node.output[i + 1]]
  1918. sympy_shape = self._new_symbolic_shape(output_tensor_ranks[i], node)
  1919. shape = get_shape_from_sympy_shape(sympy_shape)
  1920. value_info = helper.make_tensor_value_info(node.output[i + 1], output_tensor_types[i], shape)
  1921. vi.CopyFrom(value_info)
  1922. def _propagate_shape_and_type(self, node, input_index=0, output_index=0):
  1923. shape = self._get_shape(node, input_index)
  1924. output_dtype = self.known_vi_[node.input[input_index]].type.tensor_type.elem_type
  1925. vi = self.known_vi_[node.output[output_index]]
  1926. vi.CopyFrom(helper.make_tensor_value_info(node.output[output_index], output_dtype, shape))
  1927. def _is_none_dim(self, dim_value):
  1928. if type(dim_value) != str:
  1929. return False
  1930. if "unk__" not in dim_value:
  1931. return False
  1932. if dim_value in self.symbolic_dims_.keys():
  1933. return False
  1934. return True
  1935. def _is_shape_contains_none_dim(self, out_shape):
  1936. for out in out_shape:
  1937. if self._is_none_dim(out):
  1938. return out
  1939. return None
  1940. def _infer_impl(self, start_sympy_data=None):
  1941. self.sympy_data_ = start_sympy_data or {}
  1942. self.out_mp_.graph.ClearField("value_info")
  1943. self._apply_suggested_merge(graph_input_only=True)
  1944. self.input_symbols_ = set()
  1945. for i in self.out_mp_.graph.input:
  1946. input_shape = get_shape_from_value_info(i)
  1947. if input_shape is None:
  1948. continue
  1949. if is_sequence(i.type):
  1950. input_dims = i.type.sequence_type.elem_type.tensor_type.shape.dim
  1951. else:
  1952. input_dims = i.type.tensor_type.shape.dim
  1953. for i_dim, dim in enumerate(input_shape):
  1954. if dim is None:
  1955. # some models use None for symbolic dim in input, replace it with a string
  1956. input_dims[i_dim].dim_param = str(self._new_symbolic_dim(i.name, i_dim))
  1957. self.input_symbols_.update([d for d in input_shape if type(d) == str])
  1958. for s in self.input_symbols_:
  1959. if s in self.suggested_merge_:
  1960. s_merge = self.suggested_merge_[s]
  1961. assert s_merge in self.symbolic_dims_
  1962. self.symbolic_dims_[s] = self.symbolic_dims_[s_merge]
  1963. else:
  1964. # Since inputs are not produced by other ops, we can assume positivity
  1965. self.symbolic_dims_[s] = sympy.Symbol(s, integer=True, positive=True)
  1966. # create a temporary ModelProto for single node inference
  1967. # note that we remove initializer to have faster inference
  1968. # for tensor ops like Reshape/Tile/Expand that read initializer, we need to do sympy computation based inference anyways
  1969. self.tmp_mp_ = onnx.ModelProto()
  1970. self.tmp_mp_.CopyFrom(self.out_mp_)
  1971. self.tmp_mp_.graph.ClearField("initializer")
  1972. # compute prerequesite for node for topological sort
  1973. # node with subgraphs may have dependency on implicit inputs, which will affect topological sort
  1974. prereq_for_node = {} # map from node to all its inputs, including implicit ones in subgraph
  1975. def get_prereq(node):
  1976. names = set(i for i in node.input if i)
  1977. subgraphs = []
  1978. if "If" == node.op_type:
  1979. subgraphs = [
  1980. get_attribute(node, "then_branch"),
  1981. get_attribute(node, "else_branch"),
  1982. ]
  1983. elif node.op_type in ["Loop", "Scan"]:
  1984. subgraphs = [get_attribute(node, "body")]
  1985. for g in subgraphs:
  1986. g_outputs_and_initializers = {i.name for i in g.initializer}
  1987. g_prereq = set()
  1988. for n in g.node:
  1989. g_outputs_and_initializers.update(n.output)
  1990. for n in g.node:
  1991. g_prereq.update([i for i in get_prereq(n) if i not in g_outputs_and_initializers])
  1992. names.update(g_prereq)
  1993. # remove subgraph inputs from g_prereq since those are local-only
  1994. for i in g.input:
  1995. if i.name in names:
  1996. names.remove(i.name)
  1997. return names
  1998. for n in self.tmp_mp_.graph.node:
  1999. prereq_for_node[n.output[0]] = get_prereq(n)
  2000. # topological sort nodes, note there might be dead nodes so we check if all graph outputs are reached to terminate
  2001. sorted_nodes = []
  2002. sorted_known_vi = set([i.name for i in list(self.out_mp_.graph.input) + list(self.out_mp_.graph.initializer)])
  2003. if any([o.name in sorted_known_vi for o in self.out_mp_.graph.output]):
  2004. # Loop/Scan will have some graph output in graph inputs, so don't do topological sort
  2005. sorted_nodes = self.out_mp_.graph.node
  2006. else:
  2007. while not all([o.name in sorted_known_vi for o in self.out_mp_.graph.output]):
  2008. old_sorted_nodes_len = len(sorted_nodes)
  2009. for node in self.out_mp_.graph.node:
  2010. if (node.output[0] not in sorted_known_vi) and all(
  2011. [i in sorted_known_vi for i in prereq_for_node[node.output[0]] if i]
  2012. ):
  2013. sorted_known_vi.update(node.output)
  2014. sorted_nodes.append(node)
  2015. if old_sorted_nodes_len == len(sorted_nodes) and not all(
  2016. [o.name in sorted_known_vi for o in self.out_mp_.graph.output]
  2017. ):
  2018. raise Exception("Invalid model with cyclic graph")
  2019. for node in sorted_nodes:
  2020. assert all([i in self.known_vi_ for i in node.input if i])
  2021. self._onnx_infer_single_node(node)
  2022. known_aten_op = False
  2023. if node.op_type in self.dispatcher_:
  2024. self.dispatcher_[node.op_type](node)
  2025. elif node.op_type in ["ConvTranspose"]:
  2026. # onnx shape inference ops like ConvTranspose may have empty shape for symbolic input
  2027. # before adding symbolic compute for them
  2028. # mark the output type as UNDEFINED to allow guessing of rank
  2029. vi = self.known_vi_[node.output[0]]
  2030. if len(vi.type.tensor_type.shape.dim) == 0:
  2031. vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
  2032. elif node.op_type == "ATen" and node.domain == "org.pytorch.aten":
  2033. for attr in node.attribute:
  2034. # TODO: Is overload_name needed?
  2035. if attr.name == "operator":
  2036. aten_op_name = attr.s.decode("utf-8") if isinstance(attr.s, bytes) else attr.s
  2037. if aten_op_name in self.aten_op_dispatcher_:
  2038. known_aten_op = True
  2039. self.aten_op_dispatcher_[aten_op_name](node)
  2040. break
  2041. if self.verbose_ > 2:
  2042. logger.debug(node.op_type + ": " + node.name)
  2043. for i, name in enumerate(node.input):
  2044. logger.debug(
  2045. " Input {}: {} {}".format(i, name, "initializer" if name in self.initializers_ else "")
  2046. )
  2047. # onnx automatically merge dims with value, i.e. Mul(['aaa', 'bbb'], [1000, 1]) -> [1000, 'bbb']
  2048. # symbolic shape inference needs to apply merge of 'aaa' -> 1000 in this case
  2049. if node.op_type in [
  2050. "Add",
  2051. "Sub",
  2052. "Mul",
  2053. "Div",
  2054. "MatMul",
  2055. "MatMulInteger",
  2056. "MatMulInteger16",
  2057. "Where",
  2058. "Sum",
  2059. ]:
  2060. vi = self.known_vi_[node.output[0]]
  2061. out_rank = len(get_shape_from_type_proto(vi.type))
  2062. in_shapes = [self._get_shape(node, i) for i in range(len(node.input))]
  2063. for d in range(out_rank - (2 if node.op_type in ["MatMul", "MatMulInteger", "MatMulInteger16"] else 0)):
  2064. in_dims = [s[len(s) - out_rank + d] for s in in_shapes if len(s) + d >= out_rank]
  2065. if len(in_dims) > 1:
  2066. self._check_merged_dims(in_dims, allow_broadcast=True)
  2067. for i_o in range(len(node.output)):
  2068. # Special case: We do not care about the training related
  2069. # outputs of SkipLayerNormalization
  2070. if node.op_type == "SkipLayerNormalization" and i_o in [1, 2]:
  2071. continue
  2072. vi = self.known_vi_[node.output[i_o]]
  2073. out_type = vi.type
  2074. out_type_kind = out_type.WhichOneof("value")
  2075. # do not process shape for non-tensors
  2076. if out_type_kind not in ["tensor_type", "sparse_tensor_type", None]:
  2077. if self.verbose_ > 2:
  2078. if out_type_kind == "sequence_type":
  2079. seq_cls_type = out_type.sequence_type.elem_type.WhichOneof("value")
  2080. if "tensor_type" == seq_cls_type:
  2081. logger.debug(
  2082. " {}: sequence of {} {}".format(
  2083. node.output[i_o],
  2084. str(get_shape_from_value_info(vi)),
  2085. onnx.TensorProto.DataType.Name(
  2086. vi.type.sequence_type.elem_type.tensor_type.elem_type
  2087. ),
  2088. )
  2089. )
  2090. else:
  2091. logger.debug(" {}: sequence of {}".format(node.output[i_o], seq_cls_type))
  2092. else:
  2093. logger.debug(" {}: {}".format(node.output[i_o], out_type_kind))
  2094. continue
  2095. out_shape = get_shape_from_value_info(vi)
  2096. out_type_undefined = out_type.tensor_type.elem_type == onnx.TensorProto.UNDEFINED
  2097. if self.verbose_ > 2:
  2098. logger.debug(
  2099. " {}: {} {}".format(
  2100. node.output[i_o],
  2101. str(out_shape),
  2102. onnx.TensorProto.DataType.Name(vi.type.tensor_type.elem_type),
  2103. )
  2104. )
  2105. if node.output[i_o] in self.sympy_data_:
  2106. logger.debug(" Sympy Data: " + str(self.sympy_data_[node.output[i_o]]))
  2107. # onnx >= 1.11.0, use unk__#index instead of None when the shape dim is uncertain
  2108. if (
  2109. out_shape is not None and (None in out_shape or self._is_shape_contains_none_dim(out_shape))
  2110. ) or out_type_undefined:
  2111. if self.auto_merge_:
  2112. if node.op_type in [
  2113. "Add",
  2114. "Sub",
  2115. "Mul",
  2116. "Div",
  2117. "MatMul",
  2118. "MatMulInteger",
  2119. "MatMulInteger16",
  2120. "Concat",
  2121. "Where",
  2122. "Sum",
  2123. "Equal",
  2124. "Less",
  2125. "Greater",
  2126. "LessOrEqual",
  2127. "GreaterOrEqual",
  2128. "Min",
  2129. "Max",
  2130. ]:
  2131. shapes = [self._get_shape(node, i) for i in range(len(node.input))]
  2132. if node.op_type in [
  2133. "MatMul",
  2134. "MatMulInteger",
  2135. "MatMulInteger16",
  2136. ]:
  2137. if None in out_shape or self._is_shape_contains_none_dim(out_shape):
  2138. if None in out_shape:
  2139. idx = out_shape.index(None)
  2140. else:
  2141. idx = out_shape.index(self._is_shape_contains_none_dim(out_shape))
  2142. dim_idx = [len(s) - len(out_shape) + idx for s in shapes]
  2143. # only support auto merge for MatMul for dim < rank-2 when rank > 2
  2144. assert len(shapes[0]) > 2 and dim_idx[0] < len(shapes[0]) - 2
  2145. assert len(shapes[1]) > 2 and dim_idx[1] < len(shapes[1]) - 2
  2146. elif node.op_type == "Expand":
  2147. # auto merge for cases like Expand([min(batch, 1), min(seq, 512)], [batch, seq])
  2148. shapes = [
  2149. self._get_shape(node, 0),
  2150. self._get_value(node, 1),
  2151. ]
  2152. else:
  2153. shapes = []
  2154. if shapes:
  2155. for idx in range(len(out_shape)):
  2156. if out_shape[idx] is not None and not self._is_none_dim(out_shape[idx]):
  2157. continue
  2158. # note that the broadcasting rule aligns from right to left
  2159. # if a tensor has a lower rank (dim_idx[idx] < 0), it would automatically broadcast and need no merge
  2160. dim_idx = [len(s) - len(out_shape) + idx for s in shapes]
  2161. if len(dim_idx) > 0:
  2162. self._add_suggested_merge(
  2163. [
  2164. s[i] if is_literal(s[i]) else str(s[i])
  2165. for s, i in zip(shapes, dim_idx)
  2166. if i >= 0
  2167. ]
  2168. )
  2169. self.run_ = True
  2170. else:
  2171. self.run_ = False
  2172. else:
  2173. self.run_ = False
  2174. # create new dynamic dims for ops not handled by symbolic shape inference
  2175. if self.run_ == False and not node.op_type in self.dispatcher_ and not known_aten_op:
  2176. is_unknown_op = out_type_undefined and (out_shape is None or len(out_shape) == 0)
  2177. if is_unknown_op:
  2178. # unknown op to ONNX, maybe from higher opset or other domain
  2179. # only guess the output rank from input 0 when using guess_output_rank option
  2180. out_rank = self._get_shape_rank(node, 0) if self.guess_output_rank_ else -1
  2181. else:
  2182. # valid ONNX op, but not handled by symbolic shape inference, just assign dynamic shape
  2183. out_rank = len(out_shape)
  2184. if out_rank >= 0:
  2185. new_shape = self._new_symbolic_shape(out_rank, node, i_o)
  2186. if out_type_undefined:
  2187. # guess output data type from input vi if not defined
  2188. out_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
  2189. else:
  2190. # otherwise, use original data type
  2191. out_dtype = vi.type.tensor_type.elem_type
  2192. vi.CopyFrom(
  2193. helper.make_tensor_value_info(
  2194. vi.name,
  2195. out_dtype,
  2196. get_shape_from_sympy_shape(new_shape),
  2197. )
  2198. )
  2199. if self.verbose_ > 0:
  2200. if is_unknown_op:
  2201. logger.debug(
  2202. "Possible unknown op: {} node: {}, guessing {} shape".format(
  2203. node.op_type, node.name, vi.name
  2204. )
  2205. )
  2206. if self.verbose_ > 2:
  2207. logger.debug(
  2208. " {}: {} {}".format(
  2209. node.output[i_o],
  2210. str(new_shape),
  2211. vi.type.tensor_type.elem_type,
  2212. )
  2213. )
  2214. self.run_ = True
  2215. continue # continue the inference after guess, no need to stop as no merge is needed
  2216. if self.verbose_ > 0 or not self.auto_merge_ or out_type_undefined:
  2217. logger.debug("Stopping at incomplete shape inference at " + node.op_type + ": " + node.name)
  2218. logger.debug("node inputs:")
  2219. for i in node.input:
  2220. logger.debug(self.known_vi_[i])
  2221. logger.debug("node outputs:")
  2222. for o in node.output:
  2223. logger.debug(self.known_vi_[o])
  2224. if self.auto_merge_ and not out_type_undefined:
  2225. logger.debug("Merging: " + str(self.suggested_merge_))
  2226. return False
  2227. self.run_ = False
  2228. return True
  2229. def _update_output_from_vi(self):
  2230. for output in self.out_mp_.graph.output:
  2231. if output.name in self.known_vi_:
  2232. output.CopyFrom(self.known_vi_[output.name])
  2233. @staticmethod
  2234. def infer_shapes(in_mp, int_max=2**31 - 1, auto_merge=False, guess_output_rank=False, verbose=0):
  2235. onnx_opset = get_opset(in_mp)
  2236. if (not onnx_opset) or onnx_opset < 7:
  2237. logger.warning("Only support models of onnx opset 7 and above.")
  2238. return None
  2239. symbolic_shape_inference = SymbolicShapeInference(int_max, auto_merge, guess_output_rank, verbose)
  2240. all_shapes_inferred = False
  2241. symbolic_shape_inference._preprocess(in_mp)
  2242. while symbolic_shape_inference.run_:
  2243. all_shapes_inferred = symbolic_shape_inference._infer_impl()
  2244. symbolic_shape_inference._update_output_from_vi()
  2245. if not all_shapes_inferred:
  2246. onnx.save_model(symbolic_shape_inference.out_mp_, "sym_shape_infer_temp.onnx", save_as_external_data=True)
  2247. raise Exception("Incomplete symbolic shape inference")
  2248. return symbolic_shape_inference.out_mp_
  2249. def parse_arguments():
  2250. parser = argparse.ArgumentParser()
  2251. parser.add_argument("--input", required=True, help="The input model file")
  2252. parser.add_argument("--output", help="The output model file")
  2253. parser.add_argument(
  2254. "--auto_merge",
  2255. help="Automatically merge symbolic dims when confliction happens",
  2256. action="store_true",
  2257. default=False,
  2258. )
  2259. parser.add_argument(
  2260. "--int_max",
  2261. help="maximum value for integer to be treated as boundless for ops like slice",
  2262. type=int,
  2263. default=2**31 - 1,
  2264. )
  2265. parser.add_argument(
  2266. "--guess_output_rank",
  2267. help="guess output rank to be the same as input 0 for unknown ops",
  2268. action="store_true",
  2269. default=False,
  2270. )
  2271. parser.add_argument(
  2272. "--verbose",
  2273. help="Prints detailed logs of inference, 0: turn off, 1: warnings, 3: detailed",
  2274. type=int,
  2275. default=0,
  2276. )
  2277. parser.add_argument(
  2278. "--save_as_external_data",
  2279. help="Saving an ONNX model to external data",
  2280. action="store_true",
  2281. default=False,
  2282. )
  2283. parser.add_argument(
  2284. "--all_tensors_to_one_file",
  2285. help="Saving all the external data to one file",
  2286. action="store_true",
  2287. default=False,
  2288. )
  2289. parser.add_argument(
  2290. "--external_data_location",
  2291. help="The file location to save the external file",
  2292. default="./",
  2293. )
  2294. parser.add_argument(
  2295. "--external_data_size_threshold",
  2296. help="The size threshold for external data",
  2297. type=int,
  2298. default=1024,
  2299. )
  2300. return parser.parse_args()
  2301. if __name__ == "__main__":
  2302. args = parse_arguments()
  2303. logger.info("input model: " + args.input)
  2304. if args.output:
  2305. logger.info("output model " + args.output)
  2306. logger.info("Doing symbolic shape inference...")
  2307. out_mp = SymbolicShapeInference.infer_shapes(
  2308. onnx.load(args.input),
  2309. args.int_max,
  2310. args.auto_merge,
  2311. args.guess_output_rank,
  2312. args.verbose,
  2313. )
  2314. if args.output and out_mp:
  2315. if args.save_as_external_data:
  2316. onnx.save_model(
  2317. out_mp,
  2318. args.output,
  2319. save_as_external_data=True,
  2320. all_tensors_to_one_file=args.all_tensors_to_one_file,
  2321. location=args.external_data_location,
  2322. size_threshold=args.external_data_size_threshold,
  2323. convert_attribute=False,
  2324. )
  2325. else:
  2326. onnx.save(out_mp, args.output)
  2327. logger.info("Done!")