visual_processing_service.py 47 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193
  1. # # ==================== visual_processing_service.py ====================
  2. # import torch
  3. # import cv2
  4. # import numpy as np
  5. # import requests
  6. # from io import BytesIO
  7. # from PIL import Image
  8. # from typing import Dict, List, Optional, Tuple
  9. # import logging
  10. # from transformers import CLIPProcessor, CLIPModel, AutoImageProcessor, AutoModelForImageClassification
  11. # from sklearn.cluster import KMeans
  12. # import webcolors
  13. # logger = logging.getLogger(__name__)
  14. # class VisualProcessingService:
  15. # """Service for extracting visual attributes from product images using CLIP and computer vision."""
  16. # def __init__(self):
  17. # self.clip_model = None
  18. # self.clip_processor = None
  19. # self.classification_model = None
  20. # self.classification_processor = None
  21. # def _get_clip_model(self):
  22. # """Lazy load CLIP model."""
  23. # if self.clip_model is None:
  24. # logger.info("Loading CLIP model...")
  25. # self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
  26. # self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
  27. # self.clip_model.eval()
  28. # return self.clip_model, self.clip_processor
  29. # def _get_classification_model(self):
  30. # """Lazy load image classification model for product categories."""
  31. # if self.classification_model is None:
  32. # logger.info("Loading classification model...")
  33. # # Using Google's ViT model fine-tuned on fashion/products
  34. # self.classification_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
  35. # self.classification_model = AutoModelForImageClassification.from_pretrained(
  36. # "google/vit-base-patch16-224"
  37. # )
  38. # self.classification_model.eval()
  39. # return self.classification_model, self.classification_processor
  40. # def download_image(self, image_url: str) -> Optional[Image.Image]:
  41. # """Download image from URL."""
  42. # try:
  43. # response = requests.get(image_url, timeout=10)
  44. # response.raise_for_status()
  45. # image = Image.open(BytesIO(response.content)).convert('RGB')
  46. # return image
  47. # except Exception as e:
  48. # logger.error(f"Error downloading image from {image_url}: {str(e)}")
  49. # return None
  50. # def extract_dominant_colors(self, image: Image.Image, n_colors: int = 3) -> List[Dict]:
  51. # """Extract dominant colors from image using K-means clustering."""
  52. # try:
  53. # # Resize image for faster processing
  54. # img_small = image.resize((150, 150))
  55. # img_array = np.array(img_small)
  56. # # Reshape to pixels
  57. # pixels = img_array.reshape(-1, 3)
  58. # # Apply K-means
  59. # kmeans = KMeans(n_clusters=n_colors, random_state=42, n_init=10)
  60. # kmeans.fit(pixels)
  61. # colors = []
  62. # for center in kmeans.cluster_centers_:
  63. # rgb = tuple(center.astype(int))
  64. # color_name = self._get_color_name(rgb)
  65. # colors.append({
  66. # "name": color_name,
  67. # "rgb": rgb,
  68. # "percentage": float(np.sum(kmeans.labels_ == len(colors)) / len(kmeans.labels_) * 100)
  69. # })
  70. # # Sort by percentage
  71. # colors.sort(key=lambda x: x['percentage'], reverse=True)
  72. # return colors
  73. # except Exception as e:
  74. # logger.error(f"Error extracting colors: {str(e)}")
  75. # return []
  76. # def _get_color_name(self, rgb: Tuple[int, int, int]) -> str:
  77. # """Convert RGB to closest color name."""
  78. # try:
  79. # # Try to get exact match
  80. # color_name = webcolors.rgb_to_name(rgb)
  81. # return color_name
  82. # except ValueError:
  83. # # Find closest color
  84. # min_distance = float('inf')
  85. # closest_name = 'unknown'
  86. # for name in webcolors.CSS3_NAMES_TO_HEX:
  87. # hex_color = webcolors.CSS3_NAMES_TO_HEX[name]
  88. # r, g, b = webcolors.hex_to_rgb(hex_color)
  89. # distance = sum((c1 - c2) ** 2 for c1, c2 in zip(rgb, (r, g, b)))
  90. # if distance < min_distance:
  91. # min_distance = distance
  92. # closest_name = name
  93. # return closest_name
  94. # def classify_with_clip(self, image: Image.Image, candidates: List[str], attribute_name: str) -> Dict:
  95. # """Use CLIP to classify image against candidate labels."""
  96. # try:
  97. # model, processor = self._get_clip_model()
  98. # # Prepare inputs
  99. # inputs = processor(
  100. # text=candidates,
  101. # images=image,
  102. # return_tensors="pt",
  103. # padding=True
  104. # )
  105. # # Get predictions
  106. # with torch.no_grad():
  107. # outputs = model(**inputs)
  108. # logits_per_image = outputs.logits_per_image
  109. # probs = logits_per_image.softmax(dim=1)
  110. # # Get top predictions
  111. # top_probs, top_indices = torch.topk(probs[0], k=min(3, len(candidates)))
  112. # results = []
  113. # for prob, idx in zip(top_probs, top_indices):
  114. # if prob.item() > 0.15: # Confidence threshold
  115. # results.append({
  116. # "value": candidates[idx.item()],
  117. # "confidence": float(prob.item())
  118. # })
  119. # return {
  120. # "attribute": attribute_name,
  121. # "predictions": results
  122. # }
  123. # except Exception as e:
  124. # logger.error(f"Error in CLIP classification: {str(e)}")
  125. # return {"attribute": attribute_name, "predictions": []}
  126. # def detect_patterns(self, image: Image.Image) -> Dict:
  127. # """Detect patterns in the image using edge detection and texture analysis."""
  128. # try:
  129. # # Convert to OpenCV format
  130. # img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
  131. # gray = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY)
  132. # # Calculate edge density
  133. # edges = cv2.Canny(gray, 50, 150)
  134. # edge_density = np.sum(edges > 0) / (edges.shape[0] * edges.shape[1])
  135. # # Calculate texture variance
  136. # laplacian = cv2.Laplacian(gray, cv2.CV_64F)
  137. # texture_variance = laplacian.var()
  138. # # Determine pattern type based on metrics
  139. # pattern_candidates = []
  140. # if edge_density > 0.15:
  141. # pattern_candidates.append("geometric")
  142. # if texture_variance > 500:
  143. # pattern_candidates.append("textured")
  144. # if edge_density < 0.05 and texture_variance < 200:
  145. # pattern_candidates.append("solid")
  146. # # Use CLIP for more detailed pattern detection
  147. # pattern_types = [
  148. # "solid color", "striped", "checkered", "polka dot", "floral",
  149. # "geometric", "abstract", "graphic print", "camouflage", "paisley"
  150. # ]
  151. # clip_result = self.classify_with_clip(image, pattern_types, "pattern")
  152. # return clip_result
  153. # except Exception as e:
  154. # logger.error(f"Error detecting patterns: {str(e)}")
  155. # return {"attribute": "pattern", "predictions": []}
  156. # def detect_material(self, image: Image.Image) -> Dict:
  157. # """Detect material type using CLIP."""
  158. # materials = [
  159. # "cotton", "polyester", "denim", "leather", "silk", "wool",
  160. # "linen", "satin", "velvet", "fleece", "knit", "jersey",
  161. # "canvas", "nylon", "suede", "corduroy"
  162. # ]
  163. # return self.classify_with_clip(image, materials, "material")
  164. # def detect_style(self, image: Image.Image) -> Dict:
  165. # """Detect style/occasion using CLIP."""
  166. # styles = [
  167. # "casual", "formal", "sporty", "business", "vintage", "modern",
  168. # "bohemian", "streetwear", "elegant", "preppy", "athletic",
  169. # "loungewear", "party", "workwear", "outdoor"
  170. # ]
  171. # return self.classify_with_clip(image, styles, "style")
  172. # def detect_fit(self, image: Image.Image) -> Dict:
  173. # """Detect clothing fit using CLIP."""
  174. # fits = [
  175. # "slim fit", "regular fit", "loose fit", "oversized",
  176. # "tight", "relaxed", "tailored", "athletic fit"
  177. # ]
  178. # return self.classify_with_clip(image, fits, "fit")
  179. # def detect_neckline(self, image: Image.Image, product_type: str) -> Dict:
  180. # """Detect neckline type for tops using CLIP."""
  181. # if product_type.lower() not in ['shirt', 't-shirt', 'top', 'blouse', 'dress', 'sweater']:
  182. # return {"attribute": "neckline", "predictions": []}
  183. # necklines = [
  184. # "crew neck", "v-neck", "round neck", "collar", "turtleneck",
  185. # "scoop neck", "boat neck", "off-shoulder", "square neck", "halter"
  186. # ]
  187. # return self.classify_with_clip(image, necklines, "neckline")
  188. # def detect_sleeve_type(self, image: Image.Image, product_type: str) -> Dict:
  189. # """Detect sleeve type using CLIP."""
  190. # if product_type.lower() not in ['shirt', 't-shirt', 'top', 'blouse', 'dress', 'sweater', 'jacket']:
  191. # return {"attribute": "sleeve_type", "predictions": []}
  192. # sleeves = [
  193. # "short sleeve", "long sleeve", "sleeveless", "three-quarter sleeve",
  194. # "cap sleeve", "flutter sleeve", "bell sleeve", "raglan sleeve"
  195. # ]
  196. # return self.classify_with_clip(image, sleeves, "sleeve_type")
  197. # def detect_product_type(self, image: Image.Image) -> Dict:
  198. # """Detect product type using CLIP."""
  199. # product_types = [
  200. # "t-shirt", "shirt", "dress", "pants", "jeans", "shorts",
  201. # "skirt", "jacket", "coat", "sweater", "hoodie", "blazer",
  202. # "suit", "jumpsuit", "romper", "cardigan", "vest", "top",
  203. # "blouse", "tank top", "polo shirt", "sweatshirt"
  204. # ]
  205. # return self.classify_with_clip(image, product_types, "product_type")
  206. # def detect_closure_type(self, image: Image.Image) -> Dict:
  207. # """Detect closure type using CLIP."""
  208. # closures = [
  209. # "button", "zipper", "snap", "hook and eye", "velcro",
  210. # "lace-up", "pull-on", "elastic", "tie", "buckle"
  211. # ]
  212. # return self.classify_with_clip(image, closures, "closure_type")
  213. # def detect_length(self, image: Image.Image, product_type: str) -> Dict:
  214. # """Detect garment length using CLIP."""
  215. # if product_type.lower() in ['pants', 'jeans', 'trousers']:
  216. # lengths = ["full length", "ankle length", "cropped", "capri", "shorts"]
  217. # elif product_type.lower() in ['skirt', 'dress']:
  218. # lengths = ["mini", "knee length", "midi", "maxi", "floor length"]
  219. # elif product_type.lower() in ['jacket', 'coat']:
  220. # lengths = ["waist length", "hip length", "thigh length", "knee length", "full length"]
  221. # else:
  222. # lengths = ["short", "regular", "long"]
  223. # return self.classify_with_clip(image, lengths, "length")
  224. # def process_image(self, image_url: str, product_type_hint: Optional[str] = None) -> Dict:
  225. # """
  226. # Main method to process image and extract all visual attributes.
  227. # """
  228. # try:
  229. # # Download image
  230. # image = self.download_image(image_url)
  231. # if image is None:
  232. # return {
  233. # "visual_attributes": {},
  234. # "error": "Failed to download image"
  235. # }
  236. # # Extract all attributes
  237. # visual_attributes = {}
  238. # # 1. Product Type Detection
  239. # logger.info("Detecting product type...")
  240. # product_type_result = self.detect_product_type(image)
  241. # if product_type_result["predictions"]:
  242. # visual_attributes["product_type"] = product_type_result["predictions"][0]["value"]
  243. # detected_product_type = visual_attributes["product_type"]
  244. # else:
  245. # detected_product_type = product_type_hint or "unknown"
  246. # # 2. Color Detection
  247. # logger.info("Extracting colors...")
  248. # colors = self.extract_dominant_colors(image, n_colors=3)
  249. # if colors:
  250. # visual_attributes["primary_color"] = colors[0]["name"]
  251. # visual_attributes["color_palette"] = [c["name"] for c in colors]
  252. # visual_attributes["color_details"] = colors
  253. # # 3. Pattern Detection
  254. # logger.info("Detecting patterns...")
  255. # pattern_result = self.detect_patterns(image)
  256. # if pattern_result["predictions"]:
  257. # visual_attributes["pattern"] = pattern_result["predictions"][0]["value"]
  258. # # 4. Material Detection
  259. # logger.info("Detecting material...")
  260. # material_result = self.detect_material(image)
  261. # if material_result["predictions"]:
  262. # visual_attributes["material"] = material_result["predictions"][0]["value"]
  263. # # 5. Style Detection
  264. # logger.info("Detecting style...")
  265. # style_result = self.detect_style(image)
  266. # if style_result["predictions"]:
  267. # visual_attributes["style"] = style_result["predictions"][0]["value"]
  268. # # 6. Fit Detection
  269. # logger.info("Detecting fit...")
  270. # fit_result = self.detect_fit(image)
  271. # if fit_result["predictions"]:
  272. # visual_attributes["fit"] = fit_result["predictions"][0]["value"]
  273. # # 7. Neckline Detection (if applicable)
  274. # logger.info("Detecting neckline...")
  275. # neckline_result = self.detect_neckline(image, detected_product_type)
  276. # if neckline_result["predictions"]:
  277. # visual_attributes["neckline"] = neckline_result["predictions"][0]["value"]
  278. # # 8. Sleeve Type Detection (if applicable)
  279. # logger.info("Detecting sleeve type...")
  280. # sleeve_result = self.detect_sleeve_type(image, detected_product_type)
  281. # if sleeve_result["predictions"]:
  282. # visual_attributes["sleeve_type"] = sleeve_result["predictions"][0]["value"]
  283. # # 9. Closure Type Detection
  284. # logger.info("Detecting closure type...")
  285. # closure_result = self.detect_closure_type(image)
  286. # if closure_result["predictions"]:
  287. # visual_attributes["closure_type"] = closure_result["predictions"][0]["value"]
  288. # # 10. Length Detection
  289. # logger.info("Detecting length...")
  290. # length_result = self.detect_length(image, detected_product_type)
  291. # if length_result["predictions"]:
  292. # visual_attributes["length"] = length_result["predictions"][0]["value"]
  293. # # Format response
  294. # return {
  295. # "visual_attributes": visual_attributes,
  296. # "detailed_predictions": {
  297. # "product_type": product_type_result,
  298. # "pattern": pattern_result,
  299. # "material": material_result,
  300. # "style": style_result,
  301. # "fit": fit_result,
  302. # "neckline": neckline_result,
  303. # "sleeve_type": sleeve_result,
  304. # "closure_type": closure_result,
  305. # "length": length_result
  306. # }
  307. # }
  308. # except Exception as e:
  309. # logger.error(f"Error processing image: {str(e)}")
  310. # return {
  311. # "visual_attributes": {},
  312. # "error": str(e)
  313. # }
  314. # # ==================== visual_processing_service_optimized.py ====================
  315. # """
  316. # Optimized version with:
  317. # - Result caching
  318. # - Batch processing support
  319. # - Memory management
  320. # - Error recovery
  321. # - Performance monitoring
  322. # """
  323. # import torch
  324. # import cv2
  325. # import numpy as np
  326. # import requests
  327. # from io import BytesIO
  328. # from PIL import Image
  329. # from typing import Dict, List, Optional, Tuple
  330. # import logging
  331. # import time
  332. # import hashlib
  333. # from functools import lru_cache
  334. # from transformers import CLIPProcessor, CLIPModel
  335. # from sklearn.cluster import KMeans
  336. # import webcolors
  337. # logger = logging.getLogger(__name__)
  338. # class VisualProcessingService:
  339. # """Optimized service for extracting visual attributes from product images."""
  340. # # Class-level model caching (shared across instances)
  341. # _clip_model = None
  342. # _clip_processor = None
  343. # _model_device = None
  344. # def __init__(self, use_cache: bool = True, cache_ttl: int = 3600):
  345. # self.use_cache = use_cache
  346. # self.cache_ttl = cache_ttl
  347. # self._cache = {}
  348. # @classmethod
  349. # def _get_device(cls):
  350. # """Get optimal device (GPU if available, else CPU)."""
  351. # if cls._model_device is None:
  352. # cls._model_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  353. # logger.info(f"Using device: {cls._model_device}")
  354. # return cls._model_device
  355. # @classmethod
  356. # def _get_clip_model(cls):
  357. # """Lazy load CLIP model with class-level caching."""
  358. # if cls._clip_model is None:
  359. # logger.info("Loading CLIP model...")
  360. # start_time = time.time()
  361. # try:
  362. # cls._clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
  363. # cls._clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
  364. # device = cls._get_device()
  365. # cls._clip_model.to(device)
  366. # cls._clip_model.eval()
  367. # logger.info(f"CLIP model loaded in {time.time() - start_time:.2f}s")
  368. # except Exception as e:
  369. # logger.error(f"Failed to load CLIP model: {str(e)}")
  370. # raise
  371. # return cls._clip_model, cls._clip_processor
  372. # def _get_cache_key(self, image_url: str, operation: str) -> str:
  373. # """Generate cache key for results."""
  374. # url_hash = hashlib.md5(image_url.encode()).hexdigest()
  375. # return f"visual_{operation}_{url_hash}"
  376. # def _get_cached(self, key: str) -> Optional[Dict]:
  377. # """Get cached result if available and not expired."""
  378. # if not self.use_cache:
  379. # return None
  380. # if key in self._cache:
  381. # result, timestamp = self._cache[key]
  382. # if time.time() - timestamp < self.cache_ttl:
  383. # return result
  384. # else:
  385. # del self._cache[key]
  386. # return None
  387. # def _set_cached(self, key: str, value: Dict):
  388. # """Cache result with timestamp."""
  389. # if self.use_cache:
  390. # self._cache[key] = (value, time.time())
  391. # def download_image(self, image_url: str, max_size: Tuple[int, int] = (1024, 1024)) -> Optional[Image.Image]:
  392. # """Download and optionally resize image for faster processing."""
  393. # try:
  394. # response = requests.get(image_url, timeout=10)
  395. # response.raise_for_status()
  396. # image = Image.open(BytesIO(response.content)).convert('RGB')
  397. # # Resize if image is too large
  398. # if image.size[0] > max_size[0] or image.size[1] > max_size[1]:
  399. # image.thumbnail(max_size, Image.Resampling.LANCZOS)
  400. # logger.info(f"Resized image from original size to {image.size}")
  401. # return image
  402. # except Exception as e:
  403. # logger.error(f"Error downloading image from {image_url}: {str(e)}")
  404. # return None
  405. # @lru_cache(maxsize=100)
  406. # def _get_color_name_cached(self, rgb: Tuple[int, int, int]) -> str:
  407. # """Cached version of color name lookup."""
  408. # try:
  409. # return webcolors.rgb_to_name(rgb)
  410. # except ValueError:
  411. # min_distance = float('inf')
  412. # closest_name = 'unknown'
  413. # for name in webcolors.CSS3_NAMES_TO_HEX:
  414. # hex_color = webcolors.CSS3_NAMES_TO_HEX[name]
  415. # r, g, b = webcolors.hex_to_rgb(hex_color)
  416. # distance = sum((c1 - c2) ** 2 for c1, c2 in zip(rgb, (r, g, b)))
  417. # if distance < min_distance:
  418. # min_distance = distance
  419. # closest_name = name
  420. # return closest_name
  421. # def extract_dominant_colors(self, image: Image.Image, n_colors: int = 3) -> List[Dict]:
  422. # """Extract dominant colors with optimized K-means."""
  423. # try:
  424. # # Resize for faster processing
  425. # img_small = image.resize((150, 150))
  426. # img_array = np.array(img_small)
  427. # pixels = img_array.reshape(-1, 3)
  428. # # Sample pixels if too many
  429. # if len(pixels) > 10000:
  430. # indices = np.random.choice(len(pixels), 10000, replace=False)
  431. # pixels = pixels[indices]
  432. # # K-means with optimized parameters
  433. # kmeans = KMeans(
  434. # n_clusters=n_colors,
  435. # random_state=42,
  436. # n_init=5, # Reduced from 10 for speed
  437. # max_iter=100,
  438. # algorithm='elkan' # Faster for low dimensions
  439. # )
  440. # kmeans.fit(pixels)
  441. # colors = []
  442. # labels_counts = np.bincount(kmeans.labels_)
  443. # for i, center in enumerate(kmeans.cluster_centers_):
  444. # rgb = tuple(center.astype(int))
  445. # color_name = self._get_color_name_cached(rgb)
  446. # percentage = float(labels_counts[i] / len(kmeans.labels_) * 100)
  447. # colors.append({
  448. # "name": color_name,
  449. # "rgb": rgb,
  450. # "percentage": percentage
  451. # })
  452. # colors.sort(key=lambda x: x['percentage'], reverse=True)
  453. # return colors
  454. # except Exception as e:
  455. # logger.error(f"Error extracting colors: {str(e)}")
  456. # return []
  457. # def classify_with_clip(
  458. # self,
  459. # image: Image.Image,
  460. # candidates: List[str],
  461. # attribute_name: str,
  462. # confidence_threshold: float = 0.15
  463. # ) -> Dict:
  464. # """Optimized CLIP classification with batching."""
  465. # try:
  466. # model, processor = self._get_clip_model()
  467. # device = self._get_device()
  468. # # Prepare inputs
  469. # inputs = processor(
  470. # text=candidates,
  471. # images=image,
  472. # return_tensors="pt",
  473. # padding=True
  474. # )
  475. # # Move to device
  476. # inputs = {k: v.to(device) for k, v in inputs.items()}
  477. # # Get predictions with no_grad for speed
  478. # with torch.no_grad():
  479. # outputs = model(**inputs)
  480. # logits_per_image = outputs.logits_per_image
  481. # probs = logits_per_image.softmax(dim=1).cpu()
  482. # # Get top predictions
  483. # top_k = min(3, len(candidates))
  484. # top_probs, top_indices = torch.topk(probs[0], k=top_k)
  485. # results = []
  486. # for prob, idx in zip(top_probs, top_indices):
  487. # if prob.item() > confidence_threshold:
  488. # results.append({
  489. # "value": candidates[idx.item()],
  490. # "confidence": float(prob.item())
  491. # })
  492. # return {
  493. # "attribute": attribute_name,
  494. # "predictions": results
  495. # }
  496. # except Exception as e:
  497. # logger.error(f"Error in CLIP classification for {attribute_name}: {str(e)}")
  498. # return {"attribute": attribute_name, "predictions": []}
  499. # def batch_classify(
  500. # self,
  501. # image: Image.Image,
  502. # attribute_configs: List[Dict[str, any]]
  503. # ) -> Dict[str, Dict]:
  504. # """
  505. # Batch multiple CLIP classifications for efficiency.
  506. # attribute_configs: [{"name": "pattern", "candidates": [...], "threshold": 0.15}, ...]
  507. # """
  508. # results = {}
  509. # for config in attribute_configs:
  510. # attr_name = config["name"]
  511. # candidates = config["candidates"]
  512. # threshold = config.get("threshold", 0.15)
  513. # result = self.classify_with_clip(image, candidates, attr_name, threshold)
  514. # results[attr_name] = result
  515. # return results
  516. # def detect_patterns(self, image: Image.Image) -> Dict:
  517. # """Detect patterns using CLIP."""
  518. # pattern_types = [
  519. # "solid color", "striped", "checkered", "polka dot", "floral",
  520. # "geometric", "abstract", "graphic print", "camouflage", "paisley"
  521. # ]
  522. # return self.classify_with_clip(image, pattern_types, "pattern")
  523. # def process_image(
  524. # self,
  525. # image_url: str,
  526. # product_type_hint: Optional[str] = None,
  527. # attributes_to_extract: Optional[List[str]] = None
  528. # ) -> Dict:
  529. # """
  530. # Main method with caching and selective attribute extraction.
  531. # Args:
  532. # image_url: URL of the product image
  533. # product_type_hint: Optional hint about product type
  534. # attributes_to_extract: List of attributes to extract (None = all)
  535. # """
  536. # # Check cache
  537. # cache_key = self._get_cache_key(image_url, "full")
  538. # cached_result = self._get_cached(cache_key)
  539. # if cached_result:
  540. # logger.info(f"Returning cached result for {image_url}")
  541. # return cached_result
  542. # start_time = time.time()
  543. # try:
  544. # # Download image
  545. # image = self.download_image(image_url)
  546. # if image is None:
  547. # return {
  548. # "visual_attributes": {},
  549. # "error": "Failed to download image"
  550. # }
  551. # visual_attributes = {}
  552. # detailed_predictions = {}
  553. # # Default: extract all attributes
  554. # if attributes_to_extract is None:
  555. # attributes_to_extract = [
  556. # "product_type", "color", "pattern", "material",
  557. # "style", "fit", "neckline", "sleeve_type",
  558. # "closure_type", "length"
  559. # ]
  560. # # 1. Product Type Detection
  561. # if "product_type" in attributes_to_extract:
  562. # logger.info("Detecting product type...")
  563. # product_types = [
  564. # "t-shirt", "shirt", "dress", "pants", "jeans", "shorts",
  565. # "skirt", "jacket", "coat", "sweater", "hoodie", "blazer",
  566. # "top", "blouse"
  567. # ]
  568. # product_type_result = self.classify_with_clip(image, product_types, "product_type")
  569. # if product_type_result["predictions"]:
  570. # visual_attributes["product_type"] = product_type_result["predictions"][0]["value"]
  571. # detected_product_type = visual_attributes["product_type"]
  572. # else:
  573. # detected_product_type = product_type_hint or "unknown"
  574. # detailed_predictions["product_type"] = product_type_result
  575. # else:
  576. # detected_product_type = product_type_hint or "unknown"
  577. # # 2. Color Detection
  578. # if "color" in attributes_to_extract:
  579. # logger.info("Extracting colors...")
  580. # colors = self.extract_dominant_colors(image, n_colors=3)
  581. # if colors:
  582. # visual_attributes["primary_color"] = colors[0]["name"]
  583. # visual_attributes["color_palette"] = [c["name"] for c in colors]
  584. # visual_attributes["color_details"] = colors
  585. # # 3. Batch classify remaining attributes
  586. # batch_configs = []
  587. # if "pattern" in attributes_to_extract:
  588. # batch_configs.append({
  589. # "name": "pattern",
  590. # "candidates": [
  591. # "solid color", "striped", "checkered", "polka dot",
  592. # "floral", "geometric", "abstract", "graphic print"
  593. # ]
  594. # })
  595. # if "material" in attributes_to_extract:
  596. # batch_configs.append({
  597. # "name": "material",
  598. # "candidates": [
  599. # "cotton", "polyester", "denim", "leather", "silk",
  600. # "wool", "linen", "satin", "fleece", "knit"
  601. # ]
  602. # })
  603. # if "style" in attributes_to_extract:
  604. # batch_configs.append({
  605. # "name": "style",
  606. # "candidates": [
  607. # "casual", "formal", "sporty", "business", "vintage",
  608. # "modern", "streetwear", "elegant", "athletic"
  609. # ]
  610. # })
  611. # if "fit" in attributes_to_extract:
  612. # batch_configs.append({
  613. # "name": "fit",
  614. # "candidates": [
  615. # "slim fit", "regular fit", "loose fit", "oversized",
  616. # "relaxed", "tailored"
  617. # ]
  618. # })
  619. # # Product-type specific attributes
  620. # if detected_product_type.lower() in ['shirt', 't-shirt', 'top', 'blouse', 'dress', 'sweater']:
  621. # if "neckline" in attributes_to_extract:
  622. # batch_configs.append({
  623. # "name": "neckline",
  624. # "candidates": [
  625. # "crew neck", "v-neck", "round neck", "collar",
  626. # "turtleneck", "scoop neck", "boat neck"
  627. # ]
  628. # })
  629. # if "sleeve_type" in attributes_to_extract:
  630. # batch_configs.append({
  631. # "name": "sleeve_type",
  632. # "candidates": [
  633. # "short sleeve", "long sleeve", "sleeveless",
  634. # "three-quarter sleeve", "cap sleeve"
  635. # ]
  636. # })
  637. # if "closure_type" in attributes_to_extract:
  638. # batch_configs.append({
  639. # "name": "closure_type",
  640. # "candidates": [
  641. # "button", "zipper", "snap", "pull-on",
  642. # "lace-up", "elastic", "buckle"
  643. # ]
  644. # })
  645. # if "length" in attributes_to_extract:
  646. # if detected_product_type.lower() in ['pants', 'jeans', 'trousers']:
  647. # batch_configs.append({
  648. # "name": "length",
  649. # "candidates": ["full length", "ankle length", "cropped", "capri", "shorts"]
  650. # })
  651. # elif detected_product_type.lower() in ['skirt', 'dress']:
  652. # batch_configs.append({
  653. # "name": "length",
  654. # "candidates": ["mini", "knee length", "midi", "maxi", "floor length"]
  655. # })
  656. # # Execute batch classification
  657. # logger.info(f"Batch classifying {len(batch_configs)} attributes...")
  658. # batch_results = self.batch_classify(image, batch_configs)
  659. # # Process batch results
  660. # for attr_name, result in batch_results.items():
  661. # detailed_predictions[attr_name] = result
  662. # if result["predictions"]:
  663. # visual_attributes[attr_name] = result["predictions"][0]["value"]
  664. # # Format response
  665. # result = {
  666. # "visual_attributes": visual_attributes,
  667. # "detailed_predictions": detailed_predictions,
  668. # "processing_time": round(time.time() - start_time, 2)
  669. # }
  670. # # Cache result
  671. # self._set_cached(cache_key, result)
  672. # logger.info(f"Visual processing completed in {result['processing_time']}s")
  673. # return result
  674. # except Exception as e:
  675. # logger.error(f"Error processing image: {str(e)}")
  676. # return {
  677. # "visual_attributes": {},
  678. # "error": str(e),
  679. # "processing_time": round(time.time() - start_time, 2)
  680. # }
  681. # def clear_cache(self):
  682. # """Clear all cached results."""
  683. # self._cache.clear()
  684. # logger.info("Cache cleared")
  685. # def get_cache_stats(self) -> Dict:
  686. # """Get cache statistics."""
  687. # return {
  688. # "cache_size": len(self._cache),
  689. # "cache_enabled": self.use_cache,
  690. # "cache_ttl": self.cache_ttl
  691. # }
  692. # @classmethod
  693. # def cleanup_models(cls):
  694. # """Free up memory by unloading models."""
  695. # if cls._clip_model is not None:
  696. # del cls._clip_model
  697. # del cls._clip_processor
  698. # cls._clip_model = None
  699. # cls._clip_processor = None
  700. # if torch.cuda.is_available():
  701. # torch.cuda.empty_cache()
  702. # logger.info("Models unloaded and memory freed")
  703. # # ==================== Usage Example ====================
  704. # def example_usage():
  705. # """Example of how to use the optimized service."""
  706. # # Initialize service with caching
  707. # service = VisualProcessingService(use_cache=True, cache_ttl=3600)
  708. # # Process single image with all attributes
  709. # result1 = service.process_image("https://example.com/product1.jpg")
  710. # print("All attributes:", result1["visual_attributes"])
  711. # # Process with selective attributes (faster)
  712. # result2 = service.process_image(
  713. # "https://example.com/product2.jpg",
  714. # product_type_hint="t-shirt",
  715. # attributes_to_extract=["color", "pattern", "style"]
  716. # )
  717. # print("Selected attributes:", result2["visual_attributes"])
  718. # # Check cache stats
  719. # print("Cache stats:", service.get_cache_stats())
  720. # # Clear cache when needed
  721. # service.clear_cache()
  722. # # Cleanup models (call when shutting down)
  723. # VisualProcessingService.cleanup_models()
  724. # if __name__ == "__main__":
  725. # example_usage()
  726. # ==================== visual_processing_service.py (FIXED) ====================
  727. import torch
  728. import cv2
  729. import numpy as np
  730. import requests
  731. from io import BytesIO
  732. from PIL import Image
  733. from typing import Dict, List, Optional, Tuple
  734. import logging
  735. from transformers import CLIPProcessor, CLIPModel
  736. from sklearn.cluster import KMeans
  737. logger = logging.getLogger(__name__)
  738. class VisualProcessingService:
  739. """Service for extracting visual attributes from product images using CLIP."""
  740. # Class-level caching (shared across instances)
  741. _clip_model = None
  742. _clip_processor = None
  743. _device = None
  744. def __init__(self):
  745. pass
  746. @classmethod
  747. def _get_device(cls):
  748. """Get optimal device."""
  749. if cls._device is None:
  750. cls._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  751. logger.info(f"Visual Processing using device: {cls._device}")
  752. return cls._device
  753. @classmethod
  754. def _get_clip_model(cls):
  755. """Lazy load CLIP model with class-level caching."""
  756. if cls._clip_model is None:
  757. logger.info("Loading CLIP model (this may take a few minutes on first use)...")
  758. cls._clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
  759. cls._clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
  760. device = cls._get_device()
  761. cls._clip_model.to(device)
  762. cls._clip_model.eval()
  763. logger.info("✓ CLIP model loaded successfully")
  764. return cls._clip_model, cls._clip_processor
  765. def download_image(self, image_url: str) -> Optional[Image.Image]:
  766. """Download image from URL."""
  767. try:
  768. response = requests.get(image_url, timeout=10)
  769. response.raise_for_status()
  770. image = Image.open(BytesIO(response.content)).convert('RGB')
  771. return image
  772. except Exception as e:
  773. logger.error(f"Error downloading image from {image_url}: {str(e)}")
  774. return None
  775. def extract_dominant_colors(self, image: Image.Image, n_colors: int = 3) -> List[Dict]:
  776. """Extract dominant colors using K-means (FIXED webcolors issue)."""
  777. try:
  778. # Resize for faster processing
  779. img_small = image.resize((150, 150))
  780. img_array = np.array(img_small)
  781. pixels = img_array.reshape(-1, 3)
  782. # K-means clustering
  783. kmeans = KMeans(n_clusters=n_colors, random_state=42, n_init=5)
  784. kmeans.fit(pixels)
  785. colors = []
  786. labels_counts = np.bincount(kmeans.labels_)
  787. for i, center in enumerate(kmeans.cluster_centers_):
  788. rgb = tuple(center.astype(int))
  789. color_name = self._get_color_name_simple(rgb)
  790. percentage = float(labels_counts[i] / len(kmeans.labels_) * 100)
  791. colors.append({
  792. "name": color_name,
  793. "rgb": rgb,
  794. "percentage": percentage
  795. })
  796. colors.sort(key=lambda x: x['percentage'], reverse=True)
  797. return colors
  798. except Exception as e:
  799. logger.error(f"Error extracting colors: {str(e)}")
  800. return []
  801. def _get_color_name_simple(self, rgb: Tuple[int, int, int]) -> str:
  802. """
  803. Simple color name detection without webcolors dependency.
  804. Maps RGB to basic color names.
  805. """
  806. r, g, b = rgb
  807. # Define basic color ranges
  808. colors = {
  809. 'black': (r < 50 and g < 50 and b < 50),
  810. 'white': (r > 200 and g > 200 and b > 200),
  811. 'gray': (abs(r - g) < 30 and abs(g - b) < 30 and abs(r - b) < 30 and 50 <= r <= 200),
  812. 'red': (r > 150 and g < 100 and b < 100),
  813. 'green': (g > 150 and r < 100 and b < 100),
  814. 'blue': (b > 150 and r < 100 and g < 100),
  815. 'yellow': (r > 200 and g > 200 and b < 100),
  816. 'orange': (r > 200 and 100 < g < 200 and b < 100),
  817. 'purple': (r > 100 and b > 100 and g < 100),
  818. 'pink': (r > 200 and 100 < g < 200 and 100 < b < 200),
  819. 'brown': (50 < r < 150 and 30 < g < 100 and b < 80),
  820. 'cyan': (r < 100 and g > 150 and b > 150),
  821. }
  822. for color_name, condition in colors.items():
  823. if condition:
  824. return color_name
  825. # Default fallback
  826. if r > g and r > b:
  827. return 'red'
  828. elif g > r and g > b:
  829. return 'green'
  830. elif b > r and b > g:
  831. return 'blue'
  832. else:
  833. return 'gray'
  834. def classify_with_clip(
  835. self,
  836. image: Image.Image,
  837. candidates: List[str],
  838. attribute_name: str
  839. ) -> Dict:
  840. """Use CLIP to classify image against candidate labels."""
  841. try:
  842. model, processor = self._get_clip_model()
  843. device = self._get_device()
  844. # Prepare inputs
  845. inputs = processor(
  846. text=candidates,
  847. images=image,
  848. return_tensors="pt",
  849. padding=True
  850. )
  851. # Move to device
  852. inputs = {k: v.to(device) for k, v in inputs.items()}
  853. # Get predictions
  854. with torch.no_grad():
  855. outputs = model(**inputs)
  856. logits_per_image = outputs.logits_per_image
  857. probs = logits_per_image.softmax(dim=1).cpu()
  858. # Get top predictions
  859. top_k = min(3, len(candidates))
  860. top_probs, top_indices = torch.topk(probs[0], k=top_k)
  861. results = []
  862. for prob, idx in zip(top_probs, top_indices):
  863. if prob.item() > 0.15: # Confidence threshold
  864. results.append({
  865. "value": candidates[idx.item()],
  866. "confidence": float(prob.item())
  867. })
  868. return {
  869. "attribute": attribute_name,
  870. "predictions": results
  871. }
  872. except Exception as e:
  873. logger.error(f"Error in CLIP classification for {attribute_name}: {str(e)}")
  874. return {"attribute": attribute_name, "predictions": []}
  875. def process_image(
  876. self,
  877. image_url: str,
  878. product_type_hint: Optional[str] = None
  879. ) -> Dict:
  880. """
  881. Main method to process image and extract visual attributes.
  882. """
  883. import time
  884. start_time = time.time()
  885. try:
  886. # Download image
  887. image = self.download_image(image_url)
  888. if image is None:
  889. return {
  890. "visual_attributes": {},
  891. "error": "Failed to download image"
  892. }
  893. visual_attributes = {}
  894. detailed_predictions = {}
  895. # 1. Product Type Detection
  896. product_types = [
  897. "t-shirt", "shirt", "dress", "pants", "jeans", "shorts",
  898. "skirt", "jacket", "coat", "sweater", "hoodie", "top"
  899. ]
  900. product_type_result = self.classify_with_clip(image, product_types, "product_type")
  901. if product_type_result["predictions"]:
  902. visual_attributes["product_type"] = product_type_result["predictions"][0]["value"]
  903. detected_product_type = visual_attributes["product_type"]
  904. else:
  905. detected_product_type = product_type_hint or "unknown"
  906. detailed_predictions["product_type"] = product_type_result
  907. # 2. Color Detection
  908. colors = self.extract_dominant_colors(image, n_colors=3)
  909. if colors:
  910. visual_attributes["primary_color"] = colors[0]["name"]
  911. visual_attributes["color_palette"] = [c["name"] for c in colors]
  912. # 3. Pattern Detection
  913. patterns = ["solid color", "striped", "checkered", "graphic print", "floral", "geometric"]
  914. pattern_result = self.classify_with_clip(image, patterns, "pattern")
  915. if pattern_result["predictions"]:
  916. visual_attributes["pattern"] = pattern_result["predictions"][0]["value"]
  917. detailed_predictions["pattern"] = pattern_result
  918. # 4. Material Detection
  919. materials = ["cotton", "polyester", "denim", "leather", "silk", "wool", "linen"]
  920. material_result = self.classify_with_clip(image, materials, "material")
  921. if material_result["predictions"]:
  922. visual_attributes["material"] = material_result["predictions"][0]["value"]
  923. detailed_predictions["material"] = material_result
  924. # 5. Style Detection
  925. styles = ["casual", "formal", "sporty", "streetwear", "elegant", "vintage"]
  926. style_result = self.classify_with_clip(image, styles, "style")
  927. if style_result["predictions"]:
  928. visual_attributes["style"] = style_result["predictions"][0]["value"]
  929. detailed_predictions["style"] = style_result
  930. # 6. Fit Detection
  931. fits = ["slim fit", "regular fit", "loose fit", "oversized"]
  932. fit_result = self.classify_with_clip(image, fits, "fit")
  933. if fit_result["predictions"]:
  934. visual_attributes["fit"] = fit_result["predictions"][0]["value"]
  935. detailed_predictions["fit"] = fit_result
  936. # 7. Neckline (for tops)
  937. if detected_product_type.lower() in ['shirt', 't-shirt', 'top', 'dress']:
  938. necklines = ["crew neck", "v-neck", "round neck", "collar"]
  939. neckline_result = self.classify_with_clip(image, necklines, "neckline")
  940. if neckline_result["predictions"]:
  941. visual_attributes["neckline"] = neckline_result["predictions"][0]["value"]
  942. detailed_predictions["neckline"] = neckline_result
  943. # 8. Sleeve Type (for tops)
  944. if detected_product_type.lower() in ['shirt', 't-shirt', 'top']:
  945. sleeves = ["short sleeve", "long sleeve", "sleeveless"]
  946. sleeve_result = self.classify_with_clip(image, sleeves, "sleeve_type")
  947. if sleeve_result["predictions"]:
  948. visual_attributes["sleeve_type"] = sleeve_result["predictions"][0]["value"]
  949. detailed_predictions["sleeve_type"] = sleeve_result
  950. # 9. Closure Type
  951. closures = ["button", "zipper", "pull-on"]
  952. closure_result = self.classify_with_clip(image, closures, "closure_type")
  953. if closure_result["predictions"]:
  954. visual_attributes["closure_type"] = closure_result["predictions"][0]["value"]
  955. detailed_predictions["closure_type"] = closure_result
  956. processing_time = time.time() - start_time
  957. return {
  958. "visual_attributes": visual_attributes,
  959. "detailed_predictions": detailed_predictions,
  960. "processing_time": round(processing_time, 2)
  961. }
  962. except Exception as e:
  963. logger.error(f"Error processing image: {str(e)}")
  964. return {
  965. "visual_attributes": {},
  966. "error": str(e),
  967. "processing_time": round(time.time() - start_time, 2)
  968. }