Skip to content

Commit

Permalink
Change Vn from Option<V> to Option<&V>.
Browse files Browse the repository at this point in the history
Change Fn,R1, & R2 to struct tuple types.
impl Debug and PartialEq for fn types manually.
    see rust-lang/rust#45048
Adjust types to work with now-borrowed values.
Reduce clones when available.

In general, by borrowing the arguments for all builtin
functions, this lets us only clone when creating a new
value. This ends up being significant because previously
during "table" for example, each "call" applied to the
Array values cloned the value when calling the builtin
fn. Now it will only clone the value when returning a
Vs to the stack.
  • Loading branch information
cannadayr committed Jan 15, 2022
1 parent 9127bfb commit cd0a42e
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 83 deletions.
2 changes: 1 addition & 1 deletion rs_src/bin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ fn main() {
let src = new_string("{×´1+↕𝕩}");
let prog = prog(compiler,src,runtime);
info!("func loaded");
let result = call(1,Some(run(prog)),Some(V::Scalar(10.0)),None);
let result = call(1,Some(&run(prog)),Some(&V::Scalar(10.0)),None);
info!("result = {}",&result);
}
22 changes: 11 additions & 11 deletions rs_src/ebqn.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::schema::{Env,V,Vs,Vr,Vn,Block,BlockInst,Code,Calleable,Body,A,Ar,Tr2,Tr3,Runtime,Compiler,Prog,set,ok,D2,D1,new_scalar,new_string};
use crate::schema::{Env,V,Vs,Vr,Vn,Block,BlockInst,Code,Calleable,Body,A,Ar,Tr2,Tr3,Runtime,Compiler,Prog,set,ok,D2,D1,Fn,new_scalar,new_string};
use crate::prim::{provide,decompose,prim_ind};
use crate::code::{r0,r1,c};
use crate::fmt::{dbg_stack_out,dbg_stack_in};
Expand Down Expand Up @@ -139,7 +139,7 @@ pub fn vm(env: &Env,code: &Cc<Code>,mut pos: usize,mut stack: Vec<Vs>) -> Vs {
let r =
match &x.as_v().unwrap() {
V::Nothing => x,
_ => call(1,Some(f.into_v().unwrap()),Some(x.into_v().unwrap()),None),
_ => call(1,Some(&f.into_v().unwrap()),Some(&x.into_v().unwrap()),None),
};
stack.push(r);
dbg_stack_out("FN1C",pos-1,&stack);
Expand All @@ -152,8 +152,8 @@ pub fn vm(env: &Env,code: &Cc<Code>,mut pos: usize,mut stack: Vec<Vs>) -> Vs {
let r =
match (&x.as_v().unwrap(),&w.as_v().unwrap()) {
(V::Nothing,_) => x,
(_,V::Nothing) => call(1,Some(f.into_v().unwrap()),Some(x.into_v().unwrap()),None),
_ => call(2,Some(f.into_v().unwrap()),Some(x.into_v().unwrap()),Some(w.into_v().unwrap()))
(_,V::Nothing) => call(1,Some(&f.into_v().unwrap()),Some(&x.into_v().unwrap()),None),
_ => call(2,Some(&f.into_v().unwrap()),Some(&x.into_v().unwrap()),Some(&w.into_v().unwrap()))
};
stack.push(r);
dbg_stack_out("FN2C",pos-1,&stack);
Expand Down Expand Up @@ -233,7 +233,7 @@ pub fn vm(env: &Env,code: &Cc<Code>,mut pos: usize,mut stack: Vec<Vs>) -> Vs {
let i = stack.pop().unwrap();
let f = stack.pop().unwrap();
let x = stack.pop().unwrap();
let v = call(2,Some(f.into_v().unwrap()),Some(x.into_v().unwrap()),Some(i.get()));
let v = call(2,Some(&f.into_v().unwrap()),Some(&x.into_v().unwrap()),Some(&i.get()));
let r = set(false,i,v);
stack.push(Vs::V(r));
dbg_stack_out("SETM",pos-1,&stack);
Expand All @@ -242,7 +242,7 @@ pub fn vm(env: &Env,code: &Cc<Code>,mut pos: usize,mut stack: Vec<Vs>) -> Vs {
dbg_stack_in("SETC",pos-1,"".to_string(),&stack);
let i = stack.pop().unwrap();
let f = stack.pop().unwrap();
let v = call(1,Some(f.into_v().unwrap()),Some(i.get()),None);
let v = call(1,Some(&f.into_v().unwrap()),Some(&i.get()),None);
let r = set(false,i,v);
stack.push(Vs::V(r));
dbg_stack_out("SETC",pos-1,&stack);
Expand Down Expand Up @@ -286,16 +286,16 @@ pub fn runtime() -> Cc<A> {
}
}
info!("runtime loaded");
let prim_fns = V::A(Cc::new(A::new(vec![V::Fn(decompose,None),V::Fn(prim_ind,None)],vec![2])));
let _ = call(1,Some(set_prims),Some(prim_fns),None);
let prim_fns = V::A(Cc::new(A::new(vec![V::Fn(Fn(decompose),None),V::Fn(Fn(prim_ind),None)],vec![2])));
let _ = call(1,Some(&set_prims),Some(&prim_fns),None);
prims
},
None => panic!("cant get mutable runtime"),
}
}

pub fn prog(compiler: V,src: V,runtime: Cc<A>) -> Cc<Code> {
let mut prog = call(2,Some(compiler),Some(src),Some(V::A(runtime))).into_v().unwrap().into_a().unwrap();
let mut prog = call(2,Some(&compiler),Some(&src),Some(&V::A(runtime))).into_v().unwrap().into_a().unwrap();
info!("prog count = {}",prog.strong_count());
match prog.get_mut() {
Some(p) => {
Expand Down Expand Up @@ -351,7 +351,7 @@ pub fn prog(compiler: V,src: V,runtime: Cc<A>) -> Cc<Code> {
b.r.iter().map(|e| match e.as_a().unwrap().r.iter().collect_tuple() {
Some((V::Scalar(pos),V::Scalar(local),_name_id,_export_mask)) =>
(usize::from_f64(*pos).unwrap(),usize::from_f64(*local).unwrap()),
x => panic!("couldn't load compiled body {:?}",x),
_x => panic!("couldn't load compiled body"),
}).collect::<Vec<(usize,usize)>>()
},
Err(_b) => panic!("cant get unique ref to program blocks"),
Expand Down Expand Up @@ -390,6 +390,6 @@ fn init_c(r: ResourceArc<Runtime>) -> NifResult<(Atom,ResourceArc<Compiler>)> {
//}
#[rustler::nif]
fn callp(p: ResourceArc<Prog>,n: f64) -> NifResult<(Atom,V)> {
let result = call(1,Some(run(p.0.clone())),Some(V::Scalar(n)),None);
let result = call(1,Some(&run(p.0.clone())),Some(&V::Scalar(n)),None);
Ok((ok(),result.into_v().unwrap()))
}
19 changes: 17 additions & 2 deletions rs_src/fmt.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::fmt::{Display,Formatter,Result};
use crate::schema::{V,Vs};
use std::fmt::{Debug,Display,Formatter,Result};
use crate::schema::{V,Vs,Fn,R1,R2};
use log::{debug, trace, error, log_enabled, info, Level};

pub fn fmt_stack(stack: &Vec<Vs>) -> String {
Expand Down Expand Up @@ -48,3 +48,18 @@ impl Display for Vs {
}
}

impl Debug for Fn {
fn fmt(&self, f: &mut Formatter) -> Result {
write!(f, "{:?}", self)
}
}
impl Debug for R1 {
fn fmt(&self, f: &mut Formatter) -> Result {
write!(f, "{:?}", self)
}
}
impl Debug for R2 {
fn fmt(&self, f: &mut Formatter) -> Result {
write!(f, "{:?}", self)
}
}
114 changes: 57 additions & 57 deletions rs_src/prim.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::schema::{A,V,Vn,Vs,Decoder,D1,D2,Tr2,Tr3};
use crate::schema::{A,V,Vn,Vs,Decoder,D1,D2,Tr2,Tr3,Fn,R1,R2};
use crate::ebqn::{call};
use cc_mt::Cc;
use std::cmp::max;
Expand Down Expand Up @@ -50,7 +50,7 @@ fn typ(arity: usize, x: Vn, _w: Vn) -> Vs {
fn fill(arity: usize, x: Vn, _w: Vn) -> Vs {
match arity {
1 => Vs::V(V::Scalar(0.0)),
2 => Vs::V(x.unwrap()),
2 => Vs::V(x.unwrap().clone()),
_ => panic!("illegal fill arity"),
}
}
Expand Down Expand Up @@ -162,12 +162,12 @@ pub fn plus(arity:usize, x: Vn,w: Vn) -> Vs {
//dbg_args("plus",arity,&x,&w);
let r =
match arity {
1 => Vs::V(x.unwrap()),
1 => Vs::V(x.unwrap().clone()),
2 => match (x.unwrap(),w.unwrap()) {
(V::Char(xc),V::Scalar(ws)) if ws >= 0.0 => Vs::V(V::Char(char::from_u32(u32::from(xc) + u32::from_f64(ws).unwrap()).unwrap())),
(V::Scalar(xs),V::Char(wc)) if xs >= 0.0 => Vs::V(V::Char(char::from_u32(u32::from(wc) + u32::from_f64(xs).unwrap()).unwrap())),
(V::Char(xc),V::Scalar(ws)) if ws < 0.0 => Vs::V(V::Char(char::from_u32(u32::from(xc) - u32::from_f64(ws.abs()).unwrap()).unwrap())),
(V::Scalar(xs),V::Char(wc)) if xs < 0.0 => Vs::V(V::Char(char::from_u32(u32::from(wc) - u32::from_f64(xs.abs()).unwrap()).unwrap())),
(V::Char(xc),V::Scalar(ws)) if *ws >= 0.0 => Vs::V(V::Char(char::from_u32(u32::from(*xc) + u32::from_f64(*ws).unwrap()).unwrap())),
(V::Scalar(xs),V::Char(wc)) if *xs >= 0.0 => Vs::V(V::Char(char::from_u32(u32::from(*wc) + u32::from_f64(*xs).unwrap()).unwrap())),
(V::Char(xc),V::Scalar(ws)) if *ws < 0.0 => Vs::V(V::Char(char::from_u32(u32::from(*xc) - u32::from_f64(ws.abs()).unwrap()).unwrap())),
(V::Scalar(xs),V::Char(wc)) if *xs < 0.0 => Vs::V(V::Char(char::from_u32(u32::from(*wc) - u32::from_f64(xs.abs()).unwrap()).unwrap())),
(V::Scalar(xs),V::Scalar(ws)) => Vs::V(V::Scalar(xs + ws)),
_ => panic!("dyadic plus pattern not found"),
},
Expand All @@ -186,9 +186,9 @@ fn minus(arity: usize, x: Vn, w: Vn) -> Vs {
_ => panic!("monadic minus expected number"),
},
2 => match (x.unwrap(),w.unwrap()) {
(V::Scalar(xs),V::Char(wc)) => Vs::V(V::Char(char::from_u32(u32::from(wc) - u32::from_f64(xs).unwrap()).unwrap())),
(V::Char(xc),V::Char(wc)) if u32::from(xc) > u32::from(wc) => Vs::V(V::Scalar(-1.0*f64::from(u32::from(xc) - u32::from(wc)))),
(V::Char(xc),V::Char(wc)) => Vs::V(V::Scalar(f64::from(u32::from(wc) - u32::from(xc)))),
(V::Scalar(xs),V::Char(wc)) => Vs::V(V::Char(char::from_u32(u32::from(*wc) - u32::from_f64(*xs).unwrap()).unwrap())),
(V::Char(xc),V::Char(wc)) if u32::from(*xc) > u32::from(*wc) => Vs::V(V::Scalar(-1.0*f64::from(u32::from(*xc) - u32::from(*wc)))),
(V::Char(xc),V::Char(wc)) => Vs::V(V::Scalar(f64::from(u32::from(*wc) - u32::from(*xc)))),
(V::Scalar(xs),V::Scalar(ws)) => Vs::V(V::Scalar(ws - xs)),
_ => panic!("dyadic minus pattern not found"),
},
Expand Down Expand Up @@ -231,7 +231,7 @@ fn power(arity: usize, x: Vn, w: Vn) -> Vs {
_ => panic!("monadic power expected number"),
},
2 => match (x.unwrap(),w.unwrap()) {
(V::Scalar(xs),V::Scalar(ws)) => Vs::V(V::Scalar(ws.powf(xs))),
(V::Scalar(xs),V::Scalar(ws)) => Vs::V(V::Scalar(ws.powf(*xs))),
_ => panic!("dyadic power expected numbers"),
},
_ => panic!("illegal power arity"),
Expand Down Expand Up @@ -333,8 +333,8 @@ fn pick(arity: usize, x: Vn, w: Vn) -> Vs {
match arity {
2 => {
match (x.unwrap(),w.unwrap()) {
(V::A(a),V::Scalar(i)) if i >= 0.0 => Vs::V(a.r[i as i64 as usize].clone()),
(V::A(a),V::Scalar(i)) if i < 0.0 => Vs::V(a.r[((a.r.len() as f64) + i) as i64 as usize].clone()),
(V::A(a),V::Scalar(i)) if *i >= 0.0 => Vs::V(a.r[*i as i64 as usize].clone()),
(V::A(a),V::Scalar(i)) if *i < 0.0 => Vs::V(a.r[((a.r.len() as f64) + i) as i64 as usize].clone()),
_ => panic!("pick - can't index into non array"),
}
},
Expand All @@ -347,7 +347,7 @@ fn pick(arity: usize, x: Vn, w: Vn) -> Vs {
fn windows(arity: usize, x: Vn, _w: Vn) -> Vs {
match arity {
1 => match x.unwrap() {
V::Scalar(n) => Vs::V(V::A(Cc::new(A::new((0..n as i64).map(|v| V::Scalar(v as f64)).collect::<Vec<V>>(),vec![n as usize])))),
V::Scalar(n) => Vs::V(V::A(Cc::new(A::new((0..*n as i64).map(|v| V::Scalar(v as f64)).collect::<Vec<V>>(),vec![*n as usize])))),
_ => panic!("x is not a number"),
},
_ => panic!("illegal windows arity"),
Expand All @@ -359,7 +359,7 @@ fn table(arity: usize, f: Vn, x: Vn, w: Vn) -> Vs {
match arity {
1 => match x.unwrap() {
V::A(xa) => {
let ravel = (*xa).r.iter().map(|e| call(arity,f.clone(),Some(e.clone()),None).into_v().unwrap() ).collect::<Vec<V>>();
let ravel = (*xa).r.iter().map(|e| call(arity,f,Some(e),None).into_v().unwrap() ).collect::<Vec<V>>();
let sh = (*xa).sh.clone();
Vs::V(V::A(Cc::new(A::new(ravel,sh))))
},
Expand All @@ -369,7 +369,7 @@ fn table(arity: usize, f: Vn, x: Vn, w: Vn) -> Vs {
match (x.unwrap(),w.unwrap()) {
(V::A(xa),V::A(wa)) => {
let ravel = (*wa).r.iter().flat_map(|d| {
(*xa).r.iter().map(|e| call(arity,f.clone(),Some(e.clone()),Some(d.clone())).into_v().unwrap() ).collect::<Vec<V>>()
(*xa).r.iter().map(|e| call(arity,f,Some(e),Some(d)).into_v().unwrap() ).collect::<Vec<V>>()
}).collect::<Vec<V>>();
let sh = (*wa).sh.clone().into_iter().chain((*xa).sh.clone().into_iter()).collect();
Vs::V(V::A(Cc::new(A::new(ravel,sh))))
Expand Down Expand Up @@ -405,7 +405,7 @@ fn scan(arity: usize, f: Vn, x: Vn, w: Vn) -> Vs {
i += 1;
}
while i < l {
r[i] = call(2,f.clone(),Some(a.r[i].clone()),Some(r[i-c].clone())).as_v().unwrap().clone();
r[i] = call(2,f,Some(&a.r[i]),Some(&r[i-c])).as_v().unwrap().clone();
i += 1;
}
};
Expand All @@ -416,9 +416,9 @@ fn scan(arity: usize, f: Vn, x: Vn, w: Vn) -> Vs {
},
2 => {
let (wr,wa) = match w.unwrap() {
V::A(wa) => (wa.sh.len(),wa),
V::A(wa) => (wa.sh.len(),wa.clone()),
// TODO `wa` doesn't actually need to be a ref counted array
V::Scalar(ws) => (0,Cc::new(A::new(vec![V::Scalar(ws)],vec![1]))),
V::Scalar(ws) => (0,Cc::new(A::new(vec![V::Scalar(*ws)],vec![1]))),
_ => panic!("dyadic scan w is invalid type"),
};
match x.unwrap() {
Expand All @@ -442,11 +442,11 @@ fn scan(arity: usize, f: Vn, x: Vn, w: Vn) -> Vs {
}
i = 0;
while i < c {
r[i] = call(2,f.clone(),Some(xa.r[i].clone()),Some(wa.r[i].clone())).as_v().unwrap().clone();
r[i] = call(2,f.clone(),Some(&xa.r[i]),Some(&wa.r[i])).as_v().unwrap().clone();
i += 1;
}
while i < l {
r[i] = call(2,f.clone(),Some(xa.r[i].clone()),Some(r[i-c].clone())).as_v().unwrap().clone();
r[i] = call(2,f.clone(),Some(&xa.r[i]),Some(&r[i-c])).as_v().unwrap().clone();
i += 1;
}
};
Expand Down Expand Up @@ -489,7 +489,7 @@ pub fn decompose(arity:usize, x: Vn,_w: Vn) -> Vs {
_ => false
}
{
Vs::V(V::A(Cc::new(A::new(vec![V::Scalar(-1.0),(&x).as_ref().unwrap().clone()],vec![2]))))
Vs::V(V::A(Cc::new(A::new(vec![V::Scalar(-1.0),x.unwrap().clone()],vec![2]))))
}
else if // primitives
match (&x).as_ref().unwrap() {
Expand All @@ -506,7 +506,7 @@ pub fn decompose(arity:usize, x: Vn,_w: Vn) -> Vs {
_ => false,
}
{
Vs::V(V::A(Cc::new(A::new(vec![V::Scalar(0.0),(&x).as_ref().unwrap().clone()],vec![2]))))
Vs::V(V::A(Cc::new(A::new(vec![V::Scalar(0.0),x.unwrap().clone()],vec![2]))))
}
else if // repr
match (&x).as_ref().unwrap() {
Expand Down Expand Up @@ -557,7 +557,7 @@ pub fn decompose(arity:usize, x: Vn,_w: Vn) -> Vs {
let Tr3(f,g,h) = (*tr3).deref();
Vs::V(V::A(Cc::new(A::new(vec![V::Scalar(3.0),f.clone(),g.clone(),h.clone()],vec![4]))))
},
_ => Vs::V(V::A(Cc::new(A::new(vec![V::Scalar(1.0),(&x).as_ref().unwrap().clone()],vec![2])))),
_ => Vs::V(V::A(Cc::new(A::new(vec![V::Scalar(1.0),x.unwrap().clone()],vec![2])))),
}
}
},
Expand All @@ -570,45 +570,45 @@ pub fn decompose(arity:usize, x: Vn,_w: Vn) -> Vs {
pub fn prim_ind(arity:usize, x: Vn,_w: Vn) -> Vs {
match arity {
1 => match x.unwrap() {
V::BlockInst(_b,Some(prim)) => Vs::V(V::Scalar(prim as f64)),
V::UserMd1(_b,_a,Some(prim)) => Vs::V(V::Scalar(prim as f64)),
V::UserMd2(_b,_a,Some(prim)) => Vs::V(V::Scalar(prim as f64)),
V::Fn(_a,Some(prim)) => Vs::V(V::Scalar(prim as f64)),
V::R1(_f,Some(prim)) => Vs::V(V::Scalar(prim as f64)),
V::R2(_f,Some(prim)) => Vs::V(V::Scalar(prim as f64)),
V::D1(_d1,Some(prim)) => Vs::V(V::Scalar(prim as f64)),
V::D2(_d2,Some(prim)) => Vs::V(V::Scalar(prim as f64)),
V::Tr2(_tr2,Some(prim)) => Vs::V(V::Scalar(prim as f64)),
V::Tr3(_tr3,Some(prim)) => Vs::V(V::Scalar(prim as f64)),
V::BlockInst(_b,Some(prim)) => Vs::V(V::Scalar(*prim as f64)),
V::UserMd1(_b,_a,Some(prim)) => Vs::V(V::Scalar(*prim as f64)),
V::UserMd2(_b,_a,Some(prim)) => Vs::V(V::Scalar(*prim as f64)),
V::Fn(_a,Some(prim)) => Vs::V(V::Scalar(*prim as f64)),
V::R1(_f,Some(prim)) => Vs::V(V::Scalar(*prim as f64)),
V::R2(_f,Some(prim)) => Vs::V(V::Scalar(*prim as f64)),
V::D1(_d1,Some(prim)) => Vs::V(V::Scalar(*prim as f64)),
V::D2(_d2,Some(prim)) => Vs::V(V::Scalar(*prim as f64)),
V::Tr2(_tr2,Some(prim)) => Vs::V(V::Scalar(*prim as f64)),
V::Tr3(_tr3,Some(prim)) => Vs::V(V::Scalar(*prim as f64)),
_ => Vs::V(V::Scalar(64 as f64)),
},
_ => panic!("illegal plus arity"),
}
}

pub fn provide() -> A {
let fns = vec![V::Fn(typ,None),
V::Fn(fill,None),
V::Fn(log,None),
V::Fn(group_len,None),
V::Fn(group_ord,None),
V::Fn(assert_fn,None),
V::Fn(plus,None),
V::Fn(minus,None),
V::Fn(times,None),
V::Fn(divide,None),
V::Fn(power,None),
V::Fn(floor,None),
V::Fn(equals,None),
V::Fn(lesseq,None),
V::Fn(shape,None),
V::Fn(reshape,None),
V::Fn(pick,None),
V::Fn(windows,None),
V::R1(table,None),
V::R1(scan,None),
V::R2(fill_by,None),
V::R2(cases,None),
V::R2(catches,None)];
let fns = vec![V::Fn(Fn(typ),None),
V::Fn(Fn(fill),None),
V::Fn(Fn(log),None),
V::Fn(Fn(group_len),None),
V::Fn(Fn(group_ord),None),
V::Fn(Fn(assert_fn),None),
V::Fn(Fn(plus),None),
V::Fn(Fn(minus),None),
V::Fn(Fn(times),None),
V::Fn(Fn(divide),None),
V::Fn(Fn(power),None),
V::Fn(Fn(floor),None),
V::Fn(Fn(equals),None),
V::Fn(Fn(lesseq),None),
V::Fn(Fn(shape),None),
V::Fn(Fn(reshape),None),
V::Fn(Fn(pick),None),
V::Fn(Fn(windows),None),
V::R1(R1(table),None),
V::R1(R1(scan),None),
V::R2(R2(fill_by),None),
V::R2(R2(cases),None),
V::R2(R2(catches),None)];
A::new(fns,vec![23])
}
Loading

0 comments on commit cd0a42e

Please sign in to comment.