# This file is part of Pynguin.
#
# SPDX-FileCopyrightText: 2019–2026 Pynguin Contributors
#
# SPDX-License-Identifier: MIT
#
"""Provides analyses for the subject module, based on the module and its AST."""
from __future__ import annotations
import abc
import builtins
import dataclasses
import enum
import functools
import importlib
import inspect
import itertools
import json
import logging
import queue
import types
import typing
from collections import defaultdict
from pathlib import Path
from types import (
BuiltinFunctionType,
FunctionType,
GenericAlias,
MethodDescriptorType,
ModuleType,
WrapperDescriptorType,
)
from typing import Any
import astroid
from astroid.nodes import Assign, AsyncFunctionDef, ClassDef, FunctionDef, Lambda, Module
import pynguin.configuration as config
import pynguin.utils.statistics.stats as stat
import pynguin.utils.typetracing as tt
from pynguin.analyses.type_inference import (
ANY_STR,
HintInference,
InferenceProvider,
LLMInference,
NoInference,
TypeEvalPyInference,
)
from pynguin.utils.llm import LLMProvider
if config.configuration.pynguinml.ml_testing_enabled or typing.TYPE_CHECKING:
import pynguin.utils.pynguinml.ml_testing_resources as tr
from pynguin.analyses.generator import GeneratorProvider, RandomGeneratorProvider
from pynguin.analyses.modulecomplexity import mccabe_complexity
from pynguin.analyses.syntaxtree import (
FunctionDescription,
astroid_to_ast,
get_class_node_from_ast,
get_function_description,
get_function_node_from_ast,
)
from pynguin.analyses.typesystem import (
ANY,
AnyType,
Instance,
NoneType,
ProperType,
TupleType,
TypeInfo,
TypeSystem,
TypeVisitor,
UnionType,
Unsupported,
)
from pynguin.configuration import TypeInferenceStrategy
from pynguin.ga.operators.selection import RandomSelection, RankSelection
from pynguin.utils import randomness
from pynguin.utils.exceptions import (
ConstraintValidationError,
ConstructionFailedException,
CoroutineFoundException,
)
from pynguin.utils.generic.genericaccessibleobject import (
GenericAccessibleObject,
GenericCallableAccessibleObject,
GenericConstructor,
GenericEnum,
GenericFunction,
GenericMethod,
)
from pynguin.utils.orderedset import OrderedSet
from pynguin.utils.statistics.runtimevariable import RuntimeVariable
from pynguin.utils.type_utils import COLLECTIONS, PRIMITIVES, get_class_that_defined_method
from pynguin.utils.typeevalpy_json_schema import ParsedTypeEvalPyData, parse_json, provide_json
if config.configuration.pynguinml.ml_testing_enabled or typing.TYPE_CHECKING:
import pynguin.utils.pynguinml.ml_testing_resources as tr
if typing.TYPE_CHECKING:
from collections.abc import Callable, Sequence
import pynguin.ga.algorithms.archive as arch
import pynguin.ga.computations as ff
from pynguin.instrumentation.tracer import SubjectProperties
from pynguin.utils.pynguinml.mlparameter import MLParameter
AstroidFunctionDef: typing.TypeAlias = AsyncFunctionDef | FunctionDef
LOGGER = logging.getLogger(__name__)
# A set of modules that shall be blacklisted from analysis (keep them sorted to ease
# future manipulations or looking up module names of this set!!!):
# The modules that are listed here are not prohibited from execution, but Pynguin will
# not consider any classes or functions from these modules for generating inputs to
# other routines
MODULE_BLACKLIST = frozenset((
"__future__",
"_frozen_importlib",
"_thread",
"abc",
"argparse",
"asyncio",
"atexit",
"builtins",
"cmd",
"code",
"codeop",
"collections.abc",
"compileall",
"concurrent",
"concurrent.futures",
"configparser",
"contextlib",
"contextvars",
"copy",
"copyreg",
"csv",
"ctypes",
"dbm",
"dis",
"filecmp",
"fileinput",
"fnmatch",
"functools",
"gc",
"getopt",
"getpass",
"glob",
"importlib",
"io",
"itertools",
"linecache",
"logging",
"logging.config",
"logging.handlers",
"marshal",
"mmap",
"multiprocessing",
"multiprocessing.shared_memory",
"netrc",
"operator",
"os",
"os.path",
"pathlib",
"pickle",
"pickletools",
"plistlib",
"py_compile",
"queue",
"random",
"reprlib",
"sched",
"secrets",
"select",
"selectors",
"shelve",
"shutil",
"signal",
"six", # Not from STDLIB
"socket",
"sre_compile",
"sre_parse",
"ssl",
"stat",
"subprocess",
"sys",
"tarfile",
"tempfile",
"threading",
"timeit",
"trace",
"traceback",
"tracemalloc",
"types",
"typing",
"warnings",
"weakref",
))
# Blacklist for methods.
METHOD_BLACKLIST = frozenset(("time.sleep",))
def _is_blacklisted(element: Any) -> bool:
"""Checks if the given element belongs to the blacklist.
Args:
element: The element to check
Returns:
Is the element blacklisted?
"""
module_blacklist = set(MODULE_BLACKLIST).union(config.configuration.ignore_modules)
method_blacklist = set(METHOD_BLACKLIST).union(config.configuration.ignore_methods)
try:
if inspect.ismodule(element):
return element.__name__ in module_blacklist
if inspect.isclass(element):
if element.__module__ == "builtins" and (
element in PRIMITIVES or element in COLLECTIONS
):
# Allow some builtin types
return False
return element.__module__ in module_blacklist
if inspect.isfunction(element):
# Some modules can be run standalone using a main function or provide a small
# set of tests ('test'). We don't want to include those functions.
# Importing certain modules such as inspect, that use or import C-functions can
# lead to __module__ being None. We want to exclude these functions as well.
return (
element.__module__ is None
or element.__module__ in module_blacklist
or element.__qualname__.startswith((
"main",
"test",
))
or f"{element.__module__}.{element.__qualname__}" in method_blacklist
)
except Exception: # noqa: BLE001
LOGGER.warning(
"Could not check if %s is blacklisted. Assuming it is not.", element, exc_info=True
)
# Something that is not supported yet.
return False
C_MODULE_WHITELIST = frozenset((
# === Basic C modules (interpreter startup) ===
"abc",
"ast",
"codecs",
"collections",
"enum",
"functools",
"imp",
"io",
"locale",
"operator",
"signal",
"sitebuiltins",
"stat",
"thread",
"tracemalloc",
"weakref",
"builtins",
"errno",
"marshal",
"sys",
"time",
"sre",
"symtable",
"warnings",
"string",
"re",
"inspect",
"tokenize",
# === Common Data Structures & Algorithms ===
"array",
"bisect",
"heapq",
"itertools",
# === Math & Random ===
"math",
"random",
"statistics",
# === Data Serialization & Formats ===
"csv",
"json",
"pickle",
"struct",
"elementtree",
"pyexpat",
"binascii",
# === Hashing & Cryptography ===
"hashlib",
"ssl",
"blake2",
"md5",
"sha3",
"unicodedata",
# === Compression ===
"zlib",
"bz2",
"lzma",
# === Concurrency & Interoperability ===
# "multiprocessing", # Explicitly not included
# "ctypes", # Explicitly not included
# "asyncio", # Explicitly not included
# === Networking ===
"socket",
"select",
# === Database ===
"sqlite3",
# === GUI ===
"tkinter",
# === Introspection & Debugging ===
"gc",
"faulthandler",
# === Platform: POSIX/Unix-like ===
"posix",
# "posixsubprocess", # Explicitly not included
"fcntl",
"grp",
"pwd",
"resource",
"termios",
# === Platform: Windows ===
"winapi",
"msvcrt",
# === Platform: macOS ===
"scproxy",
# === Platform: Cross-platform ===
"mmap",
# --- Other ---
"queue",
"decimal",
"uuid",
"datetime",
"zoneinfo",
"shlex",
"calendar",
"yaml",
"email",
"syslog",
"dataclasses",
"pprint",
"difflib",
"cmath",
"hmac",
))
def _c_is_whitelisted(element: ModuleType) -> bool:
"""Checks if the given element belongs to the C module whitelist.
Args:
element: The element to check
Returns:
Is the element whitelisted?
"""
c_module_whitelist = set(C_MODULE_WHITELIST)
try:
module_name = element.__name__
top_level = module_name.split(".")[0]
public_top_level = top_level.lstrip("_")
return public_top_level in c_module_whitelist
except Exception: # noqa: BLE001
LOGGER.warning(
"Could not check if %s is whitelisted. Assuming it is not.", element, exc_info=True
)
# Something that is not supported yet.
return False
def _handle_c_modules(
c_extensions: set[str],
) -> None:
"""Handles the C extensions in the subject module.
Args:
c_extensions: The set of C extensions.
"""
subprocess_mode_recommended = len(c_extensions) > 0
if config.configuration.subprocess_if_recommended:
config.configuration.subprocess = subprocess_mode_recommended
if config.configuration.subprocess:
LOGGER.info(
"Subprocess mode is set to %s because the subject module uses "
"the following C extensions: %s. ",
config.configuration.subprocess,
", ".join(sorted(c_extensions)),
)
else:
LOGGER.debug(
"Subprocess mode is set to %s because the subject module does not use "
"any C extensions. ",
config.configuration.subprocess,
)
elif not config.configuration.subprocess and subprocess_mode_recommended:
LOGGER.warning(
"You are using threaded execution mode, but the subject module "
"uses the following C extensions: %s. "
"This may lead to unexpected behavior, consider using "
"subprocess mode instead.",
", ".join(sorted(c_extensions)),
)
# Store the discovered C extensions in the statistics
stat.track_output_variable(RuntimeVariable.CExtensionModules, str(sorted(c_extensions)))
stat.track_output_variable(RuntimeVariable.SubprocessMode, str(config.configuration.subprocess))
@dataclasses.dataclass
class _ModuleParseResult:
"""A data wrapper for an imported and parsed module."""
linenos: int
module_name: str
module: ModuleType
syntax_tree: Module | None
[docs]
def import_module(module_name: str) -> ModuleType:
"""Imports a module by name.
Unlike the built-in :py:func:`importlib.import_module`, this function also supports
importing module aliases.
Args:
module_name: The fully-qualified name of the module
Returns:
The imported module
"""
try:
return importlib.import_module(module_name)
except ModuleNotFoundError as error:
try:
package_name, submodule_name = module_name.rsplit(".", 1)
except ValueError as e:
raise error from e
try:
package = import_module(package_name)
except ModuleNotFoundError as e:
raise error from e
try:
submodule = getattr(package, submodule_name)
except AttributeError as e:
raise error from e
if not inspect.ismodule(submodule):
raise error
return submodule
[docs]
def read_module_ast(module_path: str, module_name: str) -> tuple[Module, str]:
"""Reads the AST of the module and returns it along with its source code.
Args:
module_path: The path of the module.
module_name: The name of the module.
Raises:
OSError: if the module file cannot be read.
AstroidError: if an error occurs during the creation of the AST.
Returns:
A tuple containing the AST and the source code.
"""
source_code = Path(module_path).read_text(encoding="utf-8")
syntax_tree = astroid.parse(code=source_code, module_name=module_name, path=module_path)
return syntax_tree, source_code
[docs]
def parse_module(module_name: str) -> _ModuleParseResult:
"""Parses a module and extracts its module-type and AST.
If the source code is not available it is not possible to build an AST. In this
case the respective field of the :py:class:`_ModuleParseResult` will contain the
value ``None``. This is the case, for example, for modules written in native code,
for example, in C.
Args:
module_name: The fully-qualified name of the module
Returns:
A tuple of the imported module type and its optional AST
"""
module = import_module(module_name)
syntax_tree: Module | None = None
linenos: int = -1
try:
module_path = inspect.getsourcefile(module)
assert module_path is not None, f"Could not determine the path of module {module}"
syntax_tree, source_code = read_module_ast(module_path, module_name)
except (
TypeError, # from `inspect.getsourcefile`
AssertionError, # from `assert`
OSError,
astroid.AstroidError,
) as error:
LOGGER.debug(
f"Could not retrieve source code for module {module_name} " # noqa: G004
f"({error}). "
f"Cannot derive syntax tree to allow Pynguin using more precise analysis."
)
else:
linenos = len(source_code.splitlines())
return _ModuleParseResult(
linenos=linenos,
module_name=module_name,
module=module,
syntax_tree=syntax_tree,
)
[docs]
class TestCluster(abc.ABC): # noqa: PLR0904
"""Interface for a test cluster."""
@property
@abc.abstractmethod
def type_system(self) -> TypeSystem:
"""Provides the inheritance graph."""
@property
@abc.abstractmethod
def linenos(self) -> int:
"""Provide the number of source code lines."""
[docs]
@abc.abstractmethod
def log_cluster_statistics(self) -> None:
"""Log the signatures of all seen callables."""
[docs]
@abc.abstractmethod
def add_generator(self, generator: GenericAccessibleObject) -> None:
"""Add the given accessible as a generator.
Args:
generator: The accessible object
"""
[docs]
@abc.abstractmethod
def add_accessible_object_under_test(
self, objc: GenericAccessibleObject, data: CallableData
) -> None:
"""Add accessible object to the objects under test.
Args:
objc: The accessible object
data: The function-description data
"""
[docs]
@abc.abstractmethod
def add_modifier(self, typ: TypeInfo, obj: GenericAccessibleObject) -> None:
"""Add a modifier.
A modifier is something that can be used to modify the given type,
for example, a method.
Args:
typ: The type that can be modified
obj: The accessible that can modify
"""
@property
@abc.abstractmethod
def accessible_objects_under_test(self) -> OrderedSet[GenericAccessibleObject]:
"""Provides all accessible objects under test."""
@property
@abc.abstractmethod
def function_data_for_accessibles(
self,
) -> dict[GenericAccessibleObject, CallableData]:
"""Provides all function data for all accessibles."""
[docs]
@abc.abstractmethod
def add_ml_data(self, obj: GenericAccessibleObject, data: MLCallableData) -> None:
"""Provides ML data for a accessible."""
[docs]
@abc.abstractmethod
def get_ml_data_for(self, generic_accessible: GenericAccessibleObject) -> MLCallableData | None:
"""Provides ML data for a accessible."""
[docs]
@abc.abstractmethod
def num_accessible_objects_under_test(self) -> int:
"""Provide the number of accessible objects under test.
Useful to check whether there is even something to test.
"""
[docs]
@abc.abstractmethod
def get_generators_for(self, typ: ProperType) -> OrderedSet[GenericAccessibleObject]:
"""Retrieve all known generators for the given type.
Args:
typ: The type we want to have the generators for
Returns:
The set of all generators for that type.
"""
[docs]
@abc.abstractmethod
def get_modifiers_for(self, typ: ProperType) -> OrderedSet[GenericAccessibleObject]:
"""Get all known modifiers for a type.
Args:
typ: The type
Returns:
The set of all accessibles that can modify the type # noqa: DAR202
"""
@property
@abc.abstractmethod
def generators(self) -> dict[ProperType, OrderedSet[GenericAccessibleObject]]:
"""Provides all available generators."""
@property
@abc.abstractmethod
def modifiers(self) -> dict[TypeInfo, OrderedSet[GenericAccessibleObject]]:
"""Provides all available modifiers."""
[docs]
@abc.abstractmethod
def get_random_accessible(self) -> GenericAccessibleObject | None:
"""Provides a random accessible of the unit under test.
Returns:
A random accessible, or None if there is none # noqa: DAR202
"""
[docs]
@abc.abstractmethod
def get_random_call_for(self, typ: ProperType) -> GenericAccessibleObject:
"""Get a random modifier for the given type.
Args:
typ: The type
Returns:
A random modifier for that type # noqa: DAR202
Raises:
ConstructionFailedException: if no modifiers for the type
exist# noqa: DAR402
"""
[docs]
@abc.abstractmethod
def get_all_generatable_types(self) -> list[ProperType]:
"""Provides all types that can be generated.
This includes primitives and collections.
Returns:
A list of all types that can be generated # noqa: DAR202
"""
[docs]
@abc.abstractmethod
def select_concrete_type(self, typ: ProperType) -> ProperType:
"""Select a concrete type from the given type.
This is required, for example, when handling union types. Currently, only
unary types, Any, and Union are handled.
Args:
typ: An optional type
Returns:
An optional type # noqa: DAR202
"""
[docs]
@abc.abstractmethod
def track_statistics_values(self, tracking_fun: Callable[[RuntimeVariable, Any], None]) -> None:
"""Track statistics values from the test cluster and its items.
Args:
tracking_fun: The tracking function as a callback.
"""
[docs]
@abc.abstractmethod
def update_return_type(
self, accessible: GenericCallableAccessibleObject, new_type: ProperType
) -> None:
"""Update the return for the given accessible to the new seen type.
Args:
accessible: the accessible that was observed
new_type: the new return type
"""
[docs]
@abc.abstractmethod
def update_parameter_knowledge(
self,
accessible: GenericCallableAccessibleObject,
param_name: str,
knowledge: tt.UsageTraceNode,
) -> None:
"""Update the knowledge about the parameter of the given accessible.
Args:
accessible: the accessible that was observed.
param_name: the parameter name for which we have new information.
knowledge: the new information.
"""
[docs]
@dataclasses.dataclass
class SignatureInfo:
"""Another utility class to group information per callable."""
# A dictionary mapping parameter names and to their developer annotated parameters
# types.
# Does not include self, etc.
annotated_parameter_types: dict[str, str] = dataclasses.field(default_factory=dict)
# Similar to above, but with guessed parameters types.
# Contains multiples type guesses.
guessed_parameter_types: dict[str, list[str]] = dataclasses.field(default_factory=dict)
# Needed to compute top-n accuracy in the evaluation.
# Elements are of form (A,B); A is a guess, B is an annotated type.
# (A,B) is only present, when A is a base type match of B.
# If it is present, it points to the partial type match between A and B.
partial_type_matches: dict[str, str] = dataclasses.field(default_factory=dict)
# Annotated return type, if Any.
# Does not include constructors.
annotated_return_type: str | None = None
# Recorded return type, if Any.
recorded_return_type: str | None = None
[docs]
@dataclasses.dataclass
class TypeGuessingStats:
"""Class to gather some type guessing related statistics."""
# Number of constructors in the MUT.
number_of_constructors: int = 0
# Maps names of callables to a signature info object.
signature_infos: dict[str, SignatureInfo] = dataclasses.field(
default_factory=lambda: defaultdict(SignatureInfo)
)
def _serialize_helper(obj):
"""Utility to deal with non-serializable types.
Args:
obj: The object to serialize
Returns:
A serializable object.
"""
if isinstance(obj, set):
return list(obj)
if isinstance(obj, SignatureInfo):
return dataclasses.asdict(obj)
return obj
[docs]
class ModuleTestCluster(TestCluster): # noqa: PLR0904
"""A test cluster for a module.
Contains all methods/constructors/functions and all required transitive
dependencies.
"""
def __init__(self, linenos: int) -> None: # noqa: D107
self.__type_system = TypeSystem()
self.__linenos = linenos
self.generator_provider = self._setup_generator_selection()
# Modifier belong to a certain class, not type.
self.__modifiers: dict[TypeInfo, OrderedSet[GenericAccessibleObject]] = defaultdict(
OrderedSet
)
self.__accessible_objects_under_test: OrderedSet[GenericAccessibleObject] = OrderedSet()
self.__function_data_for_accessibles: dict[GenericAccessibleObject, CallableData] = {}
self.__ml_data_for_accessibles: dict[GenericAccessibleObject, MLCallableData] = {}
# Keep track of all callables, this is only for statistics purposes.
self.__callables: OrderedSet[GenericCallableAccessibleObject] = OrderedSet()
def _setup_generator_selection(self) -> GeneratorProvider:
if (
config.configuration.generator_selection.generator_selection_algorithm
== config.Selection.RANK_SELECTION
):
return GeneratorProvider(
self.__type_system,
RankSelection(config.configuration.generator_selection.generator_selection_bias),
)
if (
config.configuration.generator_selection.generator_selection_algorithm
== config.Selection.RANDOM_SELECTION
):
return RandomGeneratorProvider(
self.__type_system,
RandomSelection(),
)
raise ValueError(
"Unsupported generator selection algorithm: "
f"{config.configuration.generator_selection.generator_selection_algorithm}"
)
[docs]
def log_cluster_statistics(self) -> None: # noqa: D102
stats = TypeGuessingStats()
for accessible in self.__accessible_objects_under_test:
if isinstance(accessible, GenericCallableAccessibleObject):
accessible.inferred_signature.log_stats_and_guess_signature(
accessible.is_constructor(), str(accessible), stats
)
stat.track_output_variable(
RuntimeVariable.SignatureInfos,
json.dumps(
stats.signature_infos,
default=_serialize_helper,
),
)
stat.track_output_variable(
RuntimeVariable.NumberOfConstructors,
str(stats.number_of_constructors),
)
self.__write_type_eval_py_output(stats)
def __write_type_eval_py_output(self, stats: TypeGuessingStats):
# Create a folder for the inferred types
signatures_folder = Path(config.configuration.statistics_output.report_dir) / "signatures"
signatures_folder.mkdir(parents=True, exist_ok=True)
project_folder = signatures_folder / config.configuration.statistics_output.project_name
project_folder.mkdir(parents=True, exist_ok=True)
module_folder = project_folder / config.configuration.module_name.split(".")[-1]
module_folder.mkdir(parents=True, exist_ok=True)
# Dump the captured type information to a JSON file
types_json = (
module_folder / f"{config.configuration.module_name.split('.')[-1]}_result.json"
)
types_json.write_text(
provide_json(
f"{config.configuration.module_name.split('.')[-1]}.py",
self.__accessible_objects_under_test,
self.__function_data_for_accessibles,
stats,
),
encoding="utf-8",
)
def _drop_generator(self, accessible: GenericCallableAccessibleObject):
gens = self.generator_provider.get_for_type(accessible.generated_type())
if gens is None or len(gens) == 0:
return
gens.discard(accessible)
if len(gens) == 0:
self.generator_provider.remove_all_generators_for(accessible.generated_type())
@staticmethod
def _add_or_make_union(
old_type: ProperType, new_type: ProperType, max_size: int = 5
) -> UnionType:
if isinstance(old_type, UnionType):
items = old_type.items
if len(items) >= max_size or new_type in items:
return old_type
new_type = UnionType(tuple(sorted((*items, new_type))))
elif old_type in {ANY, new_type}:
new_type = UnionType((new_type,))
else:
new_type = UnionType(tuple(sorted((old_type, new_type))))
return new_type
[docs]
def update_return_type( # noqa: D102
self, accessible: GenericCallableAccessibleObject, new_type: ProperType
) -> None:
# Loosely map runtime type to proper type
old_type = accessible.inferred_signature.return_type
new_type = self._add_or_make_union(old_type, new_type)
if old_type == new_type:
# No change
return
self._drop_generator(accessible)
# Must invalidate entire cache, because subtype relationship might also change
# the return values which are not new_type or old_type
self.generator_provider.clear_generator_cache()
self.get_all_generatable_types.cache_clear()
accessible.inferred_signature.return_type = new_type
self.generator_provider.add_for_type(new_type, accessible)
[docs]
def update_parameter_knowledge( # noqa: D102
self,
accessible: GenericCallableAccessibleObject,
param_name: str,
knowledge: tt.UsageTraceNode,
) -> None:
# Store new data
accessible.inferred_signature.usage_trace[param_name].merge(knowledge)
@property
def type_system(self) -> TypeSystem:
"""Provides the type system.
Returns:
The type system.
"""
return self.__type_system
@property
def linenos(self) -> int: # noqa: D102
return self.__linenos
[docs]
def add_generator(self, generator: GenericAccessibleObject) -> None: # noqa: D102
if isinstance(generator, GenericCallableAccessibleObject):
self.__callables.add(generator)
self.generator_provider.add(generator)
[docs]
def add_accessible_object_under_test( # noqa: D102
self, objc: GenericAccessibleObject, data: CallableData
) -> None:
self.__accessible_objects_under_test.add(objc)
self.__function_data_for_accessibles[objc] = data
[docs]
def add_modifier( # noqa: D102
self, typ: TypeInfo, obj: GenericAccessibleObject
) -> None:
if isinstance(obj, GenericCallableAccessibleObject):
self.__callables.add(obj)
self.__modifiers[typ].add(obj)
@property
def accessible_objects_under_test( # noqa: D102
self,
) -> OrderedSet[GenericAccessibleObject]:
return self.__accessible_objects_under_test
@property
def function_data_for_accessibles( # noqa: D102
self,
) -> dict[GenericAccessibleObject, CallableData]:
return self.__function_data_for_accessibles
[docs]
def add_ml_data(self, obj: GenericAccessibleObject, data: MLCallableData) -> None: # noqa: D102
self.__ml_data_for_accessibles[obj] = data
[docs]
def get_ml_data_for(self, obj: GenericAccessibleObject) -> MLCallableData | None: # noqa: D102
return self.__ml_data_for_accessibles.get(obj)
[docs]
def num_accessible_objects_under_test(self) -> int: # noqa: D102
return len(self.__accessible_objects_under_test)
[docs]
def get_generators_for( # noqa: D102
self, typ: ProperType
) -> OrderedSet[GenericAccessibleObject]:
return self.generator_provider.get_for_type(typ)
class _FindModifiers(TypeVisitor[OrderedSet[GenericAccessibleObject]]):
"""A visitor to find all modifiers for the given type."""
def __init__(self, cluster: TestCluster):
self.cluster = cluster
def visit_any_type(self, left: AnyType) -> OrderedSet[GenericAccessibleObject]:
# If it's Any just take everything.
return OrderedSet(itertools.chain.from_iterable(self.cluster.modifiers.values()))
def visit_none_type(self, left: NoneType) -> OrderedSet[GenericAccessibleObject]:
return OrderedSet()
def visit_instance(self, left: Instance) -> OrderedSet[GenericAccessibleObject]:
result: OrderedSet[GenericAccessibleObject] = OrderedSet()
for type_info in self.cluster.type_system.get_superclasses(left.type):
result.update(self.cluster.modifiers[type_info])
return result
def visit_tuple_type(self, left: TupleType) -> OrderedSet[GenericAccessibleObject]:
return OrderedSet()
def visit_union_type(self, left: UnionType) -> OrderedSet[GenericAccessibleObject]:
result: OrderedSet[GenericAccessibleObject] = OrderedSet()
for element in left.items:
result.update(element.accept(self)) # type: ignore[arg-type]
return result
def visit_unsupported_type(self, left: Unsupported) -> OrderedSet[GenericAccessibleObject]:
raise NotImplementedError("This type shall not be used during runtime")
[docs]
def get_modifiers_for( # noqa: D102
self, typ: ProperType
) -> OrderedSet[GenericAccessibleObject]:
return typ.accept(self._FindModifiers(self))
@property
def generators( # noqa: D102
self,
) -> dict[ProperType, OrderedSet[GenericAccessibleObject]]:
return self.generator_provider.get_all()
@property
def modifiers( # noqa: D102
self,
) -> dict[TypeInfo, OrderedSet[GenericAccessibleObject]]:
return self.__modifiers
[docs]
def get_random_accessible(self) -> GenericAccessibleObject | None: # noqa: D102
if self.num_accessible_objects_under_test() == 0:
return None
return randomness.choice(self.__accessible_objects_under_test)
[docs]
def get_random_call_for( # noqa: D102
self, typ: ProperType
) -> GenericAccessibleObject:
accessible_objects = self.get_modifiers_for(typ)
if len(accessible_objects) == 0:
raise ConstructionFailedException(f"No modifiers for {typ}")
return randomness.choice(accessible_objects)
[docs]
@functools.lru_cache(maxsize=128)
def get_all_generatable_types(self) -> list[ProperType]: # noqa: D102
generatable = self.generator_provider.get_all_types()
generatable.update(self.type_system.primitive_proper_types)
generatable.update(self.type_system.collection_proper_types)
return list(generatable)
[docs]
def select_concrete_type(self, typ: ProperType) -> ProperType: # noqa: D102
if isinstance(typ, AnyType):
typ = randomness.choice(self.get_all_generatable_types())
if isinstance(typ, UnionType):
typ = self.select_concrete_type(randomness.choice(typ.items))
return typ
[docs]
def track_statistics_values( # noqa: D102
self, tracking_fun: Callable[[RuntimeVariable, Any], None]
) -> None:
tracking_fun(
RuntimeVariable.AccessibleObjectsUnderTest,
self.num_accessible_objects_under_test(),
)
tracking_fun(RuntimeVariable.GeneratableTypes, len(self.get_all_generatable_types()))
cyclomatic_complexities = self.__compute_cyclomatic_complexities(
self.function_data_for_accessibles.values()
)
if cyclomatic_complexities is not None:
tracking_fun(RuntimeVariable.McCabeAST, json.dumps(cyclomatic_complexities))
tracking_fun(RuntimeVariable.LineNos, self.__linenos)
@staticmethod
def __compute_cyclomatic_complexities(
callable_data: typing.Iterable[CallableData],
) -> list[int]:
# Collect complexities only for callables that had an AST. Their minimal
# complexity is 1, the value None symbolises a callable that had no AST present,
# either because there is none or because it is an implicitly added function,
# such as a default constructor or the constructor of a base class.
return [
item.cyclomatic_complexity
for item in callable_data
if item.cyclomatic_complexity is not None
]
[docs]
class FilteredModuleTestCluster(TestCluster): # noqa: PLR0904
"""A test cluster wrapping another test cluster.
Delegates most methods to the wrapped delegate. This cluster filters out
accessible objects under test that are already fully covered, in order to focus
the search on areas that are not yet fully covered.
"""
@property
def generator_provider(self) -> GeneratorProvider: # noqa: D102
return self.__delegate.generator_provider
@property
def type_system(self) -> TypeSystem: # noqa: D102
return self.__delegate.type_system
[docs]
def update_return_type( # noqa: D102
self, accessible: GenericCallableAccessibleObject, new_type: ProperType
) -> None:
self.__delegate.update_return_type(accessible, new_type)
[docs]
def update_parameter_knowledge( # noqa: D102
self,
accessible: GenericCallableAccessibleObject,
param_name: str,
knowledge: tt.UsageTraceNode,
) -> None:
self.__delegate.update_parameter_knowledge(accessible, param_name, knowledge)
@property
def linenos(self) -> int: # noqa: D102
return self.__delegate.linenos
[docs]
def log_cluster_statistics(self) -> None: # noqa: D102
self.__delegate.log_cluster_statistics()
[docs]
def add_generator(self, generator: GenericAccessibleObject) -> None: # noqa: D102
self.__delegate.add_generator(generator)
[docs]
def add_accessible_object_under_test( # noqa: D102
self, objc: GenericAccessibleObject, data: CallableData
) -> None:
self.__delegate.add_accessible_object_under_test(objc, data)
[docs]
def add_modifier( # noqa: D102
self, typ: TypeInfo, obj: GenericAccessibleObject
) -> None:
self.__delegate.add_modifier(typ, obj)
@property
def function_data_for_accessibles( # noqa: D102
self,
) -> dict[GenericAccessibleObject, CallableData]:
return self.__delegate.function_data_for_accessibles
[docs]
def add_ml_data(self, obj: GenericAccessibleObject, data: MLCallableData) -> None: # noqa: D102
self.__delegate.add_ml_data(obj, data)
[docs]
def get_ml_data_for(self, obj: GenericAccessibleObject) -> MLCallableData | None: # noqa: D102
return self.__delegate.get_ml_data_for(obj)
[docs]
def track_statistics_values( # noqa: D102
self, tracking_fun: Callable[[RuntimeVariable, Any], None]
) -> None:
self.__delegate.track_statistics_values(tracking_fun)
def __init__( # noqa: D107
self,
delegate: ModuleTestCluster,
archive: arch.Archive,
subject_properties: SubjectProperties,
targets: OrderedSet[ff.TestCaseFitnessFunction],
) -> None:
self.__delegate = delegate
self.__subject_properties = subject_properties
existing_code_objects = {
metadata.code_object: code_object_id
for code_object_id, metadata in subject_properties.existing_code_objects.items()
}
self.__code_object_id_to_accessible_objects = {
existing_code_objects[acc.callable.__code__]: acc
for acc in delegate.accessible_objects_under_test
if isinstance(acc, GenericCallableAccessibleObject)
and hasattr(acc.callable, "__code__")
and acc.callable.__code__ in existing_code_objects
}
# Checking for __code__ is necessary, because the __init__ of a class that
# does not define __init__ points to some internal CPython stuff.
self.__accessible_to_targets: dict[GenericCallableAccessibleObject, OrderedSet] = {
acc: OrderedSet() for acc in self.__code_object_id_to_accessible_objects.values()
}
for target in targets:
if (acc := self.__get_accessible_object_for_target(target)) is not None:
targets_for_acc = self.__accessible_to_targets[acc]
targets_for_acc.add(target)
# Get informed by archive when a target is covered
archive.add_on_target_covered(self.on_target_covered)
def __get_accessible_object_for_target(
self, target: ff.TestCaseFitnessFunction
) -> GenericCallableAccessibleObject | None:
code_object_id: int | None = target.code_object_id
while code_object_id is not None:
if (
acc := self.__code_object_id_to_accessible_objects.get(code_object_id, None)
) is not None:
return acc
code_object_id = self.__subject_properties.existing_code_objects[
code_object_id
].parent_code_object_id
return None
[docs]
def on_target_covered(self, target: ff.TestCaseFitnessFunction) -> None:
"""A callback function to get informed by an archive when a target is covered.
Args:
target: The newly covered target
"""
acc = self.__get_accessible_object_for_target(target)
if acc is not None:
targets_for_acc = self.__accessible_to_targets.get(acc)
assert targets_for_acc is not None
targets_for_acc.remove(target)
if len(targets_for_acc) == 0:
self.__accessible_to_targets.pop(acc)
LOGGER.debug(
"Removed %s from test cluster because all targets within it have been covered.",
acc,
)
@property
def accessible_objects_under_test( # noqa: D102
self,
) -> OrderedSet[GenericAccessibleObject]:
accessibles = self.__accessible_to_targets.keys()
if len(accessibles) == 0:
# Should never happen, just in case everything is already covered?
return self.__delegate.accessible_objects_under_test
return OrderedSet(accessibles)
[docs]
def num_accessible_objects_under_test(self) -> int: # noqa: D102
return self.__delegate.num_accessible_objects_under_test()
[docs]
def get_generators_for( # noqa: D102
self, typ: ProperType
) -> OrderedSet[GenericAccessibleObject]:
return self.__delegate.get_generators_for(typ)
[docs]
def get_modifiers_for( # noqa: D102
self, typ: ProperType
) -> OrderedSet[GenericAccessibleObject]:
return self.__delegate.get_modifiers_for(typ)
@property
def generators( # noqa: D102
self,
) -> dict[ProperType, OrderedSet[GenericAccessibleObject]]:
return self.__delegate.generators
@property
def modifiers( # noqa: D102
self,
) -> dict[TypeInfo, OrderedSet[GenericAccessibleObject]]:
return self.__delegate.modifiers
[docs]
def get_random_accessible(self) -> GenericAccessibleObject | None: # noqa: D102
accessibles = self.__accessible_to_targets.keys()
if len(accessibles) == 0:
return self.__delegate.get_random_accessible()
return randomness.choice(OrderedSet(accessibles))
[docs]
def get_random_call_for( # noqa: D102
self, typ: ProperType
) -> GenericAccessibleObject:
return self.__delegate.get_random_call_for(typ)
[docs]
def get_all_generatable_types(self) -> list[ProperType]: # noqa: D102
return self.__delegate.get_all_generatable_types()
[docs]
def select_concrete_type(self, typ: ProperType) -> ProperType: # noqa: D102
return self.__delegate.select_concrete_type(typ)
def __get_mccabe_complexity(tree: AstroidFunctionDef | None) -> int | None:
if tree is None:
return None
try:
return mccabe_complexity(astroid_to_ast(tree))
except SyntaxError:
return None
def __is_constructor(method_name: str) -> bool:
return method_name == "__init__"
def __is_annotate(method_name: str) -> bool:
return method_name == "__annotate_func__"
def __is_protected(method_name: str) -> bool:
return method_name.startswith("_") and not method_name.startswith("__")
def __is_private(method_name: str) -> bool:
return method_name.startswith("__") and not method_name.endswith("__")
def __is_method_defined_in_class(class_: type | types.UnionType, method: object) -> bool:
return class_ == get_class_that_defined_method(method)
[docs]
@dataclasses.dataclass
class CallableData:
"""Provides all information on callables.
While the accessible is available for every callable, the other fields are only
filled for methods that are available in (Python) source code because their
information is retrieved from the abstract syntax tree.
Attributes:
accessible: the accessible object itself
tree: the AST of the callable, if any
description: the function description of the callable, if any
cyclomatic_complexity: the McCabe cyclomatic complexity of the callable, if any
"""
accessible: GenericAccessibleObject
tree: AstroidFunctionDef | None
description: FunctionDescription | None
cyclomatic_complexity: int | None
[docs]
@dataclasses.dataclass
class MLCallableData:
"""Provides ML-specific information on callables.
Attributes:
parameters: A dictionary of parameters, if any
generation_order: The generation order of the callable (can be empty)
"""
parameters: dict[str, MLParameter | None]
generation_order: list[str]
def _get_lambda_assigned_name(module_tree, lambda_lineno) -> str | None:
"""Retrieve the variable name of a lambda assignment.
Example:
For a lambda defined at line 10:
y = lambda: 42
this function will return "y" if the lambda node starts at line 10.
"""
for node in module_tree.body:
if isinstance(node, Assign) and len(node.targets) == 1:
target = node.targets[0]
if (
hasattr(target, "name")
and isinstance(node.value, Lambda)
and node.value.lineno == lambda_lineno
):
return target.name
return None
def __analyse_function(
*,
func_name: str,
func: FunctionType,
type_inference_provider: InferenceProvider,
module_tree: Module | None,
test_cluster: ModuleTestCluster,
add_to_test: bool,
) -> None:
if __is_private(func_name) or __is_protected(func_name):
LOGGER.debug("Skipping function %s from analysis", func_name)
return
if inspect.iscoroutinefunction(func) or inspect.isasyncgenfunction(func):
if add_to_test:
raise CoroutineFoundException("Found coroutine in SUT: %s", func_name)
# Coroutine outside the SUT are not problematic, just exclude them.
LOGGER.debug("Skipping coroutine %s outside of SUT", func_name)
return
LOGGER.debug("Analysing function %s", func_name)
inferred_signature = test_cluster.type_system.infer_type_info(
func,
type_inference_provider=type_inference_provider,
)
func_ast = get_function_node_from_ast(module_tree, func_name)
description = get_function_description(func_ast)
expected_exceptions = description.raises if description is not None else set()
cyclomatic_complexity = __get_mccabe_complexity(func_ast)
if getattr(func, "__name__", None) == "<lambda>":
if lambda_assigned_name := _get_lambda_assigned_name(
module_tree, func.__code__.co_firstlineno
):
func_name = lambda_assigned_name
func.__name__ = lambda_assigned_name
else:
# If the lambda itself has no name, we must not add it to the test cluster
# or else it will cause an exception during test export.
return
generic_function = GenericFunction(func, inferred_signature, expected_exceptions, func_name)
if config.configuration.pynguinml.ml_testing_enabled and module_tree is not None:
parameters: dict[str, MLParameter | None] = {}
generation_order: list[str] = []
try:
parameters, generation_order = tr.load_and_process_constraints(
module_tree.name, func_name, list(inferred_signature.original_parameters.keys())
)
except ConstraintValidationError as e:
LOGGER.warning("ConstraintValidationError occurred: %s. Skipping.", e)
ml_data = MLCallableData(
parameters=parameters,
generation_order=generation_order,
)
test_cluster.add_ml_data(generic_function, ml_data)
function_data = CallableData(
accessible=generic_function,
tree=func_ast,
description=description,
cyclomatic_complexity=cyclomatic_complexity,
)
test_cluster.add_generator(generic_function)
if add_to_test:
test_cluster.add_accessible_object_under_test(generic_function, function_data)
def __analyse_class(
*,
type_info: TypeInfo,
type_inference_provider: InferenceProvider,
module_tree: Module | None,
test_cluster: ModuleTestCluster,
add_to_test: bool,
) -> None:
LOGGER.debug("Analysing class %s", type_info)
class_ast = get_class_node_from_ast(module_tree, type_info.name)
__add_symbols(class_ast, type_info)
if type_info.raw_type is tuple:
# Tuple is problematic...
return
constructor_ast = get_function_node_from_ast(class_ast, "__init__")
description = get_function_description(constructor_ast)
expected_exceptions = description.raises if description is not None else set()
cyclomatic_complexity = __get_mccabe_complexity(constructor_ast)
if issubclass(type_info.raw_type, enum.Enum): # type: ignore[arg-type]
generic: GenericEnum | GenericConstructor = GenericEnum(type_info)
if isinstance(generic, GenericEnum) and len(generic.names) == 0:
LOGGER.debug(
"Skipping enum %s from test cluster, it has no fields.",
type_info.full_name,
)
return
else:
generic = GenericConstructor(
type_info,
test_cluster.type_system.infer_type_info(
type_info.raw_type.__init__, # type: ignore[misc]
type_inference_provider=type_inference_provider,
),
expected_exceptions,
)
generic.inferred_signature.return_type = test_cluster.type_system.convert_type_hint(
type_info.raw_type
)
if (
config.configuration.pynguinml.ml_testing_enabled
and type_info.raw_type.__module__ != "builtins"
and not isinstance(generic, GenericEnum)
):
parameters: dict[str, MLParameter | None] = {}
generation_order: list[str] = []
try:
parameters, generation_order = tr.load_and_process_constraints(
type_info.module,
type_info.name,
list(generic.inferred_signature.original_parameters.keys()),
)
except ConstraintValidationError as e:
LOGGER.warning("ConstraintValidationError occurred: %s. Skipping.", e)
ml_data = MLCallableData(
parameters=parameters,
generation_order=generation_order,
)
test_cluster.add_ml_data(generic, ml_data)
method_data = CallableData(
accessible=generic,
tree=constructor_ast,
description=description,
cyclomatic_complexity=cyclomatic_complexity,
)
if not (
type_info.is_abstract
or type_info.raw_type in COLLECTIONS
or type_info.raw_type in PRIMITIVES
):
# Don't add constructors for abstract classes and for builtins. We generate
# the latter ourselves.
test_cluster.add_generator(generic)
if add_to_test:
test_cluster.add_accessible_object_under_test(generic, method_data)
try:
methods_with_names = inspect.getmembers(type_info.raw_type, inspect.isfunction)
except Exception as ex: # noqa: BLE001
LOGGER.error("Could not get members for class %s: %s", type_info.full_name, str(ex))
return
for method_name, method in methods_with_names:
__analyse_method(
type_info=type_info,
method_name=method_name,
method=method,
type_inference_provider=type_inference_provider,
class_tree=class_ast,
test_cluster=test_cluster,
add_to_test=add_to_test,
)
# Some symbols are not interesting for us.
IGNORED_SYMBOLS: set[str] = {
"__new__",
"__init__",
"__del__",
"__repr__",
"__str__",
"__sizeof__",
"__getattribute__",
"__getattr__",
}
def __add_symbols(class_ast: ClassDef | None, type_info: TypeInfo) -> None:
"""Tries to infer what symbols can be found on an instance of the given class.
We also try to infer what attributes are defined in '__init__'.
Args:
class_ast: The AST Node of the class.
type_info: The type info.
"""
if class_ast is not None:
type_info.instance_attributes.update(tuple(class_ast.instance_attrs))
type_info.attributes.update(type_info.instance_attributes)
type_info.attributes.update(tuple(vars(type_info.raw_type)))
type_info.attributes.difference_update(IGNORED_SYMBOLS)
def __analyse_method(
*,
type_info: TypeInfo,
method_name: str,
method: (FunctionType | BuiltinFunctionType | WrapperDescriptorType | MethodDescriptorType),
type_inference_provider: InferenceProvider,
class_tree: ClassDef | None,
test_cluster: ModuleTestCluster,
add_to_test: bool,
) -> None:
if (
__is_annotate(method_name)
or __is_private(method_name)
or __is_protected(method_name)
or __is_constructor(method_name)
or not __is_method_defined_in_class(type_info.raw_type, method)
):
LOGGER.debug("Skipping method %s from analysis", method_name)
return
if inspect.iscoroutinefunction(method) or inspect.isasyncgenfunction(method):
if add_to_test:
raise CoroutineFoundException("Found coroutine in SUT: %s", method_name)
# Coroutine outside the SUT are not problematic, just exclude them.
LOGGER.debug("Skipping coroutine %s outside of SUT", method_name)
return
LOGGER.debug("Analysing method %s.%s", type_info.full_name, method_name)
inferred_signature = test_cluster.type_system.infer_type_info(
method,
type_inference_provider=type_inference_provider,
)
method_ast = get_function_node_from_ast(class_tree, method_name)
description = get_function_description(method_ast)
expected_exceptions = description.raises if description is not None else set()
cyclomatic_complexity = __get_mccabe_complexity(method_ast)
generic_method = GenericMethod(
type_info, method, inferred_signature, expected_exceptions, method_name
)
if config.configuration.pynguinml.ml_testing_enabled:
parameters: dict[str, MLParameter | None] = {}
generation_order: list[str] = []
callable_name = type_info.name + "." + method_name
try:
parameters, generation_order = tr.load_and_process_constraints(
type_info.module, callable_name, list(inferred_signature.original_parameters.keys())
)
except ConstraintValidationError as e:
LOGGER.warning("ConstraintValidationError occurred: %s. Skipping.", e)
ml_data = MLCallableData(
parameters=parameters,
generation_order=generation_order,
)
test_cluster.add_ml_data(generic_method, ml_data)
method_data = CallableData(
accessible=generic_method,
tree=method_ast,
description=description,
cyclomatic_complexity=cyclomatic_complexity,
)
test_cluster.add_generator(generic_method)
test_cluster.add_modifier(type_info, generic_method)
if add_to_test:
test_cluster.add_accessible_object_under_test(generic_method, method_data)
class _ParseResults(dict): # noqa: FURB189
def __missing__(self, key):
# Parse module on demand
res = self[key] = parse_module(key)
return res
def __resolve_dependencies(
root_module: _ModuleParseResult,
type_inference_provider: InferenceProvider,
test_cluster: ModuleTestCluster,
) -> None:
parse_results: dict[str, _ModuleParseResult] = _ParseResults()
parse_results[root_module.module_name] = root_module
# Provide a set of seen modules, classes and functions for fixed-point iteration
seen_modules: set[ModuleType] = set()
seen_classes: set[Any] = set()
seen_functions: set[Any] = set()
# Set of C-extension modules that are not whitelisted
dangerous_c_modules: set[str] = set()
# Always analyse builtins
__analyse_included_classes(
module=builtins,
root_module_name=root_module.module_name,
type_inference_provider=type_inference_provider,
test_cluster=test_cluster,
seen_classes=seen_classes,
parse_results=parse_results,
)
test_cluster.type_system.enable_numeric_tower()
# Start with root module, i.e., the module under test.
wait_list: queue.SimpleQueue[ModuleType] = queue.SimpleQueue()
wait_list.put(root_module.module)
while not wait_list.empty():
current_module = wait_list.get()
if current_module in seen_modules:
# Skip the module, we have already analysed it before
continue
if _is_blacklisted(current_module):
# Don't include anything from the blacklist
continue
# Check if the module contains c extensions that are not whitelisted
dangerous_c_modules.update(
__check_c_modules(
module=current_module,
)
)
# Analyze all classes found in the current module
__analyse_included_classes(
module=current_module,
root_module_name=root_module.module_name,
type_inference_provider=type_inference_provider,
test_cluster=test_cluster,
seen_classes=seen_classes,
parse_results=parse_results,
)
# Analyze all functions found in the current module
__analyse_included_functions(
module=current_module,
root_module_name=root_module.module_name,
type_inference_provider=type_inference_provider,
test_cluster=test_cluster,
seen_functions=seen_functions,
parse_results=parse_results,
)
# Collect the modules that are included by this module and add
# them for further processing.
for included_module in filter(inspect.ismodule, vars(current_module).values()):
wait_list.put(included_module)
# Take care that we know for future iterations that we have already analysed
# this module before
seen_modules.add(current_module)
LOGGER.info("Analyzed project to create test cluster")
LOGGER.info("Modules: %5i", len(seen_modules))
LOGGER.info("Functions: %5i", len(seen_functions))
LOGGER.info("Classes: %5i", len(seen_classes))
_handle_c_modules(dangerous_c_modules)
test_cluster.type_system.push_attributes_down()
def __analyse_included_classes(
*,
module: ModuleType,
root_module_name: str,
type_inference_provider: InferenceProvider,
test_cluster: ModuleTestCluster,
parse_results: dict[str, _ModuleParseResult],
seen_classes: set[type],
) -> None:
values = list(vars(module).values())
work_list = list(
filter(
lambda x: inspect.isclass(x) and not _is_blacklisted(x),
values,
)
)
# TODO(fk) inner classes?
while len(work_list) > 0:
current = work_list.pop(0)
if current in seen_classes:
continue
seen_classes.add(current)
type_info = test_cluster.type_system.to_type_info(current)
# Skip if the class is _ObjectProxyMethods, as it is broken
# since __module__ is not well defined on it.
if isinstance(current.__module__, property):
LOGGER.info("Skipping class that has a property __module__: %s", current)
continue
# Skip some C-extension modules that are not publicly accessible.
try:
results = parse_results[current.__module__]
except ModuleNotFoundError as error:
if getattr(current, "__file__", None) is None or Path(current.__file__).suffix in {
".so",
".pyd",
}:
LOGGER.info("C-extension module not found: %s", current.__module__)
continue
raise error
__analyse_class(
type_info=type_info,
type_inference_provider=type_inference_provider,
module_tree=results.syntax_tree,
test_cluster=test_cluster,
add_to_test=current.__module__ == root_module_name,
)
if hasattr(current, "__bases__"):
for base in current.__bases__:
# TODO(fk) base might be an instance.
# Ignored for now.
# Probably store Instance in graph instead of TypeInfo?
if isinstance(base, GenericAlias):
base = base.__origin__ # noqa: PLW2901
base_info = test_cluster.type_system.to_type_info(base)
test_cluster.type_system.add_subclass_edge(
super_class=base_info, sub_class=type_info
)
work_list.append(base)
def __analyse_included_functions(
*,
module: ModuleType,
root_module_name: str,
type_inference_provider: InferenceProvider,
test_cluster: ModuleTestCluster,
parse_results: dict[str, _ModuleParseResult],
seen_functions: set,
) -> None:
for current in filter(
lambda x: inspect.isfunction(x) and not _is_blacklisted(x),
vars(module).values(),
):
if current in seen_functions:
continue
seen_functions.add(current)
__analyse_function(
func_name=current.__qualname__,
func=current,
type_inference_provider=type_inference_provider,
module_tree=parse_results[current.__module__].syntax_tree,
test_cluster=test_cluster,
add_to_test=current.__module__ == root_module_name,
)
def __check_c_modules(
*,
module: ModuleType,
) -> set[str]:
"""Return the names of modules containing non-whitelisted C extensions.
Args:
module: The module to check
Returns:
A set of module names with non-whitelisted C code.
"""
non_whitelisted_modules = set()
# If the whole module file looks like a binary extension:
module_file = getattr(module, "__file__", "")
if module_file and Path(module_file).suffix in {".so", ".pyd"}:
if not _c_is_whitelisted(module):
non_whitelisted_modules.add(module.__name__)
return non_whitelisted_modules
# If the module is Python, inspect its members too:
for element in vars(module).values():
if inspect.isfunction(element) or inspect.isclass(element):
try:
inspect.getsource(element)
# Source is available => likely pure Python.
except Exception: # noqa: BLE001
# No source => likely compiled or builtin.
if not _c_is_whitelisted(module):
non_whitelisted_modules.add(module.__name__)
return non_whitelisted_modules
[docs]
def analyse_module(
parsed_module: _ModuleParseResult,
type_inference_strategy: TypeInferenceStrategy = TypeInferenceStrategy.TYPE_HINTS,
) -> ModuleTestCluster:
"""Analyses a module to build a test cluster.
Args:
parsed_module: The parsed module
type_inference_strategy: The type inference strategy to use.
Returns:
A test cluster for the module
"""
test_cluster = ModuleTestCluster(linenos=parsed_module.linenos)
type_provider = get_type_provider(
type_inference_strategy, parsed_module.module, test_cluster.type_system
)
__resolve_dependencies(
root_module=parsed_module,
type_inference_provider=type_provider,
test_cluster=test_cluster,
)
collect_provider_metrics(type_provider)
return test_cluster
[docs]
def generate_test_cluster(
module_name: str,
type_inference_strategy: TypeInferenceStrategy = TypeInferenceStrategy.TYPE_HINTS,
) -> ModuleTestCluster:
"""Generates a new test cluster from the given module.
Args:
module_name: The name of the root module
type_inference_strategy: Which type-inference strategy to use
Returns:
A new test cluster for the given module
"""
return analyse_module(parse_module(module_name), type_inference_strategy)
[docs]
def get_type_provider(
type_inference_strategy: TypeInferenceStrategy, module: ModuleType, type_system: TypeSystem
) -> InferenceProvider:
"""Get the initialised inference provider for the given strategy.
Args:
type_inference_strategy: The type inference strategy to use
module: The module to analyse (only needed for LLM-based inference)
type_system: The type system to use
Returns:
The type inference provider for the given strategy
"""
match type_inference_strategy:
case TypeInferenceStrategy.LLM:
callables = _collect_public_callables(module)
return LLMInference(callables, LLMProvider.OPENAI, type_system)
case TypeInferenceStrategy.TYPE_HINTS:
return HintInference()
case TypeInferenceStrategy.NONE:
return NoInference()
case TypeInferenceStrategy.TYPEEVALPY:
typeevalpy_data = _load_typeevalpy_data()
if typeevalpy_data is None:
LOGGER.warning(
"TypeEvalPy strategy selected but no valid data found. "
"Falling back to NoInference."
)
return NoInference()
return TypeEvalPyInference(typeevalpy_data)
case _:
LOGGER.error(
"Unknown type inference strategy: '%s'. Falling back to NoInference.",
type_inference_strategy,
)
return NoInference()
def _load_typeevalpy_data() -> ParsedTypeEvalPyData | None:
"""Load TypeEvalPy data from the configured JSON file path.
Returns:
Parsed TypeEvalPy data if available, None otherwise
"""
json_path = config.configuration.type_inference.typeevalpy_json_path
if not json_path:
return None
try:
return parse_json(json_path)
except (FileNotFoundError, ValueError, OSError) as e:
LOGGER.warning("Failed to load TypeEvalPy data from %s: %s", json_path, e)
return None
def _collect_public_callables(module: ModuleType) -> Sequence[Callable[..., Any]]:
"""Collects a list of all public accessibles in a module."""
callables = []
seen = set()
def add(obj):
if id(obj) not in seen:
callables.append(obj)
seen.add(id(obj))
for name, obj in vars(module).items():
if name.startswith("_") and name != "__init__":
continue
if inspect.isfunction(obj) and obj.__module__ == module.__name__:
add(obj)
for cls_name, cls in vars(module).items():
if (
cls_name.startswith("_")
or not inspect.isclass(cls)
or cls.__module__ != module.__name__
):
continue
for meth_name, member in inspect.getmembers(cls, predicate=inspect.isfunction):
if not meth_name.startswith("_"):
add(member)
return callables
[docs]
def collect_provider_metrics(typ_provider: InferenceProvider):
"""Collects metrics from the given type inference provider.
currently, this only works for LLM-based providers.
When using an LLM-based provider, collect the raw inferred
parameter strings per callable and the annotated parameter strings
as JSON and store in the LLMInferredSignatures runtime variable.
Other providers default all metrics to zero.
"""
metrics = typ_provider.get_metrics()
stat.track_output_variable(
RuntimeVariable.TypeInferenceInferredParameters, metrics.get("successful_inferences", 0)
)
stat.track_output_variable(
RuntimeVariable.TypeInferenceFailedParameters, metrics.get("failed_inferences", 0)
)
stat.track_output_variable(
RuntimeVariable.TypeInferenceLLMCalls, metrics.get("sent_requests", 0)
)
stat.track_output_variable(
RuntimeVariable.TypeInferenceLLMTime, metrics.get("total_setup_time", 0.0)
)
if isinstance(typ_provider, LLMInference):
try:
inferred_signatures: dict[str, dict] = {}
inference_map = typ_provider.get_inference_map()
callables = typ_provider.get_callables()
for func in callables:
# Build a stable key: module.qualname when possible
module_part = getattr(func, "__module__", "")
qualname_part = getattr(func, "__qualname__", None)
key = str(func) if qualname_part is None else f"{module_part}.{qualname_part}"
# Annotated parameter strings (prior), fallback to typing.Any
try:
prior = typ_provider.prior_types_for(func)
except (
Exception # noqa: BLE001
) as exc: # narrow catch for unexpected provider failures
LOGGER.debug("Could not obtain prior types for %s: %s", func, exc)
prior = {}
guessed_raw = inference_map.get(func, {})
params = set(prior.keys()) | set(guessed_raw.keys())
annotated_parameter_types: dict[str, str] = {}
guessed_parameter_types: dict[str, list[str]] = {}
for p in params:
if p in {"*args", "**kwargs"}:
annotated_parameter_types[p] = ANY_STR
else:
annotated_parameter_types[p] = prior.get(p, ANY_STR)
guess = guessed_raw.get(p, "")
if isinstance(guess, str) and guess.strip():
guessed_parameter_types[p] = [guess.strip()]
else:
guessed_parameter_types[p] = []
inferred_signatures[key] = {
"annotated_parameter_types": annotated_parameter_types,
"guessed_parameter_types": guessed_parameter_types,
"partial_type_matches": {},
"annotated_return_type": None,
"recorded_return_type": None,
}
stat.track_output_variable(
RuntimeVariable.LLMInferredSignatures, json.dumps(inferred_signatures)
)
except Exception as exc: # Catch at top-level to ensure metrics don't break analysis
LOGGER.exception("Could not collect LLM inferred signatures: %s", exc)
else:
LOGGER.debug(
"Type inference provider is not LLM-based, skipping inferred signatures collection."
)