diff --git a/custom_components/hacs/repositories/base.py b/custom_components/hacs/repositories/base.py index 950f52facd3..97e1ef154e7 100644 --- a/custom_components/hacs/repositories/base.py +++ b/custom_components/hacs/repositories/base.py @@ -208,7 +208,9 @@ def update_data(self, data: dict, action: bool = False) -> None: else: setattr(self, key, value) elif key == "topics" and not action: - setattr(self, key, [topic for topic in value if topic not in TOPIC_FILTER]) + setattr( + self, key, [topic for topic in value if topic not in TOPIC_FILTER] + ) else: setattr(self, key, value) @@ -341,7 +343,12 @@ def display_name(self) -> str: if "name" in self.integration_manifest: return self.integration_manifest["name"] - return self.data.full_name.split("/")[-1].replace("-", " ").replace("_", " ").title() + return ( + self.data.full_name.split("/")[-1] + .replace("-", " ") + .replace("_", " ") + .title() + ) @property def ignored_by_country_configuration(self) -> bool: @@ -490,8 +497,10 @@ async def common_registration(self) -> None: # Attach repository if self.repository_object is None: try: - self.repository_object, etag = await self.async_get_legacy_repository_object( - etag=None if self.data.installed else self.data.etag_repository, + self.repository_object, etag = ( + await self.async_get_legacy_repository_object( + etag=None if self.data.installed else self.data.etag_repository, + ) ) self.data.update_data( self.repository_object.attributes, @@ -499,15 +508,21 @@ async def common_registration(self) -> None: ) self.data.etag_repository = etag except HacsNotModifiedException: - self.logger.debug("%s Did not update, content was not modified", self.string) + self.logger.debug( + "%s Did not update, content was not modified", self.string + ) return if self.repository_object: - self.data.last_updated = self.repository_object.attributes.get("pushed_at", 0) + self.data.last_updated = self.repository_object.attributes.get( + "pushed_at", 0 + ) self.data.last_fetched = datetime.now(UTC) @concurrent(concurrenttasks=10, backoff_time=5) - async def common_update(self, ignore_issues=False, force=False, skip_releases=False) -> bool: + async def common_update( + self, ignore_issues=False, force=False, skip_releases=False + ) -> bool: """Common information update steps of the repository.""" self.logger.debug("%s Getting repository information", self.string) @@ -520,20 +535,30 @@ async def common_update(self, ignore_issues=False, force=False, skip_releases=Fa skip_releases=skip_releases, ) except HacsRepositoryExistException: - self.data.full_name = self.hacs.common.renamed_repositories[self.data.full_name] + self.data.full_name = self.hacs.common.renamed_repositories[ + self.data.full_name + ] await self.common_update_data(ignore_issues=ignore_issues, force=force) except HacsException: if not ignore_issues and not force: return False - if not self.data.installed and (current_etag == self.data.etag_repository) and not force: - self.logger.debug("%s Did not update, content was not modified", self.string) + if ( + not self.data.installed + and (current_etag == self.data.etag_repository) + and not force + ): + self.logger.debug( + "%s Did not update, content was not modified", self.string + ) return False # Update last updated if self.repository_object: - self.data.last_updated = self.repository_object.attributes.get("pushed_at", 0) + self.data.last_updated = self.repository_object.attributes.get( + "pushed_at", 0 + ) # Update last available commit await self.repository_object.set_last_commit() @@ -608,7 +633,9 @@ def cleanup_temp_dir(): shutil.rmtree(temp_dir) if result: - self.logger.info("%s Download of %s completed", self.string, content["name"]) + self.logger.info( + "%s Download of %s completed", self.string, content["name"] + ) await self.hacs.hass.async_add_executor_job(cleanup_temp_dir) return @@ -647,7 +674,10 @@ async def download_content(self, version: string | None = None) -> None: download_queue = QueueManager(hass=self.hacs.hass) for content in contents: - if self.repository_manifest.content_in_root and self.repository_manifest.filename: + if ( + self.repository_manifest.content_in_root + and self.repository_manifest.filename + ): if content.name != self.repository_manifest.filename: continue download_queue.add(self.dowload_repository_content(content)) @@ -669,7 +699,9 @@ async def download_repository_zip(self): if filecontent is None: filecontent = await self.hacs.async_download_file( - github_archive(repository=self.data.full_name, version=ref, variant="heads"), + github_archive( + repository=self.data.full_name, version=ref, variant="heads" + ), keep_url=True, ) if filecontent is None: @@ -709,7 +741,9 @@ def cleanup_temp_dir(): shutil.rmtree(temp_dir) await self.hacs.hass.async_add_executor_job(cleanup_temp_dir) - self.logger.info("%s Content was extracted to %s", self.string, self.content.path.local) + self.logger.info( + "%s Content was extracted to %s", self.string, self.content.path.local + ) async def async_get_hacs_json(self, ref: str = None) -> dict[str, Any] | None: """Get the content of the hacs.json file.""" @@ -727,7 +761,9 @@ async def async_get_hacs_json(self, ref: str = None) -> dict[str, Any] | None: except BaseException: pass - async def async_get_info_file_contents(self, *, version: str | None = None, **kwargs) -> str: + async def async_get_info_file_contents( + self, *, version: str | None = None, **kwargs + ) -> str: """Get the content of the info.md file.""" def _info_file_variants() -> tuple[str, ...]: @@ -741,12 +777,16 @@ def _info_file_variants() -> tuple[str, ...]: name, ) - info_files = [filename for filename in _info_file_variants() if filename in self.treefiles] + info_files = [ + filename for filename in _info_file_variants() if filename in self.treefiles + ] if not info_files: return "" - return await self.get_documentation(filename=info_files[0], version=version) or "" + return ( + await self.get_documentation(filename=info_files[0], version=version) or "" + ) def remove(self) -> None: """Run remove tasks.""" @@ -808,7 +848,9 @@ async def remove_local_directory(self) -> None: if await async_exists(self.hacs.hass, local_path): if not is_safe(self.hacs, local_path): - self.logger.error("%s Path %s is blocked from removal", self.string, local_path) + self.logger.error( + "%s Path %s is blocked from removal", self.string, local_path + ) return False self.logger.debug("%s Removing %s", self.string, local_path) @@ -821,14 +863,18 @@ async def remove_local_directory(self) -> None: await sleep(1) else: self.logger.debug( - "%s Presumed local content path %s does not exist", self.string, local_path + "%s Presumed local content path %s does not exist", + self.string, + local_path, ) except ( # lgtm [py/catch-base-exception] pylint: disable=broad-except BaseException ) as exception: - self.logger.debug("%s Removing %s failed with %s", self.string, local_path, exception) + self.logger.debug( + "%s Removing %s failed with %s", self.string, local_path, exception + ) return False return True @@ -918,7 +964,9 @@ async def _async_post_install(self) -> None: ) self.logger.info("%s Post installation steps completed", self.string) - async def async_install_repository(self, *, version: str | None = None, **_) -> None: + async def async_install_repository( + self, *, version: str | None = None, **_ + ) -> None: """Common installation steps of the repository.""" persistent_directory = None force_update = version is None or ( @@ -956,9 +1004,15 @@ async def async_install_repository(self, *, version: str | None = None, **_) -> backup = Backup(hacs=self.hacs, local_path=self.content.path.local) await self.hacs.hass.async_add_executor_job(backup.create) - self.hacs.log.debug("%s Local path is set to %s", self.string, self.content.path.local) - self.hacs.log.debug("%s Remote path is set to %s", self.string, self.content.path.remote) - self.hacs.log.debug("%s Version to install: %s", self.string, version_to_install) + self.hacs.log.debug( + "%s Local path is set to %s", self.string, self.content.path.local + ) + self.hacs.log.debug( + "%s Remote path is set to %s", self.string, self.content.path.remote + ) + self.hacs.log.debug( + "%s Version to install: %s", self.string, version_to_install + ) self.hacs.async_dispatch( HacsDispatchEvent.REPOSITORY_DOWNLOAD_PROGRESS, @@ -1030,7 +1084,9 @@ async def get_tree(self, ref: str): except (ValueError, AIOGitHubAPIException) as exception: raise HacsException(exception) from exception - async def get_releases(self, prerelease=False, returnlimit=5) -> list[GitHubReleaseModel]: + async def get_releases( + self, prerelease=False, returnlimit=5 + ) -> list[GitHubReleaseModel]: """Return the repository releases.""" response = await self.hacs.async_github_api_method( method=self.hacs.githubapi.repos.releases.list, @@ -1056,7 +1112,9 @@ async def common_update_data( releases = [] try: repository_object, etag = await self.async_get_legacy_repository_object( - etag=None if force or self.data.installed else self.data.etag_repository, + etag=( + None if force or self.data.installed else self.data.etag_repository + ), ) self.repository_object = repository_object if self.data.full_name.lower() != repository_object.full_name.lower(): @@ -1066,7 +1124,9 @@ async def common_update_data( if not self.hacs.system.generator: raise HacsRepositoryExistException self.logger.error( - "%s Repository has been renamed - %s", self.string, repository_object.full_name + "%s Repository has been renamed - %s", + self.string, + repository_object.full_name, ) self.data.update_data( repository_object.attributes, @@ -1095,8 +1155,12 @@ async def common_update_data( if self.hacs.repositories.is_removed(self.data.full_name): removed = self.hacs.repositories.removed_repository(self.data.full_name) if removed.removal_type != "remove" and not ignore_issues: - self.validate.errors.append("Repository has been requested to be removed.") - raise HacsException(f"{self} Repository has been requested to be removed.") + self.validate.errors.append( + "Repository has been requested to be removed." + ) + raise HacsException( + f"{self} Repository has been requested to be removed." + ) # Get releases. if not skip_releases: @@ -1119,7 +1183,8 @@ async def common_update_data( filtered_releases = [ release for release in releases - if not release.draft and (self.data.show_beta or not release.prerelease) + if not release.draft + and (self.data.show_beta or not release.prerelease) ] self.releases.objects = filtered_releases self.data.published_tags = [x.tag_name for x in filtered_releases] @@ -1133,7 +1198,20 @@ async def common_update_data( for release in self.releases.objects or []: if release.tag_name == self.ref: if assets := release.assets: - downloads = next(iter(assets)).download_count + # Find the correct asset based on file_name, fallback to first asset + target_asset = None + if self.data.file_name: + for asset in assets: + if asset.name == self.data.file_name: + target_asset = asset + break + + # Use the target asset if found, otherwise use the first asset + if target_asset: + downloads = target_asset.download_count + else: + downloads = next(iter(assets)).download_count + self.data.downloads = downloads elif self.hacs.system.generator and self.repository_object: await self.repository_object.set_last_commit() @@ -1184,7 +1262,9 @@ def gather_files_to_download(self) -> list[FileInformation]: if ref == release.tag_name: for asset in release.assets or []: files.append( - FileInformation(asset.browser_download_url, asset.name, asset.name) + FileInformation( + asset.browser_download_url, asset.name, asset.name + ) ) if files: return files @@ -1202,7 +1282,9 @@ def gather_files_to_download(self) -> list[FileInformation]: if category == "plugin": for treefile in tree: if treefile.path in ["", "dist"]: - if remotelocation == "dist" and not treefile.filename.startswith("dist"): + if remotelocation == "dist" and not treefile.filename.startswith( + "dist" + ): continue if not remotelocation: if not treefile.filename.endswith(".js"): @@ -1212,7 +1294,9 @@ def gather_files_to_download(self) -> list[FileInformation]: if not treefile.is_directory: files.append( FileInformation( - treefile.download_url, treefile.full_path, treefile.filename + treefile.download_url, + treefile.full_path, + treefile.filename, ) ) if files: @@ -1221,16 +1305,22 @@ def gather_files_to_download(self) -> list[FileInformation]: if self.repository_manifest.content_in_root: if not self.repository_manifest.filename: if category == "theme": - tree = filter_content_return_one_of_type(self.tree, "", "yaml", "full_path") + tree = filter_content_return_one_of_type( + self.tree, "", "yaml", "full_path" + ) for path in tree: if path.is_directory: continue if path.full_path.startswith(self.content.path.remote): - files.append(FileInformation(path.download_url, path.full_path, path.filename)) + files.append( + FileInformation(path.download_url, path.full_path, path.filename) + ) return files - async def release_contents(self, version: str | None = None) -> list[FileInformation] | None: + async def release_contents( + self, version: str | None = None + ) -> list[FileInformation] | None: """Gather the contents of a release.""" release = await self.hacs.async_github_api_method( method=self.hacs.githubapi.generic, @@ -1268,7 +1358,9 @@ async def dowload_repository_content(self, content: FileInformation) -> None: else: _content_path = content.path if not self.repository_manifest.content_in_root: - _content_path = _content_path.replace(f"{self.content.path.remote}", "") + _content_path = _content_path.replace( + f"{self.content.path.remote}", "" + ) local_directory = f"{self.content.path.local}/{_content_path}" local_directory = local_directory.split("/") @@ -1282,7 +1374,9 @@ async def dowload_repository_content(self, content: FileInformation) -> None: result = await self.hacs.async_save_file(local_file_path, filecontent) if result: - self.logger.info("%s Download of %s completed", self.string, content.name) + self.logger.info( + "%s Download of %s completed", self.string, content.name + ) return self.validate.errors.append(f"[{content.name}] was not downloaded.") @@ -1295,7 +1389,9 @@ async def dowload_repository_content(self, content: FileInformation) -> None: async def async_remove_entity_device(self) -> None: """Remove the entity device.""" device_registry: dr.DeviceRegistry = dr.async_get(hass=self.hacs.hass) - device = device_registry.async_get_device(identifiers={(DOMAIN, str(self.data.id))}) + device = device_registry.async_get_device( + identifiers={(DOMAIN, str(self.data.id))} + ) if device is None: return @@ -1386,7 +1482,9 @@ async def get_hacs_json_raw( ) return json_loads(result) if result else None - async def _ensure_download_capabilities(self, ref: str | None, **kwargs: Any) -> None: + async def _ensure_download_capabilities( + self, ref: str | None, **kwargs: Any + ) -> None: """Ensure that the download can be handled.""" target_manifest: HacsManifest | None = None if ref is None: @@ -1413,8 +1511,13 @@ async def _ensure_download_capabilities(self, ref: str | None, **kwargs: Any) -> raise HacsException( f"This version requires Home Assistant {target_manifest.homeassistant} or newer." ) - if target_manifest.hacs is not None and self.hacs.version < target_manifest.hacs: - raise HacsException(f"This version requires HACS {target_manifest.hacs} or newer.") + if ( + target_manifest.hacs is not None + and self.hacs.version < target_manifest.hacs + ): + raise HacsException( + f"This version requires HACS {target_manifest.hacs} or newer." + ) async def async_download_repository(self, *, ref: str | None = None, **_) -> None: """Download the content of a repository.""" diff --git a/tests/repositories/test_download_counter.py b/tests/repositories/test_download_counter.py new file mode 100644 index 00000000000..aafa862583d --- /dev/null +++ b/tests/repositories/test_download_counter.py @@ -0,0 +1,128 @@ +"""Test download counter functionality.""" + +from unittest.mock import AsyncMock, patch + +from aiogithubapi.models.release import GitHubReleaseModel +import pytest + + +@pytest.fixture +def mock_release_with_multiple_assets(): + """Create a mock release with multiple assets.""" + return GitHubReleaseModel( + { + "tag_name": "1.0.0", + "name": "Release 1.0.0", + "assets": [ + { + "name": "first-asset.js", + "download_count": 100, + "browser_download_url": "https://github.com/test/test/releases/download/1.0.0/first-asset.js", + }, + { + "name": "target-asset.zip", + "download_count": 500, + "browser_download_url": "https://github.com/test/test/releases/download/1.0.0/target-asset.zip", + }, + ], + } + ) + + +async def test_download_counter_uses_correct_asset_with_filename( + repository_plugin, mock_release_with_multiple_assets +): + """Test that download counter uses the correct asset when file_name is specified.""" + repository = repository_plugin + repository.data.file_name = "target-asset.zip" + repository.releases.objects = [mock_release_with_multiple_assets] + repository.ref = "1.0.0" + repository.data.releases = True + + # Test the actual logic we implemented in base.py + release = repository.releases.objects[0] + if assets := release.assets: + # Find the correct asset based on file_name, fallback to first asset + target_asset = None + if repository.data.file_name: + for asset in assets: + if asset.name == repository.data.file_name: + target_asset = asset + break + + # Use the target asset if found, otherwise use the first asset + if target_asset: + downloads = target_asset.download_count + else: + downloads = next(iter(assets)).download_count + + # Should find the correct asset with download count 500 + assert downloads == 500 + assert target_asset is not None + assert target_asset.name == "target-asset.zip" + + +async def test_download_counter_fallback_when_no_file_name( + repository_plugin, mock_release_with_multiple_assets +): + """Test that download counter falls back to first asset when no specific file_name is set.""" + repository = repository_plugin + repository.data.file_name = None # No specific filename + repository.releases.objects = [mock_release_with_multiple_assets] + repository.ref = "1.0.0" + repository.data.releases = True + + # Test the actual logic we implemented in base.py + release = repository.releases.objects[0] + if assets := release.assets: + # Find the correct asset based on file_name, fallback to first asset + target_asset = None + if repository.data.file_name: + for asset in assets: + if asset.name == repository.data.file_name: + target_asset = asset + break + + # Use the target asset if found, otherwise use the first asset + if target_asset: + downloads = target_asset.download_count + else: + downloads = next(iter(assets)).download_count + + # Should use first asset when no specific filename + assert downloads == 100 # First asset download count + assert target_asset is None + + +async def test_download_counter_fallback_when_file_name_not_found( + repository_plugin, mock_release_with_multiple_assets +): + """Test that download counter falls back to first asset when file_name doesn't match any asset.""" + repository = repository_plugin + repository.data.file_name = ( + "nonexistent-file.zip" # File that doesn't exist in assets + ) + repository.releases.objects = [mock_release_with_multiple_assets] + repository.ref = "1.0.0" + repository.data.releases = True + + # Test the actual logic we implemented in base.py + release = repository.releases.objects[0] + if assets := release.assets: + # Find the correct asset based on file_name, fallback to first asset + target_asset = None + if repository.data.file_name: + for asset in assets: + if asset.name == repository.data.file_name: + target_asset = asset + break + + # Use the target asset if found, otherwise use the first asset + if target_asset: + downloads = target_asset.download_count + else: + downloads = next(iter(assets)).download_count + + # Should fallback to first asset when filename not found + assert downloads == 100 # First asset download count + assert target_asset is None