-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Uses torchmetrics for metric computation #284
base: master
Are you sure you want to change the base?
Conversation
Closes CUNY-CL#158. * Loss is computed as before, but streamlined somewhat. * `torchmetrics`' implementation of exact match accuracy is lightly adapted. This does everything in tensor-land and should keep things on devices. My tests confirm that accuracy is EXACTLY what it was before. * A `torchmetrics`-compatible implementation of symbol error rate (here defined as the edit distance divided by sum of target lengths) is inserted here. This is heavily documented and it is compatible with our existing implementation. The hot inner loop is still on CPU, but as mentioned in the documentation, this is probably the best option and I don't observe any obvious performance penalty when enabling this. * We do away with the `evaluation` module altogether. Rather we treat the metrics objects as nullables living in the base class, a design adapted from UDTube. The CLI interface is unimpacted, and my side-by-side shows the metrics are exactly the same as before this change.
Nice!
Cool, I think when we added our implementation, torchmetrics was requiring strings right? Glad ti see this.
So we go from a denominator of 1 (I am interpreting word as sequence, since we do not do pretokenization) to a denominator of the length of the gold symbols? Makes sense if so, but I cannot really remember what is standard.
Generally awesome!! Looking at the implementation, I am wondering why we replace the generic set of evals with specific metric attributes and metric booleans? What are the benefits of this (where I see the downsides as adding some bloat to the code, and lots of steps for including new metrics)?
Does RNN library code mean for the |
@@ -64,11 +66,12 @@ def __init__( | |||
vocab_size, | |||
# All of these have keyword defaults. | |||
beam_width=defaults.BEAM_WIDTH, | |||
compute_accuracy=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we initializing these here since we define a property with a method below?
@property | ||
def num_parameters(self) -> int: | ||
return sum(part.numel() for part in self.parameters()) | ||
def compute_accuracy(self) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is always a piece of me that thinks naming boolean properties like this in a dynamically typed language is confusing (it sounds like a method for computing the accuracy). Ofc, alternatives that try to be cute are not my favorite either (e.g. should_compute_accuracy
/is_compute_accuracy
), so idk if argue for a change, but always want to call it out in case someone has a better idea :D.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll try has_accuracy
instead.
dropout_layer: nn.Dropout | ||
eval_metrics: Set[evaluators.Evaluator] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok I made a top-level comment about this. Looking at the code and thinking about it, I suppose if we wanted to avoid all of the metric specific properties/null checks, we would need to implement and maintain a generic Metric
, which I suppose might create friction for implementing new metrics so I agree with this change. Adding a new metric probably only requires adding code in 3 places in this file (beyond implementing the metric and updating the CLI) right?
|
||
Returns: | ||
Dict: averaged metrics over all validation steps. | ||
def test_step( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is new, yes? I guess it does the same thing as validation but specifies "test" mode?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah it's something you can do in LightningCLI.
) | ||
return loss | ||
|
||
def _reset_metrics(self) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a link to the torchmetrics docs explaining this that we could put here?
|
||
def validation_step(self, batch: data.PaddedBatch, batch_idx: int) -> Dict: | ||
def on_validation_epoch_start(self) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this be inherited? Or does some class in its inheritance override this so we have to override it back?
def on_validation_epoch_start(self) -> None: | ||
self._reset_metrics() | ||
|
||
def validation_step(self, batch: data.PaddedBatch, batch_idx: int) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This also looks the same as the base method.
self._log_loss(loss, len(batch), "val") | ||
self._update_metrics(predictions, batch.target.padded) | ||
|
||
def on_validation_epoch_end(self) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
Just to say: this isn't really ready for review yet. It depends on a lot of other small changes which I'll make first. I thought I could do it all in one go: I was wrong.
Yep. It wasn't too hard to make our own.
Actually I mispoke here somewhat because I misread the code: the denominator was the length of the tensor; now it's the length of the string the tensor denotes.
Benefits: data lives on the accelerator, like loss data does. I actually don't think the steps to add metrics is meaningfully more difficult than what we had previously, so I don't see any big downsides either. It also can be documented without much trouble.
I originally piloted with that. Most things in the Torch universe (including loss functions, but also everything in |
Sorry for my delay. This mostly all makes sense to me. " 90% of my debugging here involves transpositions and making shape conform." -- this has typically been my experience in general when writing torch code :). I am mostly trying to suggest that having a yoyodyne default assumption of what shape tensors are in would be nice -- and I think would make it conceptually easier to visualize tensors as you code. If I follow correctly though, the decisions you made in reshaping sounds very reasonable. |
Closes #158.
torchmetrics
' implementation of exact match accuracy is lightly adapted. This does everything in tensor-land and should keep things on devices. My tests confirm that accuracy is what it was before.torchmetrics
-compatible implementation of symbol error rate (here defined as the edit distance divided by sum of target lengths) is inserted here. This is heavily documented and it is compatible with our existing implementation. The hot inner loop is still on CPU, but as mentioned in the documentation, this is probably the best option and I don't observe any obvious performance penalty when enabling this. Note that the computation of the old one was not as normally defined and gave the number of edits per word, not the number of edits per gold target symbol.evaluation
module altogether. Rather we treat the metrics objects as nullables living in the base class, a design adapted from UDTube.The CLI interface is unimpacted.