Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions examples/hessian.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,16 @@ fn main() {
let y = Jet::<2>::constant(1.0);
let result_x = f(x, y);
println!("At (1, 1):");
println!(" f(1, 1) = {}", result_x.value()); // 0
println!(" df/dx = {}", result_x.dx()); // 0
println!(" d²f/dx² = {}", result_x.ddx()); // 42
println!(" f(1, 1) = {}", result_x.value()); // 0
println!(" df/dx = {}", result_x.dx()); // 0
println!(" d²f/dx² = {}", result_x.ddx()); // 42

// df/dy and d²f/dy² at (1, 1): seed y as variable, keep x constant.
let x = Jet::<2>::constant(1.0);
let y = Jet::<2>::var(1.0);
let result_y = f(x, y);
println!(" df/dy = {}", result_y.dx()); // 0
println!(" d²f/dy² = {}", result_y.ddx()); // 10
println!(" df/dy = {}", result_y.dx()); // 0
println!(" d²f/dy² = {}", result_y.ddx()); // 10
}

fn f(x: Jet<2>, y: Jet<2>) -> Jet<2> {
Expand Down
18 changes: 9 additions & 9 deletions examples/jet_ad.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@ fn main() {
let x = Jet::<1>::var(2.0);
let y = x.powi(3);
println!("f(x) = x^3 at x = 2");
println!(" f(2) = {}", y.value()); // 8
println!(" f'(2) = {}", y.dx()); // 12
println!(" f(2) = {}", y.value()); // 8
println!(" f'(2) = {}", y.dx()); // 12

// 2. Second-order: f(x) = sin(x) at x = pi/4
let x = Jet::<2>::var(std::f64::consts::FRAC_PI_4);
let y = x.sin();
println!("\nf(x) = sin(x) at x = pi/4");
println!(" f(pi/4) = {:.6}", y.value()); // 0.707107
println!(" f'(pi/4) = {:.6}", y.dx()); // 0.707107 (cos(pi/4))
println!(" f''(pi/4) = {:.6}", y.ddx()); // -0.707107 (-sin(pi/4))
println!(" f(pi/4) = {:.6}", y.value()); // 0.707107
println!(" f'(pi/4) = {:.6}", y.dx()); // 0.707107 (cos(pi/4))
println!(" f''(pi/4) = {:.6}", y.ddx()); // -0.707107 (-sin(pi/4))

// 3. Higher-order: f(x) = exp(x) at x = 0
// All derivatives of exp are 1 at 0, so f^(k)(0) = 1 for every k.
Expand Down Expand Up @@ -43,15 +43,15 @@ fn main() {
let x: Dual = Dual::var(1.0);
let y = x.ln();
println!("\nUsing Dual alias: f(x) = ln(x) at x = 1");
println!(" f(1) = {}", y.value()); // 0
println!(" f'(1) = {}", y.dx()); // 1
println!(" f(1) = {}", y.value()); // 0
println!(" f'(1) = {}", y.dx()); // 1

let x: HyperDual = HyperDual::var(1.0);
let y = x.ln();
println!("\nUsing HyperDual alias: f(x) = ln(x) at x = 1");
println!(" f(1) = {}", y.value()); // 0
println!(" f'(1) = {}", y.dx()); // 1
println!(" f''(1) = {}", y.ddx()); // -1
println!(" f'(1) = {}", y.dx()); // 1
println!(" f''(1) = {}", y.ddx()); // -1

// 6. Backward-compat constructors (AD0 / AD1 / AD2)
// AD is an alias for Jet<2>; AD1 and AD2 set derivatives directly.
Expand Down
33 changes: 25 additions & 8 deletions examples/real_trait_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,22 @@ fn main() {
// The Real trait is implemented for f64 and AD (= Jet<2>).
// Using the generic function with both types:
let x_f64 = 2f64;
let x_ad1 = AD1(2f64, 1f64); // Jet<2> with first derivative = 1, second = 0
let x_ad1 = AD1(2f64, 1f64); // Jet<2> with first derivative = 1, second = 0
let x_ad2 = AD2(2f64, 1f64, 0f64); // Jet<2> with explicit second derivative

println!("f(x) = x^2, evaluated with Real trait:");
println!(" f64: f(2) = {}", f(x_f64));
println!(" AD1: f(2) = {}, f'(2) = {}", f(x_ad1).value(), f(x_ad1).dx());
println!(" AD2: f(2) = {}, f'(2) = {}, f''(2) = {}",
f(x_ad2).value(), f(x_ad2).dx(), f(x_ad2).ddx());
println!(
" AD1: f(2) = {}, f'(2) = {}",
f(x_ad1).value(),
f(x_ad1).dx()
);
println!(
" AD2: f(2) = {}, f'(2) = {}, f''(2) = {}",
f(x_ad2).value(),
f(x_ad2).dx(),
f(x_ad2).ddx()
);

// Direct Jet<N> usage (without the Real trait — more explicit):
println!("\nf(x) = x^2 with explicit Jet<N> types:");
Expand All @@ -22,13 +30,22 @@ fn main() {

let x2 = Jet::<2>::var(2.0);
let y2 = x2.powi(2);
println!(" Jet<2>: f(2) = {}, f'(2) = {}, f''(2) = {}",
y2.value(), y2.dx(), y2.ddx());
println!(
" Jet<2>: f(2) = {}, f'(2) = {}, f''(2) = {}",
y2.value(),
y2.dx(),
y2.ddx()
);

let x3 = Jet::<3>::var(2.0);
let y3 = x3.powi(2);
println!(" Jet<3>: f(2) = {}, f'(2) = {}, f''(2) = {}, f'''(2) = {}",
y3.value(), y3.dx(), y3.ddx(), y3.derivative(3));
println!(
" Jet<3>: f(2) = {}, f'(2) = {}, f''(2) = {}, f'''(2) = {}",
y3.value(),
y3.dx(),
y3.ddx(),
y3.derivative(3)
);
}

// Generic function over the Real trait (works with f64 and AD = Jet<2>).
Expand Down
17 changes: 13 additions & 4 deletions src/structure/ad.rs
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,10 @@ pub type HyperDual = Jet<2>;
/// Create a `Jet<0>` constant (zero-order, value only).
#[inline]
pub fn ad0(x: f64) -> Jet<0> {
Jet { value: x, deriv: [] }
Jet {
value: x,
deriv: [],
}
}

/// Create a `Jet<1>` with value and first derivative.
Expand Down Expand Up @@ -510,7 +513,10 @@ impl<const N: usize> Index<usize> for Jet<N> {
} else if index <= N {
&self.deriv[index - 1]
} else {
panic!("Jet<{}> index {} out of bounds (max index = {})", N, index, N)
panic!(
"Jet<{}> index {} out of bounds (max index = {})",
N, index, N
)
}
}
}
Expand All @@ -522,7 +528,10 @@ impl<const N: usize> IndexMut<usize> for Jet<N> {
} else if index <= N {
&mut self.deriv[index - 1]
} else {
panic!("Jet<{}> index {} out of bounds (max index = {})", N, index, N)
panic!(
"Jet<{}> index {} out of bounds (max index = {})",
N, index, N
)
}
}
}
Expand Down Expand Up @@ -941,7 +950,7 @@ impl<const N: usize> Jet<N> {
cs += ka * s.coeff(n - k);
}
s.set_coeff(n, ss / (n as f64));
c.set_coeff(n, cs / (n as f64)); // NO negative for cosh
c.set_coeff(n, cs / (n as f64)); // NO negative for cosh
}
(s, c)
}
Expand Down
46 changes: 37 additions & 9 deletions src/structure/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -329,11 +329,11 @@ use arrow::datatypes::{
Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type,
UInt64Type, UInt8Type,
};
#[cfg(feature = "parquet")]
use indexmap::IndexMap;
use std::cmp::{max, min};
#[cfg(feature = "csv")]
use std::collections::HashMap;
#[cfg(feature = "parquet")]
use indexmap::IndexMap;
#[cfg(any(feature = "csv", feature = "nc", feature = "parquet"))]
use std::error::Error;
use std::fmt;
Expand Down Expand Up @@ -1311,15 +1311,21 @@ impl Series {
pub fn var(&self) -> anyhow::Result<f64> {
use crate::statistics::stat::Statistics;
let v = self.to_f64_vec()?;
anyhow::ensure!(v.len() > 1, "Cannot compute variance of Series with fewer than 2 elements");
anyhow::ensure!(
v.len() > 1,
"Cannot compute variance of Series with fewer than 2 elements"
);
Ok(v.var())
}

/// Standard deviation of all elements (numeric types only)
pub fn sd(&self) -> anyhow::Result<f64> {
use crate::statistics::stat::Statistics;
let v = self.to_f64_vec()?;
anyhow::ensure!(v.len() > 1, "Cannot compute sd of Series with fewer than 2 elements");
anyhow::ensure!(
v.len() > 1,
"Cannot compute sd of Series with fewer than 2 elements"
);
Ok(v.sd())
}

Expand All @@ -1329,8 +1335,15 @@ impl Series {

macro_rules! typed_min {
($v:expr, $dtype:ident) => {{
let min_val = $v.iter().cloned().reduce(|a, b| if a <= b { a } else { b }).unwrap();
Ok(Scalar { value: DTypeValue::$dtype(min_val), dtype: DType::$dtype })
let min_val = $v
.iter()
.cloned()
.reduce(|a, b| if a <= b { a } else { b })
.unwrap();
Ok(Scalar {
value: DTypeValue::$dtype(min_val),
dtype: DType::$dtype,
})
}};
}

Expand Down Expand Up @@ -1359,8 +1372,15 @@ impl Series {

macro_rules! typed_max {
($v:expr, $dtype:ident) => {{
let max_val = $v.iter().cloned().reduce(|a, b| if a >= b { a } else { b }).unwrap();
Ok(Scalar { value: DTypeValue::$dtype(max_val), dtype: DType::$dtype })
let max_val = $v
.iter()
.cloned()
.reduce(|a, b| if a >= b { a } else { b })
.unwrap();
Ok(Scalar {
value: DTypeValue::$dtype(max_val),
dtype: DType::$dtype,
})
}};
}

Expand Down Expand Up @@ -1996,7 +2016,15 @@ impl DataFrame {

let stat_labels = vec!["count", "mean", "sd", "min", "max"];
let mut result = DataFrame::new(vec![]);
result.push("stat", Series::new(stat_labels.iter().map(|s| s.to_string()).collect::<Vec<String>>()));
result.push(
"stat",
Series::new(
stat_labels
.iter()
.map(|s| s.to_string())
.collect::<Vec<String>>(),
),
);

for (i, series) in self.data.iter().enumerate() {
if let Ok(v) = series.to_f64_vec() {
Expand Down
7 changes: 3 additions & 4 deletions src/util/plot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@ use self::pyo3::types::{IntoPyDict, PyDictMethods};
use self::pyo3::{PyResult, Python};
pub use self::Grid::{Off, On};
use self::PlotOptions::{Domain, Images, Pairs, Path};
use std::collections::HashMap;
use std::fmt::Display;
use std::borrow::BorrowMut;
use std::collections::HashMap;
use std::ffi::CString;
use std::fmt::Display;

type Vector = Vec<f64>;

Expand Down Expand Up @@ -504,8 +504,7 @@ impl Plot for Plot2D {
let plot_type = self.plot_type.clone();

// Global variables to plot
let mut globals =
vec![("plt", py.import("matplotlib.pyplot")?)].into_py_dict(py)?;
let mut globals = vec![("plt", py.import("matplotlib.pyplot")?)].into_py_dict(py)?;
globals.borrow_mut().set_item("x", x)?;
globals.borrow_mut().set_item("y", ys)?;
globals.borrow_mut().set_item("pair", pairs)?;
Expand Down
6 changes: 5 additions & 1 deletion tests/dataframe/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,11 @@ fn test_series_max() {

#[test]
fn test_series_max_string() {
let a = Series::new(vec!["apple".to_string(), "banana".to_string(), "cherry".to_string()]);
let a = Series::new(vec![
"apple".to_string(),
"banana".to_string(),
"cherry".to_string(),
]);
let m = a.max().unwrap();
assert_eq!(m, Scalar::new("cherry".to_string()));
}
Expand Down
33 changes: 9 additions & 24 deletions tests/jet_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -655,11 +655,7 @@ fn test_sin_at_zero_jet10_derivative_cycle() {
let y = x.sin();
let expected = [0.0, 1.0, 0.0, -1.0, 0.0, 1.0, 0.0, -1.0, 0.0, 1.0, 0.0];
for (k, &exp) in expected.iter().enumerate() {
assert_close_eps(
y.derivative(k),
exp,
1e-10,
);
assert_close_eps(y.derivative(k), exp, 1e-10);
}
}

Expand Down Expand Up @@ -693,11 +689,7 @@ fn test_cos_at_zero_jet10_derivative_cycle() {
let y = x.cos();
let expected = [1.0, 0.0, -1.0, 0.0, 1.0, 0.0, -1.0, 0.0, 1.0, 0.0, -1.0];
for (k, &exp) in expected.iter().enumerate() {
assert_close_eps(
y.derivative(k),
exp,
1e-10,
);
assert_close_eps(y.derivative(k), exp, 1e-10);
}
}

Expand Down Expand Up @@ -973,9 +965,9 @@ fn test_index_operator_jet1() {
#[test]
fn test_index_operator_jet2() {
let j = ad2(5.0, 3.0, 4.0);
assert_close(j[0], 5.0); // value
assert_close(j[1], 3.0); // deriv[0] = dx
assert_close(j[2], 2.0); // deriv[1] = ddx/2 = 4/2 = 2
assert_close(j[0], 5.0); // value
assert_close(j[1], 3.0); // deriv[0] = dx
assert_close(j[2], 2.0); // deriv[1] = ddx/2 = 4/2 = 2
}

#[test]
Expand Down Expand Up @@ -1106,26 +1098,19 @@ fn test_fpvector_fmap_jet1() {

#[test]
fn test_fpvector_sum_jet1() {
let v: Vec<Jet<1>> = vec![
ad1(1.0, 1.0),
ad1(2.0, 2.0),
ad1(3.0, 3.0),
];
let v: Vec<Jet<1>> = vec![ad1(1.0, 1.0), ad1(2.0, 2.0), ad1(3.0, 3.0)];
let s = v.sum();
// FPVector::sum uses reduce(self[0], +) which double-counts first element
assert_close(s.value(), 7.0); // 1 + (1+2+3)
assert_close(s.value(), 7.0); // 1 + (1+2+3)
assert_close(s.dx(), 7.0);
}

#[test]
fn test_fpvector_prod_jet1() {
let v: Vec<Jet<1>> = vec![
Jet::<1>::constant(2.0),
Jet::<1>::constant(3.0),
];
let v: Vec<Jet<1>> = vec![Jet::<1>::constant(2.0), Jet::<1>::constant(3.0)];
let p = v.prod();
// FPVector::prod uses reduce(self[0], *) which double-counts first element
assert_close(p.value(), 12.0); // 2 * (2*3)
assert_close(p.value(), 12.0); // 2 * (2*3)
}

// =============================================================================
Expand Down