diff --git a/src/github_app.py b/src/github_app.py index c18b6a9..8d60742 100644 --- a/src/github_app.py +++ b/src/github_app.py @@ -10,6 +10,10 @@ import jwt import requests +TOKEN_REQUEST_TIMEOUT_SECONDS = 10 +TOKEN_REQUEST_MAX_ATTEMPTS = 3 +TOKEN_REQUEST_BACKOFF_SECONDS = 1 + class GithubAppToken: def __init__(self, private_key, app_id) -> None: @@ -19,12 +23,7 @@ def __init__(self, private_key, app_id) -> None: # configured by the GitHub App and expire after one hour. @contextlib.contextmanager def get_token(self, installation_id: int) -> Generator[str, None, None]: - req = requests.post( - url=f"https://api.github.com/app/installations/{installation_id}/access_tokens", - headers=self.headers, - ) - req.raise_for_status() - resp = req.json() + resp = self._create_installation_access_token(installation_id) try: # This token expires in an hour yield resp["token"] @@ -32,8 +31,30 @@ def get_token(self, installation_id: int) -> Generator[str, None, None]: requests.delete( "https://api.github.com/installation/token", headers={"Authorization": f"token {resp['token']}"}, + timeout=TOKEN_REQUEST_TIMEOUT_SECONDS, ) + def _create_installation_access_token(self, installation_id: int) -> dict: + for attempt in range(1, TOKEN_REQUEST_MAX_ATTEMPTS + 1): + try: + req = requests.post( + url=f"https://api.github.com/app/installations/{installation_id}/access_tokens", + headers=self.headers, + timeout=TOKEN_REQUEST_TIMEOUT_SECONDS, + ) + req.raise_for_status() + return req.json() + except ( + requests.exceptions.ConnectionError, + requests.exceptions.SSLError, + requests.exceptions.Timeout, + ): + if attempt == TOKEN_REQUEST_MAX_ATTEMPTS: + raise + time.sleep(TOKEN_REQUEST_BACKOFF_SECONDS * attempt) + + raise RuntimeError("Failed to mint GitHub App installation token after retries") + def get_jwt_token(self, private_key, app_id): payload = { # issued at time, 60 seconds in the past to allow for clock drift diff --git a/tests/test_github_app.py b/tests/test_github_app.py new file mode 100644 index 0000000..5858ead --- /dev/null +++ b/tests/test_github_app.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +from unittest.mock import Mock +from unittest.mock import call +from unittest.mock import patch + +import pytest +import requests + +from src.github_app import GithubAppToken +from src.github_app import TOKEN_REQUEST_BACKOFF_SECONDS +from src.github_app import TOKEN_REQUEST_MAX_ATTEMPTS +from src.github_app import TOKEN_REQUEST_TIMEOUT_SECONDS + + +def _build_token_manager() -> GithubAppToken: + with patch.object(GithubAppToken, "get_authentication_header", return_value={}): + return GithubAppToken(private_key="irrelevant", app_id="1") + + +def test_get_token_retries_ssl_errors_and_returns_token(): + token_manager = _build_token_manager() + success_response = Mock() + success_response.raise_for_status.return_value = None + success_response.json.return_value = {"token": "test-token"} + + with ( + patch( + "src.github_app.requests.post", + side_effect=[requests.exceptions.SSLError("tls"), success_response], + ) as mock_post, + patch("src.github_app.requests.delete") as mock_delete, + patch("src.github_app.time.sleep") as mock_sleep, + ): + with token_manager.get_token(123) as token: + assert token == "test-token" + + assert mock_post.call_count == 2 + mock_sleep.assert_called_once_with(TOKEN_REQUEST_BACKOFF_SECONDS) + mock_delete.assert_called_once_with( + "https://api.github.com/installation/token", + headers={"Authorization": "token test-token"}, + timeout=TOKEN_REQUEST_TIMEOUT_SECONDS, + ) + + +def test_get_token_raises_after_max_transient_failures(): + token_manager = _build_token_manager() + + with ( + patch( + "src.github_app.requests.post", + side_effect=requests.exceptions.Timeout("slow network"), + ) as mock_post, + patch("src.github_app.requests.delete") as mock_delete, + patch("src.github_app.time.sleep") as mock_sleep, + ): + with pytest.raises(requests.exceptions.Timeout): + with token_manager.get_token(123): + pass + + assert mock_post.call_count == TOKEN_REQUEST_MAX_ATTEMPTS + assert mock_sleep.call_args_list == [ + call(TOKEN_REQUEST_BACKOFF_SECONDS), + call(TOKEN_REQUEST_BACKOFF_SECONDS * 2), + ] + mock_delete.assert_not_called() + + +def test_get_token_does_not_retry_http_error(): + token_manager = _build_token_manager() + failed_response = Mock() + failed_response.raise_for_status.side_effect = requests.exceptions.HTTPError("500") + + with ( + patch("src.github_app.requests.post", return_value=failed_response) as mock_post, + patch("src.github_app.requests.delete") as mock_delete, + patch("src.github_app.time.sleep") as mock_sleep, + ): + with pytest.raises(requests.exceptions.HTTPError): + with token_manager.get_token(123): + pass + + assert mock_post.call_count == 1 + mock_sleep.assert_not_called() + mock_delete.assert_not_called()