m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

117 lines
4.1 KiB

#coding:utf8
import queue_manager
import logging
from dataUtil import get_value
import uuid, json, traceback
import torch
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
from langdetect import detect
import global_dict
import time
# 初始化日志
logger = logging.getLogger(__name__)
def translate_process():
"""独立线程处理队列中的 翻译 任务"""
logger.info("翻译线程启动----")
device = torch.device("cuda")
try:
# 加载模型和分词器
logger.info("GPU 加载翻译模型")
model = M2M100ForConditionalGeneration.from_pretrained("/opt/m2m100_1.2B/model").to(device)
tokenizer = M2M100Tokenizer.from_pretrained("/opt/m2m100_1.2B/tokenizer")
except Exception as e:
logger.error(f"加载模型或分词器失败: {e}")
return
while global_dict.is_start:
# 获取任务
size = queue_manager.get_size()
if size> 0 :
task = queue_manager.get_task()
else:
logger.info('队列暂无任务-----')
time.sleep(3)
continue
result, results = {}, {}
try:
logger.info('task size:{},task:{}'.format(queue_manager.get_size(),task))
# 根据版本号判断
scenes_id = str(task['scenes_id'])
task_version = str(task['version'])
cache_version = global_dict.global_scenes_manager[scenes_id]
if not task_version == cache_version:
logger.info('任务已暂停:{}'.format(task))
continue
preTrContent = get_value(task['data'], task['input']['content'])
from_language = task['input']['fromLanguage']
to_language = task['input']['toLanguage']
# 1. 按句子切分
text_chunks = split_text(preTrContent)
# 2. 逐段翻译
translated_chunks = []
for chunk in text_chunks:
translated_text = translate_text(model, tokenizer, chunk, from_language, to_language, device)
translated_chunks.append(translated_text)
# 3. 合并翻译结果
translated_text = "".join(translated_chunks)
results.update({
'isLast': True,
'content': translated_text,
'srcContent': preTrContent,
'id': str(uuid.uuid4())
})
result.update({
'results': json.dumps(results),
'status': 1,
'message': '成功'
})
task['result'] = result
except Exception as e:
logger.error(f"翻译失败: {e}")
traceback.print_exc()
results.update({
'isLast': True,
'id': str(uuid.uuid4())
})
result.update({
'results': json.dumps(results),
'status': 2,
'message': '翻译失败'
})
finally:
# 标记任务完成并发送到 Kafka
queue_manager.task_done(task)
# 可选:清理缓存
torch.cuda.empty_cache()
def split_text(text):
return [s.strip() for s in text.replace(".", ".\n").replace("!", "!\n").replace("?", "?\n").splitlines() if
s.strip()]
def translate_text(model, tokenizer, text_chunk, src_lang, tgt_lang, device):
if 'auto' in src_lang:
# 自动检测源语言
src_lang = detect(text_chunk)
logging.info('语种未知,模型自动识别语种为:{}'.format(src_lang))
tokenizer.src_lang = src_lang
try:
with torch.no_grad(): # 禁用梯度计算
encoded_input = tokenizer(text_chunk, return_tensors="pt", truncation=True, max_length=900)
encoded_input = {k: v.to(device) for k, v in encoded_input.items()}
generated_tokens = model.generate(**encoded_input, forced_bos_token_id=tokenizer.get_lang_id(tgt_lang))
translated_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
return translated_text
except Exception as e:
logger.error(f"翻译过程中出错: {e}")
return ""