Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]

### Added

- distribute_unit function

## [0.1.0] - 2025-02-24

### Changed
Expand Down
98 changes: 98 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,46 @@ impl Rearranging {
Ok(ans)
}

/// 从候选值中选择一个能整除当前单元规模的值作为新的单元规模
pub fn distribute_unit(&self, candidates: impl IntoIterator<Item = usize>) -> Option<Self> {
let unit = candidates.into_iter().find(|n| self.unit() % n == 0)?;
if unit == self.unit() {
return Some(self.clone());
}

let ndim = self.ndim();
let mut layout = vec![0isize; 4 + (ndim + 1) * 3].into_boxed_slice();
layout[0] = unit as _;
layout[1] = self.dst_offset();
layout[2] = self.src_offset();

let (_, tail) = layout.split_at_mut(3);
let (idx, tail) = tail.split_at_mut(ndim + 2);
let (dst, src) = tail.split_at_mut(ndim + 1);

let (_, tail) = self.0.split_at(3);
let (idx_, tail) = tail.split_at(ndim + 1);
let (dst_, src_) = tail.split_at(ndim);

idx[ndim + 1] = 1;
let extra = (self.unit() / unit) as isize;
for (new, old) in izip!(idx, idx_) {
*new = *old * extra;
}

fn copy_value(new: &mut [isize], old: &[isize], unit: usize) {
let [head @ .., tail] = new else {
unreachable!()
};
head.copy_from_slice(old);
*tail = unit as _;
}
copy_value(dst, dst_, unit);
copy_value(src, src_, unit);

Some(Self(layout))
}

/// 执行方案维数。
#[inline]
pub fn ndim(&self) -> usize {
Expand Down Expand Up @@ -233,3 +273,61 @@ fn test_scheme() {
assert_eq!(scheme.src_strides(), [96, 8, 16]);
assert_eq!(scheme.shape().collect::<Vec<_>>(), [24, 2, 3]);
}

#[test]
fn test_distribute_unit() {
// 创建一个测试用的重排方案
let shape = [4, 3, 2];
let dst = [24, 8, 2];
let src = [48, 8, 16];
let dst = ArrayLayout::<3>::new(&shape, &dst, 0);
let src = ArrayLayout::<3>::new(&shape, &src, 0);
let scheme = Rearranging::new(&dst, &src, 2).unwrap();

// 测试1: 使用相同的单元大小
let candidates = vec![2];
let new_scheme = scheme.distribute_unit(candidates).unwrap();
assert_eq!(new_scheme.unit(), 2);
assert_eq!(new_scheme.count(), scheme.count());
assert_eq!(new_scheme.idx_strides(), scheme.idx_strides());
assert_eq!(new_scheme.dst_strides(), scheme.dst_strides());
assert_eq!(new_scheme.src_strides(), scheme.src_strides());

// 测试2: 使用更小的单元大小
let candidates = vec![1];
let new_scheme = scheme.distribute_unit(candidates).unwrap();
assert_eq!(new_scheme.unit(), 1);
assert_eq!(new_scheme.count(), scheme.count() * 2);
assert_eq!(
new_scheme
.idx_strides()
.iter()
.take(scheme.idx_strides().len())
.map(|&x| x / 2)
.collect::<Vec<_>>(),
scheme.idx_strides()
);
assert_eq!(
new_scheme
.dst_strides()
.iter()
.take(scheme.idx_strides().len())
.cloned()
.collect::<Vec<_>>(),
scheme.dst_strides()
);
assert_eq!(
new_scheme
.src_strides()
.iter()
.take(scheme.idx_strides().len())
.cloned()
.collect::<Vec<_>>(),
scheme.src_strides()
);

// 测试3: 使用多个候选值
let candidates = vec![4, 2, 1];
let new_scheme = scheme.distribute_unit(candidates).unwrap();
assert_eq!(new_scheme.unit(), 2); // 应该选择第一个能整除的值
}