Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Similarity between two curves using PyTorch #26

Open
Cram3r95 opened this issue Oct 24, 2022 · 20 comments
Open

Similarity between two curves using PyTorch #26

Cram3r95 opened this issue Oct 24, 2022 · 20 comments

Comments

@Cram3r95
Copy link

Hi guys,

I want to implement in my trainer a measure of similarity between my predicted trajectory and the GT trajectory. Here is an example:

imagen

The GT is the red line, my observation is the yellow line (almost hidden by the other participants) and the green line is my prediction. The other agents are not used at this moment.

Now, in order to train my DL based Motion Prediction algorithm I am using the ADE, FDE and NLL losses w.r.t. the GT. Nevertheless, I think that if my prediction does not match exactly the GT but it is in the same centerline (but driving with a different velocity, for example) it will be better. E.g.

imagen

This prediction does not match the GT (until the red diamond at the bottom), but at least the shapes of both curves are more or less the same.

How could I do that?

@cjekel
Copy link
Owner

cjekel commented Oct 26, 2022

  • If I was doing this problem.... I would probably just use MSE and MAE to compute the error on the path of the next 50 data points between GT and prediction.
  • I would not mind making a pytorch port of this library. That would be cool. Some people could use these as differentiable loss functions.
  • What is ADE and FDE?
  • Can you describe your data points for the prediction and the GT? Is it just a tensor of the X Y coordinates in a chronological order? Is there always a different number of data points between GT and your prediction?
  • This seems like a good case to use Partial DTW. Maybe I should implement a partial DTW algorithm?

@Cram3r95
Copy link
Author

Hi @cjekel . First of all, thank you for your quick response!!

The input data may be a little bit noisy, but my DL algorithm must generalize to noisy and non-noisy data in order to predict a multimodal trajectory. As expected, the multimodal prediction (both considering different directions (e.g. lanes: left, center, right) and also considering different velocity profiles for the same lane (i.e. slower vs faster), must be in the driveable area (that is, the lanes). That's why, in addition to ADE loss (Average Displacement Error == MSE of the whole prediction vs whole GT) and FDE loss (Final Displacement Error, similar to ADE but only for the last point), I would like to evaluate the similarity of the prediction shape to the proportional centerline.

A powerpoint example:

imagen

In order force the model to predict trajectories only in the driveable area.

@cjekel
Copy link
Owner

cjekel commented Oct 28, 2022

Are partial matches (between your prediction and GT) a good or bad thing in this case?

I can quickly write a pytorch variant of DTW that you can try it out with. It's not going to be well tested. The autograd should work, and it will be impressive if it works because DTW should have many non-continuous derivatives. That said, there is no reason why you wouldn't try.

@cjekel
Copy link
Owner

cjekel commented Oct 28, 2022

So this is a DTW distance in pytorch

def dtwtorch(exp_data, num_data):
    r"""
    Compute the Dynamic Time Warping distance.
    
    exp_data : torch tensor N x P
    num_data : torch tensor M x P
    """
    c = torch.cdist(exp_data, num_data, p=2)
    d = torch.empty_like(c)
    d[0, 0] = c[0, 0]
    n, m = c.shape
    for i in range(1, n):
        d[i, 0] = d[i-1, 0] + c[i, 0]
    for j in range(1, m):
        d[0, j] = d[0, j-1] + c[0, j]
    for i in range(1, n):
        for j in range(1, m):
            d[i, j] = c[i, j] + min((d[i-1, j], d[i, j-1], d[i-1, j-1]))
    return d[-1, -1]

How will you handle batching? Is GT going to be a fixed size, but your prediction going to be mini-batched? What will the shapes be for each tensor?

@cjekel
Copy link
Owner

cjekel commented Oct 28, 2022

Here is an example of minimizing DTW to fit a model in pytorch https://github.com/cjekel/similarity_measures/blob/dtwtorch/torch_DTW_demo.ipynb The resulting fit is actually pretty good, although it is not cheap to compute the DTW distance and derivatives this way for 150 data points on each path.

@Cram3r95
Copy link
Author

In my case I would have 30 points (the prediction) and the corresponding points of a centerline (the points between the closest waypoint to the first prediction point and the closest waypoint to the last prediction point), so I would say it is:

DTW for 30 points vs (from 10 to 30 approach). An example:

imagen

(Only the centerline between Start (S) and End (E) (so, only these waypoints) versus the prediction (which is made up by 30 points).

Not 150 vs 150. My idea is to integrate this as an additional loss function. Does it work for you?

@Cram3r95
Copy link
Author

By the way, in your function, what does P, N and M represent?

I assume N and M are the number of points for each tensor (which may be different) and P is data dimensionality (in my case 2 (x,y) for both curves. Am I right?

@cjekel
Copy link
Owner

cjekel commented Oct 28, 2022

I was asking about the shapes and sizes because I was curious if you wanted to match more than one center line prediction at a time (like mini-batching). So I was curious if you wanted to match shapes of BxNxP and BxMxP.

I assume N and M are the number of points for each tensor (which may be different) and P is data dimensionality (in my case 2 (x,y) for both curves. Am I right?

Yes that is correct.

@Cram3r95
Copy link
Author

Yes, could be an option, my idea is to predict N predictions, then to obtain the closest centerlines segments for each prediction (e.g.):

imagen

In order to force my models to have a similar shape to the closest centerline

@cjekel
Copy link
Owner

cjekel commented Oct 29, 2022

It could also be interesting to port the frechet distance to pytorch as well.

@Cram3r95
Copy link
Author

Cram3r95 commented Nov 1, 2022

Could you do that? I would really appreciate it.

@Cram3r95
Copy link
Author

Cram3r95 commented Nov 3, 2022

Hi @cjekel , have you finished your Frechet distance in PyTorch?

@cjekel
Copy link
Owner

cjekel commented Nov 3, 2022

I think it will look something like this

def dftorch(exp_data, num_data):
    r"""
    Compute the discrete frechet distance.
    
    exp_data : torch tensor N x P
    num_data : torch tensor M x P
    """
    n = len(exp_data)
    m = len(num_data)
    ca = torch.ones((n, m), dtype=exp_data.dtype, device=exp_data.device)
    ca = torch.multiply(ca, -1)
    ca[0, 0] = torch.linalg.vector_norm(exp_data[0] - num_data[0])
    for i in range(1, n):
        ca[i, 0] = max(ca[i-1, 0], torch.linalg.vector_norm(exp_data[i] - num_data[0]))
    for j in range(1, m):
        ca[0, j] = max(ca[0, j-1], torch.linalg.vector_norm(exp_data[0] - num_data[j]))
    for i in range(1, n):
        for j in range(1, m):
            ca[i, j] = max(min(ca[i-1, j], ca[i, j-1], ca[i-1, j-1]),
                           torch.linalg.vector_norm(exp_data[i] - num_data[j]))
    return ca[n-1, m-1]

@Cram3r95
Copy link
Author

Cram3r95 commented Nov 4, 2022

Nice @cjekel , and for my problem (to compute the similarity between two curves with a different number of points) which would be the most suitable loss function, Frechet Distance or Dynamic Time Warping?

@cjekel
Copy link
Owner

cjekel commented Nov 4, 2022

It depends. The discrete frechet distance is a metric distance. If your curves are sampled evenly, it is analogous to the longest distance between the two curves. In my past projects, curves matched with DTW generally looked better. You'll have to play with both and report back. It's possible that one of these will be eaiser to learn/train a model.

@Cram3r95
Copy link
Author

Cram3r95 commented Dec 2, 2022

Hi @cjekel, here I am again with this problem. I have a question after reading again your comments: Would not you use DTW/Frechet use, but only MSE and MAE for my purpose?

That is, I have to force my model, given the corresponding latent state, to push the predictions towards the GT. In my case I would like to include DTW or Frechet to measure the similarity between my predictions and the closest centerlines. For example:

imagen

Here you may observe that I have three plausible centerlines (dashed black lines). The predictions should model the curvature of these centerlines. That's what I mean.

@cjekel
Copy link
Owner

cjekel commented Dec 7, 2022

@Cram3r95

I have a question after reading again your comments: Would not you use DTW/Frechet use, but only MSE and MAE for my purpose?

It depends.

I'd say if you can directly apply MSE or MAE, then you'd be better off using MSE or MAE. The traditional L1 and L2 norms will be much easier to optimize for.

Now there are plenty of problems where you can't directly apply MSE or MAE. For instance, if there is a different number of data points between you prediction and your target, or the sampling rate is inconsistent between the two. In these cases, you should probably just use DTW or Frechet distance.

@Cram3r95
Copy link
Author

Cram3r95 commented Dec 8, 2022

Yes, of course, that's what I meant! Between my predictions and the groundtruth, there are the same number of data points (30), but not between the predictions and the centerlines (where I want to apply this function). Thank you for your comment.

@Cram3r95
Copy link
Author

@cjekel How would you apply both to a batch? They are only designed for single cases at this moment. It is possible to apply to a whole batch?

@cjekel
Copy link
Owner

cjekel commented Dec 19, 2022

@Cram3r95 it is going to be expensive.

One naive way would be

dtw_values = torch.zeros(y_true.shape[0])
for i, j in enumerate(y_true):
    dtw_values[i] = dtwtorch(j, y_hat)
loss = dtw_values.mean()

Where you basically compute the frechet or dtw distance per element in the batch, then average all of the values.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants