PNG compression involves two schemes — filtering and DEFLATE.
Filtering is a pre-processing step that operates row-by-row and is used to decrease entropy in the data. It works off the assumption that pixels and their neighbors are usually similar, but not necessarily the exact same. DEFLATE is a common lossless compression format combining LZ77 and Huffman coding.
The step that I'm interested in talking about right now is filtering. You can find a pretty good explanation of the algorithm in the PNG specification, but I'll walk through a quick summary of the parts that are relevant to this post.
I'll start by introducing two primitives: the pixel, and "bpp" or bits per pixel. PNGs support a number of different color formats, and those formats affect how we encode and decode pixels. There are two properties we care about — color type and bit depth.
Color type defines the channels or components that make up a pixel. For example, in the RGBA color type, pixels consist of 4 channels — red, green, blue, and alpha. PNGs support simple grayscale, grayscale with alpha, RGB, RGBA, and an "indexed" color type that lets you assign a single integer to each color.
Bit depth defines the number of bits per channel. Certain color types only permit certain bit depths. If you're curious, the list of permitted combinations can be found in the spec.
By combining the color type, which defines the number of channels, and the bit depth, which defines the number of bits per channel, we can find the number of bits per pixel. We refer to this value as "bpp." Although bpp
typically refers to bits per pixel, for the rest of this post, the "b" in "bpp" will refer to "bytes."
Let's look at a simple example:
If we have an RGB color type with a bit depth of 8, our bits per pixel is 3 * 8
or 24
, and our bytes per pixel is (bits per pixel) / 8
= 3
.
When applying filters, the minimum bytes per pixel used is 1, even if the number of bits per pixel is less than a full byte.
Filters are applied for every byte, regardless of bit depth. This means that if the number of bits per channel is greater than a full byte, we operate on the bytes of that channel separately.
The PNG Filters
There are 5 filters — none, up, sub, average, and paeth. Each filter applies a certain operation to a row of bytes. We'll walk through a simple explanation of the first 3, but we won't talk about the average
or paeth
filters. They're pretty interesting, but the rest of this post will focus on the sub
filter, so we don't need to worry about understanding how they work.
The none
filter, as the name suggests, does not alter the bytes and just copies them as-is.
The up
filter takes the pixel at position n
and subtracts it by the pixel at position n
in the row above it. For example, if we have two rows that look like this:
[1, 2, 3, 4, 5]
[1, 2, 3, 4, 5]
After applying the up
filter, we get this compressed result:
[1, 2, 3, 4, 5]
[0, 0, 0, 0, 0]
We consider the row before the first row to contain only zeros, so the first row is unchanged.
The sub
filter takes the pixel at position n
and subtracts it by the pixel in the same row at position n - 1
. For example, if we have a row that looks like this:
[1, 2, 3, 4, 5]
If we apply the sub
filter, we get this result:
# [1 - 0, 2 - 1, 3 - 2, 4 - 3, 5 - 4]
[1, 1, 1, 1, 1]
The sub
filter operates on individual channels. That is, the red channel of pixel n
is subtracted by the red channel of pixel n - 1
, the blue channel is subtracted by the prior pixel's blue channel, and so on. The calculation for finding the corresponding channel involves the pixel's bpp
.
If we look at the sub
filter as operating on individual bytes, we say that the algorithm is filtered[n] = unfiltered[n] - unfiltered[n - bpp]
. Where bpp
is calculated based on the color type and bit depth.
If this sounds a bit confusing, it should make a lot more sense when we start looking at an implementation in code.
To decode any of these filters, you just have to add to the filtered value, rather than subtract from the raw value.
Implementing the sub
Filter
This is all just required background reading to understand what we're really interested in: optimizing the sub
filter for 8-bit RGBA pixels. Although we introduced the filters by discussing how they're encoded, for the rest of this post we'll only be talking about how they're decoded.
Before moving on, however, I do want to note that the performance characteristics of filters inside the context of PNG decoders and PNG encoders are very different. This is because PNG decoders only have to apply a filter once per row, while a good encoder will likely try all filters for all rows. This can make PNG encoding somewhat slow, and also makes optimizations to individual filters more impactful. If we optimize a filter for PNG decoding, we only see wins for images that use that filter heavily. This may be an area I explore in the future, but for now my focus is primarily on decoding, as that's the operation most commonly performed on PNG files.
As promised, let's look at a simple code implementation of decoding the sub
filter. All code examples going forward will be in rust.
pub fn sub(raw_row: &[u8], decoded_row: &mut [u8]) {
for i in 0..BYTES_PER_PIXEL {
decoded_row[i] = raw_row[i];
}
for i in BYTES_PER_PIXEL..decoded_row.len() {
let left = decoded_row[i - BYTES_PER_PIXEL];
decoded_row[i] = raw_row[i].wrapping_add(left)
}
}
Bytes before the start of the row are 0, so we can just copy the first bpp
bytes into the decoded row without doing any operations. For the next bytes, we add decoded_row[i - bpp]
to the filtered byte.
Let's look at how LLVM does with this:
example::sub:
push rax
test rsi, rsi
je .LBB2_1
test rcx, rcx
je .LBB2_11
movzx eax, byte ptr [rdi]
mov byte ptr [rdx], al
cmp rsi, 1
je .LBB2_13
cmp rcx, 1
je .LBB2_15
movzx eax, byte ptr [rdi + 1]
mov byte ptr [rdx + 1], al
cmp rsi, 2
je .LBB2_17
cmp rcx, 2
je .LBB2_19
movzx eax, byte ptr [rdi + 2]
mov byte ptr [rdx + 2], al
cmp rsi, 3
je .LBB2_21
cmp rcx, 3
je .LBB2_23
movzx eax, byte ptr [rdi + 3]
mov byte ptr [rdx + 3], al
cmp rcx, 4
jbe .LBB2_27
lea r8, [rsi - 4]
cmp rcx, r8
cmovb r8, rcx
lea rax, [rcx - 5]
cmp r8, rax
cmovae r8, rax
inc r8
mov r10d, 4
cmp r8, 4
jbe .LBB2_4
mov r9d, r8d
and r9d, 3
mov eax, 4
cmovne rax, r9
mov r9, r8
sub r9, rax
neg rax
lea r10, [r8 + rax]
add r10, 4
xor r8d, r8d
.LBB2_9:
vmovd xmm0, dword ptr [rdi + r8 + 4]
vmovd xmm1, dword ptr [rdx + r8]
vpaddb xmm0, xmm1, xmm0
vmovd dword ptr [rdx + r8 + 4], xmm0
add r8, 4
cmp r9, r8
jne .LBB2_9
.LBB2_4:
mov r8, rsi
neg r8
add r10, -4
mov r9, rcx
neg r9
.LBB2_5:
cmp rcx, r10
je .LBB2_28
lea rax, [r8 + r10]
cmp rax, -4
je .LBB2_7
movzx eax, byte ptr [rdx + r10]
add al, byte ptr [rdi + r10 + 4]
mov byte ptr [rdx + r10 + 4], al
lea rax, [r9 + r10]
inc rax
inc r10
cmp rax, -4
jne .LBB2_5
.LBB2_27:
pop rax
ret
; ... panic handling code omitted
It does surprisingly well — though not perfect. It looks like our first loop is performed in serial and has bounds checks on every iteration. This is because we don't actually know that our input or output slice has at least BYTES_PER_PIXEL
elements. If this were in the context of a real PNG decoder and this function got inlined, LLVM may be able to do a better job eliding bounds checks. In a real PNG decoder, a row length of 0 would imply the image data is empty and defiltering can be skipped altogether. The row length is guaranteed to be a multiple of BYTES_PER_PIXEL
, and if it isn't, due to the image being malformed or truncated, we'd expect the PNG decoder to have errored out by this point. Going forward, our implementations of the sub
filter will rely on these two assumptions.
The second loop is pretty interesting. It looks like LLVM is doing some magic to be able to vectorize it and perform the loads and additions BYTES_PER_PIXEL
elements at a time. As we'll see later, this actually comes surprisingly close to our handwritten SIMD implementation.
Though, it does look like there's still a lot of code dedicated to handling bounds checks. Let's see the impact of removing them. For simplicity, we'll drop into unsafe to remove them. This code is just for experimenting, so we won't include any debug assertions that you'd expect in proper code using unsafe in this way.
pub unsafe fn sub_no_bound_checks(raw_row: &[u8], decoded_row: &mut [u8]) {
for i in 0..BYTES_PER_PIXEL {
*decoded_row.get_unchecked_mut(i) = *raw_row.get_unchecked(i);
}
for i in BYTES_PER_PIXEL..decoded_row.len() {
let left = *decoded_row.get_unchecked(i - BYTES_PER_PIXEL);
*decoded_row.get_unchecked_mut(i) = raw_row.get_unchecked(i).wrapping_add(left)
}
}
example::sub_no_bound_checks:
push rbx
mov eax, dword ptr [rdi]
mov dword ptr [rdx], eax
cmp rcx, 5
jb .LBB2_12
lea r8, [rcx - 4]
mov ebx, 4
cmp r8, 4
jb .LBB2_11
mov rbx, r8
vmovd xmm0, dword ptr [rdx]
and rbx, -4
lea rsi, [rbx - 4]
mov r10, rsi
shr r10, 2
inc r10
mov r9d, r10d
and r9d, 7
cmp rsi, 28
jae .LBB2_4
xor esi, esi
jmp .LBB2_6
.LBB2_4:
and r10, -8
xor esi, esi
.LBB2_5:
vmovd xmm1, dword ptr [rdi + rsi + 4]
vpaddb xmm0, xmm1, xmm0
vmovd dword ptr [rdx + rsi + 4], xmm0
vmovd xmm1, dword ptr [rdi + rsi + 8]
vpaddb xmm0, xmm1, xmm0
vmovd dword ptr [rdx + rsi + 8], xmm0
vmovd xmm1, dword ptr [rdi + rsi + 12]
vpaddb xmm0, xmm1, xmm0
vmovd dword ptr [rdx + rsi + 12], xmm0
vmovd xmm1, dword ptr [rdi + rsi + 16]
vpaddb xmm0, xmm1, xmm0
vmovd dword ptr [rdx + rsi + 16], xmm0
vmovd xmm1, dword ptr [rdi + rsi + 20]
vpaddb xmm0, xmm1, xmm0
vmovd dword ptr [rdx + rsi + 20], xmm0
vmovd xmm1, dword ptr [rdi + rsi + 24]
vpaddb xmm0, xmm1, xmm0
vmovd dword ptr [rdx + rsi + 24], xmm0
vmovd xmm1, dword ptr [rdi + rsi + 28]
vpaddb xmm0, xmm1, xmm0
vmovd dword ptr [rdx + rsi + 28], xmm0
vmovd xmm1, dword ptr [rdi + rsi + 32]
vpaddb xmm0, xmm1, xmm0
vmovd dword ptr [rdx + rsi + 32], xmm0
add rsi, 32
add r10, -8
jne .LBB2_5
.LBB2_6:
test r9, r9
je .LBB2_9
lea r10, [rdx + rsi]
add r10, 4
lea r11, [rdi + rsi]
add r11, 4
shl r9, 2
xor esi, esi
.LBB2_8:
vmovd xmm1, dword ptr [r11 + rsi]
vpaddb xmm0, xmm1, xmm0
vmovd dword ptr [r10 + rsi], xmm0
add rsi, 4
cmp r9, rsi
jne .LBB2_8
.LBB2_9:
cmp r8, rbx
je .LBB2_12
add rbx, 4
.LBB2_11:
movzx eax, byte ptr [rdi + rbx]
add al, byte ptr [rdx + rbx - 4]
mov byte ptr [rdx + rbx], al
lea rax, [rbx + 1]
mov rbx, rax
cmp rcx, rax
jne .LBB2_11
.LBB2_12:
pop rbx
ret
This does look a lot nicer. LLVM was able to autovectorize even more and it unrolled the second loop. Let's set up some benchmarks to see how big of an impact we had here. To benchmark, we'll use rust's stdlib benchmarking tools. We could use something like criterion, but that's not necessary for what we're doing here. We do want good and scientific benchmarks as we start to iterate and explore new algorithms, but the stdlib solution is sufficient for that.
#![feature(test)]
// ... our `sub` and `sub_no_bound_checks` implementations
#[cfg(test)]
mod bench {
use super::*;
use test::Bencher;
const BUFFER_SIZE: usize = 2_usize.pow(20);
#[bench]
fn bench_sub_naive_scalar(b: &mut Bencher) {
let raw_row = std::hint::black_box([10; BUFFER_SIZE]);
let mut decoded_row = std::hint::black_box([0; BUFFER_SIZE]);
b.iter(|| sub(&raw_row, &mut decoded_row));
std::hint::black_box(decoded_row);
}
#[bench]
fn bench_sub_no_bound_checks(b: &mut Bencher) {
let raw_row = std::hint::black_box([10; BUFFER_SIZE]);
let mut decoded_row = std::hint::black_box([0; BUFFER_SIZE]);
b.iter(|| unsafe { sub_no_bound_checks(&raw_row, &mut decoded_row) });
std::hint::black_box(decoded_row);
}
}
And now if we run our benchmarks using cargo bench
, we get
running 2 tests
test tests::bench_sub_naive_scalar ... bench: 86,305 ns/iter (+/- 1,123)
test tests::bench_sub_no_bound_checks ... bench: 86,289 ns/iter (+/- 2,324)
Although removing the bounds checks made our codegen look a lot nicer, it doesn't seem to actually improve our performance above noise. That's unfortunate. Let's see if we can do better.
The first goal is to try to get an idea of how close to the optimal solution we are. We can try comparing our filters to memcpy
. We want to see how much overhead our subtraction is adding to the copying of bytes from raw_row
to decoded_row
.
We can add a benchmark that looks like this,
pub unsafe fn baseline_memcpy(raw_row: &[u8], decoded_row: &mut [u8]) {
decoded_row
.get_unchecked_mut(0..raw_row.len())
.copy_from_slice(&*raw_row);
}
fn bench_baseline_memcpy(b: &mut Bencher) {
let raw_row = std::hint::black_box([10; BUFFER_SIZE]);
let mut decoded_row = std::hint::black_box([0; BUFFER_SIZE]);
b.iter(|| unsafe { baseline_memcpy(&raw_row, &mut decoded_row) });
std::hint::black_box(decoded_row);
}
And to double check that baseline_memcpy
gets optimized how we expect,
example::baseline_memcpy:
mov rax, rdx
mov rdx, rsi
mov rsi, rdi
mov rdi, rax
jmp qword ptr [rip + memcpy@GOTPCREL]
Great. Now let's compare it to our sub
implementations.
running 3 tests
test tests::bench_baseline_memcpy ... bench: 61,875 ns/iter (+/- 3,949)
test tests::bench_sub_naive_scalar ... bench: 86,731 ns/iter (+/- 1,833)
test tests::bench_sub_no_bound_checks ... bench: 86,584 ns/iter (+/- 1,374)
Our initial implementations don't actually seem to be that bad. But there's probably a lot of room for improvement here.
Current State of the Art
libpng
has had optimized filter implementations using explicit SIMD for close to a decade.
The optimization that they make is based around bpp
.
It's difficult to decode the PNG filters in parallel. At first glance, it looks like there's a pretty strict data dependency on previous iterations. If we look at an example calculation, this becomes pretty clear. For simplicity in our example, we'll say bpp
is 1. When we implement things in code, we'll usually default to a bpp
of 4.
We'll start with the filtered array [1, 2, 3]
and walk through defiltering it.
filtered = [1, 2, 3]
defiltered = [0, 0, 0]
# defiltered = [1, 0, 0]
defiltered[0] = filtered[0]
# defiltered = [1, 3, 0]
defiltered[1] = defiltered[0] + filtered[1]
# defiltered = [1, 3, 6]
defiltered[2] = defiltered[1] + filtered[2]
The last calculation depends on the results of the second-to-last calculation. How can we work around this?
libpng
's optimization doesn't really try to — it still does a lot of things in serial. But when the bpp
is greater than 1, it can operate on bpp
bytes per iteration rather than going byte-by-byte. This ends up being pretty fast — for a bpp
of 4, you're operating on 4x the number of bytes.
Let's look at how this implementation works in practice. We're going to port libpng's 4 bpp implementation to rust:
unsafe fn load4(x: [u8; 4]) -> __m128i {
let tmp = i32::from_le_bytes(*x);
_mm_cvtsi32_si128(tmp)
}
unsafe fn store4(x: &mut [u8; 4], v: __m128i) {
let tmp = _mm_cvtsi128_si32(v);
x.get_unchecked_mut(..4).copy_from_slice(&tmp.to_le_bytes());
}
pub unsafe fn sub_sse2(raw: &[u8], current: &mut [u8]) {
let mut a: __m128i;
let mut d = _mm_setzero_si128();
let mut rb = raw_row.len() + 4;
let mut idx = 0;
while rb > 4 {
a = d;
d = load4([
*raw_row.get_unchecked(idx),
*raw_row.get_unchecked(idx + 1),
*raw_row.get_unchecked(idx + 2),
*raw_row.get_unchecked(idx + 3),
]);
d = _mm_add_epi8(d, a);
store4(&mut decoded_row.get_unchecked_mut(idx..), d);
idx += 4;
rb -= 4;
}
}
So instead of loading, adding, and storing one byte at a time, we can operate on 4 bytes at a time. Let's benchmark this implementation to see how it performs.
#[bench]
fn bench_sub_sse2(b: &mut Bencher) {
let raw_row = std::hint::black_box([10; BUFFER_SIZE]);
let mut decoded_row = std::hint::black_box([0; BUFFER_SIZE]);
b.iter(|| unsafe { sub_sse2(&*raw_row, &mut *decoded_row) });
std::hint::black_box(decoded_row);
}
When we benchmark this time, we want to make use of SIMD intrinsics. To force LLVM to compile them optimally we'll have to configure target-cpu=native
. If we don't do this, our intrinsics will be compiled suboptimally and we actually tend to get slower code than the scalar version, at least on my Linux machine.
RUSTFLAGS='-Ctarget-cpu=native' cargo bench
running 4 tests
test tests::bench_baseline_memcpy ... bench: 62,844 ns/iter (+/- 3,258)
test tests::bench_sub_naive_scalar ... bench: 86,798 ns/iter (+/- 1,502)
test tests::bench_sub_no_bound_checks ... bench: 86,719 ns/iter (+/- 2,057)
test tests::bench_sub_sse2 ... bench: 86,573 ns/iter (+/- 1,004)
Pretty much no improvement. We saw previously that LLVM was able to autovectorize our loop to operate on 4 bytes at a time already, so this is likely the explanation for why our explicit SIMD implementation isn't too much faster. It could be that we missed something in porting the C code, but that seems unlikely here. In general, it doesn't seem that we can get a massive win if we're stuck operating on bpp
bytes at a time.
Trying a Different Algorithm
About a year ago, I had the idea to try solving the PNG filters using AVX and AVX2. AVX enables us to operate on 32 bytes at a time, compared to our current implementation that operates on at most 4 bytes at a time. If we're able to use AVX registers and instructions, we'd be able to operate on 8x the number of bytes as existing implementations of the filters.
After playing around with the problem for a while, I realized1 that decoding the sub
filter can be pretty trivially reduced down to a pretty well-studied problem called prefix sum. Prefix sum happens to be extremely easy to compute in parallel, which makes our problem a lot simpler.
The idea behind parallel prefix sum is that you can trivially subdivide the problem and then combine the results of the separate executions. Let's take a simple example:
Given an array [1, 2, 3, 4]
, the serial solution would be to just loop over the entire array. In a parallel version, we can split this array into [1, 2]
and [3, 4]
and compute the prefix sums separately. Then, we can take the last element of the first array and add it to each element in the second array. We'll call this the accumulate step.
So after executing the prefix sum step, we end up with the two arrays [1, 3]
and [3, 7]
. Then we apply the accumulate step and end up with two arrays [1, 3]
and [6, 10]
. Combining them, we get [1, 3, 6, 10]
which is the correct prefix sum result we're looking for.
This sounds like we're doing more work — and we are. But these kinds of operations are really fast in SIMD, so the actual number of instructions per byte is significantly less than in the scalar solution.
Algorithmica has a pretty good explanation of vectorized prefix sum that goes quite a bit deeper than we need to for this problem, but is a great read if you're interested in learning more.
We'll actually end up with something that looks very similar to algorithmica's first vectorized example. The only difference is that the algorithm presented by algorithmica operates on 32-bit integers in sequence, while we want to operate on 8-bit bytes offset by bpp
. When bpp
is 4, this looks strikingly similar to just adding 32-bit integers.
For the accumulate step, we're going to use _mm256_extract_epi32
. I wasn't able to port their accumulate implementation very well, but I'd assume it compiles down to roughly the same thing.
Here's our implementation of the sub
filter using the full width of AVX registers:
pub unsafe fn sub_avx(raw_row: &[u8], decoded_row: &mut [u8]) {
let mut last = 0;
let mut x: __m256i;
let len = raw_row.len();
let mut i = 0;
// we ensure the length in our SIMD loop is divisible by 32
let offset = len % 32;
if offset != 0 {
sub_sse2(raw_row.get_unchecked(..offset), decoded_row.get_unchecked_mut(..offset));
last = i32::from_be_bytes([
*decoded_row.get_unchecked(offset - 1),
*decoded_row.get_unchecked(offset - 2),
*decoded_row.get_unchecked(offset - 3),
*decoded_row.get_unchecked(offset - 4),
]);
i = offset;
}
while len != i {
// load 32 bytes from input array
x = _mm256_loadu_si256(raw_row.get_unchecked(i) as *const _ as *const __m256i);
// do prefix sum
x = _mm256_add_epi8(_mm256_slli_si256::<4>(x), x);
x = _mm256_add_epi8(_mm256_slli_si256::<{ 2 * 4 }>(x), x);
// accumulate for first 16 bytes
let b = _mm256_extract_epi32::<3>(x);
x = _mm256_add_epi8(_mm256_set_epi32(b, b, b, b, 0, 0, 0, 0), x);
// accumulate for previous chunk of 16 bytes
x = _mm256_add_epi8(_mm256_set1_epi32(last), x);
// extract last 4 bytes to be used in next iteration
last = _mm256_extract_epi32::<7>(x);
// write 32 bytes to out array
_mm256_storeu_si256(decoded_row.get_unchecked_mut(i) as *mut _ as *mut __m256i, x);
i += 32;
}
}
We can reuse our SSE implementation for our remainder loop at the start. How does this perform?
RUSTFLAGS='-Ctarget-cpu=native' cargo bench
running 5 tests
test tests::bench_baseline_memcpy ... bench: 61,624 ns/iter (+/- 3,384)
test tests::bench_sub_avx ... bench: 82,168 ns/iter (+/- 3,754)
test tests::bench_sub_naive_scalar ... bench: 86,713 ns/iter (+/- 3,411)
test tests::bench_sub_no_bound_checks ... bench: 86,567 ns/iter (+/- 1,432)
test tests::bench_sub_sse2 ... bench: 86,422 ns/iter (+/- 3,934)
We get something that's a bit faster. It isn't too much above noise, but in my testing it does appear to run consistently ~5% faster. That's not nothing, but it's definitely a much smaller win than we'd expect from operating on 8x the number of bytes at a time. It's likely that although we're now able to operate on a larger number of bytes at a time, our wins are being consumed by the extra processing we're now doing to perform prefix sum.
This is roughly where I left things for about a year. I came back to this problem every once in a while after being inspired by blog posts, reading code, or learning about interesting applications of x86 SIMD intrinsics, but in general I wasn't able to improve on this problem too much.
The particularly slow part is _mm256_extract_epi32
, especially the second call with a value of 7
. For values above 3, this intrinsic will compile down to multiple expensive instructions. If we remove this intrinsic, we approach the speed of memcpy
. However, we can't really remove this intrinsic, since it's necessary for the accumulate step.
Last week, as part of a larger blog post investigating the performance of PNG decoders, I revisited this problem. For the sake of completion, I was interested to see how a similar algorithm would perform if we used SSE registers instead.
The initial implementation looks like this:
pub unsafe fn sub_sse_prefix_sum(raw_row: &[u8], decoded_row: &mut [u8]) {
let mut last = 0;
let mut x: __m128i;
let len = raw_row.len();
let mut i = 0;
let offset = len % 16;
if offset != 0 {
sub_sse2(raw_row.get_unchecked(..offset), decoded_row.get_unchecked_mut(..offset));
last = i32::from_be_bytes([
*decoded_row.get_unchecked(offset - 1),
*decoded_row.get_unchecked(offset - 2),
*decoded_row.get_unchecked(offset - 3),
*decoded_row.get_unchecked(offset - 4),
]);
i = offset;
}
while len != i {
// load 16 bytes from array
x = _mm_loadu_si128(raw_row.get_unchecked(i) as *const _ as *const __m128i);
// do prefix sum
x = _mm_add_epi8(_mm_slli_si128::<4>(x), x);
x = _mm_add_epi8(_mm_slli_si128::<{ 2 * 4 }>(x), x);
// accumulate for previous chunk of 16 bytes
x = _mm_add_epi8(x, _mm_set1_epi32(last));
last = _mm_extract_epi32::<3>(x);
// write 16 bytes to out array
_mm_storeu_si128(decoded_row.get_unchecked_mut(i) as *mut _ as *mut __m128i, x);
i += 16;
}
}
This is pretty much the same as our AVX implementation, except now we operate on 16 bytes at a time. Let's see how this performs:
running 6 tests
test tests::bench_baseline_memcpy ... bench: 61,482 ns/iter (+/- 10,144)
test tests::bench_sub_avx ... bench: 82,137 ns/iter (+/- 623)
test tests::bench_sub_naive_scalar ... bench: 86,471 ns/iter (+/- 2,308)
test tests::bench_sub_no_bound_checks ... bench: 86,571 ns/iter (+/- 9,310)
test tests::bench_sub_sse2 ... bench: 86,146 ns/iter (+/- 2,569)
test tests::bench_sub_sse_prefix_sum ... bench: 113,024 ns/iter (+/- 1,683)
It's quite a bit slower. I guess that's to be expected. The AVX implementation operates on 2x the number of bytes and only gets a 5% speedup. If we only use SSE registers, we don't see any gains.
But, something interesting about using SSE registers is that we can actually avoid the extract by doing a bitshift and a broadcast. Let's look at an implementation using this,
pub unsafe fn sub_sse_prefix_sum_no_extract(raw_row: &[u8], decoded_row: &mut [u8]) {
let mut last = _mm_setzero_si128();
let mut x: __m128i;
let len = raw_row.len();
let mut i = 0;
let offset = len % 16;
if offset != 0 {
sub_sse2(raw_row.get_unchecked(..offset), decoded_row.get_unchecked_mut(..offset));
last = _mm_castps_si128(_mm_broadcast_ss(&*(decoded_row.get_unchecked(offset - 4) as *const _ as *const f32)));
i = offset;
}
while len != i {
// load 16 bytes from array
x = _mm_loadu_si128(raw_row.get_unchecked(i) as *const _ as *const __m128i);
// do prefix sum
x = _mm_add_epi8(_mm_slli_si128::<4>(x), x);
x = _mm_add_epi8(_mm_slli_si128::<{ 2 * 4 }>(x), x);
// accumulate for previous chunk of 16 bytes
x = _mm_add_epi8(x, last);
// shift right by 12 bytes and then broadcast the lower 4 bytes
// to the rest of the register
last = _mm_srli_si128::<12>(x);
last = _mm_broadcastd_epi32(last);
_mm_storeu_si128(decoded_row.get_unchecked_mut(i) as *mut _ as *mut __m128i, x);
i += 16;
}
}
Running the benchmarks:
running 7 tests
test tests::bench_baseline_memcpy ... bench: 62,683 ns/iter (+/- 18,025)
test tests::bench_sub_avx ... bench: 82,232 ns/iter (+/- 9,210)
test tests::bench_sub_naive_scalar ... bench: 86,680 ns/iter (+/- 2,137)
test tests::bench_sub_no_bound_checks ... bench: 86,756 ns/iter (+/- 1,310)
test tests::bench_sub_sse2 ... bench: 86,519 ns/iter (+/- 2,527)
test tests::bench_sub_sse_prefix_sum ... bench: 112,770 ns/iter (+/- 4,366)
test tests::bench_sub_sse_prefix_sum_no_extract ... bench: 69,864 ns/iter (+/- 1,696)
It's a lot faster! We're approaching the speed of memcpy
. With this new algorithm, we can go ~25% faster than a naive approach operating on only 4 bytes at a time. It seems hard to improve on this further — at some point we'll be bound by memory. As a proof of concept for this algorithm, I think this works quite well. It may be possible to improve on this by making better use of AVX intrinsics, but for right now it's likely not worth the effort to optimize this further.
Impact of this research
The goal up to this point has largely been to demonstrate that this algorithm can improve the performance of PNG decoding. The work demonstrated here is a proof-of-concept and doesn't contain a production-ready implementation.
The sub
filter is just a small part of PNG decoding. Although we managed to speed it up by 25% for inputs of this size, this doesn't correlate to a 25% improvement of PNG decoding. The exact improvement here is a bit hard to calculate as it depends heavily on the input — the dimensions of the PNG, DEFLATE's compression level, the distribution of filters, etc. In general I wouldn't expect this to be too large of a win in total decode time, but it may become meaningful if you're trying to write the fastest theoretical PNG decoder (this is foreshadowing).
I think it may be possible to apply similar ideas to the avg
and paeth
filters, but I haven't yet come up with a performant solution for them. One particularly painful issue with the avg
filter is that the average is taken with 9 bits of precision, rather than 8, and then truncated — so it's not sufficient to just do a bitshift in order to divide by two. Future research here may involve the VPAVGB
instruction.
I haven't yet investigated the paeth
filter, so I'm not sure how difficult such a solution for this filter would be. At first glance it appears quite a bit more challenging than the other filters, but there may be a clever solution hiding somewhere.
The intuition for this may be a bit challenging to come up with if you're not already familiar with prefix sum and its properties. I believe this particular idea came to me by just staring at an example execution for a bit.