Source code for niquests.hooks

"""
requests.hooks
~~~~~~~~~~~~~~

This module provides the capabilities for the Requests hooks system.

Available hooks:

``pre_request``:
    The prepared request just got built. You may alter it prior to be sent through HTTP.
``pre_send``:
    The prepared request got his ConnectionInfo injected.
    This event is triggered just after picking a live connection from the pool.
``on_upload``:
    Permit to monitor the upload progress of passed body.
    This event is triggered each time a block of data is transmitted to the remote peer.
    Use this hook carefully as it may impact the overall performance.
``response``:
    The response generated from a Request.
"""

from __future__ import annotations

import asyncio
import threading
import time
import typing
from collections.abc import MutableMapping

from ._compat import iscoroutinefunction
from .typing import (
    _HV,
    AsyncHookCallableType,
    AsyncHookType,
    HookCallableType,
    HookType,
)

if typing.TYPE_CHECKING:
    from .models import PreparedRequest, Response

HOOKS = [
    "pre_request",
    "pre_send",
    "on_upload",
    "early_response",
    "response",
]


def default_hooks() -> HookType[_HV]:
    return {event: [] for event in HOOKS}


def dispatch_hook(key: str, hooks: HookType[_HV] | None, hook_data: _HV, **kwargs: typing.Any) -> _HV:
    """Dispatches a hook dictionary on a given piece of data."""
    if hooks is None:
        return hook_data

    callables: list[HookCallableType[_HV]] | None = hooks.get(key)  # type: ignore[assignment]

    if callables:
        if callable(callables):
            callables = [callables]
        for hook in callables:
            try:
                _hook_data = hook(hook_data, **kwargs)
            except TypeError:
                _hook_data = hook(hook_data)
            if _hook_data is not None:
                hook_data = _hook_data

    return hook_data


async def async_dispatch_hook(key: str, hooks: AsyncHookType[_HV] | None, hook_data: _HV, **kwargs: typing.Any) -> _HV:
    """Dispatches a hook dictionary on a given piece of data asynchronously."""
    if hooks is None:
        return hook_data

    callables: list[HookCallableType[_HV] | AsyncHookCallableType[_HV]] | None = hooks.get(key)

    if callables:
        if callable(callables):
            callables = [callables]
        for hook in callables:
            if iscoroutinefunction(hook):
                try:
                    _hook_data = await hook(hook_data, **kwargs)
                except TypeError:
                    _hook_data = await hook(hook_data)
            else:
                try:
                    _hook_data = hook(hook_data, **kwargs)
                except TypeError:
                    _hook_data = hook(hook_data)

            if _hook_data is not None:
                hook_data = _hook_data

    return hook_data


class _BaseLifeCycleHook(
    typing.MutableMapping[str, typing.List[typing.Union[HookCallableType, AsyncHookCallableType]]], typing.Generic[_HV]
):
    def __init__(self) -> None:
        self._store: MutableMapping[str, list[HookCallableType[_HV] | AsyncHookCallableType[_HV]]] = {
            "pre_request": [],
            "pre_send": [],
            "on_upload": [],
            "early_response": [],
            "response": [],
        }

    def __setitem__(self, key: str | bytes, value: list[HookCallableType[_HV] | AsyncHookCallableType[_HV]]) -> None:
        raise NotImplementedError("LifeCycleHook is Read Only")

    def __getitem__(self, key: str) -> list[HookCallableType[_HV] | AsyncHookCallableType[_HV]]:
        return self._store[key]

    def get(self, key: str) -> list[HookCallableType[_HV] | AsyncHookCallableType[_HV]]:  # type: ignore[override]
        return self[key]

    def __add__(self, other) -> _BaseLifeCycleHook:
        if not isinstance(other, _BaseLifeCycleHook):
            raise TypeError

        tmp_store = {}
        combined_hooks: _BaseLifeCycleHook[_HV] = _BaseLifeCycleHook()

        for h, fns in self._store.items():
            tmp_store[h] = fns
            tmp_store[h] += other._store[h]

        combined_hooks._store = tmp_store

        return combined_hooks

    def __iter__(self):
        yield from self._store

    def items(self):
        for key in self:
            yield key, self[key]

    def __delitem__(self, key):
        raise NotImplementedError("LifeCycleHook is Read Only")

    def __len__(self):
        return len(self._store)


[docs] class LifeCycleHook(_BaseLifeCycleHook[_HV]): """ A sync-only middleware to be used in your request/response lifecycles. """ def __init__(self) -> None: super().__init__() self._store.update( { "pre_request": [self.pre_request], # type: ignore[list-item] "pre_send": [self.pre_send], # type: ignore[list-item] "on_upload": [self.on_upload], # type: ignore[list-item] "early_response": [self.early_response], # type: ignore[list-item] "response": [self.response], # type: ignore[list-item] } )
[docs] def pre_request(self, prepared_request: PreparedRequest, **kwargs: typing.Any) -> PreparedRequest | None: """The prepared request just got built. You may alter it prior to be sent through HTTP.""" return None
[docs] def pre_send(self, prepared_request: PreparedRequest, **kwargs: typing.Any) -> None: """The prepared request got his ConnectionInfo injected. This event is triggered just after picking a live connection from the pool. You may not alter the prepared request.""" return None
[docs] def on_upload(self, prepared_request: PreparedRequest, **kwargs: typing.Any) -> None: """Permit to monitor the upload progress of passed body. This event is triggered each time a block of data is transmitted to the remote peer. Use this hook carefully as it may impact the overall performance. You may not alter the prepared request.""" return None
[docs] def early_response(self, response: Response, **kwargs: typing.Any) -> None: """An early response caught before receiving the final Response for a given Request. Like but not limited to 103 Early Hints.""" return None
[docs] def response(self, response: Response, **kwargs: typing.Any) -> Response | None: """The response generated from a Request. You may alter the response at will.""" return None
[docs] class AsyncLifeCycleHook(_BaseLifeCycleHook[_HV]): """ An async-only middleware to be used in your request/response lifecycles. """ def __init__(self) -> None: super().__init__() self._store.update( { "pre_request": [self.pre_request], # type: ignore[list-item] "pre_send": [self.pre_send], # type: ignore[list-item] "on_upload": [self.on_upload], # type: ignore[list-item] "early_response": [self.early_response], # type: ignore[list-item] "response": [self.response], # type: ignore[list-item] } )
[docs] async def pre_request(self, prepared_request: PreparedRequest, **kwargs: typing.Any) -> PreparedRequest | None: """The prepared request just got built. You may alter it prior to be sent through HTTP.""" return None
[docs] async def pre_send(self, prepared_request: PreparedRequest, **kwargs: typing.Any) -> None: """The prepared request got his ConnectionInfo injected. This event is triggered just after picking a live connection from the pool. You may not alter the prepared request.""" return None
[docs] async def on_upload(self, prepared_request: PreparedRequest, **kwargs: typing.Any) -> None: """Permit to monitor the upload progress of passed body. This event is triggered each time a block of data is transmitted to the remote peer. Use this hook carefully as it may impact the overall performance. You may not alter the prepared request.""" return None
[docs] async def early_response(self, response: Response, **kwargs: typing.Any) -> None: """An early response caught before receiving the final Response for a given Request. Like but not limited to 103 Early Hints.""" return None
[docs] async def response(self, response: Response, **kwargs: typing.Any) -> Response | None: """The response generated from a Request. You may alter the response at will.""" return None
class _LeakyBucketMixin: """Shared leaky bucket algorithm logic.""" rate: float interval: float last_request: float | None def _init_leaky_bucket(self, rate: float) -> None: self.rate = rate self.interval = 1.0 / rate self.last_request = None def _compute_wait(self) -> float: """Compute wait time and update state. Returns wait time (may be <= 0).""" now = time.monotonic() if self.last_request is not None: elapsed = now - self.last_request wait_time = self.interval - elapsed else: wait_time = 0.0 return wait_time def _record_request(self) -> None: """Record that a request was made.""" self.last_request = time.monotonic() class _TokenBucketMixin: """Shared token bucket algorithm logic.""" rate: float capacity: float tokens: float last_update: float def _init_token_bucket(self, rate: float, capacity: float | None) -> None: self.rate = rate self.capacity = capacity if capacity is not None else rate self.tokens = self.capacity self.last_update = time.monotonic() def _acquire_token(self) -> float | None: """Replenish tokens and try to acquire one. Returns wait time if needed, None otherwise.""" now = time.monotonic() elapsed = now - self.last_update self.tokens = min(self.capacity, self.tokens + elapsed * self.rate) self.last_update = now if self.tokens >= 1.0: self.tokens -= 1.0 return None else: # Don't update last_update here; let _post_wait handle it wait_time = (1.0 - self.tokens) / self.rate return wait_time def _post_wait(self) -> None: """Called after waiting to consume the token.""" now = time.monotonic() elapsed = now - self.last_update # Replenish tokens accumulated during the wait self.tokens = min(self.capacity, self.tokens + elapsed * self.rate) self.last_update = now # Now consume the token self.tokens -= 1.0
[docs] class LeakyBucketLimiter(_LeakyBucketMixin, LifeCycleHook): """Rate limiter using the leaky bucket algorithm. Requests "leak" out at a constant rate. When a request arrives, it waits until enough time has passed since the last request to maintain the rate. Usage:: limiter = LeakyBucketLimiter(rate=10.0) # 10 requests per second with niquests.Session(hooks=limiter) as session: ... """
[docs] def __init__(self, rate: float = 10.0) -> None: """Initialize the leaky bucket limiter. Args: rate: Maximum requests per second """ super().__init__() self._init_leaky_bucket(rate) self._lock = threading.Lock()
def pre_request(self, prepared_request: PreparedRequest, **kwargs: typing.Any) -> PreparedRequest | None: """Wait if needed to maintain the rate limit.""" with self._lock: wait_time = self._compute_wait() if wait_time > 0: time.sleep(wait_time) self._record_request() return None
[docs] class AsyncLeakyBucketLimiter(_LeakyBucketMixin, AsyncLifeCycleHook): """Rate limiter using the leaky bucket algorithm. Requests "leak" out at a constant rate. When a request arrives, it waits until enough time has passed since the last request to maintain the rate. Usage:: limiter = AsyncLeakyBucketLimiter(rate=10.0) # 10 requests per second async with niquests.AsyncSession(hooks=limiter) as session: ... """
[docs] def __init__(self, rate: float = 10.0) -> None: """Initialize the leaky bucket limiter. Args: rate: Maximum requests per second """ super().__init__() self._init_leaky_bucket(rate) self._lock = asyncio.Lock()
async def pre_request(self, prepared_request: PreparedRequest, **kwargs: typing.Any) -> PreparedRequest | None: """Wait if needed to maintain the rate limit.""" async with self._lock: wait_time = self._compute_wait() if wait_time > 0: await asyncio.sleep(wait_time) self._record_request() return None
[docs] class TokenBucketLimiter(_TokenBucketMixin, LifeCycleHook): """Rate limiter using the token bucket algorithm. Tokens are added to a bucket at a constant rate up to a maximum capacity. Each request consumes one token. Allows bursts up to the bucket capacity. Usage:: limiter = TokenBucketLimiter(rate=10.0, capacity=50.0) # 10/s, burst of 50 with niquests.Session(hooks=limiter) as session: ... """
[docs] def __init__(self, rate: float = 10.0, capacity: float | None = None) -> None: """Initialize the token bucket limiter. Args: rate: Token replenishment rate (tokens per second) capacity: Maximum bucket capacity (defaults to rate, allowing 1 second burst) """ super().__init__() self._init_token_bucket(rate, capacity) self._lock = threading.Lock()
def pre_request(self, prepared_request: PreparedRequest, **kwargs: typing.Any) -> PreparedRequest | None: """Wait until a token is available, then consume it.""" with self._lock: wait_time = self._acquire_token() if wait_time is not None: time.sleep(wait_time) self._post_wait() return None
[docs] class AsyncTokenBucketLimiter(_TokenBucketMixin, AsyncLifeCycleHook): """Rate limiter using the token bucket algorithm. Tokens are added to a bucket at a constant rate up to a maximum capacity. Each request consumes one token. Allows bursts up to the bucket capacity. Usage:: limiter = AsyncTokenBucketLimiter(rate=10.0, capacity=50.0) # 10/s, burst of 50 async with niquests.AsyncSession(hooks=limiter) as session: ... """
[docs] def __init__(self, rate: float = 10.0, capacity: float | None = None) -> None: """Initialize the token bucket limiter. Args: rate: Token replenishment rate (tokens per second) capacity: Maximum bucket capacity (defaults to rate, allowing 1 second burst) """ super().__init__() self._init_token_bucket(rate, capacity) self._lock = asyncio.Lock()
async def pre_request(self, prepared_request: PreparedRequest, **kwargs: typing.Any) -> PreparedRequest | None: """Wait until a token is available, then consume it.""" async with self._lock: wait_time = self._acquire_token() if wait_time is not None: await asyncio.sleep(wait_time) self._post_wait() return None