Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing Adaptive Loss Weights via Callback - tf.compat.v1 #1085

Open
PhilippBrendel opened this issue Dec 10, 2022 · 12 comments
Open

Implementing Adaptive Loss Weights via Callback - tf.compat.v1 #1085

PhilippBrendel opened this issue Dec 10, 2022 · 12 comments

Comments

@PhilippBrendel
Copy link

PhilippBrendel commented Dec 10, 2022

Hi everyone,

I've read in other Issues (e.g. #215 and #908) that adaptive Loss-Weights are not high-priority for DeepXDE, but I still want to test some approaches for that, as I see quite some potential for my current use-case.
However, implementing this via a Callback like the following does not really work for me so far (cf. Error-message below).

class LossWeightCallback(dde.callbacks.Callback):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def on_epoch_end(self):
        ....
        self.model.compile("adam", lr=1e-3, decay=None,
                                       loss_weights=[1,1,1,1])
Click for Error Message
Training model...

Step      Train loss                                  Test loss                                   Test metric
0         [9.15e+04, 2.41e+00, 1.07e-07, 2.08e-06]    [3.41e+04, 9.30e-01, 0.00e+00, 0.00e+00]    []
Compiling model...
'compile' took 9.111588 s

Traceback (most recent call last):
  File "D:\Anaconda3\envs\deepxde_tf24\lib\site-packages\tensorflow\python\client\session.py", line 1375, in _do_call
    return fn(*args)
  File "D:\Anaconda3\envs\deepxde_tf24\lib\site-packages\tensorflow\python\client\session.py", line 1359, in _run_fn
    return self._call_tf_sessionrun(options, feed_dict, fetch_list,
  File "D:\Anaconda3\envs\deepxde_tf24\lib\site-packages\tensorflow\python\client\session.py", line 1451, in _call_tf_sessionrun
    return tf_session.TF_SessionRun_wrapper(self._session, options, feed_dict,
tensorflow.python.framework.errors_impl.FailedPreconditionError: Attempting to use uninitialized value beta1_power_1
         [[{{node beta1_power_1/read}}]]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "maxwell_quasistatic.py", line 1689, in <module>
    pinn.train()
  File "maxwell_quasistatic.py", line 1282, in train
    loss_hist, train_state = self.model.train(
  File "D:\Anaconda3\envs\deepxde_tf24\lib\site-packages\deepxde\utils\internal.py", line 22, in wrapper
    result = f(*args, **kwargs)
  File "D:\Anaconda3\envs\deepxde_tf24\lib\site-packages\deepxde\model.py", line 589, in train
    self._train_sgd(iterations, display_every)
  File "D:\Anaconda3\envs\deepxde_tf24\lib\site-packages\deepxde\model.py", line 606, in _train_sgd
    self._train_step(
  File "D:\Anaconda3\envs\deepxde_tf24\lib\site-packages\deepxde\model.py", line 505, in _train_step
    self.sess.run(self.train_step, feed_dict=feed_dict)
  File "D:\Anaconda3\envs\deepxde_tf24\lib\site-packages\tensorflow\python\client\session.py", line 967, in run
    result = self._run(None, fetches, feed_dict, options_ptr,
  File "D:\Anaconda3\envs\deepxde_tf24\lib\site-packages\tensorflow\python\client\session.py", line 1190, in _run
    results = self._do_run(handle, final_targets, final_fetches,
  File "D:\Anaconda3\envs\deepxde_tf24\lib\site-packages\tensorflow\python\client\session.py", line 1368, in _do_run
    return self._do_call(_run_fn, feeds, fetches, targets, options,
  File "D:\Anaconda3\envs\deepxde_tf24\lib\site-packages\tensorflow\python\client\session.py", line 1394, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.FailedPreconditionError: Attempting to use uninitialized value beta1_power_1
         [[node beta1_power_1/read (defined at D:\Anaconda3\envs\deepxde_tf24\lib\site-packages\deepxde\optimizers\tensorflow_compat_v1\optimizers.py:58) ]]

Original stack trace for 'beta1_power_1/read':
  File "maxwell_quasistatic.py", line 1689, in <module>
    pinn.train()
  File "maxwell_quasistatic.py", line 1282, in train
    loss_hist, train_state = self.model.train(
  File "D:\Anaconda3\envs\deepxde_tf24\lib\site-packages\deepxde\utils\internal.py", line 22, in wrapper
    result = f(*args, **kwargs)
  File "D:\Anaconda3\envs\deepxde_tf24\lib\site-packages\deepxde\model.py", line 589, in train
    self._train_sgd(iterations, display_every)
  File "D:\Anaconda3\envs\deepxde_tf24\lib\site-packages\deepxde\model.py", line 618, in _train_sgd
    self.callbacks.on_epoch_end()
  File "D:\Anaconda3\envs\deepxde_tf24\lib\site-packages\deepxde\callbacks.py", line 78, in on_epoch_end
    callback.on_epoch_end()
  File "D:\pinns\src\deepxde\utils\callbacks.py", line 312, in on_epoch_end
    self.pinn_obj.model.compile("adam", lr=1e-3, decay=None,
  File "D:\Anaconda3\envs\deepxde_tf24\lib\site-packages\deepxde\utils\internal.py", line 22, in wrapper
    result = f(*args, **kwargs)
  File "D:\Anaconda3\envs\deepxde_tf24\lib\site-packages\deepxde\model.py", line 124, in compile
    self._compile_tensorflow_compat_v1(lr, loss_fn, decay, loss_weights)
  File "D:\Anaconda3\envs\deepxde_tf24\lib\site-packages\deepxde\model.py", line 177, in _compile_tensorflow_compat_v1
    self.train_step = optimizers.get(
  File "D:\Anaconda3\envs\deepxde_tf24\lib\site-packages\deepxde\optimizers\tensorflow_compat_v1\optimizers.py", line 58, in get
    train_op = optim.minimize(loss, global_step=global_step)
  File "D:\Anaconda3\envs\deepxde_tf24\lib\site-packages\tensorflow\python\training\optimizer.py", line 412, in minimize
    return self.apply_gradients(grads_and_vars, global_step=global_step,
  File "D:\Anaconda3\envs\deepxde_tf24\lib\site-packages\tensorflow\python\training\optimizer.py", line 597, in apply_gradients
    self._create_slots(var_list)
  File "D:\Anaconda3\envs\deepxde_tf24\lib\site-packages\tensorflow\python\training\adam.py", line 131, in _create_slots
    self._create_non_slot_variable(
  File "D:\Anaconda3\envs\deepxde_tf24\lib\site-packages\tensorflow\python\training\optimizer.py", line 830, in _create_non_slot_variable
    v = variable_scope.variable(
  File "D:\Anaconda3\envs\deepxde_tf24\lib\site-packages\tensorflow\python\ops\variables.py", line 260, in __call__
    return cls._variable_v1_call(*args, **kwargs)
  File "D:\Anaconda3\envs\deepxde_tf24\lib\site-packages\tensorflow\python\ops\variables.py", line 206, in _variable_v1_call
    return previous_getter(
  File "D:\Anaconda3\envs\deepxde_tf24\lib\site-packages\tensorflow\python\ops\variables.py", line 199, in <lambda>
    previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs)
  File "D:\Anaconda3\envs\deepxde_tf24\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 2620, in default_variable_creator
    return variables.RefVariable(
  File "D:\Anaconda3\envs\deepxde_tf24\lib\site-packages\tensorflow\python\ops\variables.py", line 264, in __call__
    return super(VariableMetaclass, cls).__call__(*args, **kwargs)
  File "D:\Anaconda3\envs\deepxde_tf24\lib\site-packages\tensorflow\python\ops\variables.py", line 1656, in __init__
    self._init_from_args(
  File "D:\Anaconda3\envs\deepxde_tf24\lib\site-packages\tensorflow\python\ops\variables.py", line 1861, in _init_from_args
    self._snapshot = array_ops.identity(self._variable, name="read")
  File "D:\Anaconda3\envs\deepxde_tf24\lib\site-packages\tensorflow\python\util\dispatch.py", line 201, in wrapper
    return target(*args, **kwargs)
  File "D:\Anaconda3\envs\deepxde_tf24\lib\site-packages\tensorflow\python\ops\array_ops.py", line 287, in identity
    ret = gen_array_ops.identity(input, name=name)
  File "D:\Anaconda3\envs\deepxde_tf24\lib\site-packages\tensorflow\python\ops\gen_array_ops.py", line 3941, in identity
    _, _, _op, _outputs = _op_def_library._apply_op_helper(
  File "D:\Anaconda3\envs\deepxde_tf24\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 748, in _apply_op_helper
    op = g._create_op_internal(op_type_name, inputs, dtypes=None,
  File "D:\Anaconda3\envs\deepxde_tf24\lib\site-packages\tensorflow\python\framework\ops.py", line 3528, in _create_op_internal
    ret = Operation(
  File "D:\Anaconda3\envs\deepxde_tf24\lib\site-packages\tensorflow\python\framework\ops.py", line 1990, in __init__
    self._traceback = tf_stack.extract_stack()

I'm not an Expert on Core-Tensorflow (especially not TF1), so if anyone could give me an advice on what I'm doing wrong or how I can fix this, I'd really appreciate it!

Cheers,
Philipp

Ps: In #331 are some more comments, but I don't think they apply to my problem, as I'm not interested in using gradients for the weights initially.

@PhilippBrendel
Copy link
Author

        ...
        self.model.compile("adam", lr=1e-3, decay=None,
                                    loss_weights=[1, 1, 1, 1])
        self.model.sess.run(tf.global_variables_initializer())

Sorry for posting this a little early (in hindsight), but maybe this solution is helpful for others as well....

@sai-karthikeya-vemuri
Copy link

Hello Philipp,@PhilippBrendel and @lululxvi

Can you provide implementation of adaptive loss weights class?

I am also trying to apply the same thing for 2d wave equation. Since a FNN and MsFFN seem to fail, i would like to try adaptive weights.

Thanks in advance

@pescap
Copy link
Contributor

pescap commented Aug 23, 2023

Hi, I'll use this issue to centralize information regarding adaptive loss weighting:

Other references of interest:

Screenshot from 2023-08-23 09-30-03

Note: "lambda slighly improves the accuracy" (Fig 13). According to Fig. 13 it does not seem to be so efficient.

As stated by @lululxvi in #215, "based on my experience, fixed and adaptive weights have similar effects. As you can see in the papers you mentioned, the adaptive weights quickly converge to a fixed number, and thus fixed weights are basically sufficient. Also, it is recommended to use hard constraints for BC/IC, see FAQ".

I would definitely define adaptive weighting as callbacks. I'll try to figure out a structure for implementing a simple weighting scheme. I think that adaptive weighting can be useful for more involved loss terms (with e.g. 4-5 terms).

@pescap
Copy link
Contributor

pescap commented Aug 26, 2023

Hi! I have been exploring this issue.

For all these adaptive weighting techniques, we want to be able to update loss_weights during training with a callback, without having to compile the model again.

So loss_weights shall be initialized at first epoch.

The callback would adapt the weights as a on_epoch_end.

For example, in tensorflow.compat.v1 the multiplication is performed here:

deepxde/deepxde/model.py

Lines 181 to 182 in 683682c

if loss_weights is not None:
losses *= loss_weights

So, to begin with, we would comment out these two lines, and put them somewhere else at the beginning of train function.

Then, with a few changes, we could define and use self.loss_weights so that we can update the weights during training.

Do you agree @lululxvi ?

Also, do you prefer if I move this discussion to a new issue? Or could you please re-open this issue?

@lululxvi
Copy link
Owner

@pescap Yes, that sounds good.

@lululxvi lululxvi reopened this Sep 18, 2023
@pescap
Copy link
Contributor

pescap commented Oct 3, 2023

  • I'll start with a simple loss balancing algorithm (softadapt).
  • Softadapt measures the ratio of the loss value at each iteration to its value at the previous iteration, making it simple to implement.
  • This will allow me to set the infrastructure for loss balancing implementations.

As a first step, I am trying to facilitate the update of loss_weights during training, see #1511.

@jdellag
Copy link
Contributor

jdellag commented Nov 19, 2023

Really looking forward to this implementation!

@pescap
Copy link
Contributor

pescap commented Dec 6, 2023

Really looking forward to this implementation!

Working on this in #1586

@haison19952013
Copy link

  • I have experience with this type of work from my project, where I implemented my own PINN library in TensorFlow 2.0.
  • If you're interested in implementing this feature in TensorFlow 2.0, I'm happy to lend a hand.

@pescap
Copy link
Contributor

pescap commented Feb 1, 2024

  • I have experience with this type of work from my project, where I implemented my own PINN library in TensorFlow 2.0.
  • If you're interested in implementing this feature in TensorFlow 2.0, I'm happy to lend a hand.

Hi! Thank you for proposing! We could start with the 'tensorflow' backend. How do you suggest we proceed? We could discuss via the Deepxde Slack workspace.

@haison19952013
Copy link

  • I have experience with this type of work from my project, where I implemented my own PINN library in TensorFlow 2.0.
  • If you're interested in implementing this feature in TensorFlow 2.0, I'm happy to lend a hand.

Hi! Thank you for proposing! We could start with the 'tensorflow' backend. How do you suggest we proceed? We could discuss via the Deepxde Slack workspace.

Yes, could you give me the Slack ID or URL?

@pescap
Copy link
Contributor

pescap commented Feb 2, 2024

  • I have experience with this type of work from my project, where I implemented my own PINN library in TensorFlow 2.0.
  • If you're interested in implementing this feature in TensorFlow 2.0, I'm happy to lend a hand.

Hi! Thank you for proposing! We could start with the 'tensorflow' backend. How do you suggest we proceed? We could discuss via the Deepxde Slack workspace.

Yes, could you give me the Slack ID or URL?

Can you please send an email to @lululxvi asking him to add you?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants