Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add a constrain_crf option to the SequenceLabelingModel #21

Open
1 task done
eduagarcia opened this issue May 17, 2023 · 0 comments
Open
1 task done

Comments

@eduagarcia
Copy link

eduagarcia commented May 17, 2023

Is your feature request related to a problem?

There is already an CRF module on the repository called CRFwithConstraints that implements constrains for BIO and BIOES decoding, but it's only used in the definition for the TwoStageNERModel. I would like to use the constrained CRF on a simple BERT-CRF model with the SequenceLabelingModel class.

Describe the solution you'd like.

I think it is a simple solution, just add a new argument called constrain_crf that activates the CRFwithConstraints module:

adaseq/models/sequence_labeling_model.py

@@ -10,7 +10,7 @@ from modelscope.utils.config import ConfigDict
 from adaseq.data.constant import PAD_LABEL_ID
 from adaseq.metainfo import Models, Pipelines, Tasks
 from adaseq.models.base import Model
-from adaseq.modules.decoders import CRF, PartialCRF
+from adaseq.modules.decoders import CRF, PartialCRF, CRFwithConstraints
 from adaseq.modules.dropouts import WordDropout
 from adaseq.modules.embedders import Embedder
 from adaseq.modules.encoders import Encoder
@@ -52,6 +52,7 @@ class SequenceLabelingModel(Model):
         mv_interpolation: Optional[float] = 0.5,
         partial: Optional[bool] = False,
         chunk: Optional[bool] = False,
+        constrain_crf: Optional[bool] = False,
         **kwargs
     ) -> None:
         super().__init__(**kwargs)
@@ -84,8 +85,14 @@ class SequenceLabelingModel(Model):
                 self.dropout = nn.Dropout(dropout)

         self.use_crf = use_crf
+        self.constrain_crf = constrain_crf
         if use_crf:
-            if partial:
+            if constrain_crf:
+                id2label_list = [v for k, v in self.id_to_label.items()]
+                self.crf = CRFwithConstraints(
+                    id2label_list, batch_first=True, add_constraint=True
+                )
+            elif partial:
                 self.crf = PartialCRF(self.num_labels, batch_first=True)
             else:
                 self.crf = CRF(self.num_labels, batch_first=True)

To use the CRFwithConstraints on the config.yaml would be something like:

model:
  type: sequence-labeling-model
  embedder:
    model_name_or_path: sijunhe/nezha-cn-base
  word_dropout: 0.0
  use_crf: true
  constrain_crf: true

Describe alternatives you've considered.

No response

Additional context.

No response

Code of Conduct

  • I agree to follow this project's Code of Conduct
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant