Skip to content

Conversation

@RajdeepKushwaha5
Copy link

Description

Adds support for token-based rate limiting in RemoteInferenceEngine.

Changes

  • Added three new parameters to RemoteParams:
    • requests_per_minute: Limit API calls per minute
    • input_tokens_per_minute: Limit input tokens per minute
    • output_tokens_per_minute: Limit output tokens per minute
  • Created TokenRateLimiter class with sliding window algorithm
  • Integrated rate limiting into RemoteInferenceEngine._query_api()
  • Added comprehensive unit tests

Fixes

Closes #1457

- Add three new parameters to RemoteParams:
  - requests_per_minute: Limit API calls per minute
  - input_tokens_per_minute: Limit input tokens per minute
  - output_tokens_per_minute: Limit output tokens per minute
- Create TokenRateLimiter class with sliding window algorithm
- Integrate rate limiting into RemoteInferenceEngine._query_api()
- Add comprehensive unit tests for all new functionality

Fixes oumi-ai#1457
Copilot AI review requested due to automatic review settings December 13, 2025 08:55
Comment on lines 534 to 535
if self._token_rate_limiter.is_enabled():
await self._token_rate_limiter.wait_if_needed()

This comment was marked as outdated.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request adds token-based rate limiting support to the RemoteInferenceEngine, enabling automatic throttling of API requests based on requests per minute (RPM), input tokens per minute (TPM), and output tokens per minute (TPM) limits.

Key Changes:

  • Introduces a new TokenRateLimiter class with sliding window algorithm for tracking and enforcing rate limits
  • Adds three new optional parameters to RemoteParams: requests_per_minute, input_tokens_per_minute, and output_tokens_per_minute
  • Integrates rate limiting into the request flow by calling wait_if_needed() before API requests and recording token usage after successful responses

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
src/oumi/inference/token_rate_limiter.py New sliding window rate limiter supporting request count and token-based limits with async lock protection
src/oumi/core/configs/params/remote_params.py Added three new rate limit parameters with validation ensuring values are ≥1 or None
src/oumi/inference/remote_inference_engine.py Integrated rate limiter by initializing in constructor, calling wait_if_needed() before requests, and recording token usage from API responses
tests/unit/inference/test_token_rate_limiter.py Comprehensive unit tests covering basic functionality, rate limit enforcement, sliding window cleanup, and concurrent request handling
tests/unit/core/configs/params/test_remote_params.py Added validation tests for the three new rate limit parameters
Comments suppressed due to low confidence (1)

src/oumi/inference/remote_inference_engine.py:650

  • When a request fails (non-200 status, parse error, connection error, etc.), the pending request count is never decremented. The pending_requests counter is only decremented in record_usage(), which is only called on successful responses. This will cause the pending_requests count to grow unbounded over time with failed requests, eventually preventing new requests from being made as the limiter thinks it's at capacity.

Ensure that pending_requests is decremented in all failure paths, either by using a try-finally block around the request or by calling a cleanup method in error handlers.

        # Wait if token/request rate limits are being approached
        if self._token_rate_limiter.is_enabled():
            await self._token_rate_limiter.wait_if_needed()

        semaphore_or_controller = (
            self._adaptive_concurrency_controller
            if self._remote_params.use_adaptive_concurrency
            else semaphore
        )
        async with semaphore_or_controller:
            api_input = self._convert_conversation_to_api_input(
                conversation, generation_params, model_params
            )
            headers = self._get_request_headers(remote_params)
            failure_reason = None

            # Retry the request if it fails
            for attempt in range(remote_params.max_retries + 1):
                try:
                    # Calculate exponential backoff delay
                    if attempt > 0:
                        delay = min(
                            remote_params.retry_backoff_base * (2 ** (attempt - 1)),
                            remote_params.retry_backoff_max,
                        )
                        await asyncio.sleep(delay)

                    async with session.post(
                        remote_params.api_url,
                        json=api_input,
                        headers=headers,
                        timeout=remote_params.connection_timeout,
                    ) as response:
                        if response.status != 200:
                            await self._try_record_error()
                            failure_reason = await get_failure_reason_from_response(
                                response
                            )

                            # Check for non-retriable status codes to fail fast.
                            if is_non_retriable_status_code(response.status):
                                failure_reason = (
                                    f"Non-retriable error: {failure_reason}"
                                )
                                raise RuntimeError(failure_reason)
                            continue

                        # Try to parse the response as JSON
                        try:
                            response_json = await response.json()
                        except (aiohttp.ContentTypeError, json.JSONDecodeError):
                            # Try to parse as text if JSON parsing fails
                            text_response = await response.text()
                            try:
                                response_json = json.loads(text_response)
                            except (json.JSONDecodeError, ValueError) as e:
                                await self._try_record_error()
                                failure_reason = (
                                    "Failed to parse response. "
                                    f"Content type: {response.content_type}. "
                                    f"Response text: {text_response[:200]}..."
                                )
                                if attempt >= remote_params.max_retries:
                                    raise RuntimeError(
                                        "Failed to parse response as JSON after "
                                        f"{attempt + 1} attempts. {failure_reason}"
                                    ) from e
                                continue

                        # Process successful response
                        try:
                            result = self._convert_api_output_to_conversation(
                                response_json, conversation
                            )
                            # Write what we have so far to our scratch directory
                            self._save_conversation_to_scratch(result, output_path)
                            await self._try_record_success()
                            # Record token usage for rate limiting
                            await self._record_token_usage(response_json)
                            return result
                        except Exception as e:
                            # Response was successful, but we couldn't process it.
                            failure_reason = (
                                f"Failed to process successful response: {str(e)}"
                            )
                            await self._try_record_error()
                            if attempt >= remote_params.max_retries:
                                raise RuntimeError(failure_reason) from e
                            continue

                except (aiohttp.ClientError, asyncio.TimeoutError) as e:
                    # Connection or timeout errors are retriable.
                    failure_reason = f"Connection error: {str(e)}"
                    await self._try_record_error()
                    if attempt >= remote_params.max_retries:
                        raise RuntimeError(
                            f"Failed to query API after {attempt + 1} attempts due to "
                            f"connection error: {str(e)}"
                        ) from e
                    continue
                except RuntimeError:
                    # RuntimeError is raised by our code, so we don't need to retry.
                    raise
                except Exception as e:
                    # If we get here, we've hit an unexpected error.
                    failure_reason = f"Unexpected error: {str(e)}"
                    await self._try_record_error()
                    if attempt >= remote_params.max_retries:
                        raise RuntimeError(
                            f"Failed to query API after {attempt + 1} attempts due to "
                            f"unexpected error: {str(e)}"
                        ) from e
                    continue
            # This should only be reached if all retries failed
            raise RuntimeError(
                f"Failed to query API after {attempt + 1} attempts. "
                + (f"Reason: {failure_reason}" if failure_reason else "")
            )

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +299 to +318
async def _record_token_usage(self, response_json: dict) -> None:
"""Record token usage from API response for rate limiting.
Extracts token counts from the standard OpenAI-format "usage" field
in the API response and records them with the rate limiter.
Args:
response_json: The JSON response from the API.
"""
if not self._token_rate_limiter.is_enabled():
return

usage = response_json.get("usage", {})
input_tokens = usage.get("prompt_tokens", 0)
output_tokens = usage.get("completion_tokens", 0)

await self._token_rate_limiter.record_usage(
input_tokens=input_tokens,
output_tokens=output_tokens,
)
Copy link

Copilot AI Dec 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _record_token_usage method only handles OpenAI's token field naming ("prompt_tokens" and "completion_tokens"). However, other API providers like Anthropic use different field names ("input_tokens" and "output_tokens" according to their API documentation). This means token-based rate limiting won't work correctly for Anthropic and potentially other non-OpenAI providers.

Consider making _record_token_usage a virtual method that can be overridden by subclasses, or add logic to handle multiple token field naming conventions.

Copilot uses AI. Check for mistakes.
Comment on lines +230 to +248
@pytest.mark.asyncio
async def test_pending_requests_tracked(self):
"""Test that pending requests are tracked correctly."""
limiter = TokenRateLimiter(requests_per_minute=2)

# Start a request (wait_if_needed increments pending)
await limiter.wait_if_needed()

summary = await limiter.get_usage_summary()
# Pending request should be counted
assert summary["request_count"] == 1

# Complete the request
await limiter.record_usage(input_tokens=10, output_tokens=5)

summary = await limiter.get_usage_summary()
# Still 1 request (now recorded, not pending)
assert summary["request_count"] == 1

Copy link

Copilot AI Dec 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test suite is missing coverage for error scenarios where wait_if_needed() is called but record_usage() is never called (e.g., due to request failure). This is a critical scenario that would cause pending_requests to leak and eventually block all future requests.

Add a test that simulates this scenario, such as:

  • Call wait_if_needed()
  • Simulate a failure without calling record_usage()
  • Verify that the limiter still functions correctly for subsequent requests

Copilot uses AI. Check for mistakes.
Comment on lines +66 to +72
input_tokens_per_minute: Optional[int] = None
"""Maximum number of input (prompt) tokens allowed per minute.
If set, the engine will track input token usage and wait when approaching
the limit. Many API providers (OpenAI, Anthropic) enforce input token limits.
Token counts are extracted from API response headers or usage fields.
"""
Copy link

Copilot AI Dec 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The documentation states that "Token counts are extracted from API response headers or usage fields." However, the implementation in _record_token_usage only extracts tokens from the response body's "usage" field, not from headers. This is misleading.

Either update the documentation to accurately reflect the implementation (only extracts from usage field in response body), or extend the implementation to also check response headers if that functionality is intended.

Copilot uses AI. Check for mistakes.
Comment on lines +74 to +80
output_tokens_per_minute: Optional[int] = None
"""Maximum number of output (completion) tokens allowed per minute.
If set, the engine will track output token usage and wait when approaching
the limit. Many API providers enforce output/completion token limits.
Token counts are extracted from API response headers or usage fields.
"""
Copy link

Copilot AI Dec 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The documentation states that "Token counts are extracted from API response headers or usage fields." However, the implementation in _record_token_usage only extracts tokens from the response body's "usage" field, not from headers. This is misleading.

Either update the documentation to accurately reflect the implementation (only extracts from usage field in response body), or extend the implementation to also check response headers if that functionality is intended.

Copilot uses AI. Check for mistakes.
Comment on lines +183 to +187
self._pending_requests += 1

if wait_time > 0:
await asyncio.sleep(wait_time)

Copy link

Copilot AI Dec 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pending request is incremented inside the lock, but the sleep happens outside the lock. If multiple coroutines call wait_if_needed() concurrently, they will all increment pending_requests and calculate wait times based on nearly the same state, then all sleep for similar durations. This can lead to a "thundering herd" problem where multiple requests wake up simultaneously after waiting, potentially exceeding rate limits.

Consider incrementing pending_requests after calculating the wait time but before releasing the lock, or implement a queue-based approach to serialize request scheduling.

Suggested change
self._pending_requests += 1
if wait_time > 0:
await asyncio.sleep(wait_time)
await asyncio.sleep(wait_time)
self._pending_requests += 1

Copilot uses AI. Check for mistakes.
Comment on lines 122 to 169
async def _calculate_wait_time(self) -> float:
"""Calculate how long to wait before the next request.
Returns:
Wait time in seconds. 0 if no wait needed.
"""
if not self.is_enabled():
return 0.0

request_count, input_tokens, output_tokens = await self._get_current_usage()
current_time = time.time()
wait_time = 0.0

# Check request limit
if (
self._requests_per_minute is not None
and request_count >= self._requests_per_minute
):
# Find when the oldest request will expire from the window
if self._usage_history:
oldest_time = self._usage_history[0].timestamp
wait_time = max(
wait_time, oldest_time + self._window_seconds - current_time
)

# Check input token limit
if (
self._input_tokens_per_minute is not None
and input_tokens >= self._input_tokens_per_minute
):
if self._usage_history:
oldest_time = self._usage_history[0].timestamp
wait_time = max(
wait_time, oldest_time + self._window_seconds - current_time
)

# Check output token limit
if (
self._output_tokens_per_minute is not None
and output_tokens >= self._output_tokens_per_minute
):
if self._usage_history:
oldest_time = self._usage_history[0].timestamp
wait_time = max(
wait_time, oldest_time + self._window_seconds - current_time
)

return max(0.0, wait_time)
Copy link

Copilot AI Dec 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rate limiter doesn't account for the tokens that will be consumed by pending requests when calculating wait times. The _calculate_wait_time method only considers completed requests (in _usage_history), but when multiple requests are pending, the actual token consumption will be higher than what the limiter sees. This could lead to exceeding rate limits.

Consider tracking estimated token usage for pending requests or implement a reservation system where requests reserve their expected token capacity before being sent.

Copilot uses AI. Check for mistakes.
# limitations under the License.

import asyncio
import time
Copy link

Copilot AI Dec 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'time' is not used.

Suggested change
import time

Copilot uses AI. Check for mistakes.
FAQ additions:
- installation.md: Python versions, dependencies, Docker, virtual envs
- training.md: Memory issues, LoRA, datasets, checkpoints, multi-GPU
- inference.md: Engine selection, memory, model loading, generation params
- remote_jobs.md: Cloud setup, file mounts, job management, debugging

Inference guide:
- engine_comparison.md: Compare NATIVE vs VLLM vs REMOTE_VLLM vs SGLANG vs LLAMACPP
  with decision flowchart, memory requirements, and feature matrix

Updated index.md and infer.md toctrees to include new files
Comment on lines 533 to 535
# Wait if token/request rate limits are being approached
if self._token_rate_limiter.is_enabled():
await self._token_rate_limiter.wait_if_needed()

This comment was marked as outdated.

…ilure

The _pending_requests counter was only decremented on successful requests, causing it to leak indefinitely on any failure (HTTP errors, timeouts, JSON parsing errors). This could eventually block all subsequent requests.

Added try-finally block to ensure record_request_without_tokens() is always called when a request fails, preventing the counter from leaking.

Fixes the bug reported by Sentry bot in PR oumi-ai#2086
Comment on lines +141 to +145
if self._usage_history:
oldest_time = self._usage_history[0].timestamp
wait_time = max(
wait_time, oldest_time + self._window_seconds - current_time
)

This comment was marked as outdated.

When concurrent requests arrive simultaneously, the rate limit could be reached purely by pending requests while _usage_history is empty. The code was not applying any wait in this case, allowing requests to exceed the configured limit.

Added fallback wait time (0.1s) when limit is reached but history is empty, ensuring rate limiting is enforced under concurrent load.
@oelachqar oelachqar requested a review from jgreer013 December 14, 2025 17:28
@wizeng23
Copy link
Contributor

Hi @RajdeepKushwaha5, thank you for your PR contribution! However, there is someone already working on issue #1457 in PR #2082. In the future, please comment on the issue to have it assigned to you before starting on the PR, to avoid duplicating work. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature] Add support for token-based rate-limiting in RemoteInferenceEngine

2 participants