You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
136 lines
4.3 KiB
136 lines
4.3 KiB
#coding:utf8
|
|
import queue_manager
|
|
import logging
|
|
from cnocr import CnOcr
|
|
import onnxruntime as ort
|
|
from dataUtil import get_value
|
|
import uuid
|
|
import json
|
|
import requests
|
|
import os
|
|
from global_dict import global_scenes_manager
|
|
import global_dict
|
|
import time
|
|
# 初始化日志
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# 初始化 OCR 实例
|
|
ocr = CnOcr()
|
|
def ocr_process():
|
|
"""独立线程处理队列中的 OCR 任务"""
|
|
logger.info("ocr线程启动----")
|
|
while global_dict.is_start:
|
|
result = {}
|
|
results = {}
|
|
save_path = ''
|
|
# 获取任务
|
|
size = queue_manager.get_size()
|
|
if size> 0 :
|
|
task = queue_manager.get_task()
|
|
else:
|
|
logger.info('队列暂无任务-----')
|
|
time.sleep(3)
|
|
continue
|
|
try:
|
|
logger.info('task size:{},task:{}'.format(size,task))
|
|
# 根据版本号判断
|
|
scenes_id = str(task['scenes_id'])
|
|
task_version = str(task['version'])
|
|
cache_version = global_scenes_manager[scenes_id]
|
|
if not task_version == cache_version:
|
|
logger.info('任务已暂停:{}'.format(task))
|
|
continue
|
|
filePathFormula = task['input']['filePath']
|
|
data = task['data']
|
|
img_path_url = get_value(data,filePathFormula)
|
|
file_name = str(uuid.uuid4())
|
|
extension = get_file_extension(img_path_url)
|
|
save_path = './files/{}.{}'.format(file_name,extension)
|
|
download_file(img_path_url,save_path)
|
|
# 执行 OCR 识别
|
|
logger.info(f"识别开始-----")
|
|
identification_result = ocr.ocr(save_path)
|
|
text = ''
|
|
for item in identification_result:
|
|
text += item['text']
|
|
|
|
results['isLast'] = True
|
|
results['content'] = text
|
|
results['id'] = file_name
|
|
|
|
result['results'] = json.dumps(results)
|
|
result['status'] = 1
|
|
result['message'] = '成功'
|
|
|
|
task['result'] = result
|
|
except Exception as e:
|
|
logger.error(f"Error processing OCR task: {e}")
|
|
results['isLast'] = True
|
|
id = str(uuid.uuid4())
|
|
results['id'] = id
|
|
|
|
result['results'] = json.dumps(results)
|
|
result['status'] = 2
|
|
result['message'] = '识别失败'
|
|
# 标记任务完成并发送到 Kafka
|
|
delete_file(save_path)
|
|
queue_manager.task_done(task)
|
|
else:
|
|
logger.info("执行线程安全退出-----")
|
|
|
|
|
|
def download_file(url, save_path):
|
|
"""
|
|
下载文件并保存到指定路径。
|
|
|
|
:param url: 文件的下载链接
|
|
:param save_path: 保存文件的完整路径(包括文件名)
|
|
"""
|
|
try:
|
|
# 发送 HTTP GET 请求下载文件
|
|
response = requests.get(url, stream=True)
|
|
response.raise_for_status() # 检查请求是否成功
|
|
|
|
# 将文件写入指定的保存路径
|
|
with open(save_path, 'wb') as file:
|
|
for chunk in response.iter_content(chunk_size=8192):
|
|
if chunk:
|
|
file.write(chunk)
|
|
|
|
logger.info(f"文件已成功下载并保存到: {save_path}")
|
|
except requests.exceptions.RequestException as e:
|
|
logger.error(f"文件下载失败: {e}")
|
|
|
|
|
|
def get_file_extension(url):
|
|
# 找到最后一个 '.' 的位置
|
|
dot_index = url.rfind('.')
|
|
# 找到 '?' 或 '#' 的位置(如果有的话),这些符号通常用于查询参数或锚点
|
|
query_index = url.find('?', dot_index)
|
|
hash_index = url.find('#', dot_index)
|
|
|
|
# 确定扩展名的结束位置
|
|
end_index = min(query_index if query_index != -1 else len(url),
|
|
hash_index if hash_index != -1 else len(url))
|
|
|
|
# 提取扩展名
|
|
extension = url[dot_index + 1:end_index]
|
|
return extension
|
|
|
|
|
|
def delete_file(file_path):
|
|
"""
|
|
删除指定路径的文件。
|
|
|
|
:param file_path: 要删除的文件路径
|
|
:return: None
|
|
"""
|
|
try:
|
|
# 检查文件是否存在
|
|
if os.path.exists(file_path):
|
|
os.remove(file_path) # 删除文件
|
|
logger.info(f"文件 '{file_path}' 已成功删除。")
|
|
else:
|
|
logger.warning(f"文件 '{file_path}' 不存在。")
|
|
except Exception as e:
|
|
logger.error(f"删除文件 '{file_path}' 时发生错误: {e}")
|