rustdf/cluster/
scoring.rs

1use std::cmp::Ordering;
2use std::collections::HashMap;
3use std::sync::Arc;
4
5use rayon::prelude::*;
6use crate::cluster::candidates::{CandidateOpts, PairFeatures, ScoreOpts};
7use crate::cluster::cluster::ClusterResult1D;
8use crate::cluster::feature::SimpleFeature;
9use crate::cluster::pseudo::cluster_mz_mu;
10use crate::cluster::utility::robust_noise_mad;
11use crate::data::dia::{DiaIndex, TimsDatasetDIA};
12
13#[derive(Clone, Copy, Debug)]
14pub enum MatchScoreMode {
15    /// Geometric / “geom” style: uses PairFeatures + ScoreOpts
16    Geom,
17    /// XIC correlation style: uses XicScoreOpts
18    Xic,
19}
20
21#[derive(Clone, Debug)]
22pub struct XicDetails {
23    /// RT XIC similarity in [0,1], if used.
24    pub s_rt: Option<f32>,
25    /// IM XIC similarity in [0,1], if used.
26    pub s_im: Option<f32>,
27    /// Intensity-ratio consistency term in (0,1], if used.
28    pub s_intensity: Option<f32>,
29
30    /// Raw RT Pearson correlation in [-1,1], if available.
31    pub r_rt: Option<f32>,
32    /// Raw IM Pearson correlation in [-1,1], if available.
33    pub r_im: Option<f32>,
34}
35
36#[derive(Clone, Debug)]
37pub struct ScoredHit {
38    /// Index into the MS2 slice (typically `self.ms2` in FragmentIndex).
39    pub frag_idx: usize,
40    /// Final scalar score (geom or XIC, depending on mode).
41    pub score: f32,
42    /// Geometric feature bundle (only set in `Geom` mode).
43    pub geom: Option<PairFeatures>,
44    /// XIC scoring details (only set in `Xic` mode).
45    pub xic: Option<XicDetails>,
46}
47
48#[derive(Clone, Copy)]
49pub enum PrecursorLike<'a> {
50    Cluster(&'a ClusterResult1D),
51    Feature(&'a SimpleFeature),
52}
53
54#[derive(Clone, Debug)]
55pub struct XicScoreOpts {
56    /// Weight for RT-XIC correlation term.
57    pub w_rt: f32,
58    /// Weight for IM-XIC correlation term.
59    pub w_im: f32,
60    /// Weight for intensity-ratio consistency term.
61    pub w_intensity: f32,
62
63    /// Pseudo-temperature for the log-intensity ratio penalty.
64    /// Larger = more tolerant to MS1/MS2 intensity mismatch.
65    pub intensity_tau: f32,
66
67    /// Minimum final score in [0,1] to accept a pair.
68    pub min_total_score: f32,
69
70    /// Whether to use RT/IM/intensity terms at all.
71    pub use_rt: bool,
72    pub use_im: bool,
73    pub use_intensity: bool,
74}
75
76impl Default for XicScoreOpts {
77    fn default() -> Self {
78        Self {
79            // shape dominates, intensity is a weak prior
80            w_rt: 0.45,
81            w_im: 0.45,
82            w_intensity: 0.10,
83            intensity_tau: 1.5,
84            min_total_score: 0.0, // no cutoff by default
85
86            use_rt: true,
87            use_im: true,
88            use_intensity: true,
89        }
90    }
91}
92
93#[inline]
94fn zscore(v: &[f32]) -> Option<Vec<f32>> {
95    let n = v.len();
96    if n < 3 {
97        return None;
98    }
99    let mut sum = 0.0f64;
100    let mut sum2 = 0.0f64;
101    for &x in v {
102        let xf = x as f64;
103        sum += xf;
104        sum2 += xf * xf;
105    }
106    let n_f = n as f64;
107    let mean = sum / n_f;
108    let var = (sum2 / n_f) - mean * mean;
109    if !var.is_finite() || var <= 0.0 {
110        return None;
111    }
112    let std = var.sqrt();
113    let mut out = Vec::with_capacity(n);
114    for &x in v {
115        out.push(((x as f64 - mean) / std) as f32);
116    }
117    Some(out)
118}
119
120/// Pearson correlation on two traces, cropped to the common length and z-scored.
121fn pearson_corr_z(a: &[f32], b: &[f32]) -> Option<f32> {
122    let n = a.len().min(b.len());
123    if n < 3 {
124        return None;
125    }
126    let az = zscore(&a[..n])?;
127    let bz = zscore(&b[..n])?;
128
129    let mut num = 0.0f64;
130    for i in 0..n {
131        num += az[i] as f64 * bz[i] as f64;
132    }
133    let den = (n as f64).max(1.0);
134    let r = (num / den)
135        .max(-1.0)
136        .min(1.0);
137    Some(r as f32)
138}
139
140pub fn xic_match_score(
141    ms1: &ClusterResult1D,
142    ms2: &ClusterResult1D,
143    opts: &XicScoreOpts,
144) -> Option<(XicDetails, f32)> {
145    let mut score   = 0.0f32;
146    let mut w_sum   = 0.0f32;
147
148    let mut s_rt: Option<f32>         = None;
149    let mut s_im: Option<f32>         = None;
150    let mut s_intensity: Option<f32>  = None;
151    let mut r_rt: Option<f32>         = None;
152    let mut r_im: Option<f32>         = None;
153
154    // ---- RT XIC: Pearson, mapped from [-1,1] to [0,1] ----
155    if opts.use_rt {
156        if let (Some(ref rt1), Some(ref rt2)) = (&ms1.rt_trace, &ms2.rt_trace) {
157            if let Some(r) = pearson_corr_z(rt1, rt2) {
158                if r.is_finite() {
159                    let s = 0.5 * (r + 1.0); // [-1,1] -> [0,1]
160                    r_rt  = Some(r);
161                    s_rt  = Some(s);
162                    score += opts.w_rt * s;
163                    w_sum += opts.w_rt;
164                }
165            }
166        }
167    }
168
169    // ---- IM XIC: Pearson, same mapping ----
170    if opts.use_im {
171        if let (Some(ref im1), Some(ref im2)) = (&ms1.im_trace, &ms2.im_trace) {
172            if let Some(r) = pearson_corr_z(im1, im2) {
173                if r.is_finite() {
174                    let s = 0.5 * (r + 1.0);
175                    r_im  = Some(r);
176                    s_im  = Some(s);
177                    score += opts.w_im * s;
178                    w_sum += opts.w_im;
179                }
180            }
181        }
182    }
183
184    // ---- Intensity ratio: weak, symmetric penalty on log ratio ----
185    if opts.use_intensity && opts.w_intensity > 0.0 && opts.intensity_tau > 0.0 {
186        // Use RT trace integrals as a proxy; fall back to raw_sum if needed.
187        let i1 = if let Some(ref rt1) = ms1.rt_trace {
188            rt1.iter().fold(0.0f32, |acc, x| acc + x.max(0.0))
189        } else {
190            ms1.raw_sum.max(0.0)
191        };
192
193        let i2 = if let Some(ref rt2) = ms2.rt_trace {
194            rt2.iter().fold(0.0f32, |acc, x| acc + x.max(0.0))
195        } else {
196            ms2.raw_sum.max(0.0)
197        };
198
199        if i1 > 0.0 && i2 > 0.0 {
200            let ratio = (i2 / i1).max(1e-6);
201            let d     = ratio.ln().abs(); // |log(I2/I1)|
202            let s     = (-d / opts.intensity_tau).exp(); // in (0,1]
203            if s.is_finite() {
204                s_intensity = Some(s);
205                score      += opts.w_intensity * s;
206                w_sum      += opts.w_intensity;
207            }
208        }
209    }
210
211    if w_sum <= 0.0 {
212        return None;
213    }
214
215    let final_score = score / w_sum;
216    if !final_score.is_finite() {
217        None
218    } else {
219        let details = XicDetails {
220            s_rt,
221            s_im,
222            s_intensity,
223            r_rt,
224            r_im,
225        };
226        Some((details, final_score.clamp(0.0, 1.0)))
227    }
228}
229
230pub fn assign_ms2_to_best_ms1_by_xic(
231    ms1: &[ClusterResult1D],
232    ms2: &[ClusterResult1D],
233    pairs: &[(usize, usize)],
234    opts: &XicScoreOpts,
235) -> Vec<(usize, usize, f32)> {
236    if ms1.is_empty() || ms2.is_empty() || pairs.is_empty() {
237        return Vec::new();
238    }
239
240    // Precompute RT-integrated intensities (or raw_sum fallback) once per cluster.
241    let ms1_int = precompute_intensities(ms1);
242    let ms2_int = precompute_intensities(ms2);
243
244    // Dense best-hit buffer: one slot per MS2, no hashing.
245    let mut best: Vec<Option<(usize, f32)>> = vec![None; ms2.len()];
246
247    for &(ms2_idx, ms1_idx) in pairs {
248        if ms1_idx >= ms1.len() || ms2_idx >= ms2.len() {
249            continue;
250        }
251
252        // Use the precomputed-intensity version of the scorer.
253        let s = match xic_match_score_precomputed(
254            ms1_idx,
255            ms2_idx,
256            ms1,
257            ms2,
258            &ms1_int,
259            &ms2_int,
260            opts,
261        ) {
262            Some(v) => v,
263            None => continue,
264        };
265
266        // You can drop this if you already enforce the cutoff in xic_match_score_precomputed.
267        if s < opts.min_total_score {
268            continue;
269        }
270
271        match &mut best[ms2_idx] {
272            Some((_best_i, best_s)) if s <= *best_s => {
273                // keep existing winner
274            }
275            slot => {
276                // new best (or first) hit for this MS2
277                *slot = Some((ms1_idx, s));
278            }
279        }
280    }
281
282    // Build result already sorted by ms2_idx.
283    let mut out = Vec::new();
284    out.reserve(pairs.len().min(ms2.len()));
285    for (ms2_idx, maybe) in best.into_iter().enumerate() {
286        if let Some((ms1_idx, s)) = maybe {
287            out.push((ms2_idx, ms1_idx, s));
288        }
289    }
290    out
291}
292
293/// Jaccard overlap in **absolute seconds** for two closed intervals.
294#[inline]
295pub fn jaccard_time(a_lo: f64, a_hi: f64, b_lo: f64, b_hi: f64) -> f32 {
296    if !a_lo.is_finite() || !a_hi.is_finite() || !b_lo.is_finite() || !b_hi.is_finite() {
297        return 0.0;
298    }
299    if a_hi < b_lo || b_hi < a_lo {
300        return 0.0;
301    }
302    let inter = (a_hi.min(b_hi) - a_lo.max(b_lo)).max(0.0);
303    let union = (a_hi.max(b_hi) - a_lo.min(b_lo)).max(0.0);
304    if union <= 0.0 {
305        0.0
306    } else {
307        (inter / union) as f32
308    }
309}
310
311#[inline]
312fn overlap_f32(a: (f32, f32), b: (f32, f32)) -> bool {
313    let (a0, a1) = a;
314    let (b0, b1) = b;
315    a0.is_finite()
316        && a1.is_finite()
317        && b0.is_finite()
318        && b1.is_finite()
319        && a1 >= b0
320        && b1 >= a0
321}
322
323#[inline]
324fn overlap_u32(a: (u32, u32), b: (u32, u32)) -> bool {
325    a.1 >= b.0 && b.1 >= a.0
326}
327
328/// Coarse RT bucketing over absolute time in seconds (closed intervals).
329#[derive(Clone, Debug)]
330pub struct RtBuckets {
331    lo: f64,
332    inv_bw: f64,
333    buckets: Vec<Vec<usize>>,
334}
335
336impl RtBuckets {
337    pub fn build(
338        global_lo: f64,
339        global_hi: f64,
340        bucket_width: f64,
341        ms1_time_bounds: &[(f64, f64)],
342        ms1_keep: Option<&[bool]>,
343    ) -> Self {
344        let bw = bucket_width.max(0.5);
345        let lo = global_lo.floor();
346        let hi = global_hi.ceil().max(lo + bw);
347        let n = (((hi - lo) / bw).ceil() as usize).max(1);
348        let inv_bw = 1.0 / bw;
349        let mut buckets = vec![Vec::<usize>::new(); n];
350
351        let clamp = |x: f64| -> usize {
352            if x <= lo {
353                0
354            } else {
355                (((x - lo) * inv_bw).floor() as isize).clamp(0, (n as isize) - 1) as usize
356            }
357        };
358
359        for (i, &(t0, t1)) in ms1_time_bounds.iter().enumerate() {
360            if let Some(keep) = ms1_keep {
361                if !keep[i] {
362                    continue;
363                }
364            }
365            if !(t0.is_finite() && t1.is_finite()) || t1 <= t0 {
366                continue;
367            }
368            let b0 = clamp(t0);
369            let b1 = clamp(t1);
370            for b in b0..=b1 {
371                buckets[b].push(i);
372            }
373        }
374        Self { lo, inv_bw, buckets }
375    }
376
377    #[inline]
378    fn range(&self, t0: f64, t1: f64) -> (usize, usize) {
379        let n = self.buckets.len();
380        let clamp = |x: f64| -> usize {
381            if x <= self.lo {
382                0
383            } else {
384                (((x - self.lo) * self.inv_bw).floor() as isize).clamp(0, (n as isize) - 1) as usize
385            }
386        };
387        let a = clamp(t0.min(t1));
388        let b = clamp(t0.max(t1));
389        (a.min(b), a.max(b))
390    }
391
392    /// Append MS1 indices that touch [t0, t1] (not deduped).
393    #[inline]
394    pub fn gather(&self, t0: f64, t1: f64, out: &mut Vec<usize>) {
395        let (b0, b1) = self.range(t0, t1);
396        for b in b0..=b1 {
397            out.extend_from_slice(&self.buckets[b]);
398        }
399    }
400}
401
402/// A built search index over MS1 with per-group eligibility masks.
403#[derive(Clone, Debug)]
404pub struct PrecursorSearchIndex {
405    ms1_time_bounds: Vec<(f64, f64)>,
406    ms1_keep: Vec<bool>,
407    rt_buckets: RtBuckets,
408    /// group -> mask[i] (true if MS1[i] is eligible for that group by (mz ∩ isolation) AND (scan ∩ ranges))
409    ms1_group_ok: HashMap<u32, Vec<bool>>,
410    /// frame_id -> time (seconds)
411    frame_time: Arc<HashMap<u32, f64>>,
412    /// DIA program index (for tile-level checks).
413    dia_index: Arc<DiaIndex>,
414}
415
416impl PrecursorSearchIndex {
417    /// Build once per dataset / MS1 set.
418    pub fn build(ds: &TimsDatasetDIA, ms1: &[ClusterResult1D], opts: &CandidateOpts) -> Self {
419        let frame_time = Arc::new(ds.dia_index.frame_time.clone());
420        let dia_index = Arc::new(ds.dia_index.clone());
421
422        // 1) Absolute MS1 time bounds
423        let ms1_time_bounds: Vec<(f64, f64)> = ms1
424            .par_iter()
425            .map(|c| {
426                let mut t_lo = f64::INFINITY;
427                let mut t_hi = f64::NEG_INFINITY;
428
429                // Preferred: use stored frame_ids_used if available
430                if !c.frame_ids_used.is_empty() {
431                    for &fid in &c.frame_ids_used {
432                        if let Some(&t) = frame_time.get(&fid) {
433                            if t < t_lo { t_lo = t; }
434                            if t > t_hi { t_hi = t; }
435                        }
436                    }
437                }
438
439                // Fallback: infer from rt_window if frame_ids_used is empty
440                if !t_lo.is_finite() || !t_hi.is_finite() {
441                    let (rt_lo, rt_hi) = c.rt_window;
442                    if rt_hi >= rt_lo {
443                        for fid in rt_lo as u32..=rt_hi as u32 {
444                            if let Some(&t) = frame_time.get(&fid) {
445                                if t < t_lo { t_lo = t; }
446                                if t > t_hi { t_hi = t; }
447                            }
448                        }
449                    }
450                }
451
452                // If still invalid, caller will filter this out via ms1_keep.
453                (t_lo, t_hi)
454            })
455            .collect();
456
457        // 2) Compute adaptive threshold for min_raw_sum (from MS1 raw_sum distribution)
458        let ms1_raw_sums: Vec<f32> = ms1
459            .iter()
460            .filter(|c| c.ms_level == 1)
461            .map(|c| c.raw_sum)
462            .collect();
463        let ms1_noise = robust_noise_mad(&ms1_raw_sums);
464        let effective_min_raw_sum = opts.min_raw_sum.effective(ms1_noise);
465
466        // 3) Keep mask for MS1
467        let ms1_keep: Vec<bool> = ms1
468            .par_iter()
469            .enumerate()
470            .map(|(i, c)| {
471                if c.ms_level != 1 {
472                    return false;
473                }
474                if c.raw_sum < effective_min_raw_sum {
475                    return false;
476                }
477                let (t0, t1) = ms1_time_bounds[i];
478                if !(t0.is_finite() && t1.is_finite()) || t1 <= t0 {
479                    return false;
480                }
481                if let Some(max_rt) = opts.max_ms1_rt_span_sec {
482                    if (t1 - t0) > max_rt {
483                        return false;
484                    }
485                }
486                true
487            })
488            .collect();
489
490        // 3) RT buckets across MS1
491        let (mut rt_min, mut rt_max) = (f64::INFINITY, f64::NEG_INFINITY);
492        for &(a, b) in &ms1_time_bounds {
493            if a.is_finite() {
494                rt_min = rt_min.min(a);
495            }
496            if b.is_finite() {
497                rt_max = rt_max.max(b);
498            }
499        }
500        if !rt_min.is_finite() || !rt_max.is_finite() || rt_max <= rt_min {
501            rt_min = 0.0;
502            rt_max = 1.0;
503        }
504        let rt_buckets = RtBuckets::build(
505            rt_min,
506            rt_max,
507            opts.rt_bucket_width,
508            &ms1_time_bounds,
509            Some(&ms1_keep),
510        );
511
512        // 4) Per-group eligibility masks (mz ∩ isolation AND scans ∩ program), independent of RT.
513        let ms1_group_ok: HashMap<u32, Vec<bool>> = ds
514            .dia_index
515            .group_to_isolation
516            .par_iter()
517            .map(|(&g, mz_rows)| {
518                let scans = ds
519                    .dia_index
520                    .group_to_scan_ranges
521                    .get(&g)
522                    .cloned()
523                    .unwrap_or_default();
524                let mz_rows_f32: Vec<(f32, f32)> =
525                    mz_rows.iter().map(|&(a, b)| (a as f32, b as f32)).collect();
526                let scan_rows_u32: Vec<(u32, u32)> = scans.iter().copied().collect();
527
528                if mz_rows_f32.is_empty() || scan_rows_u32.is_empty() {
529                    return (g, vec![false; ms1.len()]);
530                }
531
532                let mask: Vec<bool> = ms1
533                    .par_iter()
534                    .map(|c| {
535                        if c.ms_level != 1 {
536                            return false;
537                        }
538                        let mz_ok = mz_rows_f32
539                            .iter()
540                            .any(|&w| overlap_f32(c.mz_window.unwrap(), w));
541                        if !mz_ok {
542                            return false;
543                        }
544                        let im_u32 = (c.im_window.0 as u32, c.im_window.1 as u32);
545                        scan_rows_u32.iter().any(|&s| overlap_u32(im_u32, s))
546                    })
547                    .collect();
548
549                (g, mask)
550            })
551            .collect();
552
553        Self {
554            ms1_time_bounds,
555            ms1_keep,
556            rt_buckets,
557            ms1_group_ok,
558            frame_time,
559            dia_index,
560        }
561    }
562
563    /// Enumerate physically plausible MS1–MS2 pairs.
564    ///
565    /// Conditions:
566    ///   - Same window group.
567    ///   - RT overlap (with optional Jaccard threshold).
568    ///   - IM window overlap (min_im_overlap_scans).
569    ///   - Apex deltas in RT/IM within user bounds.
570    ///   - NEW: there exists at least one tile (ProgramSlice) where:
571    ///       * precursor IM window overlaps tile scans AND precursor m/z is in tile isolation
572    ///       * fragment IM window overlaps tile scans (no m/z restriction).
573    pub fn enumerate_pairs(
574        &self,
575        ms1: &[ClusterResult1D],
576        ms2: &[ClusterResult1D],
577        opts: &CandidateOpts,
578    ) -> Vec<(usize, usize)> {
579        // Precompute MS2 absolute time bounds
580        let ms2_time_bounds: Vec<(f64, f64)> = ms2
581            .par_iter()
582            .map(|c| {
583                let mut t_lo = f64::INFINITY;
584                let mut t_hi = f64::NEG_INFINITY;
585
586                if !c.frame_ids_used.is_empty() {
587                    for &fid in &c.frame_ids_used {
588                        if let Some(&t) = self.frame_time.get(&fid) {
589                            if t < t_lo { t_lo = t; }
590                            if t > t_hi { t_hi = t; }
591                        }
592                    }
593                }
594
595                // Fallback: use rt_window if no frame_ids_used
596                if !t_lo.is_finite() || !t_hi.is_finite() {
597                    let (rt_lo, rt_hi) = c.rt_window;
598                    if rt_hi >= rt_lo {
599                        for fid in rt_lo as u32..=rt_hi as u32 {
600                            if let Some(&t) = self.frame_time.get(&fid) {
601                                if t < t_lo { t_lo = t; }
602                                if t > t_hi { t_hi = t; }
603                            }
604                        }
605                    }
606                }
607
608                (t_lo, t_hi)
609            })
610            .collect();
611
612        let ms2_time_bounds = Arc::new(ms2_time_bounds);
613
614        // Compute adaptive threshold for min_raw_sum (from MS2 raw_sum distribution)
615        let ms2_raw_sums: Vec<f32> = ms2
616            .iter()
617            .filter(|c| c.ms_level == 2)
618            .map(|c| c.raw_sum)
619            .collect();
620        let ms2_noise = robust_noise_mad(&ms2_raw_sums);
621        let effective_min_raw_sum = opts.min_raw_sum.effective(ms2_noise);
622
623        // Keep MS2s
624        let ms2_keep: Vec<bool> = ms2
625            .par_iter()
626            .enumerate()
627            .map(|(i, c)| {
628                if c.ms_level != 2 {
629                    return false;
630                }
631                if c.window_group.is_none() {
632                    return false;
633                }
634                if c.raw_sum < effective_min_raw_sum {
635                    return false;
636                }
637                let (mut t0, mut t1) = ms2_time_bounds[i];
638                if t0.is_finite() {
639                    t0 -= opts.ms2_rt_guard_sec;
640                }
641                if t1.is_finite() {
642                    t1 += opts.ms2_rt_guard_sec;
643                }
644                if !(t0.is_finite() && t1.is_finite() && t1 > t0) {
645                    return false;
646                }
647                if let Some(max_rt) = opts.max_ms2_rt_span_sec {
648                    if (t1 - t0) > max_rt {
649                        return false;
650                    }
651                }
652                true
653            })
654            .collect();
655
656        // Group MS2 by window_group
657        let mut by_group: HashMap<u32, Vec<usize>> = HashMap::new();
658        for (j, c2) in ms2.iter().enumerate() {
659            if !ms2_keep[j] {
660                continue;
661            }
662            if let Some(g) = c2.window_group {
663                by_group.entry(g).or_default().push(j);
664            }
665        }
666
667        let idx_arc = Arc::new(self.clone());
668        let ms2_tb = ms2_time_bounds.clone();
669
670        let mut out = by_group
671            .into_par_iter()
672            .flat_map(|(g, js)| {
673                // Eligibility mask for this group
674                let mask_vec: Vec<bool> = idx_arc
675                    .ms1_group_ok
676                    .get(&g)
677                    .cloned()
678                    .unwrap_or_else(|| vec![false; ms1.len()]);
679                let mask_arc = Arc::new(mask_vec);
680
681                let idx = idx_arc.clone();
682                let tb = ms2_tb.clone();
683                let dia_index = idx.dia_index.clone();
684
685                // NEW: program slices for this group, shared across MS2 in this group
686                let slices_vec = dia_index.program_slices_for_group(g);
687                let slices = Arc::new(slices_vec);
688
689                js.into_par_iter().flat_map(move |j| {
690                    let (mut t2_lo, mut t2_hi) = tb[j];
691                    if t2_lo.is_finite() {
692                        t2_lo -= opts.ms2_rt_guard_sec;
693                    }
694                    if t2_hi.is_finite() {
695                        t2_hi += opts.ms2_rt_guard_sec;
696                    }
697
698                    // RT prefilter via buckets
699                    let mut hits = Vec::<usize>::new();
700                    idx.rt_buckets.gather(t2_lo, t2_hi, &mut hits);
701                    hits.sort_unstable();
702                    hits.dedup();
703
704                    let mask = mask_arc.clone();
705                    let slices = slices.clone();
706
707                    let mut local = Vec::<(usize, usize)>::with_capacity(16);
708                    for i in hits {
709                        if !idx.ms1_keep[i] {
710                            continue;
711                        }
712                        if !mask[i] {
713                            continue;
714                        }
715
716                        let (t1_lo, t1_hi) = idx.ms1_time_bounds[i];
717                        if !(t1_lo.is_finite() && t1_hi.is_finite()) {
718                            continue;
719                        }
720
721                        // RT overlap + optional Jaccard
722                        if t1_hi < t2_lo || t2_hi < t1_lo {
723                            continue;
724                        }
725                        if opts.min_rt_jaccard > 0.0 {
726                            let jacc = jaccard_time(t1_lo, t1_hi, t2_lo, t2_hi);
727                            if jacc < opts.min_rt_jaccard {
728                                continue;
729                            }
730                        }
731
732                        let im1 = ms1[i].im_window;
733                        let im2 = ms2[j].im_window;
734
735                        // Basic IM overlap (before tile check)
736                        let im_overlap = {
737                            let lo = im1.0.max(im2.0);
738                            let hi = im1.1.min(im2.1);
739                            hi.saturating_sub(lo).saturating_add(1)
740                        };
741                        if im_overlap < opts.min_im_overlap_scans {
742                            continue;
743                        }
744
745                        // Apex deltas in IM
746                        if let Some(max_d) = opts.max_scan_apex_delta {
747                            let s1 = ms1[i].im_fit.mu;
748                            let s2 = ms2[j].im_fit.mu;
749                            if s1.is_finite() && s2.is_finite() {
750                                let d = (s1 - s2).abs() as f32;
751                                if d > max_d as f32 {
752                                    continue;
753                                }
754                            } else {
755                                continue;
756                            }
757                        }
758
759                        // Apex deltas in RT
760                        if let Some(max_dt) = opts.max_rt_apex_delta_sec {
761                            let r1 = ms1[i].rt_fit.mu;
762                            let r2 = ms2[j].rt_fit.mu;
763                            if r1.is_finite() && r2.is_finite() {
764                                if (r1 - r2).abs() > max_dt {
765                                    continue;
766                                }
767                            } else {
768                                continue;
769                            }
770                        }
771
772                        // ---- Tile-level physical check ----
773                        let prec_mz = match cluster_mz_mu(&ms1[i]) {
774                            Some(m) if m.is_finite() && m > 0.0 => m,
775                            _ => continue,
776                        };
777
778                        // Require the precursor IM apex to lie inside some tile
779                        let prec_im_apex = ms1[i].im_fit.mu;
780                        if !prec_im_apex.is_finite() {
781                            continue;
782                        }
783
784                        // Tiles where this precursor could have been selected in g (apex-based)
785                        let prec_tiles = dia_index.tiles_for_precursor_in_group(
786                            g,
787                            prec_mz,
788                            prec_im_apex,
789                        );
790                        if prec_tiles.is_empty() {
791                            continue;
792                        }
793
794                        // Tiles where this fragment cluster could appear in g (window-based)
795                        let frag_tiles = dia_index.tiles_for_fragment_in_group(g, im2);
796                        if frag_tiles.is_empty() {
797                            continue;
798                        }
799
800                        // At least one shared tile index (physical co-occurrence)
801                        let mut ok = false;
802                        for t in &prec_tiles {
803                            if frag_tiles.contains(t) {
804                                ok = true;
805                                break;
806                            }
807                        }
808                        if !ok {
809                            continue;
810                        }
811
812                        // NEW: reject fragments whose own selection lies in the same isolation tile
813                        // as the precursor (to avoid unfragmented precursor intensity in MS2).
814                        if opts.reject_frag_inside_precursor_tile {
815                            if let Some(frag_mz) = cluster_mz_mu(&ms2[j]) {
816                                if frag_mz.is_finite() && frag_mz > 0.0 {
817                                    let mut reject = false;
818
819                                    // only tiles in intersection prec_tiles ∩ frag_tiles
820                                    for &tile_idx in &prec_tiles {
821                                        if !frag_tiles.contains(&tile_idx) {
822                                            continue;
823                                        }
824                                        if tile_idx >= slices.len() {
825                                            continue;
826                                        }
827                                        let s = &slices[tile_idx]; // ProgramSlice { mz_lo, mz_hi, scan_lo, scan_hi }
828
829                                        if (frag_mz as f64) >= s.mz_lo
830                                            && (frag_mz as f64) <= s.mz_hi
831                                        {
832                                            // fragment looks like it's still inside the
833                                            // precursor's isolation tile -> drop this pair
834                                            reject = true;
835                                            break;
836                                        }
837                                    }
838
839                                    if reject {
840                                        continue;
841                                    }
842                                }
843                            }
844                        }
845
846                        // Survives all guards
847                        local.push((j, i));
848                    }
849
850                    local.into_par_iter()
851                })
852            })
853            .collect::<Vec<(usize, usize)>>();
854
855        out.sort_unstable();
856        out.dedup();
857        out
858    }
859}
860
861/// Convenience: build, enumerate, done.
862pub fn enumerate_ms2_ms1_pairs_simple(
863    ds: &TimsDatasetDIA,
864    ms1: &[ClusterResult1D],
865    ms2: &[ClusterResult1D],
866    opts: &CandidateOpts,
867) -> Vec<(usize, usize)> {
868    let idx = PrecursorSearchIndex::build(ds, ms1, opts);
869    idx.enumerate_pairs(ms1, ms2, opts)
870}
871
872// ---------------------------------------------------------------------------
873// Scoring (unchanged; left here for completeness)
874// ---------------------------------------------------------------------------
875
876/// Single scalar score in [0, ∞), larger is better.
877/// Robust to missing fits (uses `shape_neutral` if shape is unavailable).
878#[inline]
879fn score_from_features(f: &PairFeatures, opts: &ScoreOpts) -> f32 {
880    let shape_term = if f.shape_ok { f.s_shape } else { opts.shape_neutral };
881
882    let rt_close = crate::cluster::candidates::exp_decay(f.rt_apex_delta_s, opts.rt_apex_scale_s);
883    let im_close = crate::cluster::candidates::exp_decay(f.im_apex_delta_scans, opts.im_apex_scale_scans);
884
885    let im_ratio = (f.im_overlap_scans as f32) / (f.im_union_scans as f32);
886
887    let ms1_int = crate::cluster::candidates::safe_log1p(f.ms1_raw_sum);
888
889    opts.w_jacc_rt * f.jacc_rt
890        + opts.w_shape * shape_term
891        + opts.w_rt_apex * rt_close
892        + opts.w_im_apex * im_close
893        + opts.w_im_overlap * im_ratio
894        + opts.w_ms1_intensity * ms1_int
895}
896
897/// Score all pairs (ms2_idx, ms1_idx).
898pub fn score_pairs(
899    ms1: &[ClusterResult1D],
900    ms2: &[ClusterResult1D],
901    pairs: &[(usize, usize)],
902    opts: &ScoreOpts,
903) -> Vec<(usize, usize, PairFeatures, f32)> {
904    pairs.par_iter().map(|&(j, i)| {
905        let f = crate::cluster::candidates::build_features(&ms1[i], &ms2[j], opts);
906        let s = score_from_features(&f, opts);
907        (j, i, f, s)
908    }).collect()
909}
910
911/// For each MS2, choose the best MS1 index (by score, then deterministic tie-breaks).
912/// Returns a Vec<Option<usize>> indexed by ms2_idx.
913pub fn best_ms1_for_each_ms2(
914    ms1: &[ClusterResult1D],
915    ms2: &[ClusterResult1D],
916    pairs: &[(usize, usize)],
917    opts: &ScoreOpts,
918) -> Vec<Option<usize>> {
919    let scored = score_pairs(ms1, ms2, pairs, opts);
920
921    // group by ms2_idx
922    let mut by_ms2: Vec<Vec<(usize, PairFeatures, f32)>> = vec![Vec::new(); ms2.len()];
923    for (j, i, f, s) in scored {
924        by_ms2[j].push((i, f, s));
925    }
926
927    by_ms2
928        .into_par_iter()
929        .map(|mut vec_i| {
930            if vec_i.is_empty() { return None; }
931            vec_i.sort_unstable_by(|a, b| {
932                // primary: score desc
933                match b.2.partial_cmp(&a.2).unwrap_or(Ordering::Equal) {
934                    Ordering::Equal => {
935                        // tie-breaks (deterministic):
936                        // 1) higher jaccard
937                        let ja = a.1.jacc_rt;
938                        let jb = b.1.jacc_rt;
939                        if (ja - jb).abs() > 1e-6 {
940                            return jb.partial_cmp(&ja).unwrap_or(Ordering::Equal);
941                        }
942                        // 2) if both have shape, higher s_shape
943                        let sa = if a.1.shape_ok { a.1.s_shape } else { 0.0 };
944                        let sb = if b.1.shape_ok { b.1.s_shape } else { 0.0 };
945                        if (sa - sb).abs() > 1e-6 {
946                            return sb.partial_cmp(&sa).unwrap_or(Ordering::Equal);
947                        }
948                        // 3) smaller RT apex delta
949                        let dra = a.1.rt_apex_delta_s;
950                        let drb = b.1.rt_apex_delta_s;
951                        if (dra - drb).abs() > 1e-6 {
952                            return dra.partial_cmp(&drb).unwrap_or(Ordering::Equal);
953                        }
954                        // 4) smaller IM apex delta
955                        let dia = a.1.im_apex_delta_scans;
956                        let dib = b.1.im_apex_delta_scans;
957                        if (dia - dib).abs() > 1e-6 {
958                            return dia.partial_cmp(&dib).unwrap_or(Ordering::Equal);
959                        }
960                        // 5) larger IM overlap
961                        let oa = a.1.im_overlap_scans;
962                        let ob = b.1.im_overlap_scans;
963                        if oa != ob {
964                            return ob.cmp(&oa);
965                        }
966                        // 6) higher MS1 intensity
967                        let ia = a.1.ms1_raw_sum;
968                        let ib = b.1.ms1_raw_sum;
969                        ib.partial_cmp(&ia).unwrap_or(Ordering::Equal)
970                    }
971                    ord => ord,
972                }
973            });
974            Some(vec_i[0].0)
975        })
976        .collect()
977}
978
979/// Build an MS1 → Vec<MS2> map from a winner list (ms2→best ms1).
980/// Returns a Vec<Vec<usize>> with length ms1.len(), where each entry lists MS2 indices.
981pub fn ms1_to_ms2_map(
982    ms1_len: usize,
983    ms2_to_best_ms1: &[Option<usize>],
984) -> Vec<Vec<usize>> {
985    let mut out = vec![Vec::<usize>::new(); ms1_len];
986    for (ms2_idx, maybe_ms1) in ms2_to_best_ms1.iter().enumerate() {
987        if let Some(i) = maybe_ms1 {
988            if *i < ms1_len {
989                out[*i].push(ms2_idx);
990            }
991        }
992    }
993    out
994}
995
996fn precompute_intensities(clusters: &[ClusterResult1D]) -> Vec<f32> {
997    clusters
998        .par_iter()
999        .map(|c| {
1000            if let Some(ref rt) = c.rt_trace {
1001                rt.iter().fold(0.0f32, |acc, x| acc + x.max(0.0))
1002            } else {
1003                c.raw_sum.max(0.0)
1004            }
1005        })
1006        .collect()
1007}
1008
1009#[inline]
1010fn top_intensity_member(feat: &SimpleFeature) -> Option<&ClusterResult1D> {
1011    if feat.member_clusters.is_empty() {
1012        return None;
1013    }
1014    feat.member_clusters
1015        .iter()
1016        .max_by(|a, b| a.raw_sum.partial_cmp(&b.raw_sum).unwrap_or(Ordering::Equal))
1017}
1018
1019/// Geometric score for a *single* precursor cluster vs a fragment.
1020///
1021/// Returns (PairFeatures, score).
1022#[inline]
1023pub fn geom_match_score_single_cluster(
1024    ms1: &ClusterResult1D,
1025    ms2: &ClusterResult1D,
1026    opts: &ScoreOpts,
1027) -> (PairFeatures, f32) {
1028    let f = crate::cluster::candidates::build_features(ms1, ms2, opts);
1029    let s = score_from_features(&f, opts);
1030    (f, s)
1031}
1032
1033pub fn xic_match_score_precomputed(
1034    ms1_idx: usize,
1035    ms2_idx: usize,
1036    ms1: &[ClusterResult1D],
1037    ms2: &[ClusterResult1D],
1038    ms1_int: &[f32],
1039    ms2_int: &[f32],
1040    opts: &XicScoreOpts,
1041) -> Option<f32> {
1042    let mut score = 0.0f32;
1043    let mut w_sum = 0.0f32;
1044
1045    let c1 = &ms1[ms1_idx];
1046    let c2 = &ms2[ms2_idx];
1047
1048    // RT XIC
1049    if opts.use_rt {
1050        if let (Some(ref rt1), Some(ref rt2)) = (&c1.rt_trace, &c2.rt_trace) {
1051            if let Some(r) = pearson_corr_z(rt1, rt2) {
1052                if r.is_finite() {
1053                    let s = 0.5 * (r + 1.0);
1054                    score += opts.w_rt * s;
1055                    w_sum += opts.w_rt;
1056                }
1057            }
1058        }
1059    }
1060
1061    // IM XIC
1062    if opts.use_im {
1063        if let (Some(ref im1), Some(ref im2)) = (&c1.im_trace, &c2.im_trace) {
1064            if let Some(r) = pearson_corr_z(im1, im2) {
1065                if r.is_finite() {
1066                    let s = 0.5 * (r + 1.0);
1067                    score += opts.w_im * s;
1068                    w_sum += opts.w_im;
1069                }
1070            }
1071        }
1072    }
1073
1074    // Intensity ratio using precomputed integrals
1075    if opts.use_intensity && opts.w_intensity > 0.0 && opts.intensity_tau > 0.0 {
1076        let i1 = ms1_int[ms1_idx];
1077        let i2 = ms2_int[ms2_idx];
1078
1079        if i1 > 0.0 && i2 > 0.0 {
1080            let ratio = (i2 / i1).max(1e-6);
1081            let d = ratio.ln().abs();
1082            let s = (-d / opts.intensity_tau).exp();
1083            if s.is_finite() {
1084                score += opts.w_intensity * s;
1085                w_sum += opts.w_intensity;
1086            }
1087        }
1088    }
1089
1090    if w_sum <= 0.0 {
1091        return None;
1092    }
1093    let final_score = score / w_sum;
1094    if !final_score.is_finite() {
1095        None
1096    } else {
1097        Some(final_score.clamp(0.0, 1.0))
1098    }
1099}
1100
1101pub fn xic_match_score_precursor(
1102    prec: PrecursorLike<'_>,
1103    frag: &ClusterResult1D,
1104    opts: &XicScoreOpts,
1105) -> Option<(XicDetails, f32)> {
1106    match prec {
1107        PrecursorLike::Cluster(c) => xic_match_score(c, frag, opts),
1108
1109        PrecursorLike::Feature(f) => {
1110            let top = top_intensity_member(f)?;
1111            xic_match_score(top, frag, opts)
1112        }
1113    }
1114}
1115
1116
1117pub fn geom_match_score_precursor(
1118    prec: PrecursorLike<'_>,
1119    frag: &ClusterResult1D,
1120    opts: &ScoreOpts,
1121) -> Option<(PairFeatures, f32)> {
1122    match prec {
1123        PrecursorLike::Cluster(c) => {
1124            let (f, s) = geom_match_score_single_cluster(c, frag, opts);
1125            Some((f, s))
1126        }
1127
1128        PrecursorLike::Feature(feat) => {
1129            let top = top_intensity_member(feat)?;
1130            let (f, s) = geom_match_score_single_cluster(top, frag, opts);
1131            Some((f, s))
1132        }
1133    }
1134}
1135
1136/// Score a single precursor (cluster or feature) against a set of MS2 candidates.
1137///
1138/// - `prec`          : the precursor-like object (ClusterResult1D or SimpleFeature)
1139/// - `ms2`           : full fragment cluster array (e.g. `self.ms2`)
1140/// - `candidate_ids` : indices into `ms2` that passed the physical filters
1141/// - `mode`          : which scoring to use (Geom vs XIC)
1142/// - `geom_opts`     : required for Geom mode
1143/// - `xic_opts`      : required for XIC mode
1144/// - `min_score`     : keep only hits with `score >= min_score`
1145///
1146/// Returns a Vec of (frag_idx, score) sorted by descending score.
1147pub fn query_precursor_scored(
1148    prec: PrecursorLike<'_>,
1149    ms2: &[ClusterResult1D],
1150    candidate_ids: &[usize],
1151    mode: MatchScoreMode,
1152    geom_opts: &ScoreOpts,
1153    xic_opts: &XicScoreOpts,
1154    min_score: f32,
1155) -> Vec<ScoredHit> {
1156    let mut out: Vec<ScoredHit> = Vec::with_capacity(candidate_ids.len());
1157
1158    for &j in candidate_ids {
1159        if j >= ms2.len() {
1160            continue;
1161        }
1162        let frag = &ms2[j];
1163
1164        match mode {
1165            MatchScoreMode::Geom => {
1166                if let Some((f_geom, s)) = geom_match_score_precursor(prec, frag, geom_opts) {
1167                    if s.is_finite() && s >= min_score {
1168                        out.push(ScoredHit {
1169                            frag_idx: j,
1170                            score:   s,
1171                            geom:    Some(f_geom),
1172                            xic:     None,
1173                        });
1174                    }
1175                }
1176            }
1177            MatchScoreMode::Xic => {
1178                if let Some((xic_det, s)) = xic_match_score_precursor(prec, frag, xic_opts) {
1179                    if s.is_finite() && s >= min_score {
1180                        out.push(ScoredHit {
1181                            frag_idx: j,
1182                            score:   s,
1183                            geom:    None,
1184                            xic:     Some(xic_det),
1185                        });
1186                    }
1187                }
1188            }
1189        }
1190    }
1191
1192    out.sort_unstable_by(|a, b| {
1193        b.score
1194            .partial_cmp(&a.score)
1195            .unwrap_or(Ordering::Equal)
1196    });
1197
1198    out
1199}
1200
1201/// Score many precursors in parallel.
1202///
1203/// - `precs`              : slice of precursor-like objects (Cluster or Feature)
1204/// - `ms2`                : full MS2 slice
1205/// - `candidates_per_prec`: for each precursor, the Vec<usize> of candidate MS2 indices
1206///
1207/// Returns a Vec of length `precs.len()`, where each entry is the
1208/// scored hits for that precursor (sorted by descending score).
1209pub fn query_precursors_scored_par(
1210    precs: &[PrecursorLike<'_>],
1211    ms2: &[ClusterResult1D],
1212    candidates_per_prec: &[Vec<usize>],
1213    mode: MatchScoreMode,
1214    geom_opts: &ScoreOpts,
1215    xic_opts: &XicScoreOpts,
1216    min_score: f32,
1217) -> Vec<Vec<ScoredHit>> {
1218    assert_eq!(
1219        precs.len(),
1220        candidates_per_prec.len(),
1221        "precs and candidates_per_prec must have same length"
1222    );
1223
1224    precs
1225        .par_iter()
1226        .zip(candidates_per_prec.par_iter())
1227        .map(|(&prec, cands)| {
1228            query_precursor_scored(
1229                prec,
1230                ms2,
1231                cands,
1232                mode,
1233                geom_opts,
1234                xic_opts,
1235                min_score,
1236            )
1237        })
1238        .collect()
1239}