Skip to content

Commit

Permalink
providers: add validate_record to validate record according to provid…
Browse files Browse the repository at this point in the history
…er [+]

- in turn this allows validation of publisher by DataciteProvider
- closes #1137
  • Loading branch information
fenekku committed Dec 9, 2022
1 parent a73e7c3 commit a88df89
Show file tree
Hide file tree
Showing 9 changed files with 137 additions and 29 deletions.
13 changes: 6 additions & 7 deletions invenio_rdm_records/services/components/pids.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,13 @@ def publish(self, identity, draft=None, record=None):
record_pids = copy(record.get("pids", {}))
draft_schemes = set(draft_pids.keys())
record_schemes = set(record_pids.keys())

# Determine schemes which are required, but not yet created.
missing_required_schemes = (
set(self.service.config.pids_required) - record_schemes - draft_schemes
)
required_schemes = set(self.service.config.pids_required)

# Validate the draft PIDs
self.service.pids.pid_manager.validate(draft_pids, record, raise_errors=True)

self.service.pids.pid_manager.validate_record(draft, raise_errors=True)

# Detect which PIDs on a published record that has been changed.
#
# Example: An external DOI (i.e. DOI not managed by us) can be changed
Expand All @@ -88,8 +86,9 @@ def publish(self, identity, draft=None, record=None):

self.service.pids.pid_manager.discard_all(changed_pids)

# Create all PIDs specified on draft or which PIDs schemes which are
# require
# Determine schemes which are required, but not yet created.
missing_required_schemes = required_schemes - record_schemes - draft_schemes
# Create all PIDs specified on draft and all missing required PIDs
pids = self.service.pids.pid_manager.create_all(
draft,
pids=draft_pids,
Expand Down
62 changes: 50 additions & 12 deletions invenio_rdm_records/services/pids/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
class PIDManager:
"""RDM PIDs Manager."""

def __init__(self, providers):
def __init__(self, providers, required_schemes=None):
"""Constructor for RecordService."""
self._providers = providers
self._required_schemes = required_schemes

def _get_provider(self, scheme, provider_name=None):
"""Get a provider."""
Expand All @@ -47,17 +48,6 @@ def _validate_pids_schemes(self, pids):
if unknown_schemes:
raise PIDSchemeNotSupportedError(unknown_schemes)

def _validate_pids(self, pids, record, errors):
"""Validate an iterator of PIDs.
This function assumes all pid schemes are supported by the system.
"""
for scheme, pid in pids.items():
provider = self._get_provider(scheme, pid.get("provider"))
success, val_errors = provider.validate(record=record, **pid)
if not success:
errors.append({"field": f"pids.{scheme}", "messages": val_errors})

def _validate_identifiers(self, pids, errors):
"""Validate and normalize identifiers."""
# TODO: Refactor to get it injected instead.
Expand Down Expand Up @@ -87,6 +77,17 @@ def _validate_identifiers(self, pids, errors):
for scheme, id_ in identifiers:
pids[scheme]["identifier"] = id_

def _validate_pids(self, pids, record, errors):
"""Validate an iterator of PIDs.
This function assumes all pid schemes are supported by the system.
"""
for scheme, pid in pids.items():
provider = self._get_provider(scheme, pid.get("provider"))
success, val_errors = provider.validate(record=record, **pid)
if not success:
errors.append({"field": f"pids.{scheme}", "messages": val_errors})

def validate(self, pids, record, errors=None, raise_errors=False):
"""Validate PIDs."""
errors = [] if errors is None else errors
Expand All @@ -97,6 +98,43 @@ def validate(self, pids, record, errors=None, raise_errors=False):
if raise_errors and errors:
raise ValidationError(message=errors)

def validate_record(self, record, raise_errors=False):
"""Validate the record according to the PIDs' requirements.
Here we check if the record is compatible from the point of
view of the pids...
- ... it contains
- ... it would contain according to configured required pids
The responsibility lies with each provider since they are the ones
that know their criteria for a record that is complete enough to get
a PID.
"""
errors = {}

# scheme, provider_name for record's pids
scheme_names = [
(scheme, pid.get("provider"))
for scheme, pid in record.get("pids", {}).items()
]
# scheme, None for required pids
scheme_names += [(scheme, None) for scheme in self._required_schemes]
providers = [
self._get_provider(scheme, provider_name)
for scheme, provider_name in scheme_names
]

for provider in providers:
success, provider_errors = provider.validate_record(record)
if not success:
# This is not perfect as one provider may override the error of another
# but a proper dict merging algorithm is out-of-bounds here and as long
# as an error is raised we are good.
errors.update(provider_errors)

if raise_errors and errors:
raise ValidationError(message=errors)

def read(self, scheme, identifier, provider_name):
"""Read a pid."""
provider = self._get_provider(scheme, provider_name)
Expand Down
13 changes: 13 additions & 0 deletions invenio_rdm_records/services/pids/providers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,16 @@ def validate(self, record, identifier=None, provider=None, **kwargs):
pass

return True, []

def validate_record(self, record):
"""Validate the record according to the provider's rules.
By default it always validates. Descendants should override this method with
their own validation logic if need be.
:param record: A record-like (draft or published record)
:returns: A tuple (success, errors). `success` is a bool that specifies
if the validation was successful. `errors` is an
error dict of the form: `{"<fieldA>": ["<msgA1>", ...], ...}`.
"""
return True, {}
20 changes: 20 additions & 0 deletions invenio_rdm_records/services/pids/providers/datacite.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,3 +209,23 @@ def validate(self, record, identifier=None, provider=None, **kwargs):
errors.append(str(e))

return (True, []) if not errors else (False, errors)

def validate_record(self, record):
"""Validate the record according to DataCite rules.
We only add check for values not already covered by the record schema.
:returns: A tuple (success, errors). `success` is a bool that specifies
if the validation was successful. `errors` is an
error dict of the form: `{"<fieldA>": ["<msgA1>", ...], ...}`.
"""
errors = {}

if not record["metadata"].get("publisher"):
errors.update(
{
"metadata.publisher": [_("Missing data for required field.")],
}
)

return not bool(errors), errors
2 changes: 1 addition & 1 deletion invenio_rdm_records/services/pids/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _manager(self):
- limit side-effects (in case using pre-existing `pid_manager` would
cause some.)
"""
return self.manager_cls(self.config.pids_providers)
return self.manager_cls(self.config.pids_providers, self.config.pids_required)

@property
def pid_manager(self):
Expand Down
6 changes: 4 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,8 +533,6 @@ def minimal_record():
"enabled": False, # Most tests don't care about files
},
"metadata": {
"publication_date": "2020-06-01",
"resource_type": {"id": "image-photo"},
"creators": [
{
"person_or_org": {
Expand All @@ -550,6 +548,10 @@ def minimal_record():
},
},
],
"publication_date": "2020-06-01",
# because DATACITE_ENABLED is True, this field is required
"publisher": "Acme Inc",
"resource_type": {"id": "image-photo"},
"title": "A Romans story",
},
}
Expand Down
1 change: 1 addition & 0 deletions tests/resources/serializers/test_dublincore_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def test_dublincorejson_serializer_minimal(running_app, updated_minimal_record):
"creators": ["Name", "Troy Inc."],
"dates": ["2020-06-01"],
"rights": ["info:eu-repo/semantics/openAccess"],
"publishers": ["Acme Inc"],
}

serializer = DublinCoreJSONSerializer()
Expand Down
4 changes: 4 additions & 0 deletions tests/services/components/test_pids_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ def validate(self, record, identifier=None, provider=None, **kwargs):
errors.append("Identifier must be an integer.")
return (True, []) if not errors else (False, errors)

def validate_record(self, record):
"""Validate the record according to PID rules i.e. always good."""
return True, []


# configs

Expand Down
45 changes: 38 additions & 7 deletions tests/services/pids/providers/test_datacite_pid_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,24 +195,55 @@ def custom_format_func(*args):
assert datacite_provider.create(record).pid_value == expected_result


def test_datacite_provider_validation(record, mocker):
client = DataCiteClient("datacite")

# check with default func
def test_datacite_provider_validate(record):
current_app.config["DATACITE_PREFIX"] = "10.1000"
client = DataCiteClient("datacite")
datacite_provider = DataCitePIDProvider("datacite", client=client)

# Case - Valid identifier (doi)
success, errors = datacite_provider.validate(
record=record, identifier="10.1000/valid.1234", provider="datacite"
)
assert success
assert errors == []
assert [] == errors

# Case - Invalid identifier (doi)
success, errors = datacite_provider.validate(
record=record, identifier="10.2000/invalid.1234", provider="datacite"
)

assert not success
assert errors == [
expected = [
"Wrong DOI 10.2000 prefix provided, "
+ "it should be 10.1000 as defined in the rest client"
]
assert expected == errors


def test_datacite_provider_validate_record(record):
record["metadata"] = {"publisher": "Acme Inc"}
current_app.config["DATACITE_PREFIX"] = "10.1000"
client = DataCiteClient("datacite")
datacite_provider = DataCitePIDProvider("datacite", client=client)

# Case - valid new record without pids.doi
success, errors = datacite_provider.validate_record(record)
assert {} == errors
assert success

# Case - valid record with pre-existing pids.doi
record["pids"] = {
"doi": {"provider": "datacite", "identifier": "10.1000/pre-existing.1234"}
}
success, errors = datacite_provider.validate_record(record)
assert {} == errors
assert success

# Case - invalid record
del record["metadata"]["publisher"]
success, errors = datacite_provider.validate_record(record)
expected = {
"metadata.publisher": ["Missing data for required field."],
}

assert expected == errors
assert not success

0 comments on commit a88df89

Please sign in to comment.