From eb44ff0c9f6ce1f7e794d4eaa185a57828c98883 Mon Sep 17 00:00:00 2001 From: Janardhan Pulivarthi Date: Tue, 2 Sep 2025 22:53:12 +0530 Subject: [PATCH] [SYSTEMDS-3910] OOC matrix-matrix multiplication This patch introduces the MatrixMatrix multiplication logic. It performs a shuffle-based matrix multiplication on two large matrix streams. Implementation Detail: Asynchronous Producer: The processInstruction method launches a background thread to perform the entire two-phase multiplication, but returns control to the main thread immediately. This non-blocking setup allows the compiler to build the downstream executionplan while the OOC operation prepares to run upon data request. Two-Phase Streaming Logic: The background thread implements a shuffle-based algorithm to handle two large inputs: * Phase 1 (Grouping/Shuffle): It first consumes both input streams entirely. Blocks from each stream (A_ik and B_kj) are partitioned into groups based on the output block index (C_ij) they contribute to. A HashMap stores these groups, effectively "shuffling" the data for parallel processing. * Phase 2 (Aggregation/Reduce): After grouping, it processes each group independently. Within a group, it pairs the corresponding blocks using their common index k, performs the block-level multiplication, and aggregates the results to produce a single, final output block which is then enqueued to theoutput stream. Robust Block Identification: A TaggedMatrixValue wrapper is used during the grouping phase to explicitly tag each block with its source matrix (A or B). This ensures correct and unambiguous identification during the aggregation phase, a critical requirement that cannot be met by relying on block dimensions alone. Integration: The new instruction is fully integrated into the OOC framework: * The OOCInstructionParser is updated to recognize the aggregate binary in OOC context. --- .../instructions/OOCInstructionParser.java | 4 +- .../ooc/MatrixMultiplyOOCInstruction.java | 319 ++++++++++++++++++ .../ooc/MatrixVectorBinaryOOCInstruction.java | 155 --------- .../MatrixMatrixBinaryMultiplicationTest.java | 121 +++++++ .../ooc/MatrixMatrixMultiplication.dml | 29 ++ 5 files changed, 471 insertions(+), 157 deletions(-) create mode 100644 src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixMultiplyOOCInstruction.java delete mode 100644 src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java create mode 100644 src/test/java/org/apache/sysds/test/functions/ooc/MatrixMatrixBinaryMultiplicationTest.java create mode 100644 src/test/scripts/functions/ooc/MatrixMatrixMultiplication.dml diff --git a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java index 73b5ca02618..b5cf58a7463 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java @@ -25,10 +25,10 @@ import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.instructions.ooc.AggregateUnaryOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.BinaryOOCInstruction; +import org.apache.sysds.runtime.instructions.ooc.MatrixMultiplyOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.OOCInstruction; import org.apache.sysds.runtime.instructions.ooc.ReblockOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.UnaryOOCInstruction; -import org.apache.sysds.runtime.instructions.ooc.MatrixVectorBinaryOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.TransposeOOCInstruction; public class OOCInstructionParser extends InstructionParser { @@ -60,7 +60,7 @@ public static OOCInstruction parseSingleInstruction(InstructionType ooctype, Str return BinaryOOCInstruction.parseInstruction(str); case AggregateBinary: case MAPMM: - return MatrixVectorBinaryOOCInstruction.parseInstruction(str); + return MatrixMultiplyOOCInstruction.parseInstruction(str); case Reorg: return TransposeOOCInstruction.parseInstruction(str); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixMultiplyOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixMultiplyOOCInstruction.java new file mode 100644 index 00000000000..e424ed40d90 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixMultiplyOOCInstruction.java @@ -0,0 +1,319 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.instructions.ooc; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ExecutorService; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.conf.ConfigurationManager; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; +import org.apache.sysds.runtime.functionobjects.Multiply; +import org.apache.sysds.runtime.functionobjects.Plus; +import org.apache.sysds.runtime.instructions.InstructionUtils; +import org.apache.sysds.runtime.instructions.cp.CPOperand; +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.data.MatrixIndexes; +import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator; +import org.apache.sysds.runtime.matrix.operators.AggregateOperator; +import org.apache.sysds.runtime.matrix.operators.BinaryOperator; +import org.apache.sysds.runtime.matrix.operators.Operator; +import org.apache.sysds.runtime.util.CommonThreadPool; + +public class MatrixMultiplyOOCInstruction extends ComputationOOCInstruction { + + + protected MatrixMultiplyOOCInstruction(OOCType type, Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) { + super(type, op, in1, in2, out, opcode, istr); + } + + public static MatrixMultiplyOOCInstruction parseInstruction(String str) { + String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); + InstructionUtils.checkNumFields(parts, 4); + String opcode = parts[0]; + CPOperand in1 = new CPOperand(parts[1]); // the larget matrix (streamed) + CPOperand in2 = new CPOperand(parts[2]); // the small vector (in-memory) + CPOperand out = new CPOperand(parts[3]); + + AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); + AggregateBinaryOperator ba = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg); + + return new MatrixMultiplyOOCInstruction(OOCType.MAPMM, ba, in1, in2, out, opcode, str); + } + + @Override + public void processInstruction( ExecutionContext ec ) { + + if (ec.getMatrixObject(input2).getDataCharacteristics().getCols() == 1) { + _processMatrixVector(ec); + } else { + _processMatrixMatrix(ec); + } + } + + private void _processMatrixVector( ExecutionContext ec ) { + // 1. Identify the inputs + MatrixObject min = ec.getMatrixObject(input1); // big matrix + MatrixBlock vin = ec.getMatrixObject(input2) + .acquireReadAndRelease(); // in-memory vector + + // 2. Pre-partition the in-memory vector into a hashmap + HashMap partitionedVector = new HashMap<>(); + int blksize = vin.getDataCharacteristics().getBlocksize(); + if (blksize < 0) + blksize = ConfigurationManager.getBlocksize(); + for (int i = 0; i < vin.getNumRows(); i += blksize) { + long key = (long) (i / blksize) + 1; // the key starts at 1 + int end_row = Math.min(i + blksize, vin.getNumRows()); + MatrixBlock vectorSlice = vin.slice(i, end_row - 1); + partitionedVector.put(key, vectorSlice); + } + + LocalTaskQueue qIn = min.getStreamHandle(); + LocalTaskQueue qOut = new LocalTaskQueue<>(); + BinaryOperator plus = InstructionUtils.parseBinaryOperator(Opcodes.PLUS.toString()); + ec.getMatrixObject(output).setStreamHandle(qOut); + + ExecutorService pool = CommonThreadPool.get(); + try { + // Core logic: background thread + pool.submit(() -> { + IndexedMatrixValue tmp = null; + try { + HashMap partialResults = new HashMap<>(); + while ((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { + MatrixBlock matrixBlock = (MatrixBlock) tmp.getValue(); + long rowIndex = tmp.getIndexes().getRowIndex(); + long colIndex = tmp.getIndexes().getColumnIndex(); + MatrixBlock vectorSlice = partitionedVector.get(colIndex); + + // Now, call the operation with the correct, specific operator. + MatrixBlock partialResult = matrixBlock.aggregateBinaryOperations( + matrixBlock, vectorSlice, new MatrixBlock(), (AggregateBinaryOperator) _optr); + + // for single column block, no aggregation neeeded + if (min.getNumColumns() <= min.getBlocksize()) { + qOut.enqueueTask(new IndexedMatrixValue(tmp.getIndexes(), partialResult)); + } else { + MatrixBlock currAgg = partialResults.get(rowIndex); + if (currAgg == null) + partialResults.put(rowIndex, partialResult); + else + currAgg.binaryOperationsInPlace(plus, partialResult); + } + } + + // emit aggregated blocks + if (min.getNumColumns() > min.getBlocksize()) { + for (Map.Entry entry : partialResults.entrySet()) { + MatrixIndexes outIndexes = new MatrixIndexes(entry.getKey(), 1L); + qOut.enqueueTask(new IndexedMatrixValue(outIndexes, entry.getValue())); + } + } + } catch (Exception ex) { + throw new DMLRuntimeException(ex); + } finally { + qOut.closeInput(); + } + }); + } catch (Exception e) { + throw new DMLRuntimeException(e); + } finally { + pool.shutdown(); + } + } + + private void _processMatrixMatrix( ExecutionContext ec ) { + // 1. Identify the inputs + MatrixObject min = ec.getMatrixObject(input1); // big matrix + MatrixObject min2 = ec.getMatrixObject(input2); + + LocalTaskQueue qIn1 = min.getStreamHandle(); + LocalTaskQueue qIn2 = min2.getStreamHandle(); + LocalTaskQueue qOut = new LocalTaskQueue<>(); + BinaryOperator plus = InstructionUtils.parseBinaryOperator(Opcodes.PLUS.toString()); + ec.getMatrixObject(output).setStreamHandle(qOut); + + // Result matrix rows, cols = rows of A, cols of B + long resultRowBlocks = min.getDataCharacteristics().getNumRowBlocks(); + long resultColBlocks = min2.getDataCharacteristics().getNumColBlocks(); + + ExecutorService pool = CommonThreadPool.get(); + try { + // Core logic: background thread + pool.submit(() -> { + IndexedMatrixValue tmpA = null; + IndexedMatrixValue tmpB = null; + try { + // Phase 1: grouping the output blocks by block Index (The Shuffle) + Map> groupedBlocks = new HashMap<>(); + HashMap partialResults = new HashMap<>(); + + // Process matrix A: each block A(i,k) contributes to C(i,j) for all j + while((tmpA = qIn1.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { + long i = tmpA.getIndexes().getRowIndex() - 1; + long k = tmpA.getIndexes().getColumnIndex() - 1; + + for (int j=0; j new ArrayList<>()).add(taggedValue); + } + } + + // Process matrix B: each block B(k,j) contributes to C(i,j) for all i + while((tmpB = qIn2.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { + long k = tmpB.getIndexes().getRowIndex() - 1; + long j = tmpB.getIndexes().getColumnIndex() - 1; + + for (int i=0; i new ArrayList<>()).add(taggedValue); + } + } + + + // Phase 2: Multiplication and Aggregation + Map resultBlocks = new HashMap<>(); + + // Process each output block separately + for (Map.Entry> entry : groupedBlocks.entrySet()) { + MatrixIndexes outIndex = entry.getKey(); + List outValues = entry.getValue(); + + // For this output block, collect left and right input blocks + Map leftBlocks = new HashMap<>(); + Map rightBlocks = new HashMap<>(); + + // Organize blocks by k-index + for (TaggedMatrixValue taggedValue : outValues) { + IndexedMatrixValue value = taggedValue.getValue(); + long kIndex = taggedValue.getkIndex(); + + if (taggedValue.isFirstInput()) { + leftBlocks.put(kIndex, (MatrixBlock)value.getValue()); + } else { + rightBlocks.put(kIndex, (MatrixBlock)value.getValue()); + } + } + + // Create result block for this (i,j) position + MatrixBlock resultBlock = null; + + // Find k-indices that exist in both left and right + Set commonKIndices = new HashSet<>(leftBlocks.keySet()); + commonKIndices.retainAll(rightBlocks.keySet()); + + // Multiply and aggregate matching blocks + for (Long k : commonKIndices) { + MatrixBlock leftBlock = leftBlocks.get(k); + MatrixBlock rightBlock = rightBlocks.get(k); + + // Multiply matching blocks + MatrixBlock partialResult = leftBlock.aggregateBinaryOperations(leftBlock, + rightBlock, + new MatrixBlock(), + InstructionUtils.getMatMultOperator(1)); + + if (resultBlock == null) { + resultBlock = partialResult; + } else { + resultBlock = resultBlock.binaryOperationsInPlace(plus, partialResult); + } + } + + // Store the final result for this output block + if (resultBlock != null) { + resultBlocks.put(outIndex, resultBlock); + } + } + + // Enqueue all results after all multiplications are complete + for (Map.Entry entry : resultBlocks.entrySet()) { + MatrixIndexes outIdx0 = entry.getKey(); + MatrixBlock outBlock = entry.getValue(); + MatrixIndexes outIdx = new MatrixIndexes(outIdx0.getRowIndex() + 1, + outIdx0.getColumnIndex() + 1); + outBlock.checkSparseRows(); + qOut.enqueueTask(new IndexedMatrixValue(outIdx, outBlock)); + } + + } + catch(Exception ex) { + throw new DMLRuntimeException(ex); + } + finally { + qOut.closeInput(); + } + }); + } catch (Exception e) { + throw new DMLRuntimeException(e); + } + finally { + pool.shutdown(); + } + } + + /** + * Helper class to tag matrix block with their source and k-index + */ + private static class TaggedMatrixValue { + IndexedMatrixValue _value; + private long _kIndex; + private boolean _isFirstInput; + + public TaggedMatrixValue(IndexedMatrixValue value, boolean isFirstInput, long kIndex) { + this._value = value; + this._isFirstInput = isFirstInput; + this._kIndex = kIndex; + } + + public IndexedMatrixValue getValue() { + return _value; + } + + public boolean isFirstInput() { + return _isFirstInput; + } + + public long getkIndex() { + return _kIndex; + } + } +} diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java deleted file mode 100644 index 5e2d36d9df3..00000000000 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java +++ /dev/null @@ -1,155 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysds.runtime.instructions.ooc; - -import java.util.HashMap; -import java.util.concurrent.ExecutorService; - -import org.apache.sysds.common.Opcodes; -import org.apache.sysds.conf.ConfigurationManager; -import org.apache.sysds.runtime.DMLRuntimeException; -import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; -import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; -import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; -import org.apache.sysds.runtime.functionobjects.Multiply; -import org.apache.sysds.runtime.functionobjects.Plus; -import org.apache.sysds.runtime.instructions.InstructionUtils; -import org.apache.sysds.runtime.instructions.cp.CPOperand; -import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; -import org.apache.sysds.runtime.matrix.data.MatrixBlock; -import org.apache.sysds.runtime.matrix.data.MatrixIndexes; -import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator; -import org.apache.sysds.runtime.matrix.operators.AggregateOperator; -import org.apache.sysds.runtime.matrix.operators.BinaryOperator; -import org.apache.sysds.runtime.matrix.operators.Operator; -import org.apache.sysds.runtime.util.CommonThreadPool; - -public class MatrixVectorBinaryOOCInstruction extends ComputationOOCInstruction { - - - protected MatrixVectorBinaryOOCInstruction(OOCType type, Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) { - super(type, op, in1, in2, out, opcode, istr); - } - - public static MatrixVectorBinaryOOCInstruction parseInstruction(String str) { - String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); - InstructionUtils.checkNumFields(parts, 4); - String opcode = parts[0]; - CPOperand in1 = new CPOperand(parts[1]); // the larget matrix (streamed) - CPOperand in2 = new CPOperand(parts[2]); // the small vector (in-memory) - CPOperand out = new CPOperand(parts[3]); - - AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); - AggregateBinaryOperator ba = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg); - - return new MatrixVectorBinaryOOCInstruction(OOCType.MAPMM, ba, in1, in2, out, opcode, str); - } - - @Override - public void processInstruction( ExecutionContext ec ) { - // 1. Identify the inputs - MatrixObject min = ec.getMatrixObject(input1); // big matrix - MatrixBlock vin = ec.getMatrixObject(input2) - .acquireReadAndRelease(); // in-memory vector - - // 2. Pre-partition the in-memory vector into a hashmap - HashMap partitionedVector = new HashMap<>(); - int blksize = vin.getDataCharacteristics().getBlocksize(); - if (blksize < 0) - blksize = ConfigurationManager.getBlocksize(); - for (int i=0; i qIn = min.getStreamHandle(); - LocalTaskQueue qOut = new LocalTaskQueue<>(); - BinaryOperator plus = InstructionUtils.parseBinaryOperator(Opcodes.PLUS.toString()); - ec.getMatrixObject(output).setStreamHandle(qOut); - - ExecutorService pool = CommonThreadPool.get(); - try { - // Core logic: background thread - pool.submit(() -> { - IndexedMatrixValue tmp = null; - try { - HashMap partialResults = new HashMap<>(); - HashMap cnt = new HashMap<>(); - while((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { - MatrixBlock matrixBlock = (MatrixBlock) tmp.getValue(); - long rowIndex = tmp.getIndexes().getRowIndex(); - long colIndex = tmp.getIndexes().getColumnIndex(); - MatrixBlock vectorSlice = partitionedVector.get(colIndex); - - // Now, call the operation with the correct, specific operator. - MatrixBlock partialResult = matrixBlock.aggregateBinaryOperations( - matrixBlock, vectorSlice, new MatrixBlock(), (AggregateBinaryOperator) _optr); - - // for single column block, no aggregation neeeded - if( min.getNumColumns() <= min.getBlocksize() ) { - qOut.enqueueTask(new IndexedMatrixValue(tmp.getIndexes(), partialResult)); - } - else { - // aggregation - MatrixBlock currAgg = partialResults.get(rowIndex); - if (currAgg == null) { - partialResults.put(rowIndex, partialResult); - cnt.put(rowIndex, 1); - } - else { - currAgg.binaryOperationsInPlace(plus, partialResult); - int newCnt = cnt.get(rowIndex) + 1; - - if(newCnt == nBlocks){ - // early block output: emit aggregated block - MatrixIndexes idx = new MatrixIndexes(rowIndex, 1L); - MatrixBlock result = partialResults.get(rowIndex); - qOut.enqueueTask(new IndexedMatrixValue(idx, result)); - partialResults.remove(rowIndex); - cnt.remove(rowIndex); - } - else { - // maintain aggregation counts if not output-ready yet - cnt.replace(rowIndex, newCnt); - } - } - } - } - } - catch(Exception ex) { - throw new DMLRuntimeException(ex); - } - finally { - qOut.closeInput(); - } - }); - } catch (Exception e) { - throw new DMLRuntimeException(e); - } - finally { - pool.shutdown(); - } - } -} diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/MatrixMatrixBinaryMultiplicationTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/MatrixMatrixBinaryMultiplicationTest.java new file mode 100644 index 00000000000..1b37dbfa4b5 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/ooc/MatrixMatrixBinaryMultiplicationTest.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.ooc; + +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.functionobjects.Multiply; +import org.apache.sysds.runtime.functionobjects.Plus; +import org.apache.sysds.runtime.io.MatrixWriter; +import org.apache.sysds.runtime.io.MatrixWriterFactory; +import org.apache.sysds.runtime.matrix.data.LibMatrixMult; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator; +import org.apache.sysds.runtime.matrix.operators.AggregateOperator; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.util.DataConverter; +import org.apache.sysds.runtime.util.HDFSTool; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.apache.sysds.test.functions.binary.matrix.MatrixMultiplicationTest; +import org.junit.Assert; +import org.junit.Test; + +import java.io.IOException; + +public class MatrixMatrixBinaryMultiplicationTest extends AutomatedTestBase { + private final static String TEST_NAME1 = "MatrixMatrixMultiplication"; + private final static String TEST_DIR = "functions/ooc/"; + private final static String TEST_CLASS_DIR = TEST_DIR + MatrixMatrixBinaryMultiplicationTest.class.getSimpleName() + "/"; + private final static double eps = 1e-10; + private static final String INPUT_NAME = "X"; + private static final String INPUT_NAME2 = "Y"; + private static final String OUTPUT_NAME = "res"; + + private final static int rows = 1000; + private final static int cols_wide = 1000; + private final static int rows2 = 1000; + private final static int cols2 = 1000; + + private final static double sparsity1 = 0.7; + private final static double sparsity2 = 0.1; + private final int k = 1; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + TestConfiguration config = new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1); + addTestConfiguration(TEST_NAME1, config); + } + + @Test + public void testMMBinaryMultiplication1() { + runMatrixMatrixMultiplicationTest(cols_wide, false); + } + + private void runMatrixMatrixMultiplicationTest(int cols, boolean sparse ) + { + Types.ExecMode platformOld = setExecMode(Types.ExecMode.SINGLE_NODE); + + try + { + getAndLoadTestConfiguration(TEST_NAME1); + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME1 + ".dml"; + programArgs = new String[]{"-explain", "-stats", "-ooc", + "-args", input(INPUT_NAME), input(INPUT_NAME2), output(OUTPUT_NAME)}; + + // 1. Generate the data in-memory as MatrixBlock objects + double[][] A_data = getRandomMatrix(rows, cols, 0, 1, sparse?sparsity2:sparsity1, 10); + double[][] B_data = getRandomMatrix(rows2, cols2, 0, 1, 1.0, 10); + + // 2. Convert the double arrays to MatrixBlock objects + MatrixBlock A_mb = DataConverter.convertToMatrixBlock(A_data); + MatrixBlock B_mb = DataConverter.convertToMatrixBlock(B_data); + + // 3. Create a binary matrix writer + MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY); + + // 4. Write matrix A to a binary SequenceFile + writer.writeMatrixToHDFS(A_mb, input(INPUT_NAME), rows, cols, 1000, A_mb.getNonZeros()); + HDFSTool.writeMetaDataFile(input(INPUT_NAME + ".mtd"), Types.ValueType.FP64, + new MatrixCharacteristics(rows, cols, 1000, A_mb.getNonZeros()), Types.FileFormat.BINARY); + + // 5. Write vector B to a binary SequenceFile + writer.writeMatrixToHDFS(B_mb, input(INPUT_NAME2), rows2, cols2, 1000, B_mb.getNonZeros()); + HDFSTool.writeMetaDataFile(input(INPUT_NAME2 + ".mtd"), Types.ValueType.FP64, + new MatrixCharacteristics(rows2, cols2, 1000, B_mb.getNonZeros()), Types.FileFormat.BINARY); + + boolean exceptionExpected = false; + runTest(true, exceptionExpected, null, -1); + + MatrixBlock ret1 = DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME), Types.FileFormat.BINARY, rows, cols, 1000, 1000); + MatrixBlock ret2 = LibMatrixMult.matrixMult(A_mb, B_mb, k); + TestUtils.compareMatrices(ret1, ret2, 1e-8); + } + catch (IOException e) { + throw new RuntimeException(e); + } + finally { + resetExecMode(platformOld); + } + } + +} diff --git a/src/test/scripts/functions/ooc/MatrixMatrixMultiplication.dml b/src/test/scripts/functions/ooc/MatrixMatrixMultiplication.dml new file mode 100644 index 00000000000..89b7e9d1719 --- /dev/null +++ b/src/test/scripts/functions/ooc/MatrixMatrixMultiplication.dml @@ -0,0 +1,29 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Read input matrix and operator from command line args +X = read($1); +Y = read($2); + +# Operation under test +res = X %*% Y; + +write(res, $3, format="binary")