跳至主要內容

Rust与算法基础(6):归并排序(中)

QuenTine...大约 7 分钟编程Rust算法

除了forget外,另一种解决方式

上一篇结尾用了forget函数避免了资源的重复释放,但是这种方法需要再钻进临时向量中,对元素一个个执行。 其实有另一种方法,可以在一开始就把资源标记为不需要释放,那就是 ManuallyDropopen in new window

官方文档里的例子是类似这样的用法:

use std::mem::ManuallyDrop;
let mut x = ManuallyDrop::new(String::from("Hello World!"));

新建对象,然后包裹进去。但我们现在不这么操作,将向量中一个一个元素创建ManuallyDrop然后塞进临时向量里, 和之前一个一个forget也没什么两样。

我们看官方文档的描述,最开头强调ManuallyDrop是“零开销(0-cost)”抽象:

... This wrapper is 0-cost. ManuallyDrop<T> is guaranteed to have the same layout and bit validity as T ...

也就是说在内存空间中存储的数据,ManuallyDrop<T>是与T完全一致,此即零开销。

这就意味着,我们可以通过改变指针类型,把一个类型的数据“当成”另一种数据,这是C/C++里常用的操作。

在merge函数里,把创建的临时数组类型

let left_layout = Layout::array::<T>(left_length).expect("left allocation failed!");
let left_mem = alloc(left_layout).cast::<T>();
let mut left = Vec::from_raw_parts(left_mem, left_length, left_length);

let right_layout = Layout::array::<T>(right_length).expect("right allocation failed!");
let right_mem = alloc(right_layout).cast::<T>();
let mut right = Vec::from_raw_parts(right_mem, right_length, right_length);

改为Vec<ManuallyDrop<T>>的:

let left_layout = Layout::array::<T>(left_length).expect("left allocation failed!");
let left_mem = alloc(left_layout).cast::<ManuallyDrop<T>>();
let mut left = Vec::from_raw_parts(left_mem, left_length, left_length);

let right_layout = Layout::array::<T>(right_length).expect("right allocation failed!");
let right_mem = alloc(right_layout).cast::<ManuallyDrop<T>>();
let mut right = Vec::from_raw_parts(right_mem, right_length, right_length);

只改动了left_mem要转换成的指针类型,这个指针是*mut ManuallyDrop<T>类型的,其他都没有改。

left_layout没有改,因为0-cost,只要给出的长度(left_length/right_length)相同,申请的空间总大小(left_length*(size of T))就是相同的。

Vec::from_raw_parts因为第一个参数的指针类型变成了*mut ManuallyDrop<T>类型,left/right向量就是 Vec<ManuallyDrop<T>>类型的。

这么一改,下面的copy函数,编译器就报出了参数类型不匹配的错误了。

看到报错的代码里,比如这句:

ptr::copy(&left[i], &mut vec[k], 1);

copy函数需要的参数都是生指针,而代入的参数是引用。这里因为引用可以用as关键字转为对应类型的生指针, 所以代入这个函数的时候已经隐式地转换了指针类型,所以这行代码等价于:

ptr::copy((&left[i] as *const ManuallyDrop<T>), (&mut vec[k] as *mut T), 1);

所以我们只需要显式强转一下这个生指针类型就行了:

ptr::copy((&left[i] as *const ManuallyDrop<T>).cast::<T>(), &mut vec[k], 1);

其他地方也做类似的转换,最终编译器通过,运行也正常。

fn merge<T>(
    vec: &mut Vec<T>,
    compare: fn(prev: &T, next: &T) -> bool,
    p: usize,
    q: usize,
    r: usize
) {
    let left_length = q - p;
    let right_length = r - q;

    unsafe {
        let left_layout = Layout::array::<T>(left_length).expect("left allocation failed!");
        let left_mem = alloc(left_layout).cast::<ManuallyDrop<T>>();
        let mut left = Vec::from_raw_parts(left_mem, left_length, left_length);

        let right_layout = Layout::array::<T>(right_length).expect("right allocation failed!");
        let right_mem = alloc(right_layout).cast::<ManuallyDrop<T>>();
        let mut right = Vec::from_raw_parts(right_mem, right_length, right_length);

        ptr::copy(&vec[p], (&mut left[0] as *mut ManuallyDrop<T>).cast::<T>(), left_length);
        ptr::copy(&vec[q], (&mut right[0] as *mut ManuallyDrop<T>).cast::<T>(), right_length);

        let mut i = 0;
        let mut j = 0;
        let mut k = p;

        while i < left_length && j < right_length {
            if compare(&left[i], &right[j]) {
                ptr::copy((&left[i] as *const ManuallyDrop<T>).cast::<T>(), &mut vec[k], 1);
                i += 1;
            } else {
                ptr::copy((&right[j] as *const ManuallyDrop<T>).cast::<T>(), &mut vec[k], 1);
                j += 1;
            }
            k += 1;
        }

        if i < left_length {
            ptr::copy(
                (&left[i] as *const ManuallyDrop<T>).cast::<T>(),
                &mut vec[k],
                left_length - i
            );
        } else if j < right_length {
            ptr::copy(
                (&right[j] as *const ManuallyDrop<T>).cast::<T>(),
                &mut vec[k],
                right_length - j
            );
        }
    }
}

对于这两句

ptr::copy(&vec[p], (&mut left[0] as *mut ManuallyDrop<T>).cast::<T>(), left_length);
ptr::copy(&vec[q], (&mut right[0] as *mut ManuallyDrop<T>).cast::<T>(), right_length);

因为第二个参数就是指向向量第一个元素的指针,也就是指向新开辟内存的首位,这不就是left_mem/right_mem嘛。 所以可以直接改过去,变成:

ptr::copy(&vec[p], left_mem.cast::<T>(), left_length);
ptr::copy(&vec[q], right_mem.cast::<T>(), right_length);

编译器提示之前的left/right声明的变量可以不为可变的了,可以把mut删掉。

let left_layout = Layout::array::<T>(left_length).expect("left allocation failed!");
let left_mem = alloc(left_layout).cast::<ManuallyDrop<T>>();
let left = Vec::from_raw_parts(left_mem, left_length, left_length);

let right_layout = Layout::array::<T>(right_length).expect("right allocation failed!");
let right_mem = alloc(right_layout).cast::<ManuallyDrop<T>>();
let right = Vec::from_raw_parts(right_mem, right_length, right_length);

ptr::copy(&vec[p], left_mem.cast::<T>(), left_length);
ptr::copy(&vec[q], right_mem.cast::<T>(), right_length);

甚至可以把这两个copy提前到from_raw_parts前面,居然也能正常运行。

也就是说Vec::from_raw_parts函数,就是把一段开辟的堆空间“当成”是对应类型的向量。

最终的代码是这样的:

lib.rs
use std::{ alloc::{ alloc, Layout }, mem::ManuallyDrop, ptr };
use algorithms_prelude::Sorter;

pub struct MergeSorter<'a, Seq>(pub &'a mut Seq);

impl<'a, Elem> Sorter for MergeSorter<'a, Vec<Elem>> {
    type Element = Elem;

    fn sort_by(&mut self, compare: fn(prev: &Self::Element, next: &Self::Element) -> bool) {
        let vec = &mut self.0;

        if vec.len() < 2 {
            return;
        }

        merge_sort(vec, compare, 0, vec.len());
    }
}

fn merge_sort<T>(vec: &mut Vec<T>, compare: fn(prev: &T, next: &T) -> bool, p: usize, r: usize) {
    if p < r - 1 {
        let q = (p + 1 + r) >> 1;
        merge_sort(vec, compare, p, q);
        merge_sort(vec, compare, q, r);
        merge(vec, compare, p, q, r);
    }
}

fn merge<T>(
    vec: &mut Vec<T>,
    compare: fn(prev: &T, next: &T) -> bool,
    p: usize,
    q: usize,
    r: usize
) {
    let left_length = q - p;
    let right_length = r - q;

    unsafe {
        // ManuallyDrop的零成本抽象,让ManuallyDrop<T>和T在内存中的数据结构完全一致,也就允许用指针强转的方式强行写入。
        let left_layout = Layout::array::<T>(left_length).expect("left allocation failed!");
        // 申请相应尺寸的空间(前面已定义),返回一个指针,但并没有假设你的指针是对应类型的,需要你手动强转一遍。
        let left_mem = alloc(left_layout).cast::<ManuallyDrop<T>>();
        // 强转指针,匹配类型
        ptr::copy(&vec[p], left_mem.cast::<T>(), left_length);
        let left = Vec::from_raw_parts(left_mem, left_length, left_length);

        let right_layout = Layout::array::<T>(right_length).expect("right allocation failed!");
        let right_mem = alloc(right_layout).cast::<ManuallyDrop<T>>();
        ptr::copy(&vec[q], right_mem.cast::<T>(), right_length);
        let right = Vec::from_raw_parts(right_mem, right_length, right_length);

        let mut i = 0;
        let mut j = 0;
        let mut k = p;

        while i < left_length && j < right_length {
            if compare(&left[i], &right[j]) {
                ptr::copy((&left[i] as *const ManuallyDrop<T>).cast::<T>(), &mut vec[k], 1);
                i += 1;
            } else {
                ptr::copy((&right[j] as *const ManuallyDrop<T>).cast::<T>(), &mut vec[k], 1);
                j += 1;
            }
            k += 1;
        }

        if i < left_length {
            ptr::copy(
                (&left[i] as *const ManuallyDrop<T>).cast::<T>(),
                &mut vec[k],
                left_length - i
            );
        } else if j < right_length {
            ptr::copy(
                (&right[j] as *const ManuallyDrop<T>).cast::<T>(),
                &mut vec[k],
                right_length - j
            );
        }
    }
}

#[cfg(test)]
mod test {
    use super::*;
    #[test]
    fn it_sort_ascending() {
        let mut v = vec![22, 43, 145, 1, 9];
        MergeSorter(&mut v).sort_by(|prev, next| prev < next);
        assert_eq!(v, vec![1, 9, 22, 43, 145]);
    }

    #[test]
    fn it_sort_descending() {
        let mut v = vec![22, 43, 145, 1, 9];
        MergeSorter(&mut v).sort_by(|prev, next| prev > next);
        assert_eq!(v, vec![145, 43, 22, 9, 1]);
    }

    #[test]
    fn it_struct_sort_ascending() {
        #[derive(Debug, PartialEq)]
        struct Foo {
            id: u32,
            name: &'static str,
        }

        let mut v = vec![
            Foo {
                id: 22,
                name: "ZS",
            },
            Foo {
                id: 43,
                name: "LS",
            },
            Foo {
                id: 145,
                name: "WW",
            },
            Foo {
                id: 1,
                name: "ZL",
            },
            Foo {
                id: 9,
                name: "SQ",
            }
        ];

        MergeSorter(&mut v).sort_by(|prev, next| prev.id < next.id);
        assert_eq!(
            v,
            vec![
                Foo {
                    id: 1,
                    name: "ZL",
                },
                Foo {
                    id: 9,
                    name: "SQ",
                },
                Foo {
                    id: 22,
                    name: "ZS",
                },
                Foo {
                    id: 43,
                    name: "LS",
                },
                Foo {
                    id: 145,
                    name: "WW",
                }
            ]
        );
    }

    #[test]
    fn it_struct_sort_ascending_equal() {
        #[derive(Debug, PartialEq)]
        struct Foo {
            id: u32,
            name: &'static str,
        }

        let mut v = vec![
            Foo {
                id: 22,
                name: "ZS",
            },
            Foo {
                id: 43,
                name: "LS",
            },
            Foo {
                id: 145,
                name: "WW",
            },
            Foo {
                id: 1,
                name: "ZL",
            },
            Foo {
                id: 9,
                name: "SQ",
            },
            Foo {
                id: 43,
                name: "LS2",
            }
        ];

        MergeSorter(&mut v).sort_by(|prev, next| prev.id <= next.id);
        assert_eq!(
            v,
            vec![
                Foo {
                    id: 1,
                    name: "ZL",
                },
                Foo {
                    id: 9,
                    name: "SQ",
                },
                Foo {
                    id: 22,
                    name: "ZS",
                },
                Foo {
                    id: 43,
                    name: "LS",
                },
                Foo {
                    id: 43,
                    name: "LS2",
                },
                Foo {
                    id: 145,
                    name: "WW",
                }
            ]
        );
    }

    #[test]
    fn it_struct_sort_ascending_equal_box() {
        #[derive(Debug, PartialEq)]
        struct Foo {
            id: u32,
            name: &'static str,
        }

        let mut v = vec![
            Box::new(Foo {
                id: 22,
                name: "ZS",
            }),
            Box::new(Foo {
                id: 43,
                name: "LS",
            }),
            Box::new(Foo {
                id: 145,
                name: "WW",
            }),
            Box::new(Foo {
                id: 1,
                name: "ZL",
            }),
            Box::new(Foo {
                id: 9,
                name: "SQ",
            }),
            Box::new(Foo {
                id: 43,
                name: "LS2",
            })
        ];

        MergeSorter(&mut v).sort_by(|prev, next| prev.id <= next.id);
        assert_eq!(
            v,
            vec![
                Box::new(Foo {
                    id: 1,
                    name: "ZL",
                }),
                Box::new(Foo {
                    id: 9,
                    name: "SQ",
                }),
                Box::new(Foo {
                    id: 22,
                    name: "ZS",
                }),
                Box::new(Foo {
                    id: 43,
                    name: "LS",
                }),
                Box::new(Foo {
                    id: 43,
                    name: "LS2",
                }),
                Box::new(Foo {
                    id: 145,
                    name: "WW",
                })
            ]
        );
    }
}

再优化

在这个归并排序中,每次合并都是将待合并的两个切片先拷贝到临时区域,再重新写入,一来一回,遍及整个向量。 显然有个拷贝更少的方案,就是让切分至末端的数组归并到临时的数组,再继续归并,最终才写入原数组,从反复横跳,到转一圈走完。

原本的复杂度是2nlgn,新的复杂度优化为(n+1)lgn。怎么优化,下一篇再说。

上次编辑于:
贡献者: qt911025
评论
  • 按正序
  • 按倒序
  • 按热度
Powered by Waline v3.1.3