@@ -22,9 +22,143 @@ private MatrixMultiplicationVerifier() {
22
22
static int [] multiply (int [][] matrix , int [] vector ) {
23
23
int n = vector .length ;
24
24
int [] result = new int [n ];
25
- for (int i = 0 ; i < n ; i ++)
26
- for (int j = 0 ; j < n ; j ++)
25
+ for (int i = 0 ; i < n ; i ++) {
26
+ for (int j = 0 ; j < n ; j ++) {
27
27
result [i ] += matrix [i ][j ] * vector [j ];
28
+ }
29
+ }
30
+ return result ;
31
+ }
32
+
33
+ /*
34
+ Actual function that performs verification function
35
+ @params, all three input matrices of int type, number of iterations
36
+ */
37
+ public static boolean verify (int [][] matrixA , int [][] matrixB , int [][] matrixC , int iterations ) {
38
+ if (matrixA .length == 0 || matrixB .length == 0 || matrixC .length == 0
39
+ || matrixA [0 ].length == 0 || matrixB [0 ].length == 0 || matrixC [0 ].length == 0 ) {
40
+ return matrixA .length == matrixB [0 ].length
41
+ && matrixB .length == matrixC .length
42
+ && matrixC [0 ].length == matrixA [0 ].length ;
43
+ }
44
+
45
+ if (iterations <= 0 ) {
46
+ throw new IllegalArgumentException ("Number of iterations must be positive" );
47
+ }
48
+
49
+ int n = matrixA .length ;
50
+ if (iterations > 2 * n ) {
51
+ throw new IllegalArgumentException ("Number of iterations should not exceed 2 * n where n is the matrix size" );
52
+ }
53
+
54
+ Random rand = new Random ();
55
+ for (int t = 0 ; t < iterations ; t ++) {
56
+ int [] r = new int [n ];
57
+ for (int i = 0 ; i < n ; i ++) {
58
+ r [i ] = rand .nextInt (2 );
59
+ }
60
+
61
+ int [] matrixBtimesR = multiply (matrixB , r );
62
+ int [] matrixAtimesBtimesR = multiply (matrixA , matrixBtimesR );
63
+ int [] matrixCtimesR = multiply (matrixC , r );
64
+
65
+ for (int i = 0 ; i < n ; i ++) {
66
+ if (matrixAtimesBtimesR [i ] != matrixCtimesR [i ]) {
67
+ return false ;
68
+ }
69
+ }
70
+ }
71
+ return true ;
72
+ }
73
+
74
+ /*
75
+ It multiplies input matrix of double type with randomized vector.
76
+ @params matrix which is being multiplied currently with random vector.
77
+ @params random vector generated for every iteration.
78
+
79
+ This basically calculates dot product for every row, which is used to verify whether the product of matrices is valid or not.
80
+ */
81
+ static double [] multiply (double [][] matrix , double [] vector ) {
82
+ int n = vector .length ;
83
+ double [] result = new double [n ];
84
+ for (int i = 0 ; i < n ; i ++) {
85
+ for (int j = 0 ; j < n ; j ++) {
86
+ result [i ] += matrix [i ][j ] * vector [j ];
87
+ }
88
+ }
89
+ return result ;
90
+ }
91
+
92
+ /*
93
+ Actual function that performs the verification.
94
+ @params, all three input matrices of double type, number of iterations
95
+ */
96
+ public static boolean verify (double [][] matrixA , double [][] matrixB , double [][] matrixC , int iterations ) {
97
+ if (matrixA .length == 0 || matrixB .length == 0 || matrixC .length == 0
98
+ || matrixA [0 ].length == 0 || matrixB [0 ].length == 0 || matrixC [0 ].length == 0 ) {
99
+ return matrixA .length == matrixB [0 ].length
100
+ && matrixB .length == matrixC .length
101
+ && matrixC [0 ].length == matrixA [0 ].length ;
102
+ }
103
+
104
+ if (iterations <= 0 ) {
105
+ throw new IllegalArgumentException ("Number of iterations must be positive" );
106
+ }
107
+
108
+ int m = matrixA .length ;
109
+ if (iterations > 2 * m ) {
110
+ throw new IllegalArgumentException ("Number of iterations should not exceed 2 times m where m is the matrix size" );
111
+ }
112
+
113
+ Random rand = new Random ();
114
+ for (int t = 0 ; t < iterations ; t ++) {
115
+ double [] randomizedVector = new double [m ];
116
+ for (int i = 0 ; i < m ; i ++) {
117
+ randomizedVector [i ] = rand .nextInt (2 );
118
+ }
119
+
120
+ double [] matrixBtimesR = multiply (matrixB , randomizedVector );
121
+ double [] matrixAtimesBtimesR = multiply (matrixA , matrixBtimesR );
122
+ double [] matrixCtimesR = multiply (matrixC , randomizedVector );
123
+
124
+ for (int i = 0 ; i < m ; i ++) {
125
+ if (Math .abs (matrixAtimesBtimesR [i ] - matrixCtimesR [i ]) > 1e-9 ) {
126
+ return false ;
127
+ }
128
+ }
129
+ }
130
+ return true ;
131
+ }
132
+ }
133
+ import java .util .Random ;
134
+
135
+ /*
136
+ This class implements the Randomized Matrix Multiplication Verification.
137
+ It generates a random vector and performs verification using Freivalds' Algorithm.
138
+ @author Menil-dev
139
+ */
140
+ public final class MatrixMultiplicationVerifier {
141
+
142
+ private MatrixMultiplicationVerifier () {
143
+ throw new UnsupportedOperationException ("Utility class" );
144
+ }
145
+
146
+ /*
147
+ It multiplies input matrix with randomized vector.
148
+ @params matrix which is being multiplied currently with random vector
149
+ @params random vector generated for every iteration.
150
+
151
+ This basically calculates dot product for every row, which is used to verify whether the product of matrices is valid or not.
152
+ @returns matrix of calculated dot product.
153
+ */
154
+ static int [] multiply (int [][] matrix , int [] vector ) {
155
+ int n = vector .length ;
156
+ int [] result = new int [n ];
157
+ for (int i = 0 ; i < n ; i ++) {
158
+ for (int j = 0 ; j < n ; j ++) {
159
+ result [i ] += matrix [i ][j ] * vector [j ];
160
+ }
161
+ }
28
162
return result ;
29
163
}
30
164
@@ -79,9 +213,11 @@ public static boolean verify(int[][] matrixA, int[][] matrixB, int[][] matrixC,
79
213
static double [] multiply (double [][] matrix , double [] vector ) {
80
214
int n = vector .length ;
81
215
double [] result = new double [n ];
82
- for (int i = 0 ; i < n ; i ++)
83
- for (int j = 0 ; j < n ; j ++)
216
+ for (int i = 0 ; i < n ; i ++) {
217
+ for (int j = 0 ; j < n ; j ++) {
84
218
result [i ] += matrix [i ][j ] * vector [j ];
219
+ }
220
+ }
85
221
return result ;
86
222
}
87
223
0 commit comments