from __future__ import annotations
import typing
from typing import Any
from typing import AnyStr
from typing import Dict
from typing import Protocol
from typing import Sequence
from typing import Type
from typing import runtime_checkable
import httpx
from httpx import Response
from ._assertions import RESPONSE_BYTESIZE_MISMATCH
from ._assertions import RESPONSE_ELAPSED_TIME_LESS_THAN
from ._assertions import RESPONSE_ENCODING_MISMATCH
from ._assertions import RESPONSE_HISTORY_LENGTH_MISMATCH
from ._assertions import RESPONSE_NO_HEADER
from ._assertions import RESPONSE_NO_JSON_PAYLOAD
from ._assertions import RESPONSE_NOT_HTTP1
from ._assertions import RESPONSE_NOT_HTTP2
from ._assertions import RESPONSE_STATUS_CODE_MISMATCH
from ._exceptions import RestpiteAssertionError
from ._schema import RestpiteSchema
from ._status_codes import StatusCodes
@runtime_checkable
class Curlable(Protocol):
def curlify(self) -> str:
"""
Converts the object into a curl string, used for recreating the request
"""
...
[docs]class RestpiteResponse(Curlable):
def __init__(self, httpx_response: Response) -> None:
self.httpx_response = httpx_response
self.status_code = StatusCodes(httpx_response.status_code) # noqa
@property
def request_method(self) -> AnyStr:
return self.httpx_response.request.method # type: ignore
@property
def url(self) -> typing.Optional[httpx.URL]:
return self.httpx_response.url
@property
def text(self) -> Any:
return self.httpx_response.text
def json(self, **kwargs) -> Any:
return self.httpx_response.json(**kwargs)
def __getattr__(self, name: str) -> Any:
attr = getattr(self.httpx_response, name)
if not callable(attr):
return attr
def wrapper(*args, **kwargs):
return attr(*args, **kwargs)
return wrapper
def deserialize(self, schema: Type[RestpiteSchema], schema_kwargs: Dict[Any, Any]) -> Any:
# TODO: Incorporate schemas from marshmallow here to allow custom deserialization?
# TODO: Small implementation; plenty of work still to do here!
return schema(**schema_kwargs).load(self.json())
[docs] def validate(self, schema: Type[RestpiteSchema], schema_kwargs: Dict[Any, Any]) -> None:
"""
Validate(s) response json against a pre-defined schema. This does not dump the
response json through the schema, but checks for validation issues and raises a
`ValidationError` upon failure.
"""
schema(**schema_kwargs).validate(data=self.json())
# -------------------------------- HTTP RESPONSE STATUS CODE ASSERTIONS --------------------------------
[docs] def had_success_status_code(self) -> RestpiteResponse:
"""
Enforces that the response status code was a 2xx (Successful) based code. Note we do a wide assertion of
200-299 but often these are limited to a very small subset of that range, should this change in future
restpite will just work.
"""
self._assert_response_code_in_range(range(200, 300))
return self
[docs] def had_redirect_status_code(self) -> RestpiteResponse:
"""
Enforces that the response status code was a 3xx (Redirection) based code. Note we do a wide assertion of
300-399 but often these are limited to a very small subset of that range, should this change in future
restpite will just work.
"""
self._assert_response_code_in_range(range(300, 400))
return self
[docs] def had_client_error_status_code(self) -> RestpiteResponse:
"""
Enforces that the response status code was a 4xx (Client Error) based code. Note we do a wide assertion of
400-499 but often these are limited to a very small subset of that range, should this change in future
restpite will just work.
"""
self._assert_response_code_in_range(range(400, 500))
return self
[docs] def had_server_error_status_code(self) -> RestpiteResponse:
"""
Enforces that the response status code was a 3xx (Server Error) based code. Note we do a wide assertion of
500-599 but often these are limited to a very small subset of that range, should this change in future
restpite will just work.
"""
self._assert_response_code_in_range(range(500, 600))
return self
[docs] def had_status(self, expected_code: int) -> RestpiteResponse:
"""
Given a status code, matches the response against it.
:param expected_code: The expected status code, it is the callers responsibility here
to provide a 3 digit status code.
"""
if expected_code != self.status_code:
self._assertion_error(RESPONSE_STATUS_CODE_MISMATCH.format(self.status_code, expected_code))
return self
def _assert_response_code_in_range(self, expected_range: Sequence[int]) -> None:
"""
Validates the wrapped response object had a status code inside a particular range
:param expected_range: A sequence of integers to check the status code is in
:raises RestpiteAssertionError: If the status code is not in expected_range
"""
if self.status_code not in expected_range:
self._assertion_error(f"Expected: {self.status_code} to be in: {expected_range} but it was not")
# -------------------------------- HTTP RESPONSE HEADER ASSERTIONS -------------------------------------
# ------------------------------------------------------------------------------------------------------
[docs] def elapsed_time_was_less_than(self, seconds: int) -> RestpiteResponse:
"""
Checks that the total time between sending the request and closing the response connection
was less than `seconds`.
:param seconds: Expected number of seconds to ensure the request -> response cycle was less than
"""
elapsed_seconds = self.httpx_response.elapsed.total_seconds()
if elapsed_seconds >= seconds:
self._assertion_error(RESPONSE_ELAPSED_TIME_LESS_THAN.format(elapsed_seconds, seconds))
return self
[docs] def had_total_number_of_bytes(self, num_bytes: int) -> RestpiteResponse:
"""
Checks that the total number of bytes returned by the server was equal to num_bytes.
:param num_bytes: Expected number of bytes in the response
"""
response_bytes = self.httpx_response.num_bytes_downloaded
if response_bytes != num_bytes:
self._assertion_error(RESPONSE_BYTESIZE_MISMATCH.format(response_bytes, num_bytes))
return self
[docs] def was_http1(self) -> RestpiteResponse:
"""
Checks that the servers response was HTTP1/1
"""
if self.httpx_response.http_version != "HTTP/1.1":
self._assertion_error(RESPONSE_NOT_HTTP1.format(self.httpx_response.http_version))
return self
[docs] def was_http2(self) -> RestpiteResponse:
"""
Checks that he servers response was truly HTTP/2
"""
if self.httpx_response.http_version != "HTTP/2":
self._assertion_error(RESPONSE_NOT_HTTP2.format(self.httpx_response.http_version))
return self
[docs] def had_json_payload(self) -> RestpiteResponse:
"""
Checks that the response contained a json payload
"""
try:
self.json()
except ValueError:
self._assertion_error(RESPONSE_NO_JSON_PAYLOAD)
return self
[docs] def encoding_was(self, expected_encoding: str) -> RestpiteResponse:
"""
Checks that the response encoding was a particular type
"""
if self.httpx_response.encoding != expected_encoding:
self._assertion_error(RESPONSE_ENCODING_MISMATCH.format(self.httpx_response.encoding, expected_encoding))
return self
[docs] def had_history_length(self, expected_length: int) -> RestpiteResponse:
"""
Checks the the response chain was of `expected_length` length.
"""
if len(self.httpx_response.history) != expected_length:
self._assertion_error(
RESPONSE_HISTORY_LENGTH_MISMATCH.format(len(self.httpx_response.history), expected_length)
)
return self
# ------------------------------------------------------------------------------------------------------
def _assertion_error(self, message: str) -> None:
"""
Responsible for raising the `RestpiteAssertionError` which will subsequently cause tests
to fail. RestpiteAssertionError is a simple subclass of `AssertionError` with the aim
in future to bolt on more functionality, currently it serves the same purpose.
"""
raise RestpiteAssertionError(message)
[docs] def curlify(self) -> str:
# TODO: Debatable functionality
raise NotImplementedError
def __repr__(self) -> str:
# TODO: Debatable implementation!
return f"<[{repr(self.status_code)} : {self.httpx_response.url}]>"