llm_load_balancer.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662
  1. """
  2. Enhanced Multi-API-Key Load Balancer with Smart Rate Limiting
  3. Optimized for Groq's free tier (30 RPM per key, 14K daily limit)
  4. """
  5. import time
  6. import threading
  7. import requests
  8. import logging
  9. from typing import List, Dict, Optional
  10. from dataclasses import dataclass, field
  11. from datetime import datetime
  12. from django.conf import settings
  13. logger = logging.getLogger(__name__)
  14. @dataclass
  15. class APIKeyState:
  16. """Tracks state and health of a single API key"""
  17. key: str
  18. name: str
  19. requests_made: int = 0
  20. last_request_time: float = 0
  21. is_available: bool = True
  22. rate_limit_reset_time: Optional[float] = None
  23. consecutive_failures: int = 0
  24. total_requests: int = 0
  25. total_failures: int = 0
  26. request_times: list = field(default_factory=list)
  27. requests_per_minute: int = 25 # Conservative: 25 instead of 30
  28. min_request_interval: float = 2.5 # Minimum 2.5s between requests per key
  29. def can_make_request(self) -> bool:
  30. """Check if key can make a request (rate limit + spacing)"""
  31. now = time.time()
  32. # Check minimum interval between requests
  33. if self.last_request_time and (now - self.last_request_time) < self.min_request_interval:
  34. return False
  35. # Remove requests older than 1 minute
  36. self.request_times = [t for t in self.request_times if now - t < 60]
  37. return len(self.request_times) < self.requests_per_minute
  38. def mark_success(self):
  39. now = time.time()
  40. self.requests_made += 1
  41. self.total_requests += 1
  42. self.last_request_time = now
  43. self.request_times.append(now)
  44. self.consecutive_failures = 0
  45. self.is_available = True
  46. self.rate_limit_reset_time = None
  47. # Keep only last 60 seconds
  48. self.request_times = [t for t in self.request_times if now - t < 60]
  49. def mark_failure(self, is_rate_limit: bool = False, retry_after: Optional[int] = None):
  50. self.consecutive_failures += 1
  51. self.total_failures += 1
  52. if is_rate_limit:
  53. self.is_available = False
  54. reset_time = time.time() + (retry_after or 65) # 65s default
  55. self.rate_limit_reset_time = reset_time
  56. logger.warning(f"🚫 {self.name} rate limited until {datetime.fromtimestamp(reset_time).strftime('%H:%M:%S')}")
  57. # Disable after 5 consecutive failures (increased from 3)
  58. if self.consecutive_failures >= 5:
  59. self.is_available = False
  60. self.rate_limit_reset_time = time.time() + 120 # 2 min cooldown
  61. logger.error(f"❌ {self.name} disabled (cooldown 2min)")
  62. def check_availability(self) -> bool:
  63. """Check if key is available"""
  64. # Check rate limit reset
  65. if self.rate_limit_reset_time and time.time() >= self.rate_limit_reset_time:
  66. self.is_available = True
  67. self.rate_limit_reset_time = None
  68. self.consecutive_failures = 0
  69. logger.info(f"✅ {self.name} recovered")
  70. return True
  71. if not self.is_available:
  72. return False
  73. return self.can_make_request()
  74. def get_stats(self) -> Dict:
  75. success_count = self.total_requests - self.total_failures
  76. success_rate = (success_count / max(self.total_requests, 1)) * 100
  77. return {
  78. "name": self.name,
  79. "total_requests": self.total_requests,
  80. "total_failures": self.total_failures,
  81. "success_rate": round(success_rate, 2),
  82. "is_available": self.check_availability(),
  83. "consecutive_failures": self.consecutive_failures,
  84. "current_rpm": len(self.request_times),
  85. "max_rpm": self.requests_per_minute,
  86. "time_since_last_request": round(time.time() - self.last_request_time, 1) if self.last_request_time else None
  87. }
  88. class MultiKeyLLMLoadBalancer:
  89. """Enhanced load balancer with smart rate limiting"""
  90. def __init__(self, api_keys: List[Dict[str, str]], strategy: str = "round_robin"):
  91. if not api_keys:
  92. raise ValueError("At least one API key required")
  93. self.keys = [APIKeyState(key=k['key'], name=k['name']) for k in api_keys]
  94. self.strategy = strategy
  95. self.current_index = 0
  96. self.lock = threading.Lock()
  97. self.total_requests = 0
  98. self.total_failures = 0
  99. self.global_last_request = 0
  100. self.min_global_interval = 0.5 # 500ms between ANY requests
  101. logger.info(f"🔑 Load balancer initialized: {len(self.keys)} keys, '{strategy}' strategy")
  102. def get_next_key(self) -> Optional[APIKeyState]:
  103. """Get next available key with global rate limiting"""
  104. with self.lock:
  105. # Enforce minimum global interval
  106. now = time.time()
  107. time_since_last = now - self.global_last_request
  108. if time_since_last < self.min_global_interval:
  109. wait_time = self.min_global_interval - time_since_last
  110. time.sleep(wait_time)
  111. if self.strategy == "least_loaded":
  112. return self._least_loaded_select()
  113. else:
  114. return self._round_robin_select()
  115. def _round_robin_select(self) -> Optional[APIKeyState]:
  116. """Round-robin with availability check"""
  117. attempts = 0
  118. total_keys = len(self.keys)
  119. while attempts < total_keys:
  120. key = self.keys[self.current_index]
  121. self.current_index = (self.current_index + 1) % total_keys
  122. if key.check_availability():
  123. return key
  124. attempts += 1
  125. return self._wait_for_available_key()
  126. def _least_loaded_select(self) -> Optional[APIKeyState]:
  127. """Select least loaded key"""
  128. available = [k for k in self.keys if k.check_availability()]
  129. if not available:
  130. return self._wait_for_available_key()
  131. available.sort(key=lambda k: (len(k.request_times), k.last_request_time))
  132. return available[0]
  133. def _wait_for_available_key(self, max_wait: float = 5.0) -> Optional[APIKeyState]:
  134. """Wait for next available key (with timeout)"""
  135. keys_with_reset = [k for k in self.keys if k.rate_limit_reset_time]
  136. if not keys_with_reset:
  137. # Check if any key just needs spacing
  138. now = time.time()
  139. for key in self.keys:
  140. if key.is_available:
  141. wait = key.min_request_interval - (now - key.last_request_time)
  142. if 0 < wait < max_wait:
  143. logger.info(f"⏳ Waiting {wait:.1f}s for {key.name}...")
  144. time.sleep(wait + 0.1)
  145. return key if key.check_availability() else None
  146. return None
  147. keys_with_reset.sort(key=lambda k: k.rate_limit_reset_time)
  148. next_key = keys_with_reset[0]
  149. wait = max(0, next_key.rate_limit_reset_time - time.time())
  150. if 0 < wait < max_wait:
  151. logger.info(f"⏳ Waiting {wait:.1f}s for {next_key.name}...")
  152. time.sleep(wait + 0.5)
  153. return next_key if next_key.check_availability() else None
  154. return None
  155. def mark_success(self, key: APIKeyState):
  156. with self.lock:
  157. key.mark_success()
  158. self.total_requests += 1
  159. self.global_last_request = time.time()
  160. def mark_failure(self, key: APIKeyState, is_rate_limit: bool = False, retry_after: Optional[int] = None):
  161. with self.lock:
  162. key.mark_failure(is_rate_limit, retry_after)
  163. self.total_failures += 1
  164. def get_stats(self) -> Dict:
  165. with self.lock:
  166. available_count = sum(1 for k in self.keys if k.check_availability())
  167. success_rate = ((self.total_requests - self.total_failures) / max(self.total_requests, 1)) * 100
  168. return {
  169. "total_keys": len(self.keys),
  170. "available_keys": available_count,
  171. "strategy": self.strategy,
  172. "total_requests": self.total_requests,
  173. "total_failures": self.total_failures,
  174. "success_rate": round(success_rate, 2),
  175. "keys": [k.get_stats() for k in self.keys]
  176. }
  177. def call_llm(self, payload: dict, api_url: str, max_retries: int = None) -> str:
  178. """Make LLM call with smart retry and failover"""
  179. if max_retries is None:
  180. max_retries = len(self.keys) * 3
  181. attempt = 0
  182. last_error = None
  183. keys_tried = set()
  184. while attempt < max_retries:
  185. key_state = self.get_next_key()
  186. if not key_state:
  187. if len(keys_tried) >= len(self.keys):
  188. # All keys tried, wait longer
  189. logger.warning(f"⏳ All keys exhausted. Waiting 3s...")
  190. time.sleep(3)
  191. keys_tried.clear()
  192. attempt += 1
  193. continue
  194. keys_tried.add(key_state.name)
  195. try:
  196. headers = {
  197. "Authorization": f"Bearer {key_state.key}",
  198. "Content-Type": "application/json"
  199. }
  200. logger.debug(f"🔑 {key_state.name} (attempt {attempt + 1}/{max_retries})")
  201. response = requests.post(
  202. api_url,
  203. headers=headers,
  204. json=payload,
  205. timeout=30
  206. )
  207. if response.status_code == 429:
  208. retry_after = int(response.headers.get('Retry-After', 65))
  209. self.mark_failure(key_state, is_rate_limit=True, retry_after=retry_after)
  210. attempt += 1
  211. time.sleep(1) # Brief pause before next key
  212. continue
  213. response.raise_for_status()
  214. # Success
  215. self.mark_success(key_state)
  216. content = response.json()["choices"][0]["message"]["content"]
  217. logger.debug(f"✅ Success via {key_state.name}")
  218. return content
  219. except requests.exceptions.HTTPError as e:
  220. if e.response and e.response.status_code == 429:
  221. retry_after = int(e.response.headers.get('Retry-After', 65))
  222. self.mark_failure(key_state, is_rate_limit=True, retry_after=retry_after)
  223. else:
  224. self.mark_failure(key_state)
  225. logger.error(f"❌ HTTP error {key_state.name}: {e}")
  226. last_error = e
  227. attempt += 1
  228. time.sleep(0.5)
  229. except Exception as e:
  230. self.mark_failure(key_state)
  231. logger.error(f"❌ Error {key_state.name}: {e}")
  232. last_error = e
  233. attempt += 1
  234. time.sleep(0.5)
  235. stats = self.get_stats()
  236. error_msg = (
  237. f"LLM failed after {max_retries} attempts. "
  238. f"Available: {stats['available_keys']}/{stats['total_keys']}. "
  239. f"Error: {last_error}"
  240. )
  241. logger.error(f"💥 {error_msg}")
  242. raise RuntimeError(error_msg)
  243. # Global instance
  244. _load_balancer: Optional[MultiKeyLLMLoadBalancer] = None
  245. _load_balancer_lock = threading.Lock()
  246. def get_llm_load_balancer() -> MultiKeyLLMLoadBalancer:
  247. """Get singleton load balancer"""
  248. global _load_balancer
  249. if _load_balancer is None:
  250. with _load_balancer_lock:
  251. if _load_balancer is None:
  252. api_keys = getattr(settings, 'GROQ_API_KEYS', None)
  253. if not api_keys:
  254. single_key = getattr(settings, 'GROQ_API_KEY', None)
  255. if single_key:
  256. api_keys = [{'key': single_key, 'name': 'groq_key_1'}]
  257. if not api_keys:
  258. raise ValueError("No GROQ API keys configured")
  259. strategy = getattr(settings, 'LLM_LOAD_BALANCER_STRATEGY', 'round_robin')
  260. _load_balancer = MultiKeyLLMLoadBalancer(api_keys, strategy=strategy)
  261. return _load_balancer
  262. def reset_load_balancer():
  263. """Reset load balancer"""
  264. global _load_balancer
  265. with _load_balancer_lock:
  266. _load_balancer = None
  267. def call_llm_with_load_balancer(payload: dict) -> str:
  268. """Drop-in replacement for _call_llm"""
  269. balancer = get_llm_load_balancer()
  270. api_url = getattr(settings, 'GROQ_API_URL')
  271. return balancer.call_llm(payload, api_url)
  272. def get_load_balancer_stats() -> Dict:
  273. """Get stats"""
  274. try:
  275. return get_llm_load_balancer().get_stats()
  276. except Exception as e:
  277. return {"error": str(e), "total_keys": 0, "available_keys": 0}
  278. # """
  279. # Ultra-Safe Sequential Load Balancer with Adaptive Rate Limiting
  280. # Guaranteed to work with strict API rate limits
  281. # """
  282. # import time
  283. # import threading
  284. # import requests
  285. # import logging
  286. # from typing import List, Dict, Optional
  287. # from dataclasses import dataclass
  288. # from datetime import datetime
  289. # from django.conf import settings
  290. # logger = logging.getLogger(__name__)
  291. # @dataclass
  292. # class APIKeyState:
  293. # """Simple key state tracker"""
  294. # key: str
  295. # name: str
  296. # last_used: float = 0
  297. # total_requests: int = 0
  298. # total_failures: int = 0
  299. # consecutive_failures: int = 0
  300. # disabled_until: float = 0
  301. # def is_available(self) -> bool:
  302. # """Check if key is available RIGHT NOW"""
  303. # now = time.time()
  304. # # Check if disabled
  305. # if self.disabled_until > now:
  306. # return False
  307. # # Require 5 seconds between requests on SAME key
  308. # if self.last_used > 0:
  309. # elapsed = now - self.last_used
  310. # if elapsed < 5.0:
  311. # return False
  312. # return True
  313. # def get_wait_time(self) -> float:
  314. # """How long until this key is available?"""
  315. # now = time.time()
  316. # if self.disabled_until > now:
  317. # return self.disabled_until - now
  318. # if self.last_used > 0:
  319. # elapsed = now - self.last_used
  320. # if elapsed < 5.0:
  321. # return 5.0 - elapsed
  322. # return 0
  323. # def mark_success(self):
  324. # self.last_used = time.time()
  325. # self.total_requests += 1
  326. # self.consecutive_failures = 0
  327. # self.disabled_until = 0
  328. # logger.info(f"✅ {self.name} success (total: {self.total_requests})")
  329. # def mark_failure(self, is_rate_limit: bool = False):
  330. # self.last_used = time.time()
  331. # self.total_requests += 1
  332. # self.total_failures += 1
  333. # self.consecutive_failures += 1
  334. # if is_rate_limit:
  335. # # Rate limit: wait 90 seconds
  336. # self.disabled_until = time.time() + 90
  337. # logger.error(f"🚫 {self.name} RATE LIMITED → disabled for 90s")
  338. # elif self.consecutive_failures >= 2:
  339. # # 2 failures: wait 60 seconds
  340. # self.disabled_until = time.time() + 60
  341. # logger.error(f"❌ {self.name} FAILED {self.consecutive_failures}x → disabled for 60s")
  342. # class UltraSafeLoadBalancer:
  343. # """
  344. # Ultra-conservative load balancer
  345. # - Minimum 5 seconds between requests on same key
  346. # - Minimum 1 second between ANY requests (global)
  347. # - Automatic waiting for key availability
  348. # - No parallel requests
  349. # """
  350. # def __init__(self, api_keys: List[Dict[str, str]]):
  351. # if not api_keys:
  352. # raise ValueError("At least one API key required")
  353. # self.keys = [APIKeyState(key=k['key'], name=k['name']) for k in api_keys]
  354. # self.current_index = 0
  355. # self.lock = threading.Lock()
  356. # self.last_global_request = 0
  357. # self.min_global_interval = 1.0 # 1 second between ANY requests
  358. # logger.info(f"🔑 Ultra-safe balancer: {len(self.keys)} keys, 5s per-key interval, 1s global interval")
  359. # def _enforce_global_rate_limit(self):
  360. # """Ensure minimum time between ANY requests"""
  361. # with self.lock:
  362. # if self.last_global_request > 0:
  363. # elapsed = time.time() - self.last_global_request
  364. # if elapsed < self.min_global_interval:
  365. # wait = self.min_global_interval - elapsed
  366. # logger.debug(f"⏱️ Global rate limit: waiting {wait:.1f}s")
  367. # time.sleep(wait)
  368. # self.last_global_request = time.time()
  369. # def get_next_key(self, max_wait: float = 30.0) -> Optional[APIKeyState]:
  370. # """Get next available key, waiting if necessary"""
  371. # start_time = time.time()
  372. # while (time.time() - start_time) < max_wait:
  373. # with self.lock:
  374. # # Try round-robin
  375. # for _ in range(len(self.keys)):
  376. # key = self.keys[self.current_index]
  377. # self.current_index = (self.current_index + 1) % len(self.keys)
  378. # if key.is_available():
  379. # return key
  380. # # No keys available - find the one that will be ready soonest
  381. # wait_times = [(k.get_wait_time(), k) for k in self.keys]
  382. # wait_times.sort()
  383. # if wait_times:
  384. # min_wait, next_key = wait_times[0]
  385. # if min_wait > 0 and min_wait < 15:
  386. # logger.info(f"⏳ All keys busy. Waiting {min_wait:.1f}s for {next_key.name}...")
  387. # time.sleep(min_wait + 0.2)
  388. # continue
  389. # time.sleep(0.5)
  390. # # Timeout
  391. # logger.error(f"❌ No keys available after {max_wait}s wait")
  392. # return None
  393. # def call_llm(self, payload: dict, api_url: str, retry_count: int = 0) -> str:
  394. # """
  395. # Make LLM call with ONE key
  396. # Retries with SAME key after waiting if it fails
  397. # """
  398. # # Enforce global rate limit FIRST
  399. # self._enforce_global_rate_limit()
  400. # # Get available key
  401. # key = self.get_next_key(max_wait=30.0)
  402. # if not key:
  403. # raise RuntimeError("No API keys available after 30s wait")
  404. # try:
  405. # headers = {
  406. # "Authorization": f"Bearer {key.key}",
  407. # "Content-Type": "application/json"
  408. # }
  409. # logger.info(f"🔑 Request via {key.name} (retry: {retry_count})")
  410. # response = requests.post(
  411. # api_url,
  412. # headers=headers,
  413. # json=payload,
  414. # timeout=45
  415. # )
  416. # # Check rate limit
  417. # if response.status_code == 429:
  418. # key.mark_failure(is_rate_limit=True)
  419. # if retry_count < 2:
  420. # logger.warning(f"⚠️ Rate limit hit, retrying with different key...")
  421. # time.sleep(2)
  422. # return self.call_llm(payload, api_url, retry_count + 1)
  423. # else:
  424. # raise RuntimeError(f"Rate limit on {key.name} after {retry_count} retries")
  425. # # Check for other errors
  426. # response.raise_for_status()
  427. # # Success!
  428. # key.mark_success()
  429. # content = response.json()["choices"][0]["message"]["content"]
  430. # return content
  431. # except requests.exceptions.HTTPError as e:
  432. # if e.response and e.response.status_code == 429:
  433. # key.mark_failure(is_rate_limit=True)
  434. # else:
  435. # key.mark_failure(is_rate_limit=False)
  436. # if retry_count < 2:
  437. # logger.warning(f"⚠️ HTTP error, retrying... ({e})")
  438. # time.sleep(2)
  439. # return self.call_llm(payload, api_url, retry_count + 1)
  440. # else:
  441. # raise RuntimeError(f"HTTP error after {retry_count} retries: {e}")
  442. # except requests.exceptions.Timeout as e:
  443. # key.mark_failure(is_rate_limit=False)
  444. # if retry_count < 1:
  445. # logger.warning(f"⏱️ Timeout, retrying...")
  446. # time.sleep(3)
  447. # return self.call_llm(payload, api_url, retry_count + 1)
  448. # else:
  449. # raise RuntimeError(f"Timeout after {retry_count} retries: {e}")
  450. # except Exception as e:
  451. # key.mark_failure(is_rate_limit=False)
  452. # raise RuntimeError(f"Unexpected error with {key.name}: {e}")
  453. # def get_stats(self) -> Dict:
  454. # with self.lock:
  455. # available = sum(1 for k in self.keys if k.is_available())
  456. # total_reqs = sum(k.total_requests for k in self.keys)
  457. # total_fails = sum(k.total_failures for k in self.keys)
  458. # success_rate = ((total_reqs - total_fails) / max(total_reqs, 1)) * 100
  459. # return {
  460. # "total_keys": len(self.keys),
  461. # "available_keys": available,
  462. # "total_requests": total_reqs,
  463. # "total_failures": total_fails,
  464. # "success_rate": round(success_rate, 2),
  465. # "keys": [
  466. # {
  467. # "name": k.name,
  468. # "total_requests": k.total_requests,
  469. # "total_failures": k.total_failures,
  470. # "consecutive_failures": k.consecutive_failures,
  471. # "is_available": k.is_available(),
  472. # "wait_time": round(k.get_wait_time(), 1)
  473. # }
  474. # for k in self.keys
  475. # ]
  476. # }
  477. # # Global singleton
  478. # _balancer: Optional[UltraSafeLoadBalancer] = None
  479. # _balancer_lock = threading.Lock()
  480. # def get_llm_load_balancer() -> UltraSafeLoadBalancer:
  481. # """Get singleton balancer"""
  482. # global _balancer
  483. # if _balancer is None:
  484. # with _balancer_lock:
  485. # if _balancer is None:
  486. # api_keys = getattr(settings, 'GROQ_API_KEYS', None)
  487. # if not api_keys:
  488. # single_key = getattr(settings, 'GROQ_API_KEY', None)
  489. # if single_key:
  490. # api_keys = [{'key': single_key, 'name': 'groq_key_1'}]
  491. # if not api_keys:
  492. # raise ValueError("No GROQ_API_KEYS configured in settings")
  493. # _balancer = UltraSafeLoadBalancer(api_keys)
  494. # return _balancer
  495. # def reset_load_balancer():
  496. # """Reset balancer (for testing)"""
  497. # global _balancer
  498. # with _balancer_lock:
  499. # _balancer = None
  500. # logger.info("🔄 Balancer reset")
  501. # def call_llm_with_load_balancer(payload: dict) -> str:
  502. # """
  503. # Call LLM with ultra-safe rate limiting
  504. # This is the drop-in replacement for services.py
  505. # """
  506. # balancer = get_llm_load_balancer()
  507. # api_url = getattr(settings, 'GROQ_API_URL')
  508. # return balancer.call_llm(payload, api_url)
  509. # def get_load_balancer_stats() -> Dict:
  510. # """Get balancer stats"""
  511. # try:
  512. # return get_llm_load_balancer().get_stats()
  513. # except Exception as e:
  514. # return {"error": str(e)}