"""
This script generates a pip-compatible requirements file (e.g. requirements/core-requirements.txt)
based on a YAML dependencies specification (e.g. requirements/core-requirements.yaml). These
requirements files are then read by `setup.py` to specify dependencies for MLflow wheel
installation.
"""

import argparse
import sys

import yaml

AUTOGENERATED_HEADER = """\
# This file is automatically generated by running the dev/generate_requirements.py script. Do not
# manually modify this file. Instead, modify the corresponding requirements YAML file and run the
# dev/generate_requirements.py script to regenerate this file.
"""


def parse_args(args):
    parser = argparse.ArgumentParser(
        description=(
            "Generate a pip-compatible requirements.txt file from an MLflow"
            " requirements.yaml specification"
        )
    )
    parser.add_argument(
        "--requirements-yaml-location",
        required=True,
        help="Local file path of the requirements.yaml specification.",
    )
    parser.add_argument(
        "--requirements-txt-output-location",
        required=True,
        help="Local output path for the generated requirements.txt file.",
    )

    return parser.parse_args(args)


def validate_requirements_yaml(requirements_yaml):
    assert isinstance(
        requirements_yaml, dict
    ), "requirements.yaml contents must be a YAML dictionary"
    for package_entry in requirements_yaml.values():
        assert isinstance(package_entry, dict), (
            "Entry in requirements.yaml does not have required dictionary"
            f" structure: {package_entry}"
        )
        pip_release = package_entry.get("pip_release")
        assert pip_release is not None and isinstance(pip_release, str), (
            "Entry in requirements.yaml does not define a valid 'pip_release'"
            f" string value: {package_entry}"
        )
        max_major_version = package_entry.get("max_major_version")
        assert max_major_version is not None and isinstance(max_major_version, int), (
            "Entry in requirements.yaml does not specify a valid 'max_major_version'"
            f" integer value: {package_entry}"
        )
        if "minimum" in package_entry:
            assert isinstance(package_entry["minimum"], str), (
                "Entry in requirements.yaml contains an invalid 'minimum' version"
                f" string specification: {package_entry}"
            )
        if "unsupported" in package_entry:
            assert isinstance(package_entry["unsupported"], list) and all(
                isinstance(unsupported_entry, str)
                for unsupported_entry in package_entry["unsupported"]
            ), (
                "Entry in requirements.yaml contains an invalid 'unsupported' versions"
                " specification. Unsupported versions should be specified as lists of strings:"
                f" {package_entry}"
            )
        if "markers" in package_entry:
            assert isinstance(package_entry["markers"], str), (
                "Entry in requirements.yaml contains invalid 'markers' string"
                f" specification: {package_entry}"
            )


def generate_requirements_txt_content(requirements_yaml):
    requirement_strs = []
    for package_entry in requirements_yaml.values():
        pip_release = package_entry["pip_release"]
        version_specs = []

        extras = (
            f"[{','.join(extras)}]" if (extras := package_entry.get("extras")) else ""
        )

        max_major_version = package_entry["max_major_version"]
        version_specs += [f"<{max_major_version + 1}"]

        min_version = package_entry.get("minimum")
        version_specs += [f">={min_version}"] if min_version else []

        unsupported_versions = package_entry.get("unsupported", [])
        version_specs += [f"!={version}" for version in unsupported_versions]

        markers = package_entry.get("markers")
        markers = f"; {markers}" if markers else ""

        requirement_str = f"{pip_release}{extras}{','.join(version_specs)}{markers}"
        requirement_strs.append(requirement_str)

    return "\n".join(requirement_strs)


def main(args):
    args = parse_args(args)
    with open(args.requirements_yaml_location) as f:
        requirements_yaml = yaml.load(f, Loader=yaml.SafeLoader)
    validate_requirements_yaml(requirements_yaml)
    requirements_txt_content = generate_requirements_txt_content(requirements_yaml)
    with open(args.requirements_txt_output_location, "w") as f:
        # Write requirements file content with a trailing newline
        f.write(AUTOGENERATED_HEADER + "\n")
        f.write(requirements_txt_content + "\n")


if __name__ == "__main__":
    main(sys.argv[1:])
