diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 7a779cd..597b4df 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -35,5 +35,8 @@ jobs: - name: Run Ruff run: uv run ruff check --output-format=github . + - name: Run Ty (type checker) + run: uv run ty check --output-format=github . + - name: Run pytest run: uv run pytest diff --git a/README.md b/README.md index 45a54e5..81bfa90 100644 --- a/README.md +++ b/README.md @@ -157,7 +157,7 @@ Allows you to serve multiple Telegram bots in a single application. Useful if yo - Allows serving multiple bots via a single endpoint - Uses the bot token for request routing -- Requires dispatcher, web_adapter, routing, bot_settings (optional), webhook_config (optional), and security (optional) +- Requires dispatcher, web_adapter, routing, bot_config (optional), webhook_config (optional), and security (optional) **Example:** @@ -198,7 +198,16 @@ engine = TokenEngine( #### Custom Engines -You can create your own engine by inheriting from the base engine class (`BaseEngine`). This allows you to implement custom logic for webhook processing, routing, or bot management. +You can create your own engine by inheriting from `WebhookEngine`. This allows you to implement custom logic for webhook processing, routing, or bot management. + +### Request processing + +`WebhookEngine` handles incoming updates in this order: + +1. Extract token from request (`_get_bot_token_for_request`) +2. Run security checks for the token (`Security.verify(token, bound_request)`) +3. Resolve bot (`_get_bot_by_token`) +4. Pass update to aiogram dispatcher --- @@ -237,9 +246,9 @@ routing = StaticRouting(url="https://example.com/webhook") ### TokenRouting (Multi-bot, Abstract) Base class for token-based routing strategies. Used with **TokenEngine** to serve multiple bots. -- Requires a URL template with a parameter placeholder (e.g. `{bot_token}`) +- Defines the token parameter name (default: `bot_token`) - Extracts bot token from incoming requests -- Automatically formats webhook URL using the bot token +- Automatically builds webhook URL using the bot token ### PathRouting (Multi-bot) Extracts bot token from the URL path parameter. @@ -277,7 +286,7 @@ routing = QueryRouting(url="https://example.com/webhook?other=value") ``` ### Custom Routing -You can implement your own routing by inheriting from `BaseRouting` or `TokenRouting` and implementing the `webhook_point()` method (and `extract_token()` if using token-based routing). +You can implement your own routing by inheriting from `BaseRouting` or `TokenRouting` and implementing the `webhook_url()` method (and `resolve_token()` if using token-based routing). See [routing examples](/src/aiogram_webhook/routing) for implementation details. diff --git a/src/aiogram_webhook/adapters/fastapi/adapter.py b/src/aiogram_webhook/adapters/fastapi/adapter.py index 2aa5e66..9e88a94 100644 --- a/src/aiogram_webhook/adapters/fastapi/adapter.py +++ b/src/aiogram_webhook/adapters/fastapi/adapter.py @@ -4,14 +4,14 @@ from fastapi.responses import JSONResponse from aiogram_webhook.adapters.base_adapter import BoundRequest, WebAdapter -from aiogram_webhook.adapters.fastapi.mapping import FastAPIHeadersMapping, FastAPIQueryMapping +from aiogram_webhook.adapters.fastapi.mapping import FastApiHeadersMapping, FastApiQueryMapping -class FastAPIBoundRequest(BoundRequest[Request]): +class FastApiBoundRequest(BoundRequest[Request]): def __init__(self, request: Request): super().__init__(request) - self._headers = FastAPIHeadersMapping(self.request.headers) - self._query_params = FastAPIQueryMapping(self.request.query_params) + self._headers = FastApiHeadersMapping(self.request.headers) + self._query_params = FastApiQueryMapping(self.request.query_params) async def json(self) -> dict[str, Any]: return await self.request.json() @@ -23,11 +23,11 @@ def client_ip(self): return None @property - def headers(self) -> FastAPIHeadersMapping: + def headers(self) -> FastApiHeadersMapping: return self._headers @property - def query_params(self) -> FastAPIQueryMapping: + def query_params(self) -> FastApiQueryMapping: return self._query_params @property @@ -36,8 +36,8 @@ def path_params(self): class FastApiWebAdapter(WebAdapter): - def bind(self, request: Request) -> FastAPIBoundRequest: - return FastAPIBoundRequest(request=request) + def bind(self, request: Request) -> FastApiBoundRequest: + return FastApiBoundRequest(request=request) def register(self, app: FastAPI, path, handler, on_startup=None, on_shutdown=None) -> None: # noqa: ARG002 async def endpoint(request: Request): diff --git a/src/aiogram_webhook/adapters/fastapi/mapping.py b/src/aiogram_webhook/adapters/fastapi/mapping.py index bc7e717..dab4cd2 100644 --- a/src/aiogram_webhook/adapters/fastapi/mapping.py +++ b/src/aiogram_webhook/adapters/fastapi/mapping.py @@ -5,11 +5,11 @@ from aiogram_webhook.adapters.base_mapping import MappingABC -class FastAPIHeadersMapping(MappingABC[Headers]): +class FastApiHeadersMapping(MappingABC[Headers]): def getlist(self, name: str) -> list[Any]: return self._mapping.getlist(name) -class FastAPIQueryMapping(MappingABC[QueryParams]): +class FastApiQueryMapping(MappingABC[QueryParams]): def getlist(self, name: str) -> list[Any]: return self._mapping.getlist(name) diff --git a/src/aiogram_webhook/config/bot.py b/src/aiogram_webhook/config/bot.py index 05f7455..6c57b1e 100644 --- a/src/aiogram_webhook/config/bot.py +++ b/src/aiogram_webhook/config/bot.py @@ -4,7 +4,7 @@ from aiogram.client.session.base import BaseSession -@dataclass +@dataclass(slots=True) class BotConfig: session: BaseSession | None = None """HTTP Client session (For example AiohttpSession). If not specified it will be automatically created.""" diff --git a/src/aiogram_webhook/engines/base.py b/src/aiogram_webhook/engines/base.py index f8bda1b..69ae39b 100644 --- a/src/aiogram_webhook/engines/base.py +++ b/src/aiogram_webhook/engines/base.py @@ -1,10 +1,12 @@ from __future__ import annotations import asyncio +import warnings from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any from aiogram.methods import TelegramMethod +from aiogram.utils.token import TokenValidationError from aiogram_webhook.config.webhook import WebhookConfig @@ -45,9 +47,8 @@ def __init__( self.handle_in_background = handle_in_background self._background_feed_update_tasks: set[asyncio.Task[Any]] = set() - @abstractmethod - def _get_bot_from_request(self, bound_request: BoundRequest) -> Bot | None: - raise NotImplementedError + if self.security is None: + warnings.warn("Security is not configured, skipping verification", UserWarning, stacklevel=3) @abstractmethod async def set_webhook(self, *args, **kwargs) -> Bot: @@ -71,25 +72,41 @@ def _build_workflow_data(self, app: Any, **kwargs) -> dict[str, Any]: **kwargs, } + @abstractmethod + async def _get_bot_token_for_request(self, bound_request: BoundRequest) -> str | None: + raise NotImplementedError + + @abstractmethod + async def _get_bot_by_token(self, token: str) -> Bot | None: + raise NotImplementedError + async def handle_request(self, bound_request: BoundRequest): - bot = self._get_bot_from_request(bound_request) - if bot is None: - return self.web_adapter.create_json_response(status=400, payload={"detail": "Bot not found"}) + token = await self._get_bot_token_for_request(bound_request) + if token is None: + return self.web_adapter.create_json_response(status=400, payload={"detail": "Bot token not found"}) - if self.security is not None and not await self.security.verify(bot=bot, bound_request=bound_request): + if self.security is not None and not await self.security.verify( + bot_token=token, bound_request=bound_request, dispatcher=self.dispatcher + ): return self.web_adapter.create_json_response(status=403, payload={"detail": "Forbidden"}) - update = await bound_request.json() + try: + bot = await self._get_bot_by_token(token) + except TokenValidationError: + return self.web_adapter.create_json_response(status=400, payload={"detail": "Invalid bot token"}) + + if bot is None: + return self.web_adapter.create_json_response(status=400, payload={"detail": "Bot not found"}) + update = await bound_request.json() if self.handle_in_background: return await self._handle_request_background(bot=bot, update=update) - return await self._handle_request(bot=bot, update=update) def register(self, app: Any) -> None: self.web_adapter.register( app=app, - path=self.routing.path, + path=self.routing.webhook_path, handler=self.handle_request, on_startup=self.on_startup, on_shutdown=self.on_shutdown, @@ -98,7 +115,7 @@ def register(self, app: Any) -> None: async def _handle_request(self, bot: Bot, update: dict[str, Any]) -> dict[str, Any]: result = await self.dispatcher.feed_webhook_update(bot=bot, update=update) - if not isinstance(result, TelegramMethod): + if result is None: return self.web_adapter.create_json_response(status=200, payload={}) payload = self._build_webhook_payload(bot, result) @@ -110,14 +127,12 @@ async def _handle_request(self, bot: Bot, update: dict[str, Any]) -> dict[str, A return self.web_adapter.create_json_response(status=200, payload=payload) async def _background_feed_update(self, bot: Bot, update: dict[str, Any]) -> None: - result = await self.dispatcher.feed_raw_update(bot=bot, update=update) # **self.data + result = await self.dispatcher.feed_raw_update(bot=bot, update=update) if isinstance(result, TelegramMethod): await self.dispatcher.silent_call_request(bot=bot, result=result) async def _handle_request_background(self, bot: Bot, update: dict[str, Any]): - feed_update_task = asyncio.create_task( - self._background_feed_update(bot=bot, update=update), - ) + feed_update_task = asyncio.create_task(self._background_feed_update(bot=bot, update=update)) self._background_feed_update_tasks.add(feed_update_task) feed_update_task.add_done_callback(self._background_feed_update_tasks.discard) diff --git a/src/aiogram_webhook/engines/simple.py b/src/aiogram_webhook/engines/simple.py index 051f652..f1fe380 100644 --- a/src/aiogram_webhook/engines/simple.py +++ b/src/aiogram_webhook/engines/simple.py @@ -42,13 +42,16 @@ def __init__( handle_in_background=handle_in_background, ) - def _get_bot_from_request(self, bound_request: BoundRequest) -> Bot | None: # noqa: ARG002 + async def _get_bot_token_for_request(self, bound_request: BoundRequest) -> str: # noqa: ARG002 """ - Always returns the single Bot instance for any request. + Always returns the single bot token for any request. :param bound_request: The incoming bound request. - :return: The single Bot instance + :return: The single bot token """ + return self.bot.token + + async def _get_bot_by_token(self, token: str) -> Bot: # noqa: ARG002 return self.bot async def set_webhook( @@ -78,11 +81,13 @@ async def set_webhook( params = config.model_dump(exclude_none=True) if self.security is not None: - secret_token = await self.security.get_secret_token(bot=self.bot) + secret_token = await self.security.secret_token(bot_token=self.bot.token) if secret_token is not None: params["secret_token"] = secret_token - await self.bot.set_webhook(url=self.routing.webhook_point(self.bot), request_timeout=request_timeout, **params) + await self.bot.set_webhook( + url=await self.routing.webhook_url(self.bot), request_timeout=request_timeout, **params + ) return self.bot async def on_startup(self, app: Any, *args, **kwargs) -> None: # noqa: ARG002 diff --git a/src/aiogram_webhook/engines/token.py b/src/aiogram_webhook/engines/token.py index 724e0bc..731c0b8 100644 --- a/src/aiogram_webhook/engines/token.py +++ b/src/aiogram_webhook/engines/token.py @@ -1,8 +1,10 @@ from __future__ import annotations +from types import MappingProxyType from typing import TYPE_CHECKING, Any from aiogram import Bot, Dispatcher +from aiogram.client.session.aiohttp import AiohttpSession from aiogram.utils.token import extract_bot_id from aiogram_webhook.config.bot import BotConfig @@ -44,36 +46,30 @@ def __init__( ) self.routing: TokenRouting = routing # for type checker self.bot_config = bot_config or BotConfig() + self._session = self.bot_config.session or AiohttpSession() self._bots: dict[int, Bot] = {} - def _get_bot_from_request(self, bound_request: BoundRequest) -> Bot | None: - """ - Get a :class:`Bot` instance from request by token. - If the bot is not yet created, it will be created automatically. + @property + def bots(self) -> MappingProxyType[int, Bot]: + return MappingProxyType(self._bots) - :param bound_request: Incoming request - :return: Bot instance or None - """ - token = self.routing.extract_token(bound_request) - if not token: - return None - return self.get_bot(token) + async def _get_bot_token_for_request(self, bound_request: BoundRequest) -> str | None: + return await self.routing.resolve_token(bound_request) - def get_bot(self, token: str) -> Bot: - """ - Resolve or create a Bot instance by token and cache it. + async def _get_bot_by_token(self, token: str) -> Bot: + bot_id = extract_bot_id(token) + existing_bot = self._bots.get(bot_id) - :param token: The bot token - :return: Bot + if existing_bot is None or existing_bot.token != token: + new_bot = self._build_bot(token) + self._bots[bot_id] = new_bot + return new_bot - .. note:: - To connect the bot to Telegram API and set up webhook, use :meth:`set_webhook`. - """ - bot = self._bots.get(extract_bot_id(token)) - if not bot: - bot = Bot(token=token, session=self.bot_config.session, default=self.bot_config.default) - self._bots[bot.id] = bot - return bot + return existing_bot + + def _build_bot(self, token: str) -> Bot: + """Build a new Bot instance from token.""" + return Bot(token=token, session=self._session, default=self.bot_config.default) async def set_webhook( self, @@ -96,22 +92,30 @@ async def set_webhook( :param request_timeout: Request timeout :return: Bot instance """ - bot = self.get_bot(token) - config = self._build_webhook_config( + + bot = await self._get_bot_by_token(token=token) + params = self._build_webhook_config( max_connections=max_connections, drop_pending_updates=drop_pending_updates, allowed_updates=allowed_updates, - ) - params = config.model_dump(exclude_none=True) + ).model_dump(exclude_none=True) if self.security is not None: - secret_token = await self.security.get_secret_token(bot=bot) + secret_token = await self.security.secret_token(bot_token=token) if secret_token is not None: params["secret_token"] = secret_token - await bot.set_webhook(url=self.routing.webhook_point(bot), request_timeout=request_timeout, **params) + await bot.set_webhook(url=await self.routing.webhook_url(bot), request_timeout=request_timeout, **params) return bot + async def remove_bot(self, bot_id: int) -> bool: + """Remove cached bot""" + bot = self._bots.get(bot_id) + if bot is None: + return False + del self._bots[bot_id] + return True + async def on_startup(self, app: Any, *args, bots: set[Bot] | None = None, **kwargs) -> None: # noqa: ARG002 all_bots = set(bots) | set(self._bots.values()) if bots else set(self._bots.values()) workflow_data = self._build_workflow_data(app=app, bots=all_bots, **kwargs) @@ -121,6 +125,7 @@ async def on_shutdown(self, app: Any, *args, **kwargs) -> None: # noqa: ARG002 workflow_data = self._build_workflow_data(app=app, bots=set(self._bots.values()), **kwargs) await self.dispatcher.emit_shutdown(**workflow_data) - for bot in self._bots.values(): - await bot.session.close() + if self.bot_config.session is None: + await self._session.close() + self._bots.clear() diff --git a/src/aiogram_webhook/routing/base.py b/src/aiogram_webhook/routing/base.py index ce0b778..b40db3c 100644 --- a/src/aiogram_webhook/routing/base.py +++ b/src/aiogram_webhook/routing/base.py @@ -16,12 +16,15 @@ class BaseRouting(ABC): def __init__(self, url: str) -> None: self.url = URL(url) - self.base = self.url.origin() - self.path = self.url.path + + @property + def webhook_path(self) -> str: + """Get route path for web framework registration.""" + return self.url.path @abstractmethod - def webhook_point(self, bot: Bot) -> str: - """Get the webhook URL for the given bot.""" + async def webhook_url(self, bot: Bot) -> str: + """Build webhook URL for the given bot.""" raise NotImplementedError @@ -33,6 +36,6 @@ def __init__(self, url: str, param: str = "bot_token") -> None: self.param = param @abstractmethod - def extract_token(self, bound_request: BoundRequest) -> str | None: - """Extract the bot token from the incoming request.""" + async def resolve_token(self, bound_request: BoundRequest) -> str | None: + """Resolve the bot token from the incoming request.""" raise NotImplementedError diff --git a/src/aiogram_webhook/routing/path.py b/src/aiogram_webhook/routing/path.py index 840321f..db55a69 100644 --- a/src/aiogram_webhook/routing/path.py +++ b/src/aiogram_webhook/routing/path.py @@ -1,5 +1,6 @@ from aiogram import Bot +from aiogram_webhook.adapters.base_adapter import BoundRequest from aiogram_webhook.routing.base import TokenRouting @@ -21,8 +22,8 @@ def __init__(self, url: str, param: str = "bot_token") -> None: f"Expected placeholder '{{{self.param}}}' in: {self.url_template}" ) - def webhook_point(self, bot: Bot) -> str: + async def webhook_url(self, bot: Bot) -> str: return self.url_template.format_map({self.param: bot.token}) - def extract_token(self, bound_request) -> str | None: + async def resolve_token(self, bound_request: BoundRequest) -> str | None: return bound_request.path_params.get(self.param) diff --git a/src/aiogram_webhook/routing/query.py b/src/aiogram_webhook/routing/query.py index a53fbf8..fea0d35 100644 --- a/src/aiogram_webhook/routing/query.py +++ b/src/aiogram_webhook/routing/query.py @@ -1,5 +1,6 @@ from aiogram import Bot +from aiogram_webhook.adapters.base_adapter import BoundRequest from aiogram_webhook.routing.base import TokenRouting @@ -11,8 +12,8 @@ class QueryRouting(TokenRouting): Example: https://example.com/webhook?token=123:ABC will extract the token from the query string. """ - def webhook_point(self, bot: Bot) -> str: + async def webhook_url(self, bot: Bot) -> str: return self.url.update_query({self.param: bot.token}).human_repr() - def extract_token(self, bound_request) -> str | None: + async def resolve_token(self, bound_request: BoundRequest) -> str | None: return bound_request.query_params.get(self.param) diff --git a/src/aiogram_webhook/routing/static.py b/src/aiogram_webhook/routing/static.py index 05b3c13..7a710e9 100644 --- a/src/aiogram_webhook/routing/static.py +++ b/src/aiogram_webhook/routing/static.py @@ -1,3 +1,5 @@ +from aiogram import Bot + from aiogram_webhook.routing.base import BaseRouting @@ -8,5 +10,5 @@ def __init__(self, url: str) -> None: super().__init__(url=url) self.url_template = self.url.human_repr() - def webhook_point(self, bot) -> str: # noqa: ARG002 + async def webhook_url(self, bot: Bot) -> str: # noqa: ARG002 return self.url_template diff --git a/src/aiogram_webhook/security/checks/check.py b/src/aiogram_webhook/security/checks/check.py index f2ac427..76a851c 100644 --- a/src/aiogram_webhook/security/checks/check.py +++ b/src/aiogram_webhook/security/checks/check.py @@ -1,6 +1,6 @@ from typing import Protocol -from aiogram import Bot +from aiogram import Dispatcher from aiogram_webhook.adapters.base_adapter import BoundRequest @@ -8,10 +8,12 @@ class SecurityCheck(Protocol): """Protocol for security check on webhook requests.""" - async def verify(self, bot: Bot, bound_request: BoundRequest) -> bool: + async def verify(self, bot_token: str, bound_request: BoundRequest, dispatcher: Dispatcher) -> bool: """ Perform a security check. + :param bot_token: Bot token used by token-aware checks. + :param dispatcher: Dispatcher instance for dependency-aware checks. :return: True if the check passes, False otherwise. """ raise NotImplementedError diff --git a/src/aiogram_webhook/security/checks/ip.py b/src/aiogram_webhook/security/checks/ip.py index 66ea6a3..ba3b457 100644 --- a/src/aiogram_webhook/security/checks/ip.py +++ b/src/aiogram_webhook/security/checks/ip.py @@ -1,6 +1,8 @@ from ipaddress import IPv4Address, IPv4Network, IPv6Address, IPv6Network, ip_address, ip_network from typing import Final +from aiogram import Dispatcher + from aiogram_webhook.adapters.base_adapter import BoundRequest from aiogram_webhook.security.checks.check import SecurityCheck @@ -43,7 +45,7 @@ def __init__(self, *ip_entries: IPNetwork | IPAddress | str, include_default: bo else: self._addresses.add(parsed) - async def verify(self, bot, bound_request: BoundRequest) -> bool: # noqa: ARG002 + async def verify(self, bot_token: str, bound_request: BoundRequest, dispatcher: Dispatcher) -> bool: # noqa: ARG002 raw_ip = self._get_client_ip(bound_request) if not raw_ip: return False diff --git a/src/aiogram_webhook/security/secret_token.py b/src/aiogram_webhook/security/secret_token.py index 470e340..030339f 100644 --- a/src/aiogram_webhook/security/secret_token.py +++ b/src/aiogram_webhook/security/secret_token.py @@ -1,32 +1,33 @@ import re from abc import ABC, abstractmethod from hmac import compare_digest +from typing import Final -from aiogram import Bot +from aiogram import Dispatcher from aiogram_webhook.adapters.base_adapter import BoundRequest SECRET_TOKEN_PATTERN = re.compile(r"^[A-Za-z0-9_-]{1,256}$") +SECRET_TOKEN_HEADER: Final[str] = "x-telegram-bot-api-secret-token" # noqa: S105 class SecretToken(ABC): """ - Abstract base class for secret token verification in webhook requests. + Base class for secret token verification in webhook requests. """ - secret_header: str = "x-telegram-bot-api-secret-token" # noqa: S105 + async def verify(self, bot_token: str, bound_request: BoundRequest, dispatcher: Dispatcher) -> bool: # noqa: ARG002 + incoming_secret_token = bound_request.headers.get(SECRET_TOKEN_HEADER) + if incoming_secret_token is None: + return False + return compare_digest(incoming_secret_token, await self.secret_token(bot_token)) @abstractmethod - async def verify(self, bot: Bot, bound_request: BoundRequest) -> bool: - """ - Verify the secret token in the incoming request. + async def secret_token(self, bot_token: str) -> str: """ - raise NotImplementedError + Return the webhook secret token associated with the given bot token. - @abstractmethod - def secret_token(self, bot: Bot) -> str: - """ - Return the secret token for the given bot. + :param bot_token: Bot token used to resolve expected secret token. """ raise NotImplementedError @@ -39,16 +40,10 @@ class StaticSecretToken(SecretToken): See: https://core.telegram.org/bots/api#setwebhook """ - def __init__(self, token: str) -> None: - if not SECRET_TOKEN_PATTERN.match(token): + def __init__(self, secret_token: str) -> None: + if not SECRET_TOKEN_PATTERN.match(secret_token): raise ValueError("Invalid secret token format. Must be 1-256 characters, only A-Z, a-z, 0-9, _, -.") - self._token = token - - async def verify(self, bot: Bot, bound_request: BoundRequest) -> bool: # noqa: ARG002 - incoming = bound_request.headers.get(self.secret_header) - if incoming is None: - return False - return compare_digest(incoming, self._token) + self.__secret_token = secret_token - def secret_token(self, bot: Bot) -> str: # noqa: ARG002 - return self._token + async def secret_token(self, bot_token: str) -> str: # noqa: ARG002 + return self.__secret_token diff --git a/src/aiogram_webhook/security/security.py b/src/aiogram_webhook/security/security.py index ee8ea7e..4ba203c 100644 --- a/src/aiogram_webhook/security/security.py +++ b/src/aiogram_webhook/security/security.py @@ -1,4 +1,4 @@ -from aiogram import Bot +from aiogram import Dispatcher from aiogram_webhook.adapters.base_adapter import BoundRequest from aiogram_webhook.security.checks.check import SecurityCheck @@ -16,29 +16,32 @@ def __init__(self, *checks: SecurityCheck, secret_token: SecretToken | None = No self._secret_token = secret_token self._checks: tuple[SecurityCheck, ...] = checks - async def verify(self, bot: Bot, bound_request: BoundRequest) -> bool: + async def verify(self, bot_token: str, bound_request: BoundRequest, dispatcher: Dispatcher) -> bool: """ Verify the security of a webhook request. + :param bot_token: Bot token for webhook route and token-aware checks. + :param dispatcher: Dispatcher instance for dependency-aware checks. :return: True if the request passes security checks, False otherwise. """ if self._secret_token is not None: - ok = await self._secret_token.verify(bot=bot, bound_request=bound_request) + ok = await self._secret_token.verify(bot_token, bound_request, dispatcher=dispatcher) if not ok: return False for checker in self._checks: - if not await checker.verify(bot=bot, bound_request=bound_request): + if not await checker.verify(bot_token, bound_request, dispatcher=dispatcher): return False return True - async def get_secret_token(self, *, bot: Bot) -> str | None: + async def secret_token(self, bot_token: str) -> str | None: """ Get the secret token for the given bot, if configured. - :return: The secret token as a string. + :param bot_token: Bot token for which secret token should be resolved. + :return: The secret token as a string, or None if no secret-token provider is configured. """ if self._secret_token is None: return None - return self._secret_token.secret_token(bot=bot) + return await self._secret_token.secret_token(bot_token=bot_token) diff --git a/tests/conftest.py b/tests/conftest.py index 5c9dadc..6db2e50 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,7 @@ from ipaddress import IPv4Address import pytest -from aiogram import Bot +from aiogram import Bot, Dispatcher @pytest.fixture @@ -9,6 +9,11 @@ def bot(): return Bot("42:TEST") +@pytest.fixture +def dispatcher() -> Dispatcher: + return Dispatcher() + + @pytest.fixture def localhost_ip() -> IPv4Address: return IPv4Address("127.0.0.1") diff --git a/tests/fixtures/fixtures_checks.py b/tests/fixtures/fixtures_checks.py index 3fcf4a7..02ff822 100644 --- a/tests/fixtures/fixtures_checks.py +++ b/tests/fixtures/fixtures_checks.py @@ -1,16 +1,16 @@ -from aiogram import Bot +from aiogram import Dispatcher from aiogram_webhook.adapters.base_adapter import BoundRequest from aiogram_webhook.security.checks.check import SecurityCheck class PassingCheck(SecurityCheck): - async def verify(self, bot: Bot, bound_request: BoundRequest) -> bool: + async def verify(self, bot_token: str, bound_request: BoundRequest, dispatcher: Dispatcher) -> bool: return True class FailingCheck(SecurityCheck): - async def verify(self, bot: Bot, bound_request: BoundRequest) -> bool: + async def verify(self, bot_token: str, bound_request: BoundRequest, dispatcher: Dispatcher) -> bool: return False @@ -18,5 +18,5 @@ class ConditionalCheck(SecurityCheck): def __init__(self, condition: bool): self.condition = condition - async def verify(self, bot: Bot, bound_request: BoundRequest) -> bool: + async def verify(self, bot_token: str, bound_request: BoundRequest, dispatcher: Dispatcher) -> bool: return self.condition diff --git a/tests/test_ip_check.py b/tests/test_ip_check.py index 319b0f6..7ae8406 100644 --- a/tests/test_ip_check.py +++ b/tests/test_ip_check.py @@ -22,10 +22,10 @@ "direct-no-ip", ], ) -async def test_ip_check_direct(allowed_ips, request_ip, expected, bot): +async def test_ip_check_direct(allowed_ips, request_ip, expected, dispatcher): req = DummyBoundRequest(DummyRequest(ip=request_ip)) ip_check = IPCheck(*allowed_ips, include_default=False) - assert await ip_check.verify(bot, req) is expected + assert await ip_check.verify("42:TEST", req, dispatcher=dispatcher) is expected @pytest.mark.asyncio @@ -50,11 +50,11 @@ async def test_ip_check_direct(allowed_ips, request_ip, expected, bot): "forwarded-no-header", ], ) -async def test_ip_check_forwarded(allowed_ips, x_forwarded_for, expected, bot): +async def test_ip_check_forwarded(allowed_ips, x_forwarded_for, expected, dispatcher): headers = {"X-Forwarded-For": x_forwarded_for} if x_forwarded_for is not None else None req = DummyBoundRequest(DummyRequest(ip="127.0.0.1", headers=headers)) ip_check = IPCheck(*allowed_ips, include_default=False) - assert await ip_check.verify(bot, req) is expected + assert await ip_check.verify("42:TEST", req, dispatcher=dispatcher) is expected @pytest.mark.asyncio @@ -75,11 +75,11 @@ async def test_ip_check_forwarded(allowed_ips, x_forwarded_for, expected, bot): "both-both-invalid", ], ) -async def test_ip_check_both_priority(allowed_ips, request_ip, x_forwarded_for, expected, bot): +async def test_ip_check_both_priority(allowed_ips, request_ip, x_forwarded_for, expected, dispatcher): headers = {"X-Forwarded-For": x_forwarded_for} req = DummyBoundRequest(DummyRequest(ip=request_ip, headers=headers)) ip_check = IPCheck(*allowed_ips, include_default=False) - assert await ip_check.verify(bot, req) is expected + assert await ip_check.verify("42:TEST", req, dispatcher=dispatcher) is expected @pytest.mark.asyncio @@ -94,8 +94,8 @@ async def test_ip_check_both_priority(allowed_ips, request_ip, x_forwarded_for, "edgecase-first-invalid", ], ) -async def test_ip_check_edge_cases(allowed_ips, request_ip, x_forwarded_for, expected, bot): +async def test_ip_check_edge_cases(allowed_ips, request_ip, x_forwarded_for, expected, dispatcher): headers = {"X-Forwarded-For": x_forwarded_for} req = DummyBoundRequest(DummyRequest(ip=request_ip, headers=headers)) ip_check = IPCheck(*allowed_ips, include_default=False) - assert await ip_check.verify(bot, req) is expected + assert await ip_check.verify("42:TEST", req, dispatcher=dispatcher) is expected diff --git a/tests/test_routing.py b/tests/test_routing.py index e470192..fbf4133 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -17,9 +17,10 @@ "https://example.com/webhook?foo=bar", ], ) -def test_static_routing(url, bot): +@pytest.mark.asyncio +async def test_static_routing(url, bot): routing = StaticRouting(url=url) - assert routing.webhook_point(bot) == url + assert await routing.webhook_url(bot) == url @pytest.mark.parametrize( @@ -53,11 +54,12 @@ def test_static_routing(url, bot): ], ids=["standard-param-present", "standard-param-missing", "custom-param-present", "custom-param-missing"], ) -def test_path_routing(url, param, token, path_params, expected_url, expected_token): +@pytest.mark.asyncio +async def test_path_routing(url, param, token, path_params, expected_url, expected_token): routing = PathRouting(url=url, param=param) - assert routing.webhook_point(Bot(token)) == expected_url + assert await routing.webhook_url(Bot(token)) == expected_url req = DummyBoundRequest(DummyRequest(path_params=path_params)) - assert routing.extract_token(req) == expected_token + assert await routing.resolve_token(req) == expected_token @pytest.mark.parametrize( @@ -139,9 +141,10 @@ def test_path_routing(url, param, token, path_params, expected_url, expected_tok "complex-params", ], ) -def test_query_routing(url, param, token, query_params, expected_url, expected_token): +@pytest.mark.asyncio +async def test_query_routing(url, param, token, query_params, expected_url, expected_token): routing = QueryRouting(url=url, param=param) - webhook_url = routing.webhook_point(Bot(token)) + webhook_url = await routing.webhook_url(Bot(token)) # Parse both URLs to compare query params (order may differ) expected = URL(expected_url) @@ -154,4 +157,4 @@ def test_query_routing(url, param, token, query_params, expected_url, expected_t # Check token extraction req = DummyBoundRequest(DummyRequest(query_params=query_params)) - assert routing.extract_token(req) == expected_token + assert await routing.resolve_token(req) == expected_token diff --git a/tests/test_secret_token.py b/tests/test_secret_token.py index fcaae30..300dbb4 100644 --- a/tests/test_secret_token.py +++ b/tests/test_secret_token.py @@ -1,6 +1,7 @@ import pytest from aiogram_webhook.security import Security, StaticSecretToken +from aiogram_webhook.security.secret_token import SECRET_TOKEN_HEADER from tests.fixtures import DummyBoundRequest, DummyRequest @@ -14,11 +15,11 @@ ], ids=["match", "mismatch", "none"], ) -async def test_security_secret_token(secret_token, request_token, expected, bot): +async def test_security_secret_token(secret_token, request_token, expected, dispatcher): sec = Security(secret_token=StaticSecretToken(secret_token)) - headers = {"x-telegram-bot-api-secret-token": request_token} if request_token is not None else {} + headers = {SECRET_TOKEN_HEADER: request_token} if request_token is not None else {} req = DummyBoundRequest(DummyRequest(headers=headers)) - assert await sec.verify(bot, req) is expected + assert await sec.verify("42:TEST", req, dispatcher=dispatcher) is expected @pytest.mark.asyncio @@ -30,6 +31,6 @@ async def test_security_secret_token(secret_token, request_token, expected, bot) ], ids=["with-secret", "without-secret"], ) -async def test_security_get_secret_token(secret_token, expected, bot): +async def test_security_secret_token_getter(secret_token, expected): sec = Security(secret_token=secret_token) - assert await sec.get_secret_token(bot=bot) == expected + assert await sec.secret_token(bot_token="42:TEST") == expected diff --git a/tests/test_security.py b/tests/test_security.py index 0cc4ab7..2e9b603 100644 --- a/tests/test_security.py +++ b/tests/test_security.py @@ -1,6 +1,6 @@ import pytest -from aiogram_webhook.security.secret_token import StaticSecretToken +from aiogram_webhook.security.secret_token import SECRET_TOKEN_HEADER, StaticSecretToken from aiogram_webhook.security.security import Security from tests.fixtures import DummyBoundRequest, DummyRequest, FailingCheck, PassingCheck @@ -37,10 +37,10 @@ "failing-last-passing", ], ) -async def test_security_checks(checks, expected, bot): +async def test_security_checks(checks, expected, dispatcher): sec = Security(*checks) req = DummyBoundRequest() - assert await sec.verify(bot, req) is expected + assert await sec.verify("42:TEST", req, dispatcher=dispatcher) is expected @pytest.mark.asyncio @@ -71,8 +71,8 @@ async def test_security_checks(checks, expected, bot): "no-checks-no-secret", ], ) -async def test_security_checks_and_secret_token(checks, secret_token, request_token, expected, bot): +async def test_security_checks_and_secret_token(checks, secret_token, request_token, expected, dispatcher): sec = Security(*checks, secret_token=secret_token) - headers = {"x-telegram-bot-api-secret-token": request_token} if request_token is not None else {} + headers = {SECRET_TOKEN_HEADER: request_token} if request_token is not None else {} req = DummyBoundRequest(DummyRequest(headers=headers)) - assert await sec.verify(bot, req) is expected + assert await sec.verify("42:TEST", req, dispatcher=dispatcher) is expected