图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

67 lines
2.7 KiB

  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from collections import defaultdict
  6. from logging import getLogger
  7. from typing import List, Union
  8. from onnx_model import OnnxModel
  9. logger = getLogger(__name__)
  10. class Fusion:
  11. def __init__(
  12. self,
  13. model: OnnxModel,
  14. fused_op_type: str,
  15. search_op_types: Union[str, List[str]],
  16. description: str = None,
  17. ):
  18. self.search_op_types: List[str] = [search_op_types] if isinstance(search_op_types, str) else search_op_types
  19. self.fused_op_type: str = fused_op_type
  20. self.description: str = f"{fused_op_type}({description})" if description else fused_op_type
  21. self.model: OnnxModel = model
  22. self.nodes_to_remove: List = []
  23. self.nodes_to_add: List = []
  24. self.prune_graph: bool = False
  25. self.node_name_to_graph_name: dict = {}
  26. self.this_graph_name: str = None
  27. # It is optional that subclass updates fused_count since we will also check nodes_to_add to get counter.
  28. self.fused_count: defaultdict = defaultdict(int)
  29. def increase_counter(self, fused_op_name):
  30. self.fused_count[fused_op_name] += 1
  31. def apply(self):
  32. logger.debug(f"start {self.description} fusion...")
  33. input_name_to_nodes = self.model.input_name_to_nodes()
  34. output_name_to_node = self.model.output_name_to_node()
  35. # This assumes that two search ops will not be fused at same time!
  36. for search_op_type in self.search_op_types:
  37. for node in self.model.get_nodes_by_op_type(search_op_type):
  38. graph = self.model.get_graph_by_node(node)
  39. if graph is None:
  40. raise Exception("Can not find node in any graphs")
  41. self.this_graph_name = graph.name
  42. self.fuse(node, input_name_to_nodes, output_name_to_node)
  43. op_list = [node.op_type for node in self.nodes_to_add]
  44. if self.fused_count:
  45. for key, value in self.fused_count.items():
  46. if value:
  47. logger.info(f"Fused {key} count: {value}")
  48. else:
  49. count = op_list.count(self.fused_op_type)
  50. if count > 0:
  51. logger.info(f"Fused {self.description} count: {count}")
  52. self.model.remove_nodes(self.nodes_to_remove)
  53. self.model.add_nodes(self.nodes_to_add, self.node_name_to_graph_name)
  54. if self.prune_graph:
  55. self.model.prune_graph()
  56. elif self.nodes_to_remove or self.nodes_to_add:
  57. self.model.update_graph()