DemoChat SFT

Chat Supervised Fine-Tuning

This tutorial demonstrates supervised fine-tuning (SFT) on chat-formatted data using MinT.

What You’ll Learn

  1. Load and process multi-turn conversations from HuggingFace
  2. Apply chat templates for model-specific formatting
  3. Implement loss masking to train only on assistant responses
  4. Run the SFT training loop
  5. Evaluate and sample from the fine-tuned model

Datasets

We support two chat datasets:

DatasetSourceSizeTrain OnUse Case
no_robotsHuggingFaceH4/no_robots9.5K train + 500 testAll assistant messagesQuick experiments
tulu3allenai/tulu-3-sft-mixture939K samplesLast assistant messageLarge-scale training

SFT vs RL

MethodTraining SignalWhen to Use
SFTExpert demonstrations (input → output pairs)When you have high-quality labeled data
RLReward signal (correct/incorrect)When defining correctness programmatically is easier than providing examples

SFT is simpler: no sampling required, no reward function to design. The model learns to predict the next token on training data.


Step 0: Setup

Install required packages:

pip install -q datasets transformers mint

Load your API key:

import os
from dotenv import load_dotenv
 
load_dotenv()
 
# MinT uses MINT_API_KEY for authentication
if os.environ.get('MINT_API_KEY'):
    print("API key loaded")
else:
    print("WARNING: MINT_API_KEY not found!")

Connect to MinT:

import mint
from mint import types
 
service_client = mint.ServiceClient()
print("Connected to MinT")

Step 1: Configuration

Configure your training run:

# ========== CONFIGURATION ==========
 
# Dataset: "no_robots" or "tulu3"
DATASET = "no_robots"
 
# Model
BASE_MODEL = "Qwen/Qwen3-0.6B"
LORA_RANK = 16
 
# Training
NUM_STEPS = 50 if DATASET == "no_robots" else 100
BATCH_SIZE = 4
LEARNING_RATE = 1e-4
MAX_LENGTH = 2048
 
# Loss masking: which tokens to train on
# "all_assistant" = train on all assistant responses
# "last_assistant" = train only on the final assistant response
TRAIN_ON = "all_assistant" if DATASET == "no_robots" else "last_assistant"
 
print(f"Dataset: {DATASET}")
print(f"Model: {BASE_MODEL}")
print(f"Steps: {NUM_STEPS}, Batch: {BATCH_SIZE}, LR: {LEARNING_RATE}")
print(f"Max length: {MAX_LENGTH}, Train on: {TRAIN_ON}")

Parameter choices:

  • all_assistant: For smaller datasets like no_robots, train on every assistant turn to maximize data utilization
  • last_assistant: For large datasets like tulu3, training only on the final response is sufficient and faster

Step 2: Tokenizer & Chat Template

Why Chat Templates Matter

Language models don’t understand “roles” natively. A chat template converts structured messages into a flat token sequence that the model can process:

[User message] → <|im_start|>user\nHello!<|im_end|>
[Assistant]    → <|im_start|>assistant\nHi there!<|im_end|>

Different models use different templates. We use HuggingFace’s apply_chat_template() to handle this automatically.

Load the Tokenizer

from transformers import AutoTokenizer
 
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
print(f"Vocab size: {tokenizer.vocab_size:,}")
print(f"EOS token: {tokenizer.eos_token!r} (id={tokenizer.eos_token_id})")

Tokenize with Loss Masking

The key insight: we want to compute loss only on assistant tokens. This means:

  • User messages: weight = 0 (don’t train on these)
  • Assistant messages: weight = 1 (train on these)
import numpy as np
 
def tokenize_conversation(
    messages: list[dict],
    tokenizer,
    max_length: int,
    train_on: str = "all_assistant",
) -> tuple[list[int], np.ndarray]:
    """
    Tokenize a conversation and compute loss weights.
 
    Returns:
        input_ids: Token IDs for the full conversation
        weights: Loss weights (1.0 for tokens to train on, 0.0 otherwise)
    """
    # Tokenize message by message to track boundaries
    all_tokens = []
    all_weights = []
 
    for i, msg in enumerate(messages):
        # Build partial conversation up to this message
        partial = messages[:i+1]
 
        # Apply chat template
        text = tokenizer.apply_chat_template(
            partial,
            tokenize=False,
            add_generation_prompt=False,
        )
        tokens = tokenizer.encode(text, add_special_tokens=False)
 
        # Find new tokens added by this message
        prev_len = len(all_tokens)
        new_tokens = tokens[prev_len:]
 
        # Determine weight for this message
        is_assistant = msg.get("role") == "assistant"
        is_last = (i == len(messages) - 1)
 
        if train_on == "all_assistant":
            weight = 1.0 if is_assistant else 0.0
        elif train_on == "last_assistant":
            weight = 1.0 if (is_assistant and is_last) else 0.0
        else:
            weight = 1.0  # train on all tokens
 
        all_tokens.extend(new_tokens)
        all_weights.extend([weight] * len(new_tokens))
 
    # Truncate to max_length
    if len(all_tokens) > max_length:
        all_tokens = all_tokens[:max_length]
        all_weights = all_weights[:max_length]
 
    return all_tokens, np.array(all_weights, dtype=np.float32)

Test the tokenization:

demo_messages = [
    {"role": "user", "content": "Hello!"},
    {"role": "assistant", "content": "Hi there! How can I help you?"},
]
 
tokens, weights = tokenize_conversation(demo_messages, tokenizer, MAX_LENGTH, TRAIN_ON)
print(f"Tokens: {len(tokens)}")
print(f"Trainable tokens: {int(weights.sum())}")
print(f"\nDecoded: {tokenizer.decode(tokens)!r}")

Step 3: Create Training Datum

MinT expects data in Datum format. For next-token prediction:

  • model_input: tokens[:-1] (all tokens except the last)
  • target_tokens: tokens[1:] (all tokens except the first)
  • weights: weights[1:] (aligned with targets)
def conversation_to_datum(
    messages: list[dict],
    tokenizer,
    max_length: int,
    train_on: str,
) -> types.Datum:
    """
    Convert a conversation to a training Datum.
 
    The model predicts token[i+1] from token[0:i+1], so:
    - input = tokens[:-1]
    - target = tokens[1:]
    - weights = weights[1:] (shifted to align with targets)
    """
    tokens, weights = tokenize_conversation(messages, tokenizer, max_length, train_on)
 
    if len(tokens) < 2:
        raise ValueError("Conversation too short")
 
    # Next-token prediction format
    input_tokens = tokens[:-1]
    target_tokens = tokens[1:]
    target_weights = weights[1:]  # Shift weights to align with targets
 
    return types.Datum(
        model_input=types.ModelInput.from_ints(input_tokens),
        loss_fn_inputs={
            "target_tokens": list(target_tokens),
            "weights": target_weights.tolist(),
        },
    )

Test the conversion:

datum = conversation_to_datum(demo_messages, tokenizer, MAX_LENGTH, TRAIN_ON)
print(f"Input length: {datum.model_input.length}")
print(f"Target length: {len(datum.loss_fn_inputs['target_tokens'])}")
print(f"Trainable tokens: {sum(datum.loss_fn_inputs['weights']):.0f}")

Step 4: Load Dataset

We provide a simple dataset wrapper that handles batching and shuffling:

from datasets import load_dataset
from dataclasses import dataclass
import random
 
@dataclass
class ChatDataset:
    """Simple chat dataset with batching."""
    data: list[list[dict]]  # List of conversations (each is list of messages)
    index: int = 0
 
    def get_batch(self, batch_size: int) -> list[list[dict]]:
        """Get next batch of conversations."""
        batch = []
        for _ in range(batch_size):
            if self.index >= len(self.data):
                self.index = 0  # Wrap around
                random.shuffle(self.data)
            batch.append(self.data[self.index])
            self.index += 1
        return batch
 
    def __len__(self) -> int:
        return len(self.data)
 
 
def load_chat_dataset(dataset_name: str, seed: int = 42) -> tuple[ChatDataset, ChatDataset]:
    """Load train and test datasets."""
    random.seed(seed)
 
    if dataset_name == "no_robots":
        ds = load_dataset("HuggingFaceH4/no_robots")
        train_data = [row["messages"] for row in ds["train"]]
        test_data = [row["messages"] for row in ds["test"]]
 
    elif dataset_name == "tulu3":
        ds = load_dataset("allenai/tulu-3-sft-mixture", split="train")
        ds = ds.shuffle(seed=seed)
        all_data = [row["messages"] for row in ds]
        # Split: first 1024 for test, rest for train
        test_data = all_data[:1024]
        train_data = all_data[1024:]
 
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")
 
    random.shuffle(train_data)
    return ChatDataset(train_data), ChatDataset(test_data)

Load and inspect:

train_dataset, test_dataset = load_chat_dataset(DATASET)
print(f"Train: {len(train_dataset)} conversations")
print(f"Test: {len(test_dataset)} conversations")
 
# Show sample
sample = train_dataset.get_batch(1)[0]
print(f"\nSample conversation ({len(sample)} messages):")
for msg in sample[:3]:  # Show first 3 messages
    content = msg['content'][:80] + "..." if len(msg['content']) > 80 else msg['content']
    print(f"  [{msg['role']}]: {content}")

Step 5: Create Training Client

Create a LoRA training client. LoRA (Low-Rank Adaptation) allows efficient fine-tuning by training small adapter matrices instead of all model weights.

training_client = service_client.create_lora_training_client(
    base_model=BASE_MODEL,
    rank=LORA_RANK,
    train_mlp=True,
    train_attn=True,
    train_unembed=True,
)
print(f"Training client created: {BASE_MODEL}")
print(f"LoRA rank: {LORA_RANK}")

LoRA Parameters:

  • rank: Size of the low-rank matrices. Higher = more capacity, slower training
  • train_mlp: Train feed-forward (MLP) layers
  • train_attn: Train attention layers
  • train_unembed: Train output projection

Step 6: Training Loop

The training loop follows a simple pattern:

for each step:
    1. Get batch of conversations
    2. Convert to Datums (tokenize + compute weights)
    3. forward_backward: compute loss and gradients
    4. optim_step: update model weights
metrics_history = []
 
print(f"Starting SFT training: {NUM_STEPS} steps")
print(f"Batch: {BATCH_SIZE}, LR: {LEARNING_RATE}")
print()
 
for step in range(NUM_STEPS):
    # Get batch of conversations
    batch = train_dataset.get_batch(BATCH_SIZE)
 
    # Convert to Datums
    datums = []
    for messages in batch:
        try:
            datum = conversation_to_datum(messages, tokenizer, MAX_LENGTH, TRAIN_ON)
            datums.append(datum)
        except Exception as e:
            continue  # Skip malformed conversations
 
    if not datums:
        continue
 
    # Forward-backward pass: compute loss and gradients
    fwdbwd_result = training_client.forward_backward(
        datums,
        loss_fn="cross_entropy",
    ).result()
 
    # Compute loss from logprobs
    total_loss = 0.0
    total_weight = 0.0
    for i, out in enumerate(fwdbwd_result.loss_fn_outputs):
        logprobs = out['logprobs']
        if hasattr(logprobs, 'tolist'):
            logprobs = logprobs.tolist()
        w = datums[i].loss_fn_inputs['weights']
        if hasattr(w, 'tolist'):
            w = w.tolist()
        for lp, wt in zip(logprobs, w):
            total_loss += -lp * wt
            total_weight += wt
 
    loss = total_loss / max(total_weight, 1)
 
    # Optimization step: update weights
    training_client.optim_step(
        types.AdamParams(learning_rate=LEARNING_RATE)
    ).result()
 
    metrics_history.append({"step": step, "loss": loss})
 
    if step % 10 == 0 or step == NUM_STEPS - 1:
        print(f"Step {step:3d}: loss={loss:.4f}")
 
print("\nTraining complete!")
print(f"Initial loss: {metrics_history[0]['loss']:.4f}")
print(f"Final loss: {metrics_history[-1]['loss']:.4f}")

Understanding the loss:

  • Loss is the negative log-likelihood (NLL) of correct tokens
  • Lower loss = model is more confident in predicting the right tokens
  • Expect loss to decrease over training

Step 7: Evaluate

Compute NLL on the test set to check for overfitting:

# Evaluate on test set
test_batch = test_dataset.get_batch(min(32, len(test_dataset)))
 
test_datums = []
for messages in test_batch:
    try:
        datum = conversation_to_datum(messages, tokenizer, MAX_LENGTH, TRAIN_ON)
        test_datums.append(datum)
    except Exception:
        continue
 
if test_datums:
    # Forward pass only (no gradients)
    forward_result = training_client.forward(
        test_datums,
        loss_fn="cross_entropy",
    ).result()
 
    # Compute loss from logprobs
    total_loss = 0.0
    total_weight = 0.0
    for i, out in enumerate(forward_result.loss_fn_outputs):
        logprobs = out['logprobs']
        if hasattr(logprobs, 'tolist'):
            logprobs = logprobs.tolist()
        w = test_datums[i].loss_fn_inputs['weights']
        if hasattr(w, 'tolist'):
            w = w.tolist()
        for lp, wt in zip(logprobs, w):
            total_loss += -lp * wt
            total_weight += wt
 
    test_loss = total_loss / max(total_weight, 1)
    print(f"Test NLL: {test_loss:.4f}")
else:
    print("No valid test samples")

Interpreting results:

  • Test loss close to train loss = good generalization
  • Test loss much higher than train loss = overfitting (reduce training steps or increase data)

Step 8: Visualize

Plot the training curve:

import matplotlib.pyplot as plt
 
steps = [m['step'] for m in metrics_history]
losses = [m['loss'] for m in metrics_history]
 
plt.figure(figsize=(10, 5))
plt.plot(steps, losses, 'b-', linewidth=2)
plt.xlabel('Step')
plt.ylabel('Loss (NLL)')
plt.title(f'{DATASET.upper()} SFT Training')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(f'{DATASET}_training.png', dpi=150)
plt.show()

Step 9: Generate Sample

Test the fine-tuned model by generating a response:

# Save weights and get sampling client
sampling_client = training_client.save_weights_and_get_sampling_client(
    name=f"{DATASET}-sft-demo"
)
 
# Create a test prompt
test_messages = [
    {"role": "user", "content": "Write a haiku about programming."}
]
 
# Apply chat template for generation
prompt_text = tokenizer.apply_chat_template(
    test_messages,
    tokenize=False,
    add_generation_prompt=True,
)
prompt_tokens = tokenizer.encode(prompt_text, add_special_tokens=False)
 
# Sample from the model
sample_result = sampling_client.sample(
    prompt=types.ModelInput.from_ints(prompt_tokens),
    num_samples=1,
    sampling_params=types.SamplingParams(
        max_tokens=128,
        temperature=0.7,
        stop_token_ids=[tokenizer.eos_token_id],
    ),
).result()
 
response = tokenizer.decode(sample_result.sequences[0].tokens)
print("User: Write a haiku about programming.")
print(f"Assistant: {response}")

Sampling parameters:

  • temperature=0.7: Controls randomness. Lower = more deterministic, higher = more creative
  • stop_token_ids: Stop generating when EOS token is produced
  • max_tokens: Maximum tokens to generate

Step 10: Save Checkpoint

Save the final checkpoint for later use or continued training:

checkpoint = training_client.save_state(name=f"{DATASET}-sft-final").result()
print(f"Checkpoint saved: {checkpoint.path}")

To resume training later:

resumed_client = service_client.create_training_client_from_state_with_optimizer(
    checkpoint.path
)

Summary

ComponentImplementation
Datasetno_robots (9.5K) or tulu3 (939K) from HuggingFace
TokenizationHuggingFace tokenizer with chat template
Loss maskingTrain on assistant messages only
Trainingcross_entropy loss with LoRA
Checkpointingsave_state() for weights + optimizer

Key API Methods

# Setup
service_client = mint.ServiceClient()
training_client = service_client.create_lora_training_client(base_model=...)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
 
# Training
training_client.forward_backward(datums, loss_fn="cross_entropy")
training_client.optim_step(types.AdamParams(learning_rate=...))
 
# Evaluation
training_client.forward(datums, loss_fn="cross_entropy")  # No gradients
 
# Inference
sampling_client = training_client.save_weights_and_get_sampling_client(name=...)
sampling_client.sample(prompt, num_samples, sampling_params)
 
# Checkpointing
checkpoint = training_client.save_state(name=...)