Source code for automation_file.remote.http_download

"""SSRF-guarded HTTP downloader."""

from __future__ import annotations

import contextlib
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Any

import requests
from tqdm import tqdm

from automation_file.core.checksum import verify_checksum
from automation_file.core.progress import (
    CancelledException,
    ProgressReporter,
    progress_registry,
)
from automation_file.core.retry import retry_on_transient
from automation_file.exceptions import RetryExhaustedException, UrlValidationException
from automation_file.logging_config import file_automation_logger
from automation_file.remote.url_validator import validate_http_url

_DEFAULT_TIMEOUT_SECONDS = 15
_DEFAULT_CHUNK_SIZE = 1024 * 64
_MAX_RESPONSE_BYTES = 20 * 1024 * 1024

_RETRIABLE_EXCEPTIONS = (
    requests.exceptions.ConnectionError,
    requests.exceptions.Timeout,
    requests.exceptions.ChunkedEncodingError,
)


@dataclass(frozen=True)
class _StreamContext:
    """Bundle of per-download stream knobs; avoids 10+ positional args."""

    write_mode: str
    chunk_size: int
    start_byte: int
    total_size: int
    max_bytes: int
    reporter: ProgressReporter | None
    token: Any


@retry_on_transient(max_attempts=3, backoff_base=0.5, retriable=_RETRIABLE_EXCEPTIONS)
def _open_stream(
    file_url: str,
    timeout: int,
    start_byte: int = 0,
) -> requests.Response:
    headers = {"Range": f"bytes={start_byte}-"} if start_byte > 0 else None
    response = requests.get(
        file_url,
        stream=True,
        timeout=timeout,
        allow_redirects=False,
        headers=headers,
    )
    response.raise_for_status()
    return response


def _resume_layout(target: Path, resume: bool) -> tuple[Path, int, str]:
    part_path = target.with_suffix(target.suffix + ".part") if resume else target
    start_byte = part_path.stat().st_size if resume and part_path.exists() else 0
    write_mode = "ab" if start_byte > 0 else "wb"
    return part_path, start_byte, write_mode


def _open_stream_logged(file_url: str, timeout: int, start_byte: int) -> requests.Response | None:
    try:
        return _open_stream(file_url, timeout, start_byte=start_byte)
    except RetryExhaustedException as error:
        file_automation_logger.error("download_file retries exhausted: %r", error)
    except requests.exceptions.HTTPError as error:
        file_automation_logger.error("download_file HTTP error: %r", error)
    except requests.exceptions.RequestException as error:
        file_automation_logger.error("download_file request error: %r", error)
    return None


def _stream_to_disk(
    response: requests.Response,
    part_path: Path,
    target: Path,
    ctx: _StreamContext,
) -> int | None:
    """Stream ``response`` into ``part_path``. Returns bytes written, or None on failure."""
    written = ctx.start_byte
    with (
        open(part_path, ctx.write_mode) as output,  # pylint: disable=unspecified-encoding
        _progress(ctx.total_size, str(target)) as progress,
    ):
        if ctx.start_byte > 0:
            progress.update(ctx.start_byte)
        for chunk in response.iter_content(chunk_size=ctx.chunk_size):
            if ctx.token is not None:
                ctx.token.raise_if_cancelled()
            if not chunk:
                continue
            written += len(chunk)
            if written > ctx.max_bytes:
                file_automation_logger.error(
                    "download_file aborted: stream exceeded %d bytes", ctx.max_bytes
                )
                return None
            output.write(chunk)
            progress.update(len(chunk))
            if ctx.reporter is not None:
                ctx.reporter.update(len(chunk))
    return written


def _verify_and_finalize(
    target: Path,
    expected_sha256: str | None,
    reporter: ProgressReporter | None,
) -> bool:
    if expected_sha256 and not verify_checksum(target, expected_sha256):
        file_automation_logger.error("download_file checksum mismatch for %s; removing", target)
        with contextlib.suppress(OSError):
            target.unlink()
        if reporter is not None:
            reporter.finish(status="checksum_failed")
        return False
    if reporter is not None:
        reporter.finish(status="done")
    return True


def _reject_oversize(total_size: int, max_bytes: int) -> bool:
    if total_size > max_bytes:
        file_automation_logger.error(
            "download_file rejected: content-length %d > %d", total_size, max_bytes
        )
        return True
    return False


def _make_reporter(
    progress_name: str | None,
    total_size: int,
    start_byte: int,
) -> tuple[ProgressReporter | None, Any]:
    if not progress_name:
        return None, None
    reporter, token = progress_registry.create(progress_name, total=total_size or None)
    if start_byte > 0:
        reporter.update(start_byte)
    return reporter, token


def _run_stream(
    response: requests.Response,
    part_path: Path,
    target: Path,
    ctx: _StreamContext,
    progress_name: str | None,
) -> int | None:
    try:
        written = _stream_to_disk(response, part_path, target, ctx)
    except CancelledException:
        file_automation_logger.warning("download_file cancelled: %s", progress_name)
        if ctx.reporter is not None:
            ctx.reporter.finish(status="cancelled")
        return None
    except OSError as error:
        file_automation_logger.error("download_file write error: %r", error)
        if ctx.reporter is not None:
            ctx.reporter.finish(status="error")
        return None
    if written is None and ctx.reporter is not None:
        ctx.reporter.finish(status="aborted")
    return written


# pylint: disable-next=too-many-arguments,too-many-positional-arguments,too-many-locals,too-many-return-statements
[docs] def download_file( file_url: str, file_name: str, chunk_size: int = _DEFAULT_CHUNK_SIZE, timeout: int = _DEFAULT_TIMEOUT_SECONDS, max_bytes: int = _MAX_RESPONSE_BYTES, progress_name: str | None = None, resume: bool = False, expected_sha256: str | None = None, ) -> bool: """Download ``file_url`` to ``file_name`` with progress display. Validates the URL against SSRF rules, disables redirects, enforces a size cap, retries transient network errors up to three times, and uses default TLS verification. Pass ``progress_name`` to register the transfer with :data:`~automation_file.core.progress.progress_registry` so it can be polled or cancelled from the GUI. When ``resume=True`` the download streams into ``<file_name>.part`` and, if that file already exists from a previous interrupted run, sends ``Range: bytes=<existing>-`` to continue where it stopped (append mode); the ``.part`` file is atomically renamed to ``file_name`` on success. Pass ``expected_sha256`` to verify the finished file; a mismatch is logged, the download is removed, and the function returns ``False``. Returns True on success. """ try: validate_http_url(file_url) except UrlValidationException as error: file_automation_logger.error("download_file rejected URL: %r", error) return False target = Path(file_name) part_path, start_byte, write_mode = _resume_layout(target, resume) response = _open_stream_logged(file_url, timeout, start_byte) if response is None: return False total_size = start_byte + int(response.headers.get("content-length", 0)) if _reject_oversize(total_size, max_bytes): return False reporter, token = _make_reporter(progress_name, total_size, start_byte) ctx = _StreamContext( write_mode=write_mode, chunk_size=chunk_size, start_byte=start_byte, total_size=total_size, max_bytes=max_bytes, reporter=reporter, token=token, ) written = _run_stream(response, part_path, target, ctx, progress_name) if written is None: return False if resume and part_path != target: os.replace(part_path, target) if not _verify_and_finalize(target, expected_sha256, reporter): return False file_automation_logger.info("download_file: %s -> %s (%d bytes)", file_url, target, written) return True
class _NullBar: def update(self, _n: int) -> None: ... def __enter__(self): return self def __exit__(self, exc_type, exc, tb): return False def _progress(total: int, label: str): if total > 0: return tqdm(total=total, unit="B", unit_scale=True, desc=label) return _NullBar()