ocr_service.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. # ==================== ocr_service.py ====================
  2. import cv2
  3. import easyocr
  4. import numpy as np
  5. import re
  6. import requests
  7. from io import BytesIO
  8. from PIL import Image
  9. from typing import List, Tuple, Dict, Optional
  10. import logging
  11. logger = logging.getLogger(__name__)
  12. class OCRService:
  13. """Service for extracting text from product images using OCR."""
  14. def __init__(self):
  15. self.reader = None
  16. def _get_reader(self):
  17. """Lazy load EasyOCR reader."""
  18. if self.reader is None:
  19. self.reader = easyocr.Reader(['en'], gpu=False)
  20. return self.reader
  21. def download_image(self, image_url: str) -> Optional[np.ndarray]:
  22. """Download image from URL and convert to OpenCV format."""
  23. try:
  24. response = requests.get(image_url, timeout=10)
  25. response.raise_for_status()
  26. # Convert to PIL Image then to OpenCV format
  27. pil_image = Image.open(BytesIO(response.content))
  28. image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
  29. return image
  30. except Exception as e:
  31. logger.error(f"Error downloading image from {image_url}: {str(e)}")
  32. return None
  33. def preprocess_horizontal(self, image: np.ndarray) -> np.ndarray:
  34. """Preprocess image for horizontal text."""
  35. gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
  36. enhanced = cv2.GaussianBlur(gray, (5, 5), 0)
  37. _, binary = cv2.threshold(enhanced, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
  38. return binary
  39. def preprocess_vertical(self, image: np.ndarray) -> np.ndarray:
  40. """Preprocess image for vertical text."""
  41. gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
  42. enhanced = cv2.equalizeHist(gray)
  43. thresh = cv2.adaptiveThreshold(
  44. enhanced, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY_INV, 15, 10
  45. )
  46. return thresh
  47. def detect_text_regions(self, image: np.ndarray, preprocess_func) -> List[Tuple]:
  48. """Detect text regions using contours."""
  49. processed = preprocess_func(image)
  50. contours, _ = cv2.findContours(processed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
  51. text_regions = []
  52. for contour in contours:
  53. x, y, w, h = cv2.boundingRect(contour)
  54. if w > 30 and h > 30: # Filter small regions
  55. aspect_ratio = h / w
  56. text_regions.append((x, y, w, h, aspect_ratio))
  57. return text_regions
  58. def classify_and_extract_text(self, image: np.ndarray, regions: List[Tuple]) -> List[Tuple]:
  59. """Classify regions as horizontal or vertical and extract text."""
  60. reader = self._get_reader()
  61. all_detected_text = []
  62. for (x, y, w, h, aspect_ratio) in regions:
  63. roi = image[y:y + h, x:x + w]
  64. if aspect_ratio > 1.5: # Vertical text
  65. roi = cv2.rotate(roi, cv2.ROTATE_90_CLOCKWISE)
  66. results = reader.readtext(roi, detail=1)
  67. for _, text, confidence in results:
  68. all_detected_text.append((text, confidence))
  69. return all_detected_text
  70. def clean_ocr_output(self, ocr_results: List[Tuple], confidence_threshold: float = 0.40) -> List[Tuple]:
  71. """Clean OCR results by removing unwanted characters and low-confidence detections."""
  72. cleaned_results = []
  73. for text, confidence in ocr_results:
  74. if confidence < confidence_threshold:
  75. continue
  76. # Remove unwanted characters using regex
  77. cleaned_text = re.sub(r"[^A-Za-z0-9\s\.\,\(\)\-\%\/]", "", text)
  78. cleaned_text = re.sub(r"\s+", " ", cleaned_text).strip()
  79. # Remove unwanted numeric characters like single digits
  80. if len(cleaned_text) == 1 and cleaned_text.isdigit():
  81. continue
  82. if any(char.isdigit() for char in cleaned_text) and len(cleaned_text) < 2:
  83. continue
  84. if len(cleaned_text.strip()) > 0:
  85. cleaned_results.append((cleaned_text.strip(), confidence))
  86. return cleaned_results
  87. def process_image(self, image_url: str) -> Dict:
  88. """Main method to process image and extract text."""
  89. try:
  90. # Download image
  91. image = self.download_image(image_url)
  92. if image is None:
  93. return {
  94. "detected_text": [],
  95. "extracted_attributes": {},
  96. "error": "Failed to download image"
  97. }
  98. # Detect and process horizontal text
  99. horizontal_regions = self.detect_text_regions(image, self.preprocess_horizontal)
  100. horizontal_text = self.classify_and_extract_text(image, horizontal_regions)
  101. # Detect and process vertical text
  102. vertical_regions = self.detect_text_regions(image, self.preprocess_vertical)
  103. vertical_text = self.classify_and_extract_text(image, vertical_regions)
  104. # Combine results
  105. all_text = horizontal_text + vertical_text
  106. # Clean results
  107. cleaned_results = self.clean_ocr_output(all_text, confidence_threshold=0.40)
  108. # Format for response
  109. detected_text = [
  110. {"text": text, "confidence": float(confidence)}
  111. for text, confidence in cleaned_results
  112. ]
  113. return {
  114. "detected_text": detected_text,
  115. "extracted_attributes": {}
  116. }
  117. except Exception as e:
  118. logger.error(f"Error processing image: {str(e)}")
  119. return {
  120. "detected_text": [],
  121. "extracted_attributes": {},
  122. "error": str(e)
  123. }