-
Notifications
You must be signed in to change notification settings - Fork 694
feat: Add token-based rate limiting for RemoteInferenceEngine #2086
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
feat: Add token-based rate limiting for RemoteInferenceEngine #2086
Conversation
- Add three new parameters to RemoteParams: - requests_per_minute: Limit API calls per minute - input_tokens_per_minute: Limit input tokens per minute - output_tokens_per_minute: Limit output tokens per minute - Create TokenRateLimiter class with sliding window algorithm - Integrate rate limiting into RemoteInferenceEngine._query_api() - Add comprehensive unit tests for all new functionality Fixes oumi-ai#1457
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This pull request adds token-based rate limiting support to the RemoteInferenceEngine, enabling automatic throttling of API requests based on requests per minute (RPM), input tokens per minute (TPM), and output tokens per minute (TPM) limits.
Key Changes:
- Introduces a new
TokenRateLimiterclass with sliding window algorithm for tracking and enforcing rate limits - Adds three new optional parameters to
RemoteParams:requests_per_minute,input_tokens_per_minute, andoutput_tokens_per_minute - Integrates rate limiting into the request flow by calling
wait_if_needed()before API requests and recording token usage after successful responses
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
src/oumi/inference/token_rate_limiter.py |
New sliding window rate limiter supporting request count and token-based limits with async lock protection |
src/oumi/core/configs/params/remote_params.py |
Added three new rate limit parameters with validation ensuring values are ≥1 or None |
src/oumi/inference/remote_inference_engine.py |
Integrated rate limiter by initializing in constructor, calling wait_if_needed() before requests, and recording token usage from API responses |
tests/unit/inference/test_token_rate_limiter.py |
Comprehensive unit tests covering basic functionality, rate limit enforcement, sliding window cleanup, and concurrent request handling |
tests/unit/core/configs/params/test_remote_params.py |
Added validation tests for the three new rate limit parameters |
Comments suppressed due to low confidence (1)
src/oumi/inference/remote_inference_engine.py:650
- When a request fails (non-200 status, parse error, connection error, etc.), the pending request count is never decremented. The pending_requests counter is only decremented in record_usage(), which is only called on successful responses. This will cause the pending_requests count to grow unbounded over time with failed requests, eventually preventing new requests from being made as the limiter thinks it's at capacity.
Ensure that pending_requests is decremented in all failure paths, either by using a try-finally block around the request or by calling a cleanup method in error handlers.
# Wait if token/request rate limits are being approached
if self._token_rate_limiter.is_enabled():
await self._token_rate_limiter.wait_if_needed()
semaphore_or_controller = (
self._adaptive_concurrency_controller
if self._remote_params.use_adaptive_concurrency
else semaphore
)
async with semaphore_or_controller:
api_input = self._convert_conversation_to_api_input(
conversation, generation_params, model_params
)
headers = self._get_request_headers(remote_params)
failure_reason = None
# Retry the request if it fails
for attempt in range(remote_params.max_retries + 1):
try:
# Calculate exponential backoff delay
if attempt > 0:
delay = min(
remote_params.retry_backoff_base * (2 ** (attempt - 1)),
remote_params.retry_backoff_max,
)
await asyncio.sleep(delay)
async with session.post(
remote_params.api_url,
json=api_input,
headers=headers,
timeout=remote_params.connection_timeout,
) as response:
if response.status != 200:
await self._try_record_error()
failure_reason = await get_failure_reason_from_response(
response
)
# Check for non-retriable status codes to fail fast.
if is_non_retriable_status_code(response.status):
failure_reason = (
f"Non-retriable error: {failure_reason}"
)
raise RuntimeError(failure_reason)
continue
# Try to parse the response as JSON
try:
response_json = await response.json()
except (aiohttp.ContentTypeError, json.JSONDecodeError):
# Try to parse as text if JSON parsing fails
text_response = await response.text()
try:
response_json = json.loads(text_response)
except (json.JSONDecodeError, ValueError) as e:
await self._try_record_error()
failure_reason = (
"Failed to parse response. "
f"Content type: {response.content_type}. "
f"Response text: {text_response[:200]}..."
)
if attempt >= remote_params.max_retries:
raise RuntimeError(
"Failed to parse response as JSON after "
f"{attempt + 1} attempts. {failure_reason}"
) from e
continue
# Process successful response
try:
result = self._convert_api_output_to_conversation(
response_json, conversation
)
# Write what we have so far to our scratch directory
self._save_conversation_to_scratch(result, output_path)
await self._try_record_success()
# Record token usage for rate limiting
await self._record_token_usage(response_json)
return result
except Exception as e:
# Response was successful, but we couldn't process it.
failure_reason = (
f"Failed to process successful response: {str(e)}"
)
await self._try_record_error()
if attempt >= remote_params.max_retries:
raise RuntimeError(failure_reason) from e
continue
except (aiohttp.ClientError, asyncio.TimeoutError) as e:
# Connection or timeout errors are retriable.
failure_reason = f"Connection error: {str(e)}"
await self._try_record_error()
if attempt >= remote_params.max_retries:
raise RuntimeError(
f"Failed to query API after {attempt + 1} attempts due to "
f"connection error: {str(e)}"
) from e
continue
except RuntimeError:
# RuntimeError is raised by our code, so we don't need to retry.
raise
except Exception as e:
# If we get here, we've hit an unexpected error.
failure_reason = f"Unexpected error: {str(e)}"
await self._try_record_error()
if attempt >= remote_params.max_retries:
raise RuntimeError(
f"Failed to query API after {attempt + 1} attempts due to "
f"unexpected error: {str(e)}"
) from e
continue
# This should only be reached if all retries failed
raise RuntimeError(
f"Failed to query API after {attempt + 1} attempts. "
+ (f"Reason: {failure_reason}" if failure_reason else "")
)
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| async def _record_token_usage(self, response_json: dict) -> None: | ||
| """Record token usage from API response for rate limiting. | ||
| Extracts token counts from the standard OpenAI-format "usage" field | ||
| in the API response and records them with the rate limiter. | ||
| Args: | ||
| response_json: The JSON response from the API. | ||
| """ | ||
| if not self._token_rate_limiter.is_enabled(): | ||
| return | ||
|
|
||
| usage = response_json.get("usage", {}) | ||
| input_tokens = usage.get("prompt_tokens", 0) | ||
| output_tokens = usage.get("completion_tokens", 0) | ||
|
|
||
| await self._token_rate_limiter.record_usage( | ||
| input_tokens=input_tokens, | ||
| output_tokens=output_tokens, | ||
| ) |
Copilot
AI
Dec 13, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _record_token_usage method only handles OpenAI's token field naming ("prompt_tokens" and "completion_tokens"). However, other API providers like Anthropic use different field names ("input_tokens" and "output_tokens" according to their API documentation). This means token-based rate limiting won't work correctly for Anthropic and potentially other non-OpenAI providers.
Consider making _record_token_usage a virtual method that can be overridden by subclasses, or add logic to handle multiple token field naming conventions.
| @pytest.mark.asyncio | ||
| async def test_pending_requests_tracked(self): | ||
| """Test that pending requests are tracked correctly.""" | ||
| limiter = TokenRateLimiter(requests_per_minute=2) | ||
|
|
||
| # Start a request (wait_if_needed increments pending) | ||
| await limiter.wait_if_needed() | ||
|
|
||
| summary = await limiter.get_usage_summary() | ||
| # Pending request should be counted | ||
| assert summary["request_count"] == 1 | ||
|
|
||
| # Complete the request | ||
| await limiter.record_usage(input_tokens=10, output_tokens=5) | ||
|
|
||
| summary = await limiter.get_usage_summary() | ||
| # Still 1 request (now recorded, not pending) | ||
| assert summary["request_count"] == 1 | ||
|
|
Copilot
AI
Dec 13, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test suite is missing coverage for error scenarios where wait_if_needed() is called but record_usage() is never called (e.g., due to request failure). This is a critical scenario that would cause pending_requests to leak and eventually block all future requests.
Add a test that simulates this scenario, such as:
- Call wait_if_needed()
- Simulate a failure without calling record_usage()
- Verify that the limiter still functions correctly for subsequent requests
| input_tokens_per_minute: Optional[int] = None | ||
| """Maximum number of input (prompt) tokens allowed per minute. | ||
| If set, the engine will track input token usage and wait when approaching | ||
| the limit. Many API providers (OpenAI, Anthropic) enforce input token limits. | ||
| Token counts are extracted from API response headers or usage fields. | ||
| """ |
Copilot
AI
Dec 13, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The documentation states that "Token counts are extracted from API response headers or usage fields." However, the implementation in _record_token_usage only extracts tokens from the response body's "usage" field, not from headers. This is misleading.
Either update the documentation to accurately reflect the implementation (only extracts from usage field in response body), or extend the implementation to also check response headers if that functionality is intended.
| output_tokens_per_minute: Optional[int] = None | ||
| """Maximum number of output (completion) tokens allowed per minute. | ||
| If set, the engine will track output token usage and wait when approaching | ||
| the limit. Many API providers enforce output/completion token limits. | ||
| Token counts are extracted from API response headers or usage fields. | ||
| """ |
Copilot
AI
Dec 13, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The documentation states that "Token counts are extracted from API response headers or usage fields." However, the implementation in _record_token_usage only extracts tokens from the response body's "usage" field, not from headers. This is misleading.
Either update the documentation to accurately reflect the implementation (only extracts from usage field in response body), or extend the implementation to also check response headers if that functionality is intended.
| self._pending_requests += 1 | ||
|
|
||
| if wait_time > 0: | ||
| await asyncio.sleep(wait_time) | ||
|
|
Copilot
AI
Dec 13, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The pending request is incremented inside the lock, but the sleep happens outside the lock. If multiple coroutines call wait_if_needed() concurrently, they will all increment pending_requests and calculate wait times based on nearly the same state, then all sleep for similar durations. This can lead to a "thundering herd" problem where multiple requests wake up simultaneously after waiting, potentially exceeding rate limits.
Consider incrementing pending_requests after calculating the wait time but before releasing the lock, or implement a queue-based approach to serialize request scheduling.
| self._pending_requests += 1 | |
| if wait_time > 0: | |
| await asyncio.sleep(wait_time) | |
| await asyncio.sleep(wait_time) | |
| self._pending_requests += 1 |
| async def _calculate_wait_time(self) -> float: | ||
| """Calculate how long to wait before the next request. | ||
| Returns: | ||
| Wait time in seconds. 0 if no wait needed. | ||
| """ | ||
| if not self.is_enabled(): | ||
| return 0.0 | ||
|
|
||
| request_count, input_tokens, output_tokens = await self._get_current_usage() | ||
| current_time = time.time() | ||
| wait_time = 0.0 | ||
|
|
||
| # Check request limit | ||
| if ( | ||
| self._requests_per_minute is not None | ||
| and request_count >= self._requests_per_minute | ||
| ): | ||
| # Find when the oldest request will expire from the window | ||
| if self._usage_history: | ||
| oldest_time = self._usage_history[0].timestamp | ||
| wait_time = max( | ||
| wait_time, oldest_time + self._window_seconds - current_time | ||
| ) | ||
|
|
||
| # Check input token limit | ||
| if ( | ||
| self._input_tokens_per_minute is not None | ||
| and input_tokens >= self._input_tokens_per_minute | ||
| ): | ||
| if self._usage_history: | ||
| oldest_time = self._usage_history[0].timestamp | ||
| wait_time = max( | ||
| wait_time, oldest_time + self._window_seconds - current_time | ||
| ) | ||
|
|
||
| # Check output token limit | ||
| if ( | ||
| self._output_tokens_per_minute is not None | ||
| and output_tokens >= self._output_tokens_per_minute | ||
| ): | ||
| if self._usage_history: | ||
| oldest_time = self._usage_history[0].timestamp | ||
| wait_time = max( | ||
| wait_time, oldest_time + self._window_seconds - current_time | ||
| ) | ||
|
|
||
| return max(0.0, wait_time) |
Copilot
AI
Dec 13, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The rate limiter doesn't account for the tokens that will be consumed by pending requests when calculating wait times. The _calculate_wait_time method only considers completed requests (in _usage_history), but when multiple requests are pending, the actual token consumption will be higher than what the limiter sees. This could lead to exceeding rate limits.
Consider tracking estimated token usage for pending requests or implement a reservation system where requests reserve their expected token capacity before being sent.
| # limitations under the License. | ||
|
|
||
| import asyncio | ||
| import time |
Copilot
AI
Dec 13, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Import of 'time' is not used.
| import time |
FAQ additions: - installation.md: Python versions, dependencies, Docker, virtual envs - training.md: Memory issues, LoRA, datasets, checkpoints, multi-GPU - inference.md: Engine selection, memory, model loading, generation params - remote_jobs.md: Cloud setup, file mounts, job management, debugging Inference guide: - engine_comparison.md: Compare NATIVE vs VLLM vs REMOTE_VLLM vs SGLANG vs LLAMACPP with decision flowchart, memory requirements, and feature matrix Updated index.md and infer.md toctrees to include new files
…ilure The _pending_requests counter was only decremented on successful requests, causing it to leak indefinitely on any failure (HTTP errors, timeouts, JSON parsing errors). This could eventually block all subsequent requests. Added try-finally block to ensure record_request_without_tokens() is always called when a request fails, preventing the counter from leaking. Fixes the bug reported by Sentry bot in PR oumi-ai#2086
When concurrent requests arrive simultaneously, the rate limit could be reached purely by pending requests while _usage_history is empty. The code was not applying any wait in this case, allowing requests to exceed the configured limit. Added fallback wait time (0.1s) when limit is reached but history is empty, ensuring rate limiting is enforced under concurrent load.
|
Hi @RajdeepKushwaha5, thank you for your PR contribution! However, there is someone already working on issue #1457 in PR #2082. In the future, please comment on the issue to have it assigned to you before starting on the PR, to avoid duplicating work. Thanks! |
Description
Adds support for token-based rate limiting in RemoteInferenceEngine.
Changes
RemoteParams:requests_per_minute: Limit API calls per minuteinput_tokens_per_minute: Limit input tokens per minuteoutput_tokens_per_minute: Limit output tokens per minuteTokenRateLimiterclass with sliding window algorithmRemoteInferenceEngine._query_api()Fixes
Closes #1457