Skip to content

Commit

Permalink
Use safe_infer in _unpack_args and _unpack_keywords (#2117)
Browse files Browse the repository at this point in the history
  • Loading branch information
tusharsadhwani authored Apr 15, 2023
1 parent c1f8ca8 commit 495581f
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 29 deletions.
4 changes: 4 additions & 0 deletions ChangeLog
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ Release date: TBA

* Remove dependency on ``wrapt``.

* ``CallSite._unpack_args`` and ``CallSite._unpack_keywords`` now use ``safe_infer()`` for
better inference and fewer false positives.

Closes pylint-dev/pylint#8544

What's New in astroid 2.15.3?
=============================
Expand Down
27 changes: 4 additions & 23 deletions astroid/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from astroid.bases import Instance
from astroid.context import CallContext, InferenceContext
from astroid.exceptions import InferenceError, NoDefault
from astroid.helpers import safe_infer
from astroid.util import Uninferable, UninferableBase


Expand Down Expand Up @@ -91,27 +92,14 @@ def _unpack_keywords(self, keywords, context: InferenceContext | None = None):
for name, value in keywords:
if name is None:
# Then it's an unpacking operation (**)
try:
inferred = next(value.infer(context=context))
except InferenceError:
values[name] = Uninferable
continue
except StopIteration:
continue

inferred = safe_infer(value, context=context)
if not isinstance(inferred, nodes.Dict):
# Not something we can work with.
values[name] = Uninferable
continue

for dict_key, dict_value in inferred.items:
try:
dict_key = next(dict_key.infer(context=context))
except InferenceError:
values[name] = Uninferable
continue
except StopIteration:
continue
dict_key = safe_infer(dict_key, context=context)
if not isinstance(dict_key, nodes.Const):
values[name] = Uninferable
continue
Expand All @@ -134,14 +122,7 @@ def _unpack_args(self, args, context: InferenceContext | None = None):
context.extra_context = self.argument_context_map
for arg in args:
if isinstance(arg, nodes.Starred):
try:
inferred = next(arg.value.infer(context=context))
except InferenceError:
values.append(Uninferable)
continue
except StopIteration:
continue

inferred = safe_infer(arg.value, context=context)
if isinstance(inferred, UninferableBase):
values.append(Uninferable)
continue
Expand Down
55 changes: 49 additions & 6 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,22 @@

import pytest

from astroid import Slice, arguments, helpers, nodes, objects, test_utils, util
from astroid import (
Slice,
Uninferable,
arguments,
helpers,
nodes,
objects,
test_utils,
util,
)
from astroid import decorators as decoratorsmod
from astroid.arguments import CallSite
from astroid.bases import BoundMethod, Instance, UnboundMethod, UnionType
from astroid.builder import AstroidBuilder, _extract_single_node, extract_node, parse
from astroid.const import PY38_PLUS, PY39_PLUS, PY310_PLUS
from astroid.context import InferenceContext
from astroid.context import CallContext, InferenceContext
from astroid.exceptions import (
AstroidTypeError,
AttributeInferenceError,
Expand Down Expand Up @@ -1443,10 +1452,9 @@ def get_context_data(self, **kwargs):
"""
node = extract_node(code)
assert isinstance(node, nodes.NodeNG)
result = node.inferred()
assert len(result) == 2
assert isinstance(result[0], nodes.Dict)
assert result[1] is util.Uninferable
results = node.inferred()
assert len(results) == 2
assert all(isinstance(result, nodes.Dict) for result in results)

def test_python25_no_relative_import(self) -> None:
ast = resources.build_file("data/package/absimport.py")
Expand Down Expand Up @@ -5296,6 +5304,41 @@ def test_duplicated_keyword_arguments(self) -> None:
site = self._call_site_from_call(ast_node)
self.assertIn("f", site.duplicated_keywords)

def test_call_site_uninferable(self) -> None:
code = """
def get_nums():
nums = ()
if x == '1':
nums = (1, 2)
return nums
def add(x, y):
return x + y
nums = get_nums()
if x:
kwargs = {1: bar}
else:
kwargs = {}
if nums:
add(*nums)
print(**kwargs)
"""
# Test that `*nums` argument should be Uninferable
ast = parse(code, __name__)
*_, add_call, print_call = list(ast.nodes_of_class(nodes.Call))
nums_arg = add_call.args[0]
add_call_site = self._call_site_from_call(add_call)
self.assertEqual(add_call_site._unpack_args([nums_arg]), [Uninferable])

print_call_site = self._call_site_from_call(print_call)
keywords = CallContext(print_call.args, print_call.keywords).keywords
self.assertEqual(
print_call_site._unpack_keywords(keywords), {None: Uninferable}
)


class ObjectDunderNewTest(unittest.TestCase):
def test_object_dunder_new_is_inferred_if_decorator(self) -> None:
Expand Down

0 comments on commit 495581f

Please sign in to comment.