FP Complete


This blog post is the first in a planned series I’m calling “Rust quickies.” In my training sessions, we often come up with quick examples to demonstrate some point. Instead of forgetting about them, I want to put short blog posts together focusing on these examples. Hopefully these will be helpful, enjoy!

FP Complete is looking for Rust and DevOps engineers. Interested in working with us? Check out our jobs page.

Short circuiting a for loop

Let’s say I’ve got an Iterator of u32s. I want to double each value and print it. Easy enough:

fn weird_function(iter: impl IntoIterator<Item=u32>) {
    for x in iter.into_iter().map(|x| x * 2) {
        println!("{}", x);
    }
}

fn main() {
    weird_function(1..10);
}

And now let’s say we hate the number 8, and want to stop when we hit it. That’s a simple one-line change:

fn weird_function(iter: impl IntoIterator<Item=u32>) {
    for x in iter.into_iter().map(|x| x * 2) {
        if x == 8 { return } // added this line
        println!("{}", x);
    }
}

Easy, done, end of story. And for this reason, I recommend using for loops when possible. Even though, from a functional programming background, it feels overly imperative. However, some people out there want to be more functional, so let’s explore that.

for_each vs map

Let’s forget about the short-circuiting for a moment. And now we want to go back to the original version of the program, but without using a for loop. Easy enough with the method for_each. It takes a closure, which it runs for each value in the Iterator. Let’s check it out:

fn weird_function(iter: impl IntoIterator<Item=u32>) {
    iter.into_iter().map(|x| x * 2).for_each(|x| {
        println!("{}", x);
    })
}

But why, exactly do we need for_each? That seems awfully similar to map, which also applies a function over every value in an Iterator. Trying to make that change, however, demonstrates the problem. With this code:

fn weird_function(iter: impl IntoIterator<Item=u32>) {
    iter.into_iter().map(|x| x * 2).map(|x| {
        println!("{}", x);
    })
}

we get an error message:

error[E0308]: mismatched types
 --> srcmain.rs:2:5
  |
2 | /     iter.into_iter().map(|x| x * 2).map(|x| {
3 | |         println!("{}", x);
4 | |     })
  | |______^ expected `()`, found struct `Map`

Undaunted, I fix this error by sticking a semicolon at the end of that expression. That generates a warning of unused `Map` that must be used. And sure enough, running this program produces no output.

The problem is that map doesn’t drain the Iterator. Said another way, map is lazy. It adapts one Iterator into a new Iterator. But unless something comes along and drains or forces the Iterator, no actions will occur. By contrast, for_each will always drain an Iterator.

One easy trick to force draining of an Iterator is with the count() method. This will perform some unnecessary work of counting how many values are in the Iterator, but it’s not that expensive. Another approach would be to use collect. This one is a little trickier, since collect typically needs some type annotations. But thanks to a fun trick of how FromIterator is implemented for the unit type, we can collect a stream of ()s into a single () value. Meaning, this code works:

fn weird_function(iter: impl IntoIterator<Item=u32>) {
    iter.into_iter().map(|x| x * 2).map(|x| {
        println!("{}", x);
    }).collect()
}

Note the lack of a semicolon at the end there. What do you think will happen if we add in the semicolon?

Short circuiting

EDIT Enough people have asked “why not use take_while?” that I thought I’d address it. Yes, below, take_while will work for “short circuiting.” It’s probably even a good idea. But the main goal in this post is to explore some funny implementation approaches, not recommend a best practice. And overall, despite some good arguments for take_while being a good choice here, I still stand by the overall recommendation to prefer for loops for simplicity.

With the for loop approach, stopping at the first 8 was a trivial, 1 line addition. Let’s do the same thing here:

fn weird_function(iter: impl IntoIterator<Item=u32>) {
    iter.into_iter().map(|x| x * 2).map(|x| {
        if x == 8 { return }
        println!("{}", x);
    }).collect()
}

Take a guess at what the output will be. Ready? OK, here’s the real thing:

2
4
6
10
12
14
16
18

We skipped 8, but we didn’t stop. It’s the difference between a continue and a break inside the for loop. Why did this happen?

It’s important to think about the scope of a return. It will exit the current function. And in this case, the current function isn’t weird_function, but the closure inside the map call. This is what makes short-circuiting inside map so difficult.

The same exact comment will apply to for_each. The only way to stop a for_each from continuing is to panic (or abort the program, if you want to get really aggressive).

But with map, we have some ingenious ways of working around this and short-circuiting. Let’s see it in action.

collect an Option

map needs some draining method to drive it. We’ve been using collect. I’ve previously discussed the intricacies of this method. One cool feature of collect is that, for Option and Result, it provides short-circuit capabilities. We can modify our program to take advantage of that:

fn weird_function(iter: impl IntoIterator<Item=u32>) -> Option<()> {
    iter.into_iter().map(|x| x * 2).map(|x| {
        if x == 8 { return None } // short circuit!
        println!("{}", x);
        Some(()) // keep going!
    }).collect()
}

I put a return type of weird_function, though we could also use turbofish on collect and throw away the result. We just need some type annotation to say what we’re trying to collect. Since collecting the underlying () values doesn’t take up extra memory, this is even pretty efficient! The only cost is the extra Option. But that extra Option is (arguably) useful; it lets us know if we short-circuited or not.

But the story isn’t so rosy with other types. Let’s say our closure within map returns the x value. In other words, replace the last line with Some(x) instead of Some(()). Now we need to somehow collect up those u32s. Something like this would work:

fn weird_function(iter: impl IntoIterator<Item=u32>) -> Option<Vec<u32>> {
    iter.into_iter().map(|x| x * 2).map(|x| {
        if x == 8 { return None } // short circuit!
        println!("{}", x);
        Some(x) // keep going!
    }).collect()
}

But that incurs a heap allocation that we don’t want! And using count() from before is useless too, since it won’t even short circuit.

But we do have one other trick.

sum

It turns out there’s another draining method on Iterator that performs short circuiting: sum. This program works perfectly well:

fn weird_function(iter: impl IntoIterator<Item=u32>) -> Option<u32> {
    iter.into_iter().map(|x| x * 2).map(|x| {
        if x == 8 { return None } // short circuit!
        println!("{}", x);
        Some(x) // keep going!
    }).sum()
}

The downside is that it’s unnecessarily summing up the values. And maybe that could be a real problem if some kind of overflow occurs. But this mostly works. But is there some way we can stay functional, short circuit, and get no performance overhead? Sure!

Short

The final trick here is to create a new helper type for summing up an Iterator. But this thing won’t really sum. Instead, it will throw away all of the values, and stop as soon as it sees an Option. Let’s see it in practice:

#[derive(Debug)]
enum Short {
    Stopped,
    Completed,
}

impl<T> std::iter::Sum<Option<T>> for Short {
    fn sum<I: Iterator<Item = Option<T>>>(iter: I) -> Self {
        for x in iter {
            if let None = x { return Short::Stopped }
        }
        Short::Completed
    }
}
fn weird_function(iter: impl IntoIterator<Item=u32>) -> Short {
    iter.into_iter().map(|x| x * 2).map(|x| {
        if x == 8 { return None } // short circuit!
        println!("{}", x);
        Some(x) // keep going!
    }).sum()
}

fn main() {
    println!("{:?}", weird_function(1..10));
}

And voila! We’re done!

Exercise It’s pretty cheeky to use sum here. collect makes more sense. Replace sum with collect, and then change the Sum implementation into something else. Solution at the end.

Conclusion

That’s a lot of work to be functional. Rust has a great story around short circuiting. And it’s not just with return, break, and continue. It’s with the ? try operator, which forms the basis of error handling in Rust. There are times when you’ll want to use Iterator adapters, async streaming adapters, and functional-style code. But unless you have a pressing need, my recommendation is to stick to for loops.

If you liked this post, and would like to see more Rust quickies, let me know. You may also like these other pages:

Solution

use std::iter::FromIterator;

#[derive(Debug)]
enum Short {
    Stopped,
    Completed,
}

impl<T> FromIterator<Option<T>> for Short {
    fn from_iter<I: IntoIterator<Item = Option<T>>>(iter: I) -> Self {
        for x in iter {
            if let None = x { return Short::Stopped }
        }
        Short::Completed
    }
}
fn weird_function(iter: impl IntoIterator<Item=u32>) -> Short {
    iter.into_iter().map(|x| x * 2).map(|x| {
        if x == 8 { return None } // short circuit!
        println!("{}", x);
        Some(x) // keep going!
    }).collect()
}

fn main() {
    println!("{:?}", weird_function(1..10));
}

Subscribe to our blog via email

Email subscriptions come from our Atom feed and are handled by Blogtrottr. You will only receive notifications of blog posts, and can unsubscribe any time.

Tagged