""" Enhanced Multi-API-Key Load Balancer with Smart Rate Limiting Optimized for Groq's free tier (30 RPM per key, 14K daily limit) """ import time import threading import requests import logging from typing import List, Dict, Optional from dataclasses import dataclass, field from datetime import datetime from django.conf import settings logger = logging.getLogger(__name__) @dataclass class APIKeyState: """Tracks state and health of a single API key""" key: str name: str requests_made: int = 0 last_request_time: float = 0 is_available: bool = True rate_limit_reset_time: Optional[float] = None consecutive_failures: int = 0 total_requests: int = 0 total_failures: int = 0 request_times: list = field(default_factory=list) requests_per_minute: int = 25 # Conservative: 25 instead of 30 min_request_interval: float = 2.5 # Minimum 2.5s between requests per key def can_make_request(self) -> bool: """Check if key can make a request (rate limit + spacing)""" now = time.time() # Check minimum interval between requests if self.last_request_time and (now - self.last_request_time) < self.min_request_interval: return False # Remove requests older than 1 minute self.request_times = [t for t in self.request_times if now - t < 60] return len(self.request_times) < self.requests_per_minute def mark_success(self): now = time.time() self.requests_made += 1 self.total_requests += 1 self.last_request_time = now self.request_times.append(now) self.consecutive_failures = 0 self.is_available = True self.rate_limit_reset_time = None # Keep only last 60 seconds self.request_times = [t for t in self.request_times if now - t < 60] def mark_failure(self, is_rate_limit: bool = False, retry_after: Optional[int] = None): self.consecutive_failures += 1 self.total_failures += 1 if is_rate_limit: self.is_available = False reset_time = time.time() + (retry_after or 65) # 65s default self.rate_limit_reset_time = reset_time logger.warning(f"🚫 {self.name} rate limited until {datetime.fromtimestamp(reset_time).strftime('%H:%M:%S')}") # Disable after 5 consecutive failures (increased from 3) if self.consecutive_failures >= 5: self.is_available = False self.rate_limit_reset_time = time.time() + 120 # 2 min cooldown logger.error(f"❌ {self.name} disabled (cooldown 2min)") def check_availability(self) -> bool: """Check if key is available""" # Check rate limit reset if self.rate_limit_reset_time and time.time() >= self.rate_limit_reset_time: self.is_available = True self.rate_limit_reset_time = None self.consecutive_failures = 0 logger.info(f"✅ {self.name} recovered") return True if not self.is_available: return False return self.can_make_request() def get_stats(self) -> Dict: success_count = self.total_requests - self.total_failures success_rate = (success_count / max(self.total_requests, 1)) * 100 return { "name": self.name, "total_requests": self.total_requests, "total_failures": self.total_failures, "success_rate": round(success_rate, 2), "is_available": self.check_availability(), "consecutive_failures": self.consecutive_failures, "current_rpm": len(self.request_times), "max_rpm": self.requests_per_minute, "time_since_last_request": round(time.time() - self.last_request_time, 1) if self.last_request_time else None } class MultiKeyLLMLoadBalancer: """Enhanced load balancer with smart rate limiting""" def __init__(self, api_keys: List[Dict[str, str]], strategy: str = "round_robin"): if not api_keys: raise ValueError("At least one API key required") self.keys = [APIKeyState(key=k['key'], name=k['name']) for k in api_keys] self.strategy = strategy self.current_index = 0 self.lock = threading.Lock() self.total_requests = 0 self.total_failures = 0 self.global_last_request = 0 self.min_global_interval = 0.5 # 500ms between ANY requests logger.info(f"🔑 Load balancer initialized: {len(self.keys)} keys, '{strategy}' strategy") def get_next_key(self) -> Optional[APIKeyState]: """Get next available key with global rate limiting""" with self.lock: # Enforce minimum global interval now = time.time() time_since_last = now - self.global_last_request if time_since_last < self.min_global_interval: wait_time = self.min_global_interval - time_since_last time.sleep(wait_time) if self.strategy == "least_loaded": return self._least_loaded_select() else: return self._round_robin_select() def _round_robin_select(self) -> Optional[APIKeyState]: """Round-robin with availability check""" attempts = 0 total_keys = len(self.keys) while attempts < total_keys: key = self.keys[self.current_index] self.current_index = (self.current_index + 1) % total_keys if key.check_availability(): return key attempts += 1 return self._wait_for_available_key() def _least_loaded_select(self) -> Optional[APIKeyState]: """Select least loaded key""" available = [k for k in self.keys if k.check_availability()] if not available: return self._wait_for_available_key() available.sort(key=lambda k: (len(k.request_times), k.last_request_time)) return available[0] def _wait_for_available_key(self, max_wait: float = 5.0) -> Optional[APIKeyState]: """Wait for next available key (with timeout)""" keys_with_reset = [k for k in self.keys if k.rate_limit_reset_time] if not keys_with_reset: # Check if any key just needs spacing now = time.time() for key in self.keys: if key.is_available: wait = key.min_request_interval - (now - key.last_request_time) if 0 < wait < max_wait: logger.info(f"⏳ Waiting {wait:.1f}s for {key.name}...") time.sleep(wait + 0.1) return key if key.check_availability() else None return None keys_with_reset.sort(key=lambda k: k.rate_limit_reset_time) next_key = keys_with_reset[0] wait = max(0, next_key.rate_limit_reset_time - time.time()) if 0 < wait < max_wait: logger.info(f"⏳ Waiting {wait:.1f}s for {next_key.name}...") time.sleep(wait + 0.5) return next_key if next_key.check_availability() else None return None def mark_success(self, key: APIKeyState): with self.lock: key.mark_success() self.total_requests += 1 self.global_last_request = time.time() def mark_failure(self, key: APIKeyState, is_rate_limit: bool = False, retry_after: Optional[int] = None): with self.lock: key.mark_failure(is_rate_limit, retry_after) self.total_failures += 1 def get_stats(self) -> Dict: with self.lock: available_count = sum(1 for k in self.keys if k.check_availability()) success_rate = ((self.total_requests - self.total_failures) / max(self.total_requests, 1)) * 100 return { "total_keys": len(self.keys), "available_keys": available_count, "strategy": self.strategy, "total_requests": self.total_requests, "total_failures": self.total_failures, "success_rate": round(success_rate, 2), "keys": [k.get_stats() for k in self.keys] } def call_llm(self, payload: dict, api_url: str, max_retries: int = None) -> str: """Make LLM call with smart retry and failover""" if max_retries is None: max_retries = len(self.keys) * 3 attempt = 0 last_error = None keys_tried = set() while attempt < max_retries: key_state = self.get_next_key() if not key_state: if len(keys_tried) >= len(self.keys): # All keys tried, wait longer logger.warning(f"⏳ All keys exhausted. Waiting 3s...") time.sleep(3) keys_tried.clear() attempt += 1 continue keys_tried.add(key_state.name) try: headers = { "Authorization": f"Bearer {key_state.key}", "Content-Type": "application/json" } logger.debug(f"🔑 {key_state.name} (attempt {attempt + 1}/{max_retries})") response = requests.post( api_url, headers=headers, json=payload, timeout=30 ) if response.status_code == 429: retry_after = int(response.headers.get('Retry-After', 65)) self.mark_failure(key_state, is_rate_limit=True, retry_after=retry_after) attempt += 1 time.sleep(1) # Brief pause before next key continue response.raise_for_status() # Success self.mark_success(key_state) content = response.json()["choices"][0]["message"]["content"] logger.debug(f"✅ Success via {key_state.name}") return content except requests.exceptions.HTTPError as e: if e.response and e.response.status_code == 429: retry_after = int(e.response.headers.get('Retry-After', 65)) self.mark_failure(key_state, is_rate_limit=True, retry_after=retry_after) else: self.mark_failure(key_state) logger.error(f"❌ HTTP error {key_state.name}: {e}") last_error = e attempt += 1 time.sleep(0.5) except Exception as e: self.mark_failure(key_state) logger.error(f"❌ Error {key_state.name}: {e}") last_error = e attempt += 1 time.sleep(0.5) stats = self.get_stats() error_msg = ( f"LLM failed after {max_retries} attempts. " f"Available: {stats['available_keys']}/{stats['total_keys']}. " f"Error: {last_error}" ) logger.error(f"💥 {error_msg}") raise RuntimeError(error_msg) # Global instance _load_balancer: Optional[MultiKeyLLMLoadBalancer] = None _load_balancer_lock = threading.Lock() def get_llm_load_balancer() -> MultiKeyLLMLoadBalancer: """Get singleton load balancer""" global _load_balancer if _load_balancer is None: with _load_balancer_lock: if _load_balancer is None: api_keys = getattr(settings, 'GROQ_API_KEYS', None) if not api_keys: single_key = getattr(settings, 'GROQ_API_KEY', None) if single_key: api_keys = [{'key': single_key, 'name': 'groq_key_1'}] if not api_keys: raise ValueError("No GROQ API keys configured") strategy = getattr(settings, 'LLM_LOAD_BALANCER_STRATEGY', 'round_robin') _load_balancer = MultiKeyLLMLoadBalancer(api_keys, strategy=strategy) return _load_balancer def reset_load_balancer(): """Reset load balancer""" global _load_balancer with _load_balancer_lock: _load_balancer = None def call_llm_with_load_balancer(payload: dict) -> str: """Drop-in replacement for _call_llm""" balancer = get_llm_load_balancer() api_url = getattr(settings, 'GROQ_API_URL') return balancer.call_llm(payload, api_url) def get_load_balancer_stats() -> Dict: """Get stats""" try: return get_llm_load_balancer().get_stats() except Exception as e: return {"error": str(e), "total_keys": 0, "available_keys": 0} # """ # Ultra-Safe Sequential Load Balancer with Adaptive Rate Limiting # Guaranteed to work with strict API rate limits # """ # import time # import threading # import requests # import logging # from typing import List, Dict, Optional # from dataclasses import dataclass # from datetime import datetime # from django.conf import settings # logger = logging.getLogger(__name__) # @dataclass # class APIKeyState: # """Simple key state tracker""" # key: str # name: str # last_used: float = 0 # total_requests: int = 0 # total_failures: int = 0 # consecutive_failures: int = 0 # disabled_until: float = 0 # def is_available(self) -> bool: # """Check if key is available RIGHT NOW""" # now = time.time() # # Check if disabled # if self.disabled_until > now: # return False # # Require 5 seconds between requests on SAME key # if self.last_used > 0: # elapsed = now - self.last_used # if elapsed < 5.0: # return False # return True # def get_wait_time(self) -> float: # """How long until this key is available?""" # now = time.time() # if self.disabled_until > now: # return self.disabled_until - now # if self.last_used > 0: # elapsed = now - self.last_used # if elapsed < 5.0: # return 5.0 - elapsed # return 0 # def mark_success(self): # self.last_used = time.time() # self.total_requests += 1 # self.consecutive_failures = 0 # self.disabled_until = 0 # logger.info(f"✅ {self.name} success (total: {self.total_requests})") # def mark_failure(self, is_rate_limit: bool = False): # self.last_used = time.time() # self.total_requests += 1 # self.total_failures += 1 # self.consecutive_failures += 1 # if is_rate_limit: # # Rate limit: wait 90 seconds # self.disabled_until = time.time() + 90 # logger.error(f"🚫 {self.name} RATE LIMITED → disabled for 90s") # elif self.consecutive_failures >= 2: # # 2 failures: wait 60 seconds # self.disabled_until = time.time() + 60 # logger.error(f"❌ {self.name} FAILED {self.consecutive_failures}x → disabled for 60s") # class UltraSafeLoadBalancer: # """ # Ultra-conservative load balancer # - Minimum 5 seconds between requests on same key # - Minimum 1 second between ANY requests (global) # - Automatic waiting for key availability # - No parallel requests # """ # def __init__(self, api_keys: List[Dict[str, str]]): # if not api_keys: # raise ValueError("At least one API key required") # self.keys = [APIKeyState(key=k['key'], name=k['name']) for k in api_keys] # self.current_index = 0 # self.lock = threading.Lock() # self.last_global_request = 0 # self.min_global_interval = 1.0 # 1 second between ANY requests # logger.info(f"🔑 Ultra-safe balancer: {len(self.keys)} keys, 5s per-key interval, 1s global interval") # def _enforce_global_rate_limit(self): # """Ensure minimum time between ANY requests""" # with self.lock: # if self.last_global_request > 0: # elapsed = time.time() - self.last_global_request # if elapsed < self.min_global_interval: # wait = self.min_global_interval - elapsed # logger.debug(f"⏱️ Global rate limit: waiting {wait:.1f}s") # time.sleep(wait) # self.last_global_request = time.time() # def get_next_key(self, max_wait: float = 30.0) -> Optional[APIKeyState]: # """Get next available key, waiting if necessary""" # start_time = time.time() # while (time.time() - start_time) < max_wait: # with self.lock: # # Try round-robin # for _ in range(len(self.keys)): # key = self.keys[self.current_index] # self.current_index = (self.current_index + 1) % len(self.keys) # if key.is_available(): # return key # # No keys available - find the one that will be ready soonest # wait_times = [(k.get_wait_time(), k) for k in self.keys] # wait_times.sort() # if wait_times: # min_wait, next_key = wait_times[0] # if min_wait > 0 and min_wait < 15: # logger.info(f"⏳ All keys busy. Waiting {min_wait:.1f}s for {next_key.name}...") # time.sleep(min_wait + 0.2) # continue # time.sleep(0.5) # # Timeout # logger.error(f"❌ No keys available after {max_wait}s wait") # return None # def call_llm(self, payload: dict, api_url: str, retry_count: int = 0) -> str: # """ # Make LLM call with ONE key # Retries with SAME key after waiting if it fails # """ # # Enforce global rate limit FIRST # self._enforce_global_rate_limit() # # Get available key # key = self.get_next_key(max_wait=30.0) # if not key: # raise RuntimeError("No API keys available after 30s wait") # try: # headers = { # "Authorization": f"Bearer {key.key}", # "Content-Type": "application/json" # } # logger.info(f"🔑 Request via {key.name} (retry: {retry_count})") # response = requests.post( # api_url, # headers=headers, # json=payload, # timeout=45 # ) # # Check rate limit # if response.status_code == 429: # key.mark_failure(is_rate_limit=True) # if retry_count < 2: # logger.warning(f"⚠️ Rate limit hit, retrying with different key...") # time.sleep(2) # return self.call_llm(payload, api_url, retry_count + 1) # else: # raise RuntimeError(f"Rate limit on {key.name} after {retry_count} retries") # # Check for other errors # response.raise_for_status() # # Success! # key.mark_success() # content = response.json()["choices"][0]["message"]["content"] # return content # except requests.exceptions.HTTPError as e: # if e.response and e.response.status_code == 429: # key.mark_failure(is_rate_limit=True) # else: # key.mark_failure(is_rate_limit=False) # if retry_count < 2: # logger.warning(f"⚠️ HTTP error, retrying... ({e})") # time.sleep(2) # return self.call_llm(payload, api_url, retry_count + 1) # else: # raise RuntimeError(f"HTTP error after {retry_count} retries: {e}") # except requests.exceptions.Timeout as e: # key.mark_failure(is_rate_limit=False) # if retry_count < 1: # logger.warning(f"⏱️ Timeout, retrying...") # time.sleep(3) # return self.call_llm(payload, api_url, retry_count + 1) # else: # raise RuntimeError(f"Timeout after {retry_count} retries: {e}") # except Exception as e: # key.mark_failure(is_rate_limit=False) # raise RuntimeError(f"Unexpected error with {key.name}: {e}") # def get_stats(self) -> Dict: # with self.lock: # available = sum(1 for k in self.keys if k.is_available()) # total_reqs = sum(k.total_requests for k in self.keys) # total_fails = sum(k.total_failures for k in self.keys) # success_rate = ((total_reqs - total_fails) / max(total_reqs, 1)) * 100 # return { # "total_keys": len(self.keys), # "available_keys": available, # "total_requests": total_reqs, # "total_failures": total_fails, # "success_rate": round(success_rate, 2), # "keys": [ # { # "name": k.name, # "total_requests": k.total_requests, # "total_failures": k.total_failures, # "consecutive_failures": k.consecutive_failures, # "is_available": k.is_available(), # "wait_time": round(k.get_wait_time(), 1) # } # for k in self.keys # ] # } # # Global singleton # _balancer: Optional[UltraSafeLoadBalancer] = None # _balancer_lock = threading.Lock() # def get_llm_load_balancer() -> UltraSafeLoadBalancer: # """Get singleton balancer""" # global _balancer # if _balancer is None: # with _balancer_lock: # if _balancer is None: # api_keys = getattr(settings, 'GROQ_API_KEYS', None) # if not api_keys: # single_key = getattr(settings, 'GROQ_API_KEY', None) # if single_key: # api_keys = [{'key': single_key, 'name': 'groq_key_1'}] # if not api_keys: # raise ValueError("No GROQ_API_KEYS configured in settings") # _balancer = UltraSafeLoadBalancer(api_keys) # return _balancer # def reset_load_balancer(): # """Reset balancer (for testing)""" # global _balancer # with _balancer_lock: # _balancer = None # logger.info("🔄 Balancer reset") # def call_llm_with_load_balancer(payload: dict) -> str: # """ # Call LLM with ultra-safe rate limiting # This is the drop-in replacement for services.py # """ # balancer = get_llm_load_balancer() # api_url = getattr(settings, 'GROQ_API_URL') # return balancer.call_llm(payload, api_url) # def get_load_balancer_stats() -> Dict: # """Get balancer stats""" # try: # return get_llm_load_balancer().get_stats() # except Exception as e: # return {"error": str(e)}