Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Contrastive Learning on multi-label datasets #2537

Open
ycouble opened this issue Mar 11, 2024 · 3 comments
Open

Contrastive Learning on multi-label datasets #2537

ycouble opened this issue Mar 11, 2024 · 3 comments

Comments

@ycouble
Copy link

ycouble commented Mar 11, 2024

Hello,

(Cross posting this between SetFit and sentence-transformers)

We're investigating the possibility to use SetFit for customer service message classification.

Our case is a multi-label case since often the customers have more than one request in each message.
During the training phase of SetFit, the texts and labels are passed to Sentence Transformers' SentenceLabelDataset.
The contrastive examples are created based on the combination of labels, not on the intersection of labels, e.g. Labels [1, 1, 0] and [1, 0, 0] are going to be separated by contrastive learning, and only pairs of [1, 1, 0] will be gathered by the contrastive learning phase.

This can be somewhat counter productive in SetFit since with, for example, the classifier "one-vs-rest" which would require examples with one common label to be close to each other.

We were wondering if that behaviour was deliberatelly chosen this way and why ? Would you have experience dealing with this type of data and used a workaround ? Would you be interested in a contribution to allow this type of use-case ?

Cheers,

@ir2718
Copy link
Contributor

ir2718 commented Mar 14, 2024

Hi,

I think a workaround for this is possible in sentence transformers, but let me first ask whether I understand the problem. What you would like to do is create a dataset such that labels [1, 1, 0] and [0, 1, 0] would make the embeddings closer because of the second and third label but further away from each other because of the first label?

@ycouble
Copy link
Author

ycouble commented Mar 14, 2024

Hi,

I haven't figured out the best way to do it nor the implications, but your suggestion seems interesting and matches my need.

My point is to allow to consider class values individually (as your suggestion does) instead of globally (which is the case in the current implementation, where every value for every class has to be the same ST to bring the embeddings closer).

@ir2718
Copy link
Contributor

ir2718 commented Mar 16, 2024

Here's a somewhat hacky approach which you can use:

from sentence_transformers import InputExample, SentenceTransformer, losses
from sentence_transformers.evaluation import BinaryClassificationEvaluator, SequentialEvaluator
from torch.utils.data import DataLoader
from itertools import combinations
from sklearn.model_selection import train_test_split

model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")

# topic     - 0 not football, 1 football
# sentiment - 0 sad, 1 happy
dataset = [
    ("This is happy sentence about football", [1, 1]),
    ("This is sad sentence about football", [0, 1]),
    ("This is happy sentence but not about football", [1, 0]),
    ("This is sad sentence not about football", [0, 0]),
    
    ("This is another happy sentence about football", [1, 1]),
    ("This is another sad sentence about football", [0, 1]),
    ("This is another happy sentence but not about football", [1, 0]),
    ("This is another sad sentence but not about football", [0, 0]),
]

# this function pairs sentences of the same label and the same value for that label
# meaning you can generate all pairs of positives or all pairs of negatives
def generate_examples(dataset, label_idx, label_value):
    valid_idx = (idx for idx, x in enumerate(dataset) if x[1][label_idx] == label_value)
    examples = [
        InputExample(texts=(dataset[idx1][0], dataset[idx2][0]), label=label_value) for idx1, idx2 in combinations(valid_idx, 2)
    ]
    return examples

sentiment_pos_ex = generate_examples(dataset, 0, 1)
sentiment_neg_ex = generate_examples(dataset, 0, 0)
sentiment_ex = sentiment_pos_ex + sentiment_neg_ex
sentiment_train, sentiment_val = train_test_split(sentiment_ex, test_size=0.33)
sentiment_dataloader = DataLoader(sentiment_ex, shuffle=True)

topic_pos_ex = generate_examples(dataset, 1, 1)
topic_neg_ex = generate_examples(dataset, 1, 0)
topic_ex = topic_pos_ex + topic_neg_ex
topic_train, topic_val = train_test_split(topic_ex, test_size=0.33)
topic_dataloader = DataLoader(topic_ex, shuffle=True)

loss = losses.ContrastiveLoss(model)

sentiment_evaluator = BinaryClassificationEvaluator(
    sentences1=[x.texts[0] for x in sentiment_val],
    sentences2=[x.texts[1] for x in sentiment_val],
    labels=[x.label for x in sentiment_val]
)
topic_evaluator = BinaryClassificationEvaluator(
    sentences1=[x.texts[0] for x in topic_val],
    sentences2=[x.texts[1] for x in topic_val],
    labels=[x.label for x in topic_val]
)
combined_evaluator = SequentialEvaluator(
    evaluators=[sentiment_evaluator, topic_evaluator],
    main_score_function=lambda x: sum(x) / len(x) # tune this score function for your particular use case
)

model.fit(
    train_objectives=[(sentiment_dataloader, loss), (topic_dataloader, loss)],
    evaluator=combined_evaluator,
    epochs=10,
    evaluation_steps=50,
    warmup_steps=10,
    output_path="./multi_label_test",
)

The code optimizes two objectives at the same time using two different optimizers, so it's not exactly the same as combining the losses and then backpropagating, although I think that could be done by creating a custom loss function. I hope this helps.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants