From 3dc272aa8f6aa1078ee80741a822d1a45c33cf48 Mon Sep 17 00:00:00 2001 From: Ken Lauer Date: Fri, 6 May 2022 16:44:15 -0700 Subject: [PATCH 1/5] REF: use vendored dbd for record information in graph --- whatrecord/common.py | 143 +++++++++++----------------- whatrecord/db.py | 61 +++++++++++- whatrecord/graph.py | 21 +++- whatrecord/tests/test_v3_parsing.py | 14 +++ 4 files changed, 144 insertions(+), 95 deletions(-) diff --git a/whatrecord/common.py b/whatrecord/common.py index a5df93fb..07441221 100644 --- a/whatrecord/common.py +++ b/whatrecord/common.py @@ -614,64 +614,6 @@ def get_link_information(link_str: str) -> Tuple[str, List[str]]: LINK_TYPES = {"DBF_INLINK", "DBF_OUTLINK", "DBF_FWDLINK"} -COMMON_LINK_FIELDS = ("FLNK", "SDIS", "TSEL") - -# Generate by way of: Database.field_names_by_type(LINK_TYPES) -# Used when database definition files are not loaded; may not be complete -# or 100% accurate depending on EPICS version. -LINK_FIELDS_BY_RECORD = { - "aSub": ("FLNK", "INPA", "INPB", "INPC", "INPD", "INPE", "INPF", "INPG", - "INPH", "INPI", "INPJ", "INPK", "INPL", "INPM", "INPN", "INPO", - "INPP", "INPQ", "INPR", "INPS", "INPT", "INPU", "OUTA", "OUTB", - "OUTC", "OUTD", "OUTE", "OUTF", "OUTG", "OUTH", "OUTI", "OUTJ", - "OUTK", "OUTL", "OUTM", "OUTN", "OUTO", "OUTP", "OUTQ", "OUTR", - "OUTS", "OUTT", "OUTU", "SDIS", "SUBL", "TSEL"), - "aai": ("FLNK", "INP", "SDIS", "SIML", "SIOL", "TSEL"), - "aao": ("FLNK", "OUT", "SDIS", "SIML", "SIOL", "TSEL"), - "ai": ("FLNK", "INP", "SDIS", "SIML", "SIOL", "TSEL"), - "ao": ("DOL", "FLNK", "OUT", "SDIS", "SIML", "SIOL", "TSEL"), - "bi": ("FLNK", "INP", "SDIS", "SIML", "SIOL", "TSEL"), - "bo": ("DOL", "FLNK", "OUT", "SDIS", "SIML", "SIOL", "TSEL"), - "calc": ("FLNK", "INPA", "INPB", "INPC", "INPD", "INPE", "INPF", "INPG", - "INPH", "INPI", "INPJ", "INPK", "INPL", "SDIS", "TSEL"), - "calcout": ("FLNK", "INPA", "INPB", "INPC", "INPD", "INPE", "INPF", "INPG", - "INPH", "INPI", "INPJ", "INPK", "INPL", "OUT", "SDIS", "TSEL"), - "compress": ("FLNK", "INP", "SDIS", "TSEL"), - "dfanout": ("DOL", "FLNK", "OUTA", "OUTB", "OUTC", "OUTD", "OUTE", "OUTF", - "OUTG", "OUTH", "SDIS", "SELL", "TSEL"), - "event": ("FLNK", "INP", "SDIS", "SIML", "SIOL", "TSEL"), - "fanout": ("FLNK", "LNK0", "LNK1", "LNK2", "LNK3", "LNK4", "LNK5", "LNK6", - "LNK7", "LNK8", "LNK9", "LNKA", "LNKB", "LNKC", "LNKD", "LNKE", - "LNKF", "SDIS", "SELL", "TSEL"), - "histogram": ("FLNK", "SDIS", "SIML", "SIOL", "SVL", "TSEL"), - "int64in": ("FLNK", "INP", "SDIS", "SIML", "SIOL", "TSEL"), - "int64out": ("DOL", "FLNK", "OUT", "SDIS", "SIML", "SIOL", "TSEL"), - "longin": ("FLNK", "INP", "SDIS", "SIML", "SIOL", "TSEL"), - "longout": ("DOL", "FLNK", "OUT", "SDIS", "SIML", "SIOL", "TSEL"), - "lsi": ("FLNK", "INP", "SDIS", "SIML", "SIOL", "TSEL"), - "lso": ("DOL", "FLNK", "OUT", "SDIS", "SIML", "SIOL", "TSEL"), - "mbbi": ("FLNK", "INP", "SDIS", "SIML", "SIOL", "TSEL"), - "mbbiDirect": ("FLNK", "INP", "SDIS", "SIML", "SIOL", "TSEL"), - "mbbo": ("DOL", "FLNK", "OUT", "SDIS", "SIML", "SIOL", "TSEL"), - "mbboDirect": ("DOL", "FLNK", "OUT", "SDIS", "SIML", "SIOL", "TSEL"), - "permissive": ("FLNK", "SDIS", "TSEL"), - "printf": ("FLNK", "INP0", "INP1", "INP2", "INP3", "INP4", "INP5", "INP6", - "INP7", "INP8", "INP9", "OUT", "SDIS", "TSEL"), - "sel": ("FLNK", "INPA", "INPB", "INPC", "INPD", "INPE", "INPF", "INPG", - "INPH", "INPI", "INPJ", "INPK", "INPL", "NVL", "SDIS", "TSEL"), - "seq": ("DOL0", "DOL1", "DOL2", "DOL3", "DOL4", "DOL5", "DOL6", "DOL7", - "DOL8", "DOL9", "DOLA", "DOLB", "DOLC", "DOLD", "DOLE", "DOLF", - "FLNK", "LNK0", "LNK1", "LNK2", "LNK3", "LNK4", "LNK5", "LNK6", - "LNK7", "LNK8", "LNK9", "LNKA", "LNKB", "LNKC", "LNKD", "LNKE", - "LNKF", "SDIS", "SELL", "TSEL"), - "state": ("FLNK", "SDIS", "TSEL"), - "stringin": ("FLNK", "INP", "SDIS", "SIML", "SIOL", "TSEL"), - "stringout": ("DOL", "FLNK", "OUT", "SDIS", "SIML", "SIOL", "TSEL"), - "sub": ("FLNK", "INPA", "INPB", "INPC", "INPD", "INPE", "INPF", "INPG", - "INPH", "INPI", "INPJ", "INPK", "INPL", "SDIS", "TSEL"), - "subArray": ("FLNK", "INP", "SDIS", "TSEL"), - "waveform": ("FLNK", "INP", "SDIS", "SIML", "SIOL", "TSEL") -} @dataclass @@ -780,12 +722,61 @@ class RecordType: info: Dict[str, str] = field(default_factory=dict) is_grecord: bool = False + def get_links_for_record( + self, record: RecordInstance + ) -> Generator[Tuple[RecordField, str, List[str]], None, None]: + """ + Get all links - in, out, and forward links - for the given record. + + Parameters + ---------- + record : RecordInstance + Additional information, if the database definition wasn't loaded + with this instance. + + Yields + ------ + field : RecordField + link_text: str + link_info: str + """ + if record.record_type != self.name: + raise ValueError("Record types do not match") + + for field_type_info in self.get_fields_of_type(*LINK_TYPES): + field_instance = record.fields.get(field_type_info.name, None) + if field_instance and not isinstance(field_instance, PVAFieldReference): + try: + link, info = get_link_information(field_instance.value) + except ValueError: + continue + yield field_instance, link, info + + def get_fields_of_type(self, *types: str) -> Generator[RecordTypeField, None, None]: + """Get all fields of the matching type(s).""" + for fld in self.fields.values(): + if fld.type in types: + yield fld + + def get_link_fields( + self, + ) -> Generator[Tuple[RecordTypeField, str, Tuple[str, ...]], None, None]: + """ + Get all link fields - in, out, and forward links. + + Yields + ------ + field : RecordTypeField + """ + yield from self.get_fields_of_type(*LINK_TYPES) + @dataclass class RecordInstance: context: FullLoadContext name: str record_type: str + has_dbd_info: bool = False fields: Dict[str, AnyField] = field(default_factory=dict) info: Dict[StringWithContext, Any] = field(default_factory=dict) metadata: Dict[StringWithContext, Any] = field(default_factory=dict) @@ -832,9 +823,12 @@ def get_fields_of_type(self, *types) -> Generator[RecordField, None, None]: def get_links( self, - ) -> Generator[Tuple[RecordField, str, Tuple[str, ...]], None, None]: + ) -> Generator[Tuple[RecordField, str, List[str]], None, None]: """ - Get all links. + Get all links - in, out, and forward links - for this record. + + Requires that a dbd file was loaded when the record was created. + Alternatively, see :func:`RecordType.get_links_for_record`. Yields ------ @@ -849,34 +843,6 @@ def get_links( continue yield fld, link, info - def get_common_links( - self, - ) -> Generator[Tuple[RecordField, str, Tuple[str, ...]], None, None]: - """ - Without using a database definition, try to find links. - - This differs from ``get_links`` in that the other method requires - a dbd file to be loaded, whereas this will use a simple - but possibly - inaccurate - map of of record type to link fields. - - Yields - ------ - field : RecordField - link_text: str - link_info: str - """ - if self.is_pva: - return - - for name in LINK_FIELDS_BY_RECORD.get(self.record_type, COMMON_LINK_FIELDS): - fld = self.fields.get(name, None) - if fld is not None: - try: - link, info = get_link_information(fld.value) - except ValueError: - continue - yield fld, link, info - def to_summary(self) -> RecordInstanceSummary: """Return a summarized version of the record instance.""" return RecordInstanceSummary.from_record_instance(self) @@ -906,6 +872,7 @@ def update(self, other: RecordInstance) -> List[LinterMessage]: [alias for alias in other.aliases if alias not in self.aliases] ) + self.has_dbd_info = self.has_dbd_info or other.has_dbd_info if self.record_type != other.record_type: return [ LinterError( diff --git a/whatrecord/db.py b/whatrecord/db.py index 4e972ff8..584cf98e 100644 --- a/whatrecord/db.py +++ b/whatrecord/db.py @@ -5,12 +5,13 @@ import pathlib import typing from dataclasses import field -from typing import Any, Dict, FrozenSet, List, Mapping, Optional, Tuple, Union +from typing import (Any, Dict, FrozenSet, Generator, List, Mapping, Optional, + Tuple, Union) import apischema import lark -from . import transformer +from . import transformer, util from .common import (DatabaseDevice, DatabaseMenu, LinterError, LinterWarning, PVAFieldReference, RecordField, RecordInstance, RecordType, RecordTypeField, StringWithContext, dataclass) @@ -429,8 +430,9 @@ def record(self, rec_token, head, body): ) if record_type_info is None: # TODO lint error, if dbd loaded - ... + record.has_dbd_info = False else: + record.has_dbd_info = True for fld in record.fields.values(): field_info = record_type_info.fields.get(fld.name, None) if field_info is None: @@ -805,5 +807,58 @@ def from_multiple(cls, *items: _DatabaseSource) -> Database: return db + @classmethod + def from_vendored_dbd(cls, version: int = 3) -> Database: + """ + Load the vendored database definition file from whatrecord. + + This is a good fallback when you have a database file without a + corresponding database definition file. + + Parameters + ---------- + version : int, optional + Use the old V3 style or new V3 style database grammar by specifying + 3 or 4, respectively. Defaults to 3. + + Returns + ------- + db : Database + """ + if version <= 3: + return cls.from_file( + util.MODULE_PATH / "tests" / "iocs/v3_softIoc.dbd", + version=version, + ) + return cls.from_file( + util.MODULE_PATH / "tests" / "iocs" / "softIoc.dbd", + version=version, + ) + + def get_links_for_record( + self, + record: RecordInstance, + ) -> Generator[Tuple[RecordField, str, List[str]], None, None]: + """ + Get all links - in, out, and forward links. + + Parameters + ---------- + record : RecordInstance + Additional information, if the database definition wasn't loaded + with this instance. + + Yields + ------ + field : RecordField + link_text: str + link_info: str + """ + record_info = self.record_types.get(record.record_type, None) + if not record_info: + return + + yield from record_info.get_links_for_record(record) + _DatabaseSource = Union["LoadedIoc", "ShellState", Database, LinterResults] diff --git a/whatrecord/graph.py b/whatrecord/graph.py index b1508c64..83ed471a 100644 --- a/whatrecord/graph.py +++ b/whatrecord/graph.py @@ -189,6 +189,7 @@ def build_database_relations( database: Dict[str, RecordInstance], record_types: Optional[Dict[str, RecordType]] = None, aliases: Optional[Dict[str, str]] = None, + version: int = 3, ) -> PVRelations: """ Build a dictionary of PV relationships. @@ -205,6 +206,12 @@ def build_database_relations( record_types : dict, optional The database definitions to use for fields that are not defined in the database file. Dictionary of record type name to RecordType. + If not specified, the whatrecord-vendored database definition files + will be used. + + version : int, optional + Use the old V3 style or new V3 style database grammar by specifying + 3 or 4, respectively. Defaults to 3. Returns ------- @@ -212,6 +219,10 @@ def build_database_relations( Such that: ``info[pv1][pv2] = (field1, field2, info)`` And in reverse: ``info[pv2][pv1] = (field2, field1, info)`` """ + if not record_types: + dbd = Database.from_vendored_dbd(version=version) + record_types = dbd.record_types + aliases = aliases or {} warned = set() unset_ctx: FullLoadContext = (LoadContext("unknown", 0), ) @@ -219,12 +230,14 @@ def build_database_relations( # TODO: alias handling? for rec1 in database.values(): - if record_types: - # Use links as defined in the database definition + # Use links as defined in the database definition + if rec1.has_dbd_info: rec1_links = rec1.get_links() else: - # Fall back to static list of link fields - rec1_links = rec1.get_common_links() + rec1_rtype = record_types.get(rec1.record_type, None) + if rec1_rtype is None: + continue + rec1_links = rec1_rtype.get_links_for_record(rec1) for field1, link, info in rec1_links: # TODO: copied without thinking about implications diff --git a/whatrecord/tests/test_v3_parsing.py b/whatrecord/tests/test_v3_parsing.py index ff1ff43c..a4b4b777 100644 --- a/whatrecord/tests/test_v3_parsing.py +++ b/whatrecord/tests/test_v3_parsing.py @@ -307,3 +307,17 @@ def test_unquoted_warning(): message="Unquoted field value 'B'" ) ] + + +@pytest.mark.parametrize( + "version", [3, 4], +) +def test_load_vendored_database_smoke(version: int): + dbd = Database.from_vendored_dbd(version=version) + record_types = list(dbd.record_types.values()) + assert len(record_types) + record = record_types[0] + if version == 3: + assert "v3_softIoc.dbd" in record.context[0].name + else: + assert "softIoc.dbd" in record.context[0].name From 8c4ec90c3ac04014db715886332d174e31eced20 Mon Sep 17 00:00:00 2001 From: Ken Lauer Date: Fri, 6 May 2022 17:00:00 -0700 Subject: [PATCH 2/5] FIX: stray field names in record link graphs --- whatrecord/graph.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/whatrecord/graph.py b/whatrecord/graph.py index 83ed471a..d20f8a84 100644 --- a/whatrecord/graph.py +++ b/whatrecord/graph.py @@ -588,22 +588,18 @@ def add_database(self, database: Union[Dict[str, RecordInstance], Database]): for li in find_record_links( self.database.records, self.starting_records, relations=self.relations ): - for (rec, field) in ((li.record1, li.field1), (li.record2, li.field2)): - if rec.name not in self.nodes: - self.get_node(label=rec.name, text=field.name) - - src = self.get_node(li.record1.name) - dest = self.get_node(li.record2.name) + src = self.get_node(li.record1.name, text=" ") + dest = self.get_node(li.record2.name, text=" ") for field, node in [(li.field1, src), (li.field2, dest)]: if field.value or self.show_empty: text_line = self.field_format.format( field=field.name, value=field.value ) - if node.text and text_line not in node.text: - node.text = "\n".join((node.text, text_line)) - else: + if not node.text.strip(): node.text = text_line + elif text_line not in node.text: + node.text = "\n".join((node.text, text_line)) if li.field1.dtype == "DBF_INLINK": src, dest = dest, src From 4e979650707b3fcccf302de90765d02e85e0d7fe Mon Sep 17 00:00:00 2001 From: Ken Lauer Date: Fri, 6 May 2022 17:36:38 -0700 Subject: [PATCH 3/5] WIP: fix up record graphs --- whatrecord/common.py | 18 +++++++-- whatrecord/graph.py | 87 +++++++++++++++++++++++++++++--------------- 2 files changed, 72 insertions(+), 33 deletions(-) diff --git a/whatrecord/common.py b/whatrecord/common.py index 07441221..65e09357 100644 --- a/whatrecord/common.py +++ b/whatrecord/common.py @@ -550,8 +550,21 @@ class RecordField: """, } - def update_unknowns(self, other: RecordField, *, unknown_values=None, - dbd=None): + def update_from_record_type( + self, + record_type: RecordType + ): + """Update field information given dbd-provided information.""" + record_type_field = record_type.fields.get(self.name, None) + if record_type_field is not None: + self.dtype = record_type_field.type + + def update_unknowns( + self, + other: RecordField, + *, + unknown_values: Optional[Sequence[str]] = None, + ): """ If this RecordField has some missing information ("unknown"), fill it in with information from the other field. @@ -566,7 +579,6 @@ def update_unknowns(self, other: RecordField, *, unknown_values=None, if ctx.name in unknown_values: # Even if the other context is unknown, let's take it anyway: self.context = other.context - # if dbd is not None: PVRelations = Dict[ diff --git a/whatrecord/graph.py b/whatrecord/graph.py index d20f8a84..6f24196a 100644 --- a/whatrecord/graph.py +++ b/whatrecord/graph.py @@ -185,6 +185,20 @@ def is_supported_link(link: str) -> bool: return False +def _get_links_for_record( + record: RecordInstance, record_types: Optional[Dict[str, RecordType]] = None +) -> Generator[Tuple[RecordField, str, List[str]], None, None]: + """ + Get links for the provided record, referring back to the record_types dict if necessary. + """ + if record.has_dbd_info: + yield from record.get_links() + elif record_types is not None: + rec1_rtype = record_types.get(record.record_type, None) + if rec1_rtype is not None: + yield from rec1_rtype.get_links_for_record(record) + + def build_database_relations( database: Dict[str, RecordInstance], record_types: Optional[Dict[str, RecordType]] = None, @@ -230,21 +244,18 @@ def build_database_relations( # TODO: alias handling? for rec1 in database.values(): - # Use links as defined in the database definition - if rec1.has_dbd_info: - rec1_links = rec1.get_links() - else: - rec1_rtype = record_types.get(rec1.record_type, None) - if rec1_rtype is None: - continue - rec1_links = rec1_rtype.get_links_for_record(rec1) - - for field1, link, info in rec1_links: + rec1_rtype = record_types.get(rec1.record_type, None) + for field1, link, info in _get_links_for_record( + rec1, record_types=record_types + ): # TODO: copied without thinking about implications # due to the removal of st.cmd context as an attempt to reduce field1 = copy.deepcopy(field1) # field1.context = rec1.context[:1] + field1.context + if not rec1.has_dbd_info and rec1_rtype: + field1.update_from_record_type(rec1_rtype) + if "." in link: link, field2 = link.split(".", 1) elif field1.name == "FLNK": @@ -254,15 +265,18 @@ def build_database_relations( rec2 = database.get(aliases.get(link, link), None) if rec2 is None: + # Case 1: The linked record is *not* in the database. + # TODO: switch to debug; this will be expensive later if not is_supported_link(link): continue - if link not in warned: - warned.add(link) + rec2_name = link + if rec2_name not in warned: + warned.add(rec2_name) logger.debug( "Linked record from %s.%s not in database: %s", - rec1.name, field1.name, link + rec1.name, field1.name, rec2_name ) field2 = RecordField( @@ -271,16 +285,20 @@ def build_database_relations( value="(unknown-record)", context=unset_ctx, ) - rec2_name = link elif field2 in rec2.fields: + # Case 2: The linked record is in the database and has a + # recognized field name. rec2_name = rec2.name # TODO: copied without thinking about implications field2 = copy.deepcopy(rec2.fields[field2]) # field2.context = rec2.context[:1] + field2.context - elif record_types: + else: + # Case 3: The linked record is in the database but does not + # have a recognized field name. rec2_name = rec2.name dbd_record_type = record_types.get(rec2.record_type, None) if dbd_record_type is None: + # Record type not in the database? field2 = RecordField( dtype="invalid", name=field2, @@ -288,6 +306,7 @@ def build_database_relations( context=unset_ctx, ) elif field2 not in dbd_record_type.fields: + # Field name invalid field2 = RecordField( dtype="invalid", name=field2, @@ -295,6 +314,7 @@ def build_database_relations( context=unset_ctx, ) else: + # Record and field found in provided record_types dbd_record_field = dbd_record_type.fields[field2] field2 = RecordField( dtype=dbd_record_field.type, @@ -302,14 +322,11 @@ def build_database_relations( value="", context=dbd_record_field.context, ) - else: - rec2_name = rec2.name - field2 = RecordField( - dtype="unknown", - name=field2, - value="", # unset or invalid, can't tell yet - context=unset_ctx, - ) + + if rec2 is not None: + rec2_type = record_types.get(rec2.record_type, None) + if rec2_type is not None: + field2.update_from_record_type(rec2_type) by_record[rec1.name][rec2_name].append((field1, field2, info)) by_record[rec2_name][rec1.name].append((field2, field1, info)) @@ -447,7 +464,8 @@ def get_items_to_update(): def find_record_links( database: Dict[str, RecordInstance], starting_records: List[str], - relations: Optional[PVRelations] = None + relations: Optional[PVRelations] = None, + record_types: Optional[Dict[str, RecordType]] = None, ) -> Generator[LinkInfo, None, None]: """ Get all related record links from a set of starting records. @@ -467,6 +485,10 @@ def find_record_links( Pre-built PV relationship dictionary. Generated from database if not provided. + record_types : dict, optional + The database definitions to use for fields that are not defined in the + database file. Dictionary of record type name to RecordType. + Yields ------ link_info : LinkInfo @@ -475,7 +497,7 @@ def find_record_links( checked = [] if relations is None: - relations = build_database_relations(database) + relations = build_database_relations(database, record_types=record_types) records_to_check = list(starting_records) @@ -592,7 +614,9 @@ def add_database(self, database: Union[Dict[str, RecordInstance], Database]): dest = self.get_node(li.record2.name, text=" ") for field, node in [(li.field1, src), (li.field2, dest)]: - if field.value or self.show_empty: + if field.name == "PROC": + ... + elif field.value or self.show_empty: text_line = self.field_format.format( field=field.name, value=field.value ) @@ -615,9 +639,12 @@ def add_database(self, database: Union[Dict[str, RecordInstance], Database]): break if (src, dest) not in set(self.edge_pairs): - edge_kw["xlabel"] = f"{li.field1.name}/{li.field2.name}" - if li.info: - edge_kw["xlabel"] += f"\n{' '.join(li.info)}" + if "DBF_FWDLINK" in (li.field1.dtype, li.field2.dtype): + edge_kw["xlabel"] = "FLNK" + else: + edge_kw["xlabel"] = f"{li.field1.name}/{li.field2.name}" + if li.info: + edge_kw["xlabel"] += f"\n{' '.join(li.info)}" self.add_edge(src.label, dest.label, **edge_kw) if not self.nodes: @@ -660,7 +687,7 @@ def graph_links( field_format: Optional[str] = None, text_format: Optional[str] = None, sort_fields: bool = True, - show_empty: bool = False, + show_empty: bool = True, relations: Optional[PVRelations] = None, record_types: Optional[Dict[str, RecordType]] = None, ) -> RecordLinkGraph: From 5dcdc55e22e308360aeef71de26025ee71996c30 Mon Sep 17 00:00:00 2001 From: Ken Lauer Date: Wed, 11 May 2022 14:32:22 -0700 Subject: [PATCH 4/5] MNT: clean up graphing + fix tests, hopefully --- whatrecord/common.py | 5 +- whatrecord/graph.py | 158 ++++++++++++++++++++------------- whatrecord/tests/test_graph.py | 18 +++- 3 files changed, 113 insertions(+), 68 deletions(-) diff --git a/whatrecord/common.py b/whatrecord/common.py index 65e09357..7eb25df6 100644 --- a/whatrecord/common.py +++ b/whatrecord/common.py @@ -581,8 +581,11 @@ def update_unknowns( self.context = other.context +# field1, field2, options (CA, CP, CPP, etc.) +FieldRelation = Tuple[RecordField, RecordField, List[str]] + PVRelations = Dict[ - str, Dict[str, List[Tuple[RecordField, RecordField, List[str]]]] + str, Dict[str, List[FieldRelation]] ] diff --git a/whatrecord/graph.py b/whatrecord/graph.py index 6f24196a..075764f7 100644 --- a/whatrecord/graph.py +++ b/whatrecord/graph.py @@ -199,6 +199,81 @@ def _get_links_for_record( yield from rec1_rtype.get_links_for_record(record) +_unset_ctx: FullLoadContext = (LoadContext("unknown", 0), ) + + +def _field_from_record_relation( + record: Optional[RecordInstance], + field_name: str, + link_text: str, + record_types: Dict[str, RecordType] +) -> Optional[RecordField]: + """ + Create a RecordField instance based on what we know, given the parameters. + + Parameters + ---------- + record : RecordInstance or None + An (optional) record instance, if available in the database. + + field_name : str + A field name for the record (field_name). + + link_text : str, optional + The link text - likely the record name - referring to ``record``. + + record_types : dict, optional + Record type information from a database definition. + """ + if record is None: + # Case 1: The linked record is *not* in the database. + + if not is_supported_link(link_text): + return None + + return RecordField( + dtype="unknown", + name=field_name, + value="(unknown-record)", + context=_unset_ctx, + ) + + if field_name in record.fields: + # Case 2: The linked record is in the database and has a + # recognized field name. + return copy.deepcopy(record.fields[field_name]) + + # Case 3: The linked record is in the database but does not + # have a recognized field name. + dbd_record_type = record_types.get(record.record_type, None) + if dbd_record_type is None: + # Record type not in the database? + return RecordField( + dtype="invalid", + name=field_name, + value="(invalid-record-type)", + context=_unset_ctx, + ) + + if field_name not in dbd_record_type.fields: + # Field name invalid + return RecordField( + dtype="invalid", + name=field_name, + value="(invalid-field)", + context=_unset_ctx, + ) + + # Record and field found in provided record_types + dbd_record_field = dbd_record_type.fields[field_name] + return RecordField( + dtype=dbd_record_field.type, + name=field_name, + value="", + context=dbd_record_field.context, + ) + + def build_database_relations( database: Dict[str, RecordInstance], record_types: Optional[Dict[str, RecordType]] = None, @@ -239,17 +314,13 @@ def build_database_relations( aliases = aliases or {} warned = set() - unset_ctx: FullLoadContext = (LoadContext("unknown", 0), ) by_record = collections.defaultdict(lambda: collections.defaultdict(list)) - # TODO: alias handling? for rec1 in database.values(): rec1_rtype = record_types.get(rec1.record_type, None) for field1, link, info in _get_links_for_record( rec1, record_types=record_types ): - # TODO: copied without thinking about implications - # due to the removal of st.cmd context as an attempt to reduce field1 = copy.deepcopy(field1) # field1.context = rec1.context[:1] + field1.context @@ -257,73 +328,34 @@ def build_database_relations( field1.update_from_record_type(rec1_rtype) if "." in link: - link, field2 = link.split(".", 1) + link, field2_name = link.split(".", 1) elif field1.name == "FLNK": - field2 = "PROC" + field2_name = "PROC" else: - field2 = "VAL" + field2_name = "VAL" - rec2 = database.get(aliases.get(link, link), None) - if rec2 is None: - # Case 1: The linked record is *not* in the database. + rec2_name = aliases.get(link, link) + rec2 = database.get(rec2_name, None) - # TODO: switch to debug; this will be expensive later - if not is_supported_link(link): - continue + field2 = _field_from_record_relation( + record=rec2, + field_name=field2_name, + link_text=link, + record_types=record_types, + ) - rec2_name = link - if rec2_name not in warned: - warned.add(rec2_name) - logger.debug( - "Linked record from %s.%s not in database: %s", - rec1.name, field1.name, rec2_name - ) + if field2 is None: + continue - field2 = RecordField( - dtype="unknown", - name=field2, - value="(unknown-record)", - context=unset_ctx, + if rec2 is None: + warned.add(rec2_name) + logger.debug( + "Linked record from %s.%s not in database: %s", + rec1.name, field1.name, rec2_name ) - elif field2 in rec2.fields: - # Case 2: The linked record is in the database and has a - # recognized field name. - rec2_name = rec2.name - # TODO: copied without thinking about implications - field2 = copy.deepcopy(rec2.fields[field2]) - # field2.context = rec2.context[:1] + field2.context else: - # Case 3: The linked record is in the database but does not - # have a recognized field name. - rec2_name = rec2.name - dbd_record_type = record_types.get(rec2.record_type, None) - if dbd_record_type is None: - # Record type not in the database? - field2 = RecordField( - dtype="invalid", - name=field2, - value="(invalid-record-type)", - context=unset_ctx, - ) - elif field2 not in dbd_record_type.fields: - # Field name invalid - field2 = RecordField( - dtype="invalid", - name=field2, - value="(invalid-field)", - context=unset_ctx, - ) - else: - # Record and field found in provided record_types - dbd_record_field = dbd_record_type.fields[field2] - field2 = RecordField( - dtype=dbd_record_field.type, - name=field2, - value="", - context=dbd_record_field.context, - ) - - if rec2 is not None: + # We may have updated information about the record field; + # but it's possible this is entirely unnecessary (TODO) rec2_type = record_types.get(rec2.record_type, None) if rec2_type is not None: field2.update_from_record_type(rec2_type) diff --git a/whatrecord/tests/test_graph.py b/whatrecord/tests/test_graph.py index cf0642f7..fdfd7dab 100644 --- a/whatrecord/tests/test_graph.py +++ b/whatrecord/tests/test_graph.py @@ -37,12 +37,15 @@ def create_record(record_type, record_name, fields, filename=None): return db.records[record_name] -def test_simple_graph(): +def test_simple_graph(dbd: Database): database = { "record_a": create_record("ai", "record_a", {"INP": "record_b CPP MS", "VAL": "10"}), "record_b": create_record("ao", "record_b", {"OUT": "record_c CA", "VAL": "20"}), } - relations = graph.build_database_relations(database) + relations = graph.build_database_relations( + database, + record_types=dbd.record_types, + ) print(database["record_a"]) assert relations["record_a"]["record_b"] == [ ( @@ -81,7 +84,7 @@ def dbd(): ) -def test_combine_relations(): +def test_combine_relations(dbd: Database): database_1 = { "record_a": create_record("ai", "record_a", {"INP": "record_b CPP MS", "VAL": "10"}), "record_b": create_record("ao", "record_b", {"OUT": "record_c CA", "VAL": "20"}), @@ -91,12 +94,16 @@ def test_combine_relations(): "record_d": create_record("ai", "record_d", {"INP": "", "VAL": "10"}), "record_e": create_record("ai", "record_e", {"INP": "record_a CP", "VAL": "10"}), } - relations = graph.build_database_relations(database_1) + relations = graph.build_database_relations( + database_1, + record_types=dbd.record_types, + ) graph.combine_relations( relations, database_1, graph.build_database_relations(database_2), database_2, + record_types=dbd.record_types, ) assert relations["record_a"]["record_b"] == [ @@ -245,6 +252,7 @@ def test_combine_with_alias(dbd: Database): "alias_a": "record_a", "alias_b": "record_b", }, + record_types=dbd.record_types, ) relations_2 = graph.build_database_relations( @@ -254,6 +262,7 @@ def test_combine_with_alias(dbd: Database): "alias_d": "record_d", "alias_e": "record_e", }, + record_types=dbd.record_types, ) graph.combine_relations( @@ -268,6 +277,7 @@ def test_combine_with_alias(dbd: Database): "alias_d": "record_d", "alias_e": "record_e", }, + record_types=dbd.record_types, ) assert relations["record_a"]["record_b"] == [ From be8722da0e999f01764d62a9e39ae6f157c08845 Mon Sep 17 00:00:00 2001 From: Ken Lauer Date: Wed, 11 May 2022 15:19:05 -0700 Subject: [PATCH 5/5] FIX: node text again for pv relations --- whatrecord/bin/graph.py | 5 +++++ whatrecord/graph.py | 44 ++++++++++++++++++++++++++++------------- 2 files changed, 35 insertions(+), 14 deletions(-) diff --git a/whatrecord/bin/graph.py b/whatrecord/bin/graph.py index 269f5782..16f54ea8 100644 --- a/whatrecord/bin/graph.py +++ b/whatrecord/bin/graph.py @@ -207,6 +207,11 @@ def main( if databases_only: graph = get_database_graph(*loaded_items, highlight=highlight) + if not graph.nodes: + logger.warning( + "No records found matching the highlight settings: %s", + highlight + ) else: try: item, = loaded_items diff --git a/whatrecord/graph.py b/whatrecord/graph.py index 075764f7..4a27a46c 100644 --- a/whatrecord/graph.py +++ b/whatrecord/graph.py @@ -34,6 +34,26 @@ class GraphNode: #: Highlight the node in the graph? highlighted: bool = False + def add_text_line(self, line: str, delimiter: str = "\n", only_unique: bool = True): + """ + Add a line of text to the node. + + Parameters + ---------- + line : str + The line to add. + + delimiter : str, optional + The between-line delimiter. + + only_unique : bool, optional + Only add unique lines to the text. + """ + if not self.text.strip(): + self.text = line + elif not only_unique or line not in self.text: + self.text = delimiter.join((self.text, line)) + def __hash__(self): return hash(self.id) @@ -82,7 +102,9 @@ def get_node( ) logger.debug("Created node %s", label) - self.nodes[label].text = text or self.nodes[label].text + if text and text.strip(): + self.nodes[label].add_text_line(text) + return self.nodes[label] def add_edge( @@ -571,8 +593,9 @@ class RecordLinkGraph(_GraphHelper): database: Database starting_records: List[str] + newline: str = '
' header_format: str = 'record({rtype}, "{name}")' - field_format: str = '{field:>4s}: "{value}"' + field_format: str = '{field}: "{value}"' text_format: str = ( f"{{header}}" f"{_GraphHelper.newline}" @@ -609,7 +632,7 @@ def __init__( field_format: Optional[str] = None, text_format: Optional[str] = None, sort_fields: bool = True, - show_empty: bool = False, + show_empty: bool = True, relations: Optional[PVRelations] = None, record_types: Optional[Dict[str, RecordType]] = None, ): @@ -649,13 +672,9 @@ def add_database(self, database: Union[Dict[str, RecordInstance], Database]): if field.name == "PROC": ... elif field.value or self.show_empty: - text_line = self.field_format.format( - field=field.name, value=field.value + node.add_text_line( + self.field_format.format(field=field.name, value=field.value) ) - if not node.text.strip(): - node.text = text_line - elif text_line not in node.text: - node.text = "\n".join((node.text, text_line)) if li.field1.dtype == "DBF_INLINK": src, dest = dest, src @@ -686,12 +705,11 @@ def add_database(self, database: Union[Dict[str, RecordInstance], Database]): self.get_node(rec_name) for node in self.nodes.values(): - field_lines = node.text if self.sort_fields: node.text = "\n".join(sorted(node.text.splitlines())) - if field_lines: - node.text += "\n" + if node.text.strip() and not node.text.endswith("\n\n"): + node.add_text_line("\n", only_unique=False) rec = self.database.records[node.label] header = self.header_format.format(rtype=rec.record_type, name=rec.name) @@ -808,8 +826,6 @@ def get_owner(rec): owner1 = get_owner(rec1) owner2 = get_owner(rec2) - # print(rec1_name, owner1, "|", rec2_name, owner2) - if owner1 != owner2: by_script[owner2][owner1].add(rec2_name) by_script[owner1][owner2].add(rec1_name)