Browse Source

initial commit attr extraction

Harshit Pathak 3 tháng trước cách đây
mục cha
commit
ce204b31f6

+ 0 - 0
attr_extraction/__init__.py


+ 3 - 0
attr_extraction/admin.py

@@ -0,0 +1,3 @@
+from django.contrib import admin
+
+# Register your models here.

+ 6 - 0
attr_extraction/apps.py

@@ -0,0 +1,6 @@
+from django.apps import AppConfig
+
+
+class AttrExtractionConfig(AppConfig):
+    default_auto_field = 'django.db.models.BigAutoField'
+    name = 'attr_extraction'

+ 0 - 0
attr_extraction/migrations/__init__.py


+ 54 - 0
attr_extraction/models.py

@@ -0,0 +1,54 @@
+# models.py
+from django.db import models
+from django.contrib.postgres.fields import JSONField
+
+class Product(models.Model):
+    title = models.CharField(max_length=500)
+    description = models.TextField()
+    short_description = models.TextField(blank=True)
+    attributes_extracted = models.BooleanField(default=False)
+    created_at = models.DateTimeField(auto_now_add=True)
+    updated_at = models.DateTimeField(auto_now=True)
+    
+    class Meta:
+        db_table = 'products'
+        indexes = [
+            models.Index(fields=['attributes_extracted', 'created_at']),
+        ]
+
+class ProductImage(models.Model):
+    product = models.ForeignKey(Product, related_name='images', on_delete=models.CASCADE)
+    image = models.ImageField(upload_to='products/')
+    order = models.PositiveIntegerField(default=0)
+    
+    class Meta:
+        db_table = 'product_images'
+        ordering = ['order']
+
+class ProductAttribute(models.Model):
+    product = models.ForeignKey(Product, related_name='attributes', on_delete=models.CASCADE)
+    attribute_name = models.CharField(max_length=100, db_index=True)
+    attribute_value = models.TextField()
+    confidence_score = models.FloatField(default=0.0)
+    extraction_method = models.CharField(
+        max_length=20,
+        choices=[('nlp', 'NLP'), ('llm', 'LLM'), ('hybrid', 'Hybrid')],
+        default='hybrid'
+    )
+    needs_review = models.BooleanField(default=False)
+    reviewed = models.BooleanField(default=False)
+    created_at = models.DateTimeField(auto_now_add=True)
+    
+    class Meta:
+        db_table = 'product_attributes'
+        unique_together = ['product', 'attribute_name']
+        indexes = [
+            models.Index(fields=['attribute_name', 'confidence_score']),
+            models.Index(fields=['needs_review', 'reviewed']),
+        ]
+    
+    def save(self, *args, **kwargs):
+        # Auto-flag low confidence for review
+        if self.confidence_score < 0.7:
+            self.needs_review = True
+        super().save(*args, **kwargs)

+ 13 - 0
attr_extraction/serializers.py

@@ -0,0 +1,13 @@
+# serializers.py
+from rest_framework import serializers
+from .models import ProductAttribute
+
+class ProductAttributeSerializer(serializers.ModelSerializer):
+    product_title = serializers.CharField(source='product.title', read_only=True)
+    
+    class Meta:
+        model = ProductAttribute
+        fields = ['id', 'product', 'product_title', 'attribute_name', 
+                  'attribute_value', 'confidence_score', 'extraction_method',
+                  'needs_review', 'reviewed', 'created_at']
+

+ 322 - 0
attr_extraction/services/attribute_extractor.py

@@ -0,0 +1,322 @@
+# services/attribute_extractor.py
+import re
+import spacy
+from typing import Dict, List, Optional
+from anthropic import Anthropic
+import base64
+from PIL import Image
+import pytesseract
+from collections import defaultdict
+
+class HybridAttributeExtractor:
+    """
+    Hybrid extractor using NLP for structured data and LLM for complex/ambiguous cases
+    """
+    
+    def __init__(self, anthropic_api_key: str, product_type_mappings: Dict = None):
+        self.nlp = spacy.load("en_core_web_sm")
+        self.client = Anthropic(api_key=anthropic_api_key)
+        self.product_type_mappings = product_type_mappings or self._load_default_mappings()
+        
+        # Define patterns for common attributes
+        self.patterns = {
+            'size': [
+                r'\b(XXS|XS|S|M|L|XL|XXL|XXXL)\b',
+                r'\b(\d+(?:\.\d+)?)\s*(inch|inches|cm|mm|meter|metres?|ft|feet|")\b',
+                r'\b(small|medium|large|extra large)\b'
+            ],
+            'color': [
+                r'\b(black|white|red|blue|green|yellow|orange|purple|pink|brown|gray|grey|silver|gold|beige|navy|maroon|olive|teal|turquoise|lavender|cream|ivory)\b'
+            ],
+            'weight': [
+                r'\b(\d+(?:\.\d+)?)\s*(kg|g|lb|lbs|oz|pounds?|grams?|kilograms?)\b'
+            ],
+            'material': [
+                r'\b(cotton|polyester|silk|wool|leather|denim|linen|nylon|spandex|rayon|acrylic|metal|plastic|wood|glass|ceramic|steel|aluminum|rubber)\b'
+            ],
+            'brand': [
+                r'(?:by|from|brand:?)\s+([A-Z][a-zA-Z0-9\s&]+?)(?:\s|$|,|\.|;)'
+            ]
+        }
+        
+        # Confidence thresholds
+        self.confidence_threshold = 0.6
+        
+    def extract_attributes(self, product_data: Dict) -> Dict:
+        """
+        Main extraction method - uses NLP first, LLM for gaps
+        """
+        # Phase 1: Quick NLP extraction
+        nlp_attributes = self._extract_with_nlp(
+            product_data.get('title', ''),
+            product_data.get('description', '')
+        )
+        
+        # Phase 2: OCR from images if provided
+        ocr_text = ""
+        if product_data.get('images'):
+            ocr_text = self._extract_text_from_images(product_data['images'])
+            if ocr_text:
+                ocr_attributes = self._extract_with_nlp("", ocr_text)
+                nlp_attributes = self._merge_attributes(nlp_attributes, ocr_attributes)
+        
+        # Phase 3: Always call LLM to enrich and validate NLP results
+        llm_attributes = self._extract_with_llm(
+            product_data,
+            nlp_attributes,
+            ocr_text
+        )
+        final_attributes = self._merge_attributes(nlp_attributes, llm_attributes)
+        
+        return final_attributes
+    
+    def _extract_with_nlp(self, title: str, description: str) -> Dict:
+        """
+        Fast extraction using regex and spaCy
+        """
+        text = f"{title} {description}".lower()
+        attributes = defaultdict(list)
+        
+        # Pattern matching for structured attributes
+        for attr_type, patterns in self.patterns.items():
+            for pattern in patterns:
+                matches = re.finditer(pattern, text, re.IGNORECASE)
+                for match in matches:
+                    value = match.group(1) if match.groups() else match.group(0)
+                    attributes[attr_type].append(value.strip())
+        
+        # Named Entity Recognition for brands, organizations
+        doc = self.nlp(title + " " + description)
+        for ent in doc.ents:
+            if ent.label_ == "ORG" and 'brand' not in attributes:
+                attributes['brand'].append(ent.text)
+            elif ent.label_ == "PRODUCT":
+                attributes['product_type'].append(ent.text)
+            elif ent.label_ == "MONEY":
+                attributes['price'].append(ent.text)
+        
+        # Deduplicate and clean
+        cleaned_attributes = {}
+        for key, values in attributes.items():
+            if values:
+                # Take most common or first occurrence
+                cleaned_attributes[key] = list(set(values))[0] if len(set(values)) == 1 else values
+                cleaned_attributes[f'{key}_confidence'] = 0.8 if len(set(values)) == 1 else 0.5
+        
+        return cleaned_attributes
+    
+    def _extract_text_from_images(self, image_paths: List[str]) -> str:
+        """
+        Extract text from product images using OCR
+        """
+        extracted_text = []
+        
+        for img_path in image_paths[:3]:  # Limit to 3 images
+            try:
+                img = Image.open(img_path)
+                text = pytesseract.image_to_string(img)
+                if text.strip():
+                    extracted_text.append(text.strip())
+            except Exception as e:
+                print(f"OCR error for {img_path}: {e}")
+        
+        return " ".join(extracted_text)
+    
+    def _needs_llm_extraction(self, attributes: Dict, product_data: Dict) -> bool:
+        """
+        Determine if LLM extraction is needed based on confidence and completeness
+        """
+        # Check if critical attributes are missing
+        critical_attrs = ['category', 'brand', 'color', 'size']
+        missing_critical = any(attr not in attributes for attr in critical_attrs)
+        
+        # Check confidence levels
+        low_confidence = any(
+            attributes.get(f'{key}_confidence', 0) < self.confidence_threshold
+            for key in attributes.keys() if not key.endswith('_confidence')
+        )
+        
+        # Check if description is complex/unstructured
+        description = product_data.get('description', '')
+        is_complex = len(description.split()) > 100 or 'features' in description.lower()
+        
+        return missing_critical or low_confidence or is_complex
+    
+    def _extract_with_llm(self, product_data: Dict, existing_attrs: Dict, ocr_text: str) -> Dict:
+        """
+        Use LLM to extract comprehensive attributes and validate NLP results
+        """
+        prompt = f"""Analyze this product and extract ALL possible attributes with high accuracy.
+
+Title: {product_data.get('title', 'N/A')}
+Description: {product_data.get('description', 'N/A')}
+Short Description: {product_data.get('short_description', 'N/A')}
+Text from images (OCR): {ocr_text if ocr_text else 'N/A'}
+
+NLP Pre-extracted attributes (validate and enhance): {existing_attrs}
+
+Extract a comprehensive JSON object with these fields (include all that apply):
+
+**Basic Info:**
+- category: specific product category/type
+- subcategory: more specific classification
+- brand: brand name
+- model: model number/name
+- product_line: product series/collection
+
+**Physical Attributes:**
+- color: all colors (list if multiple)
+- size: size information (with units)
+- dimensions: length/width/height with units
+- weight: weight with units
+- material: materials used (list all)
+- finish: surface finish/texture
+
+**Technical Specs (if applicable):**
+- specifications: key technical specs as object
+- compatibility: what it works with
+- capacity: storage/volume capacity
+- power: power requirements/battery info
+
+**Commercial Info:**
+- condition: new/used/refurbished
+- warranty: warranty information
+- country_of_origin: manufacturing country
+- certifications: safety/quality certifications
+
+**Descriptive:**
+- key_features: list of 5-8 main features
+- benefits: main benefits/use cases
+- target_audience: who this is for
+- usage_instructions: how to use (if mentioned)
+- care_instructions: care/maintenance info
+- style: style/aesthetic (modern, vintage, etc)
+- season: seasonal relevance (if applicable)
+- occasion: suitable occasions (if applicable)
+
+**Additional:**
+- package_contents: what's included
+- variants: available variants/options
+- tags: relevant search tags (list)
+
+Only include fields where you have high confidence. Use null for uncertain values.
+For lists, provide all relevant items. Be thorough and extract every possible detail."""
+
+        content = [{"type": "text", "text": prompt}]
+        
+        # Add images if available
+        if product_data.get('images'):
+            for img_path in product_data['images'][:3]:  # Include up to 3 images for better context
+                try:
+                    with open(img_path, 'rb') as f:
+                        img_data = base64.b64encode(f.read()).decode()
+                    
+                    # Determine media type
+                    media_type = "image/jpeg"
+                    if img_path.lower().endswith('.png'):
+                        media_type = "image/png"
+                    elif img_path.lower().endswith('.webp'):
+                        media_type = "image/webp"
+                    
+                    content.append({
+                        "type": "image",
+                        "source": {
+                            "type": "base64",
+                            "media_type": media_type,
+                            "data": img_data
+                        }
+                    })
+                except Exception as e:
+                    print(f"Error processing image {img_path}: {e}")
+        
+        try:
+            response = self.client.messages.create(
+                model="claude-sonnet-4-20250514",
+                max_tokens=2048,  # Increased for comprehensive extraction
+                messages=[{"role": "user", "content": content}]
+            )
+            
+            # Parse JSON response
+            import json
+            llm_result = json.loads(response.content[0].text)
+            
+            # Add high confidence to LLM results
+            for key in llm_result:
+                if llm_result[key] is not None:
+                    llm_result[f'{key}_confidence'] = 0.95
+            
+            return llm_result
+        
+        except Exception as e:
+            print(f"LLM extraction error: {e}")
+            return {}
+    
+    def _identify_missing_attributes(self, existing_attrs: Dict) -> List[str]:
+        """
+        Identify which attributes are missing or low confidence
+        """
+        important_attrs = ['category', 'brand', 'color', 'size', 'material', 'key_features']
+        missing = []
+        
+        for attr in important_attrs:
+            if attr not in existing_attrs or existing_attrs.get(f'{attr}_confidence', 0) < 0.7:
+                missing.append(attr)
+        
+        return missing
+    
+    def _merge_attributes(self, base: Dict, additional: Dict) -> Dict:
+        """
+        Intelligently merge attributes, preferring LLM for new attributes and validation
+        """
+        merged = {}
+        
+        # Start with all NLP attributes
+        for key, value in base.items():
+            if not key.endswith('_confidence'):
+                merged[key] = value
+                merged[f'{key}_confidence'] = base.get(f'{key}_confidence', 0.7)
+        
+        # Add or override with LLM attributes
+        for key, value in additional.items():
+            if key.endswith('_confidence'):
+                continue
+            
+            if value is None:
+                # Keep NLP value if LLM returns null
+                continue
+            
+            # LLM found new attribute or better value
+            if key not in merged:
+                merged[key] = value
+                merged[f'{key}_confidence'] = additional.get(f'{key}_confidence', 0.95)
+            else:
+                # Compare values - if different, prefer LLM but mark for review
+                llm_conf = additional.get(f'{key}_confidence', 0.95)
+                nlp_conf = merged.get(f'{key}_confidence', 0.7)
+                
+                if str(value).lower() != str(merged[key]).lower():
+                    # Values differ - use LLM but add conflict flag
+                    merged[key] = value
+                    merged[f'{key}_confidence'] = llm_conf
+                    merged[f'{key}_nlp_value'] = base.get(key)  # Store NLP value for reference
+                    merged[f'{key}_conflict'] = True
+                else:
+                    # Values match - boost confidence
+                    merged[key] = value
+                    merged[f'{key}_confidence'] = min(0.99, (llm_conf + nlp_conf) / 2 + 0.1)
+        
+        return merged
+
+
+# Example usage
+if __name__ == "__main__":
+    extractor = HybridAttributeExtractor(anthropic_api_key="your-api-key")
+    
+    product = {
+        'title': 'Nike Air Max 270 Running Shoes - Black/White',
+        'description': 'Premium running shoes with Max Air cushioning. Breathable mesh upper, rubber outsole. Perfect for daily training.',
+        'images': ['path/to/image1.jpg', 'path/to/image2.jpg']
+    }
+    
+    attributes = extractor.extract_attributes(product)
+    print(attributes)

+ 78 - 0
attr_extraction/tasks.py

@@ -0,0 +1,78 @@
+# tasks.py
+from celery import shared_task
+from django.core.cache import cache
+from .models import Product, ProductAttribute
+from .services.attribute_extractor import HybridAttributeExtractor
+import json
+import hashlib
+
+@shared_task(bind=True, max_retries=3)
+def extract_product_attributes(self, product_id: int):
+    """
+    Celery task to extract attributes from a product
+    """
+    try:
+        product = Product.objects.get(id=product_id)
+        
+        # Check cache first
+        cache_key = f"product_attrs_{product.id}_{product.updated_at.timestamp()}"
+        cached_attrs = cache.get(cache_key)
+        
+        if cached_attrs:
+            return cached_attrs
+        
+        # Prepare product data
+        product_data = {
+            'title': product.title,
+            'description': product.description,
+            'short_description': product.short_description,
+            'images': [img.image.path for img in product.images.all()]
+        }
+        
+        # Extract attributes
+        extractor = HybridAttributeExtractor(
+            anthropic_api_key=settings.ANTHROPIC_API_KEY
+        )
+        attributes = extractor.extract_attributes(product_data)
+        
+        # Save to database
+        for attr_name, attr_value in attributes.items():
+            if not attr_name.endswith('_confidence'):
+                confidence = attributes.get(f'{attr_name}_confidence', 0.5)
+                
+                ProductAttribute.objects.update_or_create(
+                    product=product,
+                    attribute_name=attr_name,
+                    defaults={
+                        'attribute_value': json.dumps(attr_value) if isinstance(attr_value, (list, dict)) else str(attr_value),
+                        'confidence_score': confidence,
+                        'extraction_method': 'hybrid'
+                    }
+                )
+        
+        # Cache for 24 hours
+        cache.set(cache_key, attributes, 86400)
+        
+        # Update product status
+        product.attributes_extracted = True
+        product.save()
+        
+        return attributes
+        
+    except Product.DoesNotExist:
+        return {'error': 'Product not found'}
+    except Exception as e:
+        # Retry with exponential backoff
+        raise self.retry(exc=e, countdown=60 * (2 ** self.request.retries))
+
+
+@shared_task
+def batch_extract_attributes(product_ids: list):
+    """
+    Process multiple products in batch
+    """
+    results = {}
+    for product_id in product_ids:
+        result = extract_product_attributes.delay(product_id)
+        results[product_id] = result.id
+    return results

+ 3 - 0
attr_extraction/tests.py

@@ -0,0 +1,3 @@
+from django.test import TestCase
+
+# Create your tests here.

+ 15 - 0
attr_extraction/urls.py

@@ -0,0 +1,15 @@
+from django.urls import path
+from .views import (
+    ExtractAttributesView, 
+    BatchExtractAttributesView,
+    ProductAttributesView,
+    AttributeReviewView
+)
+
+urlpatterns = [
+    path('products/<int:product_id>/extract/', ExtractAttributesView.as_view()),
+    path('products/batch-extract/', BatchExtractAttributesView.as_view()),
+    path('products/<int:product_id>/attributes/', ProductAttributesView.as_view()),
+    path('attributes/review/', AttributeReviewView.as_view()),
+    path('attributes/<int:attribute_id>/review/', AttributeReviewView.as_view()),
+]

+ 99 - 0
attr_extraction/views.py

@@ -0,0 +1,99 @@
+from django.shortcuts import render
+
+# Create your views here.
+# views.py
+from rest_framework.views import APIView
+from rest_framework.response import Response
+from rest_framework import status
+from .tasks import extract_product_attributes, batch_extract_attributes
+from .models import Product, ProductAttribute
+from .serializers import ProductAttributeSerializer
+
+class ExtractAttributesView(APIView):
+    """
+    Trigger attribute extraction for a product
+    """
+    def post(self, request, product_id):
+        try:
+            product = Product.objects.get(id=product_id)
+            
+            # Trigger async task
+            task = extract_product_attributes.delay(product_id)
+            
+            return Response({
+                'message': 'Extraction started',
+                'task_id': task.id,
+                'product_id': product_id
+            }, status=status.HTTP_202_ACCEPTED)
+            
+        except Product.DoesNotExist:
+            return Response({'error': 'Product not found'}, status=status.HTTP_404_NOT_FOUND)
+
+
+class BatchExtractAttributesView(APIView):
+    """
+    Trigger batch extraction
+    """
+    def post(self, request):
+        product_ids = request.data.get('product_ids', [])
+        
+        if not product_ids:
+            return Response({'error': 'No product IDs provided'}, status=status.HTTP_400_BAD_REQUEST)
+        
+        task_results = batch_extract_attributes.delay(product_ids)
+        
+        return Response({
+            'message': f'Batch extraction started for {len(product_ids)} products',
+            'task_id': task_results.id
+        }, status=status.HTTP_202_ACCEPTED)
+
+
+class ProductAttributesView(APIView):
+    """
+    Get extracted attributes for a product
+    """
+    def get(self, request, product_id):
+        try:
+            product = Product.objects.get(id=product_id)
+            attributes = ProductAttribute.objects.filter(product=product)
+            
+            serializer = ProductAttributeSerializer(attributes, many=True)
+            
+            return Response({
+                'product_id': product_id,
+                'attributes_extracted': product.attributes_extracted,
+                'attributes': serializer.data
+            })
+            
+        except Product.DoesNotExist:
+            return Response({'error': 'Product not found'}, status=status.HTTP_404_NOT_FOUND)
+
+
+class AttributeReviewView(APIView):
+    """
+    Review and update low-confidence attributes
+    """
+    def get(self, request):
+        # Get attributes needing review
+        attributes = ProductAttribute.objects.filter(
+            needs_review=True,
+            reviewed=False
+        ).select_related('product')[:50]
+        
+        serializer = ProductAttributeSerializer(attributes, many=True)
+        return Response(serializer.data)
+    
+    def patch(self, request, attribute_id):
+        try:
+            attribute = ProductAttribute.objects.get(id=attribute_id)
+            
+            # Update attribute
+            attribute.attribute_value = request.data.get('attribute_value', attribute.attribute_value)
+            attribute.reviewed = True
+            attribute.confidence_score = 1.0  # Human verified
+            attribute.save()
+            
+            return Response({'message': 'Attribute updated'})
+            
+        except ProductAttribute.DoesNotExist:
+            return Response({'error': 'Attribute not found'}, status=status.HTTP_404_NOT_FOUND)

+ 30 - 0
celery.py

@@ -0,0 +1,30 @@
+# celery.py (in your project root)
+import os
+from celery import Celery
+
+os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'your_project.settings')
+
+app = Celery('your_project')
+app.config_from_object('django.conf:settings', namespace='CELERY')
+app.autodiscover_tasks()
+
+
+# settings.py additions
+CELERY_BROKER_URL = 'redis://localhost:6379/0'
+CELERY_RESULT_BACKEND = 'redis://localhost:6379/0'
+CELERY_TASK_SERIALIZER = 'json'
+CELERY_ACCEPT_CONTENT = ['json']
+CELERY_RESULT_SERIALIZER = 'json'
+CELERY_TIMEZONE = 'UTC'
+
+CACHES = {
+    'default': {
+        'BACKEND': 'django_redis.cache.RedisCache',
+        'LOCATION': 'redis://127.0.0.1:6379/1',
+        'OPTIONS': {
+            'CLIENT_CLASS': 'django_redis.client.DefaultClient',
+        }
+    }
+}
+
+ANTHROPIC_API_KEY = os.environ.get('ANTHROPIC_API_KEY')

+ 25 - 1
content_quality_tool/settings.py

@@ -29,6 +29,7 @@ INSTALLED_APPS = [
     'django.contrib.staticfiles',
     'core',
     'rest_framework',
+    'attr_extraction',
 ]
 MIDDLEWARE = [
     'django.middleware.security.SecurityMiddleware',
@@ -111,4 +112,27 @@ MESSAGE_TAGS = {
     messages.INFO: 'info',
     messages.WARNING: 'warning',
     messages.DEBUG: 'debug',
-}
+}
+
+
+
+
+# settings.py additions
+CELERY_BROKER_URL = 'redis://localhost:6379/0'
+CELERY_RESULT_BACKEND = 'redis://localhost:6379/0'
+CELERY_TASK_SERIALIZER = 'json'
+CELERY_ACCEPT_CONTENT = ['json']
+CELERY_RESULT_SERIALIZER = 'json'
+CELERY_TIMEZONE = 'UTC'
+
+CACHES = {
+    'default': {
+        'BACKEND': 'django_redis.cache.RedisCache',
+        'LOCATION': 'redis://127.0.0.1:6379/1',
+        'OPTIONS': {
+            'CLIENT_CLASS': 'django_redis.client.DefaultClient',
+        }
+    }
+}
+
+ANTHROPIC_API_KEY = os.environ.get('ANTHROPIC_API_KEY')