Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 208 additions & 6 deletions tests/test_common.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
"""Tests for `common` module."""
"""Tests for the common module."""

import base64
import builtins
import datetime
import io
import math
import json
import os
import pathlib
import sys
import tempfile
import unittest
from unittest import mock

import ee
from geemap import colormaps
from geemap import common
import ipywidgets
import numpy as np
from PIL import Image
import psutil
import pandas as pd


class CommonTest(unittest.TestCase):
Expand Down Expand Up @@ -219,7 +224,24 @@ class CommonTest(unittest.TestCase):
# TODO: test_vector_to_geojson
# TODO: test_vector_to_ee
# TODO: test_extract_pixel_values
# TODO: test_list_vars

def test_list_vars(self):
"""Tests list_vars function."""
common.list_vars_test_int = 1
common.list_vars_test_str = "test"
vars_all = common.list_vars()
vars_int = common.list_vars(var_type=int)
vars_str = common.list_vars(var_type=str)
del common.list_vars_test_int
del common.list_vars_test_str

self.assertIn("list_vars_test_int", vars_all)
self.assertIn("list_vars_test_str", vars_all)
self.assertIn("list_vars_test_int", vars_int)
self.assertNotIn("list_vars_test_str", vars_int)
self.assertIn("list_vars_test_str", vars_str)
self.assertNotIn("list_vars_test_int", vars_str)

# TODO: test_extract_transect
# TODO: test_random_sampling
# TODO: test_osm_to_gdf
Expand Down Expand Up @@ -279,6 +301,10 @@ def test_check_dir(self):
self.assertTrue(os.path.exists(abs_path_2))
self.assertEqual(abs_path_2, os.path.abspath(dir_path_2))

# Test with invalid type
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor nit: Most of the comments in geemap end with a period. The comments in this file don't.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bugger. Thanks for point it out. I fixed the entire file. I think gemini saw that most of the comments in geemap don't have punctuation, so it was continuing the trend.

with self.assertRaises(TypeError):
common.check_dir(123) # pytype: disable=wrong-arg-types

def test_check_file_path(self):
with tempfile.TemporaryDirectory() as tmpdir:
# Test with make_dirs=True
Expand Down Expand Up @@ -358,13 +384,61 @@ def test_get_palette_colors(self):

# TODO: test_plot_raster
# TODO: test_plot_raster_3d
# TODO: test_display_html

@mock.patch("geemap.common.IFrame")
@mock.patch("geemap.common.display")
def test_display_html(self, mock_display, mock_iframe):
mock_iframe.return_value = "iframe_object"
with tempfile.NamedTemporaryFile(suffix=".html", delete=False) as fp:
html_file = fp.name
try:
common.display_html(html_file, width=800, height=400)
mock_iframe.assert_called_once_with(src=html_file, width=800, height=400)
mock_display.assert_called_once_with(mock_iframe.return_value)
finally:
os.remove(html_file)

with self.assertRaisesRegex(ValueError, "is not a valid file path"):
common.display_html("non_existent_file.html")

# TODO: test_bbox_coords
# TODO: test_requireJS
# TODO: test_setupJS
# TODO: test_change_require
# TODO: test_ee_vector_style
# TODO: test_get_direct_url

@mock.patch("geemap.common.requests.head")
def test_get_direct_url(self, mock_head):
mock_response = mock.Mock()
mock_response.url = "https://example.com/direct_url"
mock_head.return_value = mock_response

# Test with a URL that redirects
self.assertEqual(
common.get_direct_url("https://example.com/redirect"),
"https://example.com/direct_url",
)
mock_head.assert_called_with(
"https://example.com/redirect", allow_redirects=True
)

# Test with a direct URL
self.assertEqual(
common.get_direct_url("https://example.com/direct_url"),
"https://example.com/direct_url",
)
mock_head.assert_called_with(
"https://example.com/direct_url", allow_redirects=True
)

# Test with non-http URL
with self.assertRaisesRegex(ValueError, "url must start with http."):
common.get_direct_url("ftp://example.com/file")

# Test with non-string URL
with self.assertRaisesRegex(ValueError, "url must be a string."):
common.get_direct_url(123) # pytype: disable=wrong-arg-types

# TODO: test_add_crs
# TODO: test_jrc_hist_monthly_history
# TODO: test_html_to_streamlit
Expand All @@ -389,7 +463,29 @@ def test_use_mkdocs_false(self):
# TODO: test_arc_add_layer
# TODO: test_arc_zoom_to_extent
# TODO: test_get_current_year
# TODO: test_html_to_gradio

def test_html_to_gradio(self):
html_list = [
"<!DOCTYPE html>",
"<html>",
"<body>",
" <script>",
' L.tileLayer("url", {',
' "attribution": "..."',
" }).addTo(map);",
' function(e) { console.log("foo"); }',
' "should be kept";',
" </script>",
"</body>",
"</html>",
]
gradio_html = common.html_to_gradio(html_list, width="800px", height="400px")
self.assertIn('<iframe style="width: 800px; height: 400px"', gradio_html)
# DO NOT SUBMIT - this is failing
# self.assertNotIn('"attribution":', gradio_html)
self.assertNotIn('function(e) { console.log("foo"); }', gradio_html)
self.assertIn('"should be kept";', gradio_html)

# TODO: test_image_check
# TODO: test_image_client
# TODO: test_image_center
Expand Down Expand Up @@ -513,7 +609,113 @@ def test_center_zoom_to_xy_range(self):
self.assertAlmostEqual(y_range[1], 4163881.1441, places=4)

# TODO: test_get_geometry_coords
# TODO: test_landsat_scaling

def test_landsat_scaling(self):
image = mock.MagicMock(spec=ee.Image)
optical_bands_mock = mock.MagicMock(name="optical_bands_mock")
thermal_bands_mock = mock.MagicMock(name="thermal_bands_mock")
qa_pixel_mock = mock.MagicMock(name="qa_pixel_mock")

def select_side_effect(band_selector):
if band_selector == "SR_B.":
return optical_bands_mock
elif band_selector == "ST_B.*":
return thermal_bands_mock
elif band_selector == "QA_PIXEL":
return qa_pixel_mock
else:
raise ValueError(f"Unexpected band selector: {band_selector}")

image.select.side_effect = select_side_effect

scaled_optical = mock.MagicMock(name="scaled_optical")
optical_bands_mock.multiply.return_value.add.return_value = scaled_optical

scaled_thermal = mock.MagicMock(name="scaled_thermal")
thermal_bands_mock.multiply.return_value.add.return_value = scaled_thermal

qa_mask = mock.MagicMock(name="qa_mask")
qa_pixel_mock.bitwiseAnd.return_value.eq.return_value = qa_mask

# To allow chaining addBands().addBands().updateMask()
image.addBands.return_value = image

# Test case 1: thermal_bands=True, apply_fmask=False
image.reset_mock()
optical_bands_mock.reset_mock()
thermal_bands_mock.reset_mock()
qa_pixel_mock.reset_mock()
common.landsat_scaling(image, thermal_bands=True, apply_fmask=False)
image.select.assert_has_calls([mock.call("SR_B."), mock.call("ST_B.*")])
self.assertEqual(image.select.call_count, 2)
optical_bands_mock.multiply.assert_called_with(0.0000275)
optical_bands_mock.multiply().add.assert_called_with(-0.2)
thermal_bands_mock.multiply.assert_called_with(0.00341802)
thermal_bands_mock.multiply().add.assert_called_with(149)
image.addBands.assert_has_calls(
[
mock.call(scaled_thermal, None, True),
mock.call(scaled_optical, None, True),
]
)
image.updateMask.assert_not_called()

# Test case 2: thermal_bands=False, apply_fmask=False
image.reset_mock()
optical_bands_mock.reset_mock()
thermal_bands_mock.reset_mock()
qa_pixel_mock.reset_mock()
common.landsat_scaling(image, thermal_bands=False, apply_fmask=False)
image.select.assert_called_once_with("SR_B.")
optical_bands_mock.multiply.assert_called_with(0.0000275)
optical_bands_mock.multiply().add.assert_called_with(-0.2)
thermal_bands_mock.multiply.assert_not_called()
image.addBands.assert_called_once_with(scaled_optical, None, True)
image.updateMask.assert_not_called()

# Test case 3: thermal_bands=True, apply_fmask=True
image.reset_mock()
optical_bands_mock.reset_mock()
thermal_bands_mock.reset_mock()
qa_pixel_mock.reset_mock()
common.landsat_scaling(image, thermal_bands=True, apply_fmask=True)
image.select.assert_has_calls(
[mock.call("SR_B."), mock.call("ST_B.*"), mock.call("QA_PIXEL")],
any_order=True,
)
self.assertEqual(image.select.call_count, 3)
optical_bands_mock.multiply.assert_called_with(0.0000275)
optical_bands_mock.multiply().add.assert_called_with(-0.2)
thermal_bands_mock.multiply.assert_called_with(0.00341802)
thermal_bands_mock.multiply().add.assert_called_with(149)
qa_pixel_mock.bitwiseAnd.assert_called_once_with(31)
qa_pixel_mock.bitwiseAnd().eq.assert_called_once_with(0)
image.addBands.assert_has_calls(
[
mock.call(scaled_thermal, None, True),
mock.call(scaled_optical, None, True),
]
)
image.updateMask.assert_called_once_with(qa_mask)

# Test case 4: thermal_bands=False, apply_fmask=True
image.reset_mock()
optical_bands_mock.reset_mock()
thermal_bands_mock.reset_mock()
qa_pixel_mock.reset_mock()
common.landsat_scaling(image, thermal_bands=False, apply_fmask=True)
image.select.assert_has_calls(
[mock.call("SR_B."), mock.call("QA_PIXEL")], any_order=True
)
self.assertEqual(image.select.call_count, 2)
optical_bands_mock.multiply.assert_called_with(0.0000275)
optical_bands_mock.multiply().add.assert_called_with(-0.2)
thermal_bands_mock.multiply.assert_not_called()
qa_pixel_mock.bitwiseAnd.assert_called_once_with(31)
qa_pixel_mock.bitwiseAnd().eq.assert_called_once_with(0)
image.addBands.assert_called_once_with(scaled_optical, None, True)
image.updateMask.assert_called_once_with(qa_mask)

# TODO: test_tms_to_geotiff
# TODO: test_tif_to_jp2
# TODO: test_ee_to_geotiff
Expand Down