diff --git a/CHANGELOG.md b/CHANGELOG.md index 588db17..4e5271b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] + +### Added + +- distribute_unit function + ## [0.1.0] - 2025-02-24 ### Changed diff --git a/src/lib.rs b/src/lib.rs index 7e84fb8..d3fe851 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -124,6 +124,46 @@ impl Rearranging { Ok(ans) } + /// 从候选值中选择一个能整除当前单元规模的值作为新的单元规模 + pub fn distribute_unit(&self, candidates: impl IntoIterator) -> Option { + let unit = candidates.into_iter().find(|n| self.unit() % n == 0)?; + if unit == self.unit() { + return Some(self.clone()); + } + + let ndim = self.ndim(); + let mut layout = vec![0isize; 4 + (ndim + 1) * 3].into_boxed_slice(); + layout[0] = unit as _; + layout[1] = self.dst_offset(); + layout[2] = self.src_offset(); + + let (_, tail) = layout.split_at_mut(3); + let (idx, tail) = tail.split_at_mut(ndim + 2); + let (dst, src) = tail.split_at_mut(ndim + 1); + + let (_, tail) = self.0.split_at(3); + let (idx_, tail) = tail.split_at(ndim + 1); + let (dst_, src_) = tail.split_at(ndim); + + idx[ndim + 1] = 1; + let extra = (self.unit() / unit) as isize; + for (new, old) in izip!(idx, idx_) { + *new = *old * extra; + } + + fn copy_value(new: &mut [isize], old: &[isize], unit: usize) { + let [head @ .., tail] = new else { + unreachable!() + }; + head.copy_from_slice(old); + *tail = unit as _; + } + copy_value(dst, dst_, unit); + copy_value(src, src_, unit); + + Some(Self(layout)) + } + /// 执行方案维数。 #[inline] pub fn ndim(&self) -> usize { @@ -233,3 +273,61 @@ fn test_scheme() { assert_eq!(scheme.src_strides(), [96, 8, 16]); assert_eq!(scheme.shape().collect::>(), [24, 2, 3]); } + +#[test] +fn test_distribute_unit() { + // 创建一个测试用的重排方案 + let shape = [4, 3, 2]; + let dst = [24, 8, 2]; + let src = [48, 8, 16]; + let dst = ArrayLayout::<3>::new(&shape, &dst, 0); + let src = ArrayLayout::<3>::new(&shape, &src, 0); + let scheme = Rearranging::new(&dst, &src, 2).unwrap(); + + // 测试1: 使用相同的单元大小 + let candidates = vec![2]; + let new_scheme = scheme.distribute_unit(candidates).unwrap(); + assert_eq!(new_scheme.unit(), 2); + assert_eq!(new_scheme.count(), scheme.count()); + assert_eq!(new_scheme.idx_strides(), scheme.idx_strides()); + assert_eq!(new_scheme.dst_strides(), scheme.dst_strides()); + assert_eq!(new_scheme.src_strides(), scheme.src_strides()); + + // 测试2: 使用更小的单元大小 + let candidates = vec![1]; + let new_scheme = scheme.distribute_unit(candidates).unwrap(); + assert_eq!(new_scheme.unit(), 1); + assert_eq!(new_scheme.count(), scheme.count() * 2); + assert_eq!( + new_scheme + .idx_strides() + .iter() + .take(scheme.idx_strides().len()) + .map(|&x| x / 2) + .collect::>(), + scheme.idx_strides() + ); + assert_eq!( + new_scheme + .dst_strides() + .iter() + .take(scheme.idx_strides().len()) + .cloned() + .collect::>(), + scheme.dst_strides() + ); + assert_eq!( + new_scheme + .src_strides() + .iter() + .take(scheme.idx_strides().len()) + .cloned() + .collect::>(), + scheme.src_strides() + ); + + // 测试3: 使用多个候选值 + let candidates = vec![4, 2, 1]; + let new_scheme = scheme.distribute_unit(candidates).unwrap(); + assert_eq!(new_scheme.unit(), 2); // 应该选择第一个能整除的值 +}