m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

117 lines
4.1 KiB

6 months ago
  1. #coding:utf8
  2. import queue_manager
  3. import logging
  4. from dataUtil import get_value
  5. import uuid, json, traceback
  6. import torch
  7. from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
  8. from langdetect import detect
  9. import global_dict
  10. import time
  11. # 初始化日志
  12. logger = logging.getLogger(__name__)
  13. def translate_process():
  14. """独立线程处理队列中的 翻译 任务"""
  15. logger.info("翻译线程启动----")
  16. device = torch.device("cuda")
  17. try:
  18. # 加载模型和分词器
  19. logger.info("GPU 加载翻译模型")
  20. model = M2M100ForConditionalGeneration.from_pretrained("/opt/m2m100_1.2B/model").to(device)
  21. tokenizer = M2M100Tokenizer.from_pretrained("/opt/m2m100_1.2B/tokenizer")
  22. except Exception as e:
  23. logger.error(f"加载模型或分词器失败: {e}")
  24. return
  25. while global_dict.is_start:
  26. # 获取任务
  27. size = queue_manager.get_size()
  28. if size> 0 :
  29. task = queue_manager.get_task()
  30. else:
  31. logger.info('队列暂无任务-----')
  32. time.sleep(3)
  33. continue
  34. result, results = {}, {}
  35. try:
  36. logger.info('task size:{},task:{}'.format(queue_manager.get_size(),task))
  37. # 根据版本号判断
  38. scenes_id = str(task['scenes_id'])
  39. task_version = str(task['version'])
  40. cache_version = global_dict.global_scenes_manager[scenes_id]
  41. if not task_version == cache_version:
  42. logger.info('任务已暂停:{}'.format(task))
  43. continue
  44. preTrContent = get_value(task['data'], task['input']['content'])
  45. from_language = task['input']['fromLanguage']
  46. to_language = task['input']['toLanguage']
  47. # 1. 按句子切分
  48. text_chunks = split_text(preTrContent)
  49. # 2. 逐段翻译
  50. translated_chunks = []
  51. for chunk in text_chunks:
  52. translated_text = translate_text(model, tokenizer, chunk, from_language, to_language, device)
  53. translated_chunks.append(translated_text)
  54. # 3. 合并翻译结果
  55. translated_text = "".join(translated_chunks)
  56. results.update({
  57. 'isLast': True,
  58. 'content': translated_text,
  59. 'srcContent': preTrContent,
  60. 'id': str(uuid.uuid4())
  61. })
  62. result.update({
  63. 'results': json.dumps(results),
  64. 'status': 1,
  65. 'message': '成功'
  66. })
  67. task['result'] = result
  68. except Exception as e:
  69. logger.error(f"翻译失败: {e}")
  70. traceback.print_exc()
  71. results.update({
  72. 'isLast': True,
  73. 'id': str(uuid.uuid4())
  74. })
  75. result.update({
  76. 'results': json.dumps(results),
  77. 'status': 2,
  78. 'message': '翻译失败'
  79. })
  80. finally:
  81. # 标记任务完成并发送到 Kafka
  82. queue_manager.task_done(task)
  83. # 可选:清理缓存
  84. torch.cuda.empty_cache()
  85. def split_text(text):
  86. return [s.strip() for s in text.replace(".", ".\n").replace("!", "!\n").replace("?", "?\n").splitlines() if
  87. s.strip()]
  88. def translate_text(model, tokenizer, text_chunk, src_lang, tgt_lang, device):
  89. if 'auto' in src_lang:
  90. # 自动检测源语言
  91. src_lang = detect(text_chunk)
  92. logging.info('语种未知,模型自动识别语种为:{}'.format(src_lang))
  93. tokenizer.src_lang = src_lang
  94. try:
  95. with torch.no_grad(): # 禁用梯度计算
  96. encoded_input = tokenizer(text_chunk, return_tensors="pt", truncation=True, max_length=900)
  97. encoded_input = {k: v.to(device) for k, v in encoded_input.items()}
  98. generated_tokens = model.generate(**encoded_input, forced_bos_token_id=tokenizer.get_lang_id(tgt_lang))
  99. translated_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
  100. return translated_text
  101. except Exception as e:
  102. logger.error(f"翻译过程中出错: {e}")
  103. return ""