扫码登录,获取cookies

This commit is contained in:
2026-03-09 16:10:29 +08:00
parent 754e720ba7
commit 8229208165
7775 changed files with 1150053 additions and 208 deletions

View File

@@ -0,0 +1,9 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.

View File

@@ -0,0 +1,685 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
import re
from typing import NamedTuple, Optional, Tuple, Union
from hypothesis import assume, strategies as st
from hypothesis.errors import InvalidArgument
from hypothesis.internal.conjecture.utils import _calc_p_continue
from hypothesis.internal.coverage import check_function
from hypothesis.internal.validation import check_type, check_valid_interval
from hypothesis.strategies._internal.utils import defines_strategy
from hypothesis.utils.conventions import UniqueIdentifier, not_set
__all__ = [
"NDIM_MAX",
"Shape",
"BroadcastableShapes",
"BasicIndex",
"check_argument",
"order_check",
"check_valid_dims",
"array_shapes",
"valid_tuple_axes",
"broadcastable_shapes",
"mutually_broadcastable_shapes",
"MutuallyBroadcastableShapesStrategy",
"BasicIndexStrategy",
]
Shape = Tuple[int, ...]
# We silence flake8 here because it disagrees with mypy about `ellipsis` (`type(...)`)
BasicIndex = Tuple[Union[int, slice, None, "ellipsis"], ...] # noqa: F821
class BroadcastableShapes(NamedTuple):
input_shapes: Tuple[Shape, ...]
result_shape: Shape
@check_function
def check_argument(condition, fail_message, *f_args, **f_kwargs):
if not condition:
raise InvalidArgument(fail_message.format(*f_args, **f_kwargs))
@check_function
def order_check(name, floor, min_, max_):
if floor > min_:
raise InvalidArgument(f"min_{name} must be at least {floor} but was {min_}")
if min_ > max_:
raise InvalidArgument(f"min_{name}={min_} is larger than max_{name}={max_}")
# 32 is a dimension limit specific to NumPy, and does not necessarily apply to
# other array/tensor libraries. Historically these strategies were built for the
# NumPy extra, so it's nice to keep these limits, and it's seemingly unlikely
# someone would want to generate >32 dim arrays anyway.
# See https://github.com/HypothesisWorks/hypothesis/pull/3067.
NDIM_MAX = 32
@check_function
def check_valid_dims(dims, name):
if dims > NDIM_MAX:
raise InvalidArgument(
f"{name}={dims}, but Hypothesis does not support arrays with "
f"more than {NDIM_MAX} dimensions"
)
@defines_strategy()
def array_shapes(
*,
min_dims: int = 1,
max_dims: Optional[int] = None,
min_side: int = 1,
max_side: Optional[int] = None,
) -> st.SearchStrategy[Shape]:
"""Return a strategy for array shapes (tuples of int >= 1).
* ``min_dims`` is the smallest length that the generated shape can possess.
* ``max_dims`` is the largest length that the generated shape can possess,
defaulting to ``min_dims + 2``.
* ``min_side`` is the smallest size that a dimension can possess.
* ``max_side`` is the largest size that a dimension can possess,
defaulting to ``min_side + 5``.
"""
check_type(int, min_dims, "min_dims")
check_type(int, min_side, "min_side")
check_valid_dims(min_dims, "min_dims")
if max_dims is None:
max_dims = min(min_dims + 2, NDIM_MAX)
check_type(int, max_dims, "max_dims")
check_valid_dims(max_dims, "max_dims")
if max_side is None:
max_side = min_side + 5
check_type(int, max_side, "max_side")
order_check("dims", 0, min_dims, max_dims)
order_check("side", 0, min_side, max_side)
return st.lists(
st.integers(min_side, max_side), min_size=min_dims, max_size=max_dims
).map(tuple)
@defines_strategy()
def valid_tuple_axes(
ndim: int,
*,
min_size: int = 0,
max_size: Optional[int] = None,
) -> st.SearchStrategy[Tuple[int, ...]]:
"""All tuples will have a length >= ``min_size`` and <= ``max_size``. The default
value for ``max_size`` is ``ndim``.
Examples from this strategy shrink towards an empty tuple, which render most
sequential functions as no-ops.
The following are some examples drawn from this strategy.
.. code-block:: pycon
>>> [valid_tuple_axes(3).example() for i in range(4)]
[(-3, 1), (0, 1, -1), (0, 2), (0, -2, 2)]
``valid_tuple_axes`` can be joined with other strategies to generate
any type of valid axis object, i.e. integers, tuples, and ``None``:
.. code-block:: python
any_axis_strategy = none() | integers(-ndim, ndim - 1) | valid_tuple_axes(ndim)
"""
check_type(int, ndim, "ndim")
check_type(int, min_size, "min_size")
if max_size is None:
max_size = ndim
check_type(int, max_size, "max_size")
order_check("size", 0, min_size, max_size)
check_valid_interval(max_size, ndim, "max_size", "ndim")
axes = st.integers(0, max(0, 2 * ndim - 1)).map(
lambda x: x if x < ndim else x - 2 * ndim
)
return st.lists(
axes, min_size=min_size, max_size=max_size, unique_by=lambda x: x % ndim
).map(tuple)
@defines_strategy()
def broadcastable_shapes(
shape: Shape,
*,
min_dims: int = 0,
max_dims: Optional[int] = None,
min_side: int = 1,
max_side: Optional[int] = None,
) -> st.SearchStrategy[Shape]:
"""Return a strategy for shapes that are broadcast-compatible with the
provided shape.
Examples from this strategy shrink towards a shape with length ``min_dims``.
The size of an aligned dimension shrinks towards size ``1``. The size of an
unaligned dimension shrink towards ``min_side``.
* ``shape`` is a tuple of integers.
* ``min_dims`` is the smallest length that the generated shape can possess.
* ``max_dims`` is the largest length that the generated shape can possess,
defaulting to ``max(len(shape), min_dims) + 2``.
* ``min_side`` is the smallest size that an unaligned dimension can possess.
* ``max_side`` is the largest size that an unaligned dimension can possess,
defaulting to 2 plus the size of the largest aligned dimension.
The following are some examples drawn from this strategy.
.. code-block:: pycon
>>> [broadcastable_shapes(shape=(2, 3)).example() for i in range(5)]
[(1, 3), (), (2, 3), (2, 1), (4, 1, 3), (3, )]
"""
check_type(tuple, shape, "shape")
check_type(int, min_side, "min_side")
check_type(int, min_dims, "min_dims")
check_valid_dims(min_dims, "min_dims")
strict_check = max_side is None or max_dims is None
if max_dims is None:
max_dims = min(max(len(shape), min_dims) + 2, NDIM_MAX)
check_type(int, max_dims, "max_dims")
check_valid_dims(max_dims, "max_dims")
if max_side is None:
max_side = max(shape[-max_dims:] + (min_side,)) + 2
check_type(int, max_side, "max_side")
order_check("dims", 0, min_dims, max_dims)
order_check("side", 0, min_side, max_side)
if strict_check:
dims = max_dims
bound_name = "max_dims"
else:
dims = min_dims
bound_name = "min_dims"
# check for unsatisfiable min_side
if not all(min_side <= s for s in shape[::-1][:dims] if s != 1):
raise InvalidArgument(
f"Given shape={shape}, there are no broadcast-compatible "
f"shapes that satisfy: {bound_name}={dims} and min_side={min_side}"
)
# check for unsatisfiable [min_side, max_side]
if not (
min_side <= 1 <= max_side or all(s <= max_side for s in shape[::-1][:dims])
):
raise InvalidArgument(
f"Given base_shape={shape}, there are no broadcast-compatible "
f"shapes that satisfy all of {bound_name}={dims}, "
f"min_side={min_side}, and max_side={max_side}"
)
if not strict_check:
# reduce max_dims to exclude unsatisfiable dimensions
for n, s in zip(range(max_dims), shape[::-1]):
if s < min_side and s != 1:
max_dims = n
break
elif not (min_side <= 1 <= max_side or s <= max_side):
max_dims = n
break
return MutuallyBroadcastableShapesStrategy(
num_shapes=1,
base_shape=shape,
min_dims=min_dims,
max_dims=max_dims,
min_side=min_side,
max_side=max_side,
).map(lambda x: x.input_shapes[0])
# See https://numpy.org/doc/stable/reference/c-api/generalized-ufuncs.html
# Implementation based on numpy.lib.function_base._parse_gufunc_signature
# with minor upgrades to handle numeric and optional dimensions. Examples:
#
# add (),()->() binary ufunc
# sum1d (i)->() reduction
# inner1d (i),(i)->() vector-vector multiplication
# matmat (m,n),(n,p)->(m,p) matrix multiplication
# vecmat (n),(n,p)->(p) vector-matrix multiplication
# matvec (m,n),(n)->(m) matrix-vector multiplication
# matmul (m?,n),(n,p?)->(m?,p?) combination of the four above
# cross1d (3),(3)->(3) cross product with frozen dimensions
#
# Note that while no examples of such usage are given, Numpy does allow
# generalised ufuncs that have *multiple output arrays*. This is not
# currently supported by Hypothesis - please contact us if you would use it!
#
# We are unsure if gufuncs allow frozen dimensions to be optional, but it's
# easy enough to support here - and so we will unless we learn otherwise.
_DIMENSION = r"\w+\??" # Note that \w permits digits too!
_SHAPE = rf"\((?:{_DIMENSION}(?:,{_DIMENSION}){{0,31}})?\)"
_ARGUMENT_LIST = f"{_SHAPE}(?:,{_SHAPE})*"
_SIGNATURE = rf"^{_ARGUMENT_LIST}->{_SHAPE}$"
_SIGNATURE_MULTIPLE_OUTPUT = rf"^{_ARGUMENT_LIST}->{_ARGUMENT_LIST}$"
class _GUfuncSig(NamedTuple):
input_shapes: Tuple[Shape, ...]
result_shape: Shape
def _hypothesis_parse_gufunc_signature(signature):
# Disable all_checks to better match the Numpy version, for testing
if not re.match(_SIGNATURE, signature):
if re.match(_SIGNATURE_MULTIPLE_OUTPUT, signature):
raise InvalidArgument(
"Hypothesis does not yet support generalised ufunc signatures "
"with multiple output arrays - mostly because we don't know of "
"anyone who uses them! Please get in touch with us to fix that."
f"\n ({signature=})"
)
if re.match(
(
# Taken from np.lib.function_base._SIGNATURE
r"^\((?:\w+(?:,\w+)*)?\)(?:,\((?:\w+(?:,\w+)*)?\))*->"
r"\((?:\w+(?:,\w+)*)?\)(?:,\((?:\w+(?:,\w+)*)?\))*$"
),
signature,
):
raise InvalidArgument(
f"{signature=} matches Numpy's regex for gufunc signatures, "
f"but contains shapes with more than {NDIM_MAX} dimensions and is thus invalid."
)
raise InvalidArgument(f"{signature!r} is not a valid gufunc signature")
input_shapes, output_shapes = (
tuple(tuple(re.findall(_DIMENSION, a)) for a in re.findall(_SHAPE, arg_list))
for arg_list in signature.split("->")
)
assert len(output_shapes) == 1
result_shape = output_shapes[0]
# Check that there are no names in output shape that do not appear in inputs.
# (kept out of parser function for easier generation of test values)
# We also disallow frozen optional dimensions - this is ambiguous as there is
# no way to share an un-named dimension between shapes. Maybe just padding?
# Anyway, we disallow it pending clarification from upstream.
for shape in (*input_shapes, result_shape):
for name in shape:
try:
int(name.strip("?"))
if "?" in name:
raise InvalidArgument(
f"Got dimension {name!r}, but handling of frozen optional dimensions "
"is ambiguous. If you known how this should work, please "
"contact us to get this fixed and documented ({signature=})."
)
except ValueError:
names_in = {n.strip("?") for shp in input_shapes for n in shp}
names_out = {n.strip("?") for n in result_shape}
if name.strip("?") in (names_out - names_in):
raise InvalidArgument(
"The {name!r} dimension only appears in the output shape, and is "
"not frozen, so the size is not determined ({signature=})."
) from None
return _GUfuncSig(input_shapes=input_shapes, result_shape=result_shape)
@defines_strategy()
def mutually_broadcastable_shapes(
*,
num_shapes: Union[UniqueIdentifier, int] = not_set,
signature: Union[UniqueIdentifier, str] = not_set,
base_shape: Shape = (),
min_dims: int = 0,
max_dims: Optional[int] = None,
min_side: int = 1,
max_side: Optional[int] = None,
) -> st.SearchStrategy[BroadcastableShapes]:
"""Return a strategy for a specified number of shapes N that are
mutually-broadcastable with one another and with the provided base shape.
* ``num_shapes`` is the number of mutually broadcast-compatible shapes to generate.
* ``base_shape`` is the shape against which all generated shapes can broadcast.
The default shape is empty, which corresponds to a scalar and thus does
not constrain broadcasting at all.
* ``min_dims`` is the smallest length that the generated shape can possess.
* ``max_dims`` is the largest length that the generated shape can possess,
defaulting to ``max(len(shape), min_dims) + 2``.
* ``min_side`` is the smallest size that an unaligned dimension can possess.
* ``max_side`` is the largest size that an unaligned dimension can possess,
defaulting to 2 plus the size of the largest aligned dimension.
The strategy will generate a :obj:`python:typing.NamedTuple` containing:
* ``input_shapes`` as a tuple of the N generated shapes.
* ``result_shape`` as the resulting shape produced by broadcasting the N shapes
with the base shape.
The following are some examples drawn from this strategy.
.. code-block:: pycon
>>> # Draw three shapes where each shape is broadcast-compatible with (2, 3)
... strat = mutually_broadcastable_shapes(num_shapes=3, base_shape=(2, 3))
>>> for _ in range(5):
... print(strat.example())
BroadcastableShapes(input_shapes=((4, 1, 3), (4, 2, 3), ()), result_shape=(4, 2, 3))
BroadcastableShapes(input_shapes=((3,), (1, 3), (2, 3)), result_shape=(2, 3))
BroadcastableShapes(input_shapes=((), (), ()), result_shape=())
BroadcastableShapes(input_shapes=((3,), (), (3,)), result_shape=(3,))
BroadcastableShapes(input_shapes=((1, 2, 3), (3,), ()), result_shape=(1, 2, 3))
"""
arg_msg = "Pass either the `num_shapes` or the `signature` argument, but not both."
if num_shapes is not not_set:
check_argument(signature is not_set, arg_msg)
check_type(int, num_shapes, "num_shapes")
assert isinstance(num_shapes, int) # for mypy
parsed_signature = None
sig_dims = 0
else:
check_argument(signature is not not_set, arg_msg)
if signature is None:
raise InvalidArgument(
"Expected a string, but got invalid signature=None. "
"(maybe .signature attribute of an element-wise ufunc?)"
)
check_type(str, signature, "signature")
parsed_signature = _hypothesis_parse_gufunc_signature(signature)
all_shapes = (*parsed_signature.input_shapes, parsed_signature.result_shape)
sig_dims = min(len(s) for s in all_shapes)
num_shapes = len(parsed_signature.input_shapes)
if num_shapes < 1:
raise InvalidArgument(f"num_shapes={num_shapes} must be at least 1")
check_type(tuple, base_shape, "base_shape")
check_type(int, min_side, "min_side")
check_type(int, min_dims, "min_dims")
check_valid_dims(min_dims, "min_dims")
strict_check = max_dims is not None
if max_dims is None:
max_dims = min(max(len(base_shape), min_dims) + 2, NDIM_MAX - sig_dims)
check_type(int, max_dims, "max_dims")
check_valid_dims(max_dims, "max_dims")
if max_side is None:
max_side = max(base_shape[-max_dims:] + (min_side,)) + 2
check_type(int, max_side, "max_side")
order_check("dims", 0, min_dims, max_dims)
order_check("side", 0, min_side, max_side)
if signature is not None and max_dims > NDIM_MAX - sig_dims:
raise InvalidArgument(
f"max_dims={signature!r} would exceed the {NDIM_MAX}-dimension"
"limit Hypothesis imposes on array shapes, "
f"given signature={parsed_signature!r}"
)
if strict_check:
dims = max_dims
bound_name = "max_dims"
else:
dims = min_dims
bound_name = "min_dims"
# check for unsatisfiable min_side
if not all(min_side <= s for s in base_shape[::-1][:dims] if s != 1):
raise InvalidArgument(
f"Given base_shape={base_shape}, there are no broadcast-compatible "
f"shapes that satisfy: {bound_name}={dims} and min_side={min_side}"
)
# check for unsatisfiable [min_side, max_side]
if not (
min_side <= 1 <= max_side or all(s <= max_side for s in base_shape[::-1][:dims])
):
raise InvalidArgument(
f"Given base_shape={base_shape}, there are no broadcast-compatible "
f"shapes that satisfy all of {bound_name}={dims}, "
f"min_side={min_side}, and max_side={max_side}"
)
if not strict_check:
# reduce max_dims to exclude unsatisfiable dimensions
for n, s in zip(range(max_dims), base_shape[::-1]):
if s < min_side and s != 1:
max_dims = n
break
elif not (min_side <= 1 <= max_side or s <= max_side):
max_dims = n
break
return MutuallyBroadcastableShapesStrategy(
num_shapes=num_shapes,
signature=parsed_signature,
base_shape=base_shape,
min_dims=min_dims,
max_dims=max_dims,
min_side=min_side,
max_side=max_side,
)
class MutuallyBroadcastableShapesStrategy(st.SearchStrategy):
def __init__(
self,
num_shapes,
signature=None,
base_shape=(),
min_dims=0,
max_dims=None,
min_side=1,
max_side=None,
):
super().__init__()
self.base_shape = base_shape
self.side_strat = st.integers(min_side, max_side)
self.num_shapes = num_shapes
self.signature = signature
self.min_dims = min_dims
self.max_dims = max_dims
self.min_side = min_side
self.max_side = max_side
self.size_one_allowed = self.min_side <= 1 <= self.max_side
def do_draw(self, data):
# We don't usually have a gufunc signature; do the common case first & fast.
if self.signature is None:
return self._draw_loop_dimensions(data)
# When we *do*, draw the core dims, then draw loop dims, and finally combine.
core_in, core_res = self._draw_core_dimensions(data)
# If some core shape has omitted optional dimensions, it's an error to add
# loop dimensions to it. We never omit core dims if min_dims >= 1.
# This ensures that we respect Numpy's gufunc broadcasting semantics and user
# constraints without needing to check whether the loop dims will be
# interpreted as an invalid substitute for the omitted core dims.
# We may implement this check later!
use = [None not in shp for shp in core_in]
loop_in, loop_res = self._draw_loop_dimensions(data, use=use)
def add_shape(loop, core):
return tuple(x for x in (loop + core)[-NDIM_MAX:] if x is not None)
return BroadcastableShapes(
input_shapes=tuple(add_shape(l_in, c) for l_in, c in zip(loop_in, core_in)),
result_shape=add_shape(loop_res, core_res),
)
def _draw_core_dimensions(self, data):
# Draw gufunc core dimensions, with None standing for optional dimensions
# that will not be present in the final shape. We track omitted dims so
# that we can do an accurate per-shape length cap.
dims = {}
shapes = []
for shape in (*self.signature.input_shapes, self.signature.result_shape):
shapes.append([])
for name in shape:
if name.isdigit():
shapes[-1].append(int(name))
continue
if name not in dims:
dim = name.strip("?")
dims[dim] = data.draw(self.side_strat)
if self.min_dims == 0 and not data.draw_boolean(7 / 8):
dims[dim + "?"] = None
else:
dims[dim + "?"] = dims[dim]
shapes[-1].append(dims[name])
return tuple(tuple(s) for s in shapes[:-1]), tuple(shapes[-1])
def _draw_loop_dimensions(self, data, use=None):
# All shapes are handled in column-major order; i.e. they are reversed
base_shape = self.base_shape[::-1]
result_shape = list(base_shape)
shapes = [[] for _ in range(self.num_shapes)]
if use is None:
use = [True for _ in range(self.num_shapes)]
else:
assert len(use) == self.num_shapes
assert all(isinstance(x, bool) for x in use)
_gap = self.max_dims - self.min_dims
p_keep_extending_shape = _calc_p_continue(desired_avg=_gap / 2, max_size=_gap)
for dim_count in range(1, self.max_dims + 1):
dim = dim_count - 1
# We begin by drawing a valid dimension-size for the given
# dimension. This restricts the variability across the shapes
# at this dimension such that they can only choose between
# this size and a singleton dimension.
if len(base_shape) < dim_count or base_shape[dim] == 1:
# dim is unrestricted by the base-shape: shrink to min_side
dim_side = data.draw(self.side_strat)
elif base_shape[dim] <= self.max_side:
# dim is aligned with non-singleton base-dim
dim_side = base_shape[dim]
else:
# only a singleton is valid in alignment with the base-dim
dim_side = 1
allowed_sides = sorted([1, dim_side]) # shrink to 0 when available
for shape_id, shape in enumerate(shapes):
# Populating this dimension-size for each shape, either
# the drawn size is used or, if permitted, a singleton
# dimension.
if dim <= len(result_shape) and self.size_one_allowed:
# aligned: shrink towards size 1
side = data.draw(st.sampled_from(allowed_sides))
else:
side = dim_side
# Use a trick where where a biased coin is queried to see
# if the given shape-tuple will continue to be grown. All
# of the relevant draws will still be made for the given
# shape-tuple even if it is no longer being added to.
# This helps to ensure more stable shrinking behavior.
if self.min_dims < dim_count:
use[shape_id] &= data.draw_boolean(p_keep_extending_shape)
if use[shape_id]:
shape.append(side)
if len(result_shape) < len(shape):
result_shape.append(shape[-1])
elif shape[-1] != 1 and result_shape[dim] == 1:
result_shape[dim] = shape[-1]
if not any(use):
break
result_shape = result_shape[: max(map(len, [self.base_shape, *shapes]))]
assert len(shapes) == self.num_shapes
assert all(self.min_dims <= len(s) <= self.max_dims for s in shapes)
assert all(self.min_side <= s <= self.max_side for side in shapes for s in side)
return BroadcastableShapes(
input_shapes=tuple(tuple(reversed(shape)) for shape in shapes),
result_shape=tuple(reversed(result_shape)),
)
class BasicIndexStrategy(st.SearchStrategy):
def __init__(
self,
shape,
min_dims,
max_dims,
allow_ellipsis,
allow_newaxis,
allow_fewer_indices_than_dims,
):
self.shape = shape
self.min_dims = min_dims
self.max_dims = max_dims
self.allow_ellipsis = allow_ellipsis
self.allow_newaxis = allow_newaxis
# allow_fewer_indices_than_dims=False will disable generating indices
# that don't cover all axes, i.e. indices that will flat index arrays.
# This is necessary for the Array API as such indices are not supported.
self.allow_fewer_indices_than_dims = allow_fewer_indices_than_dims
def do_draw(self, data):
# General plan: determine the actual selection up front with a straightforward
# approach that shrinks well, then complicate it by inserting other things.
result = []
for dim_size in self.shape:
if dim_size == 0:
result.append(slice(None))
continue
strategy = st.integers(-dim_size, dim_size - 1) | st.slices(dim_size)
result.append(data.draw(strategy))
# Insert some number of new size-one dimensions if allowed
result_dims = sum(isinstance(idx, slice) for idx in result)
while (
self.allow_newaxis
and result_dims < self.max_dims
and (result_dims < self.min_dims or data.draw(st.booleans()))
):
i = data.draw(st.integers(0, len(result)))
result.insert(i, None) # Note that `np.newaxis is None`
result_dims += 1
# Check that we'll have the right number of dimensions; reject if not.
# It's easy to do this by construction if you don't care about shrinking,
# which is really important for array shapes. So we filter instead.
assume(self.min_dims <= result_dims <= self.max_dims)
# This is a quick-and-dirty way to insert ..., xor shorten the indexer,
# but it means we don't have to do any structural analysis.
if self.allow_ellipsis and data.draw(st.booleans()):
# Choose an index; then replace all adjacent whole-dimension slices.
i = j = data.draw(st.integers(0, len(result)))
while i > 0 and result[i - 1] == slice(None):
i -= 1
while j < len(result) and result[j] == slice(None):
j += 1
result[i:j] = [Ellipsis]
elif self.allow_fewer_indices_than_dims: # pragma: no cover
while result[-1:] == [slice(None, None)] and data.draw(st.integers(0, 7)):
result.pop()
if len(result) == 1 and data.draw(st.booleans()):
# Sometimes generate bare element equivalent to a length-one tuple
return result[0]
return tuple(result)

View File

@@ -0,0 +1,225 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
"""
Write patches which add @example() decorators for discovered test cases.
Requires `hypothesis[codemods,ghostwriter]` installed, i.e. black and libcst.
This module is used by Hypothesis' builtin pytest plugin for failing examples
discovered during testing, and by HypoFuzz for _covering_ examples discovered
during fuzzing.
"""
import difflib
import hashlib
import inspect
import re
import sys
from ast import literal_eval
from contextlib import suppress
from datetime import date, datetime, timedelta, timezone
from pathlib import Path
import libcst as cst
from libcst import matchers as m
from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand
from hypothesis.configuration import storage_directory
from hypothesis.version import __version__
try:
import black
except ImportError:
black = None # type: ignore
HEADER = f"""\
From HEAD Mon Sep 17 00:00:00 2001
From: Hypothesis {__version__} <no-reply@hypothesis.works>
Date: {{when:%a, %d %b %Y %H:%M:%S}}
Subject: [PATCH] {{msg}}
---
"""
FAIL_MSG = "discovered failure"
_space_only_re = re.compile("^ +$", re.MULTILINE)
_leading_space_re = re.compile("(^[ ]*)(?:[^ \n])", re.MULTILINE)
def dedent(text):
# Simplified textwrap.dedent, for valid Python source code only
text = _space_only_re.sub("", text)
prefix = min(_leading_space_re.findall(text), key=len)
return re.sub(r"(?m)^" + prefix, "", text), prefix
def indent(text: str, prefix: str) -> str:
return "".join(prefix + line for line in text.splitlines(keepends=True))
class AddExamplesCodemod(VisitorBasedCodemodCommand):
DESCRIPTION = "Add explicit examples to failing tests."
def __init__(self, context, fn_examples, strip_via=(), dec="example", width=88):
"""Add @example() decorator(s) for failing test(s).
`code` is the source code of the module where the test functions are defined.
`fn_examples` is a dict of function name to list-of-failing-examples.
"""
assert fn_examples, "This codemod does nothing without fn_examples."
super().__init__(context)
self.decorator_func = cst.parse_expression(dec)
self.line_length = width
value_in_strip_via = m.MatchIfTrue(lambda x: literal_eval(x.value) in strip_via)
self.strip_matching = m.Call(
m.Attribute(m.Call(), m.Name("via")),
[m.Arg(m.SimpleString() & value_in_strip_via)],
)
# Codemod the failing examples to Call nodes usable as decorators
self.fn_examples = {
k: tuple(self.__call_node_to_example_dec(ex, via) for ex, via in nodes)
for k, nodes in fn_examples.items()
}
def __call_node_to_example_dec(self, node, via):
# If we have black installed, remove trailing comma, _unless_ there's a comment
node = node.with_changes(
func=self.decorator_func,
args=[
a.with_changes(
comma=a.comma
if m.findall(a.comma, m.Comment())
else cst.MaybeSentinel.DEFAULT
)
for a in node.args
]
if black
else node.args,
)
# Note: calling a method on a decorator requires PEP-614, i.e. Python 3.9+,
# but plumbing two cases through doesn't seem worth the trouble :-/
via = cst.Call(
func=cst.Attribute(node, cst.Name("via")),
args=[cst.Arg(cst.SimpleString(repr(via)))],
)
if black: # pragma: no branch
pretty = black.format_str(
cst.Module([]).code_for_node(via),
mode=black.FileMode(line_length=self.line_length),
)
via = cst.parse_expression(pretty.strip())
return cst.Decorator(via)
def leave_FunctionDef(self, _, updated_node):
return updated_node.with_changes(
# TODO: improve logic for where in the list to insert this decorator
decorators=tuple(
d
for d in updated_node.decorators
# `findall()` to see through the identity function workaround on py38
if not m.findall(d, self.strip_matching)
)
+ self.fn_examples.get(updated_node.name.value, ())
)
def get_patch_for(func, failing_examples, *, strip_via=()):
# Skip this if we're unable to find the location or source of this function.
try:
module = sys.modules[func.__module__]
fname = Path(module.__file__).relative_to(Path.cwd())
before = inspect.getsource(func)
except Exception:
return None
# The printed examples might include object reprs which are invalid syntax,
# so we parse here and skip over those. If _none_ are valid, there's no patch.
call_nodes = []
for ex, via in set(failing_examples):
with suppress(Exception):
node = cst.parse_expression(ex)
assert isinstance(node, cst.Call), node
# Check for st.data(), which doesn't support explicit examples
data = m.Arg(m.Call(m.Name("data"), args=[m.Arg(m.Ellipsis())]))
if m.matches(node, m.Call(args=[m.ZeroOrMore(), data, m.ZeroOrMore()])):
return None
call_nodes.append((node, via))
if not call_nodes:
return None
if (
module.__dict__.get("hypothesis") is sys.modules["hypothesis"]
and "given" not in module.__dict__ # more reliably present than `example`
):
decorator_func = "hypothesis.example"
else:
decorator_func = "example"
# Do the codemod and return a triple containing location and replacement info.
dedented, prefix = dedent(before)
try:
node = cst.parse_module(dedented)
except Exception: # pragma: no cover
# inspect.getsource() sometimes returns a decorator alone, which is invalid
return None
after = AddExamplesCodemod(
CodemodContext(),
fn_examples={func.__name__: call_nodes},
strip_via=strip_via,
dec=decorator_func,
width=88 - len(prefix), # to match Black's default formatting
).transform_module(node)
return (str(fname), before, indent(after.code, prefix=prefix))
def make_patch(triples, *, msg="Hypothesis: add explicit examples", when=None):
"""Create a patch for (fname, before, after) triples."""
assert triples, "attempted to create empty patch"
when = when or datetime.now(tz=timezone.utc)
by_fname = {}
for fname, before, after in triples:
by_fname.setdefault(Path(fname), []).append((before, after))
diffs = [HEADER.format(msg=msg, when=when)]
for fname, changes in sorted(by_fname.items()):
source_before = source_after = fname.read_text(encoding="utf-8")
for before, after in changes:
source_after = source_after.replace(before.rstrip(), after.rstrip(), 1)
ud = difflib.unified_diff(
source_before.splitlines(keepends=True),
source_after.splitlines(keepends=True),
fromfile=str(fname),
tofile=str(fname),
)
diffs.append("".join(ud))
return "".join(diffs)
def save_patch(patch: str, *, slug: str = "") -> Path: # pragma: no cover
assert re.fullmatch(r"|[a-z]+-", slug), f"malformed {slug=}"
now = date.today().isoformat()
cleaned = re.sub(r"^Date: .+?$", "", patch, count=1, flags=re.MULTILINE)
hash8 = hashlib.sha1(cleaned.encode()).hexdigest()[:8]
fname = Path(storage_directory("patches", f"{now}--{slug}{hash8}.patch"))
fname.parent.mkdir(parents=True, exist_ok=True)
fname.write_text(patch, encoding="utf-8")
return fname.relative_to(Path.cwd())
def gc_patches(slug: str = "") -> None: # pragma: no cover
cutoff = date.today() - timedelta(days=7)
for fname in Path(storage_directory("patches")).glob(
f"????-??-??--{slug}????????.patch"
):
if date.fromisoformat(fname.stem.split("--")[0]) < cutoff:
fname.unlink()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,345 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
"""
.. _hypothesis-cli:
----------------
hypothesis[cli]
----------------
::
$ hypothesis --help
Usage: hypothesis [OPTIONS] COMMAND [ARGS]...
Options:
--version Show the version and exit.
-h, --help Show this message and exit.
Commands:
codemod `hypothesis codemod` refactors deprecated or inefficient code.
fuzz [hypofuzz] runs tests with an adaptive coverage-guided fuzzer.
write `hypothesis write` writes property-based tests for you!
This module requires the :pypi:`click` package, and provides Hypothesis' command-line
interface, for e.g. :doc:`'ghostwriting' tests <ghostwriter>` via the terminal.
It's also where `HypoFuzz <https://hypofuzz.com/>`__ adds the :command:`hypothesis fuzz`
command (`learn more about that here <https://hypofuzz.com/docs/quickstart.html>`__).
"""
import builtins
import importlib
import inspect
import sys
import types
from difflib import get_close_matches
from functools import partial
from multiprocessing import Pool
from pathlib import Path
try:
import pytest
except ImportError:
pytest = None # type: ignore
MESSAGE = """
The Hypothesis command-line interface requires the `{}` package,
which you do not have installed. Run:
python -m pip install --upgrade 'hypothesis[cli]'
and try again.
"""
try:
import click
except ImportError:
def main():
"""If `click` is not installed, tell the user to install it then exit."""
sys.stderr.write(MESSAGE.format("click"))
sys.exit(1)
else:
# Ensure that Python scripts in the current working directory are importable,
# on the principle that Ghostwriter should 'just work' for novice users. Note
# that we append rather than prepend to the module search path, so this will
# never shadow the stdlib or installed packages.
sys.path.append(".")
@click.group(context_settings={"help_option_names": ("-h", "--help")})
@click.version_option()
def main():
pass
def obj_name(s: str) -> object:
"""This "type" imports whatever object is named by a dotted string."""
s = s.strip()
if "/" in s or "\\" in s:
raise click.UsageError(
"Remember that the ghostwriter should be passed the name of a module, not a path."
) from None
try:
return importlib.import_module(s)
except ImportError:
pass
classname = None
if "." not in s:
modulename, module, funcname = "builtins", builtins, s
else:
modulename, funcname = s.rsplit(".", 1)
try:
module = importlib.import_module(modulename)
except ImportError as err:
try:
modulename, classname = modulename.rsplit(".", 1)
module = importlib.import_module(modulename)
except (ImportError, ValueError):
if s.endswith(".py"):
raise click.UsageError(
"Remember that the ghostwriter should be passed the name of a module, not a file."
) from None
raise click.UsageError(
f"Failed to import the {modulename} module for introspection. "
"Check spelling and your Python import path, or use the Python API?"
) from err
def describe_close_matches(
module_or_class: types.ModuleType, objname: str
) -> str:
public_names = [
name for name in vars(module_or_class) if not name.startswith("_")
]
matches = get_close_matches(objname, public_names)
if matches:
return f" Closest matches: {matches!r}"
else:
return ""
if classname is None:
try:
return getattr(module, funcname)
except AttributeError as err:
if funcname == "py":
# Likely attempted to pass a local file (Eg., "myscript.py") instead of a module name
raise click.UsageError(
"Remember that the ghostwriter should be passed the name of a module, not a file."
f"\n\tTry: hypothesis write {s[:-3]}"
) from None
raise click.UsageError(
f"Found the {modulename!r} module, but it doesn't have a "
f"{funcname!r} attribute."
+ describe_close_matches(module, funcname)
) from err
else:
try:
func_class = getattr(module, classname)
except AttributeError as err:
raise click.UsageError(
f"Found the {modulename!r} module, but it doesn't have a "
f"{classname!r} class." + describe_close_matches(module, classname)
) from err
try:
return getattr(func_class, funcname)
except AttributeError as err:
if inspect.isclass(func_class):
func_class_is = "class"
else:
func_class_is = "attribute"
raise click.UsageError(
f"Found the {modulename!r} module and {classname!r} {func_class_is}, "
f"but it doesn't have a {funcname!r} attribute."
+ describe_close_matches(func_class, funcname)
) from err
def _refactor(func, fname):
try:
oldcode = Path(fname).read_text(encoding="utf-8")
except (OSError, UnicodeError) as err:
# Permissions or encoding issue, or file deleted, etc.
return f"skipping {fname!r} due to {err}"
if "hypothesis" not in oldcode:
return # This is a fast way to avoid running slow no-op codemods
try:
newcode = func(oldcode)
except Exception as err:
from libcst import ParserSyntaxError
if isinstance(err, ParserSyntaxError):
from hypothesis.extra._patching import indent
msg = indent(str(err).replace("\n\n", "\n"), " ").strip()
return f"skipping {fname!r} due to {msg}"
raise
if newcode != oldcode:
Path(fname).write_text(newcode, encoding="utf-8")
@main.command() # type: ignore # Click adds the .command attribute
@click.argument("path", type=str, required=True, nargs=-1)
def codemod(path):
"""`hypothesis codemod` refactors deprecated or inefficient code.
It adapts `python -m libcst.tool`, removing many features and config options
which are rarely relevant for this purpose. If you need more control, we
encourage you to use the libcst CLI directly; if not this one is easier.
PATH is the file(s) or directories of files to format in place, or
"-" to read from stdin and write to stdout.
"""
try:
from libcst.codemod import gather_files
from hypothesis.extra import codemods
except ImportError:
sys.stderr.write(
"You are missing required dependencies for this option. Run:\n\n"
" python -m pip install --upgrade hypothesis[codemods]\n\n"
"and try again."
)
sys.exit(1)
# Special case for stdin/stdout usage
if "-" in path:
if len(path) > 1:
raise Exception(
"Cannot specify multiple paths when reading from stdin!"
)
print("Codemodding from stdin", file=sys.stderr)
print(codemods.refactor(sys.stdin.read()))
return 0
# Find all the files to refactor, and then codemod them
files = gather_files(path)
errors = set()
if len(files) <= 1:
errors.add(_refactor(codemods.refactor, *files))
else:
with Pool() as pool:
for msg in pool.imap_unordered(
partial(_refactor, codemods.refactor), files
):
errors.add(msg)
errors.discard(None)
for msg in errors:
print(msg, file=sys.stderr)
return 1 if errors else 0
@main.command() # type: ignore # Click adds the .command attribute
@click.argument("func", type=obj_name, required=True, nargs=-1)
@click.option(
"--roundtrip",
"writer",
flag_value="roundtrip",
help="start by testing write/read or encode/decode!",
)
@click.option(
"--equivalent",
"writer",
flag_value="equivalent",
help="very useful when optimising or refactoring code",
)
@click.option(
"--errors-equivalent",
"writer",
flag_value="errors-equivalent",
help="--equivalent, but also allows consistent errors",
)
@click.option(
"--idempotent",
"writer",
flag_value="idempotent",
help="check that f(x) == f(f(x))",
)
@click.option(
"--binary-op",
"writer",
flag_value="binary_operation",
help="associativity, commutativity, identity element",
)
# Note: we deliberately omit a --ufunc flag, because the magic()
# detection of ufuncs is both precise and complete.
@click.option(
"--style",
type=click.Choice(["pytest", "unittest"]),
default="pytest" if pytest else "unittest",
help="pytest-style function, or unittest-style method?",
)
@click.option(
"-e",
"--except",
"except_",
type=obj_name,
multiple=True,
help="dotted name of exception(s) to ignore",
)
@click.option(
"--annotate/--no-annotate",
default=None,
help="force ghostwritten tests to be type-annotated (or not). "
"By default, match the code to test.",
)
def write(func, writer, except_, style, annotate): # \b disables autowrap
"""`hypothesis write` writes property-based tests for you!
Type annotations are helpful but not required for our advanced introspection
and templating logic. Try running the examples below to see how it works:
\b
hypothesis write gzip
hypothesis write numpy.matmul
hypothesis write pandas.from_dummies
hypothesis write re.compile --except re.error
hypothesis write --equivalent ast.literal_eval eval
hypothesis write --roundtrip json.dumps json.loads
hypothesis write --style=unittest --idempotent sorted
hypothesis write --binary-op operator.add
"""
# NOTE: if you want to call this function from Python, look instead at the
# ``hypothesis.extra.ghostwriter`` module. Click-decorated functions have
# a different calling convention, and raise SystemExit instead of returning.
kwargs = {"except_": except_ or (), "style": style, "annotate": annotate}
if writer is None:
writer = "magic"
elif writer == "idempotent" and len(func) > 1:
raise click.UsageError("Test functions for idempotence one at a time.")
elif writer == "roundtrip" and len(func) == 1:
writer = "idempotent"
elif "equivalent" in writer and len(func) == 1:
writer = "fuzz"
if writer == "errors-equivalent":
writer = "equivalent"
kwargs["allow_same_errors"] = True
try:
from hypothesis.extra import ghostwriter
except ImportError:
sys.stderr.write(MESSAGE.format("black"))
sys.exit(1)
code = getattr(ghostwriter, writer)(*func, **kwargs)
try:
from rich.console import Console
from rich.syntax import Syntax
from hypothesis.utils.terminal import guess_background_color
except ImportError:
print(code)
else:
try:
theme = "default" if guess_background_color() == "light" else "monokai"
code = Syntax(code, "python", background_color="default", theme=theme)
Console().print(code, soft_wrap=True)
except Exception:
print("# Error while syntax-highlighting code", file=sys.stderr)
print(code)

View File

@@ -0,0 +1,284 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
"""
.. _codemods:
--------------------
hypothesis[codemods]
--------------------
This module provides codemods based on the :pypi:`LibCST` library, which can
both detect *and automatically fix* issues with code that uses Hypothesis,
including upgrading from deprecated features to our recommended style.
You can run the codemods via our CLI::
$ hypothesis codemod --help
Usage: hypothesis codemod [OPTIONS] PATH...
`hypothesis codemod` refactors deprecated or inefficient code.
It adapts `python -m libcst.tool`, removing many features and config
options which are rarely relevant for this purpose. If you need more
control, we encourage you to use the libcst CLI directly; if not this one
is easier.
PATH is the file(s) or directories of files to format in place, or "-" to
read from stdin and write to stdout.
Options:
-h, --help Show this message and exit.
Alternatively you can use ``python -m libcst.tool``, which offers more control
at the cost of additional configuration (adding ``'hypothesis.extra'`` to the
``modules`` list in ``.libcst.codemod.yaml``) and `some issues on Windows
<https://github.com/Instagram/LibCST/issues/435>`__.
.. autofunction:: refactor
"""
import functools
import importlib
from inspect import Parameter, signature
from typing import ClassVar, List
import libcst as cst
import libcst.matchers as m
from libcst.codemod import VisitorBasedCodemodCommand
def refactor(code: str) -> str:
"""Update a source code string from deprecated to modern Hypothesis APIs.
This may not fix *all* the deprecation warnings in your code, but we're
confident that it will be easier than doing it all by hand.
We recommend using the CLI, but if you want a Python function here it is.
"""
context = cst.codemod.CodemodContext()
mod = cst.parse_module(code)
transforms: List[VisitorBasedCodemodCommand] = [
HypothesisFixPositionalKeywonlyArgs(context),
HypothesisFixComplexMinMagnitude(context),
HypothesisFixHealthcheckAll(context),
HypothesisFixCharactersArguments(context),
]
for transform in transforms:
mod = transform.transform_module(mod)
return mod.code
def match_qualname(name):
# We use the metadata to get qualname instead of matching directly on function
# name, because this handles some scope and "from x import y as z" issues.
return m.MatchMetadataIfTrue(
cst.metadata.QualifiedNameProvider,
# If there are multiple possible qualnames, e.g. due to conditional imports,
# be conservative. Better to leave the user to fix a few things by hand than
# to break their code while attempting to refactor it!
lambda qualnames: all(n.name == name for n in qualnames),
)
class HypothesisFixComplexMinMagnitude(VisitorBasedCodemodCommand):
"""Fix a deprecated min_magnitude=None argument for complex numbers::
st.complex_numbers(min_magnitude=None) -> st.complex_numbers(min_magnitude=0)
Note that this should be run *after* ``HypothesisFixPositionalKeywonlyArgs``,
in order to handle ``st.complex_numbers(None)``.
"""
DESCRIPTION = "Fix a deprecated min_magnitude=None argument for complex numbers."
METADATA_DEPENDENCIES = (cst.metadata.QualifiedNameProvider,)
@m.call_if_inside(
m.Call(metadata=match_qualname("hypothesis.strategies.complex_numbers"))
)
def leave_Arg(self, original_node, updated_node):
if m.matches(
updated_node, m.Arg(keyword=m.Name("min_magnitude"), value=m.Name("None"))
):
return updated_node.with_changes(value=cst.Integer("0"))
return updated_node
@functools.lru_cache
def get_fn(import_path):
mod, fn = import_path.rsplit(".", 1)
return getattr(importlib.import_module(mod), fn)
class HypothesisFixPositionalKeywonlyArgs(VisitorBasedCodemodCommand):
"""Fix positional arguments for newly keyword-only parameters, e.g.::
st.fractions(0, 1, 9) -> st.fractions(0, 1, max_denominator=9)
Applies to a majority of our public API, since keyword-only parameters are
great but we couldn't use them until after we dropped support for Python 2.
"""
DESCRIPTION = "Fix positional arguments for newly keyword-only parameters."
METADATA_DEPENDENCIES = (cst.metadata.QualifiedNameProvider,)
kwonly_functions = (
"hypothesis.target",
"hypothesis.find",
"hypothesis.extra.lark.from_lark",
"hypothesis.extra.numpy.arrays",
"hypothesis.extra.numpy.array_shapes",
"hypothesis.extra.numpy.unsigned_integer_dtypes",
"hypothesis.extra.numpy.integer_dtypes",
"hypothesis.extra.numpy.floating_dtypes",
"hypothesis.extra.numpy.complex_number_dtypes",
"hypothesis.extra.numpy.datetime64_dtypes",
"hypothesis.extra.numpy.timedelta64_dtypes",
"hypothesis.extra.numpy.byte_string_dtypes",
"hypothesis.extra.numpy.unicode_string_dtypes",
"hypothesis.extra.numpy.array_dtypes",
"hypothesis.extra.numpy.nested_dtypes",
"hypothesis.extra.numpy.valid_tuple_axes",
"hypothesis.extra.numpy.broadcastable_shapes",
"hypothesis.extra.pandas.indexes",
"hypothesis.extra.pandas.series",
"hypothesis.extra.pandas.columns",
"hypothesis.extra.pandas.data_frames",
"hypothesis.provisional.domains",
"hypothesis.stateful.run_state_machine_as_test",
"hypothesis.stateful.rule",
"hypothesis.stateful.initialize",
"hypothesis.strategies.floats",
"hypothesis.strategies.lists",
"hypothesis.strategies.sets",
"hypothesis.strategies.frozensets",
"hypothesis.strategies.iterables",
"hypothesis.strategies.dictionaries",
"hypothesis.strategies.characters",
"hypothesis.strategies.text",
"hypothesis.strategies.from_regex",
"hypothesis.strategies.binary",
"hypothesis.strategies.fractions",
"hypothesis.strategies.decimals",
"hypothesis.strategies.recursive",
"hypothesis.strategies.complex_numbers",
"hypothesis.strategies.shared",
"hypothesis.strategies.uuids",
"hypothesis.strategies.runner",
"hypothesis.strategies.functions",
"hypothesis.strategies.datetimes",
"hypothesis.strategies.times",
)
def leave_Call(self, original_node, updated_node):
"""Convert positional to keyword arguments."""
metadata = self.get_metadata(cst.metadata.QualifiedNameProvider, original_node)
qualnames = {qn.name for qn in metadata}
# If this isn't one of our known functions, or it has no posargs, stop there.
if (
len(qualnames) != 1
or not qualnames.intersection(self.kwonly_functions)
or not m.matches(
updated_node,
m.Call(
func=m.DoesNotMatch(m.Call()),
args=[m.Arg(keyword=None), m.ZeroOrMore()],
),
)
):
return updated_node
# Get the actual function object so that we can inspect the signature.
# This does e.g. incur a dependency on Numpy to fix Numpy-dependent code,
# but having a single source of truth about the signatures is worth it.
try:
params = signature(get_fn(*qualnames)).parameters.values()
except ModuleNotFoundError:
return updated_node
# st.floats() has a new allow_subnormal kwonly argument not at the end,
# so we do a bit more of a dance here.
if qualnames == {"hypothesis.strategies.floats"}:
params = [p for p in params if p.name != "allow_subnormal"]
if len(updated_node.args) > len(params):
return updated_node
# Create new arg nodes with the newly required keywords
assign_nospace = cst.AssignEqual(
whitespace_before=cst.SimpleWhitespace(""),
whitespace_after=cst.SimpleWhitespace(""),
)
newargs = [
arg
if arg.keyword or arg.star or p.kind is not Parameter.KEYWORD_ONLY
else arg.with_changes(keyword=cst.Name(p.name), equal=assign_nospace)
for p, arg in zip(params, updated_node.args)
]
return updated_node.with_changes(args=newargs)
class HypothesisFixHealthcheckAll(VisitorBasedCodemodCommand):
"""Replace Healthcheck.all() with list(Healthcheck)"""
DESCRIPTION = "Replace Healthcheck.all() with list(Healthcheck)"
@m.leave(m.Call(func=m.Attribute(m.Name("Healthcheck"), m.Name("all")), args=[]))
def replace_healthcheck(self, original_node, updated_node):
return updated_node.with_changes(
func=cst.Name("list"),
args=[cst.Arg(value=cst.Name("Healthcheck"))],
)
class HypothesisFixCharactersArguments(VisitorBasedCodemodCommand):
"""Fix deprecated white/blacklist arguments to characters::
st.characters(whitelist_categories=...) -> st.characters(categories=...)
st.characters(blacklist_categories=...) -> st.characters(exclude_categories=...)
st.characters(whitelist_characters=...) -> st.characters(include_characters=...)
st.characters(blacklist_characters=...) -> st.characters(exclude_characters=...)
Additionally, we drop `exclude_categories=` if `categories=` is present,
because this argument is always redundant (or an error).
"""
DESCRIPTION = "Fix deprecated white/blacklist arguments to characters."
METADATA_DEPENDENCIES = (cst.metadata.QualifiedNameProvider,)
_replacements: ClassVar = {
"whitelist_categories": "categories",
"blacklist_categories": "exclude_categories",
"whitelist_characters": "include_characters",
"blacklist_characters": "exclude_characters",
}
@m.leave(
m.Call(
metadata=match_qualname("hypothesis.strategies.characters"),
args=[
m.ZeroOrMore(),
m.Arg(keyword=m.OneOf(*map(m.Name, _replacements))),
m.ZeroOrMore(),
],
),
)
def fn(self, original_node, updated_node):
# Update to the new names
newargs = []
for arg in updated_node.args:
kw = self._replacements.get(arg.keyword.value, arg.keyword.value)
newargs.append(arg.with_changes(keyword=cst.Name(kw)))
# Drop redundant exclude_categories, which is now an error
if any(m.matches(arg, m.Arg(keyword=m.Name("categories"))) for arg in newargs):
ex = m.Arg(keyword=m.Name("exclude_categories"))
newargs = [a for a in newargs if m.matches(a, ~ex)]
return updated_node.with_changes(args=newargs)

View File

@@ -0,0 +1,64 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
"""
--------------------
hypothesis[dateutil]
--------------------
This module provides :pypi:`dateutil <python-dateutil>` timezones.
You can use this strategy to make :func:`~hypothesis.strategies.datetimes`
and :func:`~hypothesis.strategies.times` produce timezone-aware values.
"""
import datetime as dt
from dateutil import tz, zoneinfo # type: ignore
from hypothesis import strategies as st
from hypothesis.strategies._internal.utils import cacheable, defines_strategy
__all__ = ["timezones"]
def __zone_sort_key(zone):
"""Sort by absolute UTC offset at reference date,
positive first, with ties broken by name.
"""
assert zone is not None
offset = zone.utcoffset(dt.datetime(2000, 1, 1))
offset = 999 if offset is None else offset
return (abs(offset), -offset, str(zone))
@cacheable
@defines_strategy()
def timezones() -> st.SearchStrategy[dt.tzinfo]:
"""Any timezone from :pypi:`dateutil <python-dateutil>`.
This strategy minimises to UTC, or the timezone with the smallest offset
from UTC as of 2000-01-01, and is designed for use with
:py:func:`~hypothesis.strategies.datetimes`.
Note that the timezones generated by the strategy may vary depending on the
configuration of your machine. See the dateutil documentation for more
information.
"""
all_timezones = sorted(
(tz.gettz(t) for t in zoneinfo.get_zonefile_instance().zones),
key=__zone_sort_key,
)
all_timezones.insert(0, tz.UTC)
# We discard Nones in the list comprehension because Mypy knows that
# tz.gettz may return None. However this should never happen for known
# zone names, so we assert that it's impossible first.
assert None not in all_timezones
return st.sampled_from([z for z in all_timezones if z is not None])

View File

@@ -0,0 +1,30 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
from hypothesis.extra.django._fields import from_field, register_field_strategy
from hypothesis.extra.django._impl import (
LiveServerTestCase,
StaticLiveServerTestCase,
TestCase,
TransactionTestCase,
from_form,
from_model,
)
__all__ = [
"LiveServerTestCase",
"StaticLiveServerTestCase",
"TestCase",
"TransactionTestCase",
"from_field",
"from_model",
"register_field_strategy",
"from_form",
]

View File

@@ -0,0 +1,343 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
import re
import string
from datetime import timedelta
from decimal import Decimal
from functools import lru_cache
from typing import Any, Callable, Dict, Type, TypeVar, Union
import django
from django import forms as df
from django.contrib.auth.forms import UsernameField
from django.core.validators import (
validate_ipv4_address,
validate_ipv6_address,
validate_ipv46_address,
)
from django.db import models as dm
from hypothesis import strategies as st
from hypothesis.errors import InvalidArgument, ResolutionFailed
from hypothesis.internal.validation import check_type
from hypothesis.provisional import urls
from hypothesis.strategies import emails
AnyField = Union[dm.Field, df.Field]
F = TypeVar("F", bound=AnyField)
def numeric_bounds_from_validators(
field, min_value=float("-inf"), max_value=float("inf")
):
for v in field.validators:
if isinstance(v, django.core.validators.MinValueValidator):
min_value = max(min_value, v.limit_value)
elif isinstance(v, django.core.validators.MaxValueValidator):
max_value = min(max_value, v.limit_value)
return min_value, max_value
def integers_for_field(min_value, max_value):
def inner(field):
return st.integers(*numeric_bounds_from_validators(field, min_value, max_value))
return inner
@lru_cache
def timezones():
# From Django 4.0, the default is to use zoneinfo instead of pytz.
assert getattr(django.conf.settings, "USE_TZ", False)
if getattr(django.conf.settings, "USE_DEPRECATED_PYTZ", True):
from hypothesis.extra.pytz import timezones
else:
from hypothesis.strategies import timezones
return timezones()
# Mapping of field types, to strategy objects or functions of (type) -> strategy
_FieldLookUpType = Dict[
Type[AnyField],
Union[st.SearchStrategy, Callable[[Any], st.SearchStrategy]],
]
_global_field_lookup: _FieldLookUpType = {
dm.SmallIntegerField: integers_for_field(-32768, 32767),
dm.IntegerField: integers_for_field(-2147483648, 2147483647),
dm.BigIntegerField: integers_for_field(-9223372036854775808, 9223372036854775807),
dm.PositiveIntegerField: integers_for_field(0, 2147483647),
dm.PositiveSmallIntegerField: integers_for_field(0, 32767),
dm.BooleanField: st.booleans(),
dm.DateField: st.dates(),
dm.EmailField: emails(),
dm.FloatField: st.floats(),
dm.NullBooleanField: st.one_of(st.none(), st.booleans()),
dm.URLField: urls(),
dm.UUIDField: st.uuids(),
df.DateField: st.dates(),
df.DurationField: st.timedeltas(),
df.EmailField: emails(),
df.FloatField: lambda field: st.floats(
*numeric_bounds_from_validators(field), allow_nan=False, allow_infinity=False
),
df.IntegerField: integers_for_field(-2147483648, 2147483647),
df.NullBooleanField: st.one_of(st.none(), st.booleans()),
df.URLField: urls(),
df.UUIDField: st.uuids(),
}
_ipv6_strings = st.one_of(
st.ip_addresses(v=6).map(str),
st.ip_addresses(v=6).map(lambda addr: addr.exploded),
)
def register_for(field_type):
def inner(func):
_global_field_lookup[field_type] = func
return func
return inner
@register_for(dm.DateTimeField)
@register_for(df.DateTimeField)
def _for_datetime(field):
if getattr(django.conf.settings, "USE_TZ", False):
return st.datetimes(timezones=timezones())
return st.datetimes()
def using_sqlite():
try:
return (
getattr(django.conf.settings, "DATABASES", {})
.get("default", {})
.get("ENGINE", "")
.endswith(".sqlite3")
)
except django.core.exceptions.ImproperlyConfigured:
return None
@register_for(dm.TimeField)
def _for_model_time(field):
# SQLITE supports TZ-aware datetimes, but not TZ-aware times.
if getattr(django.conf.settings, "USE_TZ", False) and not using_sqlite():
return st.times(timezones=timezones())
return st.times()
@register_for(df.TimeField)
def _for_form_time(field):
if getattr(django.conf.settings, "USE_TZ", False):
return st.times(timezones=timezones())
return st.times()
@register_for(dm.DurationField)
def _for_duration(field):
# SQLite stores timedeltas as six bytes of microseconds
if using_sqlite():
delta = timedelta(microseconds=2**47 - 1)
return st.timedeltas(-delta, delta)
return st.timedeltas()
@register_for(dm.SlugField)
@register_for(df.SlugField)
def _for_slug(field):
min_size = 1
if getattr(field, "blank", False) or not getattr(field, "required", True):
min_size = 0
return st.text(
alphabet=string.ascii_letters + string.digits,
min_size=min_size,
max_size=field.max_length,
)
@register_for(dm.GenericIPAddressField)
def _for_model_ip(field):
return {
"ipv4": st.ip_addresses(v=4).map(str),
"ipv6": _ipv6_strings,
"both": st.ip_addresses(v=4).map(str) | _ipv6_strings,
}[field.protocol.lower()]
@register_for(df.GenericIPAddressField)
def _for_form_ip(field):
# the IP address form fields have no direct indication of which type
# of address they want, so direct comparison with the validator
# function has to be used instead. Sorry for the potato logic here
if validate_ipv46_address in field.default_validators:
return st.ip_addresses(v=4).map(str) | _ipv6_strings
if validate_ipv4_address in field.default_validators:
return st.ip_addresses(v=4).map(str)
if validate_ipv6_address in field.default_validators:
return _ipv6_strings
raise ResolutionFailed(f"No IP version validator on {field=}")
@register_for(dm.DecimalField)
@register_for(df.DecimalField)
def _for_decimal(field):
min_value, max_value = numeric_bounds_from_validators(field)
bound = Decimal(10**field.max_digits - 1) / (10**field.decimal_places)
return st.decimals(
min_value=max(min_value, -bound),
max_value=min(max_value, bound),
places=field.decimal_places,
)
def length_bounds_from_validators(field):
min_size = 1
max_size = field.max_length
for v in field.validators:
if isinstance(v, django.core.validators.MinLengthValidator):
min_size = max(min_size, v.limit_value)
elif isinstance(v, django.core.validators.MaxLengthValidator):
max_size = min(max_size or v.limit_value, v.limit_value)
return min_size, max_size
@register_for(dm.BinaryField)
def _for_binary(field):
min_size, max_size = length_bounds_from_validators(field)
if getattr(field, "blank", False) or not getattr(field, "required", True):
return st.just(b"") | st.binary(min_size=min_size, max_size=max_size)
return st.binary(min_size=min_size, max_size=max_size)
@register_for(dm.CharField)
@register_for(dm.TextField)
@register_for(df.CharField)
@register_for(df.RegexField)
@register_for(UsernameField)
def _for_text(field):
# We can infer a vastly more precise strategy by considering the
# validators as well as the field type. This is a minimal proof of
# concept, but we intend to leverage the idea much more heavily soon.
# See https://github.com/HypothesisWorks/hypothesis-python/issues/1116
regexes = [
re.compile(v.regex, v.flags) if isinstance(v.regex, str) else v.regex
for v in field.validators
if isinstance(v, django.core.validators.RegexValidator) and not v.inverse_match
]
if regexes:
# This strategy generates according to one of the regexes, and
# filters using the others. It can therefore learn to generate
# from the most restrictive and filter with permissive patterns.
# Not maximally efficient, but it makes pathological cases rarer.
# If you want a challenge: extend https://qntm.org/greenery to
# compute intersections of the full Python regex language.
return st.one_of(*(st.from_regex(r) for r in regexes))
# If there are no (usable) regexes, we use a standard text strategy.
min_size, max_size = length_bounds_from_validators(field)
strategy = st.text(
alphabet=st.characters(exclude_characters="\x00", exclude_categories=("Cs",)),
min_size=min_size,
max_size=max_size,
).filter(lambda s: min_size <= len(s.strip()))
if getattr(field, "blank", False) or not getattr(field, "required", True):
return st.just("") | strategy
return strategy
@register_for(df.BooleanField)
def _for_form_boolean(field):
if field.required:
return st.just(True)
return st.booleans()
def register_field_strategy(
field_type: Type[AnyField], strategy: st.SearchStrategy
) -> None:
"""Add an entry to the global field-to-strategy lookup used by
:func:`~hypothesis.extra.django.from_field`.
``field_type`` must be a subtype of :class:`django.db.models.Field` or
:class:`django.forms.Field`, which must not already be registered.
``strategy`` must be a :class:`~hypothesis.strategies.SearchStrategy`.
"""
if not issubclass(field_type, (dm.Field, df.Field)):
raise InvalidArgument(f"{field_type=} must be a subtype of Field")
check_type(st.SearchStrategy, strategy, "strategy")
if field_type in _global_field_lookup:
raise InvalidArgument(
f"{field_type=} already has a registered "
f"strategy ({_global_field_lookup[field_type]!r})"
)
if issubclass(field_type, dm.AutoField):
raise InvalidArgument("Cannot register a strategy for an AutoField")
_global_field_lookup[field_type] = strategy
def from_field(field: F) -> st.SearchStrategy[Union[F, None]]:
"""Return a strategy for values that fit the given field.
This function is used by :func:`~hypothesis.extra.django.from_form` and
:func:`~hypothesis.extra.django.from_model` for any fields that require
a value, or for which you passed ``...`` (:obj:`python:Ellipsis`) to infer
a strategy from an annotation.
It's pretty similar to the core :func:`~hypothesis.strategies.from_type`
function, with a subtle but important difference: ``from_field`` takes a
Field *instance*, rather than a Field *subtype*, so that it has access to
instance attributes such as string length and validators.
"""
check_type((dm.Field, df.Field), field, "field")
if getattr(field, "choices", False):
choices: list = []
for value, name_or_optgroup in field.choices:
if isinstance(name_or_optgroup, (list, tuple)):
choices.extend(key for key, _ in name_or_optgroup)
else:
choices.append(value)
# form fields automatically include an empty choice, strip it out
if "" in choices:
choices.remove("")
min_size = 1
if isinstance(field, (dm.CharField, dm.TextField)) and field.blank:
choices.insert(0, "")
elif isinstance(field, (df.Field)) and not field.required:
choices.insert(0, "")
min_size = 0
strategy = st.sampled_from(choices)
if isinstance(field, (df.MultipleChoiceField, df.TypedMultipleChoiceField)):
strategy = st.lists(st.sampled_from(choices), min_size=min_size)
else:
if type(field) not in _global_field_lookup:
if getattr(field, "null", False):
return st.none()
raise ResolutionFailed(f"Could not infer a strategy for {field!r}")
strategy = _global_field_lookup[type(field)] # type: ignore
if not isinstance(strategy, st.SearchStrategy):
strategy = strategy(field)
assert isinstance(strategy, st.SearchStrategy)
if field.validators:
def validate(value):
try:
field.run_validators(value)
return True
except django.core.exceptions.ValidationError:
return False
strategy = strategy.filter(validate)
if getattr(field, "null", False):
return st.none() | strategy
return strategy

View File

@@ -0,0 +1,217 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
import sys
import unittest
from functools import partial
from typing import TYPE_CHECKING, Optional, Type, TypeVar, Union
from django import forms as df, test as dt
from django.contrib.staticfiles import testing as dst
from django.core.exceptions import ValidationError
from django.db import IntegrityError, models as dm
from hypothesis import reject, strategies as st
from hypothesis.errors import InvalidArgument
from hypothesis.extra.django._fields import from_field
from hypothesis.strategies._internal.utils import defines_strategy
if sys.version_info >= (3, 10):
from types import EllipsisType as EllipsisType
elif TYPE_CHECKING:
from builtins import ellipsis as EllipsisType
else:
EllipsisType = type(Ellipsis)
ModelT = TypeVar("ModelT", bound=dm.Model)
class HypothesisTestCase:
def setup_example(self):
self._pre_setup()
def teardown_example(self, example):
self._post_teardown()
def __call__(self, result=None):
testMethod = getattr(self, self._testMethodName)
if getattr(testMethod, "is_hypothesis_test", False):
return unittest.TestCase.__call__(self, result)
else:
return dt.SimpleTestCase.__call__(self, result)
class TestCase(HypothesisTestCase, dt.TestCase):
pass
class TransactionTestCase(HypothesisTestCase, dt.TransactionTestCase):
pass
class LiveServerTestCase(HypothesisTestCase, dt.LiveServerTestCase):
pass
class StaticLiveServerTestCase(HypothesisTestCase, dst.StaticLiveServerTestCase):
pass
@defines_strategy()
def from_model(
model: Type[ModelT], /, **field_strategies: Union[st.SearchStrategy, EllipsisType]
) -> st.SearchStrategy[ModelT]:
"""Return a strategy for examples of ``model``.
.. warning::
Hypothesis creates saved models. This will run inside your testing
transaction when using the test runner, but if you use the dev console
this will leave debris in your database.
``model`` must be an subclass of :class:`~django:django.db.models.Model`.
Strategies for fields may be passed as keyword arguments, for example
``is_staff=st.just(False)``. In order to support models with fields named
"model", this is a positional-only parameter.
Hypothesis can often infer a strategy based the field type and validators,
and will attempt to do so for any required fields. No strategy will be
inferred for an :class:`~django:django.db.models.AutoField`, nullable field,
foreign key, or field for which a keyword
argument is passed to ``from_model()``. For example,
a Shop type with a foreign key to Company could be generated with::
shop_strategy = from_model(Shop, company=from_model(Company))
Like for :func:`~hypothesis.strategies.builds`, you can pass
``...`` (:obj:`python:Ellipsis`) as a keyword argument to infer a strategy for
a field which has a default value instead of using the default.
"""
if not issubclass(model, dm.Model):
raise InvalidArgument(f"{model=} must be a subtype of Model")
fields_by_name = {f.name: f for f in model._meta.concrete_fields}
for name, value in sorted(field_strategies.items()):
if value is ...:
field_strategies[name] = from_field(fields_by_name[name])
for name, field in sorted(fields_by_name.items()):
if (
name not in field_strategies
and not field.auto_created
and field.default is dm.fields.NOT_PROVIDED
):
field_strategies[name] = from_field(field)
for field in field_strategies:
if model._meta.get_field(field).primary_key:
# The primary key is generated as part of the strategy. We
# want to find any existing row with this primary key and
# overwrite its contents.
kwargs = {field: field_strategies.pop(field)}
kwargs["defaults"] = st.fixed_dictionaries(field_strategies) # type: ignore
return _models_impl(st.builds(model.objects.update_or_create, **kwargs))
# The primary key is not generated as part of the strategy, so we
# just match against any row that has the same value for all
# fields.
return _models_impl(st.builds(model.objects.get_or_create, **field_strategies))
@st.composite
def _models_impl(draw, strat):
"""Handle the nasty part of drawing a value for models()"""
try:
return draw(strat)[0]
except IntegrityError:
reject()
@defines_strategy()
def from_form(
form: Type[df.Form],
form_kwargs: Optional[dict] = None,
**field_strategies: Union[st.SearchStrategy, EllipsisType],
) -> st.SearchStrategy[df.Form]:
"""Return a strategy for examples of ``form``.
``form`` must be an subclass of :class:`~django:django.forms.Form`.
Strategies for fields may be passed as keyword arguments, for example
``is_staff=st.just(False)``.
Hypothesis can often infer a strategy based the field type and validators,
and will attempt to do so for any required fields. No strategy will be
inferred for a disabled field or field for which a keyword argument
is passed to ``from_form()``.
This function uses the fields of an unbound ``form`` instance to determine
field strategies, any keyword arguments needed to instantiate the unbound
``form`` instance can be passed into ``from_form()`` as a dict with the
keyword ``form_kwargs``. E.g.::
shop_strategy = from_form(Shop, form_kwargs={"company_id": 5})
Like for :func:`~hypothesis.strategies.builds`, you can pass
``...`` (:obj:`python:Ellipsis`) as a keyword argument to infer a strategy for
a field which has a default value instead of using the default.
"""
# currently unsupported:
# ComboField
# FilePathField
# FileField
# ImageField
form_kwargs = form_kwargs or {}
if not issubclass(form, df.BaseForm):
raise InvalidArgument(f"{form=} must be a subtype of Form")
# Forms are a little bit different from models. Model classes have
# all their fields defined, whereas forms may have different fields
# per-instance. So, we ought to instantiate the form and get the
# fields from the instance, thus we need to accept the kwargs for
# instantiation as well as the explicitly defined strategies
unbound_form = form(**form_kwargs)
fields_by_name = {}
for name, field in unbound_form.fields.items():
if isinstance(field, df.MultiValueField):
# PS: So this is a little strange, but MultiValueFields must
# have their form data encoded in a particular way for the
# values to actually be picked up by the widget instances'
# ``value_from_datadict``.
# E.g. if a MultiValueField named 'mv_field' has 3
# sub-fields then the ``value_from_datadict`` will look for
# 'mv_field_0', 'mv_field_1', and 'mv_field_2'. Here I'm
# decomposing the individual sub-fields into the names that
# the form validation process expects
for i, _field in enumerate(field.fields):
fields_by_name[f"{name}_{i}"] = _field
else:
fields_by_name[name] = field
for name, value in sorted(field_strategies.items()):
if value is ...:
field_strategies[name] = from_field(fields_by_name[name])
for name, field in sorted(fields_by_name.items()):
if name not in field_strategies and not field.disabled:
field_strategies[name] = from_field(field)
return _forms_impl(
st.builds(
partial(form, **form_kwargs),
data=st.fixed_dictionaries(field_strategies), # type: ignore
)
)
@st.composite
def _forms_impl(draw, strat):
"""Handle the nasty part of drawing a value for from_form()"""
try:
return draw(strat)
except ValidationError:
reject()

View File

@@ -0,0 +1,53 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
"""
-----------------------
hypothesis[dpcontracts]
-----------------------
This module provides tools for working with the :pypi:`dpcontracts` library,
because `combining contracts and property-based testing works really well
<https://hillelwayne.com/talks/beyond-unit-tests/>`_.
It requires ``dpcontracts >= 0.4``.
"""
from dpcontracts import PreconditionError
from hypothesis import reject
from hypothesis.errors import InvalidArgument
from hypothesis.internal.reflection import proxies
def fulfill(contract_func):
"""Decorate ``contract_func`` to reject calls which violate preconditions,
and retry them with different arguments.
This is a convenience function for testing internal code that uses
:pypi:`dpcontracts`, to automatically filter out arguments that would be
rejected by the public interface before triggering a contract error.
This can be used as ``builds(fulfill(func), ...)`` or in the body of the
test e.g. ``assert fulfill(func)(*args)``.
"""
if not hasattr(contract_func, "__contract_wrapped_func__"):
raise InvalidArgument(
f"{contract_func.__name__} has no dpcontracts preconditions"
)
@proxies(contract_func)
def inner(*args, **kwargs):
try:
return contract_func(*args, **kwargs)
except PreconditionError:
reject()
return inner

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,217 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
"""
----------------
hypothesis[lark]
----------------
This extra can be used to generate strings matching any context-free grammar,
using the `Lark parser library <https://github.com/lark-parser/lark>`_.
It currently only supports Lark's native EBNF syntax, but we plan to extend
this to support other common syntaxes such as ANTLR and :rfc:`5234` ABNF.
Lark already `supports loading grammars
<https://lark-parser.readthedocs.io/en/latest/nearley.html>`_
from `nearley.js <https://nearley.js.org/>`_, so you may not have to write
your own at all.
"""
from inspect import signature
from typing import Dict, Optional
import lark
from lark.grammar import NonTerminal, Terminal
from hypothesis import strategies as st
from hypothesis.errors import InvalidArgument
from hypothesis.internal.conjecture.utils import calc_label_from_name
from hypothesis.internal.validation import check_type
from hypothesis.strategies._internal.utils import cacheable, defines_strategy
__all__ = ["from_lark"]
def get_terminal_names(terminals, rules, ignore_names):
"""Get names of all terminals in the grammar.
The arguments are the results of calling ``Lark.grammar.compile()``,
so you would think that the ``terminals`` and ``ignore_names`` would
have it all... but they omit terminals created with ``@declare``,
which appear only in the expansion(s) of nonterminals.
"""
names = {t.name for t in terminals} | set(ignore_names)
for rule in rules:
names |= {t.name for t in rule.expansion if isinstance(t, Terminal)}
return names
class LarkStrategy(st.SearchStrategy):
"""Low-level strategy implementation wrapping a Lark grammar.
See ``from_lark`` for details.
"""
def __init__(self, grammar, start, explicit):
assert isinstance(grammar, lark.lark.Lark)
if start is None:
start = grammar.options.start
if not isinstance(start, list):
start = [start]
self.grammar = grammar
# This is a total hack, but working around the changes is a nicer user
# experience than breaking for anyone who doesn't instantly update their
# installation of Lark alongside Hypothesis.
compile_args = signature(grammar.grammar.compile).parameters
if "terminals_to_keep" in compile_args:
terminals, rules, ignore_names = grammar.grammar.compile(start, ())
elif "start" in compile_args: # pragma: no cover
# Support lark <= 0.10.0, without the terminals_to_keep argument.
terminals, rules, ignore_names = grammar.grammar.compile(start)
else: # pragma: no cover
# This branch is to support lark <= 0.7.1, without the start argument.
terminals, rules, ignore_names = grammar.grammar.compile()
self.names_to_symbols = {}
for r in rules:
t = r.origin
self.names_to_symbols[t.name] = t
for t in terminals:
self.names_to_symbols[t.name] = Terminal(t.name)
self.start = st.sampled_from([self.names_to_symbols[s] for s in start])
self.ignored_symbols = tuple(self.names_to_symbols[n] for n in ignore_names)
self.terminal_strategies = {
t.name: st.from_regex(t.pattern.to_regexp(), fullmatch=True)
for t in terminals
}
unknown_explicit = set(explicit) - get_terminal_names(
terminals, rules, ignore_names
)
if unknown_explicit:
raise InvalidArgument(
"The following arguments were passed as explicit_strategies, "
"but there is no such terminal production in this grammar: "
+ repr(sorted(unknown_explicit))
)
self.terminal_strategies.update(explicit)
nonterminals = {}
for rule in rules:
nonterminals.setdefault(rule.origin.name, []).append(tuple(rule.expansion))
for v in nonterminals.values():
v.sort(key=len)
self.nonterminal_strategies = {
k: st.sampled_from(v) for k, v in nonterminals.items()
}
self.__rule_labels = {}
def do_draw(self, data):
state = []
start = data.draw(self.start)
self.draw_symbol(data, start, state)
return "".join(state)
def rule_label(self, name):
try:
return self.__rule_labels[name]
except KeyError:
return self.__rule_labels.setdefault(
name, calc_label_from_name(f"LARK:{name}")
)
def draw_symbol(self, data, symbol, draw_state):
if isinstance(symbol, Terminal):
try:
strategy = self.terminal_strategies[symbol.name]
except KeyError:
raise InvalidArgument(
"Undefined terminal %r. Generation does not currently support "
"use of %%declare unless you pass `explicit`, a dict of "
'names-to-strategies, such as `{%r: st.just("")}`'
% (symbol.name, symbol.name)
) from None
draw_state.append(data.draw(strategy))
else:
assert isinstance(symbol, NonTerminal)
data.start_example(self.rule_label(symbol.name))
expansion = data.draw(self.nonterminal_strategies[symbol.name])
for e in expansion:
self.draw_symbol(data, e, draw_state)
self.gen_ignore(data, draw_state)
data.stop_example()
def gen_ignore(self, data, draw_state):
if self.ignored_symbols and data.draw_boolean(1 / 4):
emit = data.draw(st.sampled_from(self.ignored_symbols))
self.draw_symbol(data, emit, draw_state)
def calc_has_reusable_values(self, recur):
return True
def check_explicit(name):
def inner(value):
check_type(str, value, "value drawn from " + name)
return value
return inner
@cacheable
@defines_strategy(force_reusable_values=True)
def from_lark(
grammar: lark.lark.Lark,
*,
start: Optional[str] = None,
explicit: Optional[Dict[str, st.SearchStrategy[str]]] = None,
) -> st.SearchStrategy[str]:
"""A strategy for strings accepted by the given context-free grammar.
``grammar`` must be a ``Lark`` object, which wraps an EBNF specification.
The Lark EBNF grammar reference can be found
`here <https://lark-parser.readthedocs.io/en/latest/grammar.html>`_.
``from_lark`` will automatically generate strings matching the
nonterminal ``start`` symbol in the grammar, which was supplied as an
argument to the Lark class. To generate strings matching a different
symbol, including terminals, you can override this by passing the
``start`` argument to ``from_lark``. Note that Lark may remove unreachable
productions when the grammar is compiled, so you should probably pass the
same value for ``start`` to both.
Currently ``from_lark`` does not support grammars that need custom lexing.
Any lexers will be ignored, and any undefined terminals from the use of
``%declare`` will result in generation errors. To define strategies for
such terminals, pass a dictionary mapping their name to a corresponding
strategy as the ``explicit`` argument.
The :pypi:`hypothesmith` project includes a strategy for Python source,
based on a grammar and careful post-processing.
"""
check_type(lark.lark.Lark, grammar, "grammar")
if explicit is None:
explicit = {}
else:
check_type(dict, explicit, "explicit")
explicit = {
k: v.map(check_explicit(f"explicit[{k!r}]={v!r}"))
for k, v in explicit.items()
}
return LarkStrategy(grammar, start, explicit)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,20 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
from hypothesis.extra.pandas.impl import (
column,
columns,
data_frames,
indexes,
range_indexes,
series,
)
__all__ = ["indexes", "range_indexes", "series", "column", "columns", "data_frames"]

View File

@@ -0,0 +1,756 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
from collections import OrderedDict, abc
from copy import copy
from datetime import datetime, timedelta
from typing import Any, List, Optional, Sequence, Set, Union
import attr
import numpy as np
import pandas
from hypothesis import strategies as st
from hypothesis._settings import note_deprecation
from hypothesis.control import reject
from hypothesis.errors import InvalidArgument
from hypothesis.extra import numpy as npst
from hypothesis.internal.conjecture import utils as cu
from hypothesis.internal.coverage import check, check_function
from hypothesis.internal.reflection import get_pretty_function_description
from hypothesis.internal.validation import (
check_type,
check_valid_interval,
check_valid_size,
try_convert,
)
from hypothesis.strategies._internal.strategies import Ex, check_strategy
from hypothesis.strategies._internal.utils import cacheable, defines_strategy
try:
from pandas.core.arrays.integer import IntegerDtype
except ImportError:
IntegerDtype = ()
def dtype_for_elements_strategy(s):
return st.shared(
s.map(lambda x: pandas.Series([x]).dtype),
key=("hypothesis.extra.pandas.dtype_for_elements_strategy", s),
)
def infer_dtype_if_necessary(dtype, values, elements, draw):
if dtype is None and not values:
return draw(dtype_for_elements_strategy(elements))
return dtype
@check_function
def elements_and_dtype(elements, dtype, source=None):
if source is None:
prefix = ""
else:
prefix = f"{source}."
if elements is not None:
check_strategy(elements, f"{prefix}elements")
else:
with check("dtype is not None"):
if dtype is None:
raise InvalidArgument(
f"At least one of {prefix}elements or {prefix}dtype must be provided."
)
with check("isinstance(dtype, CategoricalDtype)"):
if pandas.api.types.CategoricalDtype.is_dtype(dtype):
raise InvalidArgument(
f"{prefix}dtype is categorical, which is currently unsupported"
)
if isinstance(dtype, type) and issubclass(dtype, IntegerDtype):
raise InvalidArgument(
f"Passed {dtype=} is a dtype class, please pass in an instance of this class."
"Otherwise it would be treated as dtype=object"
)
if isinstance(dtype, type) and np.dtype(dtype).kind == "O" and dtype is not object:
err_msg = f"Passed {dtype=} is not a valid Pandas dtype."
if issubclass(dtype, datetime):
err_msg += ' To generate valid datetimes, pass `dtype="datetime64[ns]"`'
raise InvalidArgument(err_msg)
elif issubclass(dtype, timedelta):
err_msg += ' To generate valid timedeltas, pass `dtype="timedelta64[ns]"`'
raise InvalidArgument(err_msg)
note_deprecation(
f"{err_msg} We'll treat it as "
"dtype=object for now, but this will be an error in a future version.",
since="2021-12-31",
has_codemod=False,
stacklevel=1,
)
if isinstance(dtype, st.SearchStrategy):
raise InvalidArgument(
f"Passed {dtype=} is a strategy, but we require a concrete dtype "
"here. See https://stackoverflow.com/q/74355937 for workaround patterns."
)
_get_subclasses = getattr(IntegerDtype, "__subclasses__", list)
dtype = {t.name: t() for t in _get_subclasses()}.get(dtype, dtype)
if isinstance(dtype, IntegerDtype):
is_na_dtype = True
dtype = np.dtype(dtype.name.lower())
elif dtype is not None:
is_na_dtype = False
dtype = try_convert(np.dtype, dtype, "dtype")
else:
is_na_dtype = False
if elements is None:
elements = npst.from_dtype(dtype)
if is_na_dtype:
elements = st.none() | elements
elif dtype is not None:
def convert_element(value):
if is_na_dtype and value is None:
return None
name = f"draw({prefix}elements)"
try:
return np.array([value], dtype=dtype)[0]
except (TypeError, ValueError):
raise InvalidArgument(
"Cannot convert %s=%r of type %s to dtype %s"
% (name, value, type(value).__name__, dtype.str)
) from None
elements = elements.map(convert_element)
assert elements is not None
return elements, dtype
class ValueIndexStrategy(st.SearchStrategy):
def __init__(self, elements, dtype, min_size, max_size, unique, name):
super().__init__()
self.elements = elements
self.dtype = dtype
self.min_size = min_size
self.max_size = max_size
self.unique = unique
self.name = name
def do_draw(self, data):
result = []
seen = set()
iterator = cu.many(
data,
min_size=self.min_size,
max_size=self.max_size,
average_size=(self.min_size + self.max_size) / 2,
)
while iterator.more():
elt = data.draw(self.elements)
if self.unique:
if elt in seen:
iterator.reject()
continue
seen.add(elt)
result.append(elt)
dtype = infer_dtype_if_necessary(
dtype=self.dtype, values=result, elements=self.elements, draw=data.draw
)
return pandas.Index(
result, dtype=dtype, tupleize_cols=False, name=data.draw(self.name)
)
DEFAULT_MAX_SIZE = 10
@cacheable
@defines_strategy()
def range_indexes(
min_size: int = 0,
max_size: Optional[int] = None,
name: st.SearchStrategy[Optional[str]] = st.none(),
) -> st.SearchStrategy[pandas.RangeIndex]:
"""Provides a strategy which generates an :class:`~pandas.Index` whose
values are 0, 1, ..., n for some n.
Arguments:
* min_size is the smallest number of elements the index can have.
* max_size is the largest number of elements the index can have. If None
it will default to some suitable value based on min_size.
* name is the name of the index. If st.none(), the index will have no name.
"""
check_valid_size(min_size, "min_size")
check_valid_size(max_size, "max_size")
if max_size is None:
max_size = min([min_size + DEFAULT_MAX_SIZE, 2**63 - 1])
check_valid_interval(min_size, max_size, "min_size", "max_size")
check_strategy(name)
return st.builds(pandas.RangeIndex, st.integers(min_size, max_size), name=name)
@cacheable
@defines_strategy()
def indexes(
*,
elements: Optional[st.SearchStrategy[Ex]] = None,
dtype: Any = None,
min_size: int = 0,
max_size: Optional[int] = None,
unique: bool = True,
name: st.SearchStrategy[Optional[str]] = st.none(),
) -> st.SearchStrategy[pandas.Index]:
"""Provides a strategy for producing a :class:`pandas.Index`.
Arguments:
* elements is a strategy which will be used to generate the individual
values of the index. If None, it will be inferred from the dtype. Note:
even if the elements strategy produces tuples, the generated value
will not be a MultiIndex, but instead be a normal index whose elements
are tuples.
* dtype is the dtype of the resulting index. If None, it will be inferred
from the elements strategy. At least one of dtype or elements must be
provided.
* min_size is the minimum number of elements in the index.
* max_size is the maximum number of elements in the index. If None then it
will default to a suitable small size. If you want larger indexes you
should pass a max_size explicitly.
* unique specifies whether all of the elements in the resulting index
should be distinct.
* name is a strategy for strings or ``None``, which will be passed to
the :class:`pandas.Index` constructor.
"""
check_valid_size(min_size, "min_size")
check_valid_size(max_size, "max_size")
check_valid_interval(min_size, max_size, "min_size", "max_size")
check_type(bool, unique, "unique")
elements, dtype = elements_and_dtype(elements, dtype)
if max_size is None:
max_size = min_size + DEFAULT_MAX_SIZE
return ValueIndexStrategy(elements, dtype, min_size, max_size, unique, name)
@defines_strategy()
def series(
*,
elements: Optional[st.SearchStrategy[Ex]] = None,
dtype: Any = None,
index: Optional[st.SearchStrategy[Union[Sequence, pandas.Index]]] = None,
fill: Optional[st.SearchStrategy[Ex]] = None,
unique: bool = False,
name: st.SearchStrategy[Optional[str]] = st.none(),
) -> st.SearchStrategy[pandas.Series]:
"""Provides a strategy for producing a :class:`pandas.Series`.
Arguments:
* elements: a strategy that will be used to generate the individual
values in the series. If None, we will attempt to infer a suitable
default from the dtype.
* dtype: the dtype of the resulting series and may be any value
that can be passed to :class:`numpy.dtype`. If None, will use
pandas's standard behaviour to infer it from the type of the elements
values. Note that if the type of values that comes out of your
elements strategy varies, then so will the resulting dtype of the
series.
* index: If not None, a strategy for generating indexes for the
resulting Series. This can generate either :class:`pandas.Index`
objects or any sequence of values (which will be passed to the
Index constructor).
You will probably find it most convenient to use the
:func:`~hypothesis.extra.pandas.indexes` or
:func:`~hypothesis.extra.pandas.range_indexes` function to produce
values for this argument.
* name: is a strategy for strings or ``None``, which will be passed to
the :class:`pandas.Series` constructor.
Usage:
.. code-block:: pycon
>>> series(dtype=int).example()
0 -2001747478
1 1153062837
"""
if index is None:
index = range_indexes()
else:
check_strategy(index, "index")
elements, np_dtype = elements_and_dtype(elements, dtype)
index_strategy = index
# if it is converted to an object, use object for series type
if (
np_dtype is not None
and np_dtype.kind == "O"
and not isinstance(dtype, IntegerDtype)
):
dtype = np_dtype
@st.composite
def result(draw):
index = draw(index_strategy)
if len(index) > 0:
if dtype is not None:
result_data = draw(
npst.arrays(
dtype=object,
elements=elements,
shape=len(index),
fill=fill,
unique=unique,
)
).tolist()
else:
result_data = list(
draw(
npst.arrays(
dtype=object,
elements=elements,
shape=len(index),
fill=fill,
unique=unique,
)
).tolist()
)
return pandas.Series(result_data, index=index, dtype=dtype, name=draw(name))
else:
return pandas.Series(
(),
index=index,
dtype=dtype
if dtype is not None
else draw(dtype_for_elements_strategy(elements)),
name=draw(name),
)
return result()
@attr.s(slots=True)
class column:
"""Data object for describing a column in a DataFrame.
Arguments:
* name: the column name, or None to default to the column position. Must
be hashable, but can otherwise be any value supported as a pandas column
name.
* elements: the strategy for generating values in this column, or None
to infer it from the dtype.
* dtype: the dtype of the column, or None to infer it from the element
strategy. At least one of dtype or elements must be provided.
* fill: A default value for elements of the column. See
:func:`~hypothesis.extra.numpy.arrays` for a full explanation.
* unique: If all values in this column should be distinct.
"""
name = attr.ib(default=None)
elements = attr.ib(default=None)
dtype = attr.ib(default=None, repr=get_pretty_function_description)
fill = attr.ib(default=None)
unique = attr.ib(default=False)
def columns(
names_or_number: Union[int, Sequence[str]],
*,
dtype: Any = None,
elements: Optional[st.SearchStrategy[Ex]] = None,
fill: Optional[st.SearchStrategy[Ex]] = None,
unique: bool = False,
) -> List[column]:
"""A convenience function for producing a list of :class:`column` objects
of the same general shape.
The names_or_number argument is either a sequence of values, the
elements of which will be used as the name for individual column
objects, or a number, in which case that many unnamed columns will
be created. All other arguments are passed through verbatim to
create the columns.
"""
if isinstance(names_or_number, (int, float)):
names: List[Union[int, str, None]] = [None] * names_or_number
else:
names = list(names_or_number)
return [
column(name=n, dtype=dtype, elements=elements, fill=fill, unique=unique)
for n in names
]
@defines_strategy()
def data_frames(
columns: Optional[Sequence[column]] = None,
*,
rows: Optional[st.SearchStrategy[Union[dict, Sequence[Any]]]] = None,
index: Optional[st.SearchStrategy[Ex]] = None,
) -> st.SearchStrategy[pandas.DataFrame]:
"""Provides a strategy for producing a :class:`pandas.DataFrame`.
Arguments:
* columns: An iterable of :class:`column` objects describing the shape
of the generated DataFrame.
* rows: A strategy for generating a row object. Should generate
either dicts mapping column names to values or a sequence mapping
column position to the value in that position (note that unlike the
:class:`pandas.DataFrame` constructor, single values are not allowed
here. Passing e.g. an integer is an error, even if there is only one
column).
At least one of rows and columns must be provided. If both are
provided then the generated rows will be validated against the
columns and an error will be raised if they don't match.
Caveats on using rows:
* In general you should prefer using columns to rows, and only use
rows if the columns interface is insufficiently flexible to
describe what you need - you will get better performance and
example quality that way.
* If you provide rows and not columns, then the shape and dtype of
the resulting DataFrame may vary. e.g. if you have a mix of int
and float in the values for one column in your row entries, the
column will sometimes have an integral dtype and sometimes a float.
* index: If not None, a strategy for generating indexes for the
resulting DataFrame. This can generate either :class:`pandas.Index`
objects or any sequence of values (which will be passed to the
Index constructor).
You will probably find it most convenient to use the
:func:`~hypothesis.extra.pandas.indexes` or
:func:`~hypothesis.extra.pandas.range_indexes` function to produce
values for this argument.
Usage:
The expected usage pattern is that you use :class:`column` and
:func:`columns` to specify a fixed shape of the DataFrame you want as
follows. For example the following gives a two column data frame:
.. code-block:: pycon
>>> from hypothesis.extra.pandas import column, data_frames
>>> data_frames([
... column('A', dtype=int), column('B', dtype=float)]).example()
A B
0 2021915903 1.793898e+232
1 1146643993 inf
2 -2096165693 1.000000e+07
If you want the values in different columns to interact in some way you
can use the rows argument. For example the following gives a two column
DataFrame where the value in the first column is always at most the value
in the second:
.. code-block:: pycon
>>> from hypothesis.extra.pandas import column, data_frames
>>> import hypothesis.strategies as st
>>> data_frames(
... rows=st.tuples(st.floats(allow_nan=False),
... st.floats(allow_nan=False)).map(sorted)
... ).example()
0 1
0 -3.402823e+38 9.007199e+15
1 -1.562796e-298 5.000000e-01
You can also combine the two:
.. code-block:: pycon
>>> from hypothesis.extra.pandas import columns, data_frames
>>> import hypothesis.strategies as st
>>> data_frames(
... columns=columns(["lo", "hi"], dtype=float),
... rows=st.tuples(st.floats(allow_nan=False),
... st.floats(allow_nan=False)).map(sorted)
... ).example()
lo hi
0 9.314723e-49 4.353037e+45
1 -9.999900e-01 1.000000e+07
2 -2.152861e+134 -1.069317e-73
(Note that the column dtype must still be specified and will not be
inferred from the rows. This restriction may be lifted in future).
Combining rows and columns has the following behaviour:
* The column names and dtypes will be used.
* If the column is required to be unique, this will be enforced.
* Any values missing from the generated rows will be provided using the
column's fill.
* Any values in the row not present in the column specification (if
dicts are passed, if there are keys with no corresponding column name,
if sequences are passed if there are too many items) will result in
InvalidArgument being raised.
"""
if index is None:
index = range_indexes()
else:
check_strategy(index, "index")
index_strategy = index
if columns is None:
if rows is None:
raise InvalidArgument("At least one of rows and columns must be provided")
else:
@st.composite
def rows_only(draw):
index = draw(index_strategy)
@check_function
def row():
result = draw(rows)
check_type(abc.Iterable, result, "draw(row)")
return result
if len(index) > 0:
return pandas.DataFrame([row() for _ in index], index=index)
else:
# If we haven't drawn any rows we need to draw one row and
# then discard it so that we get a consistent shape for the
# DataFrame.
base = pandas.DataFrame([row()])
return base.drop(0)
return rows_only()
assert columns is not None
cols = try_convert(tuple, columns, "columns")
rewritten_columns = []
column_names: Set[str] = set()
for i, c in enumerate(cols):
check_type(column, c, f"columns[{i}]")
c = copy(c)
if c.name is None:
label = f"columns[{i}]"
c.name = i
else:
label = c.name
try:
hash(c.name)
except TypeError:
raise InvalidArgument(
f"Column names must be hashable, but columns[{i}].name was "
f"{c.name!r} of type {type(c.name).__name__}, which cannot be hashed."
) from None
if c.name in column_names:
raise InvalidArgument(f"duplicate definition of column name {c.name!r}")
column_names.add(c.name)
c.elements, _ = elements_and_dtype(c.elements, c.dtype, label)
if c.dtype is None and rows is not None:
raise InvalidArgument(
"Must specify a dtype for all columns when combining rows with columns."
)
c.fill = npst.fill_for(
fill=c.fill, elements=c.elements, unique=c.unique, name=label
)
rewritten_columns.append(c)
if rows is None:
@st.composite
def just_draw_columns(draw):
index = draw(index_strategy)
local_index_strategy = st.just(index)
data = OrderedDict((c.name, None) for c in rewritten_columns)
# Depending on how the columns are going to be generated we group
# them differently to get better shrinking. For columns with fill
# enabled, the elements can be shrunk independently of the size,
# so we can just shrink by shrinking the index then shrinking the
# length and are generally much more free to move data around.
# For columns with no filling the problem is harder, and drawing
# them like that would result in rows being very far apart from
# each other in the underlying data stream, which gets in the way
# of shrinking. So what we do is reorder and draw those columns
# row wise, so that the values of each row are next to each other.
# This makes life easier for the shrinker when deleting blocks of
# data.
columns_without_fill = [c for c in rewritten_columns if c.fill.is_empty]
if columns_without_fill:
for c in columns_without_fill:
data[c.name] = pandas.Series(
np.zeros(shape=len(index), dtype=object),
index=index,
dtype=c.dtype,
)
seen = {c.name: set() for c in columns_without_fill if c.unique}
for i in range(len(index)):
for c in columns_without_fill:
if c.unique:
for _ in range(5):
value = draw(c.elements)
if value not in seen[c.name]:
seen[c.name].add(value)
break
else:
reject()
else:
value = draw(c.elements)
try:
data[c.name][i] = value
except ValueError as err: # pragma: no cover
# This just works in Pandas 1.4 and later, but gives
# a confusing error on previous versions.
if c.dtype is None and not isinstance(
value, (float, int, str, bool, datetime, timedelta)
):
raise ValueError(
f"Failed to add {value=} to column "
f"{c.name} with dtype=None. Maybe passing "
"dtype=object would help?"
) from err
# Unclear how this could happen, but users find a way...
raise
for c in rewritten_columns:
if not c.fill.is_empty:
data[c.name] = draw(
series(
index=local_index_strategy,
dtype=c.dtype,
elements=c.elements,
fill=c.fill,
unique=c.unique,
)
)
return pandas.DataFrame(data, index=index)
return just_draw_columns()
else:
@st.composite
def assign_rows(draw):
index = draw(index_strategy)
result = pandas.DataFrame(
OrderedDict(
(
c.name,
pandas.Series(
np.zeros(dtype=c.dtype, shape=len(index)), dtype=c.dtype
),
)
for c in rewritten_columns
),
index=index,
)
fills = {}
any_unique = any(c.unique for c in rewritten_columns)
if any_unique:
all_seen = [set() if c.unique else None for c in rewritten_columns]
while all_seen[-1] is None:
all_seen.pop()
for row_index in range(len(index)):
for _ in range(5):
original_row = draw(rows)
row = original_row
if isinstance(row, dict):
as_list = [None] * len(rewritten_columns)
for i, c in enumerate(rewritten_columns):
try:
as_list[i] = row[c.name]
except KeyError:
try:
as_list[i] = fills[i]
except KeyError:
if c.fill.is_empty:
raise InvalidArgument(
f"Empty fill strategy in {c!r} cannot "
f"complete row {original_row!r}"
) from None
fills[i] = draw(c.fill)
as_list[i] = fills[i]
for k in row:
if k not in column_names:
raise InvalidArgument(
"Row %r contains column %r not in columns %r)"
% (row, k, [c.name for c in rewritten_columns])
)
row = as_list
if any_unique:
has_duplicate = False
for seen, value in zip(all_seen, row):
if seen is None:
continue
if value in seen:
has_duplicate = True
break
seen.add(value)
if has_duplicate:
continue
row = list(try_convert(tuple, row, "draw(rows)"))
if len(row) > len(rewritten_columns):
raise InvalidArgument(
f"Row {original_row!r} contains too many entries. Has "
f"{len(row)} but expected at most {len(rewritten_columns)}"
)
while len(row) < len(rewritten_columns):
c = rewritten_columns[len(row)]
if c.fill.is_empty:
raise InvalidArgument(
f"Empty fill strategy in {c!r} cannot "
f"complete row {original_row!r}"
)
row.append(draw(c.fill))
result.iloc[row_index] = row
break
else:
reject()
return result
return assign_rows()

View File

@@ -0,0 +1,19 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
"""
Stub for users who manually load our pytest plugin.
The plugin implementation is now located in a top-level module outside the main
hypothesis tree, so that Pytest can load the plugin without thereby triggering
the import of Hypothesis itself (and thus loading our own plugins).
"""
from _hypothesis_pytestplugin import * # noqa

View File

@@ -0,0 +1,54 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
"""
----------------
hypothesis[pytz]
----------------
This module provides :pypi:`pytz` timezones.
You can use this strategy to make
:py:func:`hypothesis.strategies.datetimes` and
:py:func:`hypothesis.strategies.times` produce timezone-aware values.
"""
import datetime as dt
import pytz
from pytz.tzfile import StaticTzInfo # type: ignore # considered private by typeshed
from hypothesis import strategies as st
from hypothesis.strategies._internal.utils import cacheable, defines_strategy
__all__ = ["timezones"]
@cacheable
@defines_strategy()
def timezones() -> st.SearchStrategy[dt.tzinfo]:
"""Any timezone in the Olsen database, as a pytz tzinfo object.
This strategy minimises to UTC, or the smallest possible fixed
offset, and is designed for use with
:py:func:`hypothesis.strategies.datetimes`.
"""
all_timezones = [pytz.timezone(tz) for tz in pytz.all_timezones]
# Some timezones have always had a constant offset from UTC. This makes
# them simpler than timezones with daylight savings, and the smaller the
# absolute offset the simpler they are. Of course, UTC is even simpler!
static: list = [pytz.UTC]
static += sorted(
(t for t in all_timezones if isinstance(t, StaticTzInfo)),
key=lambda tz: abs(tz.utcoffset(dt.datetime(2000, 1, 1))),
)
# Timezones which have changed UTC offset; best ordered by name.
dynamic = [tz for tz in all_timezones if tz not in static]
return st.sampled_from(static + dynamic)

View File

@@ -0,0 +1,78 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
from contextlib import contextmanager
from datetime import timedelta
from typing import Iterable
from redis import Redis
from hypothesis.database import ExampleDatabase
from hypothesis.internal.validation import check_type
class RedisExampleDatabase(ExampleDatabase):
"""Store Hypothesis examples as sets in the given :class:`~redis.Redis` datastore.
This is particularly useful for shared databases, as per the recipe
for a :class:`~hypothesis.database.MultiplexedDatabase`.
.. note::
If a test has not been run for ``expire_after``, those examples will be allowed
to expire. The default time-to-live persists examples between weekly runs.
"""
def __init__(
self,
redis: Redis,
*,
expire_after: timedelta = timedelta(days=8),
key_prefix: bytes = b"hypothesis-example:",
):
check_type(Redis, redis, "redis")
check_type(timedelta, expire_after, "expire_after")
check_type(bytes, key_prefix, "key_prefix")
self.redis = redis
self._expire_after = expire_after
self._prefix = key_prefix
def __repr__(self) -> str:
return (
f"RedisExampleDatabase({self.redis!r}, expire_after={self._expire_after!r})"
)
@contextmanager
def _pipeline(self, *reset_expire_keys, transaction=False, auto_execute=True):
# Context manager to batch updates and expiry reset, reducing TCP roundtrips
pipe = self.redis.pipeline(transaction=transaction)
yield pipe
for key in reset_expire_keys:
pipe.expire(self._prefix + key, self._expire_after)
if auto_execute:
pipe.execute()
def fetch(self, key: bytes) -> Iterable[bytes]:
with self._pipeline(key, auto_execute=False) as pipe:
pipe.smembers(self._prefix + key)
yield from pipe.execute()[0]
def save(self, key: bytes, value: bytes) -> None:
with self._pipeline(key) as pipe:
pipe.sadd(self._prefix + key, value)
def delete(self, key: bytes, value: bytes) -> None:
with self._pipeline(key) as pipe:
pipe.srem(self._prefix + key, value)
def move(self, src: bytes, dest: bytes, value: bytes) -> None:
with self._pipeline(src, dest) as pipe:
pipe.srem(self._prefix + src, value)
pipe.sadd(self._prefix + dest, value)