328 lines
11 KiB
Python
328 lines
11 KiB
Python
# -*- coding: utf-8 -*-
|
||
import base64
|
||
import hashlib
|
||
import hmac
|
||
import json
|
||
import os
|
||
import re
|
||
import socket
|
||
import time
|
||
from datetime import datetime, timezone
|
||
from http.client import HTTPSConnection
|
||
from typing import Any, Dict, List, Optional
|
||
|
||
Point = Dict[str, int]
|
||
ItemPolygon = Dict[str, int]
|
||
|
||
|
||
class TencentOCR:
|
||
"""腾讯云 OCR 封装,自动从环境变量或配置文件加载密钥"""
|
||
|
||
@staticmethod
|
||
def _load_secret() -> Dict[str, str]:
|
||
# 优先从环境变量读取
|
||
sid = "AKIDXw86q6D8pJYZOEvOm25wZy96oIZcQ1OX"
|
||
skey = "ye7MNAj4ub5PVO2TmriLkwtc8QTItGPO"
|
||
|
||
# 如果没有,就尝试从 ~/.tencent_ocr.json 加载
|
||
if not sid or not skey:
|
||
cfg_path = os.path.expanduser("~/.tencent_ocr.json")
|
||
if os.path.exists(cfg_path):
|
||
with open(cfg_path, "r", encoding="utf-8") as f:
|
||
cfg = json.load(f)
|
||
sid = sid or cfg.get("secret_id")
|
||
skey = skey or cfg.get("secret_key")
|
||
|
||
if not sid or not skey:
|
||
raise RuntimeError(
|
||
"❌ 未找到腾讯云 OCR 密钥,请设置环境变量 TENCENT_SECRET_ID / TENCENT_SECRET_KEY,"
|
||
"或在用户目录下创建 ~/.tencent_ocr.json(格式:{\"secret_id\":\"...\",\"secret_key\":\"...\"})"
|
||
)
|
||
|
||
return {"secret_id": sid, "secret_key": skey}
|
||
|
||
@staticmethod
|
||
def _hmac_sha256(key: bytes, msg: str) -> bytes:
|
||
return hmac.new(key, msg.encode("utf-8"), hashlib.sha256).digest()
|
||
|
||
@staticmethod
|
||
def _strip_data_uri_prefix(b64: str) -> str:
|
||
if "," in b64 and b64.strip().lower().startswith("data:"):
|
||
return b64.split(",", 1)[1]
|
||
return b64
|
||
|
||
@staticmethod
|
||
def _now_ts_and_date():
|
||
ts = int(time.time())
|
||
date = datetime.fromtimestamp(ts, tz=timezone.utc).strftime("%Y-%m-%d")
|
||
return ts, date
|
||
|
||
@staticmethod
|
||
def recognize(
|
||
*,
|
||
image_path: Optional[str] = None,
|
||
image_bytes: Optional[bytes] = None,
|
||
image_url: Optional[str] = None,
|
||
region: Optional[str] = None,
|
||
token: Optional[str] = None,
|
||
action: str = "GeneralBasicOCR",
|
||
version: str = "2018-11-19",
|
||
service: str = "ocr",
|
||
host: str = "ocr.tencentcloudapi.com",
|
||
timeout: int = 15,
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
调用腾讯云 OCR,三选一:image_path / image_bytes / image_url
|
||
自动加载密钥(优先环境变量 -> ~/.tencent_ocr.json)
|
||
"""
|
||
# 读取密钥
|
||
sec = TencentOCR._load_secret()
|
||
secret_id = sec["secret_id"]
|
||
secret_key = sec["secret_key"]
|
||
|
||
assert sum(v is not None for v in (image_path, image_bytes, image_url)) == 1, \
|
||
"必须且只能提供 image_path / image_bytes / image_url 之一"
|
||
|
||
# 1. payload
|
||
payload: Dict[str, Any] = {}
|
||
if image_url:
|
||
payload["ImageUrl"] = image_url
|
||
else:
|
||
if image_bytes is None:
|
||
with open(image_path, "rb") as f:
|
||
image_bytes = f.read()
|
||
img_b64 = base64.b64encode(image_bytes).decode("utf-8")
|
||
img_b64 = TencentOCR._strip_data_uri_prefix(img_b64)
|
||
payload["ImageBase64"] = img_b64
|
||
|
||
payload_str = json.dumps(payload, ensure_ascii=False, separators=(",", ":"))
|
||
|
||
# 2. 参数准备
|
||
algorithm = "TC3-HMAC-SHA256"
|
||
http_method = "POST"
|
||
canonical_uri = "/"
|
||
canonical_querystring = ""
|
||
content_type = "application/json; charset=utf-8"
|
||
signed_headers = "content-type;host;x-tc-action"
|
||
|
||
timestamp, date = TencentOCR._now_ts_and_date()
|
||
credential_scope = f"{date}/{service}/tc3_request"
|
||
|
||
# 3. 规范请求串
|
||
canonical_headers = (
|
||
f"content-type:{content_type}\n"
|
||
f"host:{host}\n"
|
||
f"x-tc-action:{action.lower()}\n"
|
||
)
|
||
hashed_request_payload = hashlib.sha256(payload_str.encode("utf-8")).hexdigest()
|
||
canonical_request = (
|
||
f"{http_method}\n{canonical_uri}\n{canonical_querystring}\n"
|
||
f"{canonical_headers}\n{signed_headers}\n{hashed_request_payload}"
|
||
)
|
||
|
||
# 4. 签名
|
||
hashed_canonical_request = hashlib.sha256(canonical_request.encode("utf-8")).hexdigest()
|
||
string_to_sign = (
|
||
f"{algorithm}\n{timestamp}\n{credential_scope}\n{hashed_canonical_request}"
|
||
)
|
||
secret_date = TencentOCR._hmac_sha256(("TC3" + secret_key).encode("utf-8"), date)
|
||
secret_service = hmac.new(secret_date, service.encode("utf-8"), hashlib.sha256).digest()
|
||
secret_signing = hmac.new(secret_service, b"tc3_request", hashlib.sha256).digest()
|
||
signature = hmac.new(secret_signing, string_to_sign.encode("utf-8"), hashlib.sha256).hexdigest()
|
||
|
||
authorization = (
|
||
f"{algorithm} "
|
||
f"Credential={secret_id}/{credential_scope}, "
|
||
f"SignedHeaders={signed_headers}, "
|
||
f"Signature={signature}"
|
||
)
|
||
|
||
# 5. headers
|
||
headers = {
|
||
"Authorization": authorization,
|
||
"Content-Type": content_type,
|
||
"Host": host,
|
||
"X-TC-Action": action,
|
||
"X-TC-Timestamp": str(timestamp),
|
||
"X-TC-Version": version,
|
||
}
|
||
if region:
|
||
headers["X-TC-Region"] = region
|
||
if token:
|
||
headers["X-TC-Token"] = token
|
||
|
||
# 6. 发请求
|
||
try:
|
||
conn = HTTPSConnection(host, timeout=timeout)
|
||
conn.request("POST", "/", body=payload_str.encode("utf-8"), headers=headers)
|
||
resp = conn.getresponse()
|
||
raw = resp.read().decode("utf-8", errors="replace")
|
||
try:
|
||
data = json.loads(raw)
|
||
except Exception:
|
||
data = {"NonJSONBody": raw}
|
||
return {
|
||
"http_status": resp.status,
|
||
"http_reason": resp.reason,
|
||
"headers": dict(resp.getheaders()),
|
||
"body": data,
|
||
}
|
||
except socket.gaierror as e:
|
||
return {"error": "DNS_RESOLUTION_FAILED", "detail": str(e)}
|
||
except socket.timeout:
|
||
return {"error": "NETWORK_TIMEOUT", "detail": f"Timeout after {timeout}s"}
|
||
except Exception as e:
|
||
return {"error": "REQUEST_FAILED", "detail": str(e)}
|
||
finally:
|
||
try:
|
||
conn.close()
|
||
except Exception:
|
||
pass
|
||
|
||
@staticmethod
|
||
def _norm(s: str) -> str:
|
||
return (s or "").strip().lstrip("@").lower()
|
||
|
||
@staticmethod
|
||
def _rect_from_polygon(poly: List[Point]) -> Optional[ItemPolygon]:
|
||
if not poly:
|
||
return None
|
||
xs = [p["X"] for p in poly]
|
||
ys = [p["Y"] for p in poly]
|
||
return {"X": min(xs), "Y": min(ys), "Width": max(xs) - min(xs), "Height": max(ys) - min(ys)}
|
||
|
||
@classmethod
|
||
def find_last_name_bbox(cls, ocr: Dict[str, Any], name: str) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
从 OCR JSON 中找到指定名字的“最后一次”出现并返回坐标信息。
|
||
:param ocr: 完整 OCR JSON(含 Response.TextDetections)
|
||
:param name: 前端传入的名字,比如 'lee39160'
|
||
:return: dict 或 None,例如:
|
||
{
|
||
"index": 21,
|
||
"text": "lee39160",
|
||
"item": {"X": 248, "Y": 1701, "Width": 214, "Height": 49},
|
||
"polygon": [...],
|
||
"center": {"x": 355.0, "y": 1725.5}
|
||
}
|
||
"""
|
||
dets = (ocr.get("body") or ocr).get("Response", {}).get("TextDetections", [])
|
||
if not dets or not name:
|
||
return None
|
||
|
||
target = cls._norm(name)
|
||
found = -1
|
||
|
||
# 从后往前找最后一个严格匹配
|
||
for i in range(len(dets) - 1, -1, -1):
|
||
txt = cls._norm(dets[i].get("DetectedText", ""))
|
||
if txt == target:
|
||
found = i
|
||
break
|
||
|
||
# 兜底:再匹配原始文本(可能带 @)
|
||
if found == -1:
|
||
for i in range(len(dets) - 1, -1, -1):
|
||
raw = (dets[i].get("DetectedText") or "").strip().lower()
|
||
if raw.lstrip("@") == target:
|
||
found = i
|
||
break
|
||
|
||
if found == -1:
|
||
return None
|
||
|
||
det = dets[found]
|
||
item: Optional[ItemPolygon] = det.get("ItemPolygon")
|
||
poly: List[Point] = det.get("Polygon") or []
|
||
|
||
# 没有 ItemPolygon 就从 Polygon 算
|
||
if not item:
|
||
item = cls._rect_from_polygon(poly)
|
||
if not item:
|
||
return None
|
||
|
||
center = {"x": item["X"] + item["Width"] / 2.0, "y": item["Y"] + item["Height"] / 2.0}
|
||
|
||
return {
|
||
"index": found,
|
||
"text": det.get("DetectedText", ""),
|
||
"item": item,
|
||
"polygon": poly,
|
||
"center": center,
|
||
}
|
||
|
||
@staticmethod
|
||
def _get_detections(ocr: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||
"""兼容含 body 层的 OCR 结构,提取 TextDetections 列表"""
|
||
return (ocr.get("body") or ocr).get("Response", {}).get("TextDetections", []) or []
|
||
|
||
@staticmethod
|
||
def _norm_txt(s: str) -> str:
|
||
"""清洗文本:去空格"""
|
||
return (s or "").strip()
|
||
|
||
@classmethod
|
||
def slice_texts_between(
|
||
cls,
|
||
ocr: Dict[str, Any],
|
||
start_keyword: str = "切换账号",
|
||
end_keyword: str = "添加账号",
|
||
*,
|
||
username_like: bool = False, # True 时只保留像用户名的文本
|
||
min_conf: int = 0 # 置信度下限
|
||
) -> List[Dict[str, Any]]:
|
||
"""
|
||
返回位于 start_keyword 与 end_keyword 之间的所有文本项(不含两端),
|
||
每项保留原始 DetectedText、Confidence、ItemPolygon 等信息。
|
||
"""
|
||
dets = cls._get_detections(ocr)
|
||
if not dets:
|
||
return []
|
||
|
||
# 找“切换账号”最后一次出现的下标
|
||
start_idx = -1
|
||
for i, d in enumerate(dets):
|
||
txt = cls._norm_txt(d.get("DetectedText", ""))
|
||
if txt == start_keyword:
|
||
start_idx = i
|
||
|
||
# 找“添加账号”第一次出现的下标
|
||
end_idx = -1
|
||
for i, d in enumerate(dets):
|
||
txt = cls._norm_txt(d.get("DetectedText", ""))
|
||
if txt == end_keyword:
|
||
end_idx = i
|
||
break
|
||
|
||
if start_idx == -1 or end_idx == -1 or end_idx <= start_idx:
|
||
return []
|
||
|
||
# 提取两者之间的内容
|
||
mid = []
|
||
for d in dets[start_idx + 1:end_idx]:
|
||
if int(d.get("Confidence", 0)) < min_conf:
|
||
continue
|
||
txt = cls._norm_txt(d.get("DetectedText", ""))
|
||
if not txt:
|
||
continue
|
||
mid.append(d)
|
||
|
||
if not username_like:
|
||
return mid
|
||
|
||
# 只保留像用户名的文本
|
||
pat = re.compile(r"^[A-Za-z0-9_.-]{3,}$")
|
||
filtered = [d for d in mid if pat.match(cls._norm_txt(d.get("DetectedText", "")))]
|
||
return filtered
|
||
|
||
|
||
|
||
|
||
if __name__ == "__main__":
|
||
result = TencentOCR.recognize(
|
||
image_path=r"C:\Users\zhangkai\Desktop\last-item\iosai\test.png",
|
||
action="GeneralAccurateOCR",
|
||
)
|
||
print(json.dumps(result, ensure_ascii=False, indent=2))
|
||
|