Skip to content

Commit

Permalink
Remove some branches in tight PPU loops
Browse files Browse the repository at this point in the history
  • Loading branch information
aelred committed Oct 6, 2024
1 parent eac4474 commit 3f5f7ae
Showing 1 changed file with 161 additions and 81 deletions.
242 changes: 161 additions & 81 deletions src/ppu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ pub struct PPU<M = NESPPUMemory> {
cycle_count: u16,
tile_pattern: ShiftRegister,
palette_select: ShiftRegister,
active_sprites: [Sprite; ACTIVE_SPRITES],
active_sprites: [ActiveSprite; ACTIVE_SPRITES],
active_sprites_has_zero: bool,
control: Control,
status: Status,
Expand All @@ -61,7 +61,7 @@ impl<M: Memory> PPU<M> {
cycle_count: 0,
tile_pattern: ShiftRegister::default(),
palette_select: ShiftRegister::default(),
active_sprites: [Sprite::default(); ACTIVE_SPRITES],
active_sprites: [ActiveSprite::default(); ACTIVE_SPRITES],
active_sprites_has_zero: false,
control: Control::default(),
mask: Mask::default(),
Expand Down Expand Up @@ -128,6 +128,7 @@ impl<M: Memory> PPU<M> {
}

let sprite_size = self.control.sprite_size();
let table = self.control.sprite_pattern_table();

let all_sprites = self.object_attribute_memory.chunks_exact(4).map(|chunk| {
let attributes = SpriteAttributes::from_bits_truncate(chunk[2]);
Expand All @@ -137,108 +138,128 @@ impl<M: Memory> PPU<M> {
let scanline = self.scanline - 1;

let sprites_on_scanline = all_sprites.enumerate().filter(|(_, sprite)| {
let y = u16::from(sprite.y);
let y = sprite.y as u16;
scanline >= y && scanline < y + sprite_size.height() as u16
});

self.active_sprites = [Sprite::default(); ACTIVE_SPRITES];
self.active_sprites = [ActiveSprite::default(); ACTIVE_SPRITES];
self.active_sprites_has_zero = false;

for (dest, (i, src)) in self.active_sprites.iter_mut().zip(sprites_on_scanline) {
self.active_sprites_has_zero |= i == 0;
*dest = src;
*dest = ActiveSprite {
sprite: src,
..Default::default()
};
}

for i in 0..ACTIVE_SPRITES {
let sprite = self.active_sprites[i].sprite;
let attr = sprite.attributes;

// Use wrapping_sub and % to handle default zero'd sprites at y = 0 without branching
let y_in_sprite = scanline.wrapping_sub(sprite.y as u16) as u8 % sprite_size.height();
let y_in_sprite = attr.ver_flip(y_in_sprite, sprite_size);

let (sprite_table, index) = match sprite_size {
SpriteSize::_8x8 => (table, sprite.tile_index),
SpriteSize::_8x16 => (
PatternTable::from((sprite.tile_index & 0b1) == 1),
sprite.tile_index & 0b1111_1110,
),
};

let (pattern0, pattern1) = self.read_pattern_row(sprite_table, index, y_in_sprite);

self.active_sprites[i].pattern0 = pattern0;
self.active_sprites[i].pattern1 = pattern1;
}
}

fn next_color(&mut self) -> Option<Color> {
fn next_color(&mut self) -> Color {
let sprite = self.sprite_color();
let (background, background_opaque) = self.background_color();

let color = match sprite {
Some((sprite, priority, _)) if priority => sprite,
_ if background_opaque => background,
_ => sprite.map(|(sprite, _, _)| sprite).unwrap_or(background),
let color_address = if sprite.visible && sprite.priority {
sprite.color_address
} else if background_opaque {
background
} else if sprite.visible {
sprite.color_address
} else {
background
};

if let Some((_, _, index)) = sprite {
if self.active_sprites_has_zero && index == 0 && background_opaque {
self.status |= Status::SPRITE_ZERO_HIT;
}
if self.active_sprites_has_zero && sprite.index == 0 && background_opaque {
self.status |= Status::SPRITE_ZERO_HIT;
}

Some(color)
Color(self.memory.read(color_address))
}

fn background_color(&mut self) -> (Color, bool) {
let lower_bits = self.tile_pattern.get_bits(self.fine_x);
let higher_bits = self.palette_select.get_bits(self.fine_x);
fn background_color(&self) -> (Address, bool) {
let lower_bits = self.tile_pattern.get_bits(self.fine_x);
let higher_bits = self.palette_select.get_bits(self.fine_x);

let color_index = (lower_bits | (higher_bits << 2)) as u16;

let show_background = self.mask.contains(Mask::SHOW_BACKGROUND);
let opaque = show_background && lower_bits != 0;

// Use universal background colour when transparent
let address = BACKGROUND_PALETTES + color_index * opaque as u16;
let color_address = BACKGROUND_PALETTES + color_index * opaque as u16;

(Color(self.memory.read(address)), opaque)
(color_address, opaque)
}

fn sprite_color(&mut self) -> Option<(Color, bool, usize)> {
if !self.mask.contains(Mask::SHOW_SPRITES) || self.scanline == 0 {
return None;
}
fn sprite_color(&self) -> SelectedSprite {
let show_sprites = self.mask.contains(Mask::SHOW_SPRITES) && self.scanline > 0;

let cycle_count = self.cycle_count;
let scanline = self.scanline - 1;
// Bitflags for which sprites should be shown, to avoid branches
let mut show: u8 = 0b0000_0000;
// All 8 sprites, plus a 9th sprite that will always be 'None'
let mut results: [(SpriteAttributes, u8); 9] = Default::default();

let sprites = self.active_sprites;
for (index, active_sprite) in self.active_sprites.iter().enumerate() {
let x = active_sprite.sprite.x as u16;
let attr = active_sprite.sprite.attributes;

let sprite_size = self.control.sprite_size();
let table = self.control.sprite_pattern_table();
// Use % to always handle default sprite with x = 0 without branching
let x_in_sprite = attr.hor_flip(self.cycle_count.wrapping_sub(x) as u8 % 8);

for (i, sprite) in sprites.iter().enumerate() {
let x = u16::from(sprite.x);
let attr = sprite.attributes;
let bit0 = (active_sprite.pattern0 >> x_in_sprite) & 0b1;
let bit1 = (active_sprite.pattern1 >> x_in_sprite) & 0b1;

if cycle_count < x || cycle_count >= x + 8 || scanline < sprite.y as u16 {
continue;
}
let lower_index = (bit1 << 1) | bit0;

let x_in_sprite = attr.hor_flip((cycle_count - x) as u8);
let y_in_sprite = attr.ver_flip((scanline - sprite.y as u16) as u8, sprite_size);
let transparent = lower_index == 0;

let (sprite_table, index) = match sprite_size {
SpriteSize::_8x8 => (table, sprite.tile_index),
SpriteSize::_8x16 => (
PatternTable::from((sprite.tile_index & 0b1) == 1),
sprite.tile_index & 0b1111_1110,
),
};
let show_sprite = show_sprites
&& !transparent
&& self.cycle_count >= x
&& self.cycle_count < x + 8
&& self.scanline > active_sprite.sprite.y as u16;

let (pattern0, pattern1) = self.read_pattern_row(sprite_table, index, y_in_sprite);
show |= (show_sprite as u8) << index;

let bit0 = (pattern0 >> x_in_sprite) & 0b1;
let bit1 = (pattern1 >> x_in_sprite) & 0b1;
results[index] = (attr, lower_index);
}

let lower_index = (bit1 << 1) | bit0;
// Find the highest-priority sprite that should be shown using trailing zeros in the bit flag
let index = show.trailing_zeros() as usize;
let (attr, lower_index) = results[index];

let transparent = lower_index == 0;
if transparent {
continue;
}
let palette = (attr & SpriteAttributes::PALETTE).bits();
let color_index = (palette << 2) | lower_index;
let color_address = SPRITE_PALETTES + color_index.into();
let priority = !attr.contains(SpriteAttributes::PRIORITY);

let palette = (attr & SpriteAttributes::PALETTE).bits();
let color_index = (palette << 2) | lower_index;
let address = SPRITE_PALETTES + color_index.into();
return Some((
Color(self.memory.read(address)),
!attr.contains(SpriteAttributes::PRIORITY),
i,
));
SelectedSprite {
visible: index < ACTIVE_SPRITES,
color_address,
priority,
index,
}

None
}

fn read_pattern_row(
Expand Down Expand Up @@ -333,7 +354,7 @@ impl<M: Memory> PPU<M> {
self.increment_coarse_x();
}

let color = if in_bounds { self.next_color() } else { None };
let color = in_bounds.then(|| self.next_color());

// Don't shift registers in the last 4 bits, or everything goes out of alignment.
// Oddly, the cycle count in a scanline isn't divisible by 8.
Expand Down Expand Up @@ -416,6 +437,21 @@ impl ShiftRegister {
}
}

#[derive(Copy, Clone)]
struct SelectedSprite {
visible: bool,
color_address: Address,
priority: bool,
index: usize,
}

#[derive(Debug, Copy, Clone, Eq, PartialEq, Default)]
struct ActiveSprite {
sprite: Sprite,
pattern0: u8,
pattern1: u8,
}

#[derive(Debug, Copy, Clone, Eq, PartialEq)]
struct Sprite {
x: u8,
Expand Down Expand Up @@ -454,7 +490,7 @@ bitflags! {
impl SpriteAttributes {
/// Bit-twiddling to avoid a conditional, same as `x = if hor_flip { x } else { 7 - x }`
fn hor_flip(self, x: u8) -> u8 {
x ^ ((((self & Self::HORIZONTAL_FLIP) ^ Self::HORIZONTAL_FLIP).bits() >> 6) * 0b0000_0111)
x ^ (((!self & Self::HORIZONTAL_FLIP).bits() >> 6) * 0b0000_0111)
}

/// Bit-twiddling to avoid a conditional, same as `y = if ver_flip { height - 1 - y } else { y }`
Expand Down Expand Up @@ -1073,14 +1109,26 @@ mod tests {
ppu.load_sprites();

let expected = [
Sprite::new(1, 22, 1, SpriteAttributes::from_bits_truncate(1)),
Sprite::new(2, 23, 2, SpriteAttributes::from_bits_truncate(2)),
Sprite::new(3, 29, 3, SpriteAttributes::from_bits_truncate(3)),
Sprite::default(),
Sprite::default(),
Sprite::default(),
Sprite::default(),
Sprite::default(),
ActiveSprite {
sprite: Sprite::new(1, 22, 1, SpriteAttributes::from_bits_truncate(1)),
pattern0: 0,
pattern1: 0,
},
ActiveSprite {
sprite: Sprite::new(2, 23, 2, SpriteAttributes::from_bits_truncate(2)),
pattern0: 0,
pattern1: 0,
},
ActiveSprite {
sprite: Sprite::new(3, 29, 3, SpriteAttributes::from_bits_truncate(3)),
pattern0: 0,
pattern1: 0,
},
ActiveSprite::default(),
ActiveSprite::default(),
ActiveSprite::default(),
ActiveSprite::default(),
ActiveSprite::default(),
];

assert_eq!(ppu.active_sprites, expected);
Expand All @@ -1103,14 +1151,46 @@ mod tests {
ppu.load_sprites();

let expected = [
Sprite::new(0, 23, 0, SpriteAttributes::from_bits_truncate(0)),
Sprite::new(1, 23, 1, SpriteAttributes::from_bits_truncate(1)),
Sprite::new(2, 24, 2, SpriteAttributes::from_bits_truncate(2)),
Sprite::new(3, 24, 3, SpriteAttributes::from_bits_truncate(3)),
Sprite::new(4, 25, 4, SpriteAttributes::from_bits_truncate(4)),
Sprite::new(5, 25, 5, SpriteAttributes::from_bits_truncate(5)),
Sprite::new(6, 26, 6, SpriteAttributes::from_bits_truncate(6)),
Sprite::new(7, 26, 7, SpriteAttributes::from_bits_truncate(7)),
ActiveSprite {
sprite: Sprite::new(0, 23, 0, SpriteAttributes::from_bits_truncate(0)),
pattern0: 0,
pattern1: 0,
},
ActiveSprite {
sprite: Sprite::new(1, 23, 1, SpriteAttributes::from_bits_truncate(1)),
pattern0: 0,
pattern1: 0,
},
ActiveSprite {
sprite: Sprite::new(2, 24, 2, SpriteAttributes::from_bits_truncate(2)),
pattern0: 0,
pattern1: 0,
},
ActiveSprite {
sprite: Sprite::new(3, 24, 3, SpriteAttributes::from_bits_truncate(3)),
pattern0: 0,
pattern1: 0,
},
ActiveSprite {
sprite: Sprite::new(4, 25, 4, SpriteAttributes::from_bits_truncate(4)),
pattern0: 0,
pattern1: 0,
},
ActiveSprite {
sprite: Sprite::new(5, 25, 5, SpriteAttributes::from_bits_truncate(5)),
pattern0: 0,
pattern1: 0,
},
ActiveSprite {
sprite: Sprite::new(6, 26, 6, SpriteAttributes::from_bits_truncate(6)),
pattern0: 0,
pattern1: 0,
},
ActiveSprite {
sprite: Sprite::new(7, 26, 7, SpriteAttributes::from_bits_truncate(7)),
pattern0: 0,
pattern1: 0,
},
];

assert_eq!(ppu.active_sprites, expected);
Expand All @@ -1126,7 +1206,7 @@ mod tests {
ppu.scanline = 30;
ppu.load_sprites();

let cleared = [Sprite::default(); 8];
let cleared = [ActiveSprite::default(); 8];

assert_ne!(ppu.active_sprites, cleared);

Expand Down

0 comments on commit 3f5f7ae

Please sign in to comment.