mscore/algorithm/
utility.rs

1use std::collections::HashMap;
2use std::f64::consts::SQRT_2;
3use rayon::prelude::*;
4use rayon::ThreadPoolBuilder;
5
6use std::collections::VecDeque;
7
8fn gauss_kronrod(f: &dyn Fn(f64) -> f64, a: f64, b: f64) -> (f64, f64) {
9    let nodes = [
10        0.0, 0.20778495500789848, 0.40584515137739717, 0.58608723546769113,
11        0.74153118559939444, 0.86486442335976907, 0.94910791234275852, 0.99145537112081264,
12    ];
13    let weights_gauss = [
14        0.41795918367346939, 0.38183005050511894, 0.27970539148927667, 0.12948496616886969,
15    ];
16    let weights_kronrod = [
17        0.20948214108472783, 0.20443294007529889, 0.19035057806478541, 0.16900472663926790,
18        0.14065325971552592, 0.10479001032225018, 0.06309209262997855, 0.02293532201052922,
19    ];
20
21    let c1 = (b - a) / 2.0;
22    let c2 = (b + a) / 2.0;
23
24    let mut integral_gauss = 0.0;
25    let mut integral_kronrod = 0.0;
26
27    for i in 0..4 {
28        let x = c1 * nodes[i] + c2;
29        integral_gauss += weights_gauss[i] * (f(x) + f(2.0 * c2 - x));
30    }
31
32    for i in 0..8 {
33        let x = c1 * nodes[i] + c2;
34        integral_kronrod += weights_kronrod[i] * (f(x) + f(2.0 * c2 - x));
35    }
36
37    integral_gauss *= c1;
38    integral_kronrod *= c1;
39
40    (integral_kronrod, (integral_kronrod - integral_gauss).abs())
41}
42
43pub fn adaptive_integration(f: &dyn Fn(f64) -> f64, a: f64, b: f64, epsabs: f64, epsrel: f64) -> (f64, f64) {
44    let mut intervals = VecDeque::new();
45    intervals.push_back((a, b));
46
47    let mut result = 0.0;
48    let mut total_error = 0.0;
49
50    while let Some((a, b)) = intervals.pop_front() {
51        let (integral, error) = gauss_kronrod(f, a, b);
52        if error < epsabs || error < epsrel * integral.abs() {
53            result += integral;
54            total_error += error;
55        } else {
56            let mid = (a + b) / 2.0;
57            intervals.push_back((a, mid));
58            intervals.push_back((mid, b));
59        }
60    }
61
62    (result, total_error)
63}
64
65
66
67
68// Numerical integration using the trapezoidal rule
69fn integrate<F>(f: F, a: f64, b: f64, n: usize) -> f64
70    where
71        F: Fn(f64) -> f64,
72{
73    let dx = (b - a) / n as f64;
74    let mut sum = 0.0;
75    for i in 0..n {
76        let x = a + i as f64 * dx;
77        sum += f(x);
78    }
79    sum * dx
80}
81
82// Complementary error function (erfc)
83fn erfc(x: f64) -> f64 {
84    1.0 - erf(x)
85}
86
87// Error function (erf)
88fn erf(x: f64) -> f64 {
89    let t = 1.0 / (1.0 + 0.5 * x.abs());
90    let tau = t * (-x * x - 1.26551223 + t * (1.00002368 +
91        t * (0.37409196 + t * (0.09678418 + t * (-0.18628806 +
92            t * (0.27886807 + t * (-1.13520398 + t * (1.48851587 +
93                t * (-0.82215223 + t * 0.17087277)))))))))
94        .exp();
95    if x >= 0.0 {
96        1.0 - tau
97    } else {
98        tau - 1.0
99    }
100}
101
102// Exponentially modified Gaussian function
103fn emg(x: f64, mu: f64, sigma: f64, lambda: f64) -> f64 {
104    let part1 = lambda / 2.0 * (-lambda * (x - mu) + lambda * lambda * sigma * sigma / 2.0).exp();
105    let part2 = erfc((mu + lambda * sigma * sigma - x) / (sigma * 2.0_f64.sqrt()));
106    part1 * part2
107}
108
109pub fn custom_cdf_normal(x: f64, mean: f64, std_dev: f64) -> f64 {
110    let z = (x - mean) / std_dev;
111    0.5 * (1.0 + erf(z / SQRT_2))
112}
113
114pub fn accumulated_intensity_cdf_normal(sample_start: f64, sample_end: f64, mean: f64, std_dev: f64) -> f64 {
115    let cdf_start = custom_cdf_normal(sample_start, mean, std_dev);
116    let cdf_end = custom_cdf_normal(sample_end, mean, std_dev);
117    cdf_end - cdf_start
118}
119
120pub fn calculate_bounds_normal(mean: f64, std: f64, z_score: f64) -> (f64, f64) {
121    (mean - z_score * std, mean + z_score * std)
122}
123
124pub fn emg_function(x: f64, mu: f64, sigma: f64, lambda: f64) -> f64 {
125    let prefactor = lambda / 2.0 * ((lambda / 2.0) * (2.0 * mu + lambda * sigma.powi(2) - 2.0 * x)).exp();
126    let erfc_part = erfc((mu + lambda * sigma.powi(2) - x) / (SQRT_2 * sigma));
127    prefactor * erfc_part
128}
129
130pub fn emg_cdf_range(lower_limit: f64, upper_limit: f64, mu: f64, sigma: f64, lambda: f64, n_steps: Option<usize>) -> f64 {
131    let n_steps = n_steps.unwrap_or(1000);
132    integrate(|x| emg(x, mu, sigma, lambda), lower_limit, upper_limit, n_steps)
133}
134
135pub fn calculate_bounds_emg(mu: f64, sigma: f64, lambda: f64, step_size: f64, target: f64, lower_start: f64, upper_start: f64, n_steps: Option<usize>) -> (f64, f64) {
136    assert!(0.0 <= target && target <= 1.0, "target must be in [0, 1]");
137
138    let lower_initial = mu - lower_start * sigma - 2.0;
139    let upper_initial = mu + upper_start * sigma;
140
141    let steps = ((upper_initial - lower_initial) / step_size).round() as usize;
142    let search_space: Vec<f64> = (0..=steps).map(|i| lower_initial + i as f64 * step_size).collect();
143
144    let calc_cdf = |low: usize, high: usize| -> f64 {
145        emg_cdf_range(search_space[low], search_space[high], mu, sigma, lambda, n_steps)
146    };
147
148    // Binary search for cutoff values
149    let (mut low, mut high) = (0, steps);
150    while low < high {
151        let mid = low + (high - low) / 2;
152        if calc_cdf(0, mid) < target {
153            low = mid + 1;
154        } else {
155            high = mid;
156        }
157    }
158    let upper_cutoff_index = low;
159
160    low = 0;
161    high = upper_cutoff_index;
162    while low < high {
163        let mid = high - (high - low) / 2;
164        let prob_mid_to_upper = calc_cdf(mid, upper_cutoff_index);
165
166        if prob_mid_to_upper < target {
167            high = mid - 1;
168        } else {
169            low = mid;
170        }
171    }
172    let lower_cutoff_index = high;
173
174    (search_space[lower_cutoff_index], search_space[upper_cutoff_index])
175}
176
177pub fn calculate_frame_occurrence_emg(retention_times: &[f64], rt: f64, sigma: f64, lambda_: f64, target_p: f64, step_size: f64, n_steps: Option<usize>) -> Vec<i32> {
178    let (rt_min, rt_max) = calculate_bounds_emg(rt, sigma, lambda_, step_size, target_p, 20.0, 60.0, n_steps);
179
180    // Finding the frame closest to rt_min
181    let first_frame = retention_times.iter()
182        .enumerate()
183        .min_by(|(_, &a), (_, &b)| (a - rt_min).abs().partial_cmp(&(b - rt_min).abs()).unwrap())
184        .map(|(idx, _)| idx + 1) // Rust is zero-indexed, so +1 to match Python's 1-indexing
185        .unwrap_or(0); // Fallback in case of an empty slice
186
187    // Finding the frame closest to rt_max
188    let last_frame = retention_times.iter()
189        .enumerate()
190        .min_by(|(_, &a), (_, &b)| (a - rt_max).abs().partial_cmp(&(b - rt_max).abs()).unwrap())
191        .map(|(idx, _)| idx + 1) // Same adjustment for 1-indexing
192        .unwrap_or(0); // Fallback
193
194    // Generating the range of frames
195    (first_frame..=last_frame).map(|x| x as i32).collect()
196}
197
198pub fn calculate_frame_abundance_emg(time_map: &HashMap<i32, f64>, occurrences: &[i32], rt: f64, sigma: f64, lambda_: f64, rt_cycle_length: f64, n_steps: Option<usize>) -> Vec<f64> {
199    let mut frame_abundance = Vec::new();
200
201    for &occurrence in occurrences {
202        if let Some(&time) = time_map.get(&occurrence) {
203            let start = time - rt_cycle_length;
204            let i = emg_cdf_range(start, time, rt, sigma, lambda_, n_steps);
205            frame_abundance.push(i);
206        }
207    }
208
209    frame_abundance
210}
211
212// retention_times: &[f64], rt: f64, sigma: f64, lambda_: f64
213pub fn calculate_frame_occurrences_emg_par(retention_times: &[f64], rts: Vec<f64>, sigmas: Vec<f64>, lambdas: Vec<f64>, target_p: f64, step_size: f64, num_threads: usize, n_steps: Option<usize>) -> Vec<Vec<i32>> {
214    let thread_pool = ThreadPoolBuilder::new().num_threads(num_threads).build().unwrap();
215    let result = thread_pool.install(|| {
216        rts.into_par_iter().zip(sigmas.into_par_iter()).zip(lambdas.into_par_iter())
217            .map(|((rt, sigma), lambda)| {
218                calculate_frame_occurrence_emg(retention_times, rt, sigma, lambda, target_p, step_size, n_steps)
219            })
220            .collect()
221    });
222    result
223}
224
225pub fn calculate_frame_abundances_emg_par(time_map: &HashMap<i32, f64>, occurrences: Vec<Vec<i32>>, rts: Vec<f64>, sigmas: Vec<f64>, lambdas: Vec<f64>, rt_cycle_length: f64, num_threads: usize, n_steps: Option<usize>) -> Vec<Vec<f64>> {
226    let thread_pool = ThreadPoolBuilder::new().num_threads(num_threads).build().unwrap();
227    let result = thread_pool.install(|| {
228        occurrences.into_par_iter().zip(rts.into_par_iter()).zip(sigmas.into_par_iter()).zip(lambdas.into_par_iter())
229            .map(|(((occurrences, rt), sigma), lambda)| {
230                calculate_frame_abundance_emg(time_map, &occurrences, rt, sigma, lambda, rt_cycle_length, n_steps)
231            })
232            .collect()
233    });
234    result
235}
236
237/// Returns the CDF in the range [sample_start, sample_end] for a Normal(mean, std_dev).
238pub fn normal_cdf_range(lower_limit: f64, upper_limit: f64, mean: f64, std_dev: f64) -> f64 {
239    let cdf_start = custom_cdf_normal(lower_limit, mean, std_dev);
240    let cdf_end = custom_cdf_normal(upper_limit, mean, std_dev);
241    cdf_end - cdf_start
242}
243
244/// Calculate the bounding interval [lower, upper] around `mean` that captures `target` total probability
245/// using a binary search across a discretized search space. This mirrors `calculate_bounds_emg`.
246pub fn calculate_bounds_gaussian(
247    mean: f64,
248    sigma: f64,
249    step_size: f64,
250    target: f64,
251    lower_start: f64,
252    upper_start: f64
253) -> (f64, f64) {
254    assert!((0.0..=1.0).contains(&target), "target must be in [0, 1]");
255
256    let lower_initial = mean - lower_start * sigma;
257    let upper_initial = mean + upper_start * sigma;
258
259    let steps = ((upper_initial - lower_initial) / step_size).ceil() as usize;
260    let search_space: Vec<f64> = (0..=steps)
261        .map(|i| lower_initial + i as f64 * step_size)
262        .collect();
263
264    let calc_cdf = |low: usize, high: usize| -> f64 {
265        normal_cdf_range(search_space[low], search_space[high], mean, sigma)
266    };
267
268    // 1) Find upper cutoff
269    let (mut low, mut high) = (0, steps);
270    while low < high {
271        let mid = low + (high - low) / 2;
272        if calc_cdf(0, mid) < target {
273            low = mid + 1;
274        } else {
275            high = mid;
276        }
277    }
278    let upper_cutoff_index = low;
279
280    // 2) Find lower cutoff
281    low = 0;
282    high = upper_cutoff_index;
283    while low < high {
284        let mid = high - (high - low) / 2;
285        if calc_cdf(mid, upper_cutoff_index) < target {
286            high = mid - 1;
287        } else {
288            low = mid;
289        }
290    }
291    let lower_cutoff_index = high;
292
293    (search_space[lower_cutoff_index], search_space[upper_cutoff_index])
294}
295
296/// Returns all scan indices (0-based) that fall into the range where Normal(mean, sigma)
297/// has at least `target_p` coverage.
298///
299/// For timsTOF data, `inverse_ion_mobility` runs backward (highest to lowest values correspond to scans).
300///
301/// # Arguments
302///
303/// - `inverse_ion_mobility`: The inverse ion mobility values for all scans (descending order).
304/// - `mean`: The mean of the Gaussian distribution.
305/// - `sigma`: The standard deviation of the Gaussian distribution.
306/// - `target_p`: The target probability to capture.
307/// - `step_size`: Step size for searching bounds.
308/// - `n_lower_start`: Initial lower bound factor (relative to sigma).
309/// - `n_upper_start`: Initial upper bound factor (relative to sigma).
310///
311/// # Returns
312///
313/// A `Vec<usize>` containing all scan indices (0-based) within the computed range.
314///
315/// # Example
316///
317/// ```rust
318/// use mscore::algorithm::utility::calculate_scan_occurrence_gaussian;
319///
320/// let inverse_ion_mobility = vec![0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3];
321/// let scans = calculate_scan_occurrence_gaussian(
322///     &inverse_ion_mobility,
323///     1.1,  // mean
324///     0.001,  // sigma
325///     0.9999, // target probability
326///     0.01, // step size
327///     3.0,  // n_lower_start
328///     3.0   // n_upper_start
329/// );
330///
331/// assert_eq!(scans, vec![2]); // Scans corresponding to 1.3 ± 1σ
332/// ```
333pub fn calculate_scan_occurrence_gaussian(
334    inverse_ion_mobility: &[f64],
335    mean: f64,
336    sigma: f64,
337    target_p: f64,
338    step_size: f64,
339    n_lower_start: f64,
340    n_upper_start: f64,
341) -> Vec<i32> {
342    // Calculate bounds for the Gaussian
343    let (ims_lower, ims_upper) = calculate_bounds_gaussian(mean, sigma, step_size, target_p, n_lower_start, n_upper_start);
344
345    // Create a list of tuples (inverse_ion_mobility_value, index) in reverse order
346    let indexed_values: Vec<(f64, usize)> = inverse_ion_mobility
347        .iter()
348        .rev()
349        .enumerate()
350        .map(|(i, &val)| (val, i))
351        .collect();
352
353    // Find the closest index to ims_lower
354    let upper_idx = indexed_values
355        .iter()
356        .enumerate()
357        .min_by(|(_, (val_a, _)), (_, (val_b, _))| {
358            (val_a - ims_lower).abs().partial_cmp(&(val_b - ims_lower).abs()).unwrap()
359        })
360        .map(|(idx, _)| idx)
361        .unwrap_or(0);
362
363    // Find the closest index to ims_upper
364    let lower_idx = indexed_values
365        .iter()
366        .enumerate()
367        .min_by(|(_, (val_a, _)), (_, (val_b, _))| {
368            (val_a - ims_upper).abs().partial_cmp(&(val_b - ims_upper).abs()).unwrap()
369        })
370        .map(|(idx, _)| idx)
371        .unwrap_or(indexed_values.len() - 1);
372
373    // Extract the indices of the scans in the found range
374    if lower_idx <= upper_idx {
375        indexed_values[lower_idx..=upper_idx]
376            .iter()
377            .map(|&(_, idx)| idx as i32)
378            .collect()
379    } else {
380        Vec::new()
381    }
382}
383
384
385/// Compute the abundance in each occurrence frame by looking at
386/// the probability of Normal(mean, sigma) within `[time - rt_cycle_length, time]`.
387pub fn calculate_abundance_gaussian(
388    time_map: &HashMap<i32, f64>,
389    occurrences: &[i32],
390    mean: f64,
391    sigma: f64,
392    cycle_length: f64,
393) -> Vec<f64> {
394    let mut frame_abundance = Vec::new();
395
396    for &occurrence in occurrences {
397        if let Some(&time) = time_map.get(&occurrence) {
398            let start = time - cycle_length;
399            let val = normal_cdf_range(start, time, mean, sigma);
400            frame_abundance.push(val);
401        }
402    }
403
404    frame_abundance
405}
406
407pub fn calculate_scan_occurrences_gaussian_par(
408    times: &[f64],
409    means: Vec<f64>,
410    sigmas: Vec<f64>,
411    target_p: f64,
412    step_size: f64,
413    n_lower_start: f64,
414    n_upper_start: f64,
415    num_threads: usize
416) -> Vec<Vec<i32>> {
417    let thread_pool = ThreadPoolBuilder::new()
418        .num_threads(num_threads)
419        .build()
420        .unwrap();
421
422    thread_pool.install(|| {
423        means.into_par_iter()
424            .zip(sigmas.into_par_iter())
425            .map(|(m, s)| {
426                calculate_scan_occurrence_gaussian(
427                    times,
428                    m,
429                    s,
430                    target_p,
431                    step_size,
432                    n_lower_start,
433                    n_upper_start
434                )
435            })
436            .collect()
437    })
438}
439
440/// Parallel version for multiple (mean, sigma) pairs to get abundance
441pub fn calculate_scan_abundances_gaussian_par(
442    time_map: &HashMap<i32, f64>,
443    occurrences: Vec<Vec<i32>>,
444    means: Vec<f64>,
445    sigmas: Vec<f64>,
446    cycle_length: f64,
447    num_threads: usize
448) -> Vec<Vec<f64>> {
449    let thread_pool = ThreadPoolBuilder::new()
450        .num_threads(num_threads)
451        .build()
452        .unwrap();
453
454    thread_pool.install(|| {
455        occurrences.into_par_iter()
456            .zip(means.into_par_iter())
457            .zip(sigmas.into_par_iter())
458            .map(|((occ, m), s)| {
459                calculate_abundance_gaussian(
460                    time_map,
461                    &occ,
462                    m,
463                    s,
464                    cycle_length
465                )
466            })
467            .collect()
468    })
469}
470
471#[cfg(test)]
472mod tests {
473    use super::*;
474
475    fn approx_eq(a: f64, b: f64, epsilon: f64) -> bool {
476        (a - b).abs() < epsilon
477    }
478
479    #[test]
480    fn test_normal_cdf_range() {
481        let mean = 0.0;
482        let std_dev = 1.0;
483
484        // For a standard normal, nearly all probability is within ~[-10, 10].
485        // So normal_cdf_range(-10, 10, 0, 1) should be ~1.0
486        let prob_all = normal_cdf_range(-10.0, 10.0, mean, std_dev);
487        assert!(approx_eq(prob_all, 1.0, 1e-6),
488                "CDF range from -10 to 10 should capture nearly all probability, got {prob_all}");
489
490        // Check an interval around mean ± 1σ -> about 68% of mass
491        let prob_1sigma = normal_cdf_range(-1.0, 1.0, mean, std_dev);
492        assert!(
493            (prob_1sigma - 0.68).abs() < 0.02,
494            "Expected ~0.68 within ±1σ, got {prob_1sigma}"
495        );
496    }
497
498    #[test]
499    fn test_calculate_bounds_gaussian() {
500        let mean = 0.0;
501        let sigma = 1.0;
502        let target = 0.68;
503        // We'll discretize in steps of 0.1
504        let (low, high) = calculate_bounds_gaussian(mean, sigma, 0.01, target, 5.0, 5.0);
505
506        // Check that the coverage is close to 0.68
507        let coverage = normal_cdf_range(low, high, mean, sigma);
508        assert!(
509            (coverage - target).abs() < 0.1,
510            "Expected coverage ~0.68, got {coverage} for interval [{low}, {high}]"
511        );
512    }
513
514    #[test]
515    fn test_calculate_frame_occurrence_gaussian() {
516        // Suppose we have 10 frames of retention times from 0.0 to 9.0
517        let retention_times: Vec<f64> = (0..10).map(|x| x as f64).collect();
518        let mean = 5.0;   // Centered around the 5th second
519        let sigma = 1.0;
520        let target_p = 0.68;
521        let step_size = 0.1;
522
523        // This should capture frames near t=5.0, about ±1.0 in the "most probable" sense.
524        // "lower_start" and "upper_start" here are up to you; let's do 5.0 each side
525        let frames = calculate_scan_occurrence_gaussian(
526            &retention_times,
527            mean,
528            sigma,
529            target_p,
530            step_size,
531            5.0,
532            5.0
533        );
534
535        // Expect frames near 4, 5, 6
536        // Because those times (4.0, 5.0, 6.0) are the main chunk of ±1σ around 5.0
537        assert!(
538            !frames.is_empty(),
539            "We expect at least a few frames around 5.0"
540        );
541        assert!(
542            frames.contains(&5),
543            "We definitely expect the central frame (index=5 in 1-based indexing) to be included"
544        );
545    }
546
547    #[test]
548    fn test_calculate_frame_abundance_gaussian() {
549        // Set up a mock time map: frame_index -> time
550        // We'll pretend each frame index i runs from i-1 to i in real-time
551        let mut time_map = HashMap::new();
552        for i in 1..=5 {
553            time_map.insert(i as i32, i as f64);
554        }
555
556        // Suppose we only have two frames to check
557        let occurrences = vec![1, 3];
558        let mean = 3.0;
559        let sigma = 1.0;
560        let im_cycle_length = 1.0;
561
562        let abundances = calculate_abundance_gaussian(
563            &time_map,
564            &occurrences,
565            mean,
566            sigma,
567            im_cycle_length
568        );
569
570        // We'll do a basic sanity check:
571        // - For frame 1, it integrates from time=0 to time=1.
572        // - For frame 3, from 2 to 3.
573        assert_eq!(abundances.len(), 2, "We should have 2 abundance values");
574        let (a1, a2) = (abundances[0], abundances[1]);
575
576        // The second abundance (covering [2,3]) should be bigger,
577        // because it's closer to mean=3.0
578        assert!(
579            a2 > a1,
580            "Expected frame near t=3 to have higher abundance than t=1"
581        );
582    }
583
584    #[test]
585    fn test_parallel_functions() {
586        // Just a quick sanity check
587        let retention_times: Vec<f64> = (0..10).map(|x| x as f64).collect();
588        let means = vec![3.0, 5.0];
589        let sigmas = vec![1.0, 1.5];
590
591        let target_p = 0.68;
592        let step_size = 0.1;
593        let num_threads = 2;
594
595        let res_occurrences = calculate_scan_occurrences_gaussian_par(
596            &retention_times,
597            means.clone(),
598            sigmas.clone(),
599            target_p,
600            step_size,
601            5.0,
602            5.0,
603            num_threads
604        );
605        assert_eq!(res_occurrences.len(), 2, "Should produce 2 sets of occurrences");
606
607        // Mock time_map for abundances
608        let mut time_map = HashMap::new();
609        for i in 1..=10 {
610            time_map.insert(i, i as f64);
611        }
612
613        let res_abundances = calculate_scan_abundances_gaussian_par(
614            &time_map,
615            res_occurrences,
616            means,
617            sigmas,
618            1.0,          // rt_cycle_length
619            num_threads
620        );
621        assert_eq!(res_abundances.len(), 2, "Should produce 2 sets of abundances");
622    }
623}