diff --git a/compass/_cli/finalize.py b/compass/_cli/finalize.py index 46f2dc509..735034d25 100644 --- a/compass/_cli/finalize.py +++ b/compass/_cli/finalize.py @@ -8,7 +8,7 @@ from rich.console import Console from compass.utilities import Directories -from compass.utilities.location import Jurisdiction +from compass.utilities.jurisdictions import Jurisdiction from compass.utilities.parsing import load_config from compass.utilities.finalize import save_run_meta, doc_infos_to_db, save_db from compass.scripts.process import _initialize_model_params @@ -116,16 +116,7 @@ def _compile_db(jurisdictions, dirs): doc_info = doc_info[0] all_doc_infos.append( - { - "ord_db_fp": ord_db_fp, - "source": doc_info.get("source"), - "date": ( - doc_info.get("effective_year"), - doc_info.get("effective_month"), - doc_info.get("effective_day"), - ), - "jurisdiction": jurisdiction, - } + {"ord_db_fp": ord_db_fp, "jurisdiction": jurisdiction} ) db, __ = doc_infos_to_db(all_doc_infos) diff --git a/compass/common/__init__.py b/compass/common/__init__.py index 8e26703fb..3de4d7959 100644 --- a/compass/common/__init__.py +++ b/compass/common/__init__.py @@ -2,7 +2,6 @@ from .base import ( EXTRACT_ORIGINAL_SETBACK_TEXT_PROMPT, - BaseTextExtractor, empty_output, llm_response_starts_with_no, llm_response_starts_with_yes, diff --git a/compass/common/base.py b/compass/common/base.py index d16a69b27..c284c0912 100644 --- a/compass/common/base.py +++ b/compass/common/base.py @@ -1,19 +1,13 @@ """Common ordinance extraction components""" -import asyncio import logging from datetime import datetime import networkx as nx -from elm import ApiBase from compass.common.tree import AsyncDecisionTree from compass.utilities import llm_response_as_json -from compass.utilities.enums import LLMUsageCategory -from compass.utilities.parsing import ( - merge_overlapping_texts, - clean_backticks_from_llm_response, -) + from compass.exceptions import COMPASSRuntimeError @@ -970,64 +964,3 @@ def setup_graph_permitted_use_districts(**kwargs): ), ) return G - - -class BaseTextExtractor: - """Base implementation for a text extractor""" - - SYSTEM_MESSAGE = ( - "You are a text extraction assistant. Your job is to extract only " - "verbatim, **unmodified** excerpts from provided legal or policy " - "documents. Do not interpret or paraphrase. Do not summarize. Only " - "return exactly copied segments that match the specified scope. If " - "the relevant content appears within a table, return the entire " - "table, including headers and footers, exactly as formatted." - ) - """System message for text extraction LLM calls""" - _USAGE_LABEL = LLMUsageCategory.DOCUMENT_ORDINANCE_SUMMARY - - def __init__(self, llm_caller): - """ - - Parameters - ---------- - llm_caller : LLMCaller - LLM Caller instance used to extract ordinance info with. - """ - self.llm_caller = llm_caller - - async def _process(self, text_chunks, instructions, is_valid_chunk): - """Perform extraction processing""" - logger.info( - "Extracting summary text from %d text chunks asynchronously...", - len(text_chunks), - ) - logger.debug("Model instructions are:\n%s", instructions) - outer_task_name = asyncio.current_task().get_name() - summaries = [ - asyncio.create_task( - self.llm_caller.call( - sys_msg=self.SYSTEM_MESSAGE, - content=f"{instructions}\n\n# TEXT #\n\n{chunk}", - usage_sub_label=self._USAGE_LABEL, - ), - name=outer_task_name, - ) - for chunk in text_chunks - ] - summary_chunks = await asyncio.gather(*summaries) - summary_chunks = [ - clean_backticks_from_llm_response(chunk) - for chunk in summary_chunks - if is_valid_chunk(chunk) - ] - - text_summary = merge_overlapping_texts(summary_chunks) - logger.debug( - "Final summary contains %d tokens", - ApiBase.count_tokens( - text_summary, - model=self.llm_caller.kwargs.get("model", "gpt-4"), - ), - ) - return text_summary diff --git a/compass/data/tx_water_districts.csv b/compass/data/tx_water_districts.csv index f1e543100..22f7acf68 100644 --- a/compass/data/tx_water_districts.csv +++ b/compass/data/tx_water_districts.csv @@ -38,7 +38,7 @@ Texas,,High Plains,Underground Water Conservation District,36, Texas,,Hill Country,Underground Water Conservation District,37, Texas,,Hudspeth County,Underground Water Conservation District,38, Texas,,Irion County,WConservation District,39, -Texas,,Jeff Davis,County Underground Water Conservation District,40, +Texas,,Jeff Davis County,Underground Water Conservation District,40, Texas,,Kenedy County,Groundwater Conservation District,41, Texas,,Kimble County,Groundwater Conservation District,42, Texas,,Kinney County,Groundwater Conservation District,43, diff --git a/compass/exceptions.py b/compass/exceptions.py index 82b3fb86b..0a833dca5 100644 --- a/compass/exceptions.py +++ b/compass/exceptions.py @@ -21,9 +21,17 @@ class COMPASSNotInitializedError(COMPASSError): """COMPASS not initialized error""" +class COMPASSTypeError(COMPASSError, TypeError): + """COMPASS TypeError""" + + class COMPASSValueError(COMPASSError, ValueError): """COMPASS ValueError""" class COMPASSRuntimeError(COMPASSError, RuntimeError): """COMPASS RuntimeError""" + + +class COMPASSPluginConfigurationError(COMPASSRuntimeError): + """COMPASS Plugin Configuration Error""" diff --git a/compass/extraction/__init__.py b/compass/extraction/__init__.py index 7a19c21bd..373bf198f 100644 --- a/compass/extraction/__init__.py +++ b/compass/extraction/__init__.py @@ -1,9 +1,9 @@ """Ordinance text extraction tooling""" from .apply import ( - check_for_ordinance_info, + check_for_relevant_text, extract_date, - extract_ordinance_text_with_llm, - extract_ordinance_text_with_ngram_validation, + extract_relevant_text_with_llm, + extract_relevant_text_with_ngram_validation, extract_ordinance_values, ) diff --git a/compass/extraction/apply.py b/compass/extraction/apply.py index 4ae5bfff4..024861bd8 100644 --- a/compass/extraction/apply.py +++ b/compass/extraction/apply.py @@ -19,25 +19,29 @@ _TEXT_OUT_CHAR_BUFFER = 1.05 -async def check_for_ordinance_info( +async def check_for_relevant_text( doc, model_config, heuristic, tech, - ordinance_text_collector_class, - permitted_use_text_collector_class=None, + text_collectors, usage_tracker=None, + min_chunks_to_process=3, ): - """Parse a single document for ordinance information + """Parse a single document for relevant text (e.g. ordinances) + + The results of the text parsing are stored in the documents attrs + under the respective text collector label. Parameters ---------- - doc : elm.web.document.BaseDocument + doc : BaseDocument A document instance (PDF, HTML, etc) potentially containing ordinance information. Note that if the document's ``attrs`` - has the ``"contains_ord_info"`` key, it will not be processed. - To force a document to be processed by this function, remove - that key from the documents ``attrs``. + has the relevant text output, the corresponding text collector + will not be run. To force a document to be processed by this + function, remove all previously collected text from the + document's ``attrs``. model_config : compass.llm.config.LLMConfig Configuration describing which LLM service, splitter, and call parameters should be used for extraction. @@ -47,38 +51,31 @@ async def check_for_ordinance_info( tech : str Technology of interest (e.g. "solar", "wind", etc). This is used to set up some document validation decision trees. - ordinance_text_collector_class : type - Collector class invoked to capture ordinance text chunks. - permitted_use_text_collector_class : type, optional - Collector class used to capture permitted-use districts text. - When ``None``, the permitted-use workflow is skipped. + text_collectors : iterable + Iterable of text collector classes to run during document + parsing. Each class must implement the + :class:`compass.plugin.interface.BaseTextCollector` interface. + If the document already contains text collected by a given + collector (i.e. the collector's ``OUT_LABEL`` is found in + ``doc.attrs``), that collector will be skipped. usage_tracker : UsageTracker, optional Optional tracker instance to monitor token usage during LLM calls. By default, ``None``. + min_chunks_to_process : int, optional + Minimum number of chunks to process before aborting due to text + failing the heuristic or deemed not legal (if applicable). + By default, ``3``. Returns ------- - elm.web.document.BaseDocument - Document that has been parsed for ordinance text. The results of - the parsing are stored in the documents attrs. In particular, - the attrs will contain a ``"contains_ord_info"`` key that - will be set to ``True`` if ordinance info was found in the text, - and ``False`` otherwise. If ``True``, the attrs will also - contain a ``"date"`` key containing the most recent date that - the ordinance was enacted (or a tuple of `None` if not found), - and an ``"ordinance_text"`` key containing the ordinance text - snippet. Note that the snippet may contain other info as well, - but should encapsulate all of the ordinance text. + bool + ``True`` if any text was collected by any of the text collectors + and ``False`` otherwise. Notes ----- - The function updates progress bar logging as chunks are processed - and sets ``contains_district_info`` when - ``permitted_use_text_collector_class`` is provided. + The function updates progress bar logging as chunks are processed. """ - if "contains_ord_info" in doc.attrs: - return doc - chunks = model_config.text_splitter.split_text(doc.text) chunk_parser = ParseChunksWithMemory(chunks, num_to_recall=2) legal_text_validator = ( @@ -93,52 +90,48 @@ async def check_for_ordinance_info( else None ) - ordinance_text_collector = ordinance_text_collector_class( - llm_service=model_config.llm_service, - usage_tracker=usage_tracker, - **model_config.llm_call_kwargs, - ) - callbacks = [ordinance_text_collector.check_chunk] - if permitted_use_text_collector_class is not None: - permitted_use_text_collector = permitted_use_text_collector_class( + collectors_to_run = [] + callbacks = [] + for collector_class in text_collectors: + if collector_class is None or collector_class.OUT_LABEL in doc.attrs: + continue + + collector = collector_class( llm_service=model_config.llm_service, usage_tracker=usage_tracker, **model_config.llm_call_kwargs, ) - callbacks.append(permitted_use_text_collector.check_chunk) + collectors_to_run.append(collector) + callbacks.append(collector.check_chunk) + + if not collectors_to_run: + logger.debug( + "No text collectors to run for document from %s", + doc.attrs.get("source", "unknown source"), + ) + return False await parse_by_chunks( chunk_parser, heuristic, legal_text_validator, callbacks=callbacks, - min_chunks_to_process=3, + min_chunks_to_process=min_chunks_to_process, ) - doc.attrs["contains_ord_info"] = ordinance_text_collector.contains_ord_info - if doc.attrs["contains_ord_info"]: - doc.attrs["ordinance_text"] = ordinance_text_collector.ordinance_text - logger.debug_to_file( - "Ordinance text for %s is:\n%s", - doc.attrs.get("source", "unknown source"), - doc.attrs["ordinance_text"], - ) - - if permitted_use_text_collector_class is not None: - doc.attrs["contains_district_info"] = ( - permitted_use_text_collector.contains_district_info - ) - if doc.attrs["contains_district_info"]: - doc.attrs["permitted_use_text"] = ( - permitted_use_text_collector.permitted_use_district_text - ) + found_text = False + for collector in collectors_to_run: + if text := collector.relevant_text: + found_text = True + doc.attrs[collector.OUT_LABEL] = text logger.debug_to_file( - "Permitted use text for %s is:\n%s", + "%r text for %s is:\n%s", + collector.OUT_LABEL, doc.attrs.get("source", "unknown source"), - doc.attrs["permitted_use_text"], + text, ) - return doc + return found_text async def extract_date(doc, model_config, usage_tracker=None): @@ -146,7 +139,7 @@ async def extract_date(doc, model_config, usage_tracker=None): Parameters ---------- - doc : elm.web.document.BaseDocument + doc : BaseDocument A document potentially containing date information. model_config : compass.llm.config.LLMConfig Configuration describing which LLM service, splitter, and call @@ -157,7 +150,7 @@ async def extract_date(doc, model_config, usage_tracker=None): Returns ------- - elm.web.document.BaseDocument + BaseDocument Document that has been parsed for dates. The results of the parsing are stored in the documents attrs. In particular, the attrs will contain a ``"date"`` key that will contain the @@ -189,21 +182,20 @@ async def extract_date(doc, model_config, usage_tracker=None): return doc -async def extract_ordinance_text_with_llm( +async def extract_relevant_text_with_llm( doc, text_splitter, extractor, original_text_key ): """Extract ordinance text from document using LLM Parameters ---------- - doc : elm.web.document.BaseDocument + doc : BaseDocument A document known to contain ordinance information. This means it - must contain an ``"ordinance_text"`` key in the attrs. You can - run :func:`check_for_ordinance_info` - to have this attribute populated automatically for documents - that are found to contain ordinance data. Note that if the - document's attrs does not contain the ``"ordinance_text"`` - key, you will get an error. + must contain the `original_text_key` key in the attrs. You can + run :func:`check_for_relevant_text` to have this attribute + populated automatically for documents that are found to contain + relevant extraction text. Note that if the document's attrs does + not contain the `original_text_key`, you will get an error. text_splitter : LCTextSplitter, optional Optional Langchain text splitter (or subclass instance), or any object that implements a `split_text` method. The method should @@ -217,7 +209,7 @@ async def extract_ordinance_text_with_llm( Returns ------- - elm.web.document.BaseDocument + BaseDocument Document that has been parsed for ordinance text. The results of the extraction are stored in the document's attrs. str @@ -225,7 +217,7 @@ async def extract_ordinance_text_with_llm( `doc.attrs` dictionary. """ - prev_meta_name = original_text_key # "ordinance_text" + prev_meta_name = original_text_key for meta_name, parser in extractor.parsers: doc.attrs[meta_name] = await _parse_if_input_text_not_empty( doc.attrs[prev_meta_name], @@ -239,7 +231,7 @@ async def extract_ordinance_text_with_llm( return doc, prev_meta_name -async def extract_ordinance_text_with_ngram_validation( +async def extract_relevant_text_with_ngram_validation( doc, text_splitter, extractor, @@ -261,13 +253,13 @@ async def extract_ordinance_text_with_ngram_validation( Parameters ---------- - doc : elm.web.document.BaseDocument + doc : BaseDocument A document known to contain ordinance information. This means it - must contain an ``"ordinance_text"`` key in the attrs. You can - run :func:`~compass.extraction.apply.check_for_ordinance_info` + must contain an ``"relevant_text"`` key in the attrs. You can + run :func:`~compass.extraction.apply.check_for_relevant_text` to have this attribute populated automatically for documents that are found to contain ordinance data. Note that if the - document's attrs does not contain the ``"ordinance_text"`` + document's attrs does not contain the ``"relevant_text"`` key, it will not be processed. text_splitter : LCTextSplitter, optional Optional Langchain text splitter (or subclass instance), or any @@ -301,7 +293,7 @@ async def extract_ordinance_text_with_ngram_validation( Returns ------- - elm.web.document.BaseDocument + BaseDocument Document that has been parsed for ordinance text. The results of the extraction are stored in the document's attrs. """ @@ -309,7 +301,8 @@ async def extract_ordinance_text_with_ngram_validation( msg = ( f"Input document has no {original_text_key!r} key or string " "does not contain information. Please run " - "`check_for_ordinance_info` prior to calling this method." + "`compass.extraction.check_for_relevant_text()` with the proper " + "text collector prior to calling this method." ) warn(msg, COMPASSWarning) return doc @@ -358,9 +351,8 @@ async def _extract_with_ngram_check( ) best_score = 0 - out_text_key = "extracted_text" for attempt in range(1, num_tries + 1): - doc, out_text_key = await extract_ordinance_text_with_llm( + doc, out_text_key = await extract_relevant_text_with_llm( doc, text_splitter, extractor, original_text_key ) cleaned_text = doc.attrs[out_text_key] @@ -422,10 +414,10 @@ async def extract_ordinance_values(doc, parser, text_key, out_key): Parameters ---------- - doc : elm.web.document.BaseDocument + doc : BaseDocument A document known to contain ordinance text. This means it must contain an `text_key` key in the attrs. You can run - :func:`~compass.extraction.apply.extract_ordinance_text_with_llm` + :func:`~compass.extraction.apply.extract_relevant_text_with_llm` to have this attribute populated automatically for documents that are found to contain ordinance data. Note that if the document's attrs does not contain the `text_key` key, it will @@ -442,7 +434,7 @@ async def extract_ordinance_values(doc, parser, text_key, out_key): Returns ------- - elm.web.document.BaseDocument + BaseDocument Document that has been parsed for ordinance values. The results of the extraction are stored in the document's attrs. @@ -455,7 +447,7 @@ async def extract_ordinance_values(doc, parser, text_key, out_key): msg = ( f"Input document has no {text_key!r} key or string " "does not contain info. Please run " - "`extract_ordinance_text_with_llm` prior to calling this method." + "`extract_relevant_text_with_llm` prior to calling this method." ) warn(msg, COMPASSWarning) return doc diff --git a/compass/extraction/context.py b/compass/extraction/context.py new file mode 100644 index 000000000..7b6b9c833 --- /dev/null +++ b/compass/extraction/context.py @@ -0,0 +1,180 @@ +"""Extraction context for multi-document ordinance extraction""" + +from textwrap import shorten +from collections.abc import Iterable + +import pandas as pd + +from compass.services.threaded import FileMover +from compass.exceptions import COMPASSTypeError + + +class ExtractionContext: + """Context for extraction operations supporting multiple documents + + This class provides a Document-compatible interface for extraction + workflows that may involve one or more source documents. It tracks + chunk-level provenance to identify which document each text chunk + originated from, while maintaining compatibility with existing + extraction functions that expect Document-like objects + """ + + def __init__(self, documents=None, attrs=None): + """ + + Parameters + ---------- + documents : sequence of BaseDocument, optional + One or more source documents contributing to this context. + For single-document workflows (solar, wind), pass a list + with one document. For multi-document workflows (water + rights), pass all contributing documents + attrs : dict, optional + Context-level attributes for extraction metadata + (jurisdiction, tech type, etc.). By default, ``None`` + """ + self.attrs = attrs or {} + self._documents = _as_list(documents) + self._data_docs = [] + + @property + def text(self): + """str: Concatenated text from all documents""" + return "\n\n".join(doc.text for doc in self.documents) + + @property + def pages(self): + """list: Concatenated pages from all documents""" + return [page for doc in self.documents for page in doc.pages] + + @property + def num_documents(self): + """int: Number of source documents in this context""" + return len(self.documents) + + @property + def documents(self): + """list: List of documents that might contain relevant info""" + return self._documents + + @documents.setter + def documents(self, other): + self._documents = _as_list(other) + + @property + def data_docs(self): + """list: List of documents that contributed to extraction""" + return self._data_docs + + @data_docs.setter + def data_docs(self, other): + if not isinstance(other, list): + msg = "data_docs must be set to a *list* of documents" + raise COMPASSTypeError(msg) + + self._data_docs = other + + def __str__(self): + header = ( + f"{self.__class__.__name__} with {self.num_documents:,} document" + ) + if self.num_documents != 1: + header = f"{header}s" + + if self.num_documents > 0: + docs = "\n\t- ".join( + [ + d.attrs.get("source", "Unknown source") + for d in self.documents + ] + ) + header = f"{header}:\n\t- {docs}" + + data_docs = _data_docs_repr(self.data_docs) + attrs = _attrs_repr(self.attrs) + return f"{header}\n{data_docs}\n{attrs}" + + def __len__(self): + return self.num_documents + + def __getitem__(self, index): + return self.documents[index] + + def __iter__(self): + return iter(self.documents) + + def __bool__(self): + return bool(self.documents) + + async def mark_doc_as_data_source(self, doc, out_fn_stem=None): + """Mark a document as a data source for extraction + + Parameters + ---------- + doc : BaseDocument + Document to add as a data source + out_fn_stem : str, optional + Optional output filename stem for this document. If + provided, the document file will be moved from the + temporary directory to the output directory with this + filename stem and appropriate file suffix. + By default, ``None``. + """ + self._data_docs.append(doc) + if out_fn_stem is not None: + await _move_file_to_out_dir(doc, out_fn_stem) + + +async def _move_file_to_out_dir(doc, out_fn): + """Move PDF or HTML text file to output directory""" + out_fp = await FileMover.call(doc, out_fn) + doc.attrs["out_fp"] = out_fp + return doc + + +def _as_list(documents): + """Convert input to list""" + if documents is None: + return [] + if not isinstance(documents, Iterable): + return [documents] + return list(documents) + + +def _data_docs_repr(data_docs): + """String representation of data source documents""" + if not data_docs: + return "Registered Data Source Documents: None" + + data_docs = "\n\t- ".join( + [d.attrs.get("source", "Unknown source") for d in data_docs] + ) + + return f"Registered Data Source Documents:\n\t- {data_docs}" + + +def _attrs_repr(attrs): + """String representation of context attributes""" + if not attrs: + return "Attrs: None" + + attrs = { + k: ( + f"DataFrame with {len(v):,} rows" + if isinstance(v, pd.DataFrame) + else v + ) + for k, v in attrs.items() + } + + indent = max(len(k) for k in attrs) + 2 + width = max(10, 80 - (indent + 4)) + to_join = [] + for k, v in attrs.items(): + v_str = str(v) + if "\n" in v_str: + v_str = shorten(v_str, width=width) + to_join.append(f"{k:>{indent}}:\t{v_str}") + + attrs = "\n".join(to_join) + return f"Attrs:\n{attrs}" diff --git a/compass/extraction/date.py b/compass/extraction/date.py index 8c01d4176..bb505ad12 100644 --- a/compass/extraction/date.py +++ b/compass/extraction/date.py @@ -58,7 +58,7 @@ async def parse(self, doc): Parameters ---------- - doc : elm.web.document.BaseDocument + doc : BaseDocument Document with a `raw_pages` attribute. Returns diff --git a/compass/extraction/small_wind/__init__.py b/compass/extraction/small_wind/__init__.py index d21c264ca..2e3c05aee 100644 --- a/compass/extraction/small_wind/__init__.py +++ b/compass/extraction/small_wind/__init__.py @@ -1,53 +1,3 @@ -"""Small wind ordinance extraction utilities""" +"""Small wind ordinance extraction plugin""" -from .ordinance import ( - SmallWindHeuristic, - SmallWindOrdinanceTextCollector, - SmallWindOrdinanceTextExtractor, - SmallWindPermittedUseDistrictsTextCollector, - SmallWindPermittedUseDistrictsTextExtractor, -) -from .parse import ( - StructuredSmallWindOrdinanceParser, - StructuredSmallWindPermittedUseDistrictsParser, -) - - -SMALL_WIND_QUESTION_TEMPLATES = [ - "filetype:pdf {jurisdiction} wind energy conversion system ordinances", - "wind energy conversion system ordinances {jurisdiction}", - "{jurisdiction} wind WECS ordinance", - ( - "Where can I find the legal text for small wind energy " - "turbine zoning ordinances in {jurisdiction}?" - ), - ( - "What is the specific legal information regarding zoning " - "ordinances for small wind turbines in {jurisdiction}?" - ), -] - -BEST_SMALL_WIND_ORDINANCE_WEBSITE_URL_KEYWORDS = { - "pdf": 92160, - "wecs": 46080, - "wind": 23040, - "zoning": 11520, - "ordinance": 5760, - r"renewable%20energy": 1440, - r"renewable+energy": 1440, - "renewable energy": 1440, - "planning": 720, - "plan": 360, - "government": 180, - "code": 60, - "area": 60, - r"land%20development": 15, - r"land+development": 15, - "land development": 15, - "land": 3, - "environment": 3, - "energy": 3, - "renewable": 3, - "municipal": 1, - "department": 1, -} +from .plugin import COMPASSSmallWindExtractor diff --git a/compass/extraction/small_wind/ordinance.py b/compass/extraction/small_wind/ordinance.py index fc838465e..426df3b4e 100644 --- a/compass/extraction/small_wind/ordinance.py +++ b/compass/extraction/small_wind/ordinance.py @@ -6,11 +6,12 @@ import logging -from compass.common import BaseTextExtractor -from compass.validation.content import Heuristic -from compass.llm.calling import StructuredLLMCaller +from compass.plugin.ordinance import ( + OrdinanceHeuristic, + OrdinanceTextCollector, + OrdinanceTextExtractor, +) from compass.utilities.enums import LLMUsageCategory -from compass.utilities.parsing import merge_overlapping_texts logger = logging.getLogger(__name__) @@ -32,7 +33,7 @@ _IGNORE_TYPES_LARGE = "large, utility-scale, for-sale, commercial" -class SmallWindHeuristic(Heuristic): +class SmallWindHeuristic(OrdinanceHeuristic): """Perform a heuristic check for mention of wind turbines in text""" NOT_TECH_WORDS = [ @@ -68,7 +69,7 @@ class SmallWindHeuristic(Heuristic): """Words and phrases that indicate text is NOT about WECS""" GOOD_TECH_KEYWORDS = ["wind", "setback"] """Words that indicate we should keep a chunk for analysis""" - GOOD_TECH_ACRONYMS = ["wecs", "wes", "swet", "pwet", "wef"] + GOOD_TECH_ACRONYMS = ["wecs", "wes", "swet", "pwet", "wef", "pwec", "swec"] """Acronyms for WECS that we want to capture""" GOOD_TECH_PHRASES = [ "small wecs", @@ -107,8 +108,6 @@ class SmallWindHeuristic(Heuristic): "front-of-meter wecs", "front-of-meter turbine", "front-of-meter wind", - "pwec", - "swecs", "wind energy conversion", "wind turbine", "wind tower", @@ -117,9 +116,12 @@ class SmallWindHeuristic(Heuristic): """Phrases that indicate text is about WECS""" -class SmallWindOrdinanceTextCollector(StructuredLLMCaller): +class SmallWindOrdinanceTextCollector(OrdinanceTextCollector): """Check text chunks for ordinances and collect them if they do""" + OUT_LABEL = "relevant_text" + """Identifier for text collected by this class""" + CONTAINS_ORD_PROMPT = ( "You extract structured data from text. Return your answer in JSON " "format (not markdown). Your JSON file must include exactly two " @@ -162,10 +164,6 @@ class SmallWindOrdinanceTextCollector(StructuredLLMCaller): ) """Prompt to check if chunk is for small WES""" - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._ordinance_chunks = {} - async def check_chunk(self, chunk_parser, ind): """Check a chunk at a given ind to see if it contains ordinance @@ -204,32 +202,11 @@ async def check_chunk(self, chunk_parser, ind): logger.debug("Text at ind %d is for small WECS", ind) - _store_chunk(chunk_parser, ind, self._ordinance_chunks) + self._store_chunk(chunk_parser, ind) logger.debug("Added text at ind %d to ordinances", ind) return True - @property - def contains_ord_info(self): - """bool: Flag indicating whether text contains ordinance info""" - return bool(self._ordinance_chunks) - - @property - def ordinance_text(self): - """str: Combined ordinance text from the individual chunks""" - logger.debug( - "Grabbing %d ordinance chunk(s) from original text at these " - "indices: %s", - len(self._ordinance_chunks), - list(self._ordinance_chunks), - ) - - text = [ - self._ordinance_chunks[ind] - for ind in sorted(self._ordinance_chunks) - ] - return merge_overlapping_texts(text) - async def _check_chunk_contains_ord(self, key, text_chunk): """Call LLM on a chunk of text to check for ordinance""" content = await self.call( @@ -251,9 +228,12 @@ async def _check_chunk_is_for_small_scale(self, key, text_chunk): return content.get(key, False) -class SmallWindPermittedUseDistrictsTextCollector(StructuredLLMCaller): +class SmallWindPermittedUseDistrictsTextCollector(OrdinanceTextCollector): """Check text chunks for permitted wind districts; collect them""" + OUT_LABEL = "permitted_use_text" + """Identifier for text collected by this class""" + DISTRICT_PROMPT = ( "You are a legal scholar that reads ordinance text and determines " "whether the text explicitly contains relevant information to " @@ -277,10 +257,6 @@ class SmallWindPermittedUseDistrictsTextCollector(StructuredLLMCaller): ) """Prompt to check if chunk contains info on permitted districts""" - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._district_chunks = {} - async def check_chunk(self, chunk_parser, ind): """Check a chunk to see if it contains permitted uses @@ -311,46 +287,28 @@ async def check_chunk(self, chunk_parser, ind): contains_district_info = content.get(key, False) if contains_district_info: - _store_chunk(chunk_parser, ind, self._district_chunks) + self._store_chunk(chunk_parser, ind) logger.debug("Text at ind %d contains district info", ind) return True logger.debug("Text at ind %d does not contain district info", ind) return False - @property - def contains_district_info(self): - """bool: Flag indicating whether text contains district info""" - return bool(self._district_chunks) - @property - def permitted_use_district_text(self): - """str: Combined permitted use districts text from the chunks""" - logger.debug( - "Grabbing %d permitted use chunk(s) from original text at " - "these indices: %s", - len(self._district_chunks), - list(self._district_chunks), - ) +class SmallWindOrdinanceTextExtractor(OrdinanceTextExtractor): + """Extract succinct ordinance text from input""" - text = [ - self._district_chunks[ind] for ind in sorted(self._district_chunks) - ] - return merge_overlapping_texts(text) + IN_LABEL = SmallWindOrdinanceTextCollector.OUT_LABEL + """Identifier for collected text ingested by this class""" + OUT_LABEL = "cleaned_text_for_extraction" + """Identifier for ordinance text extracted by this class""" -class SmallWindOrdinanceTextExtractor(BaseTextExtractor): - """Extract succinct ordinance text from input + TASK_DESCRIPTION = "Extracting small wind ordinance text" + """Task description to show in progress bar""" - Purpose: - Extract relevant ordinance text from document. - Responsibilities: - 1. Extract portions from chunked document text relevant to - particular ordinance type (e.g. wind zoning for small wind - systems). - Key Relationships: - Uses a StructuredLLMCaller for LLM queries. - """ + TASK_ID = "ordinance_text_extraction" + """ID to use for this extraction for linking with LLM configs""" WIND_ENERGY_SYSTEM_FILTER_PROMPT = ( "# CONTEXT #\n" @@ -470,7 +428,6 @@ async def extract_wind_energy_system_section(self, text_chunks): return await self._process( text_chunks=text_chunks, instructions=self.WIND_ENERGY_SYSTEM_FILTER_PROMPT, - is_valid_chunk=_valid_chunk, ) async def extract_small_wind_energy_system_section(self, text_chunks): @@ -491,7 +448,6 @@ async def extract_small_wind_energy_system_section(self, text_chunks): return await self._process( text_chunks=text_chunks, instructions=self.SMALL_WIND_ENERGY_SYSTEM_SECTION_FILTER_PROMPT, - is_valid_chunk=_valid_chunk, ) @property @@ -510,24 +466,23 @@ def parsers(self): "wind_energy_systems_text", self.extract_wind_energy_system_section, ) - yield ( - "cleaned_ordinance_text", - self.extract_small_wind_energy_system_section, - ) + yield self.OUT_LABEL, self.extract_small_wind_energy_system_section + + +class SmallWindPermittedUseDistrictsTextExtractor(OrdinanceTextExtractor): + """Extract succinct permitted use districts text from input""" + IN_LABEL = SmallWindPermittedUseDistrictsTextCollector.OUT_LABEL + """Identifier for collected text ingested by this class""" -class SmallWindPermittedUseDistrictsTextExtractor(BaseTextExtractor): - """Extract succinct ordinance text from input + OUT_LABEL = "districts_text" + """Identifier for permitted use text extracted by this class""" - Purpose: - Extract relevant ordinance text from document. - Responsibilities: - 1. Extract portions from chunked document text relevant to - particular ordinance type (e.g. wind zoning for small wind - systems). - Key Relationships: - Uses a StructuredLLMCaller for LLM queries. - """ + TASK_DESCRIPTION = "Extracting small wind permitted use text" + """Task description to show in progress bar""" + + TASK_ID = "permitted_use_text_extraction" + """ID to use for this extraction for linking with LLM configs""" _USAGE_LABEL = LLMUsageCategory.DOCUMENT_PERMITTED_USE_DISTRICTS_SUMMARY @@ -640,7 +595,6 @@ async def extract_permitted_uses(self, text_chunks): return await self._process( text_chunks=text_chunks, instructions=self.PERMITTED_USES_FILTER_PROMPT, - is_valid_chunk=_valid_chunk, ) async def extract_wes_permitted_uses(self, text_chunks): @@ -661,7 +615,6 @@ async def extract_wes_permitted_uses(self, text_chunks): return await self._process( text_chunks=text_chunks, instructions=self.WES_PERMITTED_USES_FILTER_PROMPT, - is_valid_chunk=_valid_chunk, ) @property @@ -677,19 +630,4 @@ def parsers(self): outputs parsed text. """ yield "permitted_use_only_text", self.extract_permitted_uses - yield "districts_text", self.extract_wes_permitted_uses - - -def _valid_chunk(chunk): - """True if chunk has content""" - return chunk and "no relevant text" not in chunk.lower() - - -def _store_chunk(parser, chunk_ind, store): - """Store chunk and its neighbors if it is not already stored""" - for offset in range(1 - parser.num_to_recall, 2): - ind_to_grab = chunk_ind + offset - if ind_to_grab < 0 or ind_to_grab >= len(parser.text_chunks): - continue - - store.setdefault(ind_to_grab, parser.text_chunks[ind_to_grab]) + yield self.OUT_LABEL, self.extract_wes_permitted_uses diff --git a/compass/extraction/small_wind/parse.py b/compass/extraction/small_wind/parse.py index 7e1273f08..12ece59c4 100644 --- a/compass/extraction/small_wind/parse.py +++ b/compass/extraction/small_wind/parse.py @@ -8,7 +8,7 @@ import pandas as pd -from compass.llm.calling import BaseLLMCaller, ChatLLMCaller +from compass.plugin.ordinance import OrdinanceParser from compass.extraction.features import SetbackFeatures from compass.common import ( EXTRACT_ORIGINAL_SETBACK_TEXT_PROMPT, @@ -203,18 +203,9 @@ class SmallWindSetbackFeatures(SetbackFeatures): """Clarifications to add to feature prompts""" -class StructuredSmallWindParser(BaseLLMCaller): +class StructuredSmallWindParser(OrdinanceParser): """Base class for parsing structured data""" - def _init_chat_llm_caller(self, system_message): - """Initialize a ChatLLMCaller instance for the DecisionTree""" - return ChatLLMCaller( - self.llm_service, - system_message=system_message, - usage_tracker=self.usage_tracker, - **self.kwargs, - ) - async def _check_wind_turbine_type(self, text): """Get the small turbine size mentioned in the text""" logger.info("Checking turbine types...") @@ -256,6 +247,12 @@ class StructuredSmallWindOrdinanceParser(StructuredSmallWindParser): individual values. """ + IN_LABEL = "cleaned_text_for_extraction" + """Identifier for text ingested by this class""" + + OUT_LABEL = "ordinance_values" + """Identifier for structured ordinance data output by this class""" + async def parse(self, text): """Parse text and extract structure ordinance data @@ -566,6 +563,12 @@ class StructuredSmallWindPermittedUseDistrictsParser( individual values. """ + IN_LABEL = "districts_text" + """Identifier for text ingested by this class""" + + OUT_LABEL = "permitted_district_values" + """Identifier for structured ordinance data output by this class""" + _SMALL_WES_CLARIFICATION = ( "Small wind energy systems (AWES) may also be referred to as " "non-commercial wind energy systems, on-site wind energy systems, " diff --git a/compass/extraction/small_wind/plugin.py b/compass/extraction/small_wind/plugin.py new file mode 100644 index 000000000..e4d3de1b9 --- /dev/null +++ b/compass/extraction/small_wind/plugin.py @@ -0,0 +1,98 @@ +"""COMPASS wind extraction plugin""" + +from compass.plugin.interface import ExtractionPlugin +from compass.extraction.small_wind.ordinance import ( + SmallWindHeuristic, + SmallWindOrdinanceTextCollector, + SmallWindOrdinanceTextExtractor, + SmallWindPermittedUseDistrictsTextCollector, + SmallWindPermittedUseDistrictsTextExtractor, +) +from compass.extraction.small_wind.parse import ( + StructuredSmallWindOrdinanceParser, + StructuredSmallWindPermittedUseDistrictsParser, +) + +StructuredSmallWindOrdinanceParser.IN_LABEL = ( + SmallWindOrdinanceTextExtractor.OUT_LABEL +) +StructuredSmallWindPermittedUseDistrictsParser.IN_LABEL = ( + SmallWindPermittedUseDistrictsTextExtractor.OUT_LABEL +) + +SMALL_WIND_QUESTION_TEMPLATES = [ + "filetype:pdf {jurisdiction} wind energy conversion system ordinances", + "wind energy conversion system ordinances {jurisdiction}", + "{jurisdiction} wind WECS ordinance", + ( + "Where can I find the legal text for small wind energy " + "turbine zoning ordinances in {jurisdiction}?" + ), + ( + "What is the specific legal information regarding zoning " + "ordinances for small wind turbines in {jurisdiction}?" + ), +] + +BEST_SMALL_WIND_ORDINANCE_WEBSITE_URL_KEYWORDS = { + "pdf": 92160, + "wecs": 46080, + "wind": 23040, + "zoning": 11520, + "ordinance": 5760, + r"renewable%20energy": 1440, + r"renewable+energy": 1440, + "renewable energy": 1440, + "planning": 720, + "plan": 360, + "government": 180, + "code": 60, + "area": 60, + r"land%20development": 15, + r"land+development": 15, + "land development": 15, + "land": 3, + "environment": 3, + "energy": 3, + "renewable": 3, + "municipal": 1, + "department": 1, +} + + +class COMPASSSmallWindExtractor(ExtractionPlugin): + """COMPASS small wind extraction plugin""" + + IDENTIFIER = "small wind" + """str: Identifier for extraction task """ + + QUESTION_TEMPLATES = SMALL_WIND_QUESTION_TEMPLATES + """list: List of search engine question templates for extraction""" + + WEBSITE_KEYWORDS = BEST_SMALL_WIND_ORDINANCE_WEBSITE_URL_KEYWORDS + """list: List of keywords + + Keywords indicate links which should be prioritized when performing + a website scrape for a wind ordinance document. + """ + + heuristic = SmallWindHeuristic() + """BaseHeuristic: Object with a ``check()`` method""" + + TEXT_COLLECTORS = [ + SmallWindOrdinanceTextCollector, + SmallWindPermittedUseDistrictsTextCollector, + ] + """Classes for collecting wind ordinance text chunks from docs""" + + TEXT_EXTRACTORS = [ + SmallWindOrdinanceTextExtractor, + SmallWindPermittedUseDistrictsTextExtractor, + ] + """Class for extracting cleaned ord text from collected text""" + + PARSERS = [ + StructuredSmallWindOrdinanceParser, + StructuredSmallWindPermittedUseDistrictsParser, + ] + """Class for parsing structured ordinance data from text""" diff --git a/compass/extraction/solar/__init__.py b/compass/extraction/solar/__init__.py index 332c84141..203addfc7 100644 --- a/compass/extraction/solar/__init__.py +++ b/compass/extraction/solar/__init__.py @@ -1,55 +1,3 @@ -"""Solar ordinance extraction utilities""" +"""Solar ordinance extraction plugin""" -from .ordinance import ( - SolarHeuristic, - SolarOrdinanceTextCollector, - SolarOrdinanceTextExtractor, - SolarPermittedUseDistrictsTextCollector, - SolarPermittedUseDistrictsTextExtractor, -) -from .parse import ( - StructuredSolarOrdinanceParser, - StructuredSolarPermittedUseDistrictsParser, -) - - -SOLAR_QUESTION_TEMPLATES = [ - "filetype:pdf {jurisdiction} solar energy conversion system ordinances", - "solar energy conversion system ordinances {jurisdiction}", - "{jurisdiction} solar energy farm ordinance", - ( - "Where can I find the legal text for commercial solar energy " - "conversion system zoning ordinances in {jurisdiction}?" - ), - ( - "What is the specific legal information regarding zoning " - "ordinances for commercial solar energy conversion systems in " - "{jurisdiction}?" - ), -] - -BEST_SOLAR_ORDINANCE_WEBSITE_URL_KEYWORDS = { - "pdf": 92160, - "secs": 46080, - "solar": 23040, - "zoning": 11520, - "ordinance": 5760, - r"renewable%20energy": 1440, - r"renewable+energy": 1440, - "renewable energy": 1440, - "planning": 720, - "plan": 360, - "government": 180, - "code": 60, - "area": 60, - r"land%20development": 15, - r"land+development": 15, - "land development": 15, - "land": 3, - "environment": 3, - "energy": 3, - "renewable": 3, - "municipal": 1, - "department": 1, - # TODO: add board??? -} +from .plugin import COMPASSSolarExtractor diff --git a/compass/extraction/solar/ordinance.py b/compass/extraction/solar/ordinance.py index 238c6d494..3679e519c 100644 --- a/compass/extraction/solar/ordinance.py +++ b/compass/extraction/solar/ordinance.py @@ -6,11 +6,12 @@ import logging -from compass.common import BaseTextExtractor -from compass.validation.content import Heuristic -from compass.llm.calling import StructuredLLMCaller +from compass.plugin.ordinance import ( + OrdinanceHeuristic, + OrdinanceTextCollector, + OrdinanceTextExtractor, +) from compass.utilities.enums import LLMUsageCategory -from compass.utilities.parsing import merge_overlapping_texts logger = logging.getLogger(__name__) @@ -34,7 +35,7 @@ ) -class SolarHeuristic(Heuristic): +class SolarHeuristic(OrdinanceHeuristic): """Perform a heuristic check for mention of solar farms in text""" NOT_TECH_WORDS = [ @@ -75,9 +76,12 @@ class SolarHeuristic(Heuristic): """Phrases that indicate text is about solar farms""" -class SolarOrdinanceTextCollector(StructuredLLMCaller): +class SolarOrdinanceTextCollector(OrdinanceTextCollector): """Check text chunks for ordinances and collect them if they do""" + OUT_LABEL = "relevant_text" + """Identifier for text collected by this class""" + CONTAINS_ORD_PROMPT = ( "You extract structured data from text. Return your answer in JSON " "format (not markdown). Your JSON file must include exactly two " @@ -90,7 +94,7 @@ class SolarOrdinanceTextCollector(StructuredLLMCaller): "energy systems (or solar panels). " "All restrictions should be enforceable - ignore any text that only " "provides a legal definition of the regulation. If the text does not " - f"specify any concrete {_SEARCH_TERMS_OR} for a wind energy system, " + f"specify any concrete {_SEARCH_TERMS_OR} for a solar energy system, " "set this key to `null`. The last key is '{key}', which is a boolean " "that is set to True if the text excerpt explicitly details " f"{_SEARCH_TERMS_OR} for a solar energy system and False otherwise." @@ -116,10 +120,6 @@ class SolarOrdinanceTextCollector(StructuredLLMCaller): ) """Prompt to check if chunk is for utility-scale SEF""" - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._ordinance_chunks = {} - async def check_chunk(self, chunk_parser, ind): """Check a chunk at a given ind to see if it contains ordinance @@ -158,32 +158,11 @@ async def check_chunk(self, chunk_parser, ind): logger.debug("Text at ind %d is for utility-scale SEF", ind) - _store_chunk(chunk_parser, ind, self._ordinance_chunks) + self._store_chunk(chunk_parser, ind) logger.debug("Added text at ind %d to ordinances", ind) return True - @property - def contains_ord_info(self): - """bool: Flag indicating whether text contains ordinance info""" - return bool(self._ordinance_chunks) - - @property - def ordinance_text(self): - """str: Combined ordinance text from the individual chunks""" - logger.debug( - "Grabbing %d ordinance chunk(s) from original text at these " - "indices: %s", - len(self._ordinance_chunks), - list(self._ordinance_chunks), - ) - - text = [ - self._ordinance_chunks[ind] - for ind in sorted(self._ordinance_chunks) - ] - return merge_overlapping_texts(text) - async def _check_chunk_contains_ord(self, key, text_chunk): """Call LLM on a chunk of text to check for ordinance""" content = await self.call( @@ -205,9 +184,12 @@ async def _check_chunk_is_for_utility_scale(self, key, text_chunk): return content.get(key, False) -class SolarPermittedUseDistrictsTextCollector(StructuredLLMCaller): +class SolarPermittedUseDistrictsTextCollector(OrdinanceTextCollector): """Check text chunks for permitted solar districts; collect them""" + OUT_LABEL = "permitted_use_text" + """Identifier for text collected by this class""" + DISTRICT_PROMPT = ( "You are a legal scholar that reads ordinance text and determines " "whether it explicitly contains relevant information to determine the " @@ -230,10 +212,6 @@ class SolarPermittedUseDistrictsTextCollector(StructuredLLMCaller): ) """Prompt to check if chunk contains info on permitted districts""" - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._district_chunks = {} - async def check_chunk(self, chunk_parser, ind): """Check a chunk to see if it contains permitted uses @@ -263,46 +241,28 @@ async def check_chunk(self, chunk_parser, ind): contains_district_info = content.get(key, False) if contains_district_info: - _store_chunk(chunk_parser, ind, self._district_chunks) + self._store_chunk(chunk_parser, ind) logger.debug("Text at ind %d contains district info", ind) return True logger.debug("Text at ind %d does not contain district info", ind) return False - @property - def contains_district_info(self): - """bool: Flag indicating whether text contains district info""" - return bool(self._district_chunks) - @property - def permitted_use_district_text(self): - """str: Combined permitted use districts text from the chunks""" - logger.debug( - "Grabbing %d permitted use chunk(s) from original text at these " - "indices: %s", - len(self._district_chunks), - list(self._district_chunks), - ) +class SolarOrdinanceTextExtractor(OrdinanceTextExtractor): + """Extract succinct ordinance text from input""" - text = [ - self._district_chunks[ind] for ind in sorted(self._district_chunks) - ] - return merge_overlapping_texts(text) + IN_LABEL = SolarOrdinanceTextCollector.OUT_LABEL + """Identifier for collected text ingested by this class""" + OUT_LABEL = "cleaned_text_for_extraction" + """Identifier for ordinance text extracted by this class""" -class SolarOrdinanceTextExtractor(BaseTextExtractor): - """Extract succinct ordinance text from input + TASK_DESCRIPTION = "Extracting solar ordinance text" + """Task description to show in progress bar""" - Purpose: - Extract relevant ordinance text from document. - Responsibilities: - 1. Extract portions from chunked document text relevant to - particular ordinance type (e.g. solar zoning for - utility-scale systems). - Key Relationships: - Uses a StructuredLLMCaller for LLM queries. - """ + TASK_ID = "ordinance_text_extraction" + """ID to use for this extraction for linking with LLM configs""" SOLAR_ENERGY_SYSTEM_FILTER_PROMPT = ( "# CONTEXT #\n" @@ -366,7 +326,6 @@ async def extract_solar_energy_system_section(self, text_chunks): return await self._process( text_chunks=text_chunks, instructions=self.SOLAR_ENERGY_SYSTEM_FILTER_PROMPT, - is_valid_chunk=_valid_chunk, ) @property @@ -381,24 +340,23 @@ def parsers(self): Async function that takes a ``text_chunks`` input and outputs parsed text. """ - yield ( - "cleaned_ordinance_text", - self.extract_solar_energy_system_section, - ) + yield self.OUT_LABEL, self.extract_solar_energy_system_section + + +class SolarPermittedUseDistrictsTextExtractor(OrdinanceTextExtractor): + """Extract succinct permitted use districts text from input""" + + IN_LABEL = SolarPermittedUseDistrictsTextCollector.OUT_LABEL + """Identifier for collected text ingested by this class""" + OUT_LABEL = "districts_text" + """Identifier for permitted use text extracted by this class""" -class SolarPermittedUseDistrictsTextExtractor(BaseTextExtractor): - """Extract succinct ordinance text from input + TASK_DESCRIPTION = "Extracting solar permitted use text" + """Task description to show in progress bar""" - Purpose: - Extract relevant ordinance text from document. - Responsibilities: - 1. Extract portions from chunked document text relevant to - particular ordinance type (e.g. solar zoning for - utility-scale systems). - Key Relationships: - Uses a StructuredLLMCaller for LLM queries. - """ + TASK_ID = "permitted_use_text_extraction" + """ID to use for this extraction for linking with LLM configs""" _USAGE_LABEL = LLMUsageCategory.DOCUMENT_PERMITTED_USE_DISTRICTS_SUMMARY @@ -511,7 +469,6 @@ async def extract_permitted_uses(self, text_chunks): return await self._process( text_chunks=text_chunks, instructions=self.PERMITTED_USES_FILTER_PROMPT, - is_valid_chunk=_valid_chunk, ) async def extract_sef_permitted_uses(self, text_chunks): @@ -532,7 +489,6 @@ async def extract_sef_permitted_uses(self, text_chunks): return await self._process( text_chunks=text_chunks, instructions=self.SEF_PERMITTED_USES_FILTER_PROMPT, - is_valid_chunk=_valid_chunk, ) @property @@ -548,19 +504,4 @@ def parsers(self): outputs parsed text. """ yield "permitted_use_only_text", self.extract_permitted_uses - yield "districts_text", self.extract_sef_permitted_uses - - -def _valid_chunk(chunk): - """True if chunk has content""" - return chunk and "no relevant text" not in chunk.lower() - - -def _store_chunk(parser, chunk_ind, store): - """Store chunk and its neighbors if it is not already stored""" - for offset in range(1 - parser.num_to_recall, 2): - ind_to_grab = chunk_ind + offset - if ind_to_grab < 0 or ind_to_grab >= len(parser.text_chunks): - continue - - store.setdefault(ind_to_grab, parser.text_chunks[ind_to_grab]) + yield self.OUT_LABEL, self.extract_sef_permitted_uses diff --git a/compass/extraction/solar/parse.py b/compass/extraction/solar/parse.py index 9972deb42..89088a331 100644 --- a/compass/extraction/solar/parse.py +++ b/compass/extraction/solar/parse.py @@ -8,7 +8,7 @@ import pandas as pd -from compass.llm.calling import BaseLLMCaller, ChatLLMCaller +from compass.plugin.ordinance import OrdinanceParser from compass.extraction.features import SetbackFeatures from compass.common import ( EXTRACT_ORIGINAL_SETBACK_TEXT_PROMPT, @@ -150,18 +150,9 @@ } -class StructuredSolarParser(BaseLLMCaller): +class StructuredSolarParser(OrdinanceParser): """Base class for parsing structured data""" - def _init_chat_llm_caller(self, system_message): - """Initialize a ChatLLMCaller instance for the DecisionTree""" - return ChatLLMCaller( - self.llm_service, - system_message=system_message, - usage_tracker=self.usage_tracker, - **self.kwargs, - ) - async def _check_solar_farm_type(self, text): """Get the largest solar farm size mentioned in the text""" logger.info("Checking solar farm types") @@ -203,6 +194,12 @@ class StructuredSolarOrdinanceParser(StructuredSolarParser): individual values. """ + IN_LABEL = "cleaned_text_for_extraction" + """Identifier for text ingested by this class""" + + OUT_LABEL = "ordinance_values" + """Identifier for structured ordinance data output by this class""" + async def parse(self, text): """Parse text and extract structure ordinance data @@ -502,6 +499,12 @@ class StructuredSolarPermittedUseDistrictsParser(StructuredSolarParser): individual values. """ + IN_LABEL = "districts_text" + """Identifier for text ingested by this class""" + + OUT_LABEL = "permitted_district_values" + """Identifier for structured ordinance data output by this class""" + _LARGE_SEF_CLARIFICATION = ( "Large solar energy systems (SES) may also be referred to as solar " "panels, solar energy conversion systems (SECS), solar energy " diff --git a/compass/extraction/solar/plugin.py b/compass/extraction/solar/plugin.py new file mode 100644 index 000000000..8123f2acd --- /dev/null +++ b/compass/extraction/solar/plugin.py @@ -0,0 +1,99 @@ +"""COMPASS solar extraction plugin""" + +from compass.plugin.interface import ExtractionPlugin +from compass.extraction.solar.ordinance import ( + SolarHeuristic, + SolarOrdinanceTextCollector, + SolarOrdinanceTextExtractor, + SolarPermittedUseDistrictsTextCollector, + SolarPermittedUseDistrictsTextExtractor, +) +from compass.extraction.solar.parse import ( + StructuredSolarOrdinanceParser, + StructuredSolarPermittedUseDistrictsParser, +) + +StructuredSolarOrdinanceParser.IN_LABEL = SolarOrdinanceTextExtractor.OUT_LABEL +StructuredSolarPermittedUseDistrictsParser.IN_LABEL = ( + SolarPermittedUseDistrictsTextExtractor.OUT_LABEL +) + +SOLAR_QUESTION_TEMPLATES = [ + "filetype:pdf {jurisdiction} solar energy conversion system ordinances", + "solar energy conversion system ordinances {jurisdiction}", + "{jurisdiction} solar energy farm ordinance", + ( + "Where can I find the legal text for commercial solar energy " + "conversion system zoning ordinances in {jurisdiction}?" + ), + ( + "What is the specific legal information regarding zoning " + "ordinances for commercial solar energy conversion systems in " + "{jurisdiction}?" + ), +] + + +BEST_SOLAR_ORDINANCE_WEBSITE_URL_KEYWORDS = { + "pdf": 92160, + "secs": 46080, + "solar": 23040, + "zoning": 11520, + "ordinance": 5760, + r"renewable%20energy": 1440, + r"renewable+energy": 1440, + "renewable energy": 1440, + "planning": 720, + "plan": 360, + "government": 180, + "code": 60, + "area": 60, + r"land%20development": 15, + r"land+development": 15, + "land development": 15, + "land": 3, + "environment": 3, + "energy": 3, + "renewable": 3, + "municipal": 1, + "department": 1, + # TODO: add board??? +} + + +class COMPASSSolarExtractor(ExtractionPlugin): + """COMPASS solar extraction plugin""" + + IDENTIFIER = "solar" + """str: Identifier for extraction task """ + + QUESTION_TEMPLATES = SOLAR_QUESTION_TEMPLATES + """list: List of search engine question templates for extraction""" + + WEBSITE_KEYWORDS = BEST_SOLAR_ORDINANCE_WEBSITE_URL_KEYWORDS + """list: List of keywords + + Keywords indicate links which should be prioritized when performing + a website scrape for a wind ordinance document. + """ + + heuristic = SolarHeuristic() + """BaseHeuristic: Object with a ``check()`` method""" + + TEXT_COLLECTORS = [ + SolarOrdinanceTextCollector, + SolarPermittedUseDistrictsTextCollector, + ] + """Classes for collecting wind ordinance text chunks from docs""" + + TEXT_EXTRACTORS = [ + SolarOrdinanceTextExtractor, + SolarPermittedUseDistrictsTextExtractor, + ] + """Class for extracting cleaned ord text from collected text""" + + PARSERS = [ + StructuredSolarOrdinanceParser, + StructuredSolarPermittedUseDistrictsParser, + ] + """Class for parsing structured ordinance data from text""" diff --git a/compass/extraction/water/__init__.py b/compass/extraction/water/__init__.py index 4141ecbbf..bf4008bb4 100644 --- a/compass/extraction/water/__init__.py +++ b/compass/extraction/water/__init__.py @@ -1,48 +1,3 @@ """Water ordinance extraction utilities""" -from .parse import StructuredWaterParser -from .ordinance import ( - WaterRightsHeuristic, - WaterRightsTextCollector, - WaterRightsTextExtractor, -) -from .processing import ( - build_corpus, - extract_water_rights_ordinance_values, - label_docs_no_legal_check, - write_water_rights_data_to_disk, -) - - -WATER_RIGHTS_QUESTION_TEMPLATES = [ - "{jurisdiction} rules", - "{jurisdiction} management plan", - "{jurisdiction} well permits", - "{jurisdiction} well permit requirements", - "requirements to drill a water well in {jurisdiction}", -] - -BEST_WATER_RIGHTS_ORDINANCE_WEBSITE_URL_KEYWORDS = { - "pdf": 92160, - "water": 46080, - "rights": 23040, - "zoning": 11520, - "ordinance": 5760, - r"renewable%20energy": 1440, - r"renewable+energy": 1440, - "renewable energy": 1440, - "planning": 720, - "plan": 360, - "government": 180, - "code": 60, - "area": 60, - r"land%20development": 15, - r"land+development": 15, - "land development": 15, - "land": 3, - "environment": 3, - "energy": 3, - "renewable": 3, - "municipal": 1, - "department": 1, -} +from .plugin import TexasWaterRightsExtractor diff --git a/compass/extraction/water/ordinance.py b/compass/extraction/water/ordinance.py deleted file mode 100644 index 49d5e8006..000000000 --- a/compass/extraction/water/ordinance.py +++ /dev/null @@ -1,138 +0,0 @@ -"""Water ordinance document content collection and extraction - -These methods help filter down the document text to only the portions -relevant to water rights ordinances. -""" - -import logging - -from compass.common import BaseTextExtractor -from compass.llm.calling import StructuredLLMCaller -from compass.utilities.parsing import merge_overlapping_texts -from compass.utilities.enums import LLMUsageCategory - - -logger = logging.getLogger(__name__) - - -class WaterRightsHeuristic: - """NoOp heuristic check""" - - def check(self, *__, **___): # noqa: PLR6301 - """Always return ``True`` for water rights documents""" - return True - - -class WaterRightsTextCollector(StructuredLLMCaller): - """Check text chunks for ordinances and collect them if they do""" - - WELL_PERMITS_PROMPT = ( - "You extract structured data from text. Return your answer in JSON " - "format (not markdown). Your JSON file must include exactly three " - "keys. The first key is 'district_rules' which is a string summarizes " - "the rules associated with the groundwater conservation district. " - "The second key is 'well_requirements', which is a string that " - "summarizes the requirements for drilling a groundwater well. The " - "last key is '{key}', which is a boolean that is set to True if the " - "text excerpt provides substantive information related to the " - "groundwater conservation district's rules or management plans. " - ) - """Prompt to check if chunk contains water rights ordinance info""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._ordinance_chunks = {} - - @property - def contains_ord_info(self): - """bool: Flag indicating whether text contains ordinance info""" - return bool(self._ordinance_chunks) - - @property - def ordinance_text(self): - """str: Combined ordinance text from the individual chunks""" - logger.debug( - "Grabbing %d ordinance chunk(s) from original text at these " - "indices: %s", - len(self._ordinance_chunks), - list(self._ordinance_chunks), - ) - - text = [ - self._ordinance_chunks[ind] - for ind in sorted(self._ordinance_chunks) - ] - return merge_overlapping_texts(text) - - async def check_chunk(self, chunk_parser, ind): - """Check a chunk at a given ind to see if it contains ordinance - - Parameters - ---------- - chunk_parser : ParseChunksWithMemory - Instance that contains a ``parse_from_ind`` method. - ind : int - Index of the chunk to check. - - Returns - ------- - bool - Boolean flag indicating whether or not the text in the chunk - contains water rights ordinance text. - """ - contains_ord_info = await chunk_parser.parse_from_ind( - ind, - key="contains_ord_info", - llm_call_callback=self._check_chunk_contains_ord, - ) - - if contains_ord_info: - logger.debug( - "Text at ind %d contains water rights ordinance info", ind - ) - _store_chunk(chunk_parser, ind, self._ordinance_chunks) - else: - logger.debug( - "Text at ind %d does not contain water rights ordinance info", - ind, - ) - - return contains_ord_info - - async def _check_chunk_contains_ord(self, key, text_chunk): - """Call LLM on a chunk of text to check for ordinance""" - content = await self.call( - sys_msg=self.WELL_PERMITS_PROMPT.format(key=key), - content=text_chunk, - usage_sub_label=(LLMUsageCategory.DOCUMENT_CONTENT_VALIDATION), - ) - logger.debug("LLM response: %s", content) - return content.get(key, False) - - -class WaterRightsTextExtractor(BaseTextExtractor): - """No-Op text extractor""" - - @property - def parsers(self): - """Iterable of parsers provided by this extractor - - Yields - ------ - name : str - Name describing the type of text output by the parser. - parser : callable - Async function that takes a ``text_chunks`` input and - outputs parsed text. - """ - yield "cleaned_ordinance_text", merge_overlapping_texts - - -def _store_chunk(parser, chunk_ind, store): - """Store chunk and its neighbors if it is not already stored""" - for offset in range(1 - parser.num_to_recall, 2): - ind_to_grab = chunk_ind + offset - if ind_to_grab < 0 or ind_to_grab >= len(parser.text_chunks): - continue - - store.setdefault(ind_to_grab, parser.text_chunks[ind_to_grab]) diff --git a/compass/extraction/water/plugin.py b/compass/extraction/water/plugin.py new file mode 100644 index 000000000..eda687096 --- /dev/null +++ b/compass/extraction/water/plugin.py @@ -0,0 +1,292 @@ +"""COMPASS water rights extraction plugin""" + +import logging +from pathlib import Path + +import pandas as pd +from elm import EnergyWizard +from elm.embed import ChunkAndEmbed + +from compass.extraction import extract_date +from compass.plugin.base import BaseExtractionPlugin +from compass.utilities.enums import LLMTasks +from compass.utilities.parsing import extract_ord_year_from_doc_attrs +from compass.exceptions import COMPASSRuntimeError +from compass.extraction.water.parse import StructuredWaterParser + + +logger = logging.getLogger(__name__) + + +WATER_RIGHTS_QUESTION_TEMPLATES = [ + "{jurisdiction} rules", + "{jurisdiction} management plan", + "{jurisdiction} well permits", + "{jurisdiction} well permit requirements", + "requirements to drill a water well in {jurisdiction}", +] +BEST_WATER_RIGHTS_ORDINANCE_WEBSITE_URL_KEYWORDS = { + "pdf": 92160, + "water": 46080, + "rights": 23040, + "zoning": 11520, + "ordinance": 5760, + r"renewable%20energy": 1440, + r"renewable+energy": 1440, + "renewable energy": 1440, + "planning": 720, + "plan": 360, + "government": 180, + "code": 60, + "area": 60, + r"land%20development": 15, + r"land+development": 15, + "land development": 15, + "land": 3, + "environment": 3, + "energy": 3, + "renewable": 3, + "municipal": 1, + "department": 1, +} + + +class WaterRightsHeuristic: + """NoOp heuristic check""" + + def check(self, *__, **___): # noqa: PLR6301 + """Always return ``True`` for water rights documents""" + return True + + +class TexasWaterRightsExtractor(BaseExtractionPlugin): + """COMPASS solar extraction plugin""" + + IDENTIFIER = "tx water rights" + """str: Identifier for extraction task """ + + QUESTION_TEMPLATES = WATER_RIGHTS_QUESTION_TEMPLATES + """list: List of search engine question templates for extraction""" + + WEBSITE_KEYWORDS = BEST_WATER_RIGHTS_ORDINANCE_WEBSITE_URL_KEYWORDS + """list: List of keywords + + Keywords indicate links which should be prioritized when performing + a website scrape for a wind ordinance document. + """ + + heuristic = WaterRightsHeuristic() + """BaseHeuristic: Object with a ``check()`` method""" + + async def filter_docs( + self, + extraction_context, + need_jurisdiction_verification=True, # noqa: ARG002 + ): + """Filter down candidate documents before parsing + + Parameters + ---------- + extraction_context : ExtractionContext + Context containing candidate documents to be filtered. + Set the ``.documents`` attribute of this object to be the + iterable of documents that should be kept for parsing. + need_jurisdiction_verification : bool, optional + Whether to verify that documents pertain to the correct + jurisdiction. By default, ``True``. + + Returns + ------- + ExtractionContext + Context with filtered down documents. + """ + model_config = self.model_configs.get( + LLMTasks.EMBEDDING, self.model_configs[LLMTasks.DEFAULT] + ) + _setup_endpoints(model_config) + + corpus = [] + for ind, doc in enumerate(extraction_context, start=1): + url = doc.attrs.get("source", "unknown source") + logger.info("Embedding %r", url) + obj = ChunkAndEmbed( + doc.text, + model=model_config.name, + tokens_per_chunk=model_config.text_splitter_chunk_size, + overlap=model_config.text_splitter_chunk_overlap, + split_on="\n", + ) + try: + embeddings = await obj.run_async( + rate_limit=model_config.llm_service_rate_limit + ) + if any(e is None for e in embeddings): + msg = ( + "Embeddings are ``None`` when building corpus for " + "water rights extraction!" + ) + raise COMPASSRuntimeError(msg) # noqa: TRY301 + + corpus.append( + pd.DataFrame( + { + "text": obj.text_chunks.chunks, + "embedding": embeddings, + } + ) + ) + + except Exception as e: # noqa: BLE001 + logger.info("could not embed %r with error: %s", url, e) + continue + + date_model_config = self.model_configs.get( + LLMTasks.DATE_EXTRACTION, self.model_configs[LLMTasks.DEFAULT] + ) + await extract_date(doc, date_model_config, self.usage_tracker) + + await extraction_context.mark_doc_as_data_source( + doc, out_fn_stem=f"{self.jurisdiction.full_name} {ind}" + ) + + if len(corpus) == 0: + logger.info( + "No documents returned for %s, skipping", + self.jurisdiction.full_name, + ) + return None + + extraction_context.attrs["corpus"] = pd.concat(corpus) + return extraction_context + + async def parse_docs_for_structured_data(self, extraction_context): + """Parse documents to extract structured data/information + + Parameters + ---------- + extraction_context : ExtractionContext + Context containing candidate documents to parse. + + Returns + ------- + ExtractionContext or None + Context with extracted data/information stored in the + ``.attrs`` dictionary, or ``None`` if no data was extracted. + """ + model_config = self.model_configs.get( + LLMTasks.DATA_EXTRACTION, self.model_configs[LLMTasks.DEFAULT] + ) + + logger.debug("Building energy wizard") + wizard = EnergyWizard( + extraction_context.attrs["corpus"], + model=model_config.name, + ) + + logger.debug("Calling parser class") + parser = StructuredWaterParser( + wizard=wizard, + location=self.jurisdiction.full_name, + llm_service=model_config.llm_service, + usage_tracker=self.usage_tracker, + **model_config.llm_call_kwargs, + ) + + data_df = await parser.parse() + data_df = _set_data_year(data_df, extraction_context) + data_df = _set_data_sources(data_df, extraction_context) + extraction_context.attrs["structured_data"] = data_df + extraction_context.attrs["out_data_fn"] = ( + f"{self.jurisdiction.full_name} Water Rights.csv" + ) + return extraction_context + + @classmethod + def save_structured_data(cls, doc_infos, out_dir): + """Write extracted water rights data to disk + + Parameters + ---------- + doc_infos : list of dict + List of dictionaries containing the following keys: + + - "jurisdiction": An initialized Jurisdiction object + representing the jurisdiction that was extracted. + - "ord_db_fp": A path to the extracted structured data + stored on disk, or ``None`` if no data was extracted. + + out_dir : path-like + Path to the output directory for the data. + + Returns + ------- + int + Number of unique water rights districts that information was + found/written for. + """ + db = [] + for doc_info in doc_infos: + ord_db = pd.read_csv(doc_info["ord_db_fp"]) + if len(ord_db) == 0: + continue + + jurisdiction = doc_info["jurisdiction"] + ord_db["WCD_ID"] = jurisdiction.code + ord_db["county"] = jurisdiction.county + ord_db["state"] = jurisdiction.state + ord_db["subdivision"] = jurisdiction.subdivision_name + ord_db["jurisdiction_type"] = jurisdiction.type + + db.append(ord_db) + + if not db: + return 0 + + db = pd.concat([df.dropna(axis=1, how="all") for df in db], axis=0) + db.to_csv(Path(out_dir) / "water_rights.csv", index=False) + return len(db["WCD_ID"].unique()) + + +def _set_data_year(data_df, extraction_context): + """Set the ordinance year column in the data DataFrame""" + years = filter( + None, + [ + extract_ord_year_from_doc_attrs(doc.attrs) + for doc in extraction_context + ], + ) + if not years: + data_df["ord_year"] = None + else: + # TODO: is `max` the right one to use here? + data_df["ord_year"] = max(years) + return data_df + + +def _set_data_sources(data_df, extraction_context): + """Set the source column in the data DataFrame""" + sources = filter( + None, [doc.attrs.get("source") for doc in extraction_context] + ) + if not sources: + data_df["source"] = None + else: + data_df["source"] = " ;\n".join(sources) + return data_df + + +def _setup_endpoints(embedding_model_config): + """Set proper URLS for elm classes""" + ChunkAndEmbed.USE_CLIENT_EMBEDDINGS = True + EnergyWizard.USE_CLIENT_EMBEDDINGS = True + ChunkAndEmbed.EMBEDDING_MODEL = EnergyWizard.EMBEDDING_MODEL = ( + embedding_model_config.name + ) + + endpoint = embedding_model_config.client_kwargs["azure_endpoint"] + ChunkAndEmbed.EMBEDDING_URL = endpoint + ChunkAndEmbed.URL = endpoint + EnergyWizard.EMBEDDING_URL = endpoint + + EnergyWizard.URL = "openai.azure.com" # need to trigger Azure setup diff --git a/compass/extraction/water/processing.py b/compass/extraction/water/processing.py deleted file mode 100644 index 92d5daa92..000000000 --- a/compass/extraction/water/processing.py +++ /dev/null @@ -1,201 +0,0 @@ -"""Water ordinance structured parsing class""" - -import logging -from pathlib import Path - -import pandas as pd -from elm import EnergyWizard -from elm.embed import ChunkAndEmbed -from elm.web.document import PDFDocument - -from compass.utilities.enums import LLMTasks -from compass.exceptions import COMPASSRuntimeError - - -logger = logging.getLogger(__name__) - - -async def label_docs_no_legal_check(docs, **__): # noqa: RUF029 - """Label documents with the "don't check for legal status" flag - - Parameters - ---------- - docs : iterable of elm.web.document.PDFDocument - Documents to label. - - Returns - ------- - iterable of elm.web.document.PDFDocument - Input docs with the "check_if_legal_doc" attribute set to False. - """ - for doc in docs: - doc.attrs["check_if_legal_doc"] = False - return docs - - -async def build_corpus(docs, jurisdiction, model_configs, **__): - """Build knowledge corpus for water rights extraction - - Parameters - ---------- - docs : iterable of elm.web.document.PDFDocument - Documents to build corpus from. - jurisdiction : compass.utilities.location.Jurisdiction - Jurisdiction being processed. - model_configs : dict - Dictionary of model configurations for various LLM tasks. - - Returns - ------- - list or None - List containing a single PDFDocument with the corpus, or None - if no corpus could be built. - - Raises - ------ - COMPASSRuntimeError - If embeddings could not be generated. - """ - model_config = model_configs.get( - LLMTasks.EMBEDDING, model_configs[LLMTasks.DEFAULT] - ) - _setup_endpoints(model_config) - - corpus = [] - for doc in docs: - url = doc.attrs.get("source", "unknown source") - logger.info("Embedding %r", url) - obj = ChunkAndEmbed( - doc.text, - model=model_config.name, - tokens_per_chunk=model_config.text_splitter_chunk_size, - overlap=model_config.text_splitter_chunk_overlap, - split_on="\n", - ) - try: - embeddings = await obj.run_async(rate_limit=3e4) - if any(e is None for e in embeddings): - msg = ( - "Embeddings are ``None`` when building corpus for " - "water rights extraction!" - ) - raise COMPASSRuntimeError(msg) # noqa: TRY301 - - corpus.append( - pd.DataFrame( - {"text": obj.text_chunks.chunks, "embedding": embeddings} - ) - ) - - except Exception as e: # noqa: BLE001 - logger.info("could not embed %r with error: %s", url, e) - - if len(corpus) == 0: - logger.info( - "No documents returned for %s, skipping", jurisdiction.full_name - ) - return None - - corpus_doc = PDFDocument( - ["water extraction context"], attrs={"corpus": pd.concat(corpus)} - ) - return [corpus_doc] - - -async def extract_water_rights_ordinance_values( - corpus_doc, parser_class, out_key, usage_tracker, model_config, **__ -): - """Extract ordinance values from a temporary vector store. - - Parameters - ---------- - corpus_doc : elm.web.document.PDFDocument - Document containing the vector store corpus. - parser_class : type - Class used to parse the vector store. - out_key : str - Key used to store extracted values in the document attributes. - usage_tracker : compass.services.usage.UsageTracker - Instance of the UsageTracker class used to track LLM usage. - model_config : compass.llm.config.LLMConfig - Model configuration used for LLM calls. - - Returns - ------- - elm.web.document.PDFDocument - Document with extracted ordinance values stored in attributes. - """ - - logger.debug("Building energy wizard") - wizard = EnergyWizard(corpus_doc.attrs["corpus"], model=model_config.name) - - logger.debug("Calling parser class") - parser = parser_class( - wizard=wizard, - location=corpus_doc.attrs["jurisdiction_name"], - llm_service=model_config.llm_service, - usage_tracker=usage_tracker, - **model_config.llm_call_kwargs, - ) - corpus_doc.attrs[out_key] = await parser.parse() - return corpus_doc - - -def write_water_rights_data_to_disk(doc_infos, out_dir): - """Write extracted water rights data to disk - - Parameters - ---------- - doc_infos : list of dict - List of dictionaries containing extracted document information - and data file paths. - out_dir : path-like - Path to the output directory for the data. - - Returns - ------- - int - Number of unique water rights districts that information was - found/written for. - """ - db = [] - for doc_info in doc_infos: - ord_db = pd.read_csv(doc_info["ord_db_fp"]) - if len(ord_db) == 0: - continue - ord_db["source"] = doc_info.get("source") - - year, *__ = doc_info.get("date") or (None, None, None) - ord_db["ord_year"] = year if year is not None and year > 0 else None - - jurisdiction = doc_info["jurisdiction"] - ord_db["WCD_ID"] = jurisdiction.code - ord_db["county"] = jurisdiction.county - ord_db["state"] = jurisdiction.state - ord_db["subdivision"] = jurisdiction.subdivision_name - ord_db["jurisdiction_type"] = jurisdiction.type - - db.append(ord_db) - - if not db: - return 0 - - db = pd.concat([df.dropna(axis=1, how="all") for df in db], axis=0) - db.to_csv(Path(out_dir) / "water_rights.csv", index=False) - return len(db["WCD_ID"].unique()) - - -def _setup_endpoints(embedding_model_config): - """Set proper URLS for elm classes""" - ChunkAndEmbed.USE_CLIENT_EMBEDDINGS = True - EnergyWizard.USE_CLIENT_EMBEDDINGS = True - ChunkAndEmbed.EMBEDDING_MODEL = EnergyWizard.EMBEDDING_MODEL = ( - embedding_model_config.name - ) - - endpoint = embedding_model_config.client_kwargs["azure_endpoint"] - ChunkAndEmbed.EMBEDDING_URL = endpoint - ChunkAndEmbed.URL = endpoint - EnergyWizard.EMBEDDING_URL = endpoint - - EnergyWizard.URL = "openai.azure.com" # need to trigger Azure setup diff --git a/compass/extraction/wind/__init__.py b/compass/extraction/wind/__init__.py index f8169b9e6..f92547fbc 100644 --- a/compass/extraction/wind/__init__.py +++ b/compass/extraction/wind/__init__.py @@ -1,54 +1,3 @@ -"""Wind ordinance extraction utilities""" +"""Wind ordinance extraction plugin""" -from .ordinance import ( - WindHeuristic, - WindOrdinanceTextCollector, - WindOrdinanceTextExtractor, - WindPermittedUseDistrictsTextCollector, - WindPermittedUseDistrictsTextExtractor, -) -from .parse import ( - StructuredWindOrdinanceParser, - StructuredWindPermittedUseDistrictsParser, -) - - -WIND_QUESTION_TEMPLATES = [ - "filetype:pdf {jurisdiction} wind energy conversion system ordinances", - "wind energy conversion system ordinances {jurisdiction}", - "{jurisdiction} wind WECS ordinance", - ( - "Where can I find the legal text for commercial wind energy " - "conversion system zoning ordinances in {jurisdiction}?" - ), - ( - "What is the specific legal information regarding zoning " - "ordinances for commercial wind energy conversion systems in " - "{jurisdiction}?" - ), -] - -BEST_WIND_ORDINANCE_WEBSITE_URL_KEYWORDS = { - "pdf": 92160, - "wecs": 46080, - "wind": 23040, - "zoning": 11520, - "ordinance": 5760, - r"renewable%20energy": 1440, - r"renewable+energy": 1440, - "renewable energy": 1440, - "planning": 720, - "plan": 360, - "government": 180, - "code": 60, - "area": 60, - r"land%20development": 15, - r"land+development": 15, - "land development": 15, - "land": 3, - "environment": 3, - "energy": 3, - "renewable": 3, - "municipal": 1, - "department": 1, -} +from .plugin import COMPASSWindExtractor diff --git a/compass/extraction/wind/ordinance.py b/compass/extraction/wind/ordinance.py index d09a30176..d1be4f93b 100644 --- a/compass/extraction/wind/ordinance.py +++ b/compass/extraction/wind/ordinance.py @@ -6,11 +6,12 @@ import logging -from compass.common import BaseTextExtractor -from compass.validation.content import Heuristic -from compass.llm.calling import StructuredLLMCaller +from compass.plugin.ordinance import ( + OrdinanceHeuristic, + OrdinanceTextCollector, + OrdinanceTextExtractor, +) from compass.utilities.enums import LLMUsageCategory -from compass.utilities.parsing import merge_overlapping_texts logger = logging.getLogger(__name__) @@ -32,7 +33,7 @@ _IGNORE_TYPES = "private, residential, micro, small, or medium sized" -class WindHeuristic(Heuristic): +class WindHeuristic(OrdinanceHeuristic): """Perform a heuristic check for mention of wind turbines in text""" NOT_TECH_WORDS = [ @@ -90,9 +91,12 @@ class WindHeuristic(Heuristic): """Phrases that indicate text is about WECS""" -class WindOrdinanceTextCollector(StructuredLLMCaller): +class WindOrdinanceTextCollector(OrdinanceTextCollector): """Check text chunks for ordinances and collect them if they do""" + OUT_LABEL = "relevant_text" + """Identifier for text collected by this class""" + CONTAINS_ORD_PROMPT = ( "You extract structured data from text. Return your answer in JSON " "format (not markdown). Your JSON file must include exactly two " @@ -133,10 +137,6 @@ class WindOrdinanceTextCollector(StructuredLLMCaller): ) """Prompt to check if chunk is for utility-scale WES""" - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._ordinance_chunks = {} - async def check_chunk(self, chunk_parser, ind): """Check a chunk at a given ind to see if it contains ordinance @@ -175,32 +175,11 @@ async def check_chunk(self, chunk_parser, ind): logger.debug("Text at ind %d is for utility-scale WECS", ind) - _store_chunk(chunk_parser, ind, self._ordinance_chunks) + self._store_chunk(chunk_parser, ind) logger.debug("Added text at ind %d to ordinances", ind) return True - @property - def contains_ord_info(self): - """bool: Flag indicating whether text contains ordinance info""" - return bool(self._ordinance_chunks) - - @property - def ordinance_text(self): - """str: Combined ordinance text from the individual chunks""" - logger.debug( - "Grabbing %d ordinance chunk(s) from original text at these " - "indices: %s", - len(self._ordinance_chunks), - list(self._ordinance_chunks), - ) - - text = [ - self._ordinance_chunks[ind] - for ind in sorted(self._ordinance_chunks) - ] - return merge_overlapping_texts(text) - async def _check_chunk_contains_ord(self, key, text_chunk): """Call LLM on a chunk of text to check for ordinance""" content = await self.call( @@ -222,9 +201,12 @@ async def _check_chunk_is_for_utility_scale(self, key, text_chunk): return content.get(key, False) -class WindPermittedUseDistrictsTextCollector(StructuredLLMCaller): +class WindPermittedUseDistrictsTextCollector(OrdinanceTextCollector): """Check text chunks for permitted wind districts; collect them""" + OUT_LABEL = "permitted_use_text" + """Identifier for text collected by this class""" + DISTRICT_PROMPT = ( "You are a legal scholar that reads ordinance text and determines " "whether the text explicitly contains relevant information to " @@ -247,10 +229,6 @@ class WindPermittedUseDistrictsTextCollector(StructuredLLMCaller): ) """Prompt to check if chunk contains info on permitted districts""" - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._district_chunks = {} - async def check_chunk(self, chunk_parser, ind): """Check a chunk to see if it contains permitted uses @@ -281,46 +259,28 @@ async def check_chunk(self, chunk_parser, ind): contains_district_info = content.get(key, False) if contains_district_info: - _store_chunk(chunk_parser, ind, self._district_chunks) + self._store_chunk(chunk_parser, ind) logger.debug("Text at ind %d contains district info", ind) return True logger.debug("Text at ind %d does not contain district info", ind) return False - @property - def contains_district_info(self): - """bool: Flag indicating whether text contains district info""" - return bool(self._district_chunks) - @property - def permitted_use_district_text(self): - """str: Combined permitted use districts text from the chunks""" - logger.debug( - "Grabbing %d permitted use chunk(s) from original text at " - "these indices: %s", - len(self._district_chunks), - list(self._district_chunks), - ) +class WindOrdinanceTextExtractor(OrdinanceTextExtractor): + """Extract succinct ordinance text from input""" - text = [ - self._district_chunks[ind] for ind in sorted(self._district_chunks) - ] - return merge_overlapping_texts(text) + IN_LABEL = WindOrdinanceTextCollector.OUT_LABEL + """Identifier for collected text ingested by this class""" + OUT_LABEL = "cleaned_text_for_extraction" + """Identifier for ordinance text extracted by this class""" -class WindOrdinanceTextExtractor(BaseTextExtractor): - """Extract succinct ordinance text from input + TASK_DESCRIPTION = "Extracting wind ordinance text" + """Task description to show in progress bar""" - Purpose: - Extract relevant ordinance text from document. - Responsibilities: - 1. Extract portions from chunked document text relevant to - particular ordinance type (e.g. wind zoning for utility-scale - systems). - Key Relationships: - Uses a StructuredLLMCaller for LLM queries. - """ + TASK_ID = "ordinance_text_extraction" + """ID to use for this extraction for linking with LLM configs""" WIND_ENERGY_SYSTEM_FILTER_PROMPT = ( "# CONTEXT #\n" @@ -438,7 +398,6 @@ async def extract_wind_energy_system_section(self, text_chunks): return await self._process( text_chunks=text_chunks, instructions=self.WIND_ENERGY_SYSTEM_FILTER_PROMPT, - is_valid_chunk=_valid_chunk, ) async def extract_large_wind_energy_system_section(self, text_chunks): @@ -459,7 +418,6 @@ async def extract_large_wind_energy_system_section(self, text_chunks): return await self._process( text_chunks=text_chunks, instructions=self.LARGE_WIND_ENERGY_SYSTEM_SECTION_FILTER_PROMPT, - is_valid_chunk=_valid_chunk, ) @property @@ -478,24 +436,23 @@ def parsers(self): "wind_energy_systems_text", self.extract_wind_energy_system_section, ) - yield ( - "cleaned_ordinance_text", - self.extract_large_wind_energy_system_section, - ) + yield self.OUT_LABEL, self.extract_large_wind_energy_system_section + + +class WindPermittedUseDistrictsTextExtractor(OrdinanceTextExtractor): + """Extract succinct permitted use districts text from input""" + IN_LABEL = WindPermittedUseDistrictsTextCollector.OUT_LABEL + """Identifier for collected text ingested by this class""" -class WindPermittedUseDistrictsTextExtractor(BaseTextExtractor): - """Extract succinct ordinance text from input + OUT_LABEL = "districts_text" + """Identifier for permitted use text extracted by this class""" - Purpose: - Extract relevant ordinance text from document. - Responsibilities: - 1. Extract portions from chunked document text relevant to - particular ordinance type (e.g. wind zoning for utility-scale - systems). - Key Relationships: - Uses a StructuredLLMCaller for LLM queries. - """ + TASK_DESCRIPTION = "Extracting wind permitted use text" + """Task description to show in progress bar""" + + TASK_ID = "permitted_use_text_extraction" + """ID to use for this extraction for linking with LLM configs""" _USAGE_LABEL = LLMUsageCategory.DOCUMENT_PERMITTED_USE_DISTRICTS_SUMMARY @@ -608,7 +565,6 @@ async def extract_permitted_uses(self, text_chunks): return await self._process( text_chunks=text_chunks, instructions=self.PERMITTED_USES_FILTER_PROMPT, - is_valid_chunk=_valid_chunk, ) async def extract_wes_permitted_uses(self, text_chunks): @@ -629,7 +585,6 @@ async def extract_wes_permitted_uses(self, text_chunks): return await self._process( text_chunks=text_chunks, instructions=self.WES_PERMITTED_USES_FILTER_PROMPT, - is_valid_chunk=_valid_chunk, ) @property @@ -645,19 +600,4 @@ def parsers(self): outputs parsed text. """ yield "permitted_use_only_text", self.extract_permitted_uses - yield "districts_text", self.extract_wes_permitted_uses - - -def _valid_chunk(chunk): - """True if chunk has content""" - return chunk and "no relevant text" not in chunk.lower() - - -def _store_chunk(parser, chunk_ind, store): - """Store chunk and its neighbors if it is not already stored""" - for offset in range(1 - parser.num_to_recall, 2): - ind_to_grab = chunk_ind + offset - if ind_to_grab < 0 or ind_to_grab >= len(parser.text_chunks): - continue - - store.setdefault(ind_to_grab, parser.text_chunks[ind_to_grab]) + yield self.OUT_LABEL, self.extract_wes_permitted_uses diff --git a/compass/extraction/wind/parse.py b/compass/extraction/wind/parse.py index 21bd84f80..95c6f20a3 100644 --- a/compass/extraction/wind/parse.py +++ b/compass/extraction/wind/parse.py @@ -8,7 +8,7 @@ import pandas as pd -from compass.llm.calling import BaseLLMCaller, ChatLLMCaller +from compass.plugin.ordinance import OrdinanceParser from compass.extraction.features import SetbackFeatures from compass.common import ( EXTRACT_ORIGINAL_SETBACK_TEXT_PROMPT, @@ -140,18 +140,9 @@ } -class StructuredWindParser(BaseLLMCaller): +class StructuredWindParser(OrdinanceParser): """Base class for parsing structured data""" - def _init_chat_llm_caller(self, system_message): - """Initialize a ChatLLMCaller instance for the DecisionTree""" - return ChatLLMCaller( - self.llm_service, - system_message=system_message, - usage_tracker=self.usage_tracker, - **self.kwargs, - ) - async def _check_wind_turbine_type(self, text): """Get the largest turbine size mentioned in the text""" logger.info("Checking turbine types...") @@ -193,6 +184,12 @@ class StructuredWindOrdinanceParser(StructuredWindParser): individual values. """ + IN_LABEL = "cleaned_text_for_extraction" + """Identifier for text ingested by this class""" + + OUT_LABEL = "ordinance_values" + """Identifier for structured ordinance data output by this class""" + async def parse(self, text): """Parse text and extract structure ordinance data @@ -505,6 +502,12 @@ class StructuredWindPermittedUseDistrictsParser(StructuredWindParser): individual values. """ + IN_LABEL = "districts_text" + """Identifier for text ingested by this class""" + + OUT_LABEL = "permitted_district_values" + """Identifier for structured ordinance data output by this class""" + _LARGE_WES_CLARIFICATION = ( "Large wind energy systems (WES) may also be referred to as wind " "turbines, wind energy conversion systems (WECS), wind energy " diff --git a/compass/extraction/wind/plugin.py b/compass/extraction/wind/plugin.py new file mode 100644 index 000000000..1e22ffaa8 --- /dev/null +++ b/compass/extraction/wind/plugin.py @@ -0,0 +1,97 @@ +"""COMPASS wind extraction plugin""" + +from compass.plugin.interface import ExtractionPlugin +from compass.extraction.wind.ordinance import ( + WindHeuristic, + WindOrdinanceTextCollector, + WindOrdinanceTextExtractor, + WindPermittedUseDistrictsTextCollector, + WindPermittedUseDistrictsTextExtractor, +) +from compass.extraction.wind.parse import ( + StructuredWindOrdinanceParser, + StructuredWindPermittedUseDistrictsParser, +) + +StructuredWindOrdinanceParser.IN_LABEL = WindOrdinanceTextExtractor.OUT_LABEL +StructuredWindPermittedUseDistrictsParser.IN_LABEL = ( + WindPermittedUseDistrictsTextExtractor.OUT_LABEL +) + +WIND_QUESTION_TEMPLATES = [ + "filetype:pdf {jurisdiction} wind energy conversion system ordinances", + "wind energy conversion system ordinances {jurisdiction}", + "{jurisdiction} wind WECS ordinance", + ( + "Where can I find the legal text for commercial wind energy " + "conversion system zoning ordinances in {jurisdiction}?" + ), + ( + "What is the specific legal information regarding zoning " + "ordinances for commercial wind energy conversion systems in " + "{jurisdiction}?" + ), +] + +BEST_WIND_ORDINANCE_WEBSITE_URL_KEYWORDS = { + "pdf": 92160, + "wecs": 46080, + "wind": 23040, + "zoning": 11520, + "ordinance": 5760, + r"renewable%20energy": 1440, + r"renewable+energy": 1440, + "renewable energy": 1440, + "planning": 720, + "plan": 360, + "government": 180, + "code": 60, + "area": 60, + r"land%20development": 15, + r"land+development": 15, + "land development": 15, + "land": 3, + "environment": 3, + "energy": 3, + "renewable": 3, + "municipal": 1, + "department": 1, +} + + +class COMPASSWindExtractor(ExtractionPlugin): + """COMPASS wind extraction plugin""" + + IDENTIFIER = "wind" + """str: Identifier for extraction task """ + + QUESTION_TEMPLATES = WIND_QUESTION_TEMPLATES + """list: List of search engine question templates for extraction""" + + WEBSITE_KEYWORDS = BEST_WIND_ORDINANCE_WEBSITE_URL_KEYWORDS + """list: List of keywords + + Keywords indicate links which should be prioritized when performing + a website scrape for a wind ordinance document. + """ + + heuristic = WindHeuristic() + """BaseHeuristic: Object with a ``check()`` method""" + + TEXT_COLLECTORS = [ + WindOrdinanceTextCollector, + WindPermittedUseDistrictsTextCollector, + ] + """Classes for collecting wind ordinance text chunks from docs""" + + TEXT_EXTRACTORS = [ + WindOrdinanceTextExtractor, + WindPermittedUseDistrictsTextExtractor, + ] + """Class for extracting cleaned ord text from collected text""" + + PARSERS = [ + StructuredWindOrdinanceParser, + StructuredWindPermittedUseDistrictsParser, + ] + """Class for parsing structured ordinance data from text""" diff --git a/compass/plugin/__init__.py b/compass/plugin/__init__.py new file mode 100644 index 000000000..f42cef353 --- /dev/null +++ b/compass/plugin/__init__.py @@ -0,0 +1,16 @@ +"""COMPASS plugin tools""" + +from .base import BaseExtractionPlugin +from .interface import ( + BaseHeuristic, + BaseTextCollector, + BaseTextExtractor, + BaseParser, + ExtractionPlugin, +) +from .ordinance import ( + OrdinanceHeuristic, + OrdinanceTextCollector, + OrdinanceTextExtractor, + OrdinanceParser, +) diff --git a/compass/plugin/base.py b/compass/plugin/base.py new file mode 100644 index 000000000..71957bc8c --- /dev/null +++ b/compass/plugin/base.py @@ -0,0 +1,148 @@ +"""Base COMPASS extraction plugin interface""" + +from abc import ABC, abstractmethod + +from compass.pb import COMPASS_PB +from compass.services.threaded import UsageUpdater +from compass.utilities import compute_total_cost_from_usage + + +class BaseExtractionPlugin(ABC): + """Base class for COMPASS extraction plugins + + This class provides the most extraction flexibility, but the + implementer must define most functionality on their own. + """ + + def __init__(self, jurisdiction, model_configs, usage_tracker=None): + """ + + Parameters + ---------- + jurisdiction : Jurisdiction + Jurisdiction for which extraction is being performed. + model_configs : dict + Dictionary where keys are + :class:`~compass.utilities.enums.LLMTasks` and values are + :class:`~compass.llm.config.LLMConfig` instances to be used + for those tasks. + usage_tracker : UsageTracker, optional + Usage tracker instance that can be used to record the LLM + call cost. By default, ``None``. + """ + self.jurisdiction = jurisdiction + self.model_configs = model_configs + self.usage_tracker = usage_tracker + + @property + @abstractmethod + def IDENTIFIER(self): # noqa: N802 + """str: Identifier for extraction task (e.g. "water rights")""" + raise NotImplementedError + + @property + @abstractmethod + def QUESTION_TEMPLATES(self): # noqa: N802 + """list: List of search engine question templates for extraction + + Question templates can contain the placeholder + ``{jurisdiction}`` which will be replaced with the full + jurisdiction name during the search engine query. + """ + raise NotImplementedError + + @property + @abstractmethod + def WEBSITE_KEYWORDS(self): # noqa: N802 + """list: List of keywords + + List of keywords that indicate links which should be prioritized + when performing a website scrape for a document. + """ + raise NotImplementedError + + @property + @abstractmethod + def heuristic(self): + """BaseHeuristic: Object with a ``check()`` method + + The ``check()`` method should accept a string of text and + return ``True`` if the text passes the heuristic check and + ``False`` otherwise. + """ + raise NotImplementedError + + @abstractmethod + async def filter_docs( + self, extraction_context, need_jurisdiction_verification=True + ): + """Filter down candidate documents before parsing + + Parameters + ---------- + extraction_context : ExtractionContext + Context containing candidate documents to be filtered. + Set the ``.documents`` attribute of this object to be the + iterable of documents that should be kept for parsing. + need_jurisdiction_verification : bool, optional + Whether to verify that documents pertain to the correct + jurisdiction. By default, ``True``. + + Returns + ------- + ExtractionContext + Context with filtered down documents. + """ + raise NotImplementedError + + @abstractmethod + async def parse_docs_for_structured_data(self, extraction_context): + """Parse documents to extract structured data/information + + Parameters + ---------- + extraction_context : ExtractionContext + Context containing candidate documents to parse. + + Returns + ------- + ExtractionContext or None + Context with extracted data/information stored in the + ``.attrs`` dictionary, or ``None`` if no data was extracted. + """ + raise NotImplementedError + + @classmethod + @abstractmethod + def save_structured_data(cls, doc_infos, out_dir): + """Write combined extracted structured data to disk + + Parameters + ---------- + doc_infos : list of dict + List of dictionaries containing the following keys: + + - "jurisdiction": An initialized Jurisdiction object + representing the jurisdiction that was extracted. + - "ord_db_fp": A path to the extracted structured data + stored on disk, or ``None`` if no data was extracted. + + out_dir : path-like + Path to the output directory for the data. + + Returns + ------- + int + Number of jurisdictions for which data was successfully + extracted. + """ + raise NotImplementedError + + async def record_usage(self): + """Persist usage tracking data when a tracker is available""" + if self.usage_tracker is None: + return + + total_usage = await UsageUpdater.call(self.usage_tracker) + total_cost = compute_total_cost_from_usage(total_usage) + COMPASS_PB.update_total_cost(total_cost, replace=True) diff --git a/compass/plugin/interface.py b/compass/plugin/interface.py new file mode 100644 index 000000000..f73cff8cb --- /dev/null +++ b/compass/plugin/interface.py @@ -0,0 +1,740 @@ +"""COMPASS extraction plugin base class""" + +import asyncio +import logging +from itertools import chain +from abc import ABC, abstractmethod +from contextlib import contextmanager +from functools import cached_property + +import pandas as pd + +from compass.plugin.base import BaseExtractionPlugin +from compass.llm.calling import LLMCaller +from compass.extraction import ( + extract_ordinance_values, + extract_relevant_text_with_ngram_validation, +) +from compass.scripts.download import filter_ordinance_docs +from compass.services.threaded import CleanedFileWriter +from compass.utilities.enums import LLMTasks +from compass.utilities import ( + num_ordinances_dataframe, + doc_infos_to_db, + save_db, +) +from compass.utilities.parsing import extract_ord_year_from_doc_attrs +from compass.exceptions import COMPASSPluginConfigurationError +from compass.pb import COMPASS_PB + +logger = logging.getLogger(__name__) + + +EXCLUDE_FROM_ORD_DOC_CHECK = { + # if doc only contains these, it's not good enough to count as an + # ordinance. Note that prohibitions are explicitly not on this list + "color", + "decommissioning", + "lighting", + "visual impact", + "glare", + "repowering", + "fencing", + "climbing prevention", + "signage", + "soil", + "primary use districts", + "special use districts", + "accessory use districts", +} + + +class BaseHeuristic(ABC): + """Base class for a heuristic check""" + + @abstractmethod + def check(self, text): + """Check for mention of a tech in text (or text chunk) + + Parameters + ---------- + text : str + Input text that may or may not mention the technology of + interest. + + Returns + ------- + bool + ``True`` if the text passes the heuristic check and + ``False`` otherwise. + """ + raise NotImplementedError + + +class BaseTextCollector(ABC): + """Base class for text collectors that gather relevant text""" + + @property + @abstractmethod + def OUT_LABEL(self): # noqa: N802 + """str: Identifier for text collected by this class""" + raise NotImplementedError + + @property + @abstractmethod + def relevant_text(self): + """str: Combined relevant text from the individual chunks""" + raise NotImplementedError + + @abstractmethod + async def check_chunk(self, chunk_parser, ind): + """Check if a text chunk is relevant for extraction + + You should validate chunks like so:: + + is_correct_kind_of_text = await chunk_parser.parse_from_ind( + ind, + key="my_unique_validation_key", + llm_call_callback=my_async_llm_call_function, + ) + + where the `"key"` is unique to this particular validation (it + will be used to cache the validation result in the chunk + parser's memory) and `my_async_llm_call_function` is an async + function that takes in a key and text chunk and returns a + boolean indicating whether or not the text chunk passes the + validation. You can call `chunk_parser.parse_from_ind` as many + times as you want within this method, but be sure to use unique + keys for each validation. + + Parameters + ---------- + chunk_parser : ParseChunksWithMemory + Instance that contains a ``parse_from_ind`` method. + ind : int + Index of the chunk to check. + + Returns + ------- + bool + Boolean flag indicating whether or not the text in the chunk + contains information relevant to the extraction task. + + See Also + -------- + :func:`~compass.validation.content.ParseChunksWithMemory.parse_from_ind` + Method used to parse text from a chunk with memory of prior + chunk validations. + """ + raise NotImplementedError + + +class BaseTextExtractor(ABC): + """Extract succinct extraction text from input""" + + TASK_DESCRIPTION = "Condensing text for extraction" + """Task description to show in progress bar""" + + TASK_ID = "text_extraction" + """ID to use for this extraction for linking with LLM configs""" + + @property + @abstractmethod + def IN_LABEL(self): # noqa: N802 + """str: Identifier for text ingested by this class""" + raise NotImplementedError + + @property + @abstractmethod + def OUT_LABEL(self): # noqa: N802 + """str: Identifier for final text extracted by this class""" + raise NotImplementedError + + @property + @abstractmethod + def parsers(self): + """Generator: Generator of (key, extractor) pairs + + `extractor` should be an async callable that accepts a list of + text chunks and returns the shortened (succinct) text to be used + for extraction. The `key` should be a string identifier for the + text returned by the extractor. Multiple (key, extractor) pairs + can be chained in generator order to iteratively refine the + text for extraction. + """ + raise NotImplementedError + + +class BaseParser(ABC): + """Extract succinct extraction text from input""" + + TASK_ID = "data_extraction" + """ID to use for this extraction for linking with LLM configs""" + + @property + @abstractmethod + def IN_LABEL(self): # noqa: N802 + """str: Identifier for text ingested by this class""" + raise NotImplementedError + + @property + @abstractmethod + def OUT_LABEL(self): # noqa: N802 + """str: Identifier for final structured data output""" + raise NotImplementedError + + @abstractmethod + async def parse(self, text): + """Parse text and extract structured data + + Parameters + ---------- + text : str + Text which may or may not contain information relevant to + the current extraction. + + Returns + ------- + pandas.DataFrame or None + DataFrame containing structured extracted data. Can also + be ``None`` if no relevant values can be parsed from the + text. + """ + raise NotImplementedError + + +class ExtractionPlugin(BaseExtractionPlugin): + """Base class for COMPASS extraction plugins + + This class provides a good balance between ease of use and + extraction flexibility, allowing implementers to provide additional + functionality during the extraction process. + + Plugins can hook into various stages of the extraction pipeline + to modify behavior, add custom processing, or integrate with + external systems. + + Subclasses should implement the desired hooks and override + methods as needed. + """ + + @property + @abstractmethod + def IDENTIFIER(self): # noqa: N802 + """str: Identifier for extraction task (e.g. "water rights")""" + raise NotImplementedError + + @property + @abstractmethod + def QUESTION_TEMPLATES(self): # noqa: N802 + """list: List of search engine question templates for extraction + + Question templates can contain the placeholder + ``{jurisdiction}`` which will be replaced with the full + jurisdiction name during the search engine query. + """ + raise NotImplementedError + + @property + @abstractmethod + def WEBSITE_KEYWORDS(self): # noqa: N802 + """list: List of keywords + + List of keywords that indicate links which should be prioritized + when performing a website scrape for a document. + """ + raise NotImplementedError + + @property + @abstractmethod + def TEXT_COLLECTORS(self): # noqa: N802 + """list of BaseTextCollector: Classes to collect text + + Should be an iterable of one or more classes to collect text + for the extraction task. + """ + raise NotImplementedError + + @property + @abstractmethod + def TEXT_EXTRACTORS(self): # noqa: N802 + """list of BaseTextExtractor: Classes to condense text + + Should be an iterable of one or more classes to condense text in + preparation for the extraction task. + """ + raise NotImplementedError + + @property + @abstractmethod + def PARSERS(self): # noqa: N802 + """list of BaseParser: Classes to extract structured data + + Should be an iterable of one or more classes to extract + structured data from text. + """ + raise NotImplementedError + + @property + def heuristic(self): + """BaseHeuristic: Object with a ``check()`` method + + The ``check()`` method should accept a string of text and + return ``True`` if the text passes the heuristic check and + ``False`` otherwise. + """ + raise NotImplementedError + + @classmethod + def save_structured_data(cls, doc_infos, out_dir): + """Write extracted water rights data to disk + + Parameters + ---------- + doc_infos : list of dict + List of dictionaries containing the following keys: + + - "jurisdiction": An initialized Jurisdiction object + representing the jurisdiction that was extracted. + - "ord_db_fp": A path to the extracted structured data + stored on disk, or ``None`` if no data was extracted. + + out_dir : path-like + Path to the output directory for the data. + + Returns + ------- + int + Number of unique jurisdictions that information was + found/written for. + """ + db, num_docs_found = doc_infos_to_db(doc_infos) + save_db(db, out_dir) + return num_docs_found + + def __init__(self, jurisdiction, model_configs, usage_tracker=None): + """ + + Parameters + ---------- + jurisdiction : Jurisdiction + Jurisdiction for which extraction is being performed. + model_configs : dict + Dictionary where keys are LLMTasks and values are LLMConfig + instances to be used for those tasks. + usage_tracker : UsageTracker, optional + Usage tracker instance that can be used to record the LLM + call cost. By default, ``None``. + """ + super().__init__( + jurisdiction=jurisdiction, + model_configs=model_configs, + usage_tracker=usage_tracker, + ) + + # TODO: This should happen during plugin registration + self._validate_in_out_keys() + + @cached_property + def producers(self): + """list: All classes that produce attributes on the doc""" + return chain(self.PARSERS, self.TEXT_EXTRACTORS, self.TEXT_COLLECTORS) + + @cached_property + def consumer_producer_pairs(self): + """list: Pairs of (consumer, producer) for IN/OUT validation""" + return [ + (self.PARSERS, chain(self.TEXT_EXTRACTORS, self.TEXT_COLLECTORS)), + (self.TEXT_EXTRACTORS, self.TEXT_COLLECTORS), + ] + + def _validate_in_out_keys(self): + """Validate that all IN_LABELs have matching OUT_LABELs""" + out_keys = {} + for producer in self.producers: + out_keys.setdefault(producer.OUT_LABEL, []).append(producer) + + dupes = {k: v for k, v in out_keys.items() if len(v) > 1} + if dupes: + formatted = "\n".join( + [ + f"{key}: {[cls.__name__ for cls in classes]}" + for key, classes in dupes.items() + ] + ) + msg = ( + "Multiple processing classes produce the same OUT_LABEL key:\n" + f"{formatted}" + ) + raise COMPASSPluginConfigurationError(msg) + + for consumers, producers in self.consumer_producer_pairs: + _validate_in_out_keys(consumers, producers) + + async def pre_filter_docs_hook(self, extraction_context): # noqa: PLR6301 + """Pre-process documents before running them through the filter + + Parameters + ---------- + extraction_context : ExtractionContext + Context with downloaded documents to process. + + Returns + ------- + ExtractionContext + Context with documents to be passed onto the filtering step. + """ + return extraction_context + + async def post_filter_docs_hook(self, extraction_context): # noqa: PLR6301 + """Post-process documents after running them through the filter + + Parameters + ---------- + extraction_context : ExtractionContext + Context with documents that passed the filtering step. + + Returns + ------- + ExtractionContext + Context with documents to be passed onto the parsing step. + """ + return extraction_context + + async def extract_relevant_text(self, doc, extractor_class, model_config): + """Condense text for extraction task + + This method takes a text extractor and applies it to the + collected document chunks to get a concise version of the text + that can be used for structured data extraction. + + The extracted text will be stored in the ``.attrs`` dictionary + of the input document under the ``extractor_class.OUT_LABEL`` + key. + + Parameters + ---------- + doc : BaseDocument + Document containing text chunks to condense. + extractor_class : BaseTextExtractor + Class to use for text extraction. + model_config : LLMConfig + Configuration for the LLM model to use for text extraction. + """ + llm_caller = LLMCaller( + llm_service=model_config.llm_service, + usage_tracker=self.usage_tracker, + **model_config.llm_call_kwargs, + ) + extractor = extractor_class(llm_caller) + doc = await extract_relevant_text_with_ngram_validation( + doc, + model_config.text_splitter, + extractor, + original_text_key=extractor_class.IN_LABEL, + ) + await self._write_cleaned_text(doc) + + async def extract_ordinances_from_text( + self, doc, parser_class, model_config + ): + """Extract structured data from input text + + The extracted structured data will be stored in the ``.attrs`` + dictionary of the input document under the + ``parser_class.OUT_LABEL`` key. + + Parameters + ---------- + doc : BaseDocument + Document containing text to extract structured data from. + parser_class : BaseParser + Class to use for structured data extraction. + model_config : LLMConfig + Configuration for the LLM model to use for structured data + extraction. + """ + parser = parser_class( + llm_service=model_config.llm_service, + usage_tracker=self.usage_tracker, + **model_config.llm_call_kwargs, + ) + logger.info( + "Extracting %s...", parser_class.OUT_LABEL.replace("_", " ") + ) + await extract_ordinance_values( + doc, + parser, + text_key=parser_class.IN_LABEL, + out_key=parser_class.OUT_LABEL, + ) + + @classmethod + def get_structured_data_row_count(cls, data_df): + """Get the number of data rows extracted from a document + + Parameters + ---------- + data_df : pandas.DataFrame or None + DataFrame to check for extracted structured data. + + Returns + ------- + int + Number of data rows extracted from the document. + """ + if data_df is None: + return 0 + + return num_ordinances_dataframe( + data_df, exclude_features=EXCLUDE_FROM_ORD_DOC_CHECK + ) + + async def filter_docs( + self, extraction_context, need_jurisdiction_verification=True + ): + """Filter down candidate documents before parsing + + Parameters + ---------- + extraction_context : ExtractionContext + Context containing candidate documents to be filtered. + need_jurisdiction_verification : bool, optional + Whether to verify that documents pertain to the correct + jurisdiction. By default, ``True``. + + Returns + ------- + iterable of BaseDocument + Filtered documents or ``None`` if no documents remain. + """ + if not extraction_context: + return None + + logger.debug( + "Passing %d document(s) in to `pre_filter_docs_hook` ", + extraction_context.num_documents, + ) + + docs = await self.pre_filter_docs_hook(extraction_context.documents) + logger.debug( + "%d document(s) remaining after `pre_filter_docs_hook` for " + "%s\n\t- %s", + len(docs), + self.jurisdiction.full_name, + "\n\t- ".join( + [doc.attrs.get("source", "Unknown source") for doc in docs] + ), + ) + + docs = await filter_ordinance_docs( + docs, + self.jurisdiction, + self.model_configs, + heuristic=self.heuristic, + tech=self.IDENTIFIER, + text_collectors=self.TEXT_COLLECTORS, + usage_tracker=self.usage_tracker, + check_for_correct_jurisdiction=need_jurisdiction_verification, + ) + + if not docs: + return None + + logger.debug( + "Passing %d document(s) in to `post_filter_docs_hook` ", len(docs) + ) + docs = await self.post_filter_docs_hook(docs) + logger.debug( + "%d document(s) remaining after `post_filter_docs_hook` for " + "%s\n\t- %s", + len(docs), + self.jurisdiction.full_name, + "\n\t- ".join( + [doc.attrs.get("source", "Unknown source") for doc in docs] + ), + ) + if not docs: + return None + + extraction_context.documents = docs + return extraction_context + + async def parse_docs_for_structured_data(self, extraction_context): + """Parse documents to extract structured data/information + + Parameters + ---------- + extraction_context : ExtractionContext + Context containing candidate documents to parse. + + Returns + ------- + ExtractionContext or None + Context with extracted data/information stored in the + ``.attrs`` dictionary, or ``None`` if no data was extracted. + """ + for doc_for_extraction in extraction_context: + data_df = await self.parse_single_doc_for_structured_data( + doc_for_extraction + ) + row_count = self.get_structured_data_row_count(data_df) + if row_count > 0: + await extraction_context.mark_doc_as_data_source( + doc_for_extraction, out_fn_stem=self.jurisdiction.full_name + ) + extraction_context.attrs["structured_data"] = data_df + logger.info( + "%d ordinance value(s) found in doc from %s for %s. ", + row_count, + doc_for_extraction.attrs.get("source", "unknown source"), + self.jurisdiction.full_name, + ) + return extraction_context + + logger.debug( + "No ordinances found; searched %d docs", + extraction_context.num_documents, + ) + return None + + async def parse_single_doc_for_structured_data(self, doc_for_extraction): + """Extract all possible structured data from a document + + This method is called from the default implementation of + `parse_docs_for_structured_data()` for each document that passed + filtering. If you overwrite`parse_docs_for_structured_data()``, + you can ignore this method. + + Parameters + ---------- + doc_for_extraction : BaseDocument + Document to extract structured data from. + + Returns + ------- + BaseDocument + Document with extracted structured data stored in the + ``.attrs`` dictionary. + """ + with self._tracked_progress(): + tasks = [ + asyncio.create_task( + self._try_extract_ordinances( + doc_for_extraction, parser_class + ), + name=self.jurisdiction.full_name, + ) + for parser_class in filter(None, self.PARSERS) + ] + await asyncio.gather(*tasks) + + return self._concat_scrape_results(doc_for_extraction) + + async def _try_extract_ordinances(self, doc_for_extraction, parser_class): + """Apply a single extractor and parser to legal text""" + + if parser_class.IN_LABEL not in doc_for_extraction.attrs: + await self._run_text_extractors(doc_for_extraction, parser_class) + + model_config = self._get_model_config( + primary_key=parser_class.TASK_ID, + secondary_key=LLMTasks.DATA_EXTRACTION, + ) + await self.extract_ordinances_from_text( + doc_for_extraction, + parser_class=parser_class, + model_config=model_config, + ) + + await self.record_usage() + + async def _run_text_extractors(self, doc_for_extraction, parser_class): + """Run text extractor(s) on document to get text for a parser""" + te = [ + te + for te in self.TEXT_EXTRACTORS + if te.OUT_LABEL == parser_class.IN_LABEL + ] + if len(te) != 1: + msg = ( + f"Could not find unique text extractor for parser " + f"{parser_class.__name__} with IN_LABEL " + f"{parser_class.IN_LABEL!r}. Got matches: {te}" + ) + raise COMPASSPluginConfigurationError(msg) + + te = te[0] + model_config = self._get_model_config( + primary_key=te.TASK_ID, + secondary_key=LLMTasks.TEXT_EXTRACTION, + ) + logger.debug( + "Condensing text for extraction using %r for doc from %s", + te.__name__, + doc_for_extraction.attrs.get("source", "unknown source"), + ) + assert self._jsp is not None, "No progress bar set!" + task_id = self._jsp.add_task(te.TASK_DESCRIPTION) + await self.extract_relevant_text(doc_for_extraction, te, model_config) + await self.record_usage() + self._jsp.remove_task(task_id) + + @contextmanager + def _tracked_progress(self): + """Context manager to set up jurisdiction sub-progress bar""" + loc = self.jurisdiction.full_name + with COMPASS_PB.jurisdiction_sub_prog(loc) as self._jsp: + yield + + self._jsp = None + + def _concat_scrape_results(self, doc): + """Concatenate structured data from all parsers""" + data = [doc.attrs.get(p.OUT_LABEL, None) for p in self.PARSERS] + data = [df for df in data if df is not None and not df.empty] + if len(data) == 0: + return None + + data = data[0] if len(data) == 1 else pd.concat(data) + data["source"] = doc.attrs.get("source") + data["ord_year"] = extract_ord_year_from_doc_attrs(doc.attrs) + return data + + def _get_model_config(self, primary_key, secondary_key): + """Get model config: primary_key -> secondary_key -> default""" + if primary_key in self.model_configs: + return self.model_configs[primary_key] + return self.model_configs.get( + secondary_key, self.model_configs[LLMTasks.DEFAULT] + ) + + async def _write_cleaned_text(self, doc): + """Write cleaned text to `clean_files` dir""" + out_fp = await CleanedFileWriter.call(doc, self.jurisdiction.full_name) + doc.attrs["cleaned_fps"] = out_fp + return doc + + +def _validate_in_out_keys(consumers, producers): + """Validate that all IN_LABELs have matching OUT_LABELs""" + in_keys = {} + out_keys = {} + + for producer_class in producers: + out_keys.setdefault(producer_class.OUT_LABEL, []).append( + producer_class + ) + + for consumer_class in chain(consumers): + in_keys.setdefault(consumer_class.IN_LABEL, []).append(consumer_class) + + for in_key, classes in in_keys.items(): + formatted = f"{[cls.__name__ for cls in classes]}" + if in_key not in out_keys: + msg = ( + f"One or more processing classes require IN_LABEL " + f"{in_key!r}, which is not produced by any previous " + f"processing class: {formatted}" + ) + raise COMPASSPluginConfigurationError(msg) diff --git a/compass/plugin/ordinance.py b/compass/plugin/ordinance.py new file mode 100644 index 000000000..affa5fb88 --- /dev/null +++ b/compass/plugin/ordinance.py @@ -0,0 +1,279 @@ +"""Helper classes for ordinance plugins""" + +import asyncio +import logging +from abc import ABC, abstractmethod +from warnings import warn + +from elm import ApiBase + +from compass.llm.calling import ( + BaseLLMCaller, + ChatLLMCaller, + StructuredLLMCaller, +) +from compass.utilities.enums import LLMUsageCategory +from compass.utilities.ngrams import convert_text_to_sentence_ngrams +from compass.warn import COMPASSWarning +from compass.utilities.parsing import ( + merge_overlapping_texts, + clean_backticks_from_llm_response, +) +from compass.plugin.interface import ( + BaseHeuristic, + BaseTextCollector, + BaseTextExtractor, + BaseParser, +) + + +logger = logging.getLogger(__name__) + + +class OrdinanceHeuristic(BaseHeuristic, ABC): + """Perform a heuristic check for mention of a technology in text""" + + _GOOD_ACRONYM_CONTEXTS = [ + " {acronym} ", + " {acronym}\n", + " {acronym}.", + "\n{acronym} ", + "\n{acronym}.", + "\n{acronym}\n", + "({acronym} ", + " {acronym})", + ] + + def check(self, text, match_count_threshold=1): + """Check for mention of a tech in text + + This check first strips the text of any tech "look-alike" words + (e.g. "window", "windshield", etc for "wind" technology). Then, + it checks for particular keywords, acronyms, and phrases that + pertain to the tech in the text. If enough keywords are mentions + (as dictated by `match_count_threshold`), this check returns + ``True``. + + Parameters + ---------- + text : str + Input text that may or may not mention the technology of + interest. + match_count_threshold : int, optional + Number of keywords that must match for the text to pass this + heuristic check. Count must be strictly greater than this + value. By default, ``1``. + + Returns + ------- + bool + ``True`` if the number of keywords/acronyms/phrases detected + exceeds the `match_count_threshold`. + """ + heuristics_text = self._convert_to_heuristics_text(text) + total_keyword_matches = self._count_single_keyword_matches( + heuristics_text + ) + total_keyword_matches += self._count_acronym_matches(heuristics_text) + total_keyword_matches += self._count_phrase_matches(heuristics_text) + return total_keyword_matches > match_count_threshold + + def _convert_to_heuristics_text(self, text): + """Convert text for heuristic content parsing""" + heuristics_text = text.casefold() + for word in self.NOT_TECH_WORDS: + heuristics_text = heuristics_text.replace(word, "") + return heuristics_text + + def _count_single_keyword_matches(self, heuristics_text): + """Count number of good tech keywords that appear in text""" + return sum( + keyword in heuristics_text for keyword in self.GOOD_TECH_KEYWORDS + ) + + def _count_acronym_matches(self, heuristics_text): + """Count number of good tech acronyms that appear in text""" + acronym_matches = 0 + for context in self._GOOD_ACRONYM_CONTEXTS: + acronym_keywords = { + context.format(acronym=acronym) + for acronym in self.GOOD_TECH_ACRONYMS + } + acronym_matches = sum( + keyword in heuristics_text for keyword in acronym_keywords + ) + if acronym_matches > 0: + break + return acronym_matches + + def _count_phrase_matches(self, heuristics_text): + """Count number of good tech phrases that appear in text""" + text_ngrams = {} + total = 0 + for phrase in self.GOOD_TECH_PHRASES: + n = len(phrase.split(" ")) + if n <= 1: + msg = ( + "Make sure your GOOD_TECH_PHRASES contain at least 2 " + f"words! Got phrase: {phrase!r}" + ) + warn(msg, COMPASSWarning) + continue + + if n not in text_ngrams: + text_ngrams[n] = set( + convert_text_to_sentence_ngrams(heuristics_text, n) + ) + + test_ngrams = ( # fmt: off + convert_text_to_sentence_ngrams(phrase, n) + + convert_text_to_sentence_ngrams(f"{phrase}s", n) + ) + if any(t in text_ngrams[n] for t in test_ngrams): + total += 1 + + return total + + @property + @abstractmethod + def NOT_TECH_WORDS(self): # noqa: N802 + """:class:`~collections.abc.Iterable`: Not tech keywords""" + raise NotImplementedError + + @property + @abstractmethod + def GOOD_TECH_KEYWORDS(self): # noqa: N802 + """:class:`~collections.abc.Iterable`: Tech keywords""" + raise NotImplementedError + + @property + @abstractmethod + def GOOD_TECH_ACRONYMS(self): # noqa: N802 + """:class:`~collections.abc.Iterable`: Tech acronyms""" + raise NotImplementedError + + @property + @abstractmethod + def GOOD_TECH_PHRASES(self): # noqa: N802 + """:class:`~collections.abc.Iterable`: Tech phrases""" + raise NotImplementedError + + +class OrdinanceTextCollector(StructuredLLMCaller, BaseTextCollector): + """Base class for ordinance text collectors""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._chunks = {} + + @property + def relevant_text(self): + """str: Combined ordinance text from the individual chunks""" + if not self._chunks: + logger.debug( + "No relevant ordinance chunk(s) found in original text", + ) + return "" + + logger.debug( + "Grabbing %d ordinance chunk(s) from original text at these " + "indices: %s", + len(self._chunks), + list(self._chunks), + ) + + text = [self._chunks[ind] for ind in sorted(self._chunks)] + return merge_overlapping_texts(text) + + def _store_chunk(self, parser, chunk_ind): + """Store chunk and its neighbors if it is not already stored""" + for offset in range(1 - parser.num_to_recall, 2): + ind_to_grab = chunk_ind + offset + if ind_to_grab < 0 or ind_to_grab >= len(parser.text_chunks): + continue + + self._chunks.setdefault( + ind_to_grab, parser.text_chunks[ind_to_grab] + ) + + +class OrdinanceTextExtractor(BaseTextExtractor, ABC): + """Base implementation for a text extractor""" + + SYSTEM_MESSAGE = ( + "You are a text extraction assistant. Your job is to extract only " + "verbatim, **unmodified** excerpts from provided legal or policy " + "documents. Do not interpret or paraphrase. Do not summarize. Only " + "return exactly copied segments that match the specified scope. If " + "the relevant content appears within a table, return the entire " + "table, including headers and footers, exactly as formatted." + ) + """System message for text extraction LLM calls""" + _USAGE_LABEL = LLMUsageCategory.DOCUMENT_ORDINANCE_SUMMARY + + def __init__(self, llm_caller): + """ + + Parameters + ---------- + llm_caller : LLMCaller + LLM Caller instance used to extract ordinance info with. + """ + self.llm_caller = llm_caller + + async def _process(self, text_chunks, instructions, is_valid_chunk=None): + """Perform extraction processing""" + if is_valid_chunk is None: + is_valid_chunk = _valid_chunk + + logger.info( + "Extracting summary text from %d text chunks asynchronously...", + len(text_chunks), + ) + logger.debug("Model instructions are:\n%s", instructions) + outer_task_name = asyncio.current_task().get_name() + summaries = [ + asyncio.create_task( + self.llm_caller.call( + sys_msg=self.SYSTEM_MESSAGE, + content=f"{instructions}\n\n# TEXT #\n\n{chunk}", + usage_sub_label=self._USAGE_LABEL, + ), + name=outer_task_name, + ) + for chunk in text_chunks + ] + summary_chunks = await asyncio.gather(*summaries) + summary_chunks = [ + clean_backticks_from_llm_response(chunk) + for chunk in summary_chunks + if is_valid_chunk(chunk) + ] + + text_summary = merge_overlapping_texts(summary_chunks) + logger.debug( + "Final summary contains %d tokens", + ApiBase.count_tokens( + text_summary, + model=self.llm_caller.kwargs.get("model", "gpt-4"), + ), + ) + return text_summary + + +class OrdinanceParser(BaseLLMCaller, BaseParser): + """Base class for parsing structured data""" + + def _init_chat_llm_caller(self, system_message): + """Initialize a ChatLLMCaller instance for the DecisionTree""" + return ChatLLMCaller( + self.llm_service, + system_message=system_message, + usage_tracker=self.usage_tracker, + **self.kwargs, + ) + + +def _valid_chunk(chunk): + """True if chunk has content""" + return chunk and "no relevant text" not in chunk.lower() diff --git a/compass/scripts/download.py b/compass/scripts/download.py index 33d1bbe22..7d5d3e1e0 100644 --- a/compass/scripts/download.py +++ b/compass/scripts/download.py @@ -16,7 +16,7 @@ ) from elm.web.utilities import filter_documents -from compass.extraction import check_for_ordinance_info, extract_date +from compass.extraction import check_for_relevant_text, extract_date from compass.services.threaded import TempFileCache, TempFileCachePB from compass.validation.location import ( DTreeJurisdictionValidator, @@ -57,9 +57,9 @@ async def download_known_urls( Returns ------- out_docs : list - List of :obj:`~elm.web.document.BaseDocument` instances - containing documents from the URL's, or an empty list if - something went wrong during the retrieval process. + List of BaseDocument instances containing documents from the + URL's, or an empty list if something went wrong during the + retrieval process. Notes ----- @@ -112,9 +112,9 @@ async def load_known_docs(jurisdiction, fps, local_file_loader_kwargs=None): Returns ------- out_docs : list - List of :obj:`~elm.web.document.BaseDocument` instances - containing documents from the paths, or an empty list if - something went wrong during the retrieval process. + List of BaseDocument instances containing documents from the + paths, or an empty list if something went wrong during the + retrieval process. Notes ----- @@ -263,8 +263,8 @@ async def download_jurisdiction_ordinances_from_website( website : str URL of the jurisdiction website to search. heuristic : callable - Callable taking an :class:`elm.web.document.BaseDocument` and - returning ``True`` when the document should be kept. + Callable taking an BaseDocument and returning ``True`` when the + document should be kept. keyword_points : dict Dictionary of keyword points to use for scoring links. Keys are keywords, values are points to assign to links @@ -303,9 +303,9 @@ async def download_jurisdiction_ordinances_from_website( Returns ------- out_docs : list - List of :obj:`~elm.web.document.BaseDocument` instances - containing potential ordinance information, or an empty list if - no ordinance document was found. + List of BaseDocument instances containing potential ordinance + information, or an empty list if no ordinance document was + found. results : list, optional List of crawl4ai results containing metadata about the crawled pages. Only returned when ``return_c4ai_results`` evaluates to @@ -391,8 +391,8 @@ async def download_jurisdiction_ordinances_from_website_compass_crawl( website : str URL of the jurisdiction website to search. heuristic : callable - Callable taking an :class:`elm.web.document.BaseDocument` and - returning ``True`` when the document should be kept. + Callable taking an BaseDocument and returning ``True`` when the + document should be kept. keyword_points : dict Dictionary of keyword points to use for scoring links. Keys are keywords, values are points to assign to links @@ -424,9 +424,9 @@ async def download_jurisdiction_ordinances_from_website_compass_crawl( Returns ------- out_docs : list - List of :obj:`~elm.web.document.BaseDocument` instances - containing potential ordinance information, or an empty list if - no ordinance document was found. + List of BaseDocument instances containing potential ordinance + information, or an empty list if no ordinance document was + found. Notes ----- @@ -533,9 +533,8 @@ async def download_jurisdiction_ordinance_using_search_engine( Returns ------- list or None - List of :obj:`~elm.web.document.BaseDocument` instances possibly - containing ordinance information, or ``None`` if no ordinance - document was found. + List of BaseDocument instances possibly containing ordinance + information, or ``None`` if no ordinance document was found. Notes ----- @@ -589,8 +588,7 @@ async def filter_ordinance_docs( model_configs, heuristic, tech, - ordinance_text_collector_class, - permitted_use_text_collector_class=None, + text_collectors, usage_tracker=None, check_for_correct_jurisdiction=True, ): @@ -598,7 +596,7 @@ async def filter_ordinance_docs( Parameters ---------- - docs : sequence of elm.web.document.BaseDocument + docs : sequence of BaseDocument Documents to screen for ordinance content. jurisdiction : Jurisdiction Location objects representing the jurisdiction. @@ -611,10 +609,13 @@ async def filter_ordinance_docs( tech : str Technology of interest (e.g. "solar", "wind", etc). This is used to set up some document validation decision trees. - ordinance_text_collector_class : type - Collector class used to extract ordinance text sections. - permitted_use_text_collector_class : type, optional - Collector class used to extract permitted-use text sections. + text_collectors : iterable + Iterable of text collector classes to run during document + parsing. Each class must implement the + :class:`compass.plugin.interface.BaseTextCollector` interface. + If the document already contains text collected by a given + collector (i.e. the collector's ``OUT_LABEL`` is found in + ``doc.attrs``), that collector will be skipped. usage_tracker : UsageTracker, optional Optional tracker instance to monitor token usage during LLM calls. By default, ``None``. @@ -625,9 +626,8 @@ async def filter_ordinance_docs( Returns ------- list or None - List of :obj:`~elm.web.document.BaseDocument` instances possibly - containing ordinance information, or ``None`` if no ordinance - document was found. + List of BaseDocument instances possibly containing ordinance + information, or ``None`` if no ordinance document was found. Notes ----- @@ -672,13 +672,12 @@ async def filter_ordinance_docs( ) docs = await filter_documents( docs, - validation_coroutine=_contains_ordinances, + validation_coroutine=_contains_relevant_text, task_name=jurisdiction.full_name, model_configs=model_configs, heuristic=heuristic, tech=tech, - ordinance_text_collector_class=ordinance_text_collector_class, - permitted_use_text_collector_class=permitted_use_text_collector_class, + text_collectors=text_collectors, usage_tracker=usage_tracker, ) if not docs: @@ -760,7 +759,7 @@ async def _down_select_docs_correct_jurisdiction( ) -async def _contains_ordinances( +async def _contains_relevant_text( doc, model_configs, usage_tracker=None, **kwargs ): """Determine whether a document contains ordinance information""" @@ -772,22 +771,21 @@ async def _contains_ordinances( "Checking doc for ordinance info (source: %r)...", doc.attrs.get("source", "unknown"), ) - doc = await check_for_ordinance_info( + found_text = await check_for_relevant_text( doc, model_config=model_config, usage_tracker=usage_tracker, **kwargs, ) - contains_ordinances = doc.attrs.get("contains_ord_info", False) - if contains_ordinances: - logger.debug("Detected ordinance info; parsing date...") + if found_text: + logger.debug("Detected relevant text; parsing date...") date_model_config = model_configs.get( LLMTasks.DATE_EXTRACTION, model_configs[LLMTasks.DEFAULT] ) doc = await extract_date( doc, date_model_config, usage_tracker=usage_tracker ) - return contains_ordinances + return found_text def _sort_final_ord_docs(all_ord_docs): diff --git a/compass/scripts/process.py b/compass/scripts/process.py index a311744ad..cbb8f67a0 100644 --- a/compass/scripts/process.py +++ b/compass/scripts/process.py @@ -9,9 +9,9 @@ from contextlib import AsyncExitStack, contextmanager from datetime import datetime, UTC -import pandas as pd from elm.web.utilities import get_redirected_url +from compass.extraction.context import ExtractionContext from compass.scripts.download import ( find_jurisdiction_website, download_known_urls, @@ -19,60 +19,14 @@ download_jurisdiction_ordinance_using_search_engine, download_jurisdiction_ordinances_from_website, download_jurisdiction_ordinances_from_website_compass_crawl, - filter_ordinance_docs, ) from compass.exceptions import COMPASSValueError, COMPASSError -from compass.extraction import ( - extract_ordinance_values, - extract_ordinance_text_with_ngram_validation, -) -from compass.extraction.solar import ( - SolarHeuristic, - SolarOrdinanceTextCollector, - SolarOrdinanceTextExtractor, - SolarPermittedUseDistrictsTextCollector, - SolarPermittedUseDistrictsTextExtractor, - StructuredSolarOrdinanceParser, - StructuredSolarPermittedUseDistrictsParser, - SOLAR_QUESTION_TEMPLATES, - BEST_SOLAR_ORDINANCE_WEBSITE_URL_KEYWORDS, -) -from compass.extraction.wind import ( - WindHeuristic, - WindOrdinanceTextCollector, - WindOrdinanceTextExtractor, - WindPermittedUseDistrictsTextCollector, - WindPermittedUseDistrictsTextExtractor, - StructuredWindOrdinanceParser, - StructuredWindPermittedUseDistrictsParser, - WIND_QUESTION_TEMPLATES, - BEST_WIND_ORDINANCE_WEBSITE_URL_KEYWORDS, -) -from compass.extraction.small_wind import ( - SmallWindHeuristic, - SmallWindOrdinanceTextCollector, - SmallWindOrdinanceTextExtractor, - SmallWindPermittedUseDistrictsTextCollector, - SmallWindPermittedUseDistrictsTextExtractor, - StructuredSmallWindOrdinanceParser, - StructuredSmallWindPermittedUseDistrictsParser, - SMALL_WIND_QUESTION_TEMPLATES, - BEST_SMALL_WIND_ORDINANCE_WEBSITE_URL_KEYWORDS, -) -from compass.extraction.water import ( - build_corpus, - extract_water_rights_ordinance_values, - label_docs_no_legal_check, - write_water_rights_data_to_disk, - WaterRightsHeuristic, - WaterRightsTextCollector, - WaterRightsTextExtractor, - StructuredWaterParser, - WATER_RIGHTS_QUESTION_TEMPLATES, - BEST_WATER_RIGHTS_ORDINANCE_WEBSITE_URL_KEYWORDS, -) +from compass.extraction.wind import COMPASSWindExtractor +from compass.extraction.solar import COMPASSSolarExtractor +from compass.extraction.small_wind import COMPASSSmallWindExtractor +from compass.extraction.water.plugin import TexasWaterRightsExtractor from compass.validation.location import JurisdictionWebsiteValidator -from compass.llm import LLMCaller, OpenAIConfig +from compass.llm import OpenAIConfig from compass.services.cpu import ( PDFLoader, OCRPDFLoader, @@ -97,18 +51,15 @@ from compass.utilities import ( LLM_COST_REGISTRY, compile_run_summary_message, - doc_infos_to_db, load_all_jurisdiction_info, load_jurisdictions_from_fp, - num_ordinances_dataframe, - save_db, save_run_meta, Directories, ProcessKwargs, - TechSpec, + compute_total_cost_from_usage, ) from compass.utilities.enums import LLMTasks -from compass.utilities.location import Jurisdiction +from compass.utilities.jurisdictions import jurisdictions_from_df from compass.utilities.logs import ( LocationFileLog, LogListener, @@ -121,46 +72,12 @@ logger = logging.getLogger(__name__) -EXCLUDE_FROM_ORD_DOC_CHECK = { - # if doc only contains these, it's not good enough to count as an - # ordinance. Note that prohibitions are explicitly not on this list - "color", - "decommissioning", - "lighting", - "visual impact", - "glare", - "repowering", - "fencing", - "climbing prevention", - "signage", - "soil", - "primary use districts", - "special use districts", - "accessory use districts", -} -_TEXT_EXTRACTION_TASKS = { - WindOrdinanceTextExtractor: "Extracting wind ordinance text", - WindPermittedUseDistrictsTextExtractor: ( - "Extracting wind permitted use text" - ), - SolarOrdinanceTextExtractor: "Extracting solar ordinance text", - SolarPermittedUseDistrictsTextExtractor: ( - "Extracting solar permitted use text" - ), - SmallWindOrdinanceTextExtractor: ("Extracting small wind ordinance text"), - SmallWindPermittedUseDistrictsTextExtractor: ( - "Extracting small wind permitted use text" - ), - WaterRightsTextExtractor: "Extracting water rights ordinance text", +EXTRACTION_REGISTRY = { + COMPASSWindExtractor.IDENTIFIER.casefold(): COMPASSWindExtractor, + COMPASSSolarExtractor.IDENTIFIER.casefold(): COMPASSSolarExtractor, + COMPASSSmallWindExtractor.IDENTIFIER.casefold(): COMPASSSmallWindExtractor, + TexasWaterRightsExtractor.IDENTIFIER.casefold(): TexasWaterRightsExtractor, } -_JUR_COLS = [ - "Jurisdiction Type", - "State", - "County", - "Subdivision", - "FIPS", - "Website", -] MAX_CONCURRENT_SEARCH_ENGINE_QUERIES = 10 @@ -219,7 +136,7 @@ async def process_jurisdictions_with_openai( # noqa: PLR0917, PLR0913 CSV file, all downloaded ordinance documents (PDFs and HTML), usage metadata, and default subdirectories for logs and intermediate outputs (unless otherwise specified). - tech : {"wind", "solar", "small wind"} + tech : {"wind", "solar", "small wind", "tx water rights"} Label indicating which technology type is being processed. jurisdiction_fp : path-like Path to a CSV file specifying the jurisdictions to process. @@ -645,9 +562,12 @@ def tpe_kwargs(self): return _configure_thread_pool_kwargs(self.process_kwargs.tpe_kwargs) @cached_property - def tech_specs(self): - """TechSpec: TechSpec for the current technology""" - return _compile_tech_specs(self.tech) + def extractor_class(self): + """obj: Extractor class for the specified technology""" + if self.tech.casefold() not in EXTRACTION_REGISTRY: + msg = f"Unknown tech input: {self.tech}" + raise COMPASSValueError(msg) + return EXTRACTION_REGISTRY[self.tech.casefold()] @cached_property def _base_services(self): @@ -704,25 +624,25 @@ async def run(self, jurisdiction_fp): terminal and may include color-coded cost information if the terminal supports it. """ - jurisdictions = _load_jurisdictions_to_process(jurisdiction_fp) + jurisdictions_df = _load_jurisdictions_to_process(jurisdiction_fp) - num_jurisdictions = len(jurisdictions) + num_jurisdictions = len(jurisdictions_df) COMPASS_PB.create_main_task(num_jurisdictions=num_jurisdictions) start_date = datetime.now(UTC) - doc_infos, total_cost = await self._run_all(jurisdictions) + doc_infos, total_cost = await self._run_all(jurisdictions_df) doc_infos = [ di for di in doc_infos if di is not None and di.get("ord_db_fp") is not None ] - if self.tech_specs.save_db_callback is not None: - num_docs_found = self.tech_specs.save_db_callback( + if doc_infos: + num_docs_found = self.extractor_class.save_structured_data( doc_infos, self.dirs.out ) else: - num_docs_found = _write_data_to_disk(doc_infos, self.dirs.out) + num_docs_found = 0 total_time = save_run_meta( self.dirs, @@ -746,33 +666,24 @@ async def run(self, jurisdiction_fp): ) return run_msg - async def _run_all(self, jurisdictions): + async def _run_all(self, jurisdictions_df): """Process all jurisdictions while required services run""" services = [model.llm_service for model in set(self.models.values())] services += self._base_services _ = self.file_loader_kwargs # init loader kwargs once _ = self.local_file_loader_kwargs # init local loader kwargs once - logger.info("Processing %d jurisdiction(s)", len(jurisdictions)) + logger.info("Processing %d jurisdiction(s)", len(jurisdictions_df)) async with RunningAsyncServices(services): tasks = [] - for __, row in jurisdictions.iterrows(): - jur_type, state, county, sub, fips, website = row[_JUR_COLS] - jurisdiction = Jurisdiction( - subdivision_type=jur_type, - state=state, - county=county, - subdivision_name=sub, - code=fips, - ) + for jurisdiction in jurisdictions_from_df(jurisdictions_df): usage_tracker = UsageTracker( jurisdiction.full_name, usage_from_response ) task = asyncio.create_task( self._processed_jurisdiction_info_with_pb( jurisdiction, - website, - self.known_local_docs.get(fips), - self.known_doc_urls.get(fips), + self.known_local_docs.get(jurisdiction.code), + self.known_doc_urls.get(jurisdiction.code), usage_tracker=usage_tracker, ), name=jurisdiction.full_name, @@ -793,23 +704,30 @@ async def _processed_jurisdiction_info_with_pb( jurisdiction, *args, **kwargs ) - async def _processed_jurisdiction_info(self, *args, **kwargs): + async def _processed_jurisdiction_info( + self, jurisdiction, *args, **kwargs + ): """Convert processed document to minimal metadata""" - doc = await self._process_jurisdiction_with_logging(*args, **kwargs) + extraction_context = await self._process_jurisdiction_with_logging( + jurisdiction, *args, **kwargs + ) - if doc is None or isinstance(doc, Exception): + if extraction_context is None or isinstance( + extraction_context, Exception + ): return None - keys = ["source", "date", "jurisdiction", "ord_db_fp"] - doc_info = {key: doc.attrs.get(key) for key in keys} + doc_info = { + "jurisdiction": jurisdiction, + "ord_db_fp": extraction_context.attrs.get("ord_db_fp"), + } logger.debug("Saving the following doc info:\n%s", doc_info) return doc_info async def _process_jurisdiction_with_logging( self, jurisdiction, - jurisdiction_website, known_local_docs=None, known_doc_urls=None, usage_tracker=None, @@ -823,7 +741,11 @@ async def _process_jurisdiction_with_logging( ): task = asyncio.create_task( _SingleJurisdictionRunner( - self.tech_specs, + self.extractor_class( + jurisdiction=jurisdiction, + model_configs=self.models, + usage_tracker=usage_tracker, + ), jurisdiction, self.models, self.web_search_params, @@ -834,7 +756,6 @@ async def _process_jurisdiction_with_logging( browser_semaphore=self.browser_semaphore, crawl_semaphore=self.crawl_semaphore, search_engine_semaphore=self.search_engine_semaphore, - jurisdiction_website=jurisdiction_website, perform_se_search=self.perform_se_search, perform_website_search=self.perform_website_search, usage_tracker=usage_tracker, @@ -842,16 +763,16 @@ async def _process_jurisdiction_with_logging( name=jurisdiction.full_name, ) try: - doc, *__ = await asyncio.gather(task) + extraction_context, *__ = await asyncio.gather(task) except KeyboardInterrupt: raise except Exception as e: msg = "Encountered error of type %r while processing %s:" err_type = type(e) logger.exception(msg, err_type, jurisdiction.full_name) - doc = None + extraction_context = None - return doc + return extraction_context class _SingleJurisdictionRunner: @@ -859,7 +780,7 @@ class _SingleJurisdictionRunner: def __init__( # noqa: PLR0913 self, - tech_specs, + extractor, jurisdiction, models, web_search_params, @@ -871,12 +792,11 @@ def __init__( # noqa: PLR0913 browser_semaphore=None, crawl_semaphore=None, search_engine_semaphore=None, - jurisdiction_website=None, perform_se_search=True, perform_website_search=True, usage_tracker=None, ): - self.tech_specs = tech_specs + self.extractor = extractor self.jurisdiction = jurisdiction self.models = models self.web_search_params = web_search_params @@ -888,9 +808,9 @@ def __init__( # noqa: PLR0913 self.crawl_semaphore = crawl_semaphore self.search_engine_semaphore = search_engine_semaphore self.usage_tracker = usage_tracker - self.jurisdiction_website = jurisdiction_website self.perform_se_search = perform_se_search self.perform_website_search = perform_website_search + self.jurisdiction_website = jurisdiction.website_url self.validate_user_website_input = True self._jsp = None @@ -915,29 +835,32 @@ async def run(self): Returns ------- - elm.web.document.BaseDocument or None + BaseDocument or None Document containing ordinance information, or ``None`` when no valid ordinance content was identified. """ start_time = time.monotonic() - doc = None + extraction_context = None logger.info( "Kicking off processing for jurisdiction: %s", self.jurisdiction.full_name, ) try: - doc = await self._run() + extraction_context = await self._run() finally: - await self._record_usage() + await self.extractor.record_usage() await _record_jurisdiction_info( - self.jurisdiction, doc, start_time, self.usage_tracker + self.jurisdiction, + extraction_context, + start_time, + self.usage_tracker, ) logger.info( "Completed processing for jurisdiction: %s", self.jurisdiction.full_name, ) - return doc + return extraction_context async def _run(self): """Search for documents and parse them for ordinances""" @@ -946,22 +869,22 @@ async def _run(self): "Checking local docs for jurisdiction: %s", self.jurisdiction.full_name, ) - doc = await self._try_find_ordinances( + extraction_context = await self._try_find_ordinances( method=self._load_known_local_documents, ) - if doc is not None: - return doc + if extraction_context is not None: + return extraction_context if self.known_doc_urls: logger.debug( "Checking known URLs for jurisdiction: %s", self.jurisdiction.full_name, ) - doc = await self._try_find_ordinances( + extraction_context = await self._try_find_ordinances( method=self._download_known_url_documents, ) - if doc is not None: - return doc + if extraction_context is not None: + return extraction_context if self.perform_se_search: logger.debug( @@ -969,36 +892,41 @@ async def _run(self): "jurisdiction: %s", self.jurisdiction.full_name, ) - doc = await self._try_find_ordinances( + extraction_context = await self._try_find_ordinances( method=self._find_documents_using_search_engine, ) - if doc is not None: - return doc + if extraction_context is not None: + return extraction_context if self.perform_website_search: logger.debug( "Collecting documents from the jurisdiction website for: %s", self.jurisdiction.full_name, ) - doc = await self._try_find_ordinances( + extraction_context = await self._try_find_ordinances( method=self._find_documents_from_website, ) - if doc is not None: - return doc + if extraction_context is not None: + return extraction_context return None async def _try_find_ordinances(self, method, *args, **kwargs): """Execute a retrieval method and parse resulting documents""" - docs = await method(*args, **kwargs) - if docs is None: + extraction_context = await method(*args, **kwargs) + if extraction_context is None: return None COMPASS_PB.update_jurisdiction_task( self.jurisdiction.full_name, description="Extracting structured data...", ) - return await self._parse_docs_for_ordinances(docs) + context = await self.extractor.parse_docs_for_structured_data( + extraction_context + ) + await self._write_out_structured_data(extraction_context) + logger.debug("Final extraction context:\n%s", context) + return context async def _load_known_local_documents(self): """Load ordinance documents from known local file paths""" @@ -1015,20 +943,17 @@ async def _load_known_local_documents(self): _add_known_doc_attrs_to_all_docs( docs, self.known_local_docs, key="source_fp" ) - docs = await self._filter_down_docs( - docs, check_for_correct_jurisdiction=False + extraction_context = await self._filter_docs( + docs, need_jurisdiction_verification=False ) - if not docs: + if not extraction_context: return None - for doc in docs: - doc.attrs["jurisdiction"] = self.jurisdiction - doc.attrs["jurisdiction_name"] = self.jurisdiction.full_name - doc.attrs["jurisdiction_website"] = None - doc.attrs["compass_crawl"] = False + extraction_context.attrs["jurisdiction_website"] = None + extraction_context.attrs["compass_crawl"] = False - await self._record_usage() - return docs + await self.extractor.record_usage() + return extraction_context async def _download_known_url_documents(self): """Download ordinance documents from pre-specified URLs""" @@ -1046,25 +971,22 @@ async def _download_known_url_documents(self): _add_known_doc_attrs_to_all_docs( docs, self.known_doc_urls, key="source" ) - docs = await self._filter_down_docs( - docs, check_for_correct_jurisdiction=False + extraction_context = await self._filter_docs( + docs, need_jurisdiction_verification=False ) - if not docs: + if not extraction_context: return None - for doc in docs: - doc.attrs["jurisdiction"] = self.jurisdiction - doc.attrs["jurisdiction_name"] = self.jurisdiction.full_name - doc.attrs["jurisdiction_website"] = None - doc.attrs["compass_crawl"] = False + extraction_context.attrs["jurisdiction_website"] = None + extraction_context.attrs["compass_crawl"] = False - await self._record_usage() - return docs + await self.extractor.record_usage() + return extraction_context async def _find_documents_using_search_engine(self): """Search the web for ordinance docs using search engines""" docs = await download_jurisdiction_ordinance_using_search_engine( - self.tech_specs.questions, + self.extractor.QUESTION_TEMPLATES, self.jurisdiction, num_urls=self.web_search_params.num_urls_to_check_per_jurisdiction, file_loader_kwargs=self.file_loader_kwargs, @@ -1073,20 +995,17 @@ async def _find_documents_using_search_engine(self): url_ignore_substrings=self.web_search_params.url_ignore_substrings, **self.web_search_params.se_kwargs, ) - docs = await self._filter_down_docs( - docs, check_for_correct_jurisdiction=True + extraction_context = await self._filter_docs( + docs, need_jurisdiction_verification=True ) - if not docs: + if not extraction_context: return None - for doc in docs: - doc.attrs["jurisdiction"] = self.jurisdiction - doc.attrs["jurisdiction_name"] = self.jurisdiction.full_name - doc.attrs["jurisdiction_website"] = None - doc.attrs["compass_crawl"] = False + extraction_context.attrs["jurisdiction_website"] = None + extraction_context.attrs["compass_crawl"] = False - await self._record_usage() - return docs + await self.extractor.record_usage() + return extraction_context async def _find_documents_from_website(self): """Search the jurisdiction website for ordinance documents""" @@ -1099,24 +1018,23 @@ async def _find_documents_from_website(self): return None self.jurisdiction_website = website - docs, scrape_results = await self._try_elm_crawl() + extraction_context, scrape_results = await self._try_elm_crawl() found_with_compass_crawl = False - if not docs: - docs = await self._try_compass_crawl(scrape_results) + if not extraction_context: + extraction_context = await self._try_compass_crawl(scrape_results) found_with_compass_crawl = True - if not docs: + if not extraction_context: return None - for doc in docs: - doc.attrs["jurisdiction"] = self.jurisdiction - doc.attrs["jurisdiction_name"] = self.jurisdiction.full_name - doc.attrs["jurisdiction_website"] = self.jurisdiction_website - doc.attrs["compass_crawl"] = found_with_compass_crawl + extraction_context.attrs["jurisdiction_website"] = ( + self.jurisdiction_website + ) + extraction_context.attrs["compass_crawl"] = found_with_compass_crawl - await self._record_usage() - return docs + await self.extractor.record_usage() + return extraction_context async def _validate_jurisdiction_website(self): """Validate a user-supplied jurisdiction website URL""" @@ -1175,18 +1093,18 @@ async def _try_elm_crawl(self): ) out = await download_jurisdiction_ordinances_from_website( self.jurisdiction_website, - heuristic=self.tech_specs.heuristic, - keyword_points=self.tech_specs.website_url_keyword_points, + heuristic=self.extractor.heuristic, + keyword_points=self.extractor.WEBSITE_KEYWORDS, file_loader_kwargs=self.file_loader_kwargs_no_ocr, crawl_semaphore=self.crawl_semaphore, pb_jurisdiction_name=self.jurisdiction.full_name, return_c4ai_results=True, ) docs, scrape_results = out - docs = await self._filter_down_docs( - docs, check_for_correct_jurisdiction=True + extraction_context = await self._filter_docs( + docs, need_jurisdiction_verification=True ) - return docs, scrape_results + return extraction_context, scrape_results async def _try_compass_crawl(self, scrape_results): """Crawl the jurisdiction website using the COMPASS crawler""" @@ -1196,305 +1114,44 @@ async def _try_compass_crawl(self, scrape_results): docs = ( await download_jurisdiction_ordinances_from_website_compass_crawl( self.jurisdiction_website, - heuristic=self.tech_specs.heuristic, - keyword_points=self.tech_specs.website_url_keyword_points, + heuristic=self.extractor.heuristic, + keyword_points=self.extractor.WEBSITE_KEYWORDS, file_loader_kwargs=self.file_loader_kwargs_no_ocr, already_visited=checked_urls, crawl_semaphore=self.crawl_semaphore, pb_jurisdiction_name=self.jurisdiction.full_name, ) ) - return await self._filter_down_docs( - docs, check_for_correct_jurisdiction=True + return await self._filter_docs( + docs, need_jurisdiction_verification=True ) - async def _filter_down_docs(self, docs, check_for_correct_jurisdiction): - """Filter down candidate documents before parsing""" - if docs and self.tech_specs.post_download_docs_hook is not None: - logger.debug( - "%d document(s) passed in to `post_download_docs_hook` for " - "%s\n\t- %s", - len(docs), - self.jurisdiction.full_name, - "\n\t- ".join( - [doc.attrs.get("source", "Unknown source") for doc in docs] - ), - ) - - docs = await self.tech_specs.post_download_docs_hook( - docs, - jurisdiction=self.jurisdiction, - model_configs=self.models, - usage_tracker=self.usage_tracker, - ) - logger.info( - "%d document(s) remaining after `post_download_docs_hook` for " - "%s\n\t- %s", - len(docs), - self.jurisdiction.full_name, - "\n\t- ".join( - [doc.attrs.get("source", "Unknown source") for doc in docs] - ), - ) - - docs = await filter_ordinance_docs( - docs, - self.jurisdiction, - self.models, - heuristic=self.tech_specs.heuristic, - tech=self.tech_specs.name, - ordinance_text_collector_class=( - self.tech_specs.ordinance_text_collector - ), - permitted_use_text_collector_class=( - self.tech_specs.permitted_use_text_collector - ), - usage_tracker=self.usage_tracker, - check_for_correct_jurisdiction=check_for_correct_jurisdiction, - ) - - if docs and self.tech_specs.post_filter_docs_hook is not None: - logger.debug( - "Passing %d document(s) in to `post_filter_docs_hook` ", - len(docs), - ) - docs = await self.tech_specs.post_filter_docs_hook( - docs, - jurisdiction=self.jurisdiction, - model_configs=self.models, - usage_tracker=self.usage_tracker, - ) - logger.info( - "%d document(s) remaining after `post_filter_docs_hook` for " - "%s\n\t- %s", - len(docs), - self.jurisdiction.full_name, - "\n\t- ".join( - [doc.attrs.get("source", "Unknown source") for doc in docs] - ), - ) - - return docs or None - - async def _parse_docs_for_ordinances(self, docs): - """Parse candidate documents in order until ordinances found""" - for possible_ord_doc in docs: - doc = await self._try_extract_all_ordinances(possible_ord_doc) - ord_count = self._get_ordinance_count(doc) - if ord_count > 0: - doc = await _move_files(doc) - logger.info( - "%d ordinance value(s) found in doc from %s for %s. " - "Outputs are here: '%s'", - ord_count, - possible_ord_doc.attrs.get("source", "unknown source"), - self.jurisdiction.full_name, - doc.attrs["ord_db_fp"], - ) - return doc - - logger.debug("No ordinances found; searched %d docs", len(docs)) - return None - - def _get_ordinance_count(self, doc): - """Get the number of ordinances extracted from a document""" - if doc is None or doc.attrs.get("ordinance_values") is None: - return 0 - - ord_df = doc.attrs["ordinance_values"] - - if self.tech_specs.num_ordinances_in_df_callback is not None: - return self.tech_specs.num_ordinances_in_df_callback(ord_df) - - return num_ordinances_dataframe( - ord_df, exclude_features=EXCLUDE_FROM_ORD_DOC_CHECK - ) - - async def _try_extract_all_ordinances(self, possible_ord_doc): - """Extract both ordinance values and permitted-use districts""" - with self._tracked_progress(): - tasks = [ - asyncio.create_task( - self._try_extract_ordinances(possible_ord_doc, **kwargs), - name=self.jurisdiction.full_name, - ) - for kwargs in self._extraction_task_kwargs - ] - - docs = await asyncio.gather(*tasks) - - return _concat_scrape_results(docs[0]) + async def _filter_docs(self, docs, need_jurisdiction_verification): + if not docs: + return None - @property - def _extraction_task_kwargs(self): - """list: Dictionaries describing extraction task config""" - tasks = [ - { - "extractor_class": self.tech_specs.ordinance_text_extractor, - "original_text_key": "ordinance_text", - "cleaned_text_key": "cleaned_ordinance_text", - "text_model": self.models.get( - LLMTasks.ORDINANCE_TEXT_EXTRACTION, - self.models[LLMTasks.DEFAULT], - ), - "parser_class": self.tech_specs.structured_ordinance_parser, - "out_key": "ordinance_values", - "value_model": self.models.get( - LLMTasks.ORDINANCE_VALUE_EXTRACTION, - self.models[LLMTasks.DEFAULT], - ), - } - ] - if ( - self.tech_specs.permitted_use_text_extractor is None - or self.tech_specs.structured_permitted_use_parser is None - ): - return tasks - - tasks.append( - { - "extractor_class": ( - self.tech_specs.permitted_use_text_extractor - ), - "original_text_key": "permitted_use_text", - "cleaned_text_key": "districts_text", - "text_model": self.models.get( - LLMTasks.PERMITTED_USE_TEXT_EXTRACTION, - self.models[LLMTasks.DEFAULT], - ), - "parser_class": ( - self.tech_specs.structured_permitted_use_parser - ), - "out_key": "permitted_district_values", - "value_model": self.models.get( - LLMTasks.PERMITTED_USE_VALUE_EXTRACTION, - self.models[LLMTasks.DEFAULT], - ), - } + extraction_context = ExtractionContext(documents=docs) + return await self.extractor.filter_docs( + extraction_context, + need_jurisdiction_verification=need_jurisdiction_verification, ) - return tasks - async def _try_extract_ordinances( - self, - possible_ord_doc, - extractor_class, - original_text_key, - cleaned_text_key, - parser_class, - out_key, - text_model, - value_model, - ): - """Apply a single extractor and parser to legal text""" - logger.debug( - "Checking for ordinances in doc from %s", - possible_ord_doc.attrs.get("source", "unknown source"), - ) - assert self._jsp is not None, "No progress bar set!" - task_id = self._jsp.add_task(_TEXT_EXTRACTION_TASKS[extractor_class]) - doc = await _extract_ordinance_text( - possible_ord_doc, - extractor_class=extractor_class, - original_text_key=original_text_key, - usage_tracker=self.usage_tracker, - model_config=text_model, - ) - await self._record_usage() - self._jsp.remove_task(task_id) - if self.tech_specs.extract_ordinances_callback is None: - out = await _extract_ordinances_from_text( - doc, - parser_class=parser_class, - text_key=cleaned_text_key, - out_key=out_key, - usage_tracker=self.usage_tracker, - model_config=value_model, - ) - else: - out = await self.tech_specs.extract_ordinances_callback( - doc, - parser_class=parser_class, - text_key=cleaned_text_key, - out_key=out_key, - usage_tracker=self.usage_tracker, - model_config=value_model, - ) - await self._record_usage() - return out - - async def _record_usage(self): - """Persist usage tracking data when a tracker is available""" - if self.usage_tracker is None: + async def _write_out_structured_data(self, extraction_context): + """Write cleaned text to `jurisdiction_dbs` dir""" + if extraction_context.attrs.get("structured_data") is None: return - total_usage = await UsageUpdater.call(self.usage_tracker) - total_cost = _compute_total_cost_from_usage(total_usage) - COMPASS_PB.update_total_cost(total_cost, replace=True) - - -def _compile_tech_specs(tech): - """Compile `TechSpec` tuple based on the user `tech` input""" - if tech.casefold() == "wind": - return TechSpec( - "wind", - WIND_QUESTION_TEMPLATES, - WindHeuristic(), - WindOrdinanceTextCollector, - WindOrdinanceTextExtractor, - WindPermittedUseDistrictsTextCollector, - WindPermittedUseDistrictsTextExtractor, - StructuredWindOrdinanceParser, - StructuredWindPermittedUseDistrictsParser, - BEST_WIND_ORDINANCE_WEBSITE_URL_KEYWORDS, - ) - if tech.casefold() == "solar": - return TechSpec( - "solar", - SOLAR_QUESTION_TEMPLATES, - SolarHeuristic(), - SolarOrdinanceTextCollector, - SolarOrdinanceTextExtractor, - SolarPermittedUseDistrictsTextCollector, - SolarPermittedUseDistrictsTextExtractor, - StructuredSolarOrdinanceParser, - StructuredSolarPermittedUseDistrictsParser, - BEST_SOLAR_ORDINANCE_WEBSITE_URL_KEYWORDS, - ) - if tech.casefold() == "small wind": - return TechSpec( - "small wind", - SMALL_WIND_QUESTION_TEMPLATES, - SmallWindHeuristic(), - SmallWindOrdinanceTextCollector, - SmallWindOrdinanceTextExtractor, - SmallWindPermittedUseDistrictsTextCollector, - SmallWindPermittedUseDistrictsTextExtractor, - StructuredSmallWindOrdinanceParser, - StructuredSmallWindPermittedUseDistrictsParser, - BEST_SMALL_WIND_ORDINANCE_WEBSITE_URL_KEYWORDS, - ) + out_fn = extraction_context.attrs.get("out_data_fn") + if out_fn is None: + out_fn = f"{self.jurisdiction.full_name} Ordinances.csv" - if tech.casefold() == "water rights": - return TechSpec( - "water rights", - WATER_RIGHTS_QUESTION_TEMPLATES, - WaterRightsHeuristic(), - WaterRightsTextCollector, - WaterRightsTextExtractor, - None, - None, - StructuredWaterParser, - None, - BEST_WATER_RIGHTS_ORDINANCE_WEBSITE_URL_KEYWORDS, - label_docs_no_legal_check, - build_corpus, - extract_water_rights_ordinance_values, - len, - write_water_rights_data_to_disk, + out_fp = await OrdDBFileWriter.call(extraction_context, out_fn) + logger.info( + "Structured data for %s stored here: '%s'", + self.jurisdiction.full_name, + out_fp, ) - - msg = f"Unknown tech input: {tech}" - raise COMPASSValueError(msg) + extraction_context.attrs["ord_db_fp"] = out_fp def _setup_main_logging(log_dir, level, listener, keep_async_logs): @@ -1521,7 +1178,8 @@ def _log_exec_info(called_args, steps): log_versions(logger) logger.info( - "Using the following processing step(s):\n\t%s", " -> ".join(steps) + "Using the following document acquisition step(s):\n\t%s", + " -> ".join(steps), ) normalized_args = convert_paths_to_strings(called_args) @@ -1629,78 +1287,14 @@ def _configure_file_loader_kwargs(file_loader_kwargs): return file_loader_kwargs -async def _extract_ordinance_text( - doc, extractor_class, original_text_key, usage_tracker, model_config -): - """Extract text pertaining to ordinance of interest""" - llm_caller = LLMCaller( - llm_service=model_config.llm_service, - usage_tracker=usage_tracker, - **model_config.llm_call_kwargs, - ) - extractor = extractor_class(llm_caller) - doc = await extract_ordinance_text_with_ngram_validation( - doc, - model_config.text_splitter, - extractor, - original_text_key=original_text_key, - ) - return await _write_cleaned_text(doc) - - -async def _extract_ordinances_from_text( - doc, parser_class, text_key, out_key, usage_tracker, model_config +async def _record_jurisdiction_info( + loc, extraction_context, start_time, usage_tracker ): - """Extract values from ordinance text""" - parser = parser_class( - llm_service=model_config.llm_service, - usage_tracker=usage_tracker, - **model_config.llm_call_kwargs, - ) - logger.info("Extracting %s...", out_key.replace("_", " ")) - return await extract_ordinance_values( - doc, parser, text_key=text_key, out_key=out_key - ) - - -async def _move_files(doc): - """Move files to output folders, if applicable""" - doc = await _move_file_to_out_dir(doc) - return await _write_ord_db(doc) - - -async def _move_file_to_out_dir(doc): - """Move PDF or HTML text file to output directory""" - out_fp = await FileMover.call(doc) - doc.attrs["out_fp"] = out_fp - return doc - - -async def _write_cleaned_text(doc): - """Write cleaned text to `clean_files` dir""" - out_fp = await CleanedFileWriter.call(doc) - doc.attrs["cleaned_fps"] = out_fp - return doc - - -async def _write_ord_db(doc): - """Write cleaned text to `jurisdiction_dbs` dir""" - out_fp = await OrdDBFileWriter.call(doc) - doc.attrs["ord_db_fp"] = out_fp - return doc - - -def _write_data_to_disk(doc_infos, out_dir): - """Write extracted data to disk""" - db, num_docs_found = doc_infos_to_db(doc_infos) - save_db(db, out_dir) - return num_docs_found - - -async def _record_jurisdiction_info(loc, doc, start_time, usage_tracker): """Record info about jurisdiction""" seconds_elapsed = time.monotonic() - start_time - await JurisdictionUpdater.call(loc, doc, seconds_elapsed, usage_tracker) + await JurisdictionUpdater.call( + loc, extraction_context, seconds_elapsed, usage_tracker + ) def _setup_pytesseract(exe_fp): @@ -1711,52 +1305,13 @@ def _setup_pytesseract(exe_fp): pytesseract.pytesseract.tesseract_cmd = exe_fp -def _concat_scrape_results(doc): - data = [ - doc.attrs.get(key, None) - for key in ["ordinance_values", "permitted_district_values"] - ] - data = [df for df in data if df is not None and not df.empty] - if len(data) == 0: - return doc - - if len(data) == 1: - doc.attrs["scraped_values"] = data[0] - return doc - - doc.attrs["scraped_values"] = pd.concat(data) - return doc - - async def _compute_total_cost(): """Compute total cost from tracked usage""" total_usage = await UsageUpdater.call(None) if not total_usage: return 0 - return _compute_total_cost_from_usage(total_usage) - - -def _compute_total_cost_from_usage(tracked_usage): - """Compute total cost from total tracked usage""" - - total_cost = 0 - for usage in tracked_usage.values(): - totals = usage.get("tracker_totals", {}) - for model, total_usage in totals.items(): - model_costs = LLM_COST_REGISTRY.get(model, {}) - total_cost += ( - total_usage.get("prompt_tokens", 0) - / 1e6 - * model_costs.get("prompt", 0) - ) - total_cost += ( - total_usage.get("response_tokens", 0) - / 1e6 - * model_costs.get("response", 0) - ) - - return total_cost + return compute_total_cost_from_usage(total_usage) def _add_known_doc_attrs_to_all_docs(docs, doc_infos, key): diff --git a/compass/services/openai.py b/compass/services/openai.py index c8f080490..a089d6cc9 100644 --- a/compass/services/openai.py +++ b/compass/services/openai.py @@ -8,7 +8,7 @@ from compass.services.base import LLMService from compass.services.usage import TimeBoundedUsageTracker -from compass.utilities import LLM_COST_REGISTRY +from compass.utilities import cost_for_model from compass.utilities.enums import LLMUsageCategory from compass.pb import COMPASS_PB @@ -202,16 +202,12 @@ def _update_pb_cost(self, response): if response is None: return - model_costs = LLM_COST_REGISTRY.get(self.model_name, {}) - prompt_cost = ( - response.usage.prompt_tokens / 1e6 * model_costs.get("prompt", 0) + response_cost = cost_for_model( + self.model_name, + response.usage.prompt_tokens, + response.usage.completion_tokens, ) - response_cost = ( - response.usage.completion_tokens - / 1e6 - * model_costs.get("response", 0) - ) - COMPASS_PB.update_total_cost(prompt_cost + response_cost) + COMPASS_PB.update_total_cost(response_cost) @async_retry_with_exponential_backoff( base_delay=1, diff --git a/compass/services/threaded.py b/compass/services/threaded.py index c5c79bf62..c32e53ba9 100644 --- a/compass/services/threaded.py +++ b/compass/services/threaded.py @@ -18,10 +18,7 @@ from compass import COMPASS_DEBUG_LEVEL from compass.services.base import Service -from compass.utilities import ( - LLM_COST_REGISTRY, - num_ordinances_in_doc, -) +from compass.utilities import compute_cost_from_totals from compass.pb import COMPASS_PB @@ -53,7 +50,7 @@ def _compute_sha256(file_path): return f"sha256:{m.hexdigest()}" -def _move_file(doc, out_dir): +def _move_file(doc, out_dir, out_fn=None): """Move a file from a temp directory to an output directory""" cached_fp = doc.attrs.get("cache_fn") if cached_fp is None: @@ -61,7 +58,7 @@ def _move_file(doc, out_dir): cached_fp = Path(cached_fp) date = datetime.now().strftime("%Y_%m_%d") - out_fn = doc.attrs.get("jurisdiction_name", cached_fp.stem) + out_fn = out_fn or cached_fp.stem out_fn = out_fn.replace(",", "").replace(" ", "_") out_fn = f"{out_fn}_downloaded_{date}" if not out_fn.endswith(cached_fp.suffix): @@ -72,9 +69,8 @@ def _move_file(doc, out_dir): return out_fp -def _write_cleaned_file(doc, out_dir): +def _write_cleaned_file(doc, out_dir, jurisdiction_name=None): """Write cleaned ordinance text to directory""" - jurisdiction_name = doc.attrs.get("jurisdiction_name") if jurisdiction_name is None: return None @@ -83,7 +79,9 @@ def _write_cleaned_file(doc, out_dir): _write_interim_cleaned_files(doc, out_dir, jurisdiction_name) key_to_fp = { - "cleaned_ordinance_text": f"{jurisdiction_name} Ordinance Summary.txt", + "cleaned_text_for_extraction": ( + f"{jurisdiction_name} Cleaned Text.txt" + ), "districts_text": f"{jurisdiction_name} Districts.txt", } out_paths = [] @@ -102,7 +100,7 @@ def _write_cleaned_file(doc, out_dir): def _write_interim_cleaned_files(doc, out_dir, jurisdiction_name): """Write intermediate output texts to file; helpful for debugging""" key_to_fp = { - "ordinance_text": f"{jurisdiction_name} Ordinance Original text.txt", + "relevant_text": f"{jurisdiction_name} Ordinance Original text.txt", "wind_energy_systems_text": ( f"{jurisdiction_name} Wind Ordinance text.txt" ), @@ -124,15 +122,14 @@ def _write_interim_cleaned_files(doc, out_dir, jurisdiction_name): (out_dir / fn).write_text(text, encoding="utf-8") -def _write_ord_db(doc, out_dir): +def _write_ord_db(extraction_context, out_dir, out_fn=None): """Write parsed ordinance database to directory""" - ord_db = doc.attrs.get("scraped_values") - jurisdiction_name = doc.attrs.get("jurisdiction_name") + ord_db = extraction_context.attrs.get("structured_data") - if ord_db is None or jurisdiction_name is None: + if ord_db is None or out_fn is None: return None - out_fp = Path(out_dir) / f"{jurisdiction_name} Ordinances.csv" + out_fp = Path(out_dir) / out_fn ord_db.to_csv(out_fp, index=False) return out_fp @@ -209,7 +206,7 @@ async def process(self, doc, file_content, make_name_unique=False): Parameters ---------- - doc : elm.web.document.BaseDocument + doc : BaseDocument Document containing meta information about the file. Must have a "source" key in the ``attrs`` dict containing the URL, which will be converted to a file name using @@ -249,7 +246,7 @@ async def process(self, doc, file_content, make_name_unique=False): Parameters ---------- - doc : elm.web.document.BaseDocument + doc : BaseDocument Document containing meta information about the file. Must have a "source" key in the ``attrs`` dict containing the URL, which will be converted to a file name using @@ -304,16 +301,19 @@ def can_process(self): """bool: Always ``True`` (limiting is handled by asyncio)""" return True - async def process(self, doc): + async def process(self, doc, *args): """Store file in out directory Parameters ---------- - doc : elm.web.document.BaseDocument + doc : BaseDocument Document containing meta information about the file. Must have relevant processing keys in the ``attrs`` dict, otherwise the file may not be stored in the output directory. + args + Additional positional argument pairs to pass to the + processing function. Returns ------- @@ -322,7 +322,11 @@ async def process(self, doc): """ loop = asyncio.get_running_loop() return await loop.run_in_executor( - self.pool, _PROCESSING_FUNCTIONS[self._PROCESS], doc, self.out_dir + self.pool, + _PROCESSING_FUNCTIONS[self._PROCESS], + doc, + self.out_dir, + *args, ) @property @@ -428,7 +432,11 @@ def can_process(self): return not self._is_processing async def process( - self, jurisdiction, doc, seconds_elapsed, usage_tracker=None + self, + jurisdiction, + extraction_context, + seconds_elapsed, + usage_tracker=None, ): """Record jurisdiction metadata in the tracking file @@ -438,12 +446,12 @@ async def process( ---------- jurisdiction : Jurisdiction The jurisdiction instance to record. - doc : elm.web.document.BaseDocument or None - Document containing meta information about the jurisdiction. - Must have relevant processing keys in the ``attrs`` dict, - otherwise the jurisdiction may not be recorded properly. - If ``None``, the jurisdiction is assumed not to have been - found. + extraction_context : ExtractionContext + Context containing meta information about the jurisdiction + under extraction. Must have relevant processing keys in the + ``attrs`` dict, otherwise the jurisdiction may not be + recorded properly. If ``None``, the jurisdiction is assumed + not to have been found. seconds_elapsed : int or float Total number of seconds it took to look for (and possibly parse) this document. @@ -459,7 +467,7 @@ async def process( _dump_jurisdiction_info, self.jurisdiction_fp, jurisdiction, - doc, + extraction_context, seconds_elapsed, usage_tracker, ) @@ -517,7 +525,7 @@ def _dump_usage(fp, tracker): def _dump_jurisdiction_info( - fp, jurisdiction, doc, seconds_elapsed, usage_tracker + fp, jurisdiction, extraction_context, seconds_elapsed, usage_tracker ): """Dump jurisdiction info to an existing file""" if not Path(fp).exists(): @@ -543,16 +551,20 @@ def _dump_jurisdiction_info( } if usage_tracker is not None: - cost = _compute_jurisdiction_cost(usage_tracker) + cost = compute_cost_from_totals(usage_tracker.totals) new_info["cost"] = cost or None - if doc is not None and num_ordinances_in_doc(doc) > 0: + if extraction_context is not None and extraction_context.data_docs: new_info["found"] = True - new_info["documents"] = [_compile_doc_info(doc)] - new_info["jurisdiction_website"] = doc.attrs.get( + new_info["documents"] = [ + _compile_doc_info(doc) for doc in extraction_context.data_docs + ] + new_info["jurisdiction_website"] = extraction_context.attrs.get( "jurisdiction_website" ) - new_info["compass_crawl"] = doc.attrs.get("compass_crawl", False) + new_info["compass_crawl"] = extraction_context.attrs.get( + "compass_crawl", False + ) jurisdiction_info["jurisdictions"].append(new_info) with Path.open(fp, "w", encoding="utf-8") as fh: @@ -572,8 +584,8 @@ def _compile_doc_info(doc): "checksum": doc.attrs.get("checksum"), "is_pdf": isinstance(doc, PDFDocument), "from_ocr": doc.attrs.get("from_ocr", False), - "ordinance_text_ngram_score": doc.attrs.get( - "ordinance_text_ngram_score" + "relevant_text_ngram_score": doc.attrs.get( + "relevant_text_ngram_score" ), "permitted_use_text_ngram_score": doc.attrs.get( "permitted_use_text_ngram_score" @@ -581,26 +593,6 @@ def _compile_doc_info(doc): } -def _compute_jurisdiction_cost(usage_tracker): - """Compute total cost from total tracked usage""" - - total_cost = 0 - for model, total_usage in usage_tracker.totals.items(): - model_costs = LLM_COST_REGISTRY.get(model, {}) - total_cost += ( - total_usage.get("prompt_tokens", 0) - / 1e6 - * model_costs.get("prompt", 0) - ) - total_cost += ( - total_usage.get("response_tokens", 0) - / 1e6 - * model_costs.get("response", 0) - ) - - return total_cost - - def _read_html_file(html_fp, **kwargs): """Default read HTML function (runs in main thread)""" text = Path(html_fp).read_text(encoding="utf-8") diff --git a/compass/utilities/__init__.py b/compass/utilities/__init__.py index b5053a963..30c6cda8c 100644 --- a/compass/utilities/__init__.py +++ b/compass/utilities/__init__.py @@ -1,6 +1,12 @@ """Ordinance utilities""" from .base import Directories, title_preserving_caps +from .costs import ( + LLM_COST_REGISTRY, + cost_for_model, + compute_cost_from_totals, + compute_total_cost_from_usage, +) from .finalize import ( compile_run_summary_message, doc_infos_to_db, @@ -15,11 +21,10 @@ extract_ord_year_from_doc_attrs, llm_response_as_json, merge_overlapping_texts, - num_ordinances_in_doc, num_ordinances_dataframe, ordinances_bool_index, ) -from .nt import ProcessKwargs, TechSpec +from .nt import ProcessKwargs from .io import load_local_docs @@ -45,35 +50,3 @@ " ", "", ] - - -LLM_COST_REGISTRY = { - "o1": {"prompt": 15, "response": 60}, - "o3-mini": {"prompt": 1.1, "response": 4.4}, - "gpt-4.5": {"prompt": 75, "response": 150}, - "gpt-4o": {"prompt": 2.5, "response": 10}, - "gpt-4o-mini": {"prompt": 0.15, "response": 0.6}, - "gpt-4.1": {"prompt": 2, "response": 8}, - "gpt-4.1-mini": {"prompt": 0.4, "response": 1.6}, - "gpt-4.1-nano": {"prompt": 0.1, "response": 0.4}, - "gpt-5": {"prompt": 1.25, "response": 10}, - "gpt-5-mini": {"prompt": 0.25, "response": 2}, - "gpt-5-nano": {"prompt": 0.05, "response": 0.4}, - "gpt-5-chat-latest": {"prompt": 1.25, "response": 10}, - "egswaterord-gpt4.1-mini": {"prompt": 0.4, "response": 1.6}, - "wetosa-gpt-4o": {"prompt": 2.5, "response": 10}, - "wetosa-gpt-4o-mini": {"prompt": 0.15, "response": 0.6}, - "wetosa-gpt-4.1": {"prompt": 2, "response": 8}, - "wetosa-gpt-4.1-mini": {"prompt": 0.4, "response": 1.6}, - "wetosa-gpt-4.1-nano": {"prompt": 0.1, "response": 0.4}, - "wetosa-gpt-5": {"prompt": 1.25, "response": 10}, - "wetosa-gpt-5-mini": {"prompt": 0.25, "response": 2}, - "wetosa-gpt-5-nano": {"prompt": 0.05, "response": 0.4}, - "wetosa-gpt-5-chat-latest": {"prompt": 1.25, "response": 10}, - "text-embedding-ada-002": {"prompt": 0.10}, -} -"""LLM Costs registry - -The registry maps model names to a dictionary that contains the cost -(in $/million tokens) for both prompt and response tokens. -""" diff --git a/compass/utilities/costs.py b/compass/utilities/costs.py new file mode 100644 index 000000000..36b8e6d34 --- /dev/null +++ b/compass/utilities/costs.py @@ -0,0 +1,120 @@ +"""COMPASS cost computation utilities""" + +LLM_COST_REGISTRY = { + "o1": {"prompt": 15, "response": 60}, + "o3-mini": {"prompt": 1.1, "response": 4.4}, + "gpt-4.5": {"prompt": 75, "response": 150}, + "gpt-4o": {"prompt": 2.5, "response": 10}, + "gpt-4o-mini": {"prompt": 0.15, "response": 0.6}, + "gpt-4.1": {"prompt": 2, "response": 8}, + "gpt-4.1-mini": {"prompt": 0.4, "response": 1.6}, + "gpt-4.1-nano": {"prompt": 0.1, "response": 0.4}, + "gpt-5": {"prompt": 1.25, "response": 10}, + "gpt-5-mini": {"prompt": 0.25, "response": 2}, + "gpt-5-nano": {"prompt": 0.05, "response": 0.4}, + "gpt-5-chat-latest": {"prompt": 1.25, "response": 10}, + "compassop-gpt-4o": {"prompt": 2.5, "response": 10}, + "compassop-gpt-4o-mini": {"prompt": 0.15, "response": 0.6}, + "compassop-gpt-4.1": {"prompt": 2, "response": 8}, + "compassop-gpt-4.1-mini": {"prompt": 0.4, "response": 1.6}, + "compassop-gpt-4.1-nano": {"prompt": 0.1, "response": 0.4}, + "compassop-gpt-5": {"prompt": 1.25, "response": 10}, + "compassop-gpt-5-mini": {"prompt": 0.25, "response": 2}, + "compassop-gpt-5-nano": {"prompt": 0.05, "response": 0.4}, + "compassop-gpt-5-chat-latest": {"prompt": 1.25, "response": 10}, + "egswaterord-gpt4.1-mini": {"prompt": 0.4, "response": 1.6}, + "wetosa-gpt-4o": {"prompt": 2.5, "response": 10}, + "wetosa-gpt-4o-mini": {"prompt": 0.15, "response": 0.6}, + "wetosa-gpt-4.1": {"prompt": 2, "response": 8}, + "wetosa-gpt-4.1-mini": {"prompt": 0.4, "response": 1.6}, + "wetosa-gpt-4.1-nano": {"prompt": 0.1, "response": 0.4}, + "wetosa-gpt-5": {"prompt": 1.25, "response": 10}, + "wetosa-gpt-5-mini": {"prompt": 0.25, "response": 2}, + "wetosa-gpt-5-nano": {"prompt": 0.05, "response": 0.4}, + "wetosa-gpt-5-chat-latest": {"prompt": 1.25, "response": 10}, + "text-embedding-ada-002": {"prompt": 0.10}, +} +"""LLM Costs registry + +The registry maps model names to a dictionary that contains the cost +(in $/million tokens) for both prompt and response tokens. +""" + + +def cost_for_model(model_name, prompt_tokens, completion_tokens): + """Compute the API costs for a model given the token usage + + Parameters + ---------- + model_name : str + Name of the model. Needs to be registered as a key in + :obj:`LLM_COST_REGISTRY` for this method to return a non-zero + value. + prompt_tokens, completion_tokens : int + Number of prompt and completion tokens used, respectively. + + Returns + ------- + float + Total cost based on the token usage. + """ + model_costs = LLM_COST_REGISTRY.get(model_name, {}) + prompt_cost = prompt_tokens / 1e6 * model_costs.get("prompt", 0) + response_cost = completion_tokens / 1e6 * model_costs.get("response", 0) + return prompt_cost + response_cost + + +def compute_cost_from_totals(totals): + """Compute total cost from total tracked usage + + Parameters + ---------- + totals : dict + Dictionary where keys are model names and their corresponding + usage statistics are values. Each usage statistics dictionary + should contain "prompt_tokens" and "response_tokens" keys + indicating the number of tokens used for prompts and responses, + respectively. This dictionary is typically obtained from the + `tracker_totals` property of a + :class:`compass.services.usage.UsageTracker` instance. + + Returns + ------- + float + Total cost based on the tracked usage. + """ + return sum( + cost_for_model( + model, + prompt_tokens=usage.get("prompt_tokens", 0), + completion_tokens=usage.get("response_tokens", 0), + ) + for model, usage in totals.items() + ) + + +def compute_total_cost_from_usage(tracked_usage): + """Compute total cost from total tracked usage + + Parameters + ---------- + tracked_usage : compass.services.usage.UsageTracker or dict + Dictionary where keys are usage categories (typically + jurisdiction names) and values are dictionaries containing usage + details. The usage details dictionaries should have a + "tracker_totals" key, which maps to another dictionary. This + innermost dictionary should have model names as keys and their + corresponding usage statistics as values. Each usage statistics + dictionary should contain "prompt_tokens" and "response_tokens" + keys indicating the number of tokens used for prompts and + responses, respectively. + + Returns + ------- + float + Total LLM cost based on the tracked usage. + """ + return sum( + compute_cost_from_totals(usage.get("tracker_totals", {})) + for usage in tracked_usage.values() + ) diff --git a/compass/utilities/enums.py b/compass/utilities/enums.py index d36a732f2..4ea9250f0 100644 --- a/compass/utilities/enums.py +++ b/compass/utilities/enums.py @@ -65,6 +65,9 @@ class LLMTasks(StrEnum): so downstream monitoring does not require additional translation. """ + DATA_EXTRACTION = auto() + """Default Data extraction task""" + DATE_EXTRACTION = LLMUsageCategory.DATE_EXTRACTION """Date extraction task""" @@ -90,6 +93,13 @@ class LLMTasks(StrEnum): EMBEDDING = auto() """Text chunk embedding task""" + TEXT_EXTRACTION = auto() + """Default Text extraction task + + This task represents the extraction/summarization of text containing + information to be extracted into structured data. + """ + JURISDICTION_MAIN_WEBSITE_VALIDATION = ( LLMUsageCategory.JURISDICTION_MAIN_WEBSITE_VALIDATION ) diff --git a/compass/utilities/finalize.py b/compass/utilities/finalize.py index c3b9424b1..b4b9696e6 100644 --- a/compass/utilities/finalize.py +++ b/compass/utilities/finalize.py @@ -10,7 +10,6 @@ from compass import __version__ as compass_version from compass.utilities.parsing import ( - extract_ord_year_from_doc_attrs, num_ordinances_dataframe, ordinances_bool_index, ) @@ -144,7 +143,8 @@ def doc_infos_to_db(doc_infos): results. Each dictionary must contain ``"ord_db_fp"`` (path to a parsed CSV), ``"source"`` (document URL), ``"date"`` (tuple of year, month, day, with ``None`` allowed), and ``"jurisdiction"`` - (a :class:`~compass.utilities.location.Jurisdiction` instance). + (a :class:`~compass.utilities.jurisdictions.Jurisdiction` + instance). Returns ------- @@ -175,7 +175,7 @@ def doc_infos_to_db(doc_infos): if num_ordinances_dataframe(ord_db) == 0: continue - results = _db_results(ord_db, doc_info) + results = _db_results(ord_db, doc_info["jurisdiction"]) results = _formatted_db(results) db.append(results) @@ -219,13 +219,9 @@ def save_db(db, out_dir): quant_db.to_csv(out_dir / "quantitative_ordinances.csv", index=False) -def _db_results(results, doc_info): +def _db_results(results, jurisdiction): """Extract results from doc attrs to DataFrame""" - results["source"] = doc_info.get("source") - results["ord_year"] = extract_ord_year_from_doc_attrs(doc_info) - - jurisdiction = doc_info["jurisdiction"] results["FIPS"] = jurisdiction.code results["county"] = jurisdiction.county results["state"] = jurisdiction.state diff --git a/compass/utilities/io.py b/compass/utilities/io.py index 02f3e9c50..f7a4cd3b7 100644 --- a/compass/utilities/io.py +++ b/compass/utilities/io.py @@ -23,7 +23,7 @@ async def load_local_docs(fps, **kwargs): Returns ------- - list of elm.web.document.BaseDocument + list of BaseDocument Non-empty loaded documents corresponding to the supplied filepaths. Empty results (e.g., unreadable files) are filtered out of the returned list. diff --git a/compass/utilities/jurisdictions.py b/compass/utilities/jurisdictions.py index dfe2797b7..9ce6100b1 100644 --- a/compass/utilities/jurisdictions.py +++ b/compass/utilities/jurisdictions.py @@ -3,6 +3,7 @@ import logging from warnings import warn import importlib.resources +from functools import cached_property import numpy as np import pandas as pd @@ -16,6 +17,165 @@ importlib.resources.files("compass") / "data" / "conus_jurisdictions.csv", importlib.resources.files("compass") / "data" / "tx_water_districts.csv", } +_JUR_COLS = [ + "Jurisdiction Type", + "State", + "County", + "Subdivision", + "FIPS", + "Website", +] +_JURISDICTION_TYPES_AS_PREFIXES = { + "town", + "township", + "city", + "borough", + "village", + "unorganized territory", +} + + +class Jurisdiction: + """Model a geographic jurisdiction used throughout COMPASS + + The class normalizes casing for location components and provides + convenience properties for rendering jurisdiction names with + correct prefixes. It is designed to align with ordinance validation + logic that expects consistent casing and phrasing across states, + counties, and municipal subdivisions. + + Notes + ----- + Instances compare case-insensitively for type and state, while the + county and subdivision name comparisons preserve their stored + casing. Hashing and ``str`` conversions defer to the full display + name generated by :attr:`full_name`. + """ + + def __init__( + self, + subdivision_type, + state, + county=None, + subdivision_name=None, + code=None, + website_url=None, + ): + """ + + Parameters + ---------- + subdivision_type : str + Type of subdivision that this jurisdiction represents. + Typical values are "state", "county", "town", "city", + "borough", "parish", "township", etc. + state : str + Name of the state containing the jurisdiction. + county : str, optional + Name of the county containing the jurisdiction, if + applicable. If the jurisdiction represents a state, leave + this input unspecified. If the jurisdiction represents a + county or a subdivision within a county, provide the county + name here. + + .. IMPORTANT:: Make sure this input is capitalized properly! + + By default, ``None``. + subdivision_name : str, optional + Name of the subdivision that the jurisdiction represents, if + applicable. If the jurisdiction represents a state or + county, leave this input unspecified. Otherwise, provide the + jurisdiction name here. + + .. IMPORTANT:: Make sure this input is capitalized properly! + + By default, ``None``. + code : int or str, optional + Optional jurisdiction code (typically FIPS or similar). + By default, ``None``. + website_url : str, optional + Optional URL for the jurisdiction's main website. + By default, ``None``. + """ + self.type = subdivision_type.title() + self.state = state.title() + self.county = county + self.subdivision_name = subdivision_name + self.code = code + self.website_url = website_url + + @cached_property + def full_name(self): + """str: Comma-separated jurisdiction display name""" + name_parts = [ + self.full_subdivision_phrase, + self.full_county_phrase, + self.state, + ] + + return ", ".join(filter(None, name_parts)) + + @cached_property + def full_name_the_prefixed(self): + """str: Full location name prefixed with ``the`` as needed""" + if self.type.casefold() == "state": + return f"the state of {self.state}" + + if self.type.casefold() in _JURISDICTION_TYPES_AS_PREFIXES: + return f"the {self.full_name}" + + return self.full_name + + @cached_property + def full_subdivision_phrase(self): + """str: Subdivision phrase for the jurisdiction or empty str""" + if not self.subdivision_name: + return "" + + if self.type.casefold() in _JURISDICTION_TYPES_AS_PREFIXES: + return f"{self.type} of {self.subdivision_name}" + + return f"{self.subdivision_name} {self.type}" + + @cached_property + def full_subdivision_phrase_the_prefixed(self): + """str: Subdivision phrase prefixed with ``the`` as needed""" + if self.type.casefold() in _JURISDICTION_TYPES_AS_PREFIXES: + return f"the {self.full_subdivision_phrase}" + + return self.full_subdivision_phrase + + @cached_property + def full_county_phrase(self): + """str: County phrase for the jurisdiction or empty str""" + if not self.county: + return "" + + if not self.subdivision_name: + return f"{self.county} {self.type}" + + return f"{self.county} County" + + def __repr__(self): + return str(self) + + def __str__(self): + return self.full_name + + def __eq__(self, other): + if isinstance(other, self.__class__): + return ( + self.type.casefold() == other.type.casefold() + and self.state.casefold() == other.state.casefold() + and self.county == other.county + and self.subdivision_name == other.subdivision_name + ) + if isinstance(other, str): + return self.full_name.casefold() == other.casefold() + return False + + def __hash__(self): + return hash(self.full_name.casefold()) def load_all_jurisdiction_info(): @@ -132,6 +292,38 @@ def load_jurisdictions_from_fp(jurisdiction_fp): return _format_jurisdiction_df_for_output(jurisdictions) +def jurisdictions_from_df(jurisdiction_info=None): + """Convert rows DataFrame into Jurisdiction instances + + Parameters + ---------- + jurisdiction_info : pandas.DataFrame, optional + DataFrame containing jurisdiction info with columns: + ``["Jurisdiction Type", "State", "County", "Subdivision", + "FIPS", "Website"]``. If ``None``, this info is loaded using + :func:`load_all_jurisdiction_info`. By default, ``None``. + + Yields + ------ + Jurisdiction + Jurisdiction instance built from each row of the input + DataFrame. + """ + if jurisdiction_info is None: + jurisdiction_info = load_all_jurisdiction_info() + + for __, row in jurisdiction_info.iterrows(): + jur_type, state, county, sub, fips, website = row[_JUR_COLS] + yield Jurisdiction( + subdivision_type=jur_type, + state=state, + county=county, + subdivision_name=sub, + code=fips, + website_url=website, + ) + + def _validate_jurisdiction_input(jurisdictions): """Throw error if user is missing required columns""" if "State" not in jurisdictions: diff --git a/compass/utilities/location.py b/compass/utilities/location.py deleted file mode 100644 index 6ef8ece35..000000000 --- a/compass/utilities/location.py +++ /dev/null @@ -1,151 +0,0 @@ -"""COMPASS Ordinance jurisdiction specification utilities""" - -from functools import cached_property - - -JURISDICTION_TYPES_AS_PREFIXES = { - "town", - "township", - "city", - "borough", - "village", - "unorganized territory", -} - - -class Jurisdiction: - """Model a geographic jurisdiction used throughout COMPASS - - The class normalizes casing for location components and provides - convenience properties for rendering jurisdiction names with - correct prefixes. It is designed to align with ordinance validation - logic that expects consistent casing and phrasing across states, - counties, and municipal subdivisions. - - Notes - ----- - Instances compare case-insensitively for type and state, while the - county and subdivision name comparisons preserve their stored - casing. Hashing and ``str`` conversions defer to the full display - name generated by :attr:`full_name`. - """ - - def __init__( - self, - subdivision_type, - state, - county=None, - subdivision_name=None, - code=None, - ): - """ - - Parameters - ---------- - subdivision_type : str - Type of subdivision that this jurisdiction represents. - Typical values are "state", "county", "town", "city", - "borough", "parish", "township", etc. - state : str - Name of the state containing the jurisdiction. - county : str, optional - Name of the county containing the jurisdiction, if - applicable. If the jurisdiction represents a state, leave - this input unspecified. If the jurisdiction represents a - county or a subdivision within a county, provide the county - name here. - - .. IMPORTANT:: Make sure this input is capitalized properly! - - By default, ``None``. - subdivision_name : str, optional - Name of the subdivision that the jurisdiction represents, if - applicable. If the jurisdiction represents a state or - county, leave this input unspecified. Otherwise, provide the - jurisdiction name here. - - .. IMPORTANT:: Make sure this input is capitalized properly! - - By default, ``None``. - code : int or str, optional - Optional jurisdiction code (typically FIPS or similar). - By default, ``None``. - """ - self.type = subdivision_type.title() - self.state = state.title() - self.county = county - self.subdivision_name = subdivision_name - self.code = code - - @cached_property - def full_name(self): - """str: Comma-separated jurisdiction display name""" - name_parts = [ - self.full_subdivision_phrase, - self.full_county_phrase, - self.state, - ] - - return ", ".join(filter(None, name_parts)) - - @cached_property - def full_name_the_prefixed(self): - """str: Full location name prefixed with ``the`` as needed""" - if self.type.casefold() == "state": - return f"the state of {self.state}" - - if self.type.casefold() in JURISDICTION_TYPES_AS_PREFIXES: - return f"the {self.full_name}" - - return self.full_name - - @cached_property - def full_subdivision_phrase(self): - """str: Subdivision phrase for the jurisdiction or empty str""" - if not self.subdivision_name: - return "" - - if self.type.casefold() in JURISDICTION_TYPES_AS_PREFIXES: - return f"{self.type} of {self.subdivision_name}" - - return f"{self.subdivision_name} {self.type}" - - @cached_property - def full_subdivision_phrase_the_prefixed(self): - """str: Subdivision phrase prefixed with ``the`` as needed""" - if self.type.casefold() in JURISDICTION_TYPES_AS_PREFIXES: - return f"the {self.full_subdivision_phrase}" - - return self.full_subdivision_phrase - - @cached_property - def full_county_phrase(self): - """str: County phrase for the jurisdiction or empty str""" - if not self.county: - return "" - - if not self.subdivision_name: - return f"{self.county} {self.type}" - - return f"{self.county} County" - - def __repr__(self): - return str(self) - - def __str__(self): - return self.full_name - - def __eq__(self, other): - if isinstance(other, self.__class__): - return ( - self.type.casefold() == other.type.casefold() - and self.state.casefold() == other.state.casefold() - and self.county == other.county - and self.subdivision_name == other.subdivision_name - ) - if isinstance(other, str): - return self.full_name.casefold() == other.casefold() - return False - - def __hash__(self): - return hash(self.full_name.casefold()) diff --git a/compass/utilities/nt.py b/compass/utilities/nt.py index 5fed9bca1..c823e39af 100644 --- a/compass/utilities/nt.py +++ b/compass/utilities/nt.py @@ -41,58 +41,3 @@ Maximum number of jurisdictions processed simultaneously. By default, ``25``. """ - -TechSpec = namedtuple( - "TechSpec", - [ - "name", - "questions", - "heuristic", - "ordinance_text_collector", - "ordinance_text_extractor", - "permitted_use_text_collector", - "permitted_use_text_extractor", - "structured_ordinance_parser", - "structured_permitted_use_parser", - "website_url_keyword_points", - "post_download_docs_hook", - "post_filter_docs_hook", - "extract_ordinances_callback", - "num_ordinances_in_df_callback", - "save_db_callback", - ], - defaults=[None, None, None, None, None], -) -TechSpec.__doc__ = """Bundle extraction configuration for a technology - -Parameters ----------- -name : str - Display name for the technology (e.g., ``"solar"``). -questions : dict - Prompt templates or question sets used during extraction. -heuristic : callable - Function implementing heuristic filters prior to LLM invocation. -ordinance_text_collector : callable - Callable that gathers candidate ordinance text spans. -ordinance_text_extractor : callable - Callable that extracts relevant ordinance snippets. -permitted_use_text_collector : callable - Callable that gathers candidate permitted-use text spans. -permitted_use_text_extractor : callable - Callable that extracts permitted-use content. -structured_ordinance_parser : callable - Callable that transforms ordinance text into structured values. -structured_permitted_use_parser : callable - Callable that transforms permitted-use text into structured values. -website_url_keyword_points : dict or None - Weightings for scoring website URLs during search. -post_download_docs_hook : callable or None - Optional async function to filter/process downloaded documents. -post_filter_docs_hook : callable or None - Optional async function to filter/process filtered documents. -extract_ordinances_callback : callable or None - Optional async function to extract ordinance data from documents. -save_db_callback : callable or None - Optional **sync** function to save ordinance database to disk. -""" diff --git a/compass/utilities/parsing.py b/compass/utilities/parsing.py index 7e3c02156..f87d74c6f 100644 --- a/compass/utilities/parsing.py +++ b/compass/utilities/parsing.py @@ -141,31 +141,6 @@ def extract_ord_year_from_doc_attrs(doc_attrs): return year if year is not None and year > 0 else None -def num_ordinances_in_doc(doc, exclude_features=None): - """Count the number of ordinance entries on a document - - Parameters - ---------- - doc : elm.web.document.BaseDocument - Document potentially containing ordinances for a jurisdiction. - If no ordinance values are found, this function returns ``0``. - exclude_features : iterable of str, optional - Optional features to exclude from ordinance count. - By default, ``None``. - - Returns - ------- - int - Number of ordinance rows represented in ``doc``. - """ - if doc is None or doc.attrs.get("ordinance_values") is None: - return 0 - - return num_ordinances_dataframe( - doc.attrs["ordinance_values"], exclude_features=exclude_features - ) - - def num_ordinances_dataframe(data, exclude_features=None): """Count ordinance rows contained in a DataFrame diff --git a/compass/validation/graphs.py b/compass/validation/graphs.py index 66e8c0318..e73b93f6e 100644 --- a/compass/validation/graphs.py +++ b/compass/validation/graphs.py @@ -278,7 +278,7 @@ def setup_graph_correct_jurisdiction_type(jurisdiction, **kwargs): Parameters ---------- - jurisdiction : compass.utilities.location.Jurisdiction + jurisdiction : Jurisdiction Target jurisdiction descriptor that guides prompt wording. **kwargs Additional keyword arguments forwarded to @@ -509,7 +509,7 @@ def setup_graph_correct_jurisdiction_from_url(jurisdiction, **kwargs): Parameters ---------- - jurisdiction : compass.utilities.location.Jurisdiction + jurisdiction : Jurisdiction Jurisdiction descriptor supplying state, county, and subdivision phrases used in prompts. **kwargs diff --git a/compass/validation/location.py b/compass/validation/location.py index 25d994ae6..6339d0ae9 100644 --- a/compass/validation/location.py +++ b/compass/validation/location.py @@ -36,7 +36,7 @@ def __init__(self, jurisdiction, **kwargs): Parameters ---------- - jurisdiction : compass.utilities.location.Jurisdiction + jurisdiction : Jurisdiction Jurisdiction descriptor with the target location attributes. **kwargs Additional keyword arguments forwarded to @@ -125,7 +125,7 @@ def __init__(self, jurisdiction, **kwargs): Parameters ---------- - jurisdiction : compass.utilities.location.Jurisdiction + jurisdiction : Jurisdiction Jurisdiction descriptor identifying expected applicability. **kwargs Additional keyword arguments forwarded to @@ -226,11 +226,11 @@ async def check(self, doc, jurisdiction): Parameters ---------- - doc : elm.web.document.BaseDocument + doc : BaseDocument Document to evaluate. The validator expects ``doc.raw_pages`` and, when available, a ``doc.attrs['source']`` URL for supplemental URL validation. - jurisdiction : compass.utilities.location.Jurisdiction + jurisdiction : Jurisdiction Target jurisdiction descriptor capturing the required location attributes. @@ -345,7 +345,7 @@ async def check(self, url, jurisdiction): ---------- url : str URL to inspect. Empty values return ``False`` immediately. - jurisdiction : compass.utilities.location.Jurisdiction + jurisdiction : Jurisdiction Target jurisdiction descriptor used to frame the validation prompts. diff --git a/docs/source/conf.py b/docs/source/conf.py index f2745b009..c5f4d1e8a 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -348,15 +348,21 @@ def setup(app): "ApiBase": ":class:`~elm.base.ApiBase`", # objects from COMPASS "AsyncDecisionTree": ":class:`~compass.common.tree.AsyncDecisionTree`", - "Jurisdiction": ":class:`~compass.utilities.location.Jurisdiction`", - "LLMCaller": ":class:`~compass.llm.calling.LLMCaller`", + "BaseDocument": ":class:`elm.web.document.BaseDocument`", + "BaseParser": ":class:`~compass.plugin.interface.BaseParser`", + "BaseTextExtractor": ":class:`~compass.plugin.interface.BaseTextExtractor`", "ChatLLMCaller": ":class:`~compass.llm.calling.ChatLLMCaller`", - "StructuredLLMCaller": ":class:`~compass.llm.calling.StructuredLLMCaller`", - "Service": ":class:`~compass.services.base.Service`", + "ExtractionContext": ":class:`~compass.extraction.context.ExtractionContext`", + "Jurisdiction": ":class:`~compass.utilities.jurisdictions.Jurisdiction`", + "LLMCaller": ":class:`~compass.llm.calling.LLMCaller`", + "LLMConfig": ":class:`~compass.llm.config.LLMConfig`", "LLMService": ":class:`~compass.services.base.LLMService`", + "LLMTasks": ":class:`~compass.utilities.enums.LLMTasks`", "OpenAIService": ":class:`~compass.services.openai.OpenAIService`", + "ParseChunksWithMemory": ":class:`~compass.validation.content.ParseChunksWithMemory`", + "Service": ":class:`~compass.services.base.Service`", + "StructuredLLMCaller": ":class:`~compass.llm.calling.StructuredLLMCaller`", "TimeBoundedUsageTracker": ":class:`~compass.services.usage.TimeBoundedUsageTracker`", "UsageTracker": ":class:`~compass.services.usage.UsageTracker`", - "ParseChunksWithMemory": ":class:`~compass.validation.content.ParseChunksWithMemory`", "WindOrdinanceTextExtractor": ":class:`~compass.extraction.wind.ordinance.WindOrdinanceTextExtractor`", } diff --git a/examples/parse_existing_docs/README.rst b/examples/parse_existing_docs/README.rst index c2a212f10..b61646ccc 100644 --- a/examples/parse_existing_docs/README.rst +++ b/examples/parse_existing_docs/README.rst @@ -23,8 +23,8 @@ The ordinance extraction pipeline is composed of several components, which are v .. mermaid:: flowchart LR - A -->|Ordinance Content Checker
(check_for_ordinance_info)| B - B -->|Ordinance Text Extractor
(extract_ordinance_text_with_llm)| C + A -->|Ordinance Content Checker
(check_for_relevant_text)| B + B -->|Ordinance Text Extractor
(extract_relevant_text_with_llm)| C C -->|Ordinance Extractor
(extract_ordinance_values)| D A@{ shape: lined-document, label: "Ordinance Document
(PDFDocument)" } B@{ shape: docs, label: "Ordinance Text Chunks
(str)"} @@ -121,16 +121,16 @@ The first functional step in extraction is to determine whether the document con Even if you're fairly sure that it does, this step is essential because it saves the text chunks from the document that contain ordinance information, enabling the next step — ordinance text extraction. -To do this, we'll use the :func:`~compass.extraction.apply.check_for_ordinance_info` function. This function uses a +To do this, we'll use the :func:`~compass.extraction.apply.check_for_relevant_text` function. This function uses a combination of keyword heuristics and LLM evaluation to identify ordinance content and collect it into a new field in the document. Here's how that might look: .. code-block:: python - from compass.extraction.apply import check_for_ordinance_info + from compass.extraction.apply import check_for_relevant_text from compass.extraction.solar import SolarHeuristic, SolarOrdinanceTextCollector - doc = await check_for_ordinance_info( + doc = await check_for_relevant_text( doc, model_config=llm_config, heuristic=SolarHeuristic(), @@ -161,15 +161,15 @@ Once we've located the general sections where the ordinances are mentioned, we'l The identified chunks are often too broad to use directly in downstream processing, so we'll pass them through another LLM-powered step that filters the content to only the most relevant ordinance language. -We'll do that using the :func:`~compass.extraction.apply.extract_ordinance_text_with_llm` function: +We'll do that using the :func:`~compass.extraction.apply.extract_relevant_text_with_llm` function: .. code-block:: python from compass.llm import LLMCaller - from compass.extraction.apply import extract_ordinance_text_with_llm + from compass.extraction.apply import extract_relevant_text_with_llm from compass.extraction.solar import SolarOrdinanceTextExtractor - doc, ord_text_key = await extract_ordinance_text_with_llm( + doc, ord_text_key = await extract_relevant_text_with_llm( doc, llm_config.text_splitter, extractor=SolarOrdinanceTextExtractor( @@ -182,7 +182,7 @@ This step reads the raw text chunks stored in ``doc.attrs["ordinance_text"]`` an the ordinance language itself. The first argument to this function is the ordinance document, which must contain an ``"ordinance_text"`` key in its ``doc.attrs`` dictionary. This key holds the concatenated text chunks identified as likely containing ordinance information. It's automatically added for us by the -:func:`~compass.extraction.apply.check_for_ordinance_info` function — assuming ordinance text is present. +:func:`~compass.extraction.apply.check_for_relevant_text` function — assuming ordinance text is present. Next, we pass in the text splitter instance, which will be used to divide the concatenated text into smaller chunks. We also provide a :class:`~compass.extraction.solar.ordinance.SolarOrdinanceTextExtractor` instance, which performs the diff --git a/examples/parse_existing_docs/parse_pdf.py b/examples/parse_existing_docs/parse_pdf.py index 4943d335b..d94da0d20 100644 --- a/examples/parse_existing_docs/parse_pdf.py +++ b/examples/parse_existing_docs/parse_pdf.py @@ -13,17 +13,17 @@ from elm.utilities import validate_azure_api_params from compass.llm import LLMCaller, OpenAIConfig -from compass.extraction.solar import ( - SolarOrdinanceTextExtractor, +from compass.extraction.solar.plugin import ( SolarHeuristic, SolarOrdinanceTextCollector, + SolarOrdinanceTextExtractor, StructuredSolarOrdinanceParser, ) from compass.services.provider import RunningAsyncServices from compass.extraction.apply import ( extract_ordinance_values, - check_for_ordinance_info, - extract_ordinance_text_with_llm, + check_for_relevant_text, + extract_relevant_text_with_llm, ) from compass.utilities.logs import AddLocationFilter from compass.utilities.enums import LLMTasks @@ -63,13 +63,12 @@ async def _extract_ordinances(doc, model_configs): LLMTasks.DOCUMENT_CONTENT_VALIDATION, model_configs[LLMTasks.DEFAULT], ) - doc = await check_for_ordinance_info( + doc = await check_for_relevant_text( doc, model_config=model_config, heuristic=SolarHeuristic(), tech="solar", - ordinance_text_collector_class=SolarOrdinanceTextCollector, - permitted_use_text_collector_class=None, + text_collectors=[SolarOrdinanceTextCollector], ) logger.info("Extracting ordinance text from document...") @@ -77,13 +76,13 @@ async def _extract_ordinances(doc, model_configs): LLMTasks.ORDINANCE_TEXT_EXTRACTION, model_configs[LLMTasks.DEFAULT], ) - doc, ord_text_key = await extract_ordinance_text_with_llm( + doc, ord_text_key = await extract_relevant_text_with_llm( doc, model_config.text_splitter, extractor=SolarOrdinanceTextExtractor( LLMCaller(llm_service=model_config.llm_service) ), - original_text_key="ordinance_text", + original_text_key=SolarOrdinanceTextExtractor.IN_LABEL, ) logger.info( @@ -99,7 +98,7 @@ async def _extract_ordinances(doc, model_configs): llm_service=model_config.llm_service ), text_key=ord_text_key, - out_key="ordinance_values", + out_key=StructuredSolarOrdinanceParser.OUT_LABEL, ) @@ -163,9 +162,10 @@ async def _extract_ordinances(doc, model_configs): # save outputs ( - doc.attrs["ordinance_values"] + doc.attrs[StructuredSolarOrdinanceParser.OUT_LABEL] .drop(columns=["quantitative"], errors="ignore") .to_csv(fp_ord, index=False) ) - with Path(fp_txt_ord_text).open("w", encoding="utf-8") as fh: - fh.write(doc.attrs["cleaned_ordinance_text"]) + Path(fp_txt_ord_text).write_text( + doc.attrs[SolarOrdinanceTextExtractor.OUT_LABEL], encoding="utf-8" + ) diff --git a/tests/python/unit/extraction/test_extraction_context.py b/tests/python/unit/extraction/test_extraction_context.py new file mode 100644 index 000000000..57ad9ff68 --- /dev/null +++ b/tests/python/unit/extraction/test_extraction_context.py @@ -0,0 +1,424 @@ +"""COMPASS extraction context tests""" + +from pathlib import Path + +import pandas as pd +import pytest +from elm.web.document import PDFDocument, HTMLDocument + +from compass.extraction.context import ( + ExtractionContext, + _as_list, + _attrs_repr, + _data_docs_repr, + _move_file_to_out_dir, +) +from compass.exceptions import COMPASSTypeError +from compass.services.threaded import FileMover + + +def test_extraction_context_iter_empty(): + """Test empty ExtractionContext iteration""" + for __ in ExtractionContext(): + msg = "Should not iterate over any documents" + raise AssertionError(msg) + + +def test_extraction_context_iter_non_iterable(): + """Test non-iterable ExtractionContext iteration""" + test_doc = PDFDocument([]) + for x in ExtractionContext(test_doc): + assert isinstance(x, PDFDocument) + assert x is test_doc + + +@pytest.mark.parametrize( + "test_input", (("a", "b"), ["a", "b"], {"a": 1, "b": 2}) +) +def test_extraction_context_iter_sequence(test_input): + """Test non-sequence ExtractionContext iteration""" + for x, y in zip(ExtractionContext(test_input), test_input, strict=True): + assert x == y + + +def test_extraction_context_set(): + """Test non-sequence ExtractionContext iteration""" + test_input = {"a", "b"} + test = ExtractionContext(test_input) + assert set(test) == test_input + + +def test_extraction_context_text_empty(): + """Test text property with empty context""" + ctx = ExtractionContext() + assert not ctx.text + + +def test_extraction_context_text_single_doc(): + """Test text property with single document""" + doc = PDFDocument(["page one", "page two"]) + ctx = ExtractionContext(doc) + assert ctx.text == "page one\npage two" + + +def test_extraction_context_text_multiple_docs(): + """Test text property concatenates multiple documents""" + doc1 = PDFDocument(["doc1 page1", "doc1 page2"]) + doc2 = HTMLDocument(["

doc2 content

"]) + ctx = ExtractionContext([doc1, doc2]) + expected = "doc1 page1\ndoc1 page2\n\ndoc2 content\n\n" + assert ctx.text == expected + + +def test_extraction_context_pages_empty(): + """Test pages property with empty context""" + ctx = ExtractionContext() + assert ctx.pages == [] + + +def test_extraction_context_pages_single_doc(): + """Test pages property with single document""" + doc = PDFDocument(["page 1", "page 2", "page 3"]) + ctx = ExtractionContext(doc) + assert ctx.pages == ["page 1", "page 2", "page 3"] + + +def test_extraction_context_pages_multiple_docs(): + """Test pages property flattens multiple documents""" + doc1 = PDFDocument(["doc1 p1", "doc1 p2"]) + doc2 = PDFDocument(["doc2 p1"]) + doc3 = HTMLDocument(["doc3 content"]) + ctx = ExtractionContext([doc1, doc2, doc3]) + assert ctx.pages == ["doc1 p1", "doc1 p2", "doc2 p1", "doc3 content"] + + +def test_extraction_context_num_documents(): + """Test num_documents property""" + assert ExtractionContext().num_documents == 0 + assert ExtractionContext(PDFDocument([])).num_documents == 1 + doc_list = [PDFDocument([]), HTMLDocument([""]), PDFDocument([])] + assert ExtractionContext(doc_list).num_documents == 3 + + +def test_extraction_context_documents_getter(): + """Test documents property getter""" + doc1 = PDFDocument(["test"]) + doc2 = HTMLDocument(["html"]) + ctx = ExtractionContext([doc1, doc2]) + assert ctx.documents == [doc1, doc2] + assert ctx.documents[0] is doc1 + assert ctx.documents[1] is doc2 + + +def test_extraction_context_documents_setter(): + """Test documents property setter""" + ctx = ExtractionContext() + assert ctx.documents == [] + + doc = PDFDocument(["page"]) + ctx.documents = doc + assert ctx.documents == [doc] + + doc_list = [PDFDocument([]), HTMLDocument([])] + ctx.documents = doc_list + assert ctx.documents == doc_list + + +def test_extraction_context_data_docs_getter(): + """Test data_docs property getter""" + ctx = ExtractionContext() + assert ctx.data_docs == [] + + ctx._data_docs = [PDFDocument([])] + assert len(ctx.data_docs) == 1 + + +def test_extraction_context_data_docs_setter_valid(): + """Test data_docs property setter with valid list""" + ctx = ExtractionContext() + doc_list = [PDFDocument([]), HTMLDocument([])] + ctx.data_docs = doc_list + assert ctx.data_docs == doc_list + + +def test_extraction_context_data_docs_setter_invalid(): + """Test data_docs property setter raises for non-list""" + ctx = ExtractionContext() + + with pytest.raises(COMPASSTypeError, match="must be set to a \\*list\\*"): + ctx.data_docs = PDFDocument([]) + + with pytest.raises(COMPASSTypeError, match="must be set to a \\*list\\*"): + ctx.data_docs = {"doc": PDFDocument([])} + + with pytest.raises(COMPASSTypeError, match="must be set to a \\*list\\*"): + ctx.data_docs = (PDFDocument([]),) + + +def test_extraction_context_str_empty(): + """Test string representation of empty context""" + ctx = ExtractionContext() + result = str(ctx) + assert "ExtractionContext with 0 documents" in result + assert "Registered Data Source Documents: None" in result + assert "Attrs: None" in result + + +def test_extraction_context_str_single_doc(): + """Test string representation with single document""" + doc = PDFDocument(["test"]) + doc.attrs["source"] = "http://example.com/doc.pdf" + ctx = ExtractionContext(doc) + result = str(ctx) + assert "ExtractionContext with 1 document" in result + assert "http://example.com/doc.pdf" in result + + +def test_extraction_context_str_multiple_docs(): + """Test string representation with multiple documents""" + doc1 = PDFDocument(["test"]) + doc1.attrs["source"] = "source1.pdf" + doc2 = HTMLDocument(["html"]) + doc2.attrs["source"] = "source2.html" + ctx = ExtractionContext([doc1, doc2]) + result = str(ctx) + assert "ExtractionContext with 2 documents" in result + assert "source1.pdf" in result + assert "source2.html" in result + + +def test_extraction_context_str_with_data_docs(): + """Test string representation with registered data docs""" + doc1 = PDFDocument(["test"]) + doc1.attrs["source"] = "main.pdf" + ctx = ExtractionContext(doc1) + + data_doc = PDFDocument(["data"]) + data_doc.attrs["source"] = "data_source.pdf" + ctx.data_docs = [data_doc] + + result = str(ctx) + assert "Registered Data Source Documents:" in result + assert "data_source.pdf" in result + + +def test_extraction_context_str_with_attrs(): + """Test string representation with attributes""" + attrs = {"jurisdiction": "Test County", "year": 2025} + ctx = ExtractionContext(attrs=attrs) + result = str(ctx) + assert "Attrs:" in result + assert "jurisdiction" in result + assert "Test County" in result + assert "year" in result + assert "2025" in result + + +def test_extraction_context_str_with_dataframe_attr(): + """Test string representation with DataFrame attribute""" + df = pd.DataFrame({"col1": [1, 2, 3], "col2": ["a", "b", "c"]}) + ctx = ExtractionContext(attrs={"table_data": df}) + result = str(ctx) + assert "Attrs:" in result + assert "table_data" in result + assert "DataFrame with 3 rows" in result + + +def test_extraction_context_str_with_multiline_attr(): + """Test string representation with multiline attribute""" + long_text = "Line 1\nLine 2\nLine 3\n" + "x" * 200 + ctx = ExtractionContext(attrs={"long_text": long_text}) + result = str(ctx) + assert "Attrs:" in result + assert "long_text" in result + assert len(result) < len(long_text) + 100 + + +def test_extraction_context_len(): + """Test __len__ returns number of documents""" + assert len(ExtractionContext()) == 0 + assert len(ExtractionContext(PDFDocument([]))) == 1 + assert len(ExtractionContext([PDFDocument([]), HTMLDocument([])])) == 2 + + +def test_extraction_context_getitem(): + """Test __getitem__ indexing""" + doc1 = PDFDocument(["first"]) + doc2 = HTMLDocument(["second"]) + ctx = ExtractionContext([doc1, doc2]) + + assert ctx[0] is doc1 + assert ctx[1] is doc2 + assert ctx[-1] is doc2 + + +def test_extraction_context_bool(): + """Test __bool__ conversion""" + assert not ExtractionContext() + assert not ExtractionContext(None) + assert ExtractionContext(PDFDocument([])) + assert ExtractionContext([PDFDocument([])]) + + +@pytest.mark.asyncio +async def test_mark_doc_as_data_source_no_file_move(): + """Test marking document without file moving""" + ctx = ExtractionContext() + doc = PDFDocument(["test content"]) + doc.attrs["source"] = "test.pdf" + + await ctx.mark_doc_as_data_source(doc) + + assert doc in ctx.data_docs + assert len(ctx.data_docs) == 1 + assert "out_fp" not in doc.attrs + + +@pytest.mark.asyncio +async def test_mark_doc_as_data_source_with_file_move(monkeypatch, tmp_path): + """Test marking document with file moving""" + out_file = tmp_path / "output.pdf" + + async def fake_file_mover(doc_arg, out_fn): # noqa + assert out_fn == "output.pdf" + return out_file + + monkeypatch.setattr(FileMover, "call", fake_file_mover) + + ctx = ExtractionContext() + doc = PDFDocument(["test content"]) + doc.attrs["source"] = "test.pdf" + + await ctx.mark_doc_as_data_source(doc, out_fn_stem="output.pdf") + + assert doc in ctx.data_docs + assert len(ctx.data_docs) == 1 + assert doc.attrs["out_fp"] == out_file + + +@pytest.mark.asyncio +async def test_move_file_to_out_dir(monkeypatch, tmp_path): + """Test _move_file_to_out_dir helper""" + output_path = tmp_path / "moved.pdf" + + async def fake_mover(doc_arg, out_fn): # noqa + assert out_fn == "output_name.pdf" + return output_path + + monkeypatch.setattr(FileMover, "call", fake_mover) + + doc = PDFDocument(["content"]) + doc.attrs["source"] = "original.pdf" + + result = await _move_file_to_out_dir(doc, "output_name.pdf") + + assert result is doc + assert doc.attrs["out_fp"] == output_path + + +@pytest.mark.parametrize( + "input_val", + [ + None, + PDFDocument([]), + [PDFDocument([])], + ["a", "b", "c"], + ("x", "y"), + {"key": "value"}, + ], +) +def test_as_list_conversions(input_val): + """Test _as_list helper with various inputs""" + result = _as_list(input_val) + assert isinstance(result, list) + if input_val is None: + assert result == [] + elif isinstance(input_val, (list, tuple)): + assert result == list(input_val) + else: + assert len(result) == 1 + + +def test_as_list_preserves_type(): + """Test _as_list converts to list properly""" + result = _as_list(("a", "b")) + assert isinstance(result, list) + assert not isinstance(result, tuple) + + +def test_data_docs_repr_empty(): + """Test _data_docs_repr with empty list""" + result = _data_docs_repr([]) + assert result == "Registered Data Source Documents: None" + + +def test_data_docs_repr_with_docs(): + """Test _data_docs_repr with documents""" + doc1 = PDFDocument(["test"]) + doc1.attrs["source"] = "source1.pdf" + doc2 = HTMLDocument(["html"]) + doc2.attrs["source"] = "source2.html" + + result = _data_docs_repr([doc1, doc2]) + assert "Registered Data Source Documents:" in result + assert "source1.pdf" in result + assert "source2.html" in result + + +def test_data_docs_repr_missing_source(): + """Test _data_docs_repr with missing source attribute""" + doc = PDFDocument(["test"]) + result = _data_docs_repr([doc]) + assert "Unknown source" in result + + +def test_attrs_repr_empty(): + """Test _attrs_repr with empty dict""" + result = _attrs_repr({}) + assert result == "Attrs: None" + + +def test_attrs_repr_simple_values(): + """Test _attrs_repr with simple key-value pairs""" + attrs = {"jurisdiction": "Test County", "year": 2025, "active": True} + result = _attrs_repr(attrs) + assert "Attrs:" in result + assert "jurisdiction" in result + assert "Test County" in result + assert "year" in result + assert "2025" in result + assert "active" in result + + +def test_attrs_repr_with_dataframe(): + """Test _attrs_repr formats DataFrames""" + df = pd.DataFrame({"a": [1, 2, 3, 4, 5]}) + attrs = {"my_table": df} + result = _attrs_repr(attrs) + assert "Attrs:" in result + assert "my_table" in result + assert "DataFrame with 5 rows" in result + + +def test_attrs_repr_with_multiline_string(): + """Test _attrs_repr shortens multiline strings""" + long_text = "\n".join([f"Line {i}" for i in range(50)]) + attrs = {"description": long_text} + result = _attrs_repr(attrs) + assert "Attrs:" in result + assert "description" in result + assert len(result) < len(long_text) + 50 + + +def test_attrs_repr_formatting_alignment(): + """Test _attrs_repr aligns values properly""" + attrs = {"short": "val", "very_long_key": "value2"} + result = _attrs_repr(attrs) + lines = result.split("\n")[1:] + assert len(lines) == 2 + assert "\t" in lines[0] + assert "\t" in lines[1] + + +if __name__ == "__main__": + pytest.main(["-q", "--show-capture=all", Path(__file__), "-rapP"]) diff --git a/tests/python/unit/plugin/test_plugin_interface.py b/tests/python/unit/plugin/test_plugin_interface.py new file mode 100644 index 000000000..fd79ff86b --- /dev/null +++ b/tests/python/unit/plugin/test_plugin_interface.py @@ -0,0 +1,176 @@ +"""COMPASS web crawling tests""" + +from pathlib import Path + +import pytest + +from compass.plugin.interface import ExtractionPlugin +from compass.exceptions import COMPASSPluginConfigurationError + + +def test_plugin_validation_parse_key_same(): + """Test plugin interface validation logic""" + + class COLL1: + OUT_LABEL = "collected" + + class EXT1: + IN_LABEL = "collected" + OUT_LABEL = "extracted" + + class EXT2: + IN_LABEL = "collected" + OUT_LABEL = "extracted_2" + + class PARS1: + IN_LABEL = "extracted" + OUT_LABEL = "parsed_1" + + class PARS2: + IN_LABEL = "collected" + OUT_LABEL = "parsed_1" + + class MYPlugin(ExtractionPlugin): + TEXT_COLLECTORS = [COLL1] + TEXT_EXTRACTORS = [EXT1, EXT2] + PARSERS = [PARS1, PARS2] + + IDENTIFIER = "test" + WEBSITE_KEYWORDS = [] + QUESTION_TEMPLATES = [] + heuristic = None + + with pytest.raises( + COMPASSPluginConfigurationError, + match="Multiple processing classes produce the same OUT_LABEL key", + ): + MYPlugin(None, None, None) + + +def test_plugin_validation_extract_key_same(): + """Test plugin interface validation logic""" + + class COLL1: + OUT_LABEL = "collected" + + class EXT1: + IN_LABEL = "collected" + OUT_LABEL = "extracted" + + class EXT2: + IN_LABEL = "collected" + OUT_LABEL = "extracted" + + class PARS1: + IN_LABEL = "extracted" + OUT_LABEL = "parsed_1" + + class PARS2: + IN_LABEL = "collected" + OUT_LABEL = "parsed_2" + + class MYPlugin(ExtractionPlugin): + TEXT_COLLECTORS = [COLL1] + TEXT_EXTRACTORS = [EXT1, EXT2] + PARSERS = [PARS1, PARS2] + + IDENTIFIER = "test" + WEBSITE_KEYWORDS = [] + QUESTION_TEMPLATES = [] + heuristic = None + + with pytest.raises( + COMPASSPluginConfigurationError, + match="Multiple processing classes produce the same OUT_LABEL key", + ): + MYPlugin(None, None, None) + + +def test_plugin_validation_no_in_key_for_extract(): + """Test plugin interface validation logic""" + + class COLL1: + OUT_LABEL = "collected" + + class EXT1: + IN_LABEL = "collected" + OUT_LABEL = "extracted" + + class EXT2: + IN_LABEL = "collected_2" + OUT_LABEL = "extracted_1" + + class PARS1: + IN_LABEL = "extracted" + OUT_LABEL = "parsed_1" + + class PARS2: + IN_LABEL = "collected" + OUT_LABEL = "parsed_2" + + class MYPlugin(ExtractionPlugin): + TEXT_COLLECTORS = [COLL1] + TEXT_EXTRACTORS = [EXT1, EXT2] + PARSERS = [PARS1, PARS2] + + IDENTIFIER = "test" + WEBSITE_KEYWORDS = [] + QUESTION_TEMPLATES = [] + heuristic = None + + with pytest.raises( + COMPASSPluginConfigurationError, + match=( + r"One or more processing classes require IN_LABEL 'collected_2', " + r"which is not produced by any previous processing class: " + r"\['EXT2'\]" + ), + ): + MYPlugin(None, None, None) + + +def test_plugin_validation_no_in_key_for_parse(): + """Test plugin interface validation logic""" + + class COLL1: + OUT_LABEL = "collected" + + class EXT1: + IN_LABEL = "collected" + OUT_LABEL = "extracted" + + class EXT2: + IN_LABEL = "collected" + OUT_LABEL = "extracted_1" + + class PARS1: + IN_LABEL = "extracted" + OUT_LABEL = "parsed_1" + + class PARS2: + IN_LABEL = "collected_2" + OUT_LABEL = "parsed_2" + + class MYPlugin(ExtractionPlugin): + TEXT_COLLECTORS = [COLL1] + TEXT_EXTRACTORS = [EXT1, EXT2] + PARSERS = [PARS1, PARS2] + + IDENTIFIER = "test" + WEBSITE_KEYWORDS = [] + QUESTION_TEMPLATES = [] + heuristic = None + + with pytest.raises( + COMPASSPluginConfigurationError, + match=( + r"One or more processing classes require IN_LABEL 'collected_2', " + r"which is not produced by any previous processing class: " + r"\['PARS2'\]" + ), + ): + MYPlugin(None, None, None) + + +if __name__ == "__main__": + pytest.main(["-q", "--show-capture=all", Path(__file__), "-rapP"]) diff --git a/tests/python/unit/scripts/test_process.py b/tests/python/unit/scripts/test_process.py index e9411c943..70ce550d4 100644 --- a/tests/python/unit/scripts/test_process.py +++ b/tests/python/unit/scripts/test_process.py @@ -274,7 +274,7 @@ async def test_process_steps_logged( assert result == f"processed {jurisdiction_fp}" assert_message_was_logged( - "Using the following processing step(s):", log_level="INFO" + "Using the following document acquisition step(s):", log_level="INFO" ) assert_message_was_logged(" -> ".join(expected_steps), log_level="INFO") diff --git a/tests/python/unit/services/test_services_threaded.py b/tests/python/unit/services/test_services_threaded.py index af45efdf1..566bbecd6 100644 --- a/tests/python/unit/services/test_services_threaded.py +++ b/tests/python/unit/services/test_services_threaded.py @@ -12,6 +12,7 @@ import pytest from elm.web.document import HTMLDocument +from compass.extraction.context import ExtractionContext from compass.services import threaded from compass.services.provider import RunningAsyncServices from compass.services.threaded import ( @@ -161,12 +162,10 @@ def test_move_file_uses_jurisdiction_name(tmp_path): cached_fp.write_text("content", encoding="utf-8") doc = HTMLDocument(["payload"]) - doc.attrs.update( - {"cache_fn": cached_fp, "jurisdiction_name": "Test County, ST"} - ) + doc.attrs.update({"cache_fn": cached_fp}) date = datetime.now().strftime("%Y_%m_%d") - moved_fp = threaded._move_file(doc, out_dir) + moved_fp = threaded._move_file(doc, out_dir, out_fn="Test County, ST") expected_name = f"Test_County_ST_downloaded_{date}.pdf" assert moved_fp.name == expected_name @@ -202,19 +201,21 @@ def test_write_cleaned_file_with_debug(tmp_path, monkeypatch): doc.attrs.update( { "jurisdiction_name": "Sample Jurisdiction", - "cleaned_ordinance_text": "clean", + "cleaned_text_for_extraction": "clean", "districts_text": "districts", - "ordinance_text": "orig", + "relevant_text": "orig", "permitted_use_text": "perm", "permitted_use_only_text": None, } ) monkeypatch.setattr(threaded, "COMPASS_DEBUG_LEVEL", 1, raising=False) - outputs = threaded._write_cleaned_file(doc, tmp_path) + outputs = threaded._write_cleaned_file( + doc, tmp_path, jurisdiction_name="Sample Jurisdiction" + ) expected_files = { - "Sample Jurisdiction Ordinance Summary.txt", + "Sample Jurisdiction Cleaned Text.txt", "Sample Jurisdiction Districts.txt", } assert {fp.name for fp in outputs} == expected_files @@ -229,7 +230,7 @@ def test_write_cleaned_file_without_jurisdiction_returns_none(tmp_path): """If jurisdiction name missing, cleaned file writer should do nothing""" doc = HTMLDocument(["payload"]) - doc.attrs["cleaned_ordinance_text"] = "clean" + doc.attrs["cleaned_text_for_extraction"] = "clean" assert threaded._write_cleaned_file(doc, tmp_path) is None @@ -237,15 +238,12 @@ def test_write_cleaned_file_skips_missing_section(tmp_path): """Missing sections should be skipped instead of erroring""" doc = HTMLDocument(["payload"]) - doc.attrs.update( - { - "jurisdiction_name": "Partial", - "cleaned_ordinance_text": "clean", - } - ) + doc.attrs.update({"cleaned_text_for_extraction": "clean"}) - outputs = threaded._write_cleaned_file(doc, tmp_path) - assert [fp.name for fp in outputs] == ["Partial Ordinance Summary.txt"] + outputs = threaded._write_cleaned_file( + doc, tmp_path, jurisdiction_name="Partial" + ) + assert [fp.name for fp in outputs] == ["Partial Cleaned Text.txt"] def test_write_ord_db_creates_csv(tmp_path): @@ -259,15 +257,10 @@ def test_write_ord_db_creates_csv(tmp_path): "other": [1], } ) - doc = HTMLDocument(["payload"]) - doc.attrs.update( - { - "jurisdiction_name": "Sample", - "scraped_values": df, - } + context = ExtractionContext(attrs={"structured_data": df}) + out_fp = threaded._write_ord_db( + context, tmp_path, out_fn="Sample Ordinances.csv" ) - - out_fp = threaded._write_ord_db(doc, tmp_path) assert out_fp.exists() assert ( out_fp.read_text(encoding="utf-8") @@ -278,8 +271,8 @@ def test_write_ord_db_creates_csv(tmp_path): def test_write_ord_db_requires_data(tmp_path): """Ord database writer returns None when data missing""" - doc = HTMLDocument(["payload"]) - assert threaded._write_ord_db(doc, tmp_path) is None + context = ExtractionContext() + assert threaded._write_ord_db(context, tmp_path, "") is None @pytest.mark.asyncio @@ -340,8 +333,7 @@ async def test_cleaned_file_writer_process(tmp_path, monkeypatch): doc = HTMLDocument(["payload"]) doc.attrs.update( { - "jurisdiction_name": "Writer", - "cleaned_ordinance_text": "clean", + "cleaned_text_for_extraction": "clean", "districts_text": "district", } ) @@ -349,12 +341,12 @@ async def test_cleaned_file_writer_process(tmp_path, monkeypatch): writer = CleanedFileWriter(tmp_path) assert writer.can_process is True writer.acquire_resources() - outputs = await writer.process(doc) + outputs = await writer.process(doc, "Writer") writer.release_resources() assert sorted(fp.name for fp in outputs) == [ + "Writer Cleaned Text.txt", "Writer Districts.txt", - "Writer Ordinance Summary.txt", ] @@ -369,17 +361,12 @@ async def test_ord_db_file_writer_process(tmp_path): "summary": ["s"], } ) - doc = HTMLDocument(["payload"]) - doc.attrs.update( - { - "jurisdiction_name": "Ord", - "scraped_values": df, - } - ) + context = ExtractionContext(attrs={"structured_data": df}) writer = OrdDBFileWriter(tmp_path) writer.acquire_resources() - out_fp = await writer.process(doc) + + out_fp = await writer.process(context, "Ord") writer.release_resources() assert out_fp.exists() @@ -464,10 +451,8 @@ async def test_jurisdiction_updater_process(tmp_path): "out_fp": tmp_path / "ord" / "doc.pdf", "checksum": "sha256:abc", "from_ocr": True, - "ordinance_text_ngram_score": 0.9, + "relevant_text_ngram_score": 0.9, "permitted_use_text_ngram_score": 0.8, - "jurisdiction_website": "http://jurisdiction.gov", - "compass_crawl": True, "ordinance_values": pd.DataFrame( { "feature": ["setback"], @@ -477,6 +462,14 @@ async def test_jurisdiction_updater_process(tmp_path): ), } ) + context = ExtractionContext( + doc, + attrs={ + "jurisdiction_website": "http://jurisdiction.gov", + "compass_crawl": True, + }, + ) + context.data_docs = [doc] tracker = SimpleNamespace( totals={ @@ -495,7 +488,7 @@ async def test_jurisdiction_updater_process(tmp_path): code="00002", ) - await updater.process(jur2, doc, 12.5, tracker) + await updater.process(jur2, context, 12.5, tracker) with jurisdiction_fp.open(encoding="utf-8") as fh: data = json.load(fh) @@ -513,25 +506,6 @@ async def test_jurisdiction_updater_process(tmp_path): updater.release_resources() -def test_compute_jurisdiction_cost_uses_registry(): - """Ensure model costs are computed using registry values""" - - tracker = SimpleNamespace( - totals={ - "gpt-4o": { - "prompt_tokens": 1_000_000, - "response_tokens": 1_000_000, - } - } - ) - assert threaded._compute_jurisdiction_cost(tracker) == pytest.approx(12.5) - - tracker_unknown = SimpleNamespace( - totals={"unknown": {"prompt_tokens": 1_000_000}} - ) - assert threaded._compute_jurisdiction_cost(tracker_unknown) == 0 - - def test_dump_usage_without_tracker_returns_existing_data(tmp_path): """_dump_usage should return existing data unchanged when tracker absent""" diff --git a/tests/python/unit/test_exceptions.py b/tests/python/unit/test_exceptions.py index 2cde1b63e..33e80ddc9 100644 --- a/tests/python/unit/test_exceptions.py +++ b/tests/python/unit/test_exceptions.py @@ -10,9 +10,11 @@ from compass.exceptions import ( COMPASSError, + COMPASSTypeError, COMPASSValueError, COMPASSNotInitializedError, COMPASSRuntimeError, + COMPASSPluginConfigurationError, ) @@ -55,6 +57,10 @@ def test_exceptions_log_uncaught_error(assert_message_was_logged): COMPASSNotInitializedError, [COMPASSError, COMPASSNotInitializedError], ), + ( + COMPASSTypeError, + [COMPASSError, TypeError, COMPASSTypeError], + ), ( COMPASSValueError, [COMPASSError, ValueError, COMPASSValueError], @@ -63,6 +69,15 @@ def test_exceptions_log_uncaught_error(assert_message_was_logged): COMPASSRuntimeError, [COMPASSError, RuntimeError, COMPASSRuntimeError], ), + ( + COMPASSPluginConfigurationError, + [ + COMPASSError, + RuntimeError, + COMPASSRuntimeError, + COMPASSPluginConfigurationError, + ], + ), ], ) def test_catching_error_by_type( diff --git a/tests/python/unit/utilities/test_utilities_costs.py b/tests/python/unit/utilities/test_utilities_costs.py new file mode 100644 index 000000000..158843b40 --- /dev/null +++ b/tests/python/unit/utilities/test_utilities_costs.py @@ -0,0 +1,520 @@ +"""Tests for COMPASS cost computation utilities""" + +from pathlib import Path + +import pytest + +from compass.utilities.costs import ( + LLM_COST_REGISTRY, + compute_cost_from_totals, + compute_total_cost_from_usage, + cost_for_model, +) + + +@pytest.mark.parametrize( + "model_name,prompt_tokens,completion_tokens,expected", + [ + ("gpt-4o", 1_000_000, 1_000_000, 12.5), + ("gpt-4o-mini", 1_000_000, 1_000_000, 0.75), + ("o1", 500_000, 500_000, 37.5), + ("gpt-5-nano", 2_000_000, 1_000_000, 0.5), + ("unknown-model", 1_000_000, 1_000_000, 0.0), + ("gpt-4o", 0, 0, 0.0), + ("gpt-4o", 100_000, 0, 0.25), + ("gpt-4o", 0, 100_000, 1.0), + ("gpt-4o", 2_500_000, 3_000_000, 36.25), + ("compassop-gpt-4.1-nano", 1_000_000, 500_000, 0.3), + ("wetosa-gpt-5-mini", 500_000, 500_000, 1.125), + ], +) +def test_cost_for_model_known_models( + model_name, prompt_tokens, completion_tokens, expected +): + """Test `cost_for_model` with various known models and token counts""" + result = cost_for_model(model_name, prompt_tokens, completion_tokens) + assert result == pytest.approx(expected) + + +@pytest.mark.parametrize( + "model_name,prompt_tokens,completion_tokens", + [ + ("", 1_000_000, 1_000_000), + ("GPT-4O", 1_000_000, 1_000_000), + ("gpt-4o-MINI", 1_000_000, 1_000_000), + ("gpt4o", 1_000_000, 1_000_000), + ], +) +def test_cost_for_model_case_sensitivity_and_unknown( + model_name, prompt_tokens, completion_tokens +): + """Test `cost_for_model` returns zero for bad inputs""" + result = cost_for_model(model_name, prompt_tokens, completion_tokens) + assert result == 0.0 + + +def test_cost_for_model_with_embedding_model(): + """Test `cost_for_model` with embedding-only model""" + result = cost_for_model("text-embedding-ada-002", 1_000_000, 0) + assert result == pytest.approx(0.10) + + +def test_cost_for_model_with_large_token_counts(): + """Test `cost_for_model` handles very large token counts accurately""" + result = cost_for_model("gpt-4o", 100_000_000, 50_000_000) + assert result == pytest.approx(750.0) + + +@pytest.mark.parametrize( + "totals,expected", + [ + ( + { + "gpt-4o": { + "prompt_tokens": 1_000_000, + "response_tokens": 500_000, + } + }, + 7.5, + ), + ( + { + "gpt-4o": { + "prompt_tokens": 1_000_000, + "response_tokens": 500_000, + }, + "gpt-4o-mini": { + "prompt_tokens": 2_000_000, + "response_tokens": 1_000_000, + }, + }, + 8.4, + ), + ({}, 0.0), + ( + {"gpt-4o": {}}, + 0.0, + ), + ( + {"gpt-4o": {"prompt_tokens": 1_000_000}}, + 2.5, + ), + ( + {"gpt-4o": {"response_tokens": 1_000_000}}, + 10.0, + ), + ( + { + "gpt-4o": { + "prompt_tokens": 500_000, + "response_tokens": 200_000, + }, + "unknown-model": { + "prompt_tokens": 1_000_000, + "response_tokens": 1_000_000, + }, + }, + 3.25, + ), + ( + { + "o1": {"prompt_tokens": 1_000_000, "response_tokens": 500_000}, + "gpt-5-nano": { + "prompt_tokens": 2_000_000, + "response_tokens": 1_000_000, + }, + "gpt-4.1-mini": { + "prompt_tokens": 500_000, + "response_tokens": 500_000, + }, + }, + 46.5, + ), + ], +) +def test_compute_cost_from_totals(totals, expected): + """Test `compute_cost_from_totals` with various total configurations""" + result = compute_cost_from_totals(totals) + assert result == pytest.approx(expected) + + +def test_compute_cost_from_totals_with_extra_keys(): + """Test `compute_cost_from_totals` ignores extra keys in usage dict""" + totals = { + "gpt-4o": { + "prompt_tokens": 1_000_000, + "response_tokens": 500_000, + "extra_key": "ignored", + "another_key": 999, + } + } + result = compute_cost_from_totals(totals) + assert result == pytest.approx(7.5) + + +@pytest.mark.parametrize( + "tracked_usage,expected", + [ + ( + { + "location1": { + "tracker_totals": { + "gpt-4o": { + "prompt_tokens": 1_000_000, + "response_tokens": 500_000, + } + } + } + }, + 7.5, + ), + ( + { + "location1": { + "tracker_totals": { + "gpt-4o": { + "prompt_tokens": 1_000_000, + "response_tokens": 500_000, + } + } + }, + "location2": { + "tracker_totals": { + "gpt-4o-mini": { + "prompt_tokens": 2_000_000, + "response_tokens": 1_000_000, + } + } + }, + }, + 8.4, + ), + ({}, 0.0), + ( + {"location1": {}}, + 0.0, + ), + ( + {"location1": {"tracker_totals": {}}}, + 0.0, + ), + ( + { + "location1": { + "tracker_totals": { + "gpt-4o": { + "prompt_tokens": 500_000, + "response_tokens": 200_000, + } + } + }, + "location2": { + "tracker_totals": { + "unknown-model": { + "prompt_tokens": 1_000_000, + "response_tokens": 1_000_000, + } + } + }, + }, + 3.25, + ), + ( + { + "new_york_county": { + "tracker_totals": { + "gpt-4o": { + "prompt_tokens": 800_000, + "response_tokens": 400_000, + }, + "gpt-4o-mini": { + "prompt_tokens": 1_500_000, + "response_tokens": 750_000, + }, + } + }, + "california_county": { + "tracker_totals": { + "o1": { + "prompt_tokens": 500_000, + "response_tokens": 250_000, + }, + } + }, + "texas_county": { + "tracker_totals": { + "gpt-5-nano": { + "prompt_tokens": 3_000_000, + "response_tokens": 2_000_000, + }, + } + }, + }, + 30.125, + ), + ], +) +def test_compute_total_cost_from_usage(tracked_usage, expected): + """Test `compute_total_cost_from_usage` with various usage configs""" + result = compute_total_cost_from_usage(tracked_usage) + assert result == pytest.approx(expected) + + +def test_compute_total_cost_from_usage_with_extra_keys(): + """Test `compute_total_cost_from_usage` ignores extra keys in usage dict""" + tracked_usage = { + "location1": { + "tracker_totals": { + "gpt-4o": { + "prompt_tokens": 1_000_000, + "response_tokens": 500_000, + } + }, + "extra_field": "ignored", + "timestamp": "2026-01-01", + } + } + result = compute_total_cost_from_usage(tracked_usage) + assert result == pytest.approx(7.5) + + +def test_integration_single_jurisdiction_workflow(): + """Test complete workflow from model costs to total tracked usage""" + prompt_tokens = 1_000_000 + completion_tokens = 500_000 + + model_cost = cost_for_model("gpt-4o", prompt_tokens, completion_tokens) + assert model_cost == pytest.approx(7.5) + + totals = { + "gpt-4o": { + "prompt_tokens": prompt_tokens, + "response_tokens": completion_tokens, + } + } + totals_cost = compute_cost_from_totals(totals) + assert totals_cost == pytest.approx(7.5) + assert totals_cost == pytest.approx(model_cost) + + tracked_usage = {"jurisdiction1": {"tracker_totals": totals}} + total_cost = compute_total_cost_from_usage(tracked_usage) + assert total_cost == pytest.approx(7.5) + assert total_cost == pytest.approx(totals_cost) + + +def test_integration_multi_jurisdiction_multi_model_workflow(): + """Test complete workflow with multiple jurisdictions and models""" + jurisdiction_configs = [ + ("california", "gpt-4o", 1_000_000, 500_000), + ("texas", "gpt-4o-mini", 2_000_000, 1_000_000), + ("new_york", "o1", 500_000, 250_000), + ] + + expected_individual_costs = [] + tracked_usage = {} + + for jurisdiction, model, prompt, completion in jurisdiction_configs: + individual_cost = cost_for_model(model, prompt, completion) + expected_individual_costs.append(individual_cost) + + totals = { + model: {"prompt_tokens": prompt, "response_tokens": completion} + } + totals_cost = compute_cost_from_totals(totals) + assert totals_cost == pytest.approx(individual_cost) + + tracked_usage[jurisdiction] = {"tracker_totals": totals} + + total_cost = compute_total_cost_from_usage(tracked_usage) + expected_total = sum(expected_individual_costs) + assert total_cost == pytest.approx(expected_total) + assert total_cost == pytest.approx(30.9) + + +def test_integration_mixed_known_unknown_models(): + """Test integration with mix of known and unknown models""" + totals = { + "gpt-4o": {"prompt_tokens": 500_000, "response_tokens": 200_000}, + "unknown-model-1": { + "prompt_tokens": 1_000_000, + "response_tokens": 1_000_000, + }, + "gpt-4o-mini": { + "prompt_tokens": 1_000_000, + "response_tokens": 500_000, + }, + "unknown-model-2": { + "prompt_tokens": 500_000, + "response_tokens": 500_000, + }, + } + + totals_cost = compute_cost_from_totals(totals) + + tracked_usage = {"jurisdiction": {"tracker_totals": totals}} + total_cost = compute_total_cost_from_usage(tracked_usage) + + assert totals_cost == pytest.approx(total_cost) + assert total_cost == pytest.approx(3.7) + + +def test_llm_cost_registry_structure(): + """Test LLM_COST_REGISTRY has expected structure""" + assert isinstance(LLM_COST_REGISTRY, dict) + assert len(LLM_COST_REGISTRY) > 0 + + for model_name, costs in LLM_COST_REGISTRY.items(): + assert isinstance(model_name, str) + assert len(model_name) > 0 + assert isinstance(costs, dict) + assert "prompt" in costs + assert isinstance(costs["prompt"], (int, float)) + assert costs["prompt"] > 0 + + +def test_llm_cost_registry_response_costs(): + """Test models with response costs have valid values""" + models_with_response = [ + model + for model, costs in LLM_COST_REGISTRY.items() + if "response" in costs + ] + + assert len(models_with_response) > 0 + + for model in models_with_response: + response_cost = LLM_COST_REGISTRY[model]["response"] + assert isinstance(response_cost, (int, float)) + assert response_cost > 0 + + +def test_llm_cost_registry_embedding_models(): + """Test embedding models have prompt cost but may lack response cost""" + embedding_model = "text-embedding-ada-002" + assert embedding_model in LLM_COST_REGISTRY + assert "prompt" in LLM_COST_REGISTRY[embedding_model] + assert "response" not in LLM_COST_REGISTRY[embedding_model] + + +def test_llm_cost_registry_model_name_patterns(): + """Test registry contains expected model name patterns""" + model_names = list(LLM_COST_REGISTRY.keys()) + + assert any("gpt-4o" in name for name in model_names) + assert any("gpt-5" in name for name in model_names) + assert any("compassop" in name for name in model_names) + assert any("wetosa" in name for name in model_names) + + +def test_llm_cost_registry_response_higher_than_prompt(): + """Test response costs are typically higher than prompt costs""" + models_with_both = [ + model + for model, costs in LLM_COST_REGISTRY.items() + if "response" in costs and "prompt" in costs + ] + + higher_response_count = sum( + 1 + for model in models_with_both + if LLM_COST_REGISTRY[model]["response"] + > LLM_COST_REGISTRY[model]["prompt"] + ) + + assert higher_response_count > len(models_with_both) * 0.8 + + +def test_llm_cost_registry_no_negative_costs(): + """Test registry contains no negative cost values""" + for model_name, costs in LLM_COST_REGISTRY.items(): + for cost_type, cost_value in costs.items(): + assert cost_value >= 0, ( + f"Negative cost for {model_name}.{cost_type}" + ) + + +def test_cost_for_model_with_negative_tokens(): + """Test `cost_for_model` handles negative token counts as zero""" + result = cost_for_model("gpt-4o", -1_000_000, -500_000) + assert result == pytest.approx(-7.5) + + +def test_compute_cost_from_totals_with_negative_tokens(): + """Test `compute_cost_from_totals` with negative token values""" + totals = { + "gpt-4o": {"prompt_tokens": -1_000_000, "response_tokens": 500_000} + } + result = compute_cost_from_totals(totals) + assert result == pytest.approx(2.5) + + +def test_cost_calculation_precision(): + """Test cost calculations maintain precision with small token counts""" + result = cost_for_model("gpt-4o", 1, 1) + expected = (1 / 1e6 * 2.5) + (1 / 1e6 * 10) + assert result == pytest.approx(expected) + assert result == pytest.approx(0.0000125) + + +def test_compute_total_cost_from_usage_deeply_nested(): + """Test `compute_total_cost_from_usage` with realistic nested structure""" + tracked_usage = { + "jurisdiction_1": { + "tracker_totals": { + "gpt-4o": { + "prompt_tokens": 500_000, + "response_tokens": 250_000, + }, + "gpt-4o-mini": { + "prompt_tokens": 1_000_000, + "response_tokens": 500_000, + }, + }, + "metadata": {"runtime": 120.5}, + }, + "jurisdiction_2": { + "tracker_totals": { + "gpt-5-nano": { + "prompt_tokens": 2_000_000, + "response_tokens": 1_000_000, + }, + }, + "metadata": {"runtime": 95.3}, + }, + } + + result = compute_total_cost_from_usage(tracked_usage) + expected = ( + (500_000 / 1e6 * 2.5 + 250_000 / 1e6 * 10) + + (1_000_000 / 1e6 * 0.15 + 500_000 / 1e6 * 0.6) + + (2_000_000 / 1e6 * 0.05 + 1_000_000 / 1e6 * 0.4) + ) + assert result == pytest.approx(expected) + assert result == pytest.approx(4.7) + + +def test_compute_jurisdiction_cost_uses_registry(): + """Ensure model costs are computed using registry values""" + + tracker = { + "jurisdiction_1": { + "tracker_totals": { + "gpt-4o": { + "prompt_tokens": 1_000_000, + "response_tokens": 1_000_000, + } + } + } + } + assert compute_total_cost_from_usage(tracker) == pytest.approx(12.5) + + tracker_unknown = { + "jurisdiction_1": { + "tracker_totals": {"unknown": {"prompt_tokens": 1_000_000}} + } + } + assert compute_total_cost_from_usage(tracker_unknown) == 0 + + +if __name__ == "__main__": + pytest.main(["-q", "--show-capture=all", Path(__file__), "-rapP"]) diff --git a/tests/python/unit/utilities/test_utilities_finalize.py b/tests/python/unit/utilities/test_utilities_finalize.py index 2d01fca3a..1fd456458 100644 --- a/tests/python/unit/utilities/test_utilities_finalize.py +++ b/tests/python/unit/utilities/test_utilities_finalize.py @@ -173,6 +173,8 @@ def test_doc_infos_to_db_compiles_and_formats(tmp_path): "value": 100, "units": "ft", "adder": 300, + "source": "http://example.com/valid", + "ord_year": 2022, } ] ).to_csv(valid_csv, index=False) @@ -196,8 +198,6 @@ def test_doc_infos_to_db_compiles_and_formats(tmp_path): }, { "ord_db_fp": valid_csv, - "source": "http://example.com/valid", - "date": (2022, 3, 4), "jurisdiction": jurisdiction, }, ] @@ -289,16 +289,9 @@ def test_db_results_populates_jurisdiction_fields(): subdivision_name="Subdivision B", type="city", ) - doc_info = { - "source": "http://example.com", - "date": (2021, 5, 6), - "jurisdiction": jurisdiction, - } - result = finalize._db_results(base_df.copy(), doc_info) + result = finalize._db_results(base_df.copy(), jurisdiction) row = result.iloc[0] - assert row["source"] == "http://example.com" - assert row["ord_year"] == 2021 assert row["FIPS"] == "54321" assert row["county"] == "County B" assert row["jurisdiction_type"] == "city" diff --git a/tests/python/unit/utilities/test_utilities_jurisdictions.py b/tests/python/unit/utilities/test_utilities_jurisdictions.py index f089f7528..b30f3e8dc 100644 --- a/tests/python/unit/utilities/test_utilities_jurisdictions.py +++ b/tests/python/unit/utilities/test_utilities_jurisdictions.py @@ -10,6 +10,9 @@ load_all_jurisdiction_info, load_jurisdictions_from_fp, jurisdiction_websites, + jurisdictions_from_df, + Jurisdiction, + _JURISDICTION_TYPES_AS_PREFIXES, ) from compass.exceptions import COMPASSValueError from compass.warn import COMPASSWarning @@ -197,5 +200,546 @@ def test_load_jurisdictions_no_repeated_townships_and_counties(tmp_path): assert {type(val) for val in jurisdictions["FIPS"]} == {int} +def test_basic_state_properties(): + """Test basic properties for ``Jurisdiction`` class for a state""" + + state = Jurisdiction("state", state="Colorado") + + assert repr(state) == "Colorado" + assert state.full_name == "Colorado" + assert state.full_name == str(state) + + assert not state.full_county_phrase + assert not state.full_subdivision_phrase + + assert state == Jurisdiction("state", state="cOlORAdo") + assert state != Jurisdiction("city", state="Colorado") + + assert state == "Colorado" + assert state == "colorado" + + +def test_basic_county_properties(): + """Test basic properties for ``Jurisdiction`` class for a county""" + + county = Jurisdiction("county", county="Box Elder", state="Utah") + + assert repr(county) == "Box Elder County, Utah" + assert county.full_name == "Box Elder County, Utah" + assert county.full_name == str(county) + + assert county.full_county_phrase == "Box Elder County" + assert not county.full_subdivision_phrase + + assert county != Jurisdiction("county", county="Box elder", state="uTah") + assert county != Jurisdiction("city", county="Box Elder", state="Utah") + + assert county == "Box Elder County, Utah" + assert county == "Box elder county, Utah" + + +def test_basic_parish_properties(): + """Test basic properties for ``Jurisdiction`` class for a parish""" + + parish = Jurisdiction("parish", county="Assumption", state="Louisiana") + + assert repr(parish) == "Assumption Parish, Louisiana" + assert parish.full_name == "Assumption Parish, Louisiana" + assert parish.full_name == str(parish) + + assert parish.full_county_phrase == "Assumption Parish" + assert not parish.full_subdivision_phrase + + assert parish == Jurisdiction( + "parish", county="Assumption", state="lOuisiana" + ) + assert parish != Jurisdiction( + "parish", county="assumption", state="lOuisiana" + ) + assert parish != Jurisdiction( + "county", county="Assumption", state="Louisiana" + ) + + assert parish == "Assumption Parish, Louisiana" + assert parish == "assumption parish, lOuisiana" + + +@pytest.mark.parametrize("jt", ["town", "city", "borough", "township"]) +def test_basic_town_properties(jt): + """Test basic properties for ``Jurisdiction`` class for a town""" + + town = Jurisdiction( + jt, county="Jefferson", state="Colorado", subdivision_name="Golden" + ) + + assert repr(town) == f"{jt.title()} of Golden, Jefferson County, Colorado" + assert ( + town.full_name == f"{jt.title()} of Golden, Jefferson County, Colorado" + ) + assert town.full_name == str(town) + assert town.full_county_phrase == "Jefferson County" + assert town.full_subdivision_phrase == f"{jt.title()} of Golden" + + assert town == Jurisdiction( + jt, county="Jefferson", state="colorado", subdivision_name="Golden" + ) + assert town != Jurisdiction( + jt, county="jefferson", state="colorado", subdivision_name="Golden" + ) + assert town != Jurisdiction( + jt, county="Jefferson", state="colorado", subdivision_name="golden" + ) + assert town != Jurisdiction( + "county", + county="Jefferson", + state="Colorado", + subdivision_name="Golden", + ) + + assert town == f"{jt.title()} of Golden, Jefferson County, Colorado" + assert town == f"{jt.title()} of golden, jefferson county, colorado" + + +def test_atypical_subdivision_properties(): + """Test basic properties for ``Jurisdiction`` class for a subdivision""" + + gore = Jurisdiction( + "gore", county="Chittenden", state="Vermont", subdivision_name="Buels" + ) + + assert repr(gore) == "Buels Gore, Chittenden County, Vermont" + assert gore.full_name == "Buels Gore, Chittenden County, Vermont" + assert gore.full_name == str(gore) + assert gore.full_county_phrase == "Chittenden County" + assert gore.full_subdivision_phrase == "Buels Gore" + + assert gore == Jurisdiction( + "gore", county="Chittenden", state="vermont", subdivision_name="Buels" + ) + assert gore != Jurisdiction( + "gore", county="chittenden", state="vermont", subdivision_name="Buels" + ) + assert gore != Jurisdiction( + "gore", county="Chittenden", state="vermont", subdivision_name="buels" + ) + assert gore != Jurisdiction( + "county", + county="Chittenden", + state="Vermont", + subdivision_name="Buels", + ) + + assert gore == "Buels Gore, Chittenden County, Vermont" + assert gore == "buels gOre, chittENden county, vermonT" + + +def test_city_no_county(): + """Test ``Jurisdiction`` for a city with no county""" + + gore = Jurisdiction("city", "Maryland", subdivision_name="Baltimore") + + assert repr(gore) == "City of Baltimore, Maryland" + assert gore.full_name == "City of Baltimore, Maryland" + assert gore.full_name == str(gore) + + assert not gore.full_county_phrase + assert gore.full_subdivision_phrase == "City of Baltimore" + + assert gore == Jurisdiction( + "city", "maryland", subdivision_name="Baltimore" + ) + assert gore != Jurisdiction( + "city", "maryland", subdivision_name="baltimore" + ) + assert gore != Jurisdiction( + "county", "maryland", subdivision_name="baltimore" + ) + + assert gore == "City of Baltimore, Maryland" + assert gore == "ciTy of baltiMore, maryland" + + +def test_full_name_the_prefixed_property(): + """Test ``Jurisdiction.full_name_the_prefixed`` property""" + + state = Jurisdiction("state", state="Colorado") + assert state.full_name_the_prefixed == "the state of Colorado" + + county = Jurisdiction("county", state="Colorado", county="Jefferson") + assert county.full_name_the_prefixed == "Jefferson County, Colorado" + + city = Jurisdiction( + "city", state="Colorado", county="Jefferson", subdivision_name="Golden" + ) + assert ( + city.full_name_the_prefixed + == "the City of Golden, Jefferson County, Colorado" + ) + + for st in _JURISDICTION_TYPES_AS_PREFIXES: + jur = Jurisdiction(st, state="Colorado", subdivision_name="Test") + assert ( + jur.full_name_the_prefixed == f"the {st.title()} of Test, Colorado" + ) + + jur = Jurisdiction(st, state="Colorado", subdivision_name="Test") + assert jur.full_name_the_prefixed == f"the {st.title()} of Test, Colorado" + + jur = Jurisdiction( + "census county division", + state="Colorado", + county="Test a", + subdivision_name="Test b", + ) + + assert ( + jur.full_name_the_prefixed + == "Test b Census County Division, Test a County, Colorado" + ) + + +def test_full_subdivision_phrase_the_prefixed_property(): + """Test ``Jurisdiction.full_subdivision_phrase_the_prefixed`` property""" + + for st in _JURISDICTION_TYPES_AS_PREFIXES: + jur = Jurisdiction(st, state="Colorado", subdivision_name="Test") + assert ( + jur.full_subdivision_phrase_the_prefixed + == f"the {st.title()} of Test" + ) + + jur = Jurisdiction( + "census county division", + state="Colorado", + county="Test a", + subdivision_name="Test b", + ) + + assert ( + jur.full_subdivision_phrase_the_prefixed + == "Test b Census County Division" + ) + + +def test_jurisdictions_from_df_basic(): + """Test ``jurisdictions_from_df`` generator with various row types""" + + jurisdictions_df = pd.DataFrame( + { + "Jurisdiction Type": ["state", "county", "city"], + "State": ["Colorado", "Utah", "Texas"], + "County": [None, "Box Elder", "Travis"], + "Subdivision": [None, None, "Austin"], + "FIPS": [8, 49003, 48453], + "Website": [ + "https://colorado.gov", + "https://boxeldercounty.org", + "https://austintexas.gov", + ], + } + ) + + jurisdictions = list(jurisdictions_from_df(jurisdictions_df)) + + assert len(jurisdictions) == 3 + + state_jur = jurisdictions[0] + assert state_jur.type == "State" + assert state_jur.state == "Colorado" + assert state_jur.county is None + assert state_jur.subdivision_name is None + assert state_jur.code == 8 + assert state_jur.website_url == "https://colorado.gov" + assert state_jur.full_name == "Colorado" + + county_jur = jurisdictions[1] + assert county_jur.type == "County" + assert county_jur.state == "Utah" + assert county_jur.county == "Box Elder" + assert county_jur.subdivision_name is None + assert county_jur.code == 49003 + assert county_jur.website_url == "https://boxeldercounty.org" + assert county_jur.full_name == "Box Elder County, Utah" + + city_jur = jurisdictions[2] + assert city_jur.type == "City" + assert city_jur.state == "Texas" + assert city_jur.county == "Travis" + assert city_jur.subdivision_name == "Austin" + assert city_jur.code == 48453 + assert city_jur.website_url == "https://austintexas.gov" + assert city_jur.full_name == "City of Austin, Travis County, Texas" + + +def test_jurisdictions_from_df_with_none_values(): + """Test ``jurisdictions_from_df`` handles None/missing values properly""" + + jurisdictions_df = pd.DataFrame( + { + "Jurisdiction Type": ["county"], + "State": ["Indiana"], + "County": ["Decatur"], + "Subdivision": [None], + "FIPS": [18031], + "Website": [None], + } + ) + + jurisdictions = list(jurisdictions_from_df(jurisdictions_df)) + + assert len(jurisdictions) == 1 + jur = jurisdictions[0] + assert jur.type == "County" + assert jur.state == "Indiana" + assert jur.county == "Decatur" + assert jur.subdivision_name is None + assert jur.code == 18031 + assert jur.website_url is None + + +def test_jurisdictions_from_df_texas_water_districts(): + """Test ``jurisdictions_from_df`` with Texas water district pattern""" + + jurisdictions_df = pd.DataFrame( + { + "Jurisdiction Type": [ + "Authority & Groundwater District", + "Aquifer Conservation District", + ], + "State": ["Texas", "Texas"], + "County": [None, None], + "Subdivision": [ + "Bandera County River", + "Barton Springs/Edwards", + ], + "FIPS": [1, 2], + "Website": [ + "https://bcragd.org", + "https://www.bseacd.org", + ], + } + ) + + jurisdictions = list(jurisdictions_from_df(jurisdictions_df)) + + assert len(jurisdictions) == 2 + + district1 = jurisdictions[0] + assert district1.type == "Authority & Groundwater District" + assert district1.state == "Texas" + assert district1.county is None + assert district1.subdivision_name == "Bandera County River" + assert district1.code == 1 + assert ( + district1.full_name + == "Bandera County River Authority & Groundwater District, Texas" + ) + + district2 = jurisdictions[1] + assert district2.type == "Aquifer Conservation District" + assert district2.state == "Texas" + assert district2.county is None + assert district2.subdivision_name == "Barton Springs/Edwards" + assert district2.code == 2 + assert ( + district2.full_name + == "Barton Springs/Edwards Aquifer Conservation District, Texas" + ) + + +def test_jurisdiction_equality_with_non_string_non_jurisdiction(): + """Test ``Jurisdiction.__eq__`` with incompatible types returns False""" + + jur = Jurisdiction("county", state="Colorado", county="Jefferson") + + assert jur is not None + assert jur != 42 + assert jur != ["Jefferson County", "Colorado"] + assert jur != {"county": "Jefferson", "state": "Colorado"} + + +def test_jurisdiction_hash_consistency(): + """Test that equal jurisdictions have the same hash""" + + jur1 = Jurisdiction("county", state="Colorado", county="Jefferson") + jur2 = Jurisdiction("county", state="colorado", county="Jefferson") + jur3 = Jurisdiction("County", state="COLORADO", county="Jefferson") + + assert hash(jur1) == hash(jur2) == hash(jur3) + + jur_set = {jur1, jur2, jur3} + assert len(jur_set) == 1 + + +def test_jurisdiction_hash_different_types_same_name(): + """Test that jurisdictions hash correctly""" + + county = Jurisdiction("county", state="Virginia", county="Alexandria") + city = Jurisdiction( + "city", state="Virginia", subdivision_name="Alexandria" + ) + + assert county.full_name != city.full_name + assert hash(county) != hash(city) + + +def test_full_county_phrase_with_subdivision(): + """Test ``full_county_phrase`` when subdivision exists and county exists""" + + jur = Jurisdiction( + "town", + state="Maine", + county="Aroostook", + subdivision_name="Perham", + ) + + assert jur.full_county_phrase == "Aroostook County" + assert jur.full_subdivision_phrase == "Town of Perham" + assert jur.full_name == "Town of Perham, Aroostook County, Maine" + + +def test_full_subdivision_phrase_non_prefix_type(): + """Test ``full_subdivision_phrase`` with non-prefix jurisdiction type""" + + gore = Jurisdiction( + "gore", state="Vermont", county="Chittenden", subdivision_name="Buels" + ) + + assert gore.full_subdivision_phrase == "Buels Gore" + assert "gore" not in _JURISDICTION_TYPES_AS_PREFIXES + + +def test_full_name_the_prefixed_non_prefix_type(): + """Test ``full_name_the_prefixed`` for non-state, non-prefix types""" + + parish = Jurisdiction("parish", state="Louisiana", county="Assumption") + assert parish.full_name_the_prefixed == "Assumption Parish, Louisiana" + + gore = Jurisdiction( + "gore", state="Vermont", county="Chittenden", subdivision_name="Buels" + ) + assert ( + gore.full_name_the_prefixed == "Buels Gore, Chittenden County, Vermont" + ) + + +def test_jurisdiction_websites_custom_dataframe(): + """Test ``jurisdiction_websites`` with explicitly passed DataFrame""" + + custom_df = pd.DataFrame( + { + "County": ["Test County", "Another County"], + "State": ["Colorado", "Utah"], + "Subdivision": [None, None], + "Jurisdiction Type": ["county", "county"], + "FIPS": [99001, 99002], + "Website": ["https://test.gov", "https://another.gov"], + } + ) + + websites = jurisdiction_websites(jurisdiction_info=custom_df) + + assert len(websites) == 2 + assert websites[99001] == "https://test.gov" + assert websites[99002] == "https://another.gov" + + +def test_load_jurisdictions_from_fp_missing_jurisdiction_type(tmp_path): + """Test error when Subdivision provided without Jurisdiction Type column""" + + test_jurisdiction_fp = tmp_path / "out.csv" + input_jurisdictions = pd.DataFrame( + { + "County": ["Aroostook"], + "State": ["Maine"], + "Subdivision": ["Perham"], + } + ) + input_jurisdictions.to_csv(test_jurisdiction_fp) + + with pytest.raises(COMPASSValueError) as exc_info: + load_jurisdictions_from_fp(test_jurisdiction_fp) + + error_msg = str(exc_info.value) + assert "Jurisdiction Type" in error_msg + assert "Subdivision" in error_msg + assert "must have" in error_msg + + +def test_load_jurisdictions_from_fp_warning_message_content(tmp_path): + """Test that warning message contains key information""" + + test_jurisdiction_fp = tmp_path / "out.csv" + input_jurisdictions = pd.DataFrame( + { + "County": ["Fake County", "Another Fake"], + "State": ["Colorado", "Texas"], + } + ) + input_jurisdictions.to_csv(test_jurisdiction_fp) + + with pytest.warns(COMPASSWarning) as record: + jurisdictions = load_jurisdictions_from_fp(test_jurisdiction_fp) + + assert len(record) == 1 + warning_msg = str(record[0].message) + + assert "not found" in warning_msg.lower() + assert "Fake County" in warning_msg + assert "Another Fake" in warning_msg + assert "Colorado" in warning_msg + assert "Texas" in warning_msg + assert ( + "spelling" in warning_msg.lower() + or "capitalization" in warning_msg.lower() + ) + + assert len(jurisdictions) == 0 + + +def test_load_jurisdictions_from_fp_whitespace_trimming(tmp_path): + """Test that input jurisdictions have whitespace trimmed""" + + test_jurisdiction_fp = tmp_path / "out.csv" + input_jurisdictions = pd.DataFrame( + { + "County": [" Decatur ", "Wharton"], + "State": [" Indiana ", " Texas "], + } + ) + input_jurisdictions.to_csv(test_jurisdiction_fp) + + jurisdictions = load_jurisdictions_from_fp(test_jurisdiction_fp) + + assert len(jurisdictions) == 2 + assert "Decatur" in set(jurisdictions["County"]) + assert "Wharton" in set(jurisdictions["County"]) + assert "Indiana" in set(jurisdictions["State"]) + assert "Texas" in set(jurisdictions["State"]) + + +def test_load_jurisdictions_from_fp_subdivision_whitespace_trimming(tmp_path): + """Test whitespace trimming for subdivisions and jurisdiction types""" + + test_jurisdiction_fp = tmp_path / "out.csv" + input_jurisdictions = pd.DataFrame( + { + "County": [" Aroostook "], + "State": [" Maine "], + "Subdivision": [" Perham "], + "Jurisdiction Type": [" town "], + } + ) + input_jurisdictions.to_csv(test_jurisdiction_fp) + + jurisdictions = load_jurisdictions_from_fp(test_jurisdiction_fp) + + assert len(jurisdictions) == 1 + assert jurisdictions.iloc[0]["County"] == "Aroostook" + assert jurisdictions.iloc[0]["State"] == "Maine" + assert jurisdictions.iloc[0]["Subdivision"] == "Perham" + assert jurisdictions.iloc[0]["Jurisdiction Type"] == "town" + + if __name__ == "__main__": pytest.main(["-q", "--show-capture=all", Path(__file__), "-rapP"]) diff --git a/tests/python/unit/utilities/test_utilities_location.py b/tests/python/unit/utilities/test_utilities_location.py deleted file mode 100644 index debcd50bc..000000000 --- a/tests/python/unit/utilities/test_utilities_location.py +++ /dev/null @@ -1,235 +0,0 @@ -"""COMPASS Ordinance Location utility tests""" - -from pathlib import Path - -import pytest - -from compass.utilities.location import ( - Jurisdiction, - JURISDICTION_TYPES_AS_PREFIXES, -) - - -def test_basic_state_properties(): - """Test basic properties for ``Jurisdiction`` class for a state""" - - state = Jurisdiction("state", state="Colorado") - - assert repr(state) == "Colorado" - assert state.full_name == "Colorado" - assert state.full_name == str(state) - - assert not state.full_county_phrase - assert not state.full_subdivision_phrase - - assert state == Jurisdiction("state", state="cOlORAdo") - assert state != Jurisdiction("city", state="Colorado") - - assert state == "Colorado" - assert state == "colorado" - - -def test_basic_county_properties(): - """Test basic properties for ``Jurisdiction`` class for a county""" - - county = Jurisdiction("county", county="Box Elder", state="Utah") - - assert repr(county) == "Box Elder County, Utah" - assert county.full_name == "Box Elder County, Utah" - assert county.full_name == str(county) - - assert county.full_county_phrase == "Box Elder County" - assert not county.full_subdivision_phrase - - assert county != Jurisdiction("county", county="Box elder", state="uTah") - assert county != Jurisdiction("city", county="Box Elder", state="Utah") - - assert county == "Box Elder County, Utah" - assert county == "Box elder county, Utah" - - -def test_basic_parish_properties(): - """Test basic properties for ``Jurisdiction`` class for a parish""" - - parish = Jurisdiction("parish", county="Assumption", state="Louisiana") - - assert repr(parish) == "Assumption Parish, Louisiana" - assert parish.full_name == "Assumption Parish, Louisiana" - assert parish.full_name == str(parish) - - assert parish.full_county_phrase == "Assumption Parish" - assert not parish.full_subdivision_phrase - - assert parish == Jurisdiction( - "parish", county="Assumption", state="lOuisiana" - ) - assert parish != Jurisdiction( - "parish", county="assumption", state="lOuisiana" - ) - assert parish != Jurisdiction( - "county", county="Assumption", state="Louisiana" - ) - - assert parish == "Assumption Parish, Louisiana" - assert parish == "assumption parish, lOuisiana" - - -@pytest.mark.parametrize("jt", ["town", "city", "borough", "township"]) -def test_basic_town_properties(jt): - """Test basic properties for ``Jurisdiction`` class for a town""" - - town = Jurisdiction( - jt, county="Jefferson", state="Colorado", subdivision_name="Golden" - ) - - assert repr(town) == f"{jt.title()} of Golden, Jefferson County, Colorado" - assert ( - town.full_name == f"{jt.title()} of Golden, Jefferson County, Colorado" - ) - assert town.full_name == str(town) - assert town.full_county_phrase == "Jefferson County" - assert town.full_subdivision_phrase == f"{jt.title()} of Golden" - - assert town == Jurisdiction( - jt, county="Jefferson", state="colorado", subdivision_name="Golden" - ) - assert town != Jurisdiction( - jt, county="jefferson", state="colorado", subdivision_name="Golden" - ) - assert town != Jurisdiction( - jt, county="Jefferson", state="colorado", subdivision_name="golden" - ) - assert town != Jurisdiction( - "county", - county="Jefferson", - state="Colorado", - subdivision_name="Golden", - ) - - assert town == f"{jt.title()} of Golden, Jefferson County, Colorado" - assert town == f"{jt.title()} of golden, jefferson county, colorado" - - -def test_atypical_subdivision_properties(): - """Test basic properties for ``Jurisdiction`` class for a subdivision""" - - gore = Jurisdiction( - "gore", county="Chittenden", state="Vermont", subdivision_name="Buels" - ) - - assert repr(gore) == "Buels Gore, Chittenden County, Vermont" - assert gore.full_name == "Buels Gore, Chittenden County, Vermont" - assert gore.full_name == str(gore) - assert gore.full_county_phrase == "Chittenden County" - assert gore.full_subdivision_phrase == "Buels Gore" - - assert gore == Jurisdiction( - "gore", county="Chittenden", state="vermont", subdivision_name="Buels" - ) - assert gore != Jurisdiction( - "gore", county="chittenden", state="vermont", subdivision_name="Buels" - ) - assert gore != Jurisdiction( - "gore", county="Chittenden", state="vermont", subdivision_name="buels" - ) - assert gore != Jurisdiction( - "county", - county="Chittenden", - state="Vermont", - subdivision_name="Buels", - ) - - assert gore == "Buels Gore, Chittenden County, Vermont" - assert gore == "buels gOre, chittENden county, vermonT" - - -def test_city_no_county(): - """Test ``Jurisdiction`` for a city with no county""" - - gore = Jurisdiction("city", "Maryland", subdivision_name="Baltimore") - - assert repr(gore) == "City of Baltimore, Maryland" - assert gore.full_name == "City of Baltimore, Maryland" - assert gore.full_name == str(gore) - - assert not gore.full_county_phrase - assert gore.full_subdivision_phrase == "City of Baltimore" - - assert gore == Jurisdiction( - "city", "maryland", subdivision_name="Baltimore" - ) - assert gore != Jurisdiction( - "city", "maryland", subdivision_name="baltimore" - ) - assert gore != Jurisdiction( - "county", "maryland", subdivision_name="baltimore" - ) - - assert gore == "City of Baltimore, Maryland" - assert gore == "ciTy of baltiMore, maryland" - - -def test_full_name_the_prefixed_property(): - """Test ``Jurisdiction.full_name_the_prefixed`` property""" - - state = Jurisdiction("state", state="Colorado") - assert state.full_name_the_prefixed == "the state of Colorado" - - county = Jurisdiction("county", state="Colorado", county="Jefferson") - assert county.full_name_the_prefixed == "Jefferson County, Colorado" - - city = Jurisdiction( - "city", state="Colorado", county="Jefferson", subdivision_name="Golden" - ) - assert ( - city.full_name_the_prefixed - == "the City of Golden, Jefferson County, Colorado" - ) - - for st in JURISDICTION_TYPES_AS_PREFIXES: - jur = Jurisdiction(st, state="Colorado", subdivision_name="Test") - assert ( - jur.full_name_the_prefixed == f"the {st.title()} of Test, Colorado" - ) - - jur = Jurisdiction(st, state="Colorado", subdivision_name="Test") - assert jur.full_name_the_prefixed == f"the {st.title()} of Test, Colorado" - - jur = Jurisdiction( - "census county division", - state="Colorado", - county="Test a", - subdivision_name="Test b", - ) - - assert ( - jur.full_name_the_prefixed - == "Test b Census County Division, Test a County, Colorado" - ) - - -def test_full_subdivision_phrase_the_prefixed_property(): - """Test ``Jurisdiction.full_subdivision_phrase_the_prefixed`` property""" - - for st in JURISDICTION_TYPES_AS_PREFIXES: - jur = Jurisdiction(st, state="Colorado", subdivision_name="Test") - assert ( - jur.full_subdivision_phrase_the_prefixed - == f"the {st.title()} of Test" - ) - - jur = Jurisdiction( - "census county division", - state="Colorado", - county="Test a", - subdivision_name="Test b", - ) - - assert ( - jur.full_subdivision_phrase_the_prefixed - == "Test b Census County Division" - ) - - -if __name__ == "__main__": - pytest.main(["-q", "--show-capture=all", Path(__file__), "-rapP"]) diff --git a/tests/python/unit/utilities/test_utilities_parsing.py b/tests/python/unit/utilities/test_utilities_parsing.py index 6a1e4b4d8..f96e1d529 100644 --- a/tests/python/unit/utilities/test_utilities_parsing.py +++ b/tests/python/unit/utilities/test_utilities_parsing.py @@ -16,7 +16,6 @@ load_config, merge_overlapping_texts, num_ordinances_dataframe, - num_ordinances_in_doc, ordinances_bool_index, ) from compass.exceptions import COMPASSValueError @@ -100,52 +99,6 @@ def test_extract_ord_year_from_doc_attrs(doc_attrs, expected): assert extract_ord_year_from_doc_attrs(doc_attrs) == expected -def test_num_ordinances_in_doc_none(): - """Test `num_ordinances_in_doc` with None document""" - - assert num_ordinances_in_doc(None) == 0 - - -def test_num_ordinances_in_doc_no_ordinance_values(): - """Test `num_ordinances_in_doc` with document missing ordinance_values""" - - doc = MagicMock() - doc.attrs = {} - assert num_ordinances_in_doc(doc) == 0 - - -def test_num_ordinances_in_doc_with_ordinances(): - """Test `num_ordinances_in_doc` with valid ordinances""" - - doc = MagicMock() - doc.attrs = { - "ordinance_values": pd.DataFrame( - { - "feature": ["setback", "height", "noise"], - "value": [100, 200, None], - "summary": ["test", None, "test"], - } - ) - } - assert num_ordinances_in_doc(doc) == 3 - - -def test_num_ordinances_in_doc_with_exclude_features(): - """Test `num_ordinances_in_doc` with excluded features""" - - doc = MagicMock() - doc.attrs = { - "ordinance_values": pd.DataFrame( - { - "feature": ["setback", "height", "noise"], - "value": [100, 200, 300], - "summary": ["test", "test", "test"], - } - ) - } - assert num_ordinances_in_doc(doc, exclude_features=["noise"]) == 2 - - def test_num_ordinances_dataframe_empty(): """Test `num_ordinances_dataframe` with empty DataFrame""" diff --git a/tests/python/unit/validation/test_validation_graphs.py b/tests/python/unit/validation/test_validation_graphs.py index 8cc83700e..9e99bb50e 100644 --- a/tests/python/unit/validation/test_validation_graphs.py +++ b/tests/python/unit/validation/test_validation_graphs.py @@ -4,7 +4,7 @@ import pytest -from compass.utilities.location import Jurisdiction +from compass.utilities.jurisdictions import Jurisdiction from compass.validation.graphs import ( setup_graph_correct_jurisdiction_type, setup_graph_correct_jurisdiction_from_url, diff --git a/tests/python/unit/validation/test_validation_location.py b/tests/python/unit/validation/test_validation_location.py index c51c76fc4..254b4f050 100644 --- a/tests/python/unit/validation/test_validation_location.py +++ b/tests/python/unit/validation/test_validation_location.py @@ -8,7 +8,7 @@ from elm.web.document import PDFDocument from elm.utilities.parse import read_pdf_ocr -from compass.utilities.location import Jurisdiction +from compass.utilities.jurisdictions import Jurisdiction from compass.validation.location import ( JurisdictionValidator, DTreeJurisdictionValidator,