Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix get line join #3

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 41 additions & 63 deletions src/Core/Storage/Database.py
Original file line number Diff line number Diff line change
@@ -422,13 +422,15 @@ def __add_data(self,
def update(self,
table_name: str,
data: Dict[str, Any],
line_id: int = -1):
line_id: int = -1,
create_fields: bool = False):
"""
Update a line of a Table.

:param table_name: Name of the Table on which to perform the query.
:param data: Updated data of the line.
:param line_id: Index of the line to update.
:param create_fields: Create missing fields.
"""

# Check table existence
@@ -443,13 +445,20 @@ def update(self,

# Define the line index
nb_line = self.nb_lines(table_name=table_name)
if nb_line == 0:
self.add_data(table_name=table_name, data=data)
return
if line_id < 0:
line_id += nb_line + 1
elif line_id > nb_line:
line_id = nb_line

# Check fields existence
undefined_fields = set(fields_names) - set(table.fields())
if create_fields:
fields_to_create = [(undef, type(data[undef])) for undef in undefined_fields]
self.create_fields(table_name=table_name, fields=fields_to_create)
undefined_fields = set(fields_names) - set(table.fields())
if len(undefined_fields) > 0:
raise ValueError(f"[{self.__class__.__name__}] Some fields where not defined in table {table}."
f" As table {table} is non-empty, please define first the following fields :"
@@ -476,15 +485,13 @@ def update(self,
def get_line(self,
table_name: str,
fields: Optional[Union[str, List[str]]] = None,
line_id: int = -1,
joins: Optional[Union[str, List[str]]] = None):
line_id: int = -1):
"""
Get a line of a Table.

:param table_name: Name of the Table on which to perform the query.
:param fields: Name(s) of the Field(s) to request.
:param line_id: Index of the line to get.
:param joins: Name(s) of Table(s) to join to the selection.
"""

# Check the Table existence
@@ -495,19 +502,13 @@ def get_line(self,

# Define the fields to select
fields_selection = ()
if fields is not None:
fields_selection += (table.id,)
fields = [fields] if type(fields) == str else fields
for field in fields:
if field in table.fields():
fields_selection += (table.fields(only_names=False)[field],)
if joins is not None:
joins = [joins] if type(joins) == str else joins
for j in joins:
if j in self.__fk[table_name].values() and j not in fields:
field_name = list(self.__fk[table_name].keys())[
list(self.__fk[table_name].values()).index(j)]
fields_selection += (table.fields(only_names=False)[field_name],)
if fields is None:
fields = table.fields()
fields_selection += (table.id,)
fields = [fields] if type(fields) == str else fields
for field in fields:
if field in table.fields():
fields_selection += (table.fields(only_names=False)[field],)

# Define the index of the line to select
nb_line = self.nb_lines(table_name=table_name)
@@ -520,16 +521,11 @@ def get_line(self,
data = table.select(*fields_selection).where(table.id == line_id).dicts()[0]

# Join
if joins is not None:
joins = [joins] if type(joins) == str else joins
for j in joins:
if j in self.__fk[table_name].values():
field_name = list(self.__fk[table_name].keys())[list(self.__fk[table_name].values()).index(j)]
if field_name in data:
data[field_name] = self.get_line(table_name=j,
fields=fields,
line_id=data[field_name],
joins=j)
for field in fields:
if field in self.__fk[table_name].keys():
data[field] = self.get_line(table_name=self.__fk[table_name][field],
fields=fields,
line_id=data[field])

return data

@@ -538,7 +534,6 @@ def get_lines(self,
fields: Optional[Union[str, List[str]]] = None,
lines_id: Optional[List[int]] = None,
lines_range: Optional[List[int]] = None,
joins: Optional[Union[str, List[str]]] = None,
batched: bool = False):
"""
Get a set of lines of a Table.
@@ -547,7 +542,6 @@ def get_lines(self,
:param fields: Name(s) of the Field(s) to select.
:param lines_id: Indices of the lines to get. If not specified, 'lines_range' value will be used.
:param lines_range: Range of indices of the lines to get. If not specified, all lines will be selected.
:param joins: Name(s) of Table(s) to join to the selection.
:param batched: If True, data is returned as one batch per field. Otherwise, data is returned as list of lines.
"""

@@ -559,19 +553,13 @@ def get_lines(self,

# Define the fields to select
fields_selection = ()
if fields is not None:
fields_selection += (table.id,)
fields = [fields] if type(fields) == str else fields
for field in fields:
if field in table.fields():
fields_selection += (table.fields(only_names=False)[field],)
if joins is not None:
joins = [joins] if type(joins) == str else joins
for j in joins:
if j in self.__fk[table_name].values() and j not in fields:
field_name = list(self.__fk[table_name].keys())[
list(self.__fk[table_name].values()).index(j)]
fields_selection += (table.fields(only_names=False)[field_name],)
if fields is None:
fields = table.fields()
fields_selection += (table.id,)
fields = [fields] if type(fields) == str else fields
for field in fields:
if field in table.fields():
fields_selection += (table.fields(only_names=False)[field],)

# Define the indices of lines to select
if lines_id is None:
@@ -601,27 +589,17 @@ def get_lines(self,
lines = [line for line in query]

# Join
if joins is not None:
joins = [joins] if type(joins) == str else joins
for j in joins:
if j in self.__fk[table_name].values():
field_name = list(self.__fk[table_name].keys())[
list(self.__fk[table_name].values()).index(j)]
dict_keys = lines.keys() if batched else lines[0].keys()
if field_name in dict_keys:
lines_id = lines[field_name] if batched else [line[field_name] for line in lines]
data = self.get_lines(table_name=j,
fields=fields,
lines_id=lines_id,
joins=joins,
batched=batched)

if batched:
lines[field_name] = data
else:
for i, l in enumerate(data):
lines[i][field_name] = l

for field in fields:
if field in self.__fk[table_name].keys():
data = self.get_lines(table_name=self.__fk[table_name][field],
fields=fields,
lines_id=lines[field] if batched else [line[field] for line in lines],
batched=batched)
if batched:
lines[field] = data
else:
for i, l in enumerate(data):
lines[i][field] = l
return lines

def nb_lines(self,