1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
use crate::index::*;
use crate::kernels::take::take_unchecked;
use crate::trusted_len::PushUnchecked;
use crate::utils::CustomIterTools;
use arrow::array::{ArrayRef, ListArray};
use arrow::buffer::Buffer;
fn sublist_get_indexes(arr: &ListArray<i64>, index: i64) -> IdxArr {
let mut iter = arr.offsets().iter();
let mut cum_offset: IdxSize = 0;
if let Some(mut previous) = iter.next().copied() {
let a: IdxArr = iter
.map(|&offset| {
let len = offset - previous;
if len == 0 {
return None;
}
previous = offset;
let out = index
.negative_to_usize(len as usize)
.map(|idx| idx as IdxSize + cum_offset);
cum_offset += len as IdxSize;
out
})
.collect_trusted();
a
} else {
IdxArr::from_slice(&[])
}
}
pub fn sublist_get(arr: &ListArray<i64>, index: i64) -> ArrayRef {
let take_by = sublist_get_indexes(arr, index);
let values = arr.values();
unsafe { take_unchecked(&**values, &take_by) }
}
pub fn array_to_unit_list(array: ArrayRef) -> ListArray<i64> {
let len = array.len();
let mut offsets = Vec::with_capacity(len + 1);
unsafe {
offsets.push_unchecked(0i64);
for _ in 0..len {
offsets.push_unchecked(offsets.len() as i64)
}
};
let offsets: Buffer<i64> = offsets.into();
let dtype = ListArray::<i64>::default_datatype(array.data_type().clone());
ListArray::<i64>::from_data(dtype, offsets, array, None)
}
#[cfg(test)]
mod test {
use super::*;
use arrow::array::{Int32Array, PrimitiveArray};
use arrow::buffer::Buffer;
use arrow::datatypes::DataType;
use std::sync::Arc;
fn get_array() -> ListArray<i64> {
let values = Int32Array::from_slice(&[1, 2, 3, 4, 5, 6]);
let offsets = Buffer::from(vec![0i64, 3, 5, 6]);
let dtype = ListArray::<i64>::default_datatype(DataType::Int32);
ListArray::<i64>::from_data(dtype, offsets, Arc::new(values), None)
}
#[test]
fn test_sublist_get_indexes() {
let arr = get_array();
let out = sublist_get_indexes(&arr, 0);
assert_eq!(out.values().as_slice(), &[0, 3, 5]);
let out = sublist_get_indexes(&arr, -1);
assert_eq!(out.values().as_slice(), &[2, 4, 5]);
}
#[test]
fn test_sublist_get() {
let arr = get_array();
let out = sublist_get(&arr, 0);
let out = out.as_any().downcast_ref::<PrimitiveArray<i32>>().unwrap();
assert_eq!(out.values().as_slice(), &[1, 4, 6]);
let out = sublist_get(&arr, -1);
let out = out.as_any().downcast_ref::<PrimitiveArray<i32>>().unwrap();
assert_eq!(out.values().as_slice(), &[3, 5, 6]);
}
}