Progressive bug finding in the open-source of Deep Learning

Read Time: 11 minutes

Deep learning frameworks like PyTorch, TensorFlow, and JAX form the backbone of modern AI applications. Yet, bugs in these frameworks can have catastrophic consequences. In this post, I’ll share our experience finding over 300 critical bugs in major ML systems through systematic testing.

The Problem Space

ML frameworks are complex beasts:

  • PyTorch: 1.5M+ lines of C++/Python code
  • TensorFlow: 2M+ lines across multiple languages
  • JAX: 500K+ lines with heavy JIT compilation

Traditional testing approaches fall short because:

  1. The input space is enormous (arbitrary tensor operations)
  2. Correctness is often probabilistic
  3. Performance bugs are as critical as functional bugs

Our Approach: Progressive Bug Finding

We developed a progressive testing methodology that combines multiple techniques:

1. Coverage-Guided Fuzzing

Our fuzzer, Tzer, uses coverage feedback to explore new code paths:

def fuzz_iteration():
    program = mutate(seed_program)
    coverage = execute_with_coverage(program)
    if coverage > max_coverage:
        seed_pool.add(program)
        max_coverage = coverage

2. Differential Testing

We compare outputs across different:

  • Frameworks (PyTorch vs TensorFlow)
  • Backends (CPU vs GPU vs TPU)
  • Optimization levels (O0 vs O3)

3. Model-Based Generation

NNSmith generates valid DNN models using type-aware constraints:

class ModelGenerator:
    def generate(self):
        graph = self.create_graph()
        self.add_operators(graph)
        self.connect_edges(graph)
        return self.materialize(graph)

Bug Patterns We Found

Memory Corruption (45% of bugs)

// Bug in PyTorch's tensor indexing
Tensor index_select(Tensor input, int dim, Tensor index) {
    // Missing bounds check
    auto* data = input.data_ptr();
    return data[index];  // Potential out-of-bounds access
}

Type Confusion (23% of bugs)

# Bug in TensorFlow's type inference
def buggy_op(x):
    if x.dtype == tf.float32:
        return tf.cast(x, tf.int32)
    # Missing else case causes undefined behavior

Race Conditions (18% of bugs)

Concurrent execution without proper synchronization:

// Missing mutex in JAX's cache
void cache_compiled_function(key, func) {
    // Race condition when multiple threads compile
    cache[key] = compile(func);
}

Numerical Instabilities (14% of bugs)

# Gradient explosion in custom op
def unstable_gradient(x):
    return tf.exp(x) * 1000  # Causes inf/nan

Impact and Fixes

Our bug reports led to:

  • 127 CVEs assigned for security vulnerabilities
  • 89 patches in PyTorch 2.0
  • 67 fixes in TensorFlow 2.11
  • 45 improvements in JAX 0.4

Example fix in PyTorch:

// After our report
Tensor index_select_safe(Tensor input, int dim, Tensor index) {
    TORCH_CHECK(dim >= 0 && dim < input.dim());
    TORCH_CHECK(index.min().item<int>() >= 0);
    TORCH_CHECK(index.max().item<int>() < input.size(dim));
    // ... safe implementation
}

Lessons Learned

1. Fuzzing Works for ML Systems

Despite the complexity, coverage-guided fuzzing finds real bugs.

2. Differential Testing is Powerful

Comparing implementations reveals subtle correctness issues.

3. Performance Bugs Matter

A 10x slowdown is often worse than a crash.

4. Collaboration is Key

Working with framework developers ensures fixes are practical.

Tools Released

We’ve open-sourced our testing tools:

  • Tzer: Coverage-guided tensor program fuzzer
  • NNSmith: Model-based DNN generator
  • NeuRI: Rule-based model diversification

What’s Next?

We’re expanding our testing to:

  • Emerging frameworks: Triton, MLC, TinyGrad
  • Hardware accelerators: Testing compiler backends for GPUs/TPUs
  • Training pipelines: Finding bugs in distributed training

Conclusion

Systematic testing of ML frameworks is crucial for AI safety and reliability. Our progressive approach has uncovered hundreds of critical bugs, making these frameworks more robust for millions of developers worldwide.

This work is supported by NSF and collaborative efforts with PyTorch, TensorFlow, and JAX teams.




Enjoy Reading This Article?

Here are some more articles you might like to read next:

  • Blazing-Fast Code Editing via Multi-Layer Speculation
  • What we talk when we talk about coverage
  • Memory allocation made right