You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
117 lines
4.1 KiB
117 lines
4.1 KiB
#coding:utf8
|
|
import queue_manager
|
|
import logging
|
|
from dataUtil import get_value
|
|
import uuid, json, traceback
|
|
import torch
|
|
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
|
|
from langdetect import detect
|
|
import global_dict
|
|
import time
|
|
# 初始化日志
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def translate_process():
|
|
"""独立线程处理队列中的 翻译 任务"""
|
|
logger.info("翻译线程启动----")
|
|
device = torch.device("cuda")
|
|
|
|
try:
|
|
# 加载模型和分词器
|
|
logger.info("GPU 加载翻译模型")
|
|
model = M2M100ForConditionalGeneration.from_pretrained("/opt/m2m100_1.2B/model").to(device)
|
|
tokenizer = M2M100Tokenizer.from_pretrained("/opt/m2m100_1.2B/tokenizer")
|
|
except Exception as e:
|
|
logger.error(f"加载模型或分词器失败: {e}")
|
|
return
|
|
|
|
while global_dict.is_start:
|
|
# 获取任务
|
|
size = queue_manager.get_size()
|
|
if size> 0 :
|
|
task = queue_manager.get_task()
|
|
else:
|
|
logger.info('队列暂无任务-----')
|
|
time.sleep(3)
|
|
continue
|
|
result, results = {}, {}
|
|
try:
|
|
logger.info('task size:{},task:{}'.format(queue_manager.get_size(),task))
|
|
# 根据版本号判断
|
|
scenes_id = str(task['scenes_id'])
|
|
task_version = str(task['version'])
|
|
cache_version = global_dict.global_scenes_manager[scenes_id]
|
|
if not task_version == cache_version:
|
|
logger.info('任务已暂停:{}'.format(task))
|
|
continue
|
|
|
|
preTrContent = get_value(task['data'], task['input']['content'])
|
|
from_language = task['input']['fromLanguage']
|
|
to_language = task['input']['toLanguage']
|
|
|
|
# 1. 按句子切分
|
|
text_chunks = split_text(preTrContent)
|
|
|
|
# 2. 逐段翻译
|
|
translated_chunks = []
|
|
for chunk in text_chunks:
|
|
translated_text = translate_text(model, tokenizer, chunk, from_language, to_language, device)
|
|
translated_chunks.append(translated_text)
|
|
|
|
# 3. 合并翻译结果
|
|
translated_text = "".join(translated_chunks)
|
|
results.update({
|
|
'isLast': True,
|
|
'content': translated_text,
|
|
'srcContent': preTrContent,
|
|
'id': str(uuid.uuid4())
|
|
})
|
|
result.update({
|
|
'results': json.dumps(results),
|
|
'status': 1,
|
|
'message': '成功'
|
|
})
|
|
task['result'] = result
|
|
except Exception as e:
|
|
logger.error(f"翻译失败: {e}")
|
|
traceback.print_exc()
|
|
results.update({
|
|
'isLast': True,
|
|
'id': str(uuid.uuid4())
|
|
})
|
|
result.update({
|
|
'results': json.dumps(results),
|
|
'status': 2,
|
|
'message': '翻译失败'
|
|
})
|
|
finally:
|
|
# 标记任务完成并发送到 Kafka
|
|
queue_manager.task_done(task)
|
|
# 可选:清理缓存
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
def split_text(text):
|
|
return [s.strip() for s in text.replace(".", ".\n").replace("!", "!\n").replace("?", "?\n").splitlines() if
|
|
s.strip()]
|
|
|
|
|
|
|
|
def translate_text(model, tokenizer, text_chunk, src_lang, tgt_lang, device):
|
|
if 'auto' in src_lang:
|
|
# 自动检测源语言
|
|
src_lang = detect(text_chunk)
|
|
logging.info('语种未知,模型自动识别语种为:{}'.format(src_lang))
|
|
tokenizer.src_lang = src_lang
|
|
try:
|
|
with torch.no_grad(): # 禁用梯度计算
|
|
encoded_input = tokenizer(text_chunk, return_tensors="pt", truncation=True, max_length=900)
|
|
encoded_input = {k: v.to(device) for k, v in encoded_input.items()}
|
|
generated_tokens = model.generate(**encoded_input, forced_bos_token_id=tokenizer.get_lang_id(tgt_lang))
|
|
translated_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
|
|
return translated_text
|
|
except Exception as e:
|
|
logger.error(f"翻译过程中出错: {e}")
|
|
return ""
|
|
|