m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

232 lines
8.0 KiB

6 months ago
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. import argparse
  6. import os
  7. import time
  8. SD_MODELS = {
  9. "1.5": "runwayml/stable-diffusion-v1-5",
  10. "2.0": "stabilityai/stable-diffusion-2",
  11. "2.1": "stabilityai/stable-diffusion-2-1",
  12. }
  13. def get_test_settings():
  14. height = 512
  15. width = 512
  16. num_inference_steps = 50
  17. prompts = [
  18. "a photo of an astronaut riding a horse on mars",
  19. "cute grey cat with blue eyes, wearing a bowtie, acrylic painting",
  20. "a cute magical flying dog, fantasy art drawn by disney concept artists, highly detailed, digital painting",
  21. "an illustration of a house with large barn with many cute flower pots and beautiful blue sky scenery",
  22. "one apple sitting on a table, still life, reflective, full color photograph, centered, close-up product",
  23. "background texture of stones, masterpiece, artistic, stunning photo, award winner photo",
  24. "new international organic style house, tropical surroundings, architecture, 8k, hdr",
  25. "beautiful Renaissance Revival Estate, Hobbit-House, detailed painting, warm colors, 8k, trending on Artstation",
  26. "blue owl, big green eyes, portrait, intricate metal design, unreal engine, octane render, realistic",
  27. "delicate elvish moonstone necklace on a velvet background, symmetrical intricate motifs, leaves, flowers, 8k",
  28. ]
  29. return height, width, num_inference_steps, prompts
  30. def get_ort_pipeline(model_name: str, directory: str, provider: str, disable_safety_checker: bool):
  31. from diffusers import OnnxStableDiffusionPipeline
  32. import onnxruntime
  33. if directory is not None:
  34. assert os.path.exists(directory)
  35. session_options = onnxruntime.SessionOptions()
  36. pipe = OnnxStableDiffusionPipeline.from_pretrained(
  37. directory,
  38. provider=provider,
  39. sess_options=session_options,
  40. )
  41. else:
  42. pipe = OnnxStableDiffusionPipeline.from_pretrained(
  43. model_name,
  44. revision="onnx",
  45. provider=provider,
  46. use_auth_token=True,
  47. )
  48. if disable_safety_checker:
  49. pipe.safety_checker = None
  50. pipe.feature_extractor = None
  51. return pipe
  52. def get_torch_pipeline(model_name: str, disable_safety_checker: bool):
  53. from diffusers import StableDiffusionPipeline
  54. from torch import channels_last, float16
  55. pipe = StableDiffusionPipeline.from_pretrained(
  56. model_name, torch_dtype=float16, revision="fp16", use_auth_token=True
  57. ).to("cuda")
  58. pipe.unet.to(memory_format=channels_last) # in-place operation
  59. if disable_safety_checker:
  60. pipe.safety_checker = None
  61. pipe.feature_extractor = None
  62. return pipe
  63. def get_image_filename_prefix(engine: str, model_name: str, batch_size: int, disable_safety_checker: bool):
  64. short_model_name = model_name.split("/")[-1].replace("stable-diffusion-", "sd")
  65. return f"{engine}_{short_model_name}_b{batch_size}" + ("" if disable_safety_checker else "_safe")
  66. def run_ort_pipeline(pipe, batch_size: int, image_filename_prefix: str):
  67. from diffusers import OnnxStableDiffusionPipeline
  68. assert isinstance(pipe, OnnxStableDiffusionPipeline)
  69. height, width, num_inference_steps, prompts = get_test_settings()
  70. pipe("warm up", height, width, num_inference_steps=2)
  71. latency_list = []
  72. for i, prompt in enumerate(prompts):
  73. input_prompts = [prompt] * batch_size
  74. inference_start = time.time()
  75. image = pipe(input_prompts, height, width, num_inference_steps).images[0]
  76. inference_end = time.time()
  77. latency = inference_end - inference_start
  78. latency_list.append(latency)
  79. print(f"Inference took {latency} seconds")
  80. image.save(f"{image_filename_prefix}_{i}.jpg")
  81. print("Average latency in seconds:", sum(latency_list) / len(latency_list))
  82. def run_torch_pipeline(pipe, batch_size: int, image_filename_prefix: str):
  83. import torch
  84. height, width, num_inference_steps, prompts = get_test_settings()
  85. pipe("warm up", height, width, num_inference_steps=2)
  86. torch.set_grad_enabled(False)
  87. latency_list = []
  88. for i, prompt in enumerate(prompts):
  89. input_prompts = [prompt] * batch_size
  90. torch.cuda.synchronize()
  91. inference_start = time.time()
  92. image = pipe(input_prompts, height, width, num_inference_steps).images[0]
  93. torch.cuda.synchronize()
  94. inference_end = time.time()
  95. latency = inference_end - inference_start
  96. latency_list.append(latency)
  97. print(f"Inference took {latency} seconds")
  98. image.save(f"{image_filename_prefix}_{i}.jpg")
  99. print("Average latency in seconds:", sum(latency_list) / len(latency_list))
  100. def run_ort(model_name: str, directory: str, provider: str, batch_size: int, disable_safety_checker: bool):
  101. load_start = time.time()
  102. pipe = get_ort_pipeline(model_name, directory, provider, disable_safety_checker)
  103. load_end = time.time()
  104. print(f"Model loading took {load_end - load_start} seconds")
  105. image_filename_prefix = get_image_filename_prefix("ort", model_name, batch_size, disable_safety_checker)
  106. run_ort_pipeline(pipe, batch_size, image_filename_prefix)
  107. def run_torch(model_name: str, batch_size: int, disable_safety_checker: bool):
  108. import torch
  109. torch.backends.cudnn.enabled = True
  110. torch.backends.cudnn.benchmark = True
  111. # torch.backends.cuda.matmul.allow_tf32 = True
  112. torch.set_grad_enabled(False)
  113. load_start = time.time()
  114. pipe = get_torch_pipeline(model_name, disable_safety_checker)
  115. load_end = time.time()
  116. print(f"Model loading took {load_end - load_start} seconds")
  117. image_filename_prefix = get_image_filename_prefix("torch", model_name, batch_size, disable_safety_checker)
  118. with torch.inference_mode():
  119. run_torch_pipeline(pipe, batch_size, image_filename_prefix)
  120. def parse_arguments():
  121. parser = argparse.ArgumentParser()
  122. parser.add_argument(
  123. "-e",
  124. "--engine",
  125. required=False,
  126. type=str,
  127. default="onnxruntime",
  128. choices=["onnxruntime", "torch"],
  129. help="Engines to benchmark. Default is onnxruntime.",
  130. )
  131. parser.add_argument(
  132. "-v",
  133. "--version",
  134. required=True,
  135. type=str,
  136. choices=list(SD_MODELS.keys()),
  137. help="Stable diffusion version like 1.5, 2.0 or 2.1",
  138. )
  139. parser.add_argument(
  140. "-p",
  141. "--pipeline",
  142. required=False,
  143. type=str,
  144. default=None,
  145. help="Directory of saved onnx pipeline. It could be output directory of optimize_pipeline.py.",
  146. )
  147. parser.add_argument(
  148. "--enable_safety_checker",
  149. required=False,
  150. action="store_true",
  151. help="Enable safety checker",
  152. )
  153. parser.set_defaults(enable_safety_checker=False)
  154. parser.add_argument("-b", "--batch_size", type=int, default=1)
  155. args = parser.parse_args()
  156. return args
  157. def main():
  158. args = parse_arguments()
  159. print(args)
  160. sd_model = SD_MODELS[args.version]
  161. if args.engine == "onnxruntime":
  162. assert args.pipeline, "--pipeline should be specified for onnxruntime engine"
  163. if args.batch_size > 1:
  164. # Need remove a line https://github.com/huggingface/diffusers/blob/a66f2baeb782e091dde4e1e6394e46f169e5ba58/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L307
  165. # in diffuers to run batch_size > 1.
  166. assert (
  167. args.enable_safety_checker
  168. ), "batch_size > 1 is not compatible with safety checker due to a bug in diffuers"
  169. provider = "CUDAExecutionProvider" # TODO: use ["CUDAExecutionProvider", "CPUExecutionProvider"] in diffuers
  170. run_ort(sd_model, args.pipeline, provider, args.batch_size, not args.enable_safety_checker)
  171. else:
  172. run_torch(sd_model, args.batch_size, not args.enable_safety_checker)
  173. if __name__ == "__main__":
  174. main()