Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion examples/gans/pix2pix_facades.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import tensorflow as tf

from ashpy import LogEvalMode
from ashpy.callbacks import LogImageGANCallback
from ashpy.losses.gan import (
AdversarialLossType,
Pix2PixLoss,
Expand Down Expand Up @@ -177,6 +178,8 @@ def main(
if not logdir.exists():
logdir.mkdir(parents=True)

callbacks = [LogImageGANCallback()]

trainer = AdversarialTrainer(
generator=generator,
discriminator=discriminator,
Expand All @@ -188,9 +191,10 @@ def main(
metrics=metrics,
logdir=logdir,
log_eval_mode=LogEvalMode.TEST,
callbacks=callbacks,
)

train_dataset = tf.data.Dataset.list_files(PATH + "train/*.jpg")
train_dataset = tf.data.Dataset.list_files(str(PATH / "train/*.jpg"))
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.map(load_image_train)
train_dataset = train_dataset.batch(BATCH_SIZE)
Expand Down
7 changes: 4 additions & 3 deletions examples/gans/pix2pix_facades_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Input Pipeline taken from: https://www.tensorflow.org/beta/tutorials/generative/pix2pix
"""
import os
from pathlib import Path

import tensorflow as tf

Expand All @@ -31,7 +32,7 @@
from ashpy.models.convolutional.unet import FUNet
from ashpy.trainers.gan import AdversarialTrainer

from .pix2pix_facades import BATCH_SIZE, BUFFER_SIZE, IMG_WIDTH, PATH, load_image_train
from pix2pix_facades import BATCH_SIZE, BUFFER_SIZE, IMG_WIDTH, PATH, load_image_train


def main(
Expand Down Expand Up @@ -94,7 +95,7 @@ def main(
)

metrics = []
logdir = f'{"log"}/{dataset_name}/run_multi'
logdir = Path(f'{"log"}/{dataset_name}/run_multi')

if not logdir.exists():
logdir.mkdir(parents=True)
Expand All @@ -116,7 +117,7 @@ def main(
log_eval_mode=LogEvalMode.TEST,
)

train_dataset = tf.data.Dataset.list_files(PATH + "train/*.jpg")
train_dataset = tf.data.Dataset.list_files(str(PATH / "train/*.jpg"))
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.map(load_image_train)
train_dataset = train_dataset.batch(BATCH_SIZE)
Expand Down
35 changes: 24 additions & 11 deletions src/ashpy/callbacks/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

"""Classifier callbacks."""

import tensorflow as tf
from ashpy.callbacks.counter_callback import CounterCallback
from ashpy.callbacks.events import Event
from ashpy.contexts import ClassifierContext
Expand All @@ -30,26 +30,29 @@
def __init__(
self,
event: Event = Event.ON_EPOCH_END,
name="log_classifier_callback",
name: str = "log_classifier_callback",
event_freq: int = 1,
input_is_zero_centered: bool = True,
):
"""
Initialize the LogClassifierCallback.

Args:
event: event to consider
event_freq: frequency of logging

name (str): name of the callback
event (ashpy.events.Event): event to consider
event_freq (int): frequency of logging
input_is_zero_centered (bool): if True, the callback assumes the input is in [-1,
1] if it is an image with type tf.float. If False, the callback assumes the input
is [0, 1] if type float, and [0, 255] if type is uint. If the input type is float
and the image is in [0, 1] use False. If the input type is uint this parameter is
ignored.
"""
super(LogClassifierCallback, self).__init__(
event=event,
fn=LogClassifierCallback._log_fn,
name=name,
event_freq=event_freq,
event=event, fn=self._log_fn, name=name, event_freq=event_freq,
)
self._input_is_zero_centered = input_is_zero_centered

Check warning on line 53 in src/ashpy/callbacks/classifier.py

View check run for this annotation

Codecov / codecov/patch

src/ashpy/callbacks/classifier.py#L53

Added line #L53 was not covered by tests

@staticmethod
def _log_fn(context: ClassifierContext) -> None:
def _log_fn(self, context: ClassifierContext) -> None:
"""
Log output of the image and label to Tensorboard.

Expand All @@ -60,5 +63,15 @@
input_tensor = context.current_batch[0]
out_label = context.current_batch[1]

rank = tf.rank(input_tensor)

Check warning on line 66 in src/ashpy/callbacks/classifier.py

View check run for this annotation

Codecov / codecov/patch

src/ashpy/callbacks/classifier.py#L66

Added line #L66 was not covered by tests

# if it is an image, check if we need to scale and shift
if (

Check warning on line 69 in src/ashpy/callbacks/classifier.py

View check run for this annotation

Codecov / codecov/patch

src/ashpy/callbacks/classifier.py#L69

Added line #L69 was not covered by tests
tf.equal(rank, 4)
and (input_tensor.dtype == tf.float32 or input_tensor.dtype == tf.float64)
and self._input_is_zero_centered
):
input_tensor = (input_tensor + 1) / 2

Check warning on line 74 in src/ashpy/callbacks/classifier.py

View check run for this annotation

Codecov / codecov/patch

src/ashpy/callbacks/classifier.py#L74

Added line #L74 was not covered by tests

log("input_x", input_tensor, context.global_step)
log("input_y", out_label, context.global_step)
4 changes: 4 additions & 0 deletions src/ashpy/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ class Attention(tf.keras.Model):
# the same as the input shape
print(output.shape)

.. testoutput::

(1, 10, 10, 64)

* Inside a Model:

.. testcode::
Expand Down
4 changes: 4 additions & 0 deletions src/ashpy/layers/instance_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ class InstanceNormalization(tf.keras.layers.Layer):
# the same as the input shape.
print(output.shape)

.. testoutput::

(1, 10, 10, 64)

* Inside a Model:

.. testcode::
Expand Down
8 changes: 5 additions & 3 deletions src/ashpy/losses/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,20 @@
class ClassifierLoss(Executor):
r"""Classifier Loss Executor using the classifier model, instantiated with a fn."""

def __init__(self, fn: tf.keras.losses.Loss) -> None:
def __init__(self, fn: tf.keras.losses.Loss, name: str = "ClassifierLoss") -> None:
r"""
Initialize :py:class:`ClassifierLoss`.

Args:
fn (:py:class:`tf.keras.losses.Loss`): Classification Loss function, should
take as input labels and prediction.
name (str): Name of the loss. It will be used for logging in Tensorboard.

Returns:
:py:obj:`None`

"""
super().__init__(fn)
super().__init__(fn, name=name)

@Executor.reduce_loss
def call(
Expand Down Expand Up @@ -69,4 +70,5 @@ def call(
lambda: loss,
lambda: tf.expand_dims(tf.expand_dims(loss, axis=-1), axis=-1),
)
return tf.reduce_mean(loss, axis=[1, 2])
loss = tf.reduce_mean(loss, axis=[1, 2])
return loss
59 changes: 49 additions & 10 deletions src/ashpy/losses/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,13 @@
class Executor:
"""Carry a function and the way of executing it. Given a context."""

def __init__(self, fn: tf.keras.losses.Loss = None) -> None:
def __init__(self, fn: tf.keras.losses.Loss = None, name: str = "loss") -> None:
"""
Initialize the Executor.

Args:
fn (:py:class:`tf.keras.losses.Loss`): A Keras Loss to execute.
name (str): Name of the loss. It will be be used for logging in TensorBoard.

Returns:
:py:obj:`None`
Expand All @@ -48,6 +49,13 @@
self._distribute_strategy = tf.distribute.get_strategy()
self._global_batch_size = -1
self._weight = lambda _: 1.0
self._name = name
self._loss_value = 0

@property
def name(self) -> str:
"""Return the name of the loss."""
return self._name

@property
def weight(self) -> Callable[..., float]:
Expand Down Expand Up @@ -153,17 +161,29 @@
:py:obj:`tf.Tensor`: Output Tensor.

"""
return self._weight(context.global_step) * self.call(context, **kwargs)
self._loss_value = self._weight(context.global_step) * self.call(
context, **kwargs
)
return self._loss_value

def __add__(self, other) -> SumExecutor:
def log(self, step: tf.Variable):
"""
Log the loss on Tensorboard.

Args:
step (tf.Variable): current training step.
"""
tf.summary.scalar(f"ashpy/losses/{self._name}", self._loss_value, step=step)

Check warning on line 176 in src/ashpy/losses/executor.py

View check run for this annotation

Codecov / codecov/patch

src/ashpy/losses/executor.py#L176

Added line #L176 was not covered by tests

def __add__(self, other: Union[SumExecutor, Executor]) -> SumExecutor:
"""Concatenate Executors together into a SumExecutor."""
if isinstance(other, SumExecutor):
other_executors = other.executors
else:
other_executors = [other]

all_executors = [self] + other_executors
return SumExecutor(all_executors)
return SumExecutor(all_executors, name=f"{self._name}+{other._name}")

def __mul__(self, other: Union[Callable[..., float], float, int, tf.Tensor]):
"""
Expand All @@ -185,8 +205,8 @@
self._weight = lambda step: weight(step) * __other(step)
return self

def __rmul__(self, other):
"""See `__mul__` method."""
def __rmul__(self, other: Union[SumExecutor, Executor]):
"""See ``__mul__`` method."""
return self * other


Expand All @@ -198,19 +218,20 @@
then summed together.
"""

def __init__(self, executors) -> None:
def __init__(self, executors: List[Executor], name: str = "LossSum") -> None:
"""
Initialize the SumExecutor.

Args:
executors (:py:obj:`list` of [:py:class:`ashpy.executors.Executor`]): Array of
:py:obj:`ashpy.executors.Executor` to sum evaluate and sum together.
name (str): Name of the loss. It will be used to log in TensorBoard.

Returns:
:py:obj:`None`

"""
super().__init__()
super().__init__(name=name)
self._executors = executors
self._global_batch_size = 1

Expand All @@ -219,6 +240,11 @@
"""Return the List of Executors."""
return self._executors

@property
def sublosses(self) -> List[Executor]:
"""Return the List of Executors."""
return self._executors

@Executor.global_batch_size.setter # pylint: disable=no-member
def global_batch_size(self, global_batch_size: int) -> None:
"""Set global batch size property."""
Expand All @@ -235,8 +261,21 @@
:py:classes:`tf.Tensor`: Output Tensor.

"""
result = tf.add_n([executor(*args, **kwargs) for executor in self._executors])
return result
self._loss_value = tf.add_n(

Check warning on line 264 in src/ashpy/losses/executor.py

View check run for this annotation

Codecov / codecov/patch

src/ashpy/losses/executor.py#L264

Added line #L264 was not covered by tests
[executor(*args, **kwargs) for executor in self._executors]
)
return self._loss_value

Check warning on line 267 in src/ashpy/losses/executor.py

View check run for this annotation

Codecov / codecov/patch

src/ashpy/losses/executor.py#L267

Added line #L267 was not covered by tests

def log(self, step: tf.Variable):
"""
Log the loss + all the sub-losses on Tensorboard.

Args:
step: current step
"""
super().log(step)
for executor in self._executors:
executor.log(step)

Check warning on line 278 in src/ashpy/losses/executor.py

View check run for this annotation

Codecov / codecov/patch

src/ashpy/losses/executor.py#L276-L278

Added lines #L276 - L278 were not covered by tests

def __add__(self, other: Union[SumExecutor, Executor]):
"""Concatenate Executors together into a SumExecutor."""
Expand Down
Loading