Skip to content

Commit 1767c7f

Browse files
authored
Update MatrixMultiplicationVerifier.java
1 parent 5d375d7 commit 1767c7f

File tree

1 file changed

+140
-4
lines changed

1 file changed

+140
-4
lines changed

src/main/java/com/thealgorithms/randomized/MatrixMultiplicationVerifier.java

Lines changed: 140 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,143 @@ private MatrixMultiplicationVerifier() {
2222
static int[] multiply(int[][] matrix, int[] vector) {
2323
int n = vector.length;
2424
int[] result = new int[n];
25-
for (int i = 0; i < n; i++)
26-
for (int j = 0; j < n; j++)
25+
for (int i = 0; i < n; i++) {
26+
for (int j = 0; j < n; j++) {
2727
result[i] += matrix[i][j] * vector[j];
28+
}
29+
}
30+
return result;
31+
}
32+
33+
/*
34+
Actual function that performs verification function
35+
@params, all three input matrices of int type, number of iterations
36+
*/
37+
public static boolean verify(int[][] matrixA, int[][] matrixB, int[][] matrixC, int iterations) {
38+
if (matrixA.length == 0 || matrixB.length == 0 || matrixC.length == 0
39+
|| matrixA[0].length == 0 || matrixB[0].length == 0 || matrixC[0].length == 0) {
40+
return matrixA.length == matrixB[0].length
41+
&& matrixB.length == matrixC.length
42+
&& matrixC[0].length == matrixA[0].length;
43+
}
44+
45+
if (iterations <= 0) {
46+
throw new IllegalArgumentException("Number of iterations must be positive");
47+
}
48+
49+
int n = matrixA.length;
50+
if (iterations > 2 * n) {
51+
throw new IllegalArgumentException("Number of iterations should not exceed 2 * n where n is the matrix size");
52+
}
53+
54+
Random rand = new Random();
55+
for (int t = 0; t < iterations; t++) {
56+
int[] r = new int[n];
57+
for (int i = 0; i < n; i++) {
58+
r[i] = rand.nextInt(2);
59+
}
60+
61+
int[] matrixBtimesR = multiply(matrixB, r);
62+
int[] matrixAtimesBtimesR = multiply(matrixA, matrixBtimesR);
63+
int[] matrixCtimesR = multiply(matrixC, r);
64+
65+
for (int i = 0; i < n; i++) {
66+
if (matrixAtimesBtimesR[i] != matrixCtimesR[i]) {
67+
return false;
68+
}
69+
}
70+
}
71+
return true;
72+
}
73+
74+
/*
75+
It multiplies input matrix of double type with randomized vector.
76+
@params matrix which is being multiplied currently with random vector.
77+
@params random vector generated for every iteration.
78+
79+
This basically calculates dot product for every row, which is used to verify whether the product of matrices is valid or not.
80+
*/
81+
static double[] multiply(double[][] matrix, double[] vector) {
82+
int n = vector.length;
83+
double[] result = new double[n];
84+
for (int i = 0; i < n; i++) {
85+
for (int j = 0; j < n; j++) {
86+
result[i] += matrix[i][j] * vector[j];
87+
}
88+
}
89+
return result;
90+
}
91+
92+
/*
93+
Actual function that performs the verification.
94+
@params, all three input matrices of double type, number of iterations
95+
*/
96+
public static boolean verify(double[][] matrixA, double[][] matrixB, double[][] matrixC, int iterations) {
97+
if (matrixA.length == 0 || matrixB.length == 0 || matrixC.length == 0
98+
|| matrixA[0].length == 0 || matrixB[0].length == 0 || matrixC[0].length == 0) {
99+
return matrixA.length == matrixB[0].length
100+
&& matrixB.length == matrixC.length
101+
&& matrixC[0].length == matrixA[0].length;
102+
}
103+
104+
if (iterations <= 0) {
105+
throw new IllegalArgumentException("Number of iterations must be positive");
106+
}
107+
108+
int m = matrixA.length;
109+
if (iterations > 2 * m) {
110+
throw new IllegalArgumentException("Number of iterations should not exceed 2 times m where m is the matrix size");
111+
}
112+
113+
Random rand = new Random();
114+
for (int t = 0; t < iterations; t++) {
115+
double[] randomizedVector = new double[m];
116+
for (int i = 0; i < m; i++) {
117+
randomizedVector[i] = rand.nextInt(2);
118+
}
119+
120+
double[] matrixBtimesR = multiply(matrixB, randomizedVector);
121+
double[] matrixAtimesBtimesR = multiply(matrixA, matrixBtimesR);
122+
double[] matrixCtimesR = multiply(matrixC, randomizedVector);
123+
124+
for (int i = 0; i < m; i++) {
125+
if (Math.abs(matrixAtimesBtimesR[i] - matrixCtimesR[i]) > 1e-9) {
126+
return false;
127+
}
128+
}
129+
}
130+
return true;
131+
}
132+
}
133+
import java.util.Random;
134+
135+
/*
136+
This class implements the Randomized Matrix Multiplication Verification.
137+
It generates a random vector and performs verification using Freivalds' Algorithm.
138+
@author Menil-dev
139+
*/
140+
public final class MatrixMultiplicationVerifier {
141+
142+
private MatrixMultiplicationVerifier() {
143+
throw new UnsupportedOperationException("Utility class");
144+
}
145+
146+
/*
147+
It multiplies input matrix with randomized vector.
148+
@params matrix which is being multiplied currently with random vector
149+
@params random vector generated for every iteration.
150+
151+
This basically calculates dot product for every row, which is used to verify whether the product of matrices is valid or not.
152+
@returns matrix of calculated dot product.
153+
*/
154+
static int[] multiply(int[][] matrix, int[] vector) {
155+
int n = vector.length;
156+
int[] result = new int[n];
157+
for (int i = 0; i < n; i++) {
158+
for (int j = 0; j < n; j++) {
159+
result[i] += matrix[i][j] * vector[j];
160+
}
161+
}
28162
return result;
29163
}
30164

@@ -79,9 +213,11 @@ public static boolean verify(int[][] matrixA, int[][] matrixB, int[][] matrixC,
79213
static double[] multiply(double[][] matrix, double[] vector) {
80214
int n = vector.length;
81215
double[] result = new double[n];
82-
for (int i = 0; i < n; i++)
83-
for (int j = 0; j < n; j++)
216+
for (int i = 0; i < n; i++) {
217+
for (int j = 0; j < n; j++) {
84218
result[i] += matrix[i][j] * vector[j];
219+
}
220+
}
85221
return result;
86222
}
87223

0 commit comments

Comments
 (0)