算法分析入门系列(二) 分治算法

Strassen矩阵算法

矩阵分割

将$NN$的矩阵分割为$\frac{N}{2}\frac{N}{2}$的两个矩阵,在这个相乘的过程中我们发现一共有四次加法,八次乘法。

而后就可以获得这个算法的递推公式:
$$
T(N) = 8*T(\frac{N}{2}) + Θ(N^2)
$$

源代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
/**
* StrassenMatrix
*/
public class StrassenMatrix {

public void printMatrix(Matrix matrix, int n) {
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
System.out.print(matrix.matrix[i][j] + " ");
}
System.out.println();
}
}

/**
* 矩阵分割
*
* @param M
* @param M11
* @param M12
* @param M21
* @param M22
* @param n
*/
public void Divide(Matrix M, Matrix M11, Matrix M12, Matrix M21, Matrix M22, int n) {
/**
* 遍历整个M矩阵,将M的不同区域分割给不同的四个矩阵 通过一个常数n来划分,每个矩阵的数之间都隔了n,所以 不会有重复出现
*/
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
M11.matrix[i][j] = M.matrix[i][j];
M12.matrix[i][j] = M.matrix[i][j + n];
M21.matrix[i][j] = M.matrix[i + n][j];
M22.matrix[i][j] = M.matrix[i + n][j + n];
}
}
}

/***
* 矩阵合并
*
* @param M
* @param M11
* @param M12
* @param M21
* @param M22
* @return
*/
public Matrix MergeMatrix(Matrix M11, Matrix M12, Matrix M21, Matrix M22, int n) {
Matrix _return = new Matrix();
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
_return.matrix[i][j] = M11.matrix[i][j];
_return.matrix[i][j + n] = M12.matrix[i][j];
_return.matrix[i + n][j] = M21.matrix[i][j];
_return.matrix[i + n][j + n] = M22.matrix[i][j];
}
}
return _return;
}

/**
* 阶数为2的矩阵乘法--Strassen法
*
* @param x
* @param y
* @return
*/
public Matrix MatrixMultiplication(Matrix x, Matrix y) {
int M1, M2, M3, M4, M5, M6, M7;
M1 = x.matrix[0][0] * (y.matrix[0][1] - y.matrix[1][1]);
M2 = y.matrix[1][1] * (x.matrix[0][0] + x.matrix[0][1]);
M3 = y.matrix[0][0] * (x.matrix[1][0] + x.matrix[1][1]);
M4 = x.matrix[1][1] * (y.matrix[1][0] - y.matrix[0][0]);
M5 = (x.matrix[0][0] * y.matrix[0][0] + x.matrix[0][0] * y.matrix[1][1])
+ (x.matrix[1][1] * y.matrix[0][0] + x.matrix[1][1] * y.matrix[1][1]);
M6 = (x.matrix[0][1] * y.matrix[1][0] + x.matrix[0][1] * y.matrix[1][1])
- (x.matrix[1][1] * y.matrix[1][0] + x.matrix[1][1] * y.matrix[1][1]);
M7 = (x.matrix[0][0] * y.matrix[0][0] + x.matrix[0][0] * y.matrix[0][1])
- (x.matrix[1][0] * y.matrix[0][0] + x.matrix[1][0] * y.matrix[0][1]);
Matrix _return = new Matrix();
_return.matrix[0][0] = M5 + M4 - M2 + M6;
_return.matrix[1][0] = M1 + M2;
_return.matrix[0][1] = M3 + M4;
_return.matrix[1][1] = M5 + M1 - M3 - M7;
return _return;
}

/**
* 矩阵乘法,阶数大于2
*
* @param x
* @param y
* @param n
* @return
*/
public Matrix MatrixMultiplication(Matrix x, Matrix y, int n) {
Matrix A11 = new Matrix();
Matrix A12 = new Matrix();
Matrix A21 = new Matrix();
Matrix A22 = new Matrix();
Matrix B11 = new Matrix();
Matrix B12 = new Matrix();
Matrix B21 = new Matrix();
Matrix B22 = new Matrix();
Matrix C11 = new Matrix();
Matrix C12 = new Matrix();
Matrix C21 = new Matrix();
Matrix C22 = new Matrix();
Matrix M1, M2, M3, M4, M5, M6, M7;
if (n == 2) {
return MatrixMultiplication(x, y);
} else {
Matrix C = null;
Divide(x, A11, A12, A21, A22, n / 2);
Divide(y, B11, B12, B21, B22, n / 2);
// Divide(C, C11, C12, C21, C22, n / 2);
n /= 2;
M1 = MatrixMultiplication(A11, MatrixModified(B12, B22, n, false), n);
M2 = MatrixMultiplication(MatrixModified(A11, A12, n, true), B22, n);
M3 = MatrixMultiplication(MatrixModified(A21, A22, n, true), B11, n);
M4 = MatrixMultiplication(A22, MatrixModified(B21, B11, n, false), n);
M5 = MatrixMultiplication(MatrixModified(A11, A22, n, true), MatrixModified(B11, B22, n, true), n);
M6 = MatrixMultiplication(MatrixModified(A12, A22, n, false), MatrixModified(B21, B22, n, true), n);
M7 = MatrixMultiplication(MatrixModified(A11, A21, n, false), MatrixModified(B11, B12, n, true), n);
C11 = MatrixModified(MatrixModified(M5, M4, n, true), MatrixModified(M2, M6, n, false), n, false);
C12 = MatrixModified(M1, M2, n, true);
C21 = MatrixModified(M3, M4, n, true);
C22 = MatrixModified(MatrixModified(M5, M1, n, true), MatrixModified(M3, M7, n, true), n, false);
C = MergeMatrix(C11, C12, C21, C22, n);
return C;
}
}

/**
* 矩阵加减法
*
* @param x
* @param y
* @return
*/
public Matrix MatrixModified(Matrix x, Matrix y, int n, Boolean isPlus) {
Matrix _return = new Matrix();
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
if (isPlus) {
_return.matrix[i][j] = x.matrix[i][j] + y.matrix[i][j];
} else {
_return.matrix[i][j] = x.matrix[i][j] - y.matrix[i][j];
}
return _return;
}

public static void main(String[] args) {
StrassenMatrix strassenMatrix = new StrassenMatrix();
// int[][] M_X = { { 1, 1 }, { 1, 1 } };
// int[][] M_Y = { { 1, 1 }, { 1, 1 } };
int[][] M_X = { { 1, 1, 1, 1, 1, 1, 1, 1 }, { 1, 1, 1, 1, 1, 1, 1, 1 }, { 1, 1, 1, 1, 1, 1, 1, 1 },
{ 1, 1, 1, 1, 1, 1, 1, 1 }, { 1, 1, 1, 1, 1, 1, 1, 1 }, { 1, 1, 1, 1, 1, 1, 1, 1 },
{ 1, 1, 1, 1, 1, 1, 1, 1 }, { 1, 1, 1, 1, 1, 1, 1, 1 } };
int[][] M_Y = { { 1, 1, 1, 1, 1, 1, 1, 1 }, { 1, 1, 1, 1, 1, 1, 1, 1 }, { 1, 1, 1, 1, 1, 1, 1, 1 },
{ 1, 1, 1, 1, 1, 1, 1, 1 }, { 1, 1, 1, 1, 1, 1, 1, 1 }, { 1, 1, 1, 1, 1, 1, 1, 1 },
{ 1, 1, 1, 1, 1, 1, 1, 1 }, { 1, 1, 1, 1, 1, 1, 1, 1 } };
Matrix x = new Matrix();
Matrix y = new Matrix();
x.matrix = M_X;
y.matrix = M_Y;
Matrix result = strassenMatrix.MatrixMultiplication(x, y, 8);
strassenMatrix.printMatrix(result, 8);
}

}

class Matrix {
public int[][] matrix = new int[32][32];
public Matrix() {
}
}

实验数据

两个$8*8$的单位矩阵相乘

1
2
3
4
5
6
7
8
9
10
int[][] M_X = { 
{ 1, 1, 1, 1, 1, 1, 1, 1 },
{ 1, 1, 1, 1, 1, 1, 1, 1 },
{ 1, 1, 1, 1, 1, 1, 1, 1 },
{ 1, 1, 1, 1, 1, 1, 1, 1 },
{ 1, 1, 1, 1, 1, 1, 1, 1 },
{ 1, 1, 1, 1, 1, 1, 1, 1 },
{ 1, 1, 1, 1, 1, 1, 1, 1 },
{ 1, 1, 1, 1, 1, 1, 1, 1 }
};

实验分析

原始算法中矩阵乘法的时间复杂度为$O(n^3)$,而在Strassen算法中降低到$O(n^{\log_2^7})$。

因为在普通的矩阵乘法中,需要进行8次阶数减半的子矩阵递归相乘,再加上矩阵相加与合并的时间,就会使得简单的矩阵乘法变得极为缓慢,而最主要的时间是用在8次矩阵相乘。
$$
T(N) = 8*T(\frac{N}{2}) + Θ(N^2)
$$
Strassen算法最主要的贡献就是将8次矩阵乘法减少到了7次,使得整个算法的复杂度有所降低。

实验结果

最近点对算法

问题描述

求出平面中所有点对里欧几里得距离最短的点对。

欧几里得距离:
$$
h = \sqrt{(x_1-x_2)^2+(y_1-y_2)^2}
$$

源代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;

/**
* NearestDots 最近点对算法
*/
public class NearestDots {
/**
* generate random dots
*
* @param num
* @return
*/
public List<Dot> generateDots(int num) {
List<Dot> dots = new ArrayList<>();
for (int i = 0; i < num; i++) {
Dot dot = new Dot(Math.random() * (num + 2), Math.random() * (num + 5));
dots.add(dot);
}
return dots;
}

/**
* get distance between tow dots
*
* @param dot1
* @param dot2
* @return
*/
public double getDistance(Dot dot1, Dot dot2) {
return (double) Math.sqrt(Math.pow(dot1.getX() - dot2.getX(), 2.0) + Math.pow(dot1.getY() - dot2.getY(), 2.0));
}

/**
* divide a dot list into a half length list as well as it's sorted bt dot.x
*
* @param dots
* @param isLeft
* @return
*/
public List<Dot> getDividePart(List<Dot> dots, boolean isLeft) {
List<Dot> _return = new ArrayList<>();
if (isLeft) {
for (int i = 0; i < dots.size() / 2; i++) {
_return.add(dots.get(i));
}
} else {
for (int i = dots.size() / 2; i < dots.size(); i++) {
_return.add(dots.get(i));
}
}
return _return;
}

/**
* get the min distance violently
*
* @param dots
* @return
*/
public double violentResolver(List<Dot> dots) {
double minDistance = Double.MAX_VALUE;
for (Dot dot1 : dots) {
for (Dot dot2 : dots) {
if (dot1.equals(dot2)) {
continue;
}
double distance = getDistance(dot1, dot2);
if (distance < minDistance) {
minDistance = distance;
}
}
}
return minDistance;
}

/**
* get min distance by dividing
*
* @param dots
* @return
*/
public double divideResolver(List<Dot> dots) {
double minDistance = Double.MAX_VALUE;
int midIndex = dots.size() / 2;
// End condition of Recursion
if (dots.size() == 2) {
return getDistance(dots.get(0), dots.get(1));
} else if (dots.size() == 3) {
double d1 = getDistance(dots.get(0), dots.get(1));
double d2 = getDistance(dots.get(0), dots.get(2));
double d3 = getDistance(dots.get(1), dots.get(2));
return Math.min(d1, Math.min(d2, d3));
}
Collections.sort(dots, new Comparator<Dot>() {
@Override
public int compare(Dot d1, Dot d2) {
if (d1.getX() < d2.getX()) {
return -1;
} else if (d1.getX() > d2.getX()) {
return 1;
}
return 0;
}
});
List<Dot> leftUnion = getDividePart(dots, true);
List<Dot> rightUnion = getDividePart(dots, false);
minDistance = Math.min(divideResolver(leftUnion), divideResolver(rightUnion));
Dot midDot = dots.get(midIndex);
for (int i = midIndex - 4 > 0 ? midIndex - 4 : 0; i < (midIndex + 3 > dots.size() ? dots.size() : midIndex + 3)
&& i != midIndex; i++) {
minDistance = Math.min(getDistance(midDot, dots.get(i)), minDistance);
}
return minDistance;
}

public static void main(String[] args) {
NearestDots nearestDots = new NearestDots();
int num = 10;
List<Dot> dots = nearestDots.generateDots(num);
double min = nearestDots.violentResolver(dots);
double min2 = nearestDots.divideResolver(dots);
System.out.println(min);
System.out.println(min2);
}

/**
* 点的数据结构
*/
public class Dot {
private double x;
private double y;

public Dot(double x, double y) {
this.x = x;
this.y = y;
}

public double getX() {
return x;
}

public void setX(double x) {
this.x = x;
}
public double getY() {
return y;
}
public void setY(double y) {
this.y = y;
}
}
}

实验分析

该算法主要是使用了分治递归的思想,难点在于处理两段分划合并时的情况。

合并的情况

此时左侧导出的点对是所有点对中距离最短的点对,右边的点对也是,接下来需要判断靠近分割线的两个点是也是最短的点对,如果是,那么这个点对就是该段合并后的点中最短距离的点对!

合并中还会出现极端情况,也就是有点在分界线上,所以我们应该人为规定在分界线上的点应该归属为左半边还是右半边。

而后我们可以根据上面这个图得出:在分界线$x=x_i$处,周围最多有七个点有可能比分治时的最小值要小,因此只需要在合并时比较这几个点之间的距离就可以了。

实验结果

思考题

  1. 分治法算法设计思想的三个基本步骤是什么?如何证明分治算法的正确性?
  • 问题划分
  • 递归求解
  • 合并子问题的解

使用数学归纳法来证明算法的正确性

  1. 利用主方式求解 Strassen’s 矩阵乘法和最近点对算法效率的递归分析结果。

Strassen’s算法
$$
Strassen’s算法: T(n) = 7T(\frac{n}{2}) + Θ(n^2)
$$
根据主方式得出$f(n) = Θ(n^{\log_2{7-3}})$所以其时间复杂度就是
$$
T(n) = \Theta(n^{\log_27})
$$
*
最近点对算法**
$$
T(n)=2*T(\frac{n}{2})+\Theta(n)
$$
根据主方式得出$f(n) = \Theta(n)$所以其时间复杂度就是
$$
T(n)=\Theta(n\log{n})
$$

  1. 解释怎样修改 Strassen’s 矩阵乘法算法,使得它也可以用于大小不必为 2 的幂的矩阵?

对半分割矩阵,只要能求解最小单元矩阵就能实现

Author: TankNee
Link: https://www.tanknee.cn/2020/04/15/algorithmanalysis_1-1/
Copyright Notice: All articles in this blog are licensed under CC BY-NC-SA 4.0 unless stating additionally.