[Research Preview] Thinking Preference Optimization

3 minute read

Published:

  • authors: Wang Yang, Song Jiang, Hongye Jin, Xiaotian Han

TL;DR

  • We introduce a novel approach to enhance model reasoning by using long/short CoT as preferred/rejected examples in DPO.
  • Our method demonstrates that carefully curated long/short thinking data pairs can significantly boost mathematical reasoning capabilities.

This approach demonstrates a practical and efficient way to improve a model’s ability to perform complex reasoning tasks. By leveraging a curated combination of fast-thinking and slow-thinking data with DPO training, gains in performance can be achieved, even with a modest dataset size.

Overview

Developing models with robust reasoning capabilities remains a key challenge in AI. While supervised fine-tuning (SFT) on extensive slow-thinking datasets is effective, we discovered that a targeted approach using a smaller dataset (5,000 examples) contrasting fast and slow thinking can yield substantial improvements.

This blog presents our methodology, including dataset curation and training details, along with empirical results demonstrating the effectiveness of our approach.

Method

Data Curation

Our approach utilizes two distinct dataset types:

  1. SFT Data:
    • OpenO1-SFT:
      • Based on O1-OPEN/OpenO1-SFT dataset, we remove special tokens (e.g., <Thought>) and reformat final answers into the \box format using GPT-4o-mini. we also filter out non-English entries and sequences longer than 8,192 tokens. The final dataset size is approximately 80,000 examples.
    • NuminaMath-CoT:
      • Randomly select examples from AI-MO/NuminaMath-CoT to match the size of the OpenO1-SFT dataset, used for comparison.
    • Sky-T1_data_17k:
      • Based on NovaSky-AI/Sky-T1_data_17k, we remove special characters and reformat the final answer into QWQ’s output style.
  2. DPO Data:
    • Rejected Answers: A random selection of 5,000 examples from AI-MO/NuminaMath-CoT , where the dataset’s original answers were treated as the “rejected” entries.
    • Chosen Answers: New answers for each problem were generated using Qwen/QwQ-32B-Preview model.
    • Filtering: Only entries where both the “chosen” and “rejected” answers were correct were included, ensuring high-quality data.

Training

  1. SFT Training:
    • Each SFT dataset is used to train a LLaMA-3.1-8B model for one epoch with a learning rate of 3e-5 and a batch size of 32.
  2. DPO Training:
    • Using the 5,000 curated examples, we perform DPO training on the SFT-trained models.
    • As a baseline, we also conduct DPO training directly on LLaMA 3.1 8B-Instruct model.
    • All DPO training is with the beta of 0.01 and a batch size of 32.

Results

  • We present the results of models trained with different datasets using SFT, followed by DPO training, on the test dataset.

Performance on AIME, MATH500, and MATH Overall

SFT Dataset/Model FamilyVariantAIMEMATH500MATH (Overall)
LLaMA 8B-Instructbase model0.020.2880.33
 dpo(lr=3e-7)0.040.2860.34
NuminaMath-CoTsft(lr=3e-5)0.00.1620.25
 dpo(lr=1e-7)0.00.1760.26
OpenO1-SFTsft(lr=3e-5)0.00.230.33
 dpo(lr=1e-7)0.010.2340.35
 dpo(lr=3e-7)0.010.2840.40
Sky-T1_data_17ksft(lr=3e-5 )0.010.2360.33
 dpo(lr=1e-7)0.00.2520.36
  • Key Observations:
    • DPO training consistently improves performance across all datasets
    • OpenO1-SFT shows the most significant gains (+23.48% on MATH500)
    • Improvements are maintained across different difficulty levels

Performance on MATH across different levels

  • To further observe the model’s performance in mathematical reasoning, we collect 100 data points for each level in MATH dataset, and test the model’s scores on datasets from different levels.
SFT Dataset/Model FamilyVariantMATH1MATH2MATH3MATH4MATH5
LLaMA 8B-Instructbase model0.590.390.370.200.10
 dpo(lr=3e-7)0.660.330.360.250.13
NuminaMath-CoTsft(lr=3e-5 )0.450.310.190.090.05
 dpo(lr=1e-7)0.510.270.240.110.01
OpenO1-SFTsft(lr=3e-5)0.580.400.310.200.10
 dpo(lr=1e-7)0.670.380.330.200.09
 dpo(lr=3e-7)0.730.450.400.280.09
Sky-T1_data_17ksft(lr=3e-5)0.610.360.330.220.08
 dpo(lr=1e-7)0.680.400.260.220.14

Notes

Kimi 1.5 model explored similar approach but focused on using short CoT as preferred examples for inference efficiency.