Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: set & get channel oscillator state #186

Merged
merged 8 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ jobs:
- name: Install Hatch
uses: pypa/hatch@a3c83ab3d481fbc2dc91dd0088628817488dd1d5
with:
version: 1.11.1
version: 1.12.0
- name: Test
run: hatch test --all
- name: Lint
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
.lazy.lua
/target
generated

Expand Down
40 changes: 34 additions & 6 deletions bosing.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ class Channel:
@property
def align_level(self) -> int: ...
@property
def iq_matrix(self) -> np.ndarray | None: ...
def iq_matrix(self) -> npt.NDArray[np.float64] | None: ...
@property
def offset(self) -> np.ndarray | None: ...
def offset(self) -> npt.NDArray[np.float64] | None: ...
@property
def iir(self) -> np.ndarray | None: ...
def iir(self) -> npt.NDArray[np.float64] | None: ...
@property
def fir(self) -> np.ndarray | None: ...
def fir(self) -> npt.NDArray[np.float64] | None: ...
@property
def filter_offset(self) -> bool: ...
@property
Expand All @@ -51,7 +51,9 @@ class Alignment:
Center: ClassVar[Alignment]
Stretch: ClassVar[Alignment]
@staticmethod
def convert(obj: Literal["end", "start", "center", "stretch"] | Alignment) -> Alignment: ...
def convert(
obj: Literal["end", "start", "center", "stretch"] | Alignment,
) -> Alignment: ...

class Shape: ...

Expand Down Expand Up @@ -368,6 +370,21 @@ class Grid(Element):
@property
def columns(self) -> Sequence[GridLength]: ...

@final
class OscState:
def __new__(
cls,
base_freq: float,
delta_freq: float,
phase: float,
) -> Self: ...
base_freq: float
delta_freq: float
phase: float
def total_freq(self) -> float: ...
def phase_at(self, time: float) -> float: ...
def with_time_shift(self, time: float) -> Self: ...

def generate_waveforms(
channels: Mapping[str, Channel],
shapes: Mapping[str, Shape],
Expand All @@ -377,4 +394,15 @@ def generate_waveforms(
amp_tolerance: float = ...,
allow_oversize: bool = ...,
crosstalk: tuple[npt.ArrayLike, Sequence[str]] | None = ...,
) -> dict[str, np.ndarray]: ...
) -> dict[str, npt.NDArray[np.float64]]: ...
def generate_waveforms_with_states(
channels: Mapping[str, Channel],
shapes: Mapping[str, Shape],
schedule: Element,
*,
time_tolerance: float = ...,
amp_tolerance: float = ...,
allow_oversize: bool = ...,
crosstalk: tuple[npt.ArrayLike, Sequence[str]] | None = ...,
states: Mapping[str, OscState] | None = ...,
) -> tuple[dict[str, npt.NDArray[np.float64]], dict[str, OscState]]: ...
84 changes: 66 additions & 18 deletions src/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,18 @@ pub(crate) enum Error {
NotEnoughDuration { required: Time, available: Time },
}

#[derive(Debug, Clone, Copy)]
pub(crate) struct OscState {
base_freq: Frequency,
delta_freq: Frequency,
phase: Phase,
}

type Result<T> = std::result::Result<T, Error>;

#[derive(Debug, Clone)]
struct Channel {
base_freq: Frequency,
delta_freq: Frequency,
phase: Phase,
osc: OscState,
pulses: PulseListBuilder,
}

Expand Down Expand Up @@ -78,17 +83,27 @@ impl Executor {
}
}

pub(crate) fn add_channel(&mut self, name: ChannelId, base_freq: Frequency) {
pub(crate) fn add_channel(&mut self, name: ChannelId, osc: OscState) {
self.channels.insert(
name,
Channel::new(base_freq, self.amp_tolerance, self.time_tolerance),
Channel {
osc,
pulses: PulseListBuilder::new(self.amp_tolerance, self.time_tolerance),
},
);
}

pub(crate) fn add_shape(&mut self, name: ShapeId, shape: Shape) {
self.shapes.insert(name, shape);
}

pub(crate) fn states(&self) -> HashMap<ChannelId, OscState> {
self.channels
.iter()
.map(|(n, b)| (n.clone(), b.osc))
.collect()
}

pub(crate) fn into_result(self) -> HashMap<ChannelId, PulseList> {
self.channels
.into_iter()
Expand Down Expand Up @@ -168,28 +183,28 @@ impl Executor {
fn execute_shift_phase(&mut self, variant: &ShiftPhase) -> Result<()> {
let delta_phase = variant.phase();
let channel = self.get_mut_channel(variant.channel_id())?;
channel.shift_phase(delta_phase);
channel.osc.shift_phase(delta_phase);
Ok(())
}

fn execute_set_phase(&mut self, variant: &SetPhase, time: Time) -> Result<()> {
let phase = variant.phase();
let channel = self.get_mut_channel(variant.channel_id())?;
channel.set_phase(phase, time);
channel.osc.set_phase(phase, time);
Ok(())
}

fn execute_shift_freq(&mut self, variant: &ShiftFreq, time: Time) -> Result<()> {
let delta_freq = variant.frequency();
let channel = self.get_mut_channel(variant.channel_id())?;
channel.shift_freq(delta_freq, time);
channel.osc.shift_freq(delta_freq, time);
Ok(())
}

fn execute_set_freq(&mut self, variant: &SetFreq, time: Time) -> Result<()> {
let freq = variant.frequency();
let channel = self.get_mut_channel(variant.channel_id())?;
channel.set_freq(freq, time);
channel.osc.set_freq(freq, time);
Ok(())
}

Expand All @@ -203,7 +218,7 @@ impl Executor {
.channels
.get_many_mut([ch1, ch2])
.ok_or(Error::ChannelNotFound(vec![ch1.clone(), ch2.clone()]))?;
channel.swap_phase(other, time);
channel.osc.swap_phase(&mut other.osc, time);
Ok(())
}

Expand All @@ -214,13 +229,28 @@ impl Executor {
}
}

impl Channel {
fn new(base_freq: Frequency, amp_tolerance: Amplitude, time_tolerance: Time) -> Self {
impl OscState {
pub(crate) fn new(base_freq: Frequency) -> Self {
Self {
base_freq,
delta_freq: Frequency::ZERO,
phase: Phase::ZERO,
pulses: PulseListBuilder::new(amp_tolerance, time_tolerance),
}
}

pub(crate) fn total_freq(&self) -> Frequency {
self.base_freq + self.delta_freq
}

pub(crate) fn phase_at(&self, time: Time) -> Phase {
self.phase + self.total_freq() * time
}

pub(crate) fn with_time_shift(&self, time: Time) -> Self {
Self {
base_freq: self.base_freq,
delta_freq: self.delta_freq,
phase: self.phase_at(time),
}
}

Expand All @@ -245,18 +275,16 @@ impl Channel {
self.phase = phase - self.delta_freq * time;
}

fn total_freq(&self) -> Frequency {
self.base_freq + self.delta_freq
}

fn swap_phase(&mut self, other: &mut Self, time: Time) {
let delta_freq = self.total_freq() - other.total_freq();
let phase1 = self.phase;
let phase2 = other.phase;
self.phase = phase2 - delta_freq * time;
other.phase = phase1 + delta_freq * time;
}
}

impl Channel {
fn add_pulse(
&mut self,
AddPulseArgs {
Expand All @@ -271,7 +299,7 @@ impl Channel {
}: AddPulseArgs,
) {
let envelope = Envelope::new(shape, width, plateau);
let global_freq = self.total_freq();
let global_freq = self.osc.total_freq();
let local_freq = freq;
self.pulses.push(PushArgs {
envelope,
Expand All @@ -285,6 +313,26 @@ impl Channel {
}
}

impl From<crate::OscState> for OscState {
fn from(osc: crate::OscState) -> Self {
Self {
base_freq: osc.base_freq,
delta_freq: osc.delta_freq,
phase: osc.phase,
}
}
}

impl From<OscState> for crate::OscState {
fn from(osc: OscState) -> Self {
Self {
base_freq: osc.base_freq,
delta_freq: osc.delta_freq,
phase: osc.phase,
}
}
}

impl<S, A, G, R, T> Iterator for IterVariant<S, A, G, R>
where
S: Iterator<Item = T>,
Expand Down
Loading