m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

85 lines
3.2 KiB

6 months ago
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. # An offline standalone script to declassify an ONNX model by randomizing the tensor data in initializers.
  6. # The ORT Performance may change especially on generative models.
  7. import argparse
  8. from pathlib import Path
  9. import numpy as np
  10. from onnx import load_model, numpy_helper, onnx_pb, save_model
  11. # An experimental small value for differentiating shape data and weights.
  12. # The tensor data with larger size can't be shape data.
  13. # User may adjust this value as needed.
  14. SIZE_THRESHOLD = 10
  15. def graph_iterator(model, func):
  16. graph_queue = [model.graph]
  17. while graph_queue:
  18. graph = graph_queue.pop(0)
  19. func(graph)
  20. for node in graph.node:
  21. for attr in node.attribute:
  22. if attr.type == onnx_pb.AttributeProto.AttributeType.GRAPH:
  23. assert isinstance(attr.g, onnx_pb.GraphProto)
  24. graph_queue.append(attr.g)
  25. if attr.type == onnx_pb.AttributeProto.AttributeType.GRAPHS:
  26. for g in attr.graphs:
  27. assert isinstance(g, onnx_pb.GraphProto)
  28. graph_queue.append(g)
  29. def randomize_graph_initializer(graph):
  30. for i_tensor in graph.initializer:
  31. array = numpy_helper.to_array(i_tensor)
  32. # TODO: need to find a better way to differentiate shape data and weights.
  33. if array.size > SIZE_THRESHOLD:
  34. random_array = np.random.uniform(array.min(), array.max(), size=array.shape).astype(array.dtype)
  35. o_tensor = numpy_helper.from_array(random_array, i_tensor.name)
  36. i_tensor.CopyFrom(o_tensor)
  37. def main():
  38. parser = argparse.ArgumentParser(description="Randomize the weights of an ONNX model")
  39. parser.add_argument("-m", type=str, required=True, help="input onnx model path")
  40. parser.add_argument("-o", type=str, required=True, help="output onnx model path")
  41. parser.add_argument(
  42. "--use_external_data_format",
  43. required=False,
  44. action="store_true",
  45. help="Store or Save in external data format",
  46. )
  47. parser.add_argument(
  48. "--all_tensors_to_one_file",
  49. required=False,
  50. action="store_true",
  51. help="Save all tensors to one file",
  52. )
  53. args = parser.parse_args()
  54. data_path = None
  55. if args.use_external_data_format:
  56. if Path(args.m).parent == Path(args.o).parent:
  57. raise RuntimeError("Please specify output directory with different parent path to input directory.")
  58. if args.all_tensors_to_one_file:
  59. data_path = Path(args.o).name + ".data"
  60. Path(args.o).parent.mkdir(parents=True, exist_ok=True)
  61. onnx_model = load_model(args.m, load_external_data=args.use_external_data_format)
  62. graph_iterator(onnx_model, randomize_graph_initializer)
  63. save_model(
  64. onnx_model,
  65. args.o,
  66. save_as_external_data=args.use_external_data_format,
  67. all_tensors_to_one_file=args.all_tensors_to_one_file,
  68. location=data_path,
  69. )
  70. if __name__ == "__main__":
  71. main()