Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Text-to-small molecule deterministic vs. nondeterministic behavior #7

Open
amelie-iska opened this issue Mar 26, 2024 · 9 comments
Open
Labels
good first issue Good for newcomers

Comments

@amelie-iska
Copy link

amelie-iska commented Mar 26, 2024

I was just wondering if there is a way to adjust the model to be less deterministic. I've tried adjusting the beams in the beam search, adding in functionality to adjust the temperature of the model and various other things in an attempt to make the model non-deterministic but it does not seem to be working. As an example, I have tried the following:

from transformers import T5Tokenizer, T5ForConditionalGeneration
import selfies as sf

tokenizer = T5Tokenizer.from_pretrained("QizhiPei/biot5-base-text2mol", model_max_length=540)
model = T5ForConditionalGeneration.from_pretrained('QizhiPei/biot5-base-text2mol')

task_definition = 'Definition: You are given a molecule description in English. Your job is to generate the molecule SELFIES that fits the description.\n\n'
text_input = "The molecule is a monocarboxylic acid anion obtained by deprotonation of the carboxy and sulfino groups of 3-sulfinopropionic acid. Major microspecies at pH 7.3 It is an organosulfinate oxoanion and a monocarboxylic acid anion. It is a conjugate base of a 3-sulfinopropionic acid."
task_input = f'Now complete the following example -\nInput: {text_input}\nOutput: '

model_input = task_definition + task_input
input_ids = tokenizer(model_input, return_tensors="pt").input_ids

generation_config = model.generation_config
generation_config.max_length = 512
generation_config.num_beams = 5
generation_config.do_sample = True
generation_config.top_k = 50
generation_config.top_p = 0.75
generation_config.temperature = 0.95

# Set the number of sequences to generate, must be <= num_beams
num_sequences = 3
generation_config.num_return_sequences = num_sequences

outputs = model.generate(input_ids, generation_config=generation_config)

print(f"Top {num_sequences} generated sequences:")
for i, output in enumerate(outputs, start=1):
    output_selfies = tokenizer.decode(output, skip_special_tokens=True).replace(' ', '')
    output_smiles = sf.decoder(output_selfies)
    print(f"Sequence {i}:")
    print("Generated SELFIES:", output_selfies)
    print("Generated SMILES:", output_smiles)
    print()

It seems that no matter how I adjust this (turning beams down to one or to higher numbers for example) does nothing to change the output. Even changing the prompt slightly does not seem to change the output. Perhaps the model has learned too rigid a representation of the text-molecule mapping? Or perhaps I am overlooking something?

@QizhiPei QizhiPei added the good first issue Good for newcomers label Mar 26, 2024
@QizhiPei
Copy link
Owner

I tried some other test case with increased num_beams and observe different outputs. Adjust diversity_penalty can also result in different outputs. Could you try more cases with above settings?

Feel free to contact for any further questions or discussions.

@amelie-iska
Copy link
Author

Would you be willing to provide your settings?

@QizhiPei
Copy link
Owner

QizhiPei commented Mar 26, 2024

text_input = "The molecule is an amino disaccharide consisting of beta-D-galactopyranose and 2-acetamido-2-deoxy-D-galactopyranose resicues joined in sequence by a (1->6) glycosidic bond. It is an amino disaccharide and a member of acetamides. It derives from a beta-D-galactose and a N-acetyl-D-galactosamine."

...

generation_config.max_length = 512
generation_config.num_beams = 5 or 25

num_sequences = 5 or 25
generation_config.num_return_sequences = num_sequences

@QizhiPei
Copy link
Owner

generation_config = model.generation_config
generation_config.max_length = 512
generation_config.num_beams = 5

generation_config.do_sample = False
generation_config.num_beam_groups = 5
generation_config.diversity_penalty = 10.0

num_sequences = 5
generation_config.num_return_sequences = num_sequences

@amelie-iska
Copy link
Author

amelie-iska commented Mar 26, 2024

Thanks for the response! I'm still getting deterministic output. Any ideas why?

Also, your first example above returns a warning:

UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.95` 
-- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.
  warnings.warn(
/opt/home/usr/miniconda3/envs/biot5/lib/python3.8/site-packages/transformers/generation/configuration_utils.py:367: 
UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.75` 
-- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.

@QizhiPei
Copy link
Owner

QizhiPei commented Mar 26, 2024

Sorry for the error. I have updated the config. You can try the two settings above.

@amelie-iska amelie-iska changed the title Text-to-protein deterministic vs. nondeterministic behavior Text-to-small molecule deterministic vs. nondeterministic behavior Mar 27, 2024
@amelie-iska
Copy link
Author

I am still getting deterministic behavior from the model. I've tried your recommendations, as well as top-k sampling, top-p sampling, temperature sampling, and a couple of others, and the model seems to generate the same molecule each time regardless of what I do. It would be very helpful to have the GitHub updated with some example scripts for nondeterministic sampling methods (as well as some variations on beam search).

@QizhiPei
Copy link
Owner

QizhiPei commented Mar 29, 2024

Thanks for your suggestions. I will update some examples for nondeterministic generation settings.

However, it's not make sense that we get different outputs with the same generation setting.
Could you kindly provide your outputs of the above two settings?

For the first setting, I get the following outputs:

Top 5 generated sequences:
Sequence 1:
Generated SELFIES: [C][C][=Branch1][C][=O][N][C@H1][C][Branch1][C][O][O][C@H1][Branch2][Ring1][=Branch1][C][O][C@@H1][O][C@H1][Branch1][Ring1][C][O][C@H1][Branch1][C][O][C@H1][Branch1][C][O][C@H1][Ring1][#Branch2][O][C@H1][Branch1][C][O][C@@H1][Ring2][Ring1][Branch1][O]
Generated SMILES: CC(=O)N[C@H1]1C(O)O[C@H1](CO[C@@H1]2O[C@H1](CO)[C@H1](O)[C@H1](O)[C@H1]2O)[C@H1](O)[C@@H1]1O

Sequence 2:
Generated SELFIES: [C][C][=Branch1][C][=O][N][C@H1][C][Branch1][C][O][O][C@H1][Branch2][Ring1][=Branch1][C][O][C@@H1][O][C@H1][Branch1][Ring1][C][O][C@H1][Branch1][C][O][C@H1][Branch1][C][O][C@H1][Ring1][#Branch2][O][C@@H1][Branch1][C][O][C@@H1][Ring2][Ring1][Branch1][O]
Generated SMILES: CC(=O)N[C@H1]1C(O)O[C@H1](CO[C@@H1]2O[C@H1](CO)[C@H1](O)[C@H1](O)[C@H1]2O)[C@@H1](O)[C@@H1]1O

Sequence 3:
Generated SELFIES: [C][C][=Branch1][C][=O][N][C@H1][C][Branch1][C][O][O][C@H1][Branch2][Ring1][=Branch1][C][O][C@@H1][O][C@H1][Branch1][Ring1][C][O][C@H1][Branch1][C][O][C@H1][Branch1][C][O][C@H1][Ring1][#Branch2][O][C@@H1][Branch1][C][O][C@H1][Ring2][Ring1][Branch1][O]
Generated SMILES: CC(=O)N[C@H1]1C(O)O[C@H1](CO[C@@H1]2O[C@H1](CO)[C@H1](O)[C@H1](O)[C@H1]2O)[C@@H1](O)[C@H1]1O

Sequence 4:
Generated SELFIES: [C][C][=Branch1][C][=O][N][C@H1][C][Branch1][C][O][O][C@H1][Branch3][Ring1][C][#Branch2][C][O][C@@H1][O][C@H1][Branch2][Ring1][=Branch1][C][O][C@@H1][O][C@H1][Branch1][Ring1][C][O][C@H1][Branch1][C][O][C@H1][Branch1][C][O][C@H1][Ring1][#Branch2][O][C@H1][Branch1][C][O][C@H1][Branch1][C][O][C@H1][Ring2][Ring1][Branch1][N][C][Branch1][C][C][=O][C@H1][Branch1][C][O][C@H1][Branch1][C][O][C@H1][Ring2][Ring1][S][O]
Generated SMILES: CC(=O)N[C@H1]C(O)O[C@H1]CO[C@@H1]1O[C@H1](CO[C@@H1]2O[C@H1](CO)[C@H1](O)[C@H1](O)[C@H1]2O)[C@H1](O)[C@H1](O)[C@H1]1NC(C)=O

Sequence 5:
Generated SELFIES: [C][C][=Branch1][C][=O][N][C@H1][C][Branch1][C][O][O][C@H1][Branch3][Ring1][C][#Branch2][C][O][C@@H1][O][C@H1][Branch2][Ring1][=Branch1][C][O][C@@H1][O][C@H1][Branch1][Ring1][C][O][C@H1][Branch1][C][O][C@H1][Branch1][C][O][C@H1][Ring1][#Branch2][O][C@H1][Branch1][C][O][C@H1][Branch1][C][O][C@H1][Ring2][Ring1][Branch1][N][C][Branch1][C][C][=O][C@@H1][Branch1][C][O][C@H1][Branch1][C][O][C@H1][Ring2][Ring1][S][O]
Generated SMILES: CC(=O)N[C@H1]C(O)O[C@H1]CO[C@@H1]1O[C@H1](CO[C@@H1]2O[C@H1](CO)[C@H1](O)[C@H1](O)[C@H1]2O)[C@H1](O)[C@H1](O)[C@H1]1NC(C)=O

For the second:

Top 5 generated sequences:
Sequence 1:
Generated SELFIES: [C][C][=Branch1][C][=O][N][C@H1][C][Branch1][C][O][O][C@H1][Branch2][Ring1][=Branch1][C][O][C@@H1][O][C@H1][Branch1][Ring1][C][O][C@H1][Branch1][C][O][C@H1][Branch1][C][O][C@H1][Ring1][#Branch2][O][C@H1][Branch1][C][O][C@@H1][Ring2][Ring1][Branch1][O]
Generated SMILES: CC(=O)N[C@H1]1C(O)O[C@H1](CO[C@@H1]2O[C@H1](CO)[C@H1](O)[C@H1](O)[C@H1]2O)[C@H1](O)[C@@H1]1O

Sequence 2:
Generated SELFIES: [C][=Branch1][C][=O][N][C@H1][C][Branch1][C][O][O][C@H1][Branch2][Ring1][=Branch1][C][O][C@@H1][O][C@H1][Branch1][Ring1][C][O][C@H1][Branch1][C][O][C@H1][Branch1][C][O][C@H1][Ring1][#Branch2][O][C@H1][Branch1][C][O][C@@H1][Ring2][Ring1][Branch1][O]
Generated SMILES: C(=O)N[C@H1]1C(O)O[C@H1](CO[C@@H1]2O[C@H1](CO)[C@H1](O)[C@H1](O)[C@H1]2O)[C@H1](O)[C@@H1]1O

Sequence 3:
Generated SELFIES: [C][C][=Branch1][C][=O][N][C@@H1][C@@H1][Branch2][Ring1][Branch1][O][C][O][C@H1][Branch1][Ring1][C][O][C@H1][Branch1][C][O][C@H1][Branch1][C][O][C@H1][Ring1][#Branch2][O][C@@H1][Branch1][C][O][C@@H1][Branch2][Ring1][=Branch1][C][O][C@@H1][O][C@H1][Branch1][Ring1][C][O][C@H1][Branch1][C][O][C@H1][Branch1][C][O][C@H1][Ring1][#Branch2][O][O][C@H1][Ring2][Ring1][S][O]
Generated SMILES: CC(=O)N[C@@H1]1[C@@H1](OC2O[C@H1](CO)[C@H1](O)[C@H1](O)[C@H1]2O)[C@@H1](O)[C@@H1](CO[C@@H1]3O[C@H1](CO)[C@H1](O)[C@H1](O)[C@H1]3O)O[C@H1]1O

Sequence 4:
Generated SELFIES: [=C][N][C@H1][C][Branch1][C][O][O][C@H1][Branch2][Ring1][=Branch1][C][O][C@@H1][O][C@H1][Branch1][Ring1][C][O][C@@H1][Branch1][C][O][C@H1][Branch1][C][O][C@H1][Ring1][#Branch2][O][C@@H1][Branch1][C][O][C@@H1][Ring2][Ring1][Branch1][N][C][Branch1][C][C][=O]
Generated SMILES: CN[C@H1]1C(O)O[C@H1](CO[C@@H1]2O[C@H1](CO)[C@@H1](O)[C@H1](O)[C@H1]2O)[C@@H1](O)[C@@H1]1NC(C)=O

Sequence 5:
Generated SELFIES: [C][C][=Branch1][C][=O][N][C@H1][C][Branch1][C][O][O][C@H1][Branch2][Ring1][=Branch1][C][O][C@@H1][O][C@H1][Branch1][Ring1][C][O][C@H1][Branch1][C][O][C@H1][Branch1][C][O][C@H1][Ring1][#Branch2][O][C@@H1][Branch1][C][O][C@H1][Ring2][Ring1][Branch1][O]
Generated SMILES: CC(=O)N[C@H1]1C(O)O[C@H1](CO[C@@H1]2O[C@H1](CO)[C@H1](O)[C@H1](O)[C@H1]2O)[C@@H1](O)[C@H1]1O

@amelie-iska
Copy link
Author

amelie-iska commented Mar 29, 2024

I have several scripts, one for top-k, top-p, and temperature sampling, along with a couple of variations on beam search. For example, for temperature sampling I have the following:

from transformers import T5Tokenizer, T5ForConditionalGeneration
import selfies as sf

model_name = "QizhiPei/biot5-base-text2mol"
tokenizer = T5Tokenizer.from_pretrained(model_name, model_max_length=512)  
model = T5ForConditionalGeneration.from_pretrained(model_name)

def generate_mol_from_text(text_input, tokenizer, model, **kwargs):
    # Tokenize input text
    task_definition = 'Definition: You are given a molecule description in English. Your job is to generate the molecule SELFIES that fits the description.\n\n'
    task_input = f'Now complete the following example -\nInput: {text_input}\nOutput: '
    model_input = task_definition + task_input
    input_ids = tokenizer(model_input, return_tensors="pt").input_ids

    # Generate molecule SELFIES 
    generation_config = model.generation_config
    generation_config.update(**kwargs)
    generation_config.max_length = 512
    generation_config.do_sample = True  # Enable sampling-based generation
    outputs = model.generate(input_ids, generation_config=generation_config)

    # Decode SELFIES and convert to SMILES
    selfies = tokenizer.decode(outputs[0], skip_special_tokens=True).replace(' ', '')
    smiles = sf.decoder(selfies)

    return selfies, smiles

text_input = "A monocarboxylic acid anion obtained by deprotonation of the carboxy and sulfino groups of 3-sulfinopropionic acid. Major microspecies at pH 7.3. It is an organosulfinate oxoanion and a monocarboxylic acid anion. It is a conjugate base of a 3-sulfinopropionic acid."

selfies, smiles = generate_mol_from_text(text_input, tokenizer, model, temperature=0.7)

print("Generated SELFIES:", selfies)
print("Generated SMILES:", smiles)

For top-k sampling we can change the line

selfies, smiles = generate_mol_from_text(text_input, tokenizer, model, temperature=0.7)

to

selfies, smiles = generate_mol_from_text(text_input, tokenizer, model, do_sample=True, top_k=25)

For top-p sampling we can change it to:

selfies, smiles = generate_mol_from_text(text_input, tokenizer, model, do_sample=True, top_p=0.9)

For diverse beam search we can set do_sample=False and use something like:

selfies, smiles = generate_mol_from_text(text, tokenizer, model, num_beams=25, num_beam_groups=5, diversity_penalty=1.0)

So currently I have five scripts for text-to-small molecule (t2m):

  • t2m_top_k.py
  • t2m_top_p.py
  • t2m_temp.py
  • t2m_diverse_beam.py
  • t2m_beam.py

Vanilla beam search is deterministic and should yield the same results each time assuming the settings are the same. The other sampling methods should be nondeterministic and should yield a new molecule each time we run a forward pass, right? I've also tried setting random seeds and ensuring CUDA is not behaving deterministically such as in the following script:

import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
import selfies as sf
from rdkit import Chem
import random
import numpy as np

# Function to check molecule validity
def is_valid_molecule(smiles):
    try:
        mol = Chem.MolFromSmiles(smiles)
        return mol is not None
    except:
        return False

# Set a new random seed for each generation to ensure diversity
def set_random_seed(seed_value=None):
    if seed_value is None:
        seed_value = random.randint(0, 10000)
    torch.manual_seed(seed_value)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed_value)
    np.random.seed(seed_value)
    random.seed(seed_value)
    torch.backends.cudnn.deterministic = False

# Initialize tokenizer and model
tokenizer = T5Tokenizer.from_pretrained("QizhiPei/biot5-base-text2mol", model_max_length=540)
model = T5ForConditionalGeneration.from_pretrained('QizhiPei/biot5-base-text2mol')

# Input description
task_definition = 'Definition: You are given a molecule description in English. Your job is to generate the molecule SELFIES that fits the description.\n\n'
text_input = "The molecule is a substituted phenethylamine derivative bearing structural resemblance to both 3,4-methylenedioxymethamphetamine and 4-bromo-2,5-dimethoxyphenethylamine. The core scaffold consists of a phenethylamine backbone with an α-methyl group on the ethylamine side chain, conferring structural similarity to MDMA. A 3,4-methylenedioxy ring is fused to the phenyl ring, a key feature of MDMA thought to contribute to its entactogenic effects. The phenyl ring is further substituted at the 2 and 5 positions with methoxy groups, a substitution pattern characteristic of the 2C-X family of psychedelics like 2C-B. To optimize for a shorter 2-hour duration of action while retaining anxiolytic efficacy, the 4-position bromine of 2C-B is eschewed in favor of a less lipophilic and more polar moiety such as an alcohol, ketone, or alkyl ether. This decrease in lipophilicity is predicted to hasten metabolism and clearance. Furthermore, the amine is modified with a bulky N-alkyl group to attenuate activity at the 5-HT2A receptor, minimizing psychedelic effects in favor of anxiolytic activity putatively mediated by 5-HT1A agonism. Promising N-substituents may include tert-butyl, neopentyl, or adamantyl groups. In summary, the hybrid molecule is N-neopentyl-2,5-dimethoxy-3,4-methylenedioxyphenethylamine, likely exhibiting a 2-hour duration and anxiolytic efficacy mediated through MDMA-like entactogenic effects and 5-HT1A agonism. Preparation of the molecule could be achieved by reductive amination of the corresponding ketone precursor with neopentylamine."
task_input = f'Now complete the following example -\nInput: {text_input}\nOutput: '

model_input = task_definition + task_input
input_ids = tokenizer(model_input, return_tensors="pt").input_ids

# Ensure diverse generation by setting a new random seed
set_random_seed()

# Generation configuration for sampling
generation_params = {
    "max_length": 512,
    "do_sample": True,
    "top_k": 50,
    "top_p": 0.95,
    "temperature": 1.2,
    "num_return_sequences": 3
}

# Generate output
outputs = model.generate(input_ids, **generation_params)

print("Generated sequences:")
for i, output in enumerate(outputs):
    output_selfies = tokenizer.decode(output, skip_special_tokens=True).replace(' ', '')
    try:
        output_smiles = sf.decoder(output_selfies)
        if is_valid_molecule(output_smiles):
            print(f"Sequence {i + 1}:")
            print("Generated SELFIES:", output_selfies)
            print("Generated SMILES:", output_smiles)
            print()
    except sf.exceptions.DecoderError:
        continue

If nondeterministic sampling is possible with this model, my concern is that the model has collapsed as described in the first three paragraphs of Section 2 of the I-JEPA paper.

It's possible this intuition does not apply to the BioT5+ architecture, but joint modeling usually requires contrastive loss (or some other trick) to prevent the energy landscape of the model from being too flat. For example, the I-JEPA paper mentions the following in addition to contrastive loss training:

"non-contrastive losses that minimize the informational redundancy across embeddings, and clustering-based approaches that maximize the entropy of the average embedding.

I'm still not convinced it's not an error on my part though, so any discussion or help you might provide is very much appreciated. How difficult would it be to retrain using a contrastive loss? This might help immensely.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue Good for newcomers
Projects
None yet
Development

No branches or pull requests

2 participants