diff --git a/tests/test_updater_ng.py b/tests/test_updater_ng.py index 3b183cfbb3..b853d40dda 100644 --- a/tests/test_updater_ng.py +++ b/tests/test_updater_ng.py @@ -18,7 +18,6 @@ from securesystemslib.signer import SSlibSigner from tests import utils -from tuf import ngclient from tuf.api import exceptions from tuf.api.metadata import ( Metadata, @@ -28,6 +27,7 @@ Targets, Timestamp, ) +from tuf.ngclient import Updater, UpdaterConfig logger = logging.getLogger(__name__) @@ -107,7 +107,7 @@ def setUp(self) -> None: self.dl_dir = tempfile.mkdtemp(dir=self.tmp_test_dir) # Creating a repository instance. The test cases will use this client # updater to refresh metadata, fetch target files, etc. - self.updater = ngclient.Updater( + self.updater = Updater( metadata_dir=self.client_directory, metadata_base_url=self.metadata_url, target_dir=self.dl_dir, @@ -242,16 +242,14 @@ def test_implicit_refresh_with_only_local_root(self) -> None: def test_both_target_urls_not_set(self) -> None: # target_base_url = None and Updater._target_base_url = None - updater = ngclient.Updater( - self.client_directory, self.metadata_url, self.dl_dir - ) + updater = Updater(self.client_directory, self.metadata_url, self.dl_dir) info = TargetFile(1, {"sha256": ""}, "targetpath") with self.assertRaises(ValueError): updater.download_target(info) def test_no_target_dir_no_filepath(self) -> None: # filepath = None and Updater.target_dir = None - updater = ngclient.Updater(self.client_directory, self.metadata_url) + updater = Updater(self.client_directory, self.metadata_url) info = TargetFile(1, {"sha256": ""}, "targetpath") with self.assertRaises(ValueError): updater.find_cached_target(info) @@ -323,6 +321,27 @@ def test_non_existing_target_file(self) -> None: with self.assertRaises(exceptions.DownloadHTTPError): self.updater.download_target(info) + def test_user_agent(self) -> None: + # test default + self.updater.refresh() + session = next(iter(self.updater._fetcher._sessions.values())) + ua = session.headers["User-Agent"] + self.assertEqual(ua[:4], "tuf/") + + # test custom UA + updater = Updater( + self.client_directory, + self.metadata_url, + self.dl_dir, + self.targets_url, + config=UpdaterConfig(app_user_agent="MyApp/1.2.3"), + ) + updater.refresh() + session = next(iter(updater._fetcher._sessions.values())) + ua = session.headers["User-Agent"] + + self.assertEqual(ua[:16], "MyApp/1.2.3 tuf/") + if __name__ == "__main__": utils.configure_test_logging(sys.argv) diff --git a/tuf/ngclient/_internal/requests_fetcher.py b/tuf/ngclient/_internal/requests_fetcher.py index 1994729fe0..c931b85a0f 100644 --- a/tuf/ngclient/_internal/requests_fetcher.py +++ b/tuf/ngclient/_internal/requests_fetcher.py @@ -10,7 +10,7 @@ # can be moved out of _internal once sigstore-python 1.0 is not relevant. import logging -from typing import Dict, Iterator, Tuple +from typing import Dict, Iterator, Optional, Tuple from urllib import parse # Imports @@ -35,7 +35,10 @@ class RequestsFetcher(FetcherInterface): """ def __init__( - self, socket_timeout: int = 30, chunk_size: int = 400000 + self, + socket_timeout: int = 30, + chunk_size: int = 400000, + app_user_agent: Optional[str] = None, ) -> None: # http://docs.python-requests.org/en/master/user/advanced/#session-objects: # @@ -56,6 +59,7 @@ def __init__( # Default settings self.socket_timeout: int = socket_timeout # seconds self.chunk_size: int = chunk_size # bytes + self.app_user_agent = app_user_agent def _fetch(self, url: str) -> Iterator[bytes]: """Fetch the contents of HTTP/HTTPS url from a remote server. @@ -138,6 +142,8 @@ def _get_session(self, url: str) -> requests.Session: self._sessions[session_index] = session ua = f"tuf/{tuf.__version__} {session.headers['User-Agent']}" + if self.app_user_agent is not None: + ua = f"{self.app_user_agent} {ua}" session.headers["User-Agent"] = ua logger.debug("Made new session %s", session_index) diff --git a/tuf/ngclient/config.py b/tuf/ngclient/config.py index 943018fca4..8019c4d26d 100644 --- a/tuf/ngclient/config.py +++ b/tuf/ngclient/config.py @@ -5,6 +5,7 @@ from dataclasses import dataclass from enum import Flag, unique +from typing import Optional @unique @@ -39,6 +40,8 @@ class UpdaterConfig: envelope_type: Configures deserialization and verification mode of TUF metadata. Per default, it is treated as traditional canonical JSON -based TUF Metadata. + app_user_agent: Application user agent, e.g. "MyApp/1.0.0". This will be + prefixed to ngclient user agent when the default fetcher is used. """ max_root_rotations: int = 32 @@ -49,3 +52,4 @@ class UpdaterConfig: targets_max_length: int = 5000000 # bytes prefix_targets_with_hash: bool = True envelope_type: EnvelopeType = EnvelopeType.METADATA + app_user_agent: Optional[str] = None diff --git a/tuf/ngclient/updater.py b/tuf/ngclient/updater.py index 666e54d320..145074aaa9 100644 --- a/tuf/ngclient/updater.py +++ b/tuf/ngclient/updater.py @@ -93,11 +93,15 @@ def __init__( else: self._target_base_url = _ensure_trailing_slash(target_base_url) - # Read trusted local root metadata - data = self._load_local_metadata(Root.type) - self._fetcher = fetcher or requests_fetcher.RequestsFetcher() self.config = config or UpdaterConfig() + if fetcher is not None: + self._fetcher = fetcher + else: + self._fetcher = requests_fetcher.RequestsFetcher( + app_user_agent=self.config.app_user_agent + ) + supported_envelopes = [EnvelopeType.METADATA, EnvelopeType.SIMPLE] if self.config.envelope_type not in supported_envelopes: raise ValueError( @@ -105,6 +109,9 @@ def __init__( f"got '{self.config.envelope_type}'" ) + # Read trusted local root metadata + data = self._load_local_metadata(Root.type) + self._trusted_set = trusted_metadata_set.TrustedMetadataSet( data, self.config.envelope_type )