1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
use impl_prelude::*;
use lapack::c::{sgesv, dgesv, cgesv, zgesv};
use super::types::SolveError;
pub trait SolveLinear: Sized + Clone {
fn compute_multi_into<D1, D2>(a: ArrayBase<D1, Ix2>,
b: ArrayBase<D2, Ix2>)
-> Result<ArrayBase<D2, Ix2>, SolveError>
where D1: DataMut<Elem = Self> + DataOwned<Elem = Self>,
D2: DataMut<Elem = Self> + DataOwned<Elem = Self>;
fn compute_into<D1, D2>(a: ArrayBase<D1, Ix2>,
b: ArrayBase<D2, Ix1>)
-> Result<ArrayBase<D2, Ix1>, SolveError>
where D1: DataMut<Elem = Self> + DataOwned<Elem = Self>,
D2: DataMut<Elem = Self> + DataOwned<Elem = Self>
{
let n = b.dim();
let b_mat = match b.into_shape((n, 1)) {
Ok(x) => x,
Err(_) => return Err(SolveError::BadLayout),
};
let res = try!(Self::compute_multi_into(a, b_mat));
Ok(res.into_shape(n).unwrap())
}
fn compute_multi<D1, D2>(a: &ArrayBase<D1, Ix2>,
b: &ArrayBase<D2, Ix2>)
-> Result<Array<Self, Ix2>, SolveError>
where D1: Data<Elem = Self>,
D2: Data<Elem = Self>
{
let a_copy = a.to_owned();
let b_copy = b.to_owned();
Self::compute_multi_into(a_copy, b_copy)
}
fn compute<D1, D2>(a: &ArrayBase<D1, Ix2>,
b: &ArrayBase<D2, Ix1>)
-> Result<Array<Self, Ix1>, SolveError>
where D1: Data<Elem = Self>,
D2: Data<Elem = Self>
{
let a_copy = a.to_owned();
let b_copy = b.to_owned();
Self::compute_into(a_copy, b_copy)
}
}
macro_rules! impl_solve_linear {
($impl_type: ty, $driver: ident) => (
impl SolveLinear for $impl_type {
fn compute_multi_into<D1, D2>(mut a: ArrayBase<D1, Ix2>, mut b: ArrayBase<D2, Ix2>)
-> Result<ArrayBase<D2, Ix2>, SolveError>
where D1: DataMut<Elem=Self> + DataOwned<Elem = Self>,
D2: DataMut<Elem=Self> + DataOwned<Elem = Self> {
let dim = a.dim();
let b_dim = b.dim();
if dim.0 != dim.1 {
return Err(SolveError::NotSquare(dim.0, dim.1));
}
if dim.0 != b_dim.0 {
return Err(SolveError::InconsistentDimensions(dim.0, b_dim.0));
}
let (slice, layout, lda) = match slice_and_layout_mut(&mut a) {
Some(x) => x,
None => return Err(SolveError::BadLayout)
};
let info = {
let (b_slice, ldb) = match slice_and_layout_matching_mut(&mut b, layout) {
Some(x) => x,
None => return Err(SolveError::InconsistentLayout)
};
let mut perm: Array<i32, Ix1> = Array::default(dim.0);
$driver(layout, dim.0 as i32, b_dim.1 as i32,
slice, lda as i32,
perm.as_slice_mut().unwrap(),
b_slice, ldb as i32)
};
if info == 0 {
Ok(b)
} else if info < 0 {
Err(SolveError::IllegalValue(-info))
} else {
Err(SolveError::Singular(info))
}
}
})
}
impl_solve_linear!(f32, sgesv);
impl_solve_linear!(f64, dgesv);
impl_solve_linear!(c32, cgesv);
impl_solve_linear!(c64, zgesv);