diff --git a/nemo_curator/stages/text/deduplication/semantic.py b/nemo_curator/stages/text/deduplication/semantic.py index d4e096dfa..7361fc4ad 100644 --- a/nemo_curator/stages/text/deduplication/semantic.py +++ b/nemo_curator/stages/text/deduplication/semantic.py @@ -21,7 +21,7 @@ 3. Optional duplicate removal based on identified duplicates """ -import os +import posixpath import time from dataclasses import dataclass, field from typing import Any, Literal @@ -179,14 +179,14 @@ def __post_init__(self): self.cache_path = self.cache_path or self.output_path # Intermediate paths - self.embeddings_path = os.path.join(self.cache_path, "embeddings") - self.semantic_dedup_path = os.path.join(self.cache_path, "semantic_dedup") + self.embeddings_path = posixpath.join(self.cache_path, "embeddings") + self.semantic_dedup_path = posixpath.join(self.cache_path, "semantic_dedup") # Output paths - self.duplicates_path = None if self.eps is None else os.path.join(self.output_path, "duplicates") + self.duplicates_path = None if self.eps is None else posixpath.join(self.output_path, "duplicates") self.deduplicated_output_path = ( - None if not self.perform_removal else os.path.join(self.output_path, "deduplicated") + None if not self.perform_removal else posixpath.join(self.output_path, "deduplicated") ) - self.id_generator_state_file = os.path.join(self.output_path, "semantic_id_generator.json") + self.id_generator_state_file = posixpath.join(self.output_path, "semantic_id_generator.json") self._validate_config() diff --git a/nemo_curator/stages/text/download/base/download.py b/nemo_curator/stages/text/download/base/download.py index 86687963b..6834efc52 100644 --- a/nemo_curator/stages/text/download/base/download.py +++ b/nemo_curator/stages/text/download/base/download.py @@ -13,11 +13,13 @@ # limitations under the License. import os +import posixpath import subprocess from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Any +import fsspec from loguru import logger from nemo_curator.stages.base import ProcessingStage @@ -37,12 +39,14 @@ def __init__(self, download_dir: str, verbose: bool = False): """ self._download_dir = download_dir self._verbose = verbose - os.makedirs(download_dir, exist_ok=True) + # Use fsspec for cloud-compatible directory creation + fs, _ = fsspec.core.url_to_fs(download_dir) + fs.makedirs(download_dir, exist_ok=True) def _check_s5cmd_installed(self) -> bool: """Check if s5cmd is installed.""" try: - subprocess.run(["s5cmd", "version"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=False) # noqa: S603, S607 + subprocess.run(["s5cmd", "version"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=False) # noqa: S607 except FileNotFoundError: return False else: @@ -87,14 +91,20 @@ def download(self, url: str) -> str | None: """ # Generate output filename output_name = self._get_output_filename(url) - output_file = os.path.join(self._download_dir, output_name) + output_file = posixpath.join(self._download_dir, output_name) temp_file = output_file + ".tmp" + # Use fsspec for cloud-compatible file operations + fs, _ = fsspec.core.url_to_fs(output_file) + # If final file exists and is non-empty, assume it's complete - if os.path.exists(output_file) and os.path.getsize(output_file) > 0: - if self._verbose: - logger.info(f"File: {output_file} exists. Not downloading") - return output_file + if fs.exists(output_file): + file_info = fs.info(output_file) + file_size = file_info.get("size", 0) + if file_size > 0: + if self._verbose: + logger.info(f"File: {output_file} exists. Not downloading") + return output_file # Download to temporary file success, error_message = self._download_to_path(url, temp_file) @@ -103,8 +113,16 @@ def download(self, url: str) -> str | None: # Download successful, atomically move temp file to final location os.rename(temp_file, output_file) if self._verbose: - file_size = os.path.getsize(output_file) - logger.info(f"Successfully downloaded to {output_file} ({file_size} bytes)") + # Try to get file size for logging, but don't fail if we can't + try: + fs, _ = fsspec.core.url_to_fs(output_file) + file_info = fs.info(output_file) + file_size = file_info.get("size", 0) + logger.info(f"Successfully downloaded to {output_file} ({file_size} bytes)") + except (OSError, KeyError, ValueError): + # If we can't get file size, just log without size + logger.info(f"Successfully downloaded to {output_file}") + logger.debug(f"Could not retrieve file size for {output_file}") return output_file else: # Download failed diff --git a/nemo_curator/stages/text/download/base/iterator.py b/nemo_curator/stages/text/download/base/iterator.py index 9b5e93890..e4b7aecc5 100644 --- a/nemo_curator/stages/text/download/base/iterator.py +++ b/nemo_curator/stages/text/download/base/iterator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os +import posixpath from abc import ABC, abstractmethod from collections.abc import Iterator from dataclasses import dataclass @@ -89,8 +89,8 @@ def process(self, task: FileGroupTask) -> DocumentBatch: if self.record_limit and record_count >= self.record_limit: break if self.add_filename_column: - # TODO: Support cloud storage https://github.com/NVIDIA-NeMo/Curator/issues/779 - record_dict[self.filename_col] = os.path.basename(file_path) # type: ignore[reportReturnType] + # Use posixpath for cloud storage compatibility + record_dict[self.filename_col] = posixpath.basename(file_path) # type: ignore[reportReturnType] records.append(record_dict) record_count += 1 diff --git a/nemo_curator/stages/text/filters/fasttext_filter.py b/nemo_curator/stages/text/filters/fasttext_filter.py index 75e6a070c..79df4b5fe 100644 --- a/nemo_curator/stages/text/filters/fasttext_filter.py +++ b/nemo_curator/stages/text/filters/fasttext_filter.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os - import fasttext +import fsspec import numpy as np from nemo_curator.stages.text.filters.doc_filter import DocumentFilter @@ -32,7 +31,8 @@ def __init__(self, model_path: str | None = None, label: str = "__label__hq", al self._name = "fasttext_quality_filter" def model_check_or_download(self) -> None: - if not os.path.exists(self._model_path): + fs, _ = fsspec.core.url_to_fs(self._model_path) + if not fs.exists(self._model_path): msg = f"Model file {self._model_path} not found" raise FileNotFoundError(msg) @@ -66,7 +66,8 @@ def __init__(self, model_path: str | None = None, min_langid_score: float = 0.3) self._name = "lang_id" def model_check_or_download(self) -> None: - if not os.path.exists(self._model_path): + fs, _ = fsspec.core.url_to_fs(self._model_path) + if not fs.exists(self._model_path): msg = f"Model file {self._model_path} not found" raise FileNotFoundError(msg) diff --git a/nemo_curator/stages/text/filters/heuristic_filter.py b/nemo_curator/stages/text/filters/heuristic_filter.py index cfccce2c0..0124a6b64 100644 --- a/nemo_curator/stages/text/filters/heuristic_filter.py +++ b/nemo_curator/stages/text/filters/heuristic_filter.py @@ -16,6 +16,7 @@ import tarfile from typing import Literal +import fsspec import huggingface_hub import requests from platformdirs import user_cache_dir @@ -789,7 +790,8 @@ def _download_histograms(self) -> None: raise requests.exceptions.RequestException(msg) # Open a file to write the content - os.makedirs(self._cache_dir, exist_ok=True) + fs, _ = fsspec.core.url_to_fs(self._cache_dir) + fs.makedirs(self._cache_dir, exist_ok=True) download_dest_path = os.path.join(self._cache_dir, "histograms.tar.gz") with open(download_dest_path, "wb") as file: file.write(response.content) diff --git a/nemo_curator/stages/text/utils/text_utils.py b/nemo_curator/stages/text/utils/text_utils.py index b706fad72..ec530b65d 100644 --- a/nemo_curator/stages/text/utils/text_utils.py +++ b/nemo_curator/stages/text/utils/text_utils.py @@ -13,7 +13,7 @@ # limitations under the License. import ast -import os +import posixpath import string import tokenize import warnings @@ -167,7 +167,7 @@ def get_docstrings(source: str, module: str = "") -> list[str]: """Parse Python source code from file or string and print docstrings.""" if hasattr(source, "read"): filename = getattr(source, "name", module) - module = os.path.splitext(os.path.basename(filename))[0] + module = posixpath.splitext(posixpath.basename(filename))[0] source = source.read() docstrings = sorted(parse_docstrings(source), key=lambda x: (NODE_TYPES.get(type(x[0])), x[1])) diff --git a/tests/stages/text/test_cloud_compatibility_fixes.py b/tests/stages/text/test_cloud_compatibility_fixes.py new file mode 100644 index 000000000..bbf56d42d --- /dev/null +++ b/tests/stages/text/test_cloud_compatibility_fixes.py @@ -0,0 +1,257 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for cloud compatibility fixes in Text Components. + +This module tests that Text Components properly use fsspec and posixpath +for cloud storage URIs instead of os/pathlib operations. +""" + +import posixpath +from io import StringIO +from unittest.mock import Mock, patch + +from nemo_curator.stages.text.utils.text_utils import get_docstrings + + +class TestTextUtilsCloudFixes: + """Test cloud compatibility fixes in text_utils.py.""" + + def test_get_docstrings_with_cloud_uri(self): + """Test that get_docstrings handles cloud URIs correctly.""" + # Create a mock file-like object with a cloud URI name + source_code = ''' +def example_function(): + """This is a docstring.""" + pass +''' + mock_file = StringIO(source_code) + # Simulate a cloud URI filename + mock_file.name = "s3://bucket/path/to/script.py" + + # This should work without errors using posixpath operations + result = get_docstrings(mock_file) + + # Verify the result contains the expected docstring + assert len(result) > 0 + assert "This is a docstring." in str(result) + + def test_get_docstrings_filename_extraction_patterns(self): + """Test filename extraction works with various cloud URI patterns.""" + test_cases = [ + ("s3://bucket/path/to/file.py", "file"), + ("gs://my-bucket/deep/nested/path/script.py", "script"), + ("abfs://container@account.dfs.core.windows.net/data/code.py", "code"), + ("https://example.com/api/v1/source.py", "source"), + ("/local/path/local_file.py", "local_file"), # Local files should still work + ] + + for uri, expected_module_name in test_cases: + # Test the pattern that's now used in the fixed code + module_name = posixpath.splitext(posixpath.basename(uri))[0] + assert module_name == expected_module_name, f"Failed for URI: {uri}" + + +class TestSemanticDeduplicationCloudFixes: + """Test cloud compatibility fixes in semantic deduplication.""" + + def test_path_construction_patterns(self): + """Test that path construction uses posixpath for cloud compatibility.""" + # Test the patterns now used in the fixed semantic.py + base_path = "s3://bucket/cache" + output_path = "gs://bucket/output" + + # These are the patterns now used in the fixed code + embeddings_path = posixpath.join(base_path, "embeddings") + semantic_dedup_path = posixpath.join(base_path, "semantic_dedup") + duplicates_path = posixpath.join(output_path, "duplicates") + deduplicated_path = posixpath.join(output_path, "deduplicated") + state_file = posixpath.join(output_path, "semantic_id_generator.json") + + # Verify the paths are constructed correctly + assert embeddings_path == "s3://bucket/cache/embeddings" + assert semantic_dedup_path == "s3://bucket/cache/semantic_dedup" + assert duplicates_path == "gs://bucket/output/duplicates" + assert deduplicated_path == "gs://bucket/output/deduplicated" + assert state_file == "gs://bucket/output/semantic_id_generator.json" + + def test_complex_cloud_uri_handling(self): + """Test complex cloud URI scenarios.""" + test_uris = [ + "s3://my-bucket/path/with/multiple/levels/", + "gs://another-bucket/data/2024/01/15/", + "abfs://container@storage.dfs.core.windows.net/datasets/processed/", + ] + + for base_uri in test_uris: + # Test subdirectory creation patterns + subdir = posixpath.join(base_uri, "embeddings") + assert subdir.startswith(base_uri) + assert subdir.endswith("embeddings") + + +class TestDownloadCloudFixes: + """Test cloud compatibility fixes in download modules.""" + + @patch("fsspec.core.url_to_fs") + def test_download_file_operations(self, mock_url_to_fs: Mock) -> None: + """Test that download operations use fsspec for cloud URIs.""" + # Mock fsspec filesystem + mock_fs = Mock() + mock_fs.makedirs.return_value = None + mock_fs.exists.return_value = True + mock_fs.info.return_value = {"size": 1024} + mock_url_to_fs.return_value = (mock_fs, "bucket/path") + + # Import after patching to ensure mock is used + from nemo_curator.stages.text.download.base.download import DocumentDownloader + + # Create a concrete subclass for testing + class TestDownloader(DocumentDownloader): + def _get_output_filename(self, _url: str) -> str: + return "test_file.txt" + + def _download_to_path(self, _url: str, _path: str) -> tuple[bool, str]: + return True, "" + + # Test cloud URI download directory + cloud_download_dir = "s3://test-bucket/downloads/" + downloader = TestDownloader(cloud_download_dir) + + # Verify the downloader was created and fsspec was called for directory creation + assert downloader is not None + mock_url_to_fs.assert_called() + mock_fs.makedirs.assert_called_with(cloud_download_dir, exist_ok=True) + + def test_filename_extraction_from_cloud_paths(self): + """Test filename extraction from cloud paths.""" + test_cases = [ + ("s3://bucket/path/to/file.txt", "file.txt"), + ("gs://my-bucket/data/document.pdf", "document.pdf"), + ("abfs://container@account.dfs.core.windows.net/files/archive.zip", "archive.zip"), + ("https://example.com/downloads/data.json", "data.json"), + ] + + for cloud_path, expected_filename in test_cases: + # Test the pattern now used in the fixed iterator.py + filename = posixpath.basename(cloud_path) + assert filename == expected_filename, f"Failed for path: {cloud_path}" + + +class TestFilterCloudFixes: + """Test cloud compatibility fixes in filter modules.""" + + @patch("fsspec.core.url_to_fs") + def test_fasttext_filter_model_check(self, mock_url_to_fs: Mock) -> None: + """Test that FastText filter uses fsspec for model file checks.""" + # Mock fsspec filesystem + mock_fs = Mock() + mock_fs.exists.return_value = True + mock_url_to_fs.return_value = (mock_fs, "bucket/path") + + # Import after patching + from nemo_curator.stages.text.filters.fasttext_filter import FastTextQualityFilter + + # Test cloud URI model path + cloud_model_path = "s3://models/fasttext_quality.bin" + filter_instance = FastTextQualityFilter(model_path=cloud_model_path) + + # This should not raise an exception with fsspec + filter_instance.model_check_or_download() + + # Verify fsspec was used + mock_url_to_fs.assert_called_with(cloud_model_path) + mock_fs.exists.assert_called_with(cloud_model_path) + + @patch("fsspec.core.url_to_fs") + def test_heuristic_filter_cache_directory(self, mock_url_to_fs: Mock) -> None: + """Test that heuristic filter uses fsspec for cache directory creation.""" + # Mock fsspec filesystem + mock_fs = Mock() + mock_fs.makedirs.return_value = None + mock_url_to_fs.return_value = (mock_fs, "bucket/path") + + # We can't easily test the full heuristic filter due to dependencies, + # but we can test the pattern directly + cache_dir = "s3://bucket/cache/" + + # This is the pattern now used in the fixed code + fs, _ = mock_url_to_fs(cache_dir) + fs.makedirs(cache_dir, exist_ok=True) + + # Verify fsspec was used + mock_url_to_fs.assert_called_with(cache_dir) + mock_fs.makedirs.assert_called_with(cache_dir, exist_ok=True) + + +class TestCloudCompatibilityIntegration: + """Integration tests for cloud compatibility across components.""" + + def test_end_to_end_cloud_uri_patterns(self): + """Test that common cloud URI patterns work across all fixed components.""" + cloud_uris = [ + "s3://my-data-bucket/datasets/train/", + "gs://ml-models/embeddings/bert/", + "abfs://data@storage.dfs.core.windows.net/processed/", + "https://api.example.com/v1/data/", + ] + + for uri in cloud_uris: + # Test path construction (semantic deduplication pattern) + embeddings_path = posixpath.join(uri, "embeddings") + assert embeddings_path.startswith(uri) + + # Test filename extraction (download pattern) + test_file_path = posixpath.join(uri, "test_file.json") + filename = posixpath.basename(test_file_path) + assert filename == "test_file.json" + + # Test module name extraction (text_utils pattern) + script_path = posixpath.join(uri, "script.py") + module_name = posixpath.splitext(posixpath.basename(script_path))[0] + assert module_name == "script" + + def test_backward_compatibility_with_local_paths(self): + """Ensure fixes don't break local filesystem operations.""" + local_paths = [ + "/home/user/data/", + "./local_data/", + "../relative/path/", + "simple_filename.txt", + ] + + for path in local_paths: + # All the fixed patterns should work with local paths too + subpath = posixpath.join(path, "subdir") + filename = posixpath.basename(subpath) + + # These operations should succeed without errors + assert isinstance(subpath, str) + assert isinstance(filename, str) + + def test_error_handling_for_invalid_uris(self): + """Test that invalid URIs are handled gracefully.""" + invalid_uris = [ + "", + "invalid://bad-protocol/path", + "s3://", # Missing bucket + "gs:///no-bucket", + ] + + for uri in invalid_uris: + # The posixpath operations should not crash on invalid URIs + # and should handle invalid URIs gracefully + result = posixpath.basename(uri) + assert isinstance(result, str)