Skip to content

Commit

Permalink
feat: named database support (#439)
Browse files Browse the repository at this point in the history
* feat: named database support (#398)

* feat: Add named database support

* test: Use named db in system tests

* 🦉 Updates from OwlBot post-processor

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

* Handle the case when client doesn't have database property

* fix: add custom routing headers

* Fixing tests for easier merge

* fixing code coverage

* addressing pr comments

* feat: Multi db test parametrization (#436)

* Feat: Parametrize the tests for multidb support

Remove "database" argument from Query and AggregationQuery constructors.
Use the "database" from the client instead.
Once set in the client, the "database" will be used throughout and cannot be re-set.

Parametrize the tests where-ever clients are used.

Use the `system-tests-named-db` in the system test.

* Add test case for when parent database name != child database name

* Update owlbot, removing the named db parameter

* Reverted test fixes

* fixing tests

* fix code coverage

* pr suggestion

* address pr comments

---------

Co-authored-by: Vishwaraj Anand <[email protected]>

---------

Co-authored-by: Bob "Wombat" Hogg <[email protected]>
Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
Co-authored-by: Vishwaraj Anand <[email protected]>
Co-authored-by: meredithslota <[email protected]>
  • Loading branch information
5 people authored Jun 21, 2023
1 parent a12971c commit abf0060
Show file tree
Hide file tree
Showing 29 changed files with 1,842 additions and 921 deletions.
6 changes: 3 additions & 3 deletions google/cloud/datastore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@
The main concepts with this API are:
- :class:`~google.cloud.datastore.client.Client`
which represents a project (string) and namespace (string) bundled with
a connection and has convenience methods for constructing objects with that
project / namespace.
which represents a project (string), database (string), and namespace
(string) bundled with a connection and has convenience methods for
constructing objects with that project/database/namespace.
- :class:`~google.cloud.datastore.entity.Entity`
which represents a single entity in the datastore
Expand Down
38 changes: 37 additions & 1 deletion google/cloud/datastore/_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def _request(
data,
base_url,
client_info,
database,
retry=None,
timeout=None,
):
Expand All @@ -84,6 +85,9 @@ def _request(
:type client_info: :class:`google.api_core.client_info.ClientInfo`
:param client_info: used to generate user agent.
:type database: str
:param database: The database to make the request for.
:type retry: :class:`google.api_core.retry.Retry`
:param retry: (Optional) retry policy for the request
Expand All @@ -101,6 +105,7 @@ def _request(
"User-Agent": user_agent,
connection_module.CLIENT_INFO_HEADER: user_agent,
}
_update_headers(headers, project, database)
api_url = build_api_url(project, method, base_url)

requester = http.request
Expand Down Expand Up @@ -136,6 +141,7 @@ def _rpc(
client_info,
request_pb,
response_pb_cls,
database,
retry=None,
timeout=None,
):
Expand Down Expand Up @@ -165,6 +171,9 @@ def _rpc(
:param response_pb_cls: The class used to unmarshall the response
protobuf.
:type database: str
:param database: The database to make the request for.
:type retry: :class:`google.api_core.retry.Retry`
:param retry: (Optional) retry policy for the request
Expand All @@ -177,7 +186,7 @@ def _rpc(
req_data = request_pb._pb.SerializeToString()
kwargs = _make_retry_timeout_kwargs(retry, timeout)
response = _request(
http, project, method, req_data, base_url, client_info, **kwargs
http, project, method, req_data, base_url, client_info, database, **kwargs
)
return response_pb_cls.deserialize(response)

Expand Down Expand Up @@ -236,6 +245,7 @@ def lookup(self, request, retry=None, timeout=None):
"""
request_pb = _make_request_pb(request, _datastore_pb2.LookupRequest)
project_id = request_pb.project_id
database_id = request_pb.database_id

return _rpc(
self.client._http,
Expand All @@ -245,6 +255,7 @@ def lookup(self, request, retry=None, timeout=None):
self.client._client_info,
request_pb,
_datastore_pb2.LookupResponse,
database_id,
retry=retry,
timeout=timeout,
)
Expand All @@ -267,6 +278,7 @@ def run_query(self, request, retry=None, timeout=None):
"""
request_pb = _make_request_pb(request, _datastore_pb2.RunQueryRequest)
project_id = request_pb.project_id
database_id = request_pb.database_id

return _rpc(
self.client._http,
Expand All @@ -276,6 +288,7 @@ def run_query(self, request, retry=None, timeout=None):
self.client._client_info,
request_pb,
_datastore_pb2.RunQueryResponse,
database_id,
retry=retry,
timeout=timeout,
)
Expand All @@ -300,6 +313,7 @@ def run_aggregation_query(self, request, retry=None, timeout=None):
request, _datastore_pb2.RunAggregationQueryRequest
)
project_id = request_pb.project_id
database_id = request_pb.database_id

return _rpc(
self.client._http,
Expand All @@ -309,6 +323,7 @@ def run_aggregation_query(self, request, retry=None, timeout=None):
self.client._client_info,
request_pb,
_datastore_pb2.RunAggregationQueryResponse,
database_id,
retry=retry,
timeout=timeout,
)
Expand All @@ -331,6 +346,7 @@ def begin_transaction(self, request, retry=None, timeout=None):
"""
request_pb = _make_request_pb(request, _datastore_pb2.BeginTransactionRequest)
project_id = request_pb.project_id
database_id = request_pb.database_id

return _rpc(
self.client._http,
Expand All @@ -340,6 +356,7 @@ def begin_transaction(self, request, retry=None, timeout=None):
self.client._client_info,
request_pb,
_datastore_pb2.BeginTransactionResponse,
database_id,
retry=retry,
timeout=timeout,
)
Expand All @@ -362,6 +379,7 @@ def commit(self, request, retry=None, timeout=None):
"""
request_pb = _make_request_pb(request, _datastore_pb2.CommitRequest)
project_id = request_pb.project_id
database_id = request_pb.database_id

return _rpc(
self.client._http,
Expand All @@ -371,6 +389,7 @@ def commit(self, request, retry=None, timeout=None):
self.client._client_info,
request_pb,
_datastore_pb2.CommitResponse,
database_id,
retry=retry,
timeout=timeout,
)
Expand All @@ -393,6 +412,7 @@ def rollback(self, request, retry=None, timeout=None):
"""
request_pb = _make_request_pb(request, _datastore_pb2.RollbackRequest)
project_id = request_pb.project_id
database_id = request_pb.database_id

return _rpc(
self.client._http,
Expand All @@ -402,6 +422,7 @@ def rollback(self, request, retry=None, timeout=None):
self.client._client_info,
request_pb,
_datastore_pb2.RollbackResponse,
database_id,
retry=retry,
timeout=timeout,
)
Expand All @@ -424,6 +445,7 @@ def allocate_ids(self, request, retry=None, timeout=None):
"""
request_pb = _make_request_pb(request, _datastore_pb2.AllocateIdsRequest)
project_id = request_pb.project_id
database_id = request_pb.database_id

return _rpc(
self.client._http,
Expand All @@ -433,6 +455,7 @@ def allocate_ids(self, request, retry=None, timeout=None):
self.client._client_info,
request_pb,
_datastore_pb2.AllocateIdsResponse,
database_id,
retry=retry,
timeout=timeout,
)
Expand All @@ -455,6 +478,7 @@ def reserve_ids(self, request, retry=None, timeout=None):
"""
request_pb = _make_request_pb(request, _datastore_pb2.ReserveIdsRequest)
project_id = request_pb.project_id
database_id = request_pb.database_id

return _rpc(
self.client._http,
Expand All @@ -464,6 +488,18 @@ def reserve_ids(self, request, retry=None, timeout=None):
self.client._client_info,
request_pb,
_datastore_pb2.ReserveIdsResponse,
database_id,
retry=retry,
timeout=timeout,
)


def _update_headers(headers, project_id, database_id=None):
"""Update the request headers.
Pass the project id, or optionally the database_id if provided.
"""
headers["x-goog-request-params"] = f"project_id={project_id}"
if database_id:
headers[
"x-goog-request-params"
] = f"project_id={project_id}&database_id={database_id}"
30 changes: 17 additions & 13 deletions google/cloud/datastore/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ def _next_page(self):

partition_id = entity_pb2.PartitionId(
project_id=self._aggregation_query.project,
database_id=self.client.database,
namespace_id=self._aggregation_query.namespace,
)

Expand All @@ -386,14 +387,15 @@ def _next_page(self):

if self._timeout is not None:
kwargs["timeout"] = self._timeout

request = {
"project_id": self._aggregation_query.project,
"partition_id": partition_id,
"read_options": read_options,
"aggregation_query": query_pb,
}
helpers.set_database_id_to_request(request, self.client.database)
response_pb = self.client._datastore_api.run_aggregation_query(
request={
"project_id": self._aggregation_query.project,
"partition_id": partition_id,
"read_options": read_options,
"aggregation_query": query_pb,
},
request=request,
**kwargs,
)

Expand All @@ -406,13 +408,15 @@ def _next_page(self):
query_pb = query_pb2.AggregationQuery()
query_pb._pb.CopyFrom(old_query_pb._pb) # copy for testability

request = {
"project_id": self._aggregation_query.project,
"partition_id": partition_id,
"read_options": read_options,
"aggregation_query": query_pb,
}
helpers.set_database_id_to_request(request, self.client.database)
response_pb = self.client._datastore_api.run_aggregation_query(
request={
"project_id": self._aggregation_query.project,
"partition_id": partition_id,
"read_options": read_options,
"aggregation_query": query_pb,
},
request=request,
**kwargs,
)

Expand Down
31 changes: 25 additions & 6 deletions google/cloud/datastore/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,15 @@ def project(self):
"""
return self._client.project

@property
def database(self):
"""Getter for database in which the batch will run.
:rtype: :class:`str`
:returns: The database in which the batch will run.
"""
return self._client.database

@property
def namespace(self):
"""Getter for namespace in which the batch will run.
Expand Down Expand Up @@ -218,6 +227,9 @@ def put(self, entity):
if self.project != entity.key.project:
raise ValueError("Key must be from same project as batch")

if self.database != entity.key.database:
raise ValueError("Key must be from same database as batch")

if entity.key.is_partial:
entity_pb = self._add_partial_key_entity_pb()
self._partial_key_entities.append(entity)
Expand Down Expand Up @@ -245,6 +257,9 @@ def delete(self, key):
if self.project != key.project:
raise ValueError("Key must be from same project as batch")

if self.database != key.database:
raise ValueError("Key must be from same database as batch")

key_pb = key.to_protobuf()
self._add_delete_key_pb()._pb.CopyFrom(key_pb._pb)

Expand Down Expand Up @@ -281,13 +296,17 @@ def _commit(self, retry, timeout):
if timeout is not None:
kwargs["timeout"] = timeout

request = {
"project_id": self.project,
"mode": mode,
"transaction": self._id,
"mutations": self._mutations,
}

helpers.set_database_id_to_request(request, self._client.database)

commit_response_pb = self._client._datastore_api.commit(
request={
"project_id": self.project,
"mode": mode,
"transaction": self._id,
"mutations": self._mutations,
},
request=request,
**kwargs,
)

Expand Down
Loading

0 comments on commit abf0060

Please sign in to comment.