diff --git a/piptools/resolver.py b/piptools/resolver.py index f9d3d1250..3fa736c6d 100644 --- a/piptools/resolver.py +++ b/piptools/resolver.py @@ -206,6 +206,7 @@ def __init__( prereleases: Optional[bool] = False, clear_caches: bool = False, allow_unsafe: bool = False, + cuts: Optional[Set[str]] = None, ) -> None: """ This class resolves a given set of constraints (a collection of @@ -220,6 +221,7 @@ def __init__( self.clear_caches = clear_caches self.allow_unsafe = allow_unsafe self.unsafe_constraints: Set[InstallRequirement] = set() + self.cuts = cuts or set() options = self.repository.options if "legacy-resolver" not in options.deprecated_features_enabled: @@ -353,7 +355,14 @@ def _resolve_one_round(self) -> Tuple[bool, Set[InstallRequirement]]: their_constraints: List[InstallRequirement] = [] with log.indentation(): for best_match in best_matches: - their_constraints.extend(self._iter_dependencies(best_match)) + for dep in self._iter_dependencies(best_match): + # Must iterate even if we're going to cut all dependencies, + # so our dependency cache gets populated. + pair = f"{best_match.name}:{dep.name}" + if best_match.name in self.cuts or pair in self.cuts: + log.debug(f"Cutting dependency {pair}") + else: + their_constraints.append(dep) # Grouping constraints to make clean diff between rounds theirs = set(self._group_constraints(their_constraints)) diff --git a/piptools/scripts/compile.py b/piptools/scripts/compile.py index 0c98c6b2d..3f6134e17 100755 --- a/piptools/scripts/compile.py +++ b/piptools/scripts/compile.py @@ -245,6 +245,13 @@ def _get_default_option(option_name: str) -> Any: default=True, help="Add options to generated file", ) +@click.option( + "--cut-deps", + multiple=True, + help="Ignore a package's dependencies. May be used more than once. " + "Pass just the package name to ignore all of its dependencies. " + "Pass pkg-name:dep-name to ignore just one dependency.", +) def cli( ctx: click.Context, verbose: int, @@ -279,6 +286,7 @@ def cli( resolver_name: str, emit_index_url: bool, emit_options: bool, + cut_deps: Tuple[str, ...], ) -> None: """Compiles requirements.txt from requirements.in specs.""" log.verbosity = verbose - quiet @@ -483,6 +491,7 @@ def cli( cache=DependencyCache(cache_dir), clear_caches=rebuild, allow_unsafe=allow_unsafe, + cuts=set(cut_deps), ) results = resolver.resolve(max_rounds=max_rounds) hashes = resolver.resolve_hashes(results) if generate_hashes else None diff --git a/tests/test_resolver.py b/tests/test_resolver.py index 0e9b2fef1..67ac1d04b 100644 --- a/tests/test_resolver.py +++ b/tests/test_resolver.py @@ -250,6 +250,81 @@ def test_resolver__allows_unsafe_deps( assert output == {str(line) for line in expected} +@pytest.mark.parametrize( + ( + "input", + "cuts", + "expected", + ), + ( + # No cuts, get all recursive dependencies. + ( + ["flask==0.10.1"], + set(), + { + "markupsafe==0.23 (from jinja2==2.7.3->flask==0.10.1)", + "itsdangerous==0.24 (from flask==0.10.1)", + "werkzeug==0.10.4 (from flask==0.10.1)", + "flask==0.10.1", + "jinja2==2.7.3 (from flask==0.10.1)", + }, + ), + # Cut all of flask's dependencies. Get only flask. + ( + ["flask==0.10.1"], + {"flask"}, + { + "flask==0.10.1", + }, + ), + # Cut flask's dependency on Jinja2. Get the remaining dependencies. + ( + ["flask==0.10.1"], + {"flask:Jinja2"}, + { + "itsdangerous==0.24 (from flask==0.10.1)", + "werkzeug==0.10.4 (from flask==0.10.1)", + "flask==0.10.1", + }, + ), + # Again cut flask's dependency on Jinja2, but now also install another + # package that depends on Jinja2. Now we do get Jinja2, among others. + ( + ["flask==0.10.1", "ipython[notebook]==2.1.0"], + {"flask:Jinja2"}, + { + "jinja2==2.7.3 (from ipython[notebook]==2.1.0)", + "tornado==3.2.2 (from ipython[notebook]==2.1.0)", + "itsdangerous==0.24 (from flask==0.10.1)", + "markupsafe==0.23 (from jinja2==2.7.3->ipython[notebook]==2.1.0)", + "pyzmq==2.1.12 (from ipython[notebook]==2.1.0)", + "ipython[notebook]==2.1.0", + "gnureadline==6.3.3 (from ipython[notebook]==2.1.0)", + "flask==0.10.1", + "werkzeug==0.10.4 (from flask==0.10.1)", + }, + ), + ), +) +def test_resolver__cut_deps( + resolver, + from_line, + input, + cuts, + expected, +): + input = [line if isinstance(line, tuple) else (line, False) for line in input] + input = [from_line(req[0], constraint=req[1]) for req in input] + resolver = resolver( + input, + cuts=cuts, + ) + output = resolver.resolve() + output = {str(line) for line in output} + + assert output == expected + + def test_resolver__max_number_rounds_reached(resolver, from_line): """ Resolver should raise an exception if max round has been reached.