Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- CLI: fix accepting keyword arguments ([#420](https://github.com/Lightning-AI/utilities/pull/420))
- Scripts: fix CLI parsing ([#419](https://github.com/Lightning-AI/utilities/pull/419))


Expand Down
19 changes: 11 additions & 8 deletions src/lightning_utilities/cli/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,18 @@ def main() -> None:
from jsonargparse import auto_cli, set_parsing_settings

set_parsing_settings(parse_optionals_as_positionals=True)
auto_cli({
"requirements": {
"_help": "Manage requirements files.",
"prune-pkgs": prune_packages_in_requirements,
"set-oldest": replace_oldest_version,
"replace-pkg": replace_package_in_requirements,
auto_cli(
{
"requirements": {
"_help": "Manage requirements files.",
"prune-pkgs": prune_packages_in_requirements,
"set-oldest": replace_oldest_version,
"replace-pkg": replace_package_in_requirements,
},
"version": _get_version,
},
"version": _get_version,
})
as_positional=False,
)


if __name__ == "__main__":
Expand Down
66 changes: 66 additions & 0 deletions tests/unittests/cli/test_command_line.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import subprocess
from pathlib import Path

import pytest


def test_version():
"""Prints the help message for the requirements commands."""
return_code = subprocess.call(["python", "-mlightning_utilities.cli", "version"]) # noqa: S607
assert return_code == 0


@pytest.mark.parametrize("args", ["positional", "optional"])
class TestRequirements:
"""Test requirements commands."""

BASE_CMD = ("python", "-m", "lightning_utilities.cli", "requirements")
REQUIREMENTS_SAMPLE = """
# This is sample requirements file
# with multi line comments

torchvision >=0.13.0, <0.16.0 # sample # comment
gym[classic,control] >=0.17.0, <0.27.0
ipython[all] <8.15.0 # strict
torchmetrics >=0.10.0, <1.3.0
deepspeed >=0.8.2, <=0.9.3; platform_system != "Windows" # strict
"""

def _create_requirements_file(self, local_path: Path, filename: str = "testing-cli-requirements.txt"):
"""Create a sample requirements file."""
req_file = local_path / filename
with open(req_file, "w", encoding="utf8") as fopen:
fopen.write(self.REQUIREMENTS_SAMPLE)
return str(req_file)

def _build_command(self, subcommand: str, cli_params: tuple, arg_style: str):
"""Build the command for the CLI."""
if arg_style == "positional":
return list(self.BASE_CMD) + [subcommand] + [value for _, value in cli_params]
if arg_style == "optional":
return list(self.BASE_CMD) + [subcommand] + [f"--{key}={value}" for key, value in cli_params]
raise ValueError(f"Unknown test configuration: {arg_style}")

def test_requirements_prune_pkgs(self, args, tmp_path):
"""Prune packages from requirements files."""
req_file = self._create_requirements_file(tmp_path)
cli_params = (("packages", "ipython"), ("req_files", req_file))
cmd = self._build_command("prune-pkgs", cli_params, args)
return_code = subprocess.call(cmd) # noqa: S603
assert return_code == 0

def test_requirements_set_oldest(self, args, tmp_path):
"""Set the oldest version of packages in requirement files."""
req_file = self._create_requirements_file(tmp_path, "requirements.txt")
cli_params = (("req_files", req_file),)
cmd = self._build_command("set-oldest", cli_params, args)
return_code = subprocess.call(cmd) # noqa: S603
assert return_code == 0

def test_requirements_replace_pkg(self, args, tmp_path):
"""Replace a package in requirements files."""
req_file = self._create_requirements_file(tmp_path, "requirements.txt")
cli_params = (("old_package", "torchvision"), ("new_package", "torchtext"), ("req_files", req_file))
cmd = self._build_command("replace-pkg", cli_params, args)
return_code = subprocess.call(cmd) # noqa: S603
assert return_code == 0
Loading