diff --git a/geemap/ai.py b/geemap/ai.py index 4a4d20c60c..21da8eb3b9 100644 --- a/geemap/ai.py +++ b/geemap/ai.py @@ -1,4 +1,5 @@ """This module contains functions for interacting with AI models. + The Genie class source code is adapted from the ee_genie.ipynb at . Credit to the original author Simon Ilyushchenko (). The DataExplorer class source code is adapted from . @@ -76,13 +77,13 @@ class Genie(widgets.VBox): Credit to the original author Simon Ilyushchenko (). Args: - project (Optional[str], optional): Google Cloud project ID. Defaults to None. - google_api_key (Optional[str], optional): Google API key. Defaults to None. - gemini_model (str, optional): The Gemini model to use. Defaults to "gemini-1.5-flash". + project: Google Cloud project ID. Defaults to None. + google_api_key: Google API key. Defaults to None. + gemini_model: The Gemini model to use. Defaults to "gemini-1.5-flash". For a list of available models, see https://bit.ly/4fKfXW7. - target_score (float, optional): The target score for the model. Defaults to 0.8. - widget_height (str, optional): The height of the widget. Defaults to "600px". - initialize_ee (bool, optional): Whether to initialize Earth Engine. Defaults to True. + target_score: The target score for the model. Defaults to 0.8. + widget_height: The height of the widget. Defaults to "600px". + initialize_ee: Whether to initialize Earth Engine. Defaults to True. Raises: ValueError: If the project ID or Google API key is not provided. @@ -97,8 +98,6 @@ def __init__( widget_height: str = "600px", initialize_ee: bool = True, ) -> None: - # Initialization - if project is None: project = coreutils.get_env_var("EE_PROJECT_ID") or coreutils.get_env_var( "GOOGLE_PROJECT_ID" @@ -125,8 +124,7 @@ def __init__( # Score to aim for (on the 0-1 scale). The exact meaning of what "score" means # is left to the LLM. - # Count of analysis rounds - + # Count of analysis rounds. self.iteration = 1 self.map_dirty = False @@ -138,7 +136,7 @@ def __init__( image_model = genai.GenerativeModel(gemini_model) - # UI widget definitions + # UI widget definitions. # We define the widgets early because some functions will write to the debug # and/or chat panels. @@ -196,7 +194,7 @@ def __init__( with chat_output: print("CHAT COLUMN\n") - # Simple functions that LLM will call + # Simple functions that LLM will call. def set_center(x: float, y: float, zoom: int) -> str: """Sets the map center to the given coordinates and zoom level and @@ -205,7 +203,6 @@ def set_center(x: float, y: float, zoom: int) -> str: print(f"SET_CENTER({x}, {y}, {zoom})\n") m.set_center(x, y) m.zoom = zoom - # global map_dirty self.map_dirty = True return ( "Do not call any more functions in this request to let geemap bounds " @@ -287,10 +284,10 @@ def inner_monologue(thoughts: str) -> str: print(f"THOUGHTS:\n {thoughts}\n") return "success" - # Functions for textual analysis of images + # Functions for textual analysis of images. def _lat_lon_to_tile(lon, lat, zoom_level): - # Convert latitude and longitude to Mercator coordinates + # Convert latitude and longitude to Mercator coordinates. x_merc = (lon + 180) / 360 y_merc = ( 1 @@ -327,49 +324,6 @@ def analyze_image(additional_instructions: str = "") -> str: return str(e) def _analyze_image(additional_instructions: str = "") -> str: - # bounds = m.bounds - # s, w = bounds[0] - # n, e = bounds[1] - # zoom = int(m.zoom) - - # min_tile_x, max_tile_y = _lat_lon_to_tile(w, s, zoom) - # max_tile_x, min_tile_y = _lat_lon_to_tile(e, n, zoom) - # min_tile_x = max(0, min_tile_x) - # max_tile_x = min(2**zoom - 1, max_tile_x) - # min_tile_y = max(0, min_tile_y) - # max_tile_y = min(2**zoom - 1, max_tile_y) - - # with debug_output: - # if additional_instructions: - # print(f"RUNNING IMAGE ANALYSIS: {additional_instructions}...\n") - # else: - # print("RUNNING IMAGE ANALYSIS...\n") - - # layers = list(m.ee_layer_dict.values()) - # if not layers: - # return "No data layers loaded" - # url_template = layers[-1]["ee_layer"].url - # tile_width = 256 - # tile_height = 256 - # image_width = (max_tile_x - min_tile_x + 1) * tile_width - # image_height = (max_tile_y - min_tile_y + 1) * tile_height - - # # Create a new blank image - # image = PIL.Image.new("RGB", (image_width, image_height)) - - # for y in range(min_tile_y, max_tile_y + 1): - # for x in range(min_tile_x, max_tile_x + 1): - # tile_url = str.format(url_template, x=x, y=y, z=zoom) - # # print(tile_url) - # tile_img = PIL.Image.open(io.BytesIO(get_image(tile_url))) - - # offset_x = (x - min_tile_x) * tile_width - # offset_y = (y - min_tile_y) * tile_height - # image.paste(tile_img, (offset_x, offset_y)) - - # width, height = image.size - # num_bands = len(image.getbands()) - with debug_output: if additional_instructions: print(f"RUNNING IMAGE ANALYSIS: {additional_instructions}...\n") @@ -393,7 +347,7 @@ def _analyze_image(additional_instructions: str = "") -> str: file.close() # Skip an LLM call when we can simply tell that something is wrong. - # (Also, LLMs might hallucinate on uniform images.) + # Also, LLMs might hallucinate on uniform images. if image_min == image_max: return ( f"The image tile has a single uniform color with value " @@ -410,8 +364,8 @@ def _analyze_image(additional_instructions: str = "") -> str: Avoid making assumptions about the specific geographic location, time period, or cause of the observed features. Focus solely on the literal contents of the image itself. Clearly indicate which features look natural, which look human-made, - and which look like image artifacts. (Eg, a completely straight blue line - is unlikely to be a river.) + and which look like image artifacts, e.g, a completely straight blue line + is unlikely to be a river. If the image is ambiguous or unclear, state so directly. Do not speculate or hypothesize beyond what is directly visible. @@ -479,13 +433,15 @@ def score_response( query: str, visualization_parameters: str, analysis: str ) -> str: """Returns how well the given analysis describes a map tile returned for - the given query. The analysis starts with a number between 0 and 1. + the given query. + + The analysis starts with a number between 0 and 1. - Arguments: - query: user-specified query - visualization_parameters: description of the bands used and visualization - parameters applied to the map tile - analysis: the textual description of the map tile + Args: + query: User-specified query. + visualization_parameters: Description of the bands used and + visualization parameters applied to the map tile. + analysis: The textual description of the map tile. """ with debug_output: print(f"VIZ PARAMS: {visualization_parameters}\n") @@ -507,18 +463,16 @@ def score_response( print(f"UNEXPECTED SCORE RESPONSE: {e}") return result - # Main prompt for the agent - system_prompt = f""" - The client is running in a Python notebook with a geemap Map displayed. - When composing Python code, do not use getMapId - just return the single-line - layer definition like 'ee.Image("USGS/SRTMGL1_003")' that we will pass to + The client is running in a Python notebook with a geemap Map displayed. When + composing Python code, do not use getMapId - just return the single-line layer + definition like 'ee.Image("USGS/SRTMGL1_003")' that we will pass to Map.addLayer(). Do not escape quotation marks in Python code. - Be sure to use Python, not Javascript, syntax for keyword parameters in - Python code (that is, "function(arg=value)") Using the provided functions, - respond to the user command following below (or respond why it's not possible). - If you get an Earth Engine error, attempt to fix it and then try again. + Be sure to use Python, not Javascript, syntax for keyword parameters in Python + code (that is, "function(arg=value)") Using the provided functions, respond to + the user command following below (or respond why it's not possible). If you get + an Earth Engine error, attempt to fix it and then try again. Before you choose a dataset, think about what kind of dataset would be most suitable for the query. Also think about what zoom level would be suitable for @@ -531,18 +485,16 @@ def score_response( monlogue function why you chose a specific dataset, zoom level and map location. Prefer mosaicing image collections using the mosaic() function, don't get - individual images from collections via - 'first()'. Choose a tile size and zoom level that will ensure the - tile has enough pixels in it to avoid graininess, but not so many - that processing becomes very expensive. Do not use wide date ranges - with collections that have many images, but remember that Landsat and - Sentinel-2 have revisit period of several days. Do not use sample - locations - try to come up with actual locations that are relevant to - the request. + individual images from collections via 'first()'. Choose a tile size and zoom + level that will ensure the tile has enough pixels in it to avoid graininess, but + not so many that processing becomes very expensive. Do not use wide date ranges + with collections that have many images, but remember that Landsat and Sentinel-2 + have revisit period of several days. Do not use sample locations - try to come + up with actual locations that are relevant to the request. Use Landsat Collection 2, not Landsat Collection 1 ids. If you are getting - repeated errors when filtering by a time range, read the dataset description - to confirm that the dataset has data for the selected range. + repeated errors when filtering by a time range, read the dataset description to + confirm that the dataset has data for the selected range. Important: after using the set_center() function, just say that you have called this function and wait for the user to hit enter, after which you should @@ -552,30 +504,28 @@ def score_response( Once the map is updated and the user told you to proceed, call the analyze_image function() to describe the image for the same location that will be shown in geemap. If you pass additional instructions to analyze_image(), do not disclose - what the image is supposed to be to discourage hallucinations - you can only tell - the analysis function to pay attention to specific areas (eg, center or top left) - or shapes (eg, a line at the bottom) in the image. You can also tell the analysis - function about the chosen bands, color palette and min/max visualization - parameters, if any, to help it interpret the colors correctly. If the image - turns out to be uniform in color with no features, - use min/max visualization parameters to enhance contrast. + what the image is supposed to be to discourage hallucinations - you can only + tell the analysis function to pay attention to specific areas (eg, center or top + left) or shapes (eg, a line at the bottom) in the image. You can also tell the + analysis function about the chosen bands, color palette and min/max + visualization parameters, if any, to help it interpret the colors correctly. If + the image turns out to be uniform in color with no features, Use min/max + visualization parameters to enhance contrast. Frequently call the inner_monologue() functions to tell the user about your current thought process. This is a good time to reflect if you have been running into repeated errors of the same kind, and if so, to try a different approach. When you are done, call the score_response() function to evaluate the analysis. - You can also tell the scoring function about the chosen bands, color palette - and min/max visualization parameters, if any. If the analysis score is below - {target_score}, - keep trying to find and show a better image. You might have to change the dataset, - map location, zoom level, date range, bands, or other parameters - think about - what went wrong in the previous attempt and make the change that's most likely - to improve the score. + You can also tell the scoring function about the chosen bands, color palette and + min/max visualization parameters, if any. If the analysis score is below + {target_score}, keep trying to find and show a better image. You might have to + change the dataset, map location, zoom level, date range, bands, or other + parameters - think about what went wrong in the previous attempt and make the + change that's most likely to improve the score. """ - # Class for LLM chat with function calling - + # Class for LLM chat with function calling. gemini_tools = [ set_center, show_layer, @@ -589,7 +539,7 @@ class Gemini: """Gemini LLM.""" def __init__( - self, system_prompt, tools=None, model_name="gemini-1.5-pro-latest" + self, system_prompt, tools=None, model_name="gemini-3-pro-preview" ): if not tools: tools = [] @@ -697,7 +647,7 @@ def chat(self, question: str, temperature=0) -> str: model = Gemini(system_prompt, gemini_tools, model_name=gemini_model) analysis_model = Gemini(scoring_system_prompt, model_name=gemini_model) - # UI functions + # UI functions. def set_cursor_waiting(): js_code = """ @@ -743,7 +693,7 @@ def on_submit(widget): # UI layout - # Arrange the chat history and input in a vertical box + # Arrange the chat history and input in a vertical box. chat_ui = widgets.VBox( [image_widget, chat_output], layout=widgets.Layout(width="420px", height=widget_height), @@ -751,7 +701,7 @@ def on_submit(widget): chat_output.layout = widgets.Layout( width="400px" - ) # Fixed width for the left control + ) # Fixed width for the left control. m.layout = widgets.Layout(flex="1 1 auto", height=widget_height) table = widgets.HBox( @@ -778,13 +728,11 @@ def matches_interval( """Checks if the collection's datetime interval matches the query datetime interval. Args: - collection_interval (tuple[datetime.datetime, datetime.datetime]): - Temporal interval of the collection. - query_interval (tuple[datetime.datetime, datetime.datetime]): A tuple - with the query interval start and end. + collection_interval: Temporal interval of the collection. + query_interval: A tuple with the query interval start and end. Returns: - bool: True if the datetime interval matches, False otherwise. + True if the datetime interval matches, False otherwise. """ start_query, end_query = query_interval start_collection, end_collection = collection_interval @@ -801,15 +749,14 @@ def matches_datetime( """Checks if the collection's datetime interval matches the query datetime. Args: - collection_interval (tuple[datetime.datetime, Optional[datetime.datetime]]): - Temporal interval of the collection. - query_datetime (datetime.datetime): A datetime coming from a query. + collection_interval: Temporal interval of the collection. + query_datetime: A datetime coming from a query. Returns: - bool: True if the datetime interval matches, False otherwise. + True if the datetime interval matches, False otherwise. """ if collection_interval[1] is None: - # End date should always be set in STAC JSON files, but just in case... + # End date should always be set in STAC JSON files, but just in case. end_date = datetime.datetime.now(tz=datetime.UTC) else: end_date = collection_interval[1] @@ -820,15 +767,14 @@ def matches_datetime( stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_fixed(1), retry=tenacity.retry_if_exception_type(LayerException), - # before_sleep=lambda retry_state: print(f"LayerException occurred. Retrying in 1 seconds... (Attempt {retry_state.attempt_number}/3)") ) def run_ee_code(code: str, ee: Any, geemap_instance: Map) -> None: """Executes Earth Engine Python code within the context of a geemap instance. Args: - code (str): The Earth Engine Python code to execute. - ee (Any): The Earth Engine module. - geemap_instance (Map): The geemap instance. + code: The Earth Engine Python code to execute. + ee: The Earth Engine module. + geemap_instance: The geemap instance. Raises: Exception: Re-raises any exception encountered during code execution. @@ -837,10 +783,11 @@ def run_ee_code(code: str, ee: Any, geemap_instance: Map) -> None: # geemap appears to have some stray print statements. _ = io.StringIO() with redirect_stdout(_): - # Note that sometimes the geemap code uses both 'Map' and 'm' to refer to a map instance. + # Note that sometimes the geemap code uses both 'Map' and 'm' to refer to + # a map instance. exec(code, {"ee": ee, "Map": geemap_instance, "m": geemap_instance}) except Exception: - # Re-raise the exception with the original traceback + # Re-raise the exception with the original traceback. exc_type, exc_value, exc_traceback = sys.exc_info() raise exc_value.with_traceback(exc_traceback) @@ -855,11 +802,7 @@ class BBox: north: float def is_global(self) -> bool: - """Checks if the bounding box is global. - - Returns: - bool: True if the bounding box is global, False otherwise. - """ + """Returns True if the bounding box is global, False otherwise.""" return ( self.west == -180 and self.south == -90 @@ -872,10 +815,10 @@ def from_list(cls, bbox_list: list[float]) -> "BBox": """Constructs a BBox from a list of four numbers [west, south, east, north]. Args: - bbox_list (List[float]): List of four numbers representing the bounding box. + bbox_list: List of four numbers representing the bounding box. Returns: - BBox: The constructed BBox object. + The constructed BBox object. Raises: ValueError: If the coordinates are not in the correct order. @@ -896,7 +839,7 @@ def to_list(self) -> list[float]: """Converts the BBox to a list of four numbers [west, south, east, north]. Returns: - List[float]: List of four numbers representing the bounding box. + List of four numbers representing the bounding box. """ return [self.west, self.south, self.east, self.north] @@ -906,10 +849,10 @@ def intersects(self, query_bbox: "BBox") -> bool: Doesn't handle bboxes extending past the antimeridian. Args: - query_bbox (BBox): Bounding box from the query. + query_bbox: Bounding box from the query. Returns: - bool: True if the two bounding boxes intersect, False otherwise. + True if the two bounding boxes intersect, False otherwise. """ return ( query_bbox.west < self.east @@ -920,7 +863,7 @@ def intersects(self, query_bbox: "BBox") -> bool: class Collection: - """A simple wrapper for a STAC Collection..""" + """A simple wrapper for a STAC Collection.""" stac_json: dict[str, Any] @@ -928,21 +871,21 @@ def __init__(self, stac_json: dict[str, Any]) -> None: """Initializes the Collection. Args: - stac_json (Dict[str, Any]): The STAC JSON of the collection. + stac_json: The STAC JSON of the collection. """ self.stac_json = stac_json if stac_json.get("gee:status") == "deprecated": - # Set the STAC 'deprecated' field that we don't set in the jsonnet files + # Set the STAC 'deprecated' field that we don't set in the jsonnet files. stac_json["deprecated"] = True def __getitem__(self, item: str) -> Any: """Gets an item from the STAC JSON. Args: - item (str): The key of the item to get. + item: The key of the item to get. Returns: - Any: The value of the item. + The value of the item. """ return self.stac_json[item] @@ -950,11 +893,11 @@ def get(self, item: str, default: Any | None = None) -> Any | None: """Matches dict's get by returning None if there is no item. Args: - item (str): The key of the item to get. - default (Optional[Any]): The default value to return if the item is not found. Defaults to None. + item: The key of the item to get. + default: Value to return if the item is not found. Defaults to None. Returns: - Optional[Any]: The value of the item or the default value. + The value of the item or the default value. """ return self.stac_json.get(item, default) @@ -962,7 +905,7 @@ def public_id(self) -> str: """Gets the public ID of the collection. Returns: - str: The public ID of the collection. + The public ID of the collection. """ return self["id"] @@ -970,7 +913,7 @@ def hyphen_id(self) -> str: """Gets the hyphenated ID of the collection. Returns: - str: The hyphenated ID of the collection. + The hyphenated ID of the collection. """ return self["id"].replace("/", "_") @@ -978,7 +921,7 @@ def get_dataset_type(self) -> str: """Gets the dataset type of the collection. Returns: - str: The dataset type of the collection. + The dataset type of the collection. """ return self["gee:type"] @@ -998,8 +941,7 @@ def datetime_interval( """Returns datetime objects representing temporal extents. Returns: - Iterable[tuple[datetime.datetime, Optional[datetime.datetime]]]: - An iterable of tuples representing temporal extents. + An iterable of tuples representing temporal extents. Raises: ValueError: If the temporal interval start is not found. @@ -1018,47 +960,27 @@ def datetime_interval( yield (start_date, end_date) def start(self) -> datetime.datetime: - """Gets the start datetime of the collection. - - Returns: - datetime.datetime: The start datetime of the collection. - """ + """Returns the start datetime of the collection.""" return list(self.datetime_interval())[0][0] def start_str(self) -> str: - """Gets the start datetime of the collection as a string. - - Returns: - str: The start datetime of the collection as a string. - """ + """Returns the start datetime of the collection as a string.""" if not self.start(): return "" return self.start().strftime("%Y-%m-%d") def end(self) -> datetime.datetime | None: - """Gets the end datetime of the collection. - - Returns: - Optional[datetime.datetime]: The end datetime of the collection. - """ + """Returns the end datetime of the collection.""" return list(self.datetime_interval())[0][1] def end_str(self) -> str: - """Gets the end datetime of the collection as a string. - - Returns: - str: The end datetime of the collection as a string. - """ + """Returns the end datetime of the collection as a string.""" if not self.end(): return "" return self.end().strftime("%Y-%m-%d") def bbox_list(self) -> Sequence[BBox]: - """Gets the bounding boxes of the collection. - - Returns: - Sequence[BBox]: A sequence of bounding boxes. - """ + """Returns a sequence of bounding boxes.""" if "extent" not in self.stac_json: # Assume global if nothing listed. return (BBox(-180, -90, 180, 90),) @@ -1067,22 +989,14 @@ def bbox_list(self) -> Sequence[BBox]: ) def bands(self) -> list[dict[str, Any]]: - """Gets the bands of the collection. - - Returns: - List[Dict[str, Any]]: A list of dictionaries representing the bands. - """ + """Returns a list of dictionaries representing the bands.""" summaries = self.stac_json.get("summaries") if not summaries: return [] return summaries.get("eo:bands", []) def spatial_resolution_m(self) -> float: - """Gets the spatial resolution of the collection in meters. - - Returns: - float: The spatial resolution of the collection in meters. - """ + """Returns the spatial resolution of the collection in meters.""" summaries = self.stac_json.get("summaries") if not summaries: return -1 @@ -1098,22 +1012,14 @@ def spatial_resolution_m(self) -> float: return -1 def temporal_resolution_str(self) -> str: - """Gets the temporal resolution of the collection as a string. - - Returns: - str: The temporal resolution of the collection as a string. - """ + """Returns the temporal resolution of the collection as a string.""" interval_dict = self.stac_json.get("gee:interval") if not interval_dict: return "" return f"{interval_dict['interval']} {interval_dict['unit']}" def python_code(self) -> str: - """Gets the Python code sample for the collection. - - Returns: - str: The Python code sample for the collection. - """ + """Returns the Python code sample for the collection.""" code = self.stac_json.get("code") if not code: return "" @@ -1124,7 +1030,7 @@ def set_python_code(self, code: str) -> None: """Sets the Python code sample for the collection. Args: - code (str): The Python code sample to set. + code: The Python code sample to set. """ if not code: self.stac_json["code"] = {"js_code": "", "py_code": code} @@ -1135,7 +1041,7 @@ def set_js_code(self, code: str) -> None: """Sets the JavaScript code sample for the collection. Args: - code (str): The JavaScript code sample to set. + code: The JavaScript code sample to set. """ if not code: return "" @@ -1143,10 +1049,7 @@ def set_js_code(self, code: str) -> None: self.stac_json["code"] = {"js_code": "", "py_code": code} def image_preview_url(self) -> str: - """Gets the URL of the preview image for the collection. - - Returns: - str: The URL of the preview image for the collection. + """Returns the URL of the preview image for the collection. Raises: ValueError: If no preview image is found. @@ -1161,11 +1064,7 @@ def image_preview_url(self) -> str: raise ValueError(f"No preview image found for {id}") def catalog_url(self) -> str: - """Gets the URL of the catalog for the collection. - - Returns: - str: The URL of the catalog for the collection. - """ + """Returns the URL of the catalog for the collection.""" links = self.stac_json["links"] for link in links: if "rel" in link and link["rel"] == "catalog": @@ -1180,7 +1079,6 @@ def catalog_url(self) -> str: return "" -# @title class CollectionList() class CollectionList(Sequence[Collection]): """List of stac.Collections; can be filtered to return a smaller sublist.""" @@ -1259,17 +1157,17 @@ def filter_by_bounding_box(self, query_bbox: BBox): def start_str(self) -> datetime.datetime: return self.start().strftime("%Y-%m-%d") - def sort_by_spatial_resolution(self, reverse=False): - """ - Sorts the collections based on their spatial resolution. + def sort_by_spatial_resolution(self, reverse: bool = False): + """Sorts the collections based on their spatial resolution. + Collections with spatial_resolution_m() == -1 are pushed to the end. Args: - reverse (bool): If True, sort in descending order (highest resolution first). - If False (default), sort in ascending order (lowest resolution first). + reverse: If True, sort in descending order (highest resolution first). If + False (default), sort in ascending order (lowest resolution first). Returns: - CollectionList: A new CollectionList instance with sorted collections. + A new CollectionList instance with sorted collections. """ def sort_key(collection): @@ -1282,20 +1180,15 @@ def sort_key(collection): return self.__class__(sorted_collections) def limit(self, n: int): - """ - Returns a new CollectionList containing the first n entries. + """Returns a new CollectionList containing the first n entries. Args: - n (int): The number of entries to include in the new list. - - Returns: - CollectionList: A new CollectionList instance with at most n collections. + n: The number of entries to include in the new list. """ return self.__class__(self._collections[:n]) - def to_df(self): + def to_df(self) -> pd.DataFrame: """Converts a collection list to a dataframe with a select set of fields.""" - rows = [] for col in self._collections: # Remove text in parens in dataset name. @@ -1323,7 +1216,7 @@ def __init__(self, storage_client: storage.Client) -> None: """Initializes the Catalog with collections loaded from Google Cloud Storage. Args: - storage_client (storage.Client): The Google Cloud Storage client. + storage_client: The Google Cloud Storage client. """ self.collections = CollectionList(self._load_collections(storage_client)) @@ -1331,10 +1224,7 @@ def get_collection(self, id: str) -> Collection: """Returns the collection with the given id. Args: - id (str): The ID of the collection. - - Returns: - Collection: The collection with the given ID. + id: The ID of the collection. Raises: ValueError: If no collection with the given ID is found. @@ -1364,7 +1254,7 @@ def _read_file(self, file_blob: storage.Blob) -> Collection: """Reads the contents of a file from the specified bucket. Args: - file_blob (storage.Blob): The blob representing the file. + file_blob: The blob representing the file. Returns: Collection: The collection created from the file contents. @@ -1376,10 +1266,10 @@ def _read_files(self, file_blobs: list[storage.Blob]) -> list[Collection]: """Processes files in parallel. Args: - file_blobs (List[storage.Blob]): The list of file blobs. + file_blobs: The list of file blobs. Returns: - List[Collection]: The list of collections created from the file contents. + The list of collections created from the file contents. """ collections = [] with futures.ThreadPoolExecutor(max_workers=10) as executor: @@ -1394,10 +1284,10 @@ def _load_collections(self, storage_client: storage.Client) -> Sequence[Collecti """Loads all EE STAC JSON files from GCS, with datetimes as objects. Args: - storage_client (storage.Client): The Google Cloud Storage client. + storage_client: The Google Cloud Storage client. Returns: - Sequence[Collection]: A tuple of collections loaded from the files. + A tuple of collections loaded from the files. """ bucket = storage_client.get_bucket("earthengine-stac") files = [ @@ -1428,12 +1318,11 @@ def _load_all_code_samples( """Loads js + py example scripts from GCS into dict keyed by dataset ID. Args: - storage_client (storage.Client): The Google Cloud Storage client. + storage_client: The Google Cloud Storage client. Returns: - Dict[str, Dict[str, str]]: A dictionary mapping dataset IDs to their code samples. + A dictionary mapping dataset IDs to their code samples. """ - # Get json file from GCS bucket # 'gs://earthengine-catalog/catalog/example_scripts.json' bucket = storage_client.get_bucket("earthengine-catalog") @@ -1441,8 +1330,7 @@ def _load_all_code_samples( file_contents = blob.download_as_string().decode() data = json.loads(file_contents) - # Flatten json to get a map from ID (using '_' rather than '/') to code - # sample. + # Flatten json to get a map from ID (using '_' rather than '/') to code sample. all_datasets_by_provider = data[0]["contents"] code_samples_dict = {} for provider in all_datasets_by_provider: @@ -1461,10 +1349,10 @@ def _make_python_code_sample(self, js_code: str) -> str: """Converts EE JS code into python. Args: - js_code (str): The JavaScript code to convert. + js_code: The JavaScript code to convert. Returns: - str: The converted Python code. + The converted Python code. """ # geemap appears to have some stray print statements. @@ -1490,7 +1378,7 @@ def __init__(self, embeddings_dict: dict[str, list[float]]) -> None: """Initializes the PrecomputedEmbeddings. Args: - embeddings_dict (Dict[str, List[float]]): A dictionary mapping texts to their embeddings. + embeddings_dict: A dictionary mapping texts to their embeddings. """ self.embeddings_dict = embeddings_dict self.model = TextEmbeddingModel.from_pretrained("google/text-embedding-004") @@ -1499,10 +1387,10 @@ def embed_documents(self, texts: list[str]) -> list[list[float]]: """Embeds a list of documents. Args: - texts (List[str]): The list of texts to embed. + texts: The list of texts to embed. Returns: - List[List[float]]: The list of embeddings. + The list of embeddings. """ return [self.embeddings_dict[text] for text in texts] @@ -1510,10 +1398,10 @@ def embed_query(self, text: str) -> list[float]: """Embeds a query text. Args: - text (str): The query text to embed. + text: The query text to embed. Returns: - List[float]: The embedding of the query text. + The embedding of the query text. """ embeddings = self.model.get_embeddings([text]) return embeddings[0].values @@ -1523,32 +1411,32 @@ def make_langchain_index(embeddings_df: pd.DataFrame) -> VectorStoreIndexWrapper """Creates an index from a dataframe of precomputed embeddings. Args: - embeddings_df (pd.DataFrame): The dataframe containing precomputed embeddings. + embeddings_df: The dataframe containing precomputed embeddings. Returns: - VectorStoreIndexWrapper: The vector store index wrapper. + The vector store index wrapper. """ - # Create a dictionary mapping texts to their embeddings + # Create a dictionary mapping texts to their embeddings. embeddings_dict = dict(zip(embeddings_df["id"], embeddings_df["embedding"])) - # Create our custom embeddings class + # Create our custom embeddings class. precomputed_embeddings = PrecomputedEmbeddings(embeddings_dict) - # Create Langchain Document objects + # Create Langchain Document objects. documents = [] for index, row in embeddings_df.iterrows(): page_content = row["id"] metadata = {"summary": row["summary"], "name": row["name"]} documents.append(Document(page_content=page_content, metadata=metadata)) - # Create the VectorstoreIndexCreator + # Create the VectorstoreIndexCreator. index_creator = VectorstoreIndexCreator(embedding=precomputed_embeddings) # Create the index return index_creator.from_documents(documents) -# Wrap Langchain embeddings in our own EE dataset wrapper +# Wrap Langchain embeddings in our own EE dataset wrapper. class EarthEngineDatasetIndex: """Class for indexing and searching Earth Engine datasets.""" @@ -1588,17 +1476,17 @@ def find_top_matches( """Retrieve relevant datasets from the Earth Engine data catalog. Args: - query (str): The kind of data being searched for, e.g., 'population'. - results (int): The number of datasets to return. Defaults to 10. - threshold (float): The maximum dot product between the query and catalog embeddings. + query: The kind of data being searched for, e.g., 'population'. + results: The number of datasets to return. Defaults to 10. + threshold: The maximum dot product between the query and catalog embeddings. Defaults to 0.7. - bounding_box (Optional[List[float]]): The spatial bounding box for the query, + bounding_box: The spatial bounding box for the query, in the format [lon1, lat1, lon2, lon2]. Defaults to None. - temporal_interval (Optional[tuple[datetime.datetime, datetime.datetime]]): - Temporal constraints as a tuple of datetime objects. Defaults to None. + temporal_interval: Temporal constraints as a tuple of datetime objects. + Defaults to None. Returns: - CollectionList: A list of collections that match the query. + A list of collections that match the query. """ similar_docs = self.index.vectorstore.similarity_search_with_score( query, llm=self.llm, k=results @@ -1625,17 +1513,17 @@ def find_top_matches_with_score_df( """Retrieve relevant datasets and their match scores as a DataFrame. Args: - query (str): The kind of data being searched for, e.g., 'population'. - results (int): The number of datasets to return. Defaults to 20. - threshold (float): The maximum dot product between the query and catalog embeddings. + query: The kind of data being searched for, e.g., 'population'. + results: The number of datasets to return. Defaults to 20. + threshold: The maximum dot product between the query and catalog embeddings. Defaults to 0.7. - bounding_box (Optional[List[float]]): The spatial bounding box for the query, - in the format [lon1, lat1, lon2, lon2]. Defaults to None. - temporal_interval (Optional[tuple[datetime.datetime, datetime.datetime]]): - Temporal constraints as a tuple of datetime objects. Defaults to None. + bounding_box: The spatial bounding box for the query, in the format [lon1, + lat1, lon2, lon2]. Defaults to None. + temporal_interval: Temporal constraints as a tuple of datetime + objects. Defaults to None. Returns: - pd.DataFrame: A DataFrame containing the dataset IDs and their match scores. + A DataFrame containing the dataset IDs and their match scores. """ scores_df = self.ids_to_match_scores_df( query, results, bounding_box, temporal_interval @@ -1656,15 +1544,15 @@ def ids_to_match_scores_df( """Convert dataset IDs and match scores to a DataFrame. Args: - query (str): The kind of data being searched for, e.g., 'population'. - results (int): The number of datasets to return. - bounding_box (Optional[List[float]]): The spatial bounding box for the query, - in the format [lon1, lat1, lon2, lon2]. Defaults to None. - temporal_interval (Optional[tuple[datetime.datetime, datetime.datetime]]): - Temporal constraints as a tuple of datetime objects. Defaults to None. + query: The kind of data being searched for, e.g., 'population'. + results: The number of datasets to return. + bounding_box: The spatial bounding box for the query, in the format [lon1, + lat1, lon2, lon2]. Defaults to None. + temporal_interval: Temporal constraints as a tuple of datetime + objects. Defaults to None. Returns: - pd.DataFrame: A DataFrame containing the dataset IDs and their match scores. + A DataFrame containing the dataset IDs and their match scores. """ similar_docs = self.index.vectorstore.similarity_search_with_score( query, llm=self.llm, k=results @@ -1680,22 +1568,21 @@ def explain_relevance( query: str, dataset_id: str, catalog: Catalog, - model_name: str = "gemini-1.5-pro-latest", + model_name: str = "gemini-3-pro-preview", stream: bool = False, ) -> str: """Prompts LLM to explain the relevance of a dataset to a query. Args: - query (str): The user's query. - dataset_id (str): The ID of the dataset. - catalog (Catalog): The catalog containing the dataset. - model_name (str): The name of the model to use. Defaults to "gemini-1.5-pro-latest". - stream (bool): Whether to stream the response. Defaults to False. + query: The user's query. + dataset_id: The ID of the dataset. + catalog: The catalog containing the dataset. + model_name: The name of the model to use. Defaults to "gemini-3-pro-preview". + stream: Whether to stream the response. Defaults to False. Returns: - str: The explanation of the dataset's relevance to the query. + The explanation of the dataset's relevance to the query. """ - stac_json = catalog.get_collection(dataset_id).stac_json return explain_relevance_from_stac_json(query, stac_json, model_name, stream) @@ -1710,33 +1597,32 @@ def explain_relevance( def explain_relevance_from_stac_json( query: str, stac_json: dict[str, Any], - model_name: str = "gemini-1.5-pro-latest", + model_name: str = "gemini-3-pro-preview", stream: bool = False, ) -> str: """Prompts LLM to explain the relevance of a dataset to a query using its STAC JSON. Args: - query (str): The user's query. - stac_json (Dict[str, Any]): The STAC JSON of the dataset. - model_name (str): The name of the model to use. Defaults to "gemini-1.5-pro-latest". - stream (bool): Whether to stream the response. Defaults to False. + query: The user's query. + stac_json: The STAC JSON of the dataset. + model_name: The name of the model to use. Defaults to "gemini-3-pro-preview". + stream: Whether to stream the response. Defaults to False. Returns: - str: The explanation of the dataset's relevance to the query. + The explanation of the dataset's relevance to the query. """ stac_json_str = json.dumps(stac_json) prompt = f""" - I am an Earth Engine user contemplating using a dataset to support - my investigation of the following query. Provide a concise, paragraph-long - summary explaining why this dataset may be a good fit for my use case. - If it does not seem like an appropriate dataset, say so. - If relevant, call attention to a max of 3 bands that may be of particular interest. - Weigh the tradeoffs between temporal and spatial resolution, particularly - if the original query specifies regions of interest, time periods, or - frequency of data collection. If I have not specified any - spatial constraints, do your best based on the nature of their query. For example, - if I'm wanting to study something small, like buildings, I will likely need good spatial resolution. + I am an Earth Engine user contemplating using a dataset to support my investigation of + the following query. Provide a concise, paragraph-long summary explaining why this + dataset may be a good fit for my use case. If it does not seem like an appropriate + dataset, say so. If relevant, call attention to a max of 3 bands that may be of + particular interest. Weigh the tradeoffs between temporal and spatial resolution, + particularly if the original query specifies regions of interest, time periods, or + frequency of data collection. If I have not specified any spatial constraints, do your + best based on the nature of their query. For example, if I'm wanting to study + something small, like buildings, I will likely need good spatial resolution. Here is the original query: {query} @@ -1758,33 +1644,33 @@ def explain_relevance_from_stac_json( (requests.exceptions.RequestException, ConnectionError) ), ) -def is_valid_question(question: str, model_name: str = "gemini-1.5-pro-latest") -> bool: +def is_valid_question(question: str, model_name: str = "gemini-3-pro-preview") -> bool: """Filters out questions that cannot be answered by a dataset search tool. Args: - question (str): The user's question. - model_name (str): The name of the model to use. Defaults to "gemini-1.5-pro-latest". + question: The user's question. + model_name: The name of the model to use. Defaults to "gemini-3-pro-preview". Returns: - bool: True if the question is valid, False otherwise. + True if the question is valid, False otherwise. """ prompt = f""" - You are a tool whose job is to determine whether or not the following question - relates even in a small way to geospatial datasets. Please provide a single - word answer either True or False. + You are a tool whose job is to determine whether or not the following question relates + even in a small way to geospatial datasets. Please provide a single word answer + either True or False. - For example, if the original query is "hello" - you should answer False. If - the original query is "cheese futures" you should still answer True because - the user could be interested in cheese production, and therefore agricultural - land where cattle might be raised. + For example, if the original query is "hello" - you should answer False. If the + original query is "cheese futures" you should still answer True because the user could + be interested in cheese production, and therefore agricultural land where cattle might + be raised. Here is the original query: {question} """ model = genai.GenerativeModel(model_name) response = model.generate_content(prompt) - # Err on the side of returning True + # Err on the side of returning True. return response.text.lower().strip() != "false" @@ -1801,15 +1687,15 @@ class CodeThoughts(typing_extensions.TypedDict): ), ) def fix_ee_python_code( - code: str, ee: Any, geemap_instance: Map, model_name: str = "gemini-1.5-pro-latest" + code: str, ee: Any, geemap_instance: Map, model_name: str = "gemini-3-pro-preview" ) -> str: """Asks a model to do ee python code correction in the event of error. Args: - code (str): The Earth Engine Python code to fix. - ee (Any): The Earth Engine module. - geemap_instance (Map): The geemap instance. - model_name (str): The name of the model to use. Defaults to "gemini-1.5-pro-latest". + code: The Earth Engine Python code to fix. + ee: The Earth Engine module. + geemap_instance: The geemap instance. + model_name: The name of the model to use. Defaults to "gemini-3-pro-preview". Returns: str: The corrected Earth Engine Python code. @@ -1825,14 +1711,13 @@ def create_error_prompt(code: str, error_msg: str) -> str: {code} ``` - I have encountered the following error, please fix it. In 1-2 sentences, - explain your debugging thought process in the 'thoughts' field. Note that - the setOptions method exists only in the ee javascript library. Code - referencing that method can be removed. + I have encountered the following error, please fix it. In 1-2 sentences, explain + your debugging thought process in the 'thoughts' field. Note that the setOptions + method exists only in the ee javascript library. Code referencing that method can + be removed. - Include the complete revised code snippet in the code field. - Do not provide any other comentary in the code field. Do not add any new - imports to the code snippet. + Include the complete revised code snippet in the code field. Do not provide any + other comentary in the code field. Do not add any new imports to the code snippet. {error_msg} """ @@ -1850,7 +1735,6 @@ def create_error_prompt(code: str, error_msg: str) -> str: while total_attempts < max_attempts and broken: try: run_ee_code(code, ee, geemap_instance) - # logging.warning(f'Code success! after {total_attempts} try.') return code except Exception as e: logging.warning("Code execution error, asking Gemini for help.") @@ -1894,20 +1778,19 @@ def __init__(self, query: str, collections: CollectionList) -> None: """Initializes the DatasetSearchInterface. Args: - query (str): The search query string. - collections (CollectionList): The list of dataset collections. + query: The search query string. + collections: The list of dataset collections. """ - self.query = query self.collections = collections - # Create the output widgets + # Create the output widgets. self.code_output = widgets.Output(layout=widgets.Layout(width="50%")) self.details_output = widgets.Output( layout=widgets.Layout(height="300px", width="100%") ) - # Initialize dataset table + # Initialize dataset table. table_html = self._build_table_html(collections) self.dataset_table = widgets.HTML(value=table_html) @@ -1916,7 +1799,7 @@ def __init__(self, query: str, collections: CollectionList) -> None: self._dataset_select_js_code = self._dataset_select_js_code(_callback_id) # self._dataset_select_js_code(_callback_id) - # Initialize map + # Initialize map. self.map_output = widgets.Output(layout=widgets.Layout(width="100%")) self.geemap_instance = geemap.Map(height="600px", width="100%") @@ -1949,13 +1832,13 @@ def display(self): ), ) - # Create the vertical box for code and details + # Create the vertical box for code and details. self.details_code_box = widgets.VBox( [details_widget, code_widget], layout=widgets.Layout(width="50%", height="600px"), ) - # Create a horizontal box for map and details/code + # Create a horizontal box for map and details/code. map_details_code_box = widgets.HBox( [self.map_widget, self.details_code_box], layout=widgets.Layout( @@ -1971,7 +1854,7 @@ def display(self): ), ) - # Add debug panel to the main layout + # Add debug panel to the main layout. main_layout = widgets.VBox( [ # title, @@ -2066,10 +1949,10 @@ def _build_table_html(self, collections: CollectionList) -> str: """Builds the HTML for the dataset table. Args: - collections (CollectionList): The list of dataset collections. + collections: The list of dataset collections. Returns: - str: The HTML string for the dataset table. + The HTML string for the dataset table. """ table_html = """ @@ -2101,7 +1984,7 @@ def update_outputs(self, selected_dataset: str) -> None: """Updates the output widgets based on the selected dataset. Args: - selected_dataset (str): The ID of the selected dataset. + selected_dataset: The ID of the selected dataset. """ collection = self.collections.filter_by_ids([selected_dataset]) @@ -2150,10 +2033,10 @@ def _dataset_select_js_code(self, callback_id: str) -> str: """Generates JavaScript code for handling dataset selection. Args: - callback_id (str): The callback ID for the dataset selection. + callback_id: The callback ID for the dataset selection. Returns: - str: The JavaScript code as a string. + The JavaScript code as a string. """ return Template( syntax.javascript( @@ -2200,6 +2083,7 @@ def _dataset_select_js_code(self, callback_id: str) -> str: class DatasetExplorer: """A widget for exploring Earth Engine datasets. + The DataExplorer class source code is adapted from . Credit to the original author Renee Johnston () """ @@ -2209,17 +2093,17 @@ def __init__( project_id: str = "GOOGLE_PROJECT_ID", google_api_key: str = "GOOGLE_API_KEY", vertex_ai_zone: str = "us-central1", - model: str = "gemini-1.5-pro", + model: str = "gemini-3-pro-preview", embeddings_cloud_path: str = "gs://earthengine-catalog/embeddings/catalog_embeddings.jsonl", ) -> None: """Initializes the DatasetExplorer. Args: - project_id (str): Google Cloud project ID. Defaults to "GOOGLE_PROJECT_ID". - google_api_key (str): Google API key. Defaults to "GOOGLE_API_KEY". - vertex_ai_zone (str): Vertex AI zone. Defaults to "us-central1". - model (str): Model name for ChatGoogleGenerativeAI. Defaults to "gemini-1.5-pro". - embeddings_cloud_path (str): Cloud path to the embeddings file. + project_id: Google Cloud project ID. Defaults to "GOOGLE_PROJECT_ID". + google_api_key: Google API key. Defaults to "GOOGLE_API_KEY". + vertex_ai_zone: Vertex AI zone. Defaults to "us-central1". + model: Model name for ChatGoogleGenerativeAI. Defaults to "gemini-3-pro-preview". + embeddings_cloud_path: Cloud path to the embeddings file. Defaults to "gs://earthengine-catalog/embeddings/catalog_embeddings.jsonl". """ @@ -2275,15 +2159,15 @@ def load_embeddings( self.ee_index = EarthEngineDatasetIndex(catalog, langchain_index, llm) - def show(self, query: str | None = None, **kwargs: Any) -> widgets.VBox: + def show(self, query: str | None = None, **kwargs) -> widgets.VBox: """Displays a query interface for searching datasets. Args: - query (Optional[str]): The initial query string. Defaults to None. - **kwargs (Any): Additional keyword arguments for widget styling. + query: The initial query string. Defaults to None. + **kwargs: Additional keyword arguments for widget styling. Returns: - widgets.VBox: A VBox containing the query input and output display. + A VBox containing the query input and output display. """ output.no_vertical_scroll() @@ -2328,7 +2212,7 @@ def on_query_change(text: widgets.Text) -> None: """Handles the event when the query text is submitted. Args: - text (widgets.Text): The text widget containing the query. + text: The text widget containing the query. """ output_widget.clear_output() with output_widget: