Files
iOSAI/Utils/OCRUtils.py
2025-11-19 17:23:41 +08:00

243 lines
8.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import cv2
import numpy as np
from typing import List, Tuple, Union, Optional
from PIL import Image
ArrayLikeImage = Union[np.ndarray, str, Image.Image]
class OCRUtils:
@classmethod
def _to_gray(cls, img: ArrayLikeImage) -> np.ndarray:
"""
接受路径/np.ndarray/PIL.Image统一转为灰度 np.ndarray。
"""
# 路径
if isinstance(img, str):
arr = cv2.imread(img, cv2.IMREAD_GRAYSCALE)
if arr is None:
raise FileNotFoundError(f"图像加载失败,请检查路径: {img}")
return arr
# PIL.Image
if isinstance(img, Image.Image):
return cv2.cvtColor(np.array(img.convert("RGB")), cv2.COLOR_RGB2GRAY)
# numpy 数组
if isinstance(img, np.ndarray):
if img.ndim == 2:
return img # 已是灰度
if img.ndim == 3:
return cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
raise ValueError("不支持的图像维度(期望 2D 灰度或 3D BGR/RGB")
raise TypeError("large_image 类型必须是 str / np.ndarray / PIL.Image.Image")
@classmethod
def non_max_suppression(
cls,
boxes: List[List[float]],
scores: Optional[np.ndarray] = None,
overlapThresh: float = 0.5
) -> np.ndarray:
"""
boxes: [ [x1,y1,x2,y2], ... ]
scores: 每个框的置信度(用于“按分数做 NMS”。若为 None则退化为按 y2 排序的经典近似。
返回: 经过 NMS 保留的 boxes(int) ndarray形状 (N,4)
"""
if len(boxes) == 0:
return np.empty((0, 4), dtype=int)
boxes = np.asarray(boxes, dtype=np.float32)
x1, y1, x2, y2 = boxes.T
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
if scores is None:
order = np.argsort(y2) # 经典写法
else:
scores = np.asarray(scores, dtype=np.float32)
order = np.argsort(scores)[::-1] # 分数从高到低
keep = []
while order.size > 0:
i = order[0] if scores is not None else order[-1]
keep.append(i)
rest = order[1:] if scores is not None else order[:-1]
xx1 = np.maximum(x1[i], x1[rest])
yy1 = np.maximum(y1[i], y1[rest])
xx2 = np.minimum(x2[i], x2[rest])
yy2 = np.minimum(y2[i], y2[rest])
w = np.maximum(0, xx2 - xx1 + 1)
h = np.maximum(0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / areas[rest]
inds = np.where(ovr <= overlapThresh)[0]
order = rest[inds]
return boxes[keep].astype(int)
# @classmethod
# def find_template(
# cls,
# template_path: str,
# large_image: ArrayLikeImage,
# threshold: float = 0.8,
# overlapThresh: float = 0.5,
# return_boxes: bool = False
# ) -> Union[List[Tuple[int, int]], Tuple[List[Tuple[int, int]], np.ndarray]]:
# """
# 在 large_image 中查找 template_path 模板的位置。
# - large_image 可为文件路径、np.ndarray 或 PIL.Image
# - threshold: 模板匹配阈值TM_CCOEFF_NORMED
# - overlapThresh: NMS 重叠阈值
# - return_boxes: True 时同时返回保留的框数组 (N,4)
#
# 返回:
# centers 或 (centers, boxes)
# centers: [(cx, cy), ...]
# boxes: [[x1,y1,x2,y2], ...] (np.ndarray, int)
# """
# # 模板(灰度)
# template = cv2.imread(template_path, cv2.IMREAD_GRAYSCALE)
# if template is None:
# raise FileNotFoundError(f"模板图像加载失败,请检查路径: {template_path}")
#
# # 大图(灰度)
# gray = cls._to_gray(large_image)
#
# # 模板尺寸
# tw, th = template.shape[::-1]
#
# # 模板匹配(相关系数归一化)
# result = cv2.matchTemplate(gray, template, cv2.TM_CCOEFF_NORMED)
#
# # 阈值筛选
# ys, xs = np.where(result >= threshold)
# if len(xs) == 0:
# return ([], np.empty((0, 4), dtype=int)) if return_boxes else []
#
# # 收集候选框与分数
# boxes = []
# scores = []
# for (x, y) in zip(xs, ys):
# boxes.append([x, y, x + tw, y + th])
# scores.append(result[y, x])
#
# # 按分数做 NMS
# boxes_nms = cls.non_max_suppression(boxes, scores=np.array(scores), overlapThresh=overlapThresh)
#
# # 计算中心点
# centers = [((x1 + x2) // 2, (y1 + y2) // 2) for (x1, y1, x2, y2) in boxes_nms]
#
#
#
# if return_boxes:
# return centers, boxes_nms
#
#
# return centers
@classmethod
def find_template(
cls,
template_path: str,
large_image: ArrayLikeImage,
threshold: float = 0.8,
overlapThresh: float = 0.5,
return_boxes: bool = False
) -> Union[List[Tuple[int, int]], Tuple[List[Tuple[int, int]], np.ndarray]]:
"""
在 large_image 中查找 template_path 模板的位置。
- large_image 可为文件路径、np.ndarray 或 PIL.Image
- threshold: 模板匹配阈值TM_CCOEFF_NORMED
- overlapThresh: NMS 重叠阈值
- return_boxes: True 时同时返回保留的框数组 (N,4)
若检测结果为空,则在相同阈值下最多重试三次(共 3 次尝试)。
返回:
centers 或 (centers, boxes)
centers: [(cx, cy), ...]
boxes: [[x1,y1,x2,y2], ...] (np.ndarray, int)
"""
if not os.path.isfile(template_path):
print(f"模板文件不存在 → {template_path}")
raise FileNotFoundError(f"模板文件不存在 → {template_path}")
size = os.path.getsize(template_path)
if size == 0:
print(f"模板文件大小为 0 → {template_path} ")
raise ValueError(f"模板文件大小为 0 → {template_path}")
# 模板(灰度)
template = cv2.imread(template_path, cv2.IMREAD_GRAYSCALE)
if template is None:
raise FileNotFoundError(f"模板图像加载失败,请检查路径: {template_path}")
# 大图(灰度)
gray = cls._to_gray(large_image)
# 模板尺寸
tw, th = template.shape[::-1]
# 内部:执行一次匹配并返回 (centers, boxes_nms)
def _match_once(cur_threshold: float):
# 模板匹配(相关系数归一化)
result = cv2.matchTemplate(gray, template, cv2.TM_CCOEFF_NORMED)
# 阈值筛选
ys, xs = np.where(result >= cur_threshold)
if len(xs) == 0:
return [], np.empty((0, 4), dtype=int)
# 收集候选框与分数
boxes = []
scores = []
for (x, y) in zip(xs, ys):
boxes.append([int(x), int(y), int(x + tw), int(y + th)])
scores.append(float(result[y, x]))
# 按分数做 NMS
boxes_nms = cls.non_max_suppression(
boxes,
scores=np.asarray(scores, dtype=np.float32),
overlapThresh=overlapThresh
)
# 计算中心点(转为 Python int
centers = [(int((x1 + x2) // 2), int((y1 + y2) // 2))
for (x1, y1, x2, y2) in boxes_nms]
# 统一为 np.ndarray[int]
boxes_nms = np.asarray(boxes_nms, dtype=int)
return centers, boxes_nms
# ===== 重试控制(最多 3 次)=====
MAX_RETRIES = 3
THRESHOLD_DECAY = 0.0 # 如需越试越宽松,可改为 0.02~0.05;不需要则保持 0
MIN_THRESHOLD = 0.6
cur_threshold = float(threshold)
last_centers, last_boxes = [], np.empty((0, 4), dtype=int)
for attempt in range(MAX_RETRIES):
centers, boxes_nms = _match_once(cur_threshold)
if centers:
if return_boxes:
return centers, boxes_nms
return centers
# 记录最后一次(若最终失败按规范返回空)
last_centers, last_boxes = centers, boxes_nms
# 为下一次尝试准备(这里默认不衰减阈值;如需可打开 THRESHOLD_DECAY
if attempt < MAX_RETRIES - 1 and THRESHOLD_DECAY > 0.0:
cur_threshold = max(MIN_THRESHOLD, cur_threshold - THRESHOLD_DECAY)
# 全部尝试失败
if return_boxes:
return last_centers, last_boxes
return last_centers