Getting Started

Open In Colab Binder

Molecular Foundation Models are demonstrating impressive performance, but current models use tokenizers that fail to represent all of chemistry; inherently limiting their performance. In particular, Atom-wise tokenizers emit a single token for any bracketed atom, triggering a combinatorial exposition of the vocabulary size. Capturing all variants of Carbon atoms would require 75,600 tokens, or nearly a quarter of the GPT-4o’s vocabulary (Wadell et al.).

The problem is that most atoms are bracketed. Any element outside the organic subset, chiral centers, isotopes, or charged species are all encoded as bracketed atoms. Bracketed atoms encode the nuclear, electronic, and geometric features that are critical to numerous widely-used compounds, including:

Smirk fixes this by tokenizing SMILES encodings all the way down to their constituent elements. Enabling the complete coverage of OpenSMILES with a vocabulary of 167 tokens.

Check out the paper for all the details; otherwise, let’s see it in action!

🐍 Installation is easy with pre-build binaries on PyPI and GitHub. Just run: pip install smirk

Installing from source? See installing from source for instructions.

!python -m pip install smirk transformers rdkit
Hide code cell output
Requirement already satisfied: smirk in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (0.2.0.dev0)
Requirement already satisfied: transformers in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (4.48.3)
Requirement already satisfied: rdkit in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (2024.9.5)
Requirement already satisfied: filelock in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from transformers) (3.19.1)
Requirement already satisfied: huggingface-hub<1.0,>=0.24.0 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from transformers) (0.34.4)
Requirement already satisfied: numpy>=1.17 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from transformers) (2.2.6)
Requirement already satisfied: packaging>=20.0 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from transformers) (25.0)
Requirement already satisfied: pyyaml>=5.1 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from transformers) (6.0.2)
Requirement already satisfied: regex!=2019.12.17 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from transformers) (2025.8.29)
Requirement already satisfied: requests in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from transformers) (2.32.5)
Requirement already satisfied: tokenizers<0.22,>=0.21 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from transformers) (0.21.4)
Requirement already satisfied: safetensors>=0.4.1 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from transformers) (0.6.2)
Requirement already satisfied: tqdm>=4.27 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from transformers) (4.67.1)
Requirement already satisfied: fsspec>=2023.5.0 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.24.0->transformers) (2025.3.0)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.24.0->transformers) (4.15.0)
Requirement already satisfied: hf-xet<2.0.0,>=1.1.3 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.24.0->transformers) (1.1.9)
Requirement already satisfied: Pillow in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from rdkit) (11.3.0)
Requirement already satisfied: charset_normalizer<4,>=2 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from requests->transformers) (3.4.3)
Requirement already satisfied: idna<4,>=2.5 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from requests->transformers) (3.10)
Requirement already satisfied: urllib3<3,>=1.21.1 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from requests->transformers) (2.5.0)
Requirement already satisfied: certifi>=2017.4.17 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from requests->transformers) (2025.8.3)

First steps

🤗 smirk subclasses Hugging Face’s PreTrainedTokenizerBase for seamless compatibility and leverages Tokenizers for raw rust-powered speed. No need to learn another framework; everything works out of the box 🎁

from smirk import SmirkTokenizerFast

# Just import and tokenize!
smirk = SmirkTokenizerFast()
smirk("CC(=O)Nc1ccc(O)cc1")
{'input_ids': [45, 45, 4, 22, 102, 5, 93, 153, 12, 153, 153, 153, 4, 102, 5, 153, 153, 12], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
# Batch Tokenization with Padding
batch = smirk([
    "C[C@@H]1CCCCCCCCCCCCC(=O)C1",
    "O=C(O)C[C@H](N)C(=O)N[C@H](C(=O)OC)Cc1ccccc1",
    "CN(C)S[N][Re@OH18]([C][O])([C][O])([C][O])([C][O])[C][O]"
], padding="longest")
batch
{'input_ids': [[45, 148, 45, 24, 71, 150, 12, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 4, 22, 102, 5, 45, 12, 162, 162, 162, 162, 162, 162, 162, 162, 162, 162, 162, 162, 162, 162, 162, 162, 162, 162, 162, 162, 162, 162, 162, 162, 162, 162, 162], [102, 22, 45, 4, 102, 5, 45, 148, 45, 23, 71, 150, 4, 93, 5, 45, 4, 22, 102, 5, 93, 148, 45, 23, 71, 150, 4, 45, 4, 22, 102, 5, 102, 45, 5, 45, 153, 12, 153, 153, 153, 153, 153, 12, 162, 162, 162, 162, 162, 162, 162, 162, 162], [45, 93, 4, 45, 5, 122, 148, 93, 150, 148, 116, 26, 12, 19, 150, 4, 148, 45, 150, 148, 102, 150, 5, 4, 148, 45, 150, 148, 102, 150, 5, 4, 148, 45, 150, 148, 102, 150, 5, 4, 148, 45, 150, 148, 102, 150, 5, 148, 45, 150, 148, 102, 150]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}
# Back to molecules!
smirk.batch_decode(batch["input_ids"], skip_special_tokens=True)
['C[C@@H]1CCCCCCCCCCCCC(=O)C1',
 'O=C(O)C[C@H](N)C(=O)N[C@H](C(=O)OC)Cc1ccccc1',
 'CN(C)S[N][Re@OH18]([C][O])([C][O])([C][O])([C][O])[C][O]']
# By default, we don't add `[CLS]` and `[SEP]` tokens, but that's just a flag
smirk_bert = SmirkTokenizerFast(template="[CLS] $0 [SEP]")
" ".join(smirk_bert.tokenize("CNCCC(c1ccccc1)Oc2ccc(cc2)C(F)(F)F", add_special_tokens=True))
'[CLS] C N C C C ( c 1 c c c c c 1 ) O c 2 c c c ( c c 2 ) C ( F ) ( F ) F [SEP]'

What Makes Smirk Special?

By fully decomposing the input molecule, smirk ensures complete coverage of the OpenSMILES specification. Any valid OpenSMILES encoding can be tokenized by smirk without emitting unknown tokens. Moreover, for non-bracketed atoms, the smirk tokenization is the same as an Atomwise tokenizer used by current molecular foundation models such as MoLFormer.

Hide code cell source
from rdkit import Chem
from rdkit.Chem.Draw import MolsToGridImage, rdMolDraw2D
from IPython.display import SVG
from transformers import AutoTokenizer

# Tokenizers being evaluated, see the paper for a more comphrensive study (30 tokenizers!)
# Or try adding one of the other tokenziers evaluated in the paper
tokenizers = {
    "smirk": smirk,
    "molformer": AutoTokenizer.from_pretrained("ibm/MoLFormer-XL-both-10pct", trust_remote_code=True),
    "GPT-4o": AutoTokenizer.from_pretrained("Xenova/gpt-4o"),
}

smi = [
    "Cl[Pt@SP1](Cl)([NH3])[NH3]", # Cisplatin 
    "Cl[Pt@SP2](Cl)([NH3])[NH3]", # Transplatin
    "CN1C=NC2=C1C(=O)N(C(=O)N2C)C", # Caffeine
    "[O-][99Tc](=O)(=O)=O.[Na+]", # Sodium pertechnetate with radiotracer
    "[Ga+]$[As-]", # Gallium arsenide
    "[OH2]", # Water
]

def get_legend(smi:str, tokenizers:dict):
    """Helper function for creating legends"""
    entries = []
    for name, tok in tokenizers.items():
        entries.append(f"{name}: {' '.join(tok.tokenize(smi))}")
    return "\n".join(entries)

# Draw all molecules and tokenizations on a grid
drawOptions = rdMolDraw2D.MolDrawOptions()
drawOptions.fixedScale = 1
drawOptions.centreMoleculesBeforeDrawing = True
drawOptions.minFontSize = 6
drawOptions.legendFontSize = 24
drawOptions.legendFraction = 0.3
MolsToGridImage(
    [Chem.MolFromSmiles(smi) for smi in smi],
    molsPerRow=2, subImgSize=(400,200),
    legends=[get_legend(smi, tokenizers) for smi in smi],
    drawOptions=drawOptions,
)
A new version of the following files was downloaded from https://huggingface.co/ibm/MoLFormer-XL-both-10pct:
- tokenization_molformer.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
A new version of the following files was downloaded from https://huggingface.co/ibm/MoLFormer-XL-both-10pct:
- tokenization_molformer_fast.py
- tokenization_molformer.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
_images/cbb826e22450908c17fd97c594e3afe497518a5cebf0701a6ac80b10058bf37a.png

Smirk tokenized all molecules without a single unknown, whereas MoLFormer’s Atomwise tokenizer emitted the unknown token for both Cisplatin and Transplatin (First row). Conversely, the Atomwise tokenizer emitted unknown tokens for the following:

  • Platinum chiral centers: [Pt@SP1] and [Pt@SP2]

  • Ammonia & Water with explicit hydrogens: [NH3] and [OH2]

  • Gallium ion: [Ga+]

  • Quadbond: $

As a data-driven method, Atomwise tokenizers only know about the atoms seen during their training; fundamentally limiting their generalization ability.

Zero to Molecular Foundation Model with Smirk!

Let’s train a small RoBERTa model on molecules from QM9 using Hugging Face and smirk.

!python -m pip install accelerate datasets torch
Hide code cell output
Requirement already satisfied: accelerate in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (1.10.1)
Requirement already satisfied: datasets in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (3.6.0)
Requirement already satisfied: torch in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (2.8.0)
Requirement already satisfied: numpy<3.0.0,>=1.17 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from accelerate) (2.2.6)
Requirement already satisfied: packaging>=20.0 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from accelerate) (25.0)
Requirement already satisfied: psutil in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from accelerate) (7.0.0)
Requirement already satisfied: pyyaml in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from accelerate) (6.0.2)
Requirement already satisfied: huggingface_hub>=0.21.0 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from accelerate) (0.34.4)
Requirement already satisfied: safetensors>=0.4.3 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from accelerate) (0.6.2)
Requirement already satisfied: filelock in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from datasets) (3.19.1)
Requirement already satisfied: pyarrow>=15.0.0 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from datasets) (21.0.0)
Requirement already satisfied: dill<0.3.9,>=0.3.0 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from datasets) (0.3.8)
Requirement already satisfied: pandas in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from datasets) (2.3.2)
Requirement already satisfied: requests>=2.32.2 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from datasets) (2.32.5)
Requirement already satisfied: tqdm>=4.66.3 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from datasets) (4.67.1)
Requirement already satisfied: xxhash in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from datasets) (3.5.0)
Requirement already satisfied: multiprocess<0.70.17 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from datasets) (0.70.16)
Requirement already satisfied: fsspec<=2025.3.0,>=2023.1.0 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (2025.3.0)
Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (3.12.15)
Requirement already satisfied: typing-extensions>=4.10.0 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from torch) (4.15.0)
Requirement already satisfied: sympy>=1.13.3 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from torch) (1.14.0)
Requirement already satisfied: networkx in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from torch) (3.4.2)
Requirement already satisfied: jinja2 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from torch) (3.1.6)
Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.8.93 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from torch) (12.8.93)
Requirement already satisfied: nvidia-cuda-runtime-cu12==12.8.90 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from torch) (12.8.90)
Requirement already satisfied: nvidia-cuda-cupti-cu12==12.8.90 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from torch) (12.8.90)
Requirement already satisfied: nvidia-cudnn-cu12==9.10.2.21 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from torch) (9.10.2.21)
Requirement already satisfied: nvidia-cublas-cu12==12.8.4.1 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from torch) (12.8.4.1)
Requirement already satisfied: nvidia-cufft-cu12==11.3.3.83 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from torch) (11.3.3.83)
Requirement already satisfied: nvidia-curand-cu12==10.3.9.90 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from torch) (10.3.9.90)
Requirement already satisfied: nvidia-cusolver-cu12==11.7.3.90 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from torch) (11.7.3.90)
Requirement already satisfied: nvidia-cusparse-cu12==12.5.8.93 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from torch) (12.5.8.93)
Requirement already satisfied: nvidia-cusparselt-cu12==0.7.1 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from torch) (0.7.1)
Requirement already satisfied: nvidia-nccl-cu12==2.27.3 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from torch) (2.27.3)
Requirement already satisfied: nvidia-nvtx-cu12==12.8.90 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from torch) (12.8.90)
Requirement already satisfied: nvidia-nvjitlink-cu12==12.8.93 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from torch) (12.8.93)
Requirement already satisfied: nvidia-cufile-cu12==1.13.1.3 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from torch) (1.13.1.3)
Requirement already satisfied: triton==3.4.0 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from torch) (3.4.0)
Requirement already satisfied: setuptools>=40.8.0 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from triton==3.4.0->torch) (80.9.0)
Requirement already satisfied: aiohappyeyeballs>=2.5.0 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (2.6.1)
Requirement already satisfied: aiosignal>=1.4.0 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (1.4.0)
Requirement already satisfied: async-timeout<6.0,>=4.0 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (5.0.1)
Requirement already satisfied: attrs>=17.3.0 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (25.3.0)
Requirement already satisfied: frozenlist>=1.1.1 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (1.7.0)
Requirement already satisfied: multidict<7.0,>=4.5 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (6.6.4)
Requirement already satisfied: propcache>=0.2.0 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (0.3.2)
Requirement already satisfied: yarl<2.0,>=1.17.0 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (1.20.1)
Requirement already satisfied: idna>=2.0 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from yarl<2.0,>=1.17.0->aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (3.10)
Requirement already satisfied: hf-xet<2.0.0,>=1.1.3 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from huggingface_hub>=0.21.0->accelerate) (1.1.9)
Requirement already satisfied: charset_normalizer<4,>=2 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from requests>=2.32.2->datasets) (3.4.3)
Requirement already satisfied: urllib3<3,>=1.21.1 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from requests>=2.32.2->datasets) (2.5.0)
Requirement already satisfied: certifi>=2017.4.17 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from requests>=2.32.2->datasets) (2025.8.3)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from sympy>=1.13.3->torch) (1.3.0)
Requirement already satisfied: MarkupSafe>=2.0 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from jinja2->torch) (3.0.2)
Requirement already satisfied: python-dateutil>=2.8.2 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from pandas->datasets) (2.9.0.post0)
Requirement already satisfied: pytz>=2020.1 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from pandas->datasets) (2025.2)
Requirement already satisfied: tzdata>=2022.7 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from pandas->datasets) (2025.2)
Requirement already satisfied: six>=1.5 in /home/runner/work/smirk/smirk/.venv/lib/python3.10/site-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.17.0)

Dataset Preprocessing

from datasets import load_dataset

# MoleculeNet's QM9 dataset. Normally this would be a larger (and unlabeled)
# dataset. But for a demo, it's perfect
dataset = load_dataset("csv", 
    data_files=["https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/qm9.csv"],
)["train"].select_columns("smiles").train_test_split(test_size=0.2)

# Tokenizer the splits! For a larger dataset, this would be done on-the-fly
dataset = dataset.map(smirk, input_columns=["smiles"], desc="Tokenizing")

💡 huggingface/tokenizers may raise a warning about being forked as we’ve already used our tokenizers (this isn’t a smirk issue). It’s harmless, but when actually training it’s best to avoid tokenization until after the fork to benefit from the rust-level parallelism

🎉 That’s it! We’ve tokenized all of QM9 using smirk!

dataset["train"].to_pandas().head()
smiles input_ids attention_mask
0 OC1C2CC(C=C2)C1=O [102, 45, 12, 45, 13, 45, 45, 4, 45, 22, 45, 1... [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
1 N#CC#CC1CC2OC12 [93, 1, 45, 45, 1, 45, 45, 12, 45, 45, 13, 102... [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
2 CCC1C(=O)CC11CN1 [45, 45, 45, 12, 45, 4, 22, 102, 5, 45, 45, 12... [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
3 Cn1cc(oc1=O)O [45, 154, 12, 153, 153, 4, 155, 153, 12, 22, 1... [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
4 CC1=CC(OC1)C1CO1 [45, 45, 12, 22, 45, 45, 4, 102, 45, 12, 5, 45... [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]

Training

Once we’ve tokenized the dataset, training the model is just a matter of configuration.

from accelerate import Accelerator
from transformers import Trainer, TrainingArguments, RobertaForMaskedLM, RobertaConfig, DataCollatorForLanguageModeling

# A very small model for demonstrating training a molecular foundation model with smirk 
config = RobertaConfig(
    vocab_size=len(smirk),
    hidden_size=256,
    intermediate_size=1024,
    num_hidden_layers=4,
    num_attention_heads=4,
)
model = RobertaForMaskedLM(config)

# Setup up the trainer to use our dataset
trainer = Trainer(
    model=model,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    processing_class=smirk,
    data_collator=DataCollatorForLanguageModeling(smirk), # The data collator needs to know about our tokenizer
)
trainer.train()