diff --git a/agentscope-core/src/main/java/io/agentscope/core/formatter/MediaUtils.java b/agentscope-core/src/main/java/io/agentscope/core/formatter/MediaUtils.java index 62ed3530f..6bca3945c 100644 --- a/agentscope-core/src/main/java/io/agentscope/core/formatter/MediaUtils.java +++ b/agentscope-core/src/main/java/io/agentscope/core/formatter/MediaUtils.java @@ -27,6 +27,7 @@ import java.nio.file.Files; import java.nio.file.InvalidPathException; import java.nio.file.Path; +import java.nio.file.Paths; import java.util.Base64; import java.util.List; import javax.imageio.ImageIO; @@ -58,7 +59,7 @@ private MediaUtils() { /** * Check if a URL is a local file path (not a URL with protocol scheme). - * Returns true for paths without http://, https://, ftp://, or file:// prefixes. + * Returns true for paths without http://, https://, ftp://, file:// or oss:// prefixes. * Used to distinguish local files from remote URLs for different processing paths. * * @param url The URL or file path to check @@ -349,14 +350,34 @@ public static String inferAudioFormatFromMediaType(String mediaType) { * Extract file extension from path or URL. */ public static String getExtension(String path) { - if (path == null) { + if (path == null || path.isBlank()) { return ""; } - int dotIndex = path.lastIndexOf('.'); - int slashIndex = Math.max(path.lastIndexOf('/'), path.lastIndexOf('\\')); - // Ensure the dot is after the last slash (not part of directory name) - if (dotIndex > slashIndex && dotIndex < path.length() - 1) { - return path.substring(dotIndex + 1); + + Path fileNamePath; + try { + if (isLocalFile(path)) { + // treat as file + fileNamePath = Paths.get(path).normalize().getFileName(); + } else { + // treat as url + URI uri = URI.create(path).normalize(); + fileNamePath = Paths.get(uri.getPath()).getFileName(); + } + } catch (Exception e) { + log.warn("Invalid path: {}", path, e); + return ""; + } + + if (fileNamePath == null) { + return ""; + } + + String fileName = fileNamePath.toString(); + int dotIndex = fileName.lastIndexOf('.'); + // Ensure the dot exists and is not the last character + if (dotIndex != -1 && dotIndex < fileName.length() - 1) { + return fileName.substring(dotIndex + 1); } return ""; } diff --git a/agentscope-core/src/test/java/io/agentscope/core/formatter/MediaUtilsTest.java b/agentscope-core/src/test/java/io/agentscope/core/formatter/MediaUtilsTest.java index 3f10397bb..8bfb93594 100644 --- a/agentscope-core/src/test/java/io/agentscope/core/formatter/MediaUtilsTest.java +++ b/agentscope-core/src/test/java/io/agentscope/core/formatter/MediaUtilsTest.java @@ -329,4 +329,77 @@ void testUrlToRgbaImageInputStreamWithFile() throws IOException { assertNotNull(is); is.close(); } + + @Test + @DisplayName("Should return empty string when path is invalid") + void testGetExtensionWithInvalidPath() { + assertEquals("", MediaUtils.getExtension(null)); + assertEquals("", MediaUtils.getExtension("")); + assertEquals("", MediaUtils.getExtension(" ")); + assertEquals("", MediaUtils.getExtension("/")); + assertEquals("", MediaUtils.getExtension("\\")); + assertEquals("", MediaUtils.getExtension("/home/user")); + assertEquals("", MediaUtils.getExtension("C:\\Users\\Administrator")); + assertEquals("", MediaUtils.getExtension("https://abc")); + assertEquals("", MediaUtils.getExtension("https://abc[@]123")); + assertEquals("", MediaUtils.getExtension("https://example.com/abc")); + assertEquals("", MediaUtils.getExtension("https://example.com/abc/")); + assertEquals("", MediaUtils.getExtension("https://example.com/img.png.")); + } + + @Test + @DisplayName("Should get extension with all types") + void testGetExtensionWithAllTypes() { + assertEquals("png", MediaUtils.getExtension("https://example.com/img.png")); + assertEquals("wav", MediaUtils.getExtension("https://example.com/audio.wav")); + assertEquals("mp3", MediaUtils.getExtension("https://example.com/audio.mp3")); + assertEquals("mp4", MediaUtils.getExtension("https://example.com/video.mp4")); + assertEquals("xxx", MediaUtils.getExtension("https://example.com/xxx.xxx")); + } + + @Test + @DisplayName("Should get last extension when have nested types") + void testGetExtensionWithNestedTypes() { + assertEquals("gz", MediaUtils.getExtension("https://example.com/img.png.tar.gz")); + assertEquals("zip", MediaUtils.getExtension("https://example.com/img.png.zip")); + } + + @Test + @DisplayName("Should get extension with masks") + void testGetExtensionWithMasks() { + assertEquals("png", MediaUtils.getExtension("https://example.com/img.png?id=1&type=png")); + assertEquals("png", MediaUtils.getExtension("https://example.com/img.png#section1.1")); + assertEquals( + "png", + MediaUtils.getExtension("https://example.com/img.png?id=1&type=png#section1.1")); + assertEquals("png", MediaUtils.getExtension("https://example.com/img.png?v=2.0.0")); + assertEquals( + "png", MediaUtils.getExtension("https://example.com/img-123&type=png&v=2.0.0.png")); + + // verify the file name include special masks + assertEquals("png", MediaUtils.getExtension("/home/user/img-123&type=png&v=2.0.0.png")); + assertEquals( + "png", + MediaUtils.getExtension("C:\\Users\\Administrator\\img_123&type=png&v=2.0.0.png")); + } + + @Test + @DisplayName("Should get extension with all supported protocol") + void testGetExtensionWithSupportedProtocol() { + assertEquals( + "png", + MediaUtils.getExtension("http://example.com/img.png?id=1&type=png#section1.1")); + assertEquals( + "png", + MediaUtils.getExtension("https://example.com/img.png?id=1&type=png#section1.1")); + assertEquals( + "png", + MediaUtils.getExtension("oss://example.com/img.png?id=1&type=png#section1.1")); + assertEquals( + "png", + MediaUtils.getExtension("file://example.com/img.png?id=1&type=png#section1.1")); + assertEquals( + "png", + MediaUtils.getExtension("ftp://example.com/img.png?id=1&type=png#section1.1")); + } }