Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYSTEMDS-3729] Add roll reorg operations in CP, python script test #2103

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
9b3d5c3
[SYSTEMDS-3729] Add roll reorg operations in CP, python script test
min-guk Sep 6, 2024
12b9073
[SYSTEMDS-3729] Code review feedback: improve the roll function (spar…
min-guk Sep 7, 2024
9a9a0d5
[SYSTEMDS-3729] Add pip install numpy to javaTests.yml
min-guk Sep 7, 2024
ad0254f
[SYSTEMDS-3729] Add pip install numpy to javaTests.yml
min-guk Sep 7, 2024
580b238
[SYSTEMDS-3729] Change javaTest list for quick check in github
min-guk Sep 7, 2024
e32fd2d
[SYSTEMDS-3729] fix javaTests.yml for gitHub actions workflow
min-guk Sep 7, 2024
9263af5
[SYSTEMDS-3729] fix javaTests.yml for gitHub actions workflow
min-guk Sep 7, 2024
28a1490
Merge remote-tracking branch 'origin/main' into main
min-guk Sep 7, 2024
17ed316
[SYSTEMDS-3729] fix testsysds.Dockerfile for gitHub actions workflow
min-guk Sep 7, 2024
3d88c1a
[SYSTEMDS-3729] fix FullRollTest for gitHub actions workflow
min-guk Sep 7, 2024
ece1a79
[SYSTEMDS-3729] fix FullRollTest for gitHub actions workflow
min-guk Sep 7, 2024
acc1f73
[SYSTEMDS-3729] rollback javaTests.yml
min-guk Sep 7, 2024
f96931f
[SYSTEMDS-3729] reformat codes
min-guk Sep 9, 2024
87bb38d
[SYSTEMDS-3729] Roll function optimization, removal of duplicate Java…
min-guk Sep 9, 2024
990e590
[SYSTEMDS-3729] Fix roll function
min-guk Sep 9, 2024
8663c23
[SYSTEMDS-3768] Python test of roll function
min-guk Sep 11, 2024
a419330
[SYSTEMDS-3768] Remove implementation: Python test script for the Jav…
min-guk Sep 11, 2024
b452b49
[SYSTEMDS-3768] Add python test code for sparse matrix
min-guk Sep 11, 2024
ec6b736
[SYSTEMDS-3729] fix precision of sparse matrix test
min-guk Sep 13, 2024
db7d600
[SYSTEMDS-3729] fix roll sparse function
min-guk Sep 14, 2024
5c23bbe
[SYSTEMDS-3729] Add java test for code coverage
min-guk Sep 16, 2024
2b2b74d
[SYSTEMDS-3729] fix java test for code coverage
min-guk Sep 19, 2024
023c591
[SYSTEMDS-3729] delete blank lines
min-guk Sep 19, 2024
db4c085
Merge branch 'main' of https://github.com/apache/systemds into main
min-guk Sep 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/main/java/org/apache/sysds/common/Builtins.java
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ public enum Builtins {
RCM("rowClassMeet", "rcm", false, false, ReturnType.MULTI_RETURN),
REMOVE("remove", false, ReturnType.MULTI_RETURN),
REV("rev", false),
ROLL("roll", false),
ROUND("round", false),
ROW_COUNT_DISTINCT("rowCountDistinct",false),
ROWINDEXMAX("rowIndexMax", false),
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/org/apache/sysds/common/Types.java
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ public boolean isCellOp() {
/** Operations that perform internal reorganization of an allocation */
public enum ReOrgOp {
DIAG, //DIAG_V2M and DIAG_M2V could not be distinguished if sizes unknown
RESHAPE, REV, SORT, TRANS;
RESHAPE, REV, ROLL, SORT, TRANS;

@Override
public String toString() {
Expand Down
21 changes: 20 additions & 1 deletion src/main/java/org/apache/sysds/hops/ReorgOp.java
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ public void checkArity() {
case REV:
HopsException.check(sz == 1, this, "should have arity 1 for op %s but has arity %d", _op, sz);
break;
case ROLL:
HopsException.check(sz == 2, this, "should have arity 2 for op %s but has arity %d", _op, sz);
break;
case RESHAPE:
case SORT:
HopsException.check(sz == 5, this, "should have arity 5 for op %s but has arity %d", _op, sz);
Expand Down Expand Up @@ -125,6 +128,7 @@ public boolean isGPUEnabled() {
}
case DIAG:
case REV:
case ROLL:
case SORT:
return false;
default:
Expand Down Expand Up @@ -175,6 +179,19 @@ else if( getDim1()==1 && getDim2()==1 )
setLops(transform1);
break;
}
case ROLL:{
Lop[] linputs = new Lop[2]; //input, shift
for (int i = 0; i < 2; i++)
linputs[i] = getInput().get(i).constructLops();

Transform transform1 = new Transform(
linputs, _op, getDataType(), getValueType(), et, 1);

setOutputDimensions(transform1);
setLineNumbers(transform1);
setLops(transform1);
break;
}
case RESHAPE: {
Lop[] linputs = new Lop[5]; //main, rows, cols, dims, byrow
for (int i = 0; i < 5; i++)
Expand Down Expand Up @@ -279,7 +296,8 @@ protected DataCharacteristics inferOutputCharacteristics( MemoTable memo )
ret = new MatrixCharacteristics(dc.getCols(), dc.getRows(), -1, dc.getNonZeros());
break;
}
case REV: {
case REV:
case ROLL: {
// dims and nnz are exactly the same as in input
if( dc.dimsKnown() )
ret = new MatrixCharacteristics(dc.getRows(), dc.getCols(), -1, dc.getNonZeros());
Expand Down Expand Up @@ -397,6 +415,7 @@ public void refreshSizeInformation()
break;
}
case REV:
case ROLL:
{
// dims and nnz are exactly the same as in input
setDim1(input1.getDim1());
Expand Down
12 changes: 11 additions & 1 deletion src/main/java/org/apache/sysds/lops/Transform.java
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,11 @@ private String getOpcode() {
case REV:
// Transpose a matrix
return "rev";


case ROLL:
// Transpose a matrix
Baunsgaard marked this conversation as resolved.
Show resolved Hide resolved
return "roll";

case DIAG:
// Transform a vector into a diagonal matrix
return "rdiag";
Expand All @@ -138,6 +142,12 @@ public String getInstructions(String input1, String output) {
return getInstructions(input1, 1, output);
}

@Override
public String getInstructions(String input1, String input2, String output) {
//opcodes: roll
return getInstructions(input1, 2, output);
}

@Override
public String getInstructions(String input1, String input2, String input3, String input4, String output) {
//opcodes: rsort
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1266,6 +1266,16 @@ else if( getOpCode() == Builtins.RBIND ) {
output.setBlocksize (id.getBlocksize());
output.setValueType(id.getValueType());
break;

case ROLL:
checkNumParameters(2);
checkMatrixParam(getFirstExpr());
checkScalarParam(getSecondExpr());
output.setDataType(DataType.MATRIX);
output.setDimensions(id.getDim1(), id.getDim2());
output.setBlocksize (id.getBlocksize());
output.setValueType(id.getValueType());
break;

case DIAG:
checkNumParameters(1);
Expand Down
8 changes: 8 additions & 0 deletions src/main/java/org/apache/sysds/parser/DMLTranslator.java
Original file line number Diff line number Diff line change
Expand Up @@ -2481,6 +2481,14 @@ else if ( sop.equalsIgnoreCase("!=") )
target.getValueType(), ReOrgOp.valueOf(source.getOpCode().name()), expr);
break;

case ROLL:
ArrayList<Hop> inputs = new ArrayList<>();
inputs.add(expr);
inputs.add(expr2);
currBuiltinOp = new ReorgOp(target.getName(), DataType.MATRIX,
target.getValueType(), ReOrgOp.valueOf(source.getOpCode().name()), inputs);
break;

case CBIND:
case RBIND:
OpOp2 appendOp2 = (source.getOpCode()==Builtins.CBIND) ? OpOp2.CBIND : OpOp2.RBIND;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.sysds.runtime.functionobjects;
import org.apache.commons.lang3.NotImplementedException;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
import org.apache.sysds.runtime.meta.DataCharacteristics;

/**
* This index function is NOT used for actual sorting but just as a reference
* in ReorgOperator in order to identify sort operations.
*
*/
public class RollIndex extends IndexFunction
{
private static final long serialVersionUID = -8446389232078905200L;

private final int _shift;

public RollIndex(int shift) {
_shift = shift;
}

public int getShift() {
return _shift;
}

@Override
public boolean computeDimension(int row, int col, CellIndex retDim) {
retDim.set(row, col);
return false;
}

@Override
public boolean computeDimension(DataCharacteristics in, DataCharacteristics out) {
out.set(in.getRows(), in.getCols(), in.getBlocksize(), in.getNonZeros());
return false;
}

@Override
public void execute(MatrixIndexes in, MatrixIndexes out) {throw new NotImplementedException();
Baunsgaard marked this conversation as resolved.
Show resolved Hide resolved
}

@Override
public void execute(CellIndex in, CellIndex out) {
throw new NotImplementedException();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ public class CPInstructionParser extends InstructionParser {
// Reorg Instruction Opcodes (repositioning of existing values)
String2CPInstructionType.put( "r'" , CPType.Reorg);
String2CPInstructionType.put( "rev" , CPType.Reorg);
String2CPInstructionType.put( "roll" , CPType.Reorg);
String2CPInstructionType.put( "rdiag" , CPType.Reorg);
String2CPInstructionType.put( "rshape" , CPType.Reshape);
String2CPInstructionType.put( "rsort" , CPType.Reorg);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.functionobjects.DiagIndex;
import org.apache.sysds.runtime.functionobjects.RevIndex;
import org.apache.sysds.runtime.functionobjects.RollIndex;
import org.apache.sysds.runtime.functionobjects.SortIndex;
import org.apache.sysds.runtime.functionobjects.SwapIndex;
import org.apache.sysds.runtime.instructions.InstructionUtils;
Expand All @@ -38,6 +39,7 @@ public class ReorgCPInstruction extends UnaryCPInstruction {
private final CPOperand _col;
private final CPOperand _desc;
private final CPOperand _ixret;
private final CPOperand _shift;

/**
* for opcodes r' and rdiag
Expand Down Expand Up @@ -83,6 +85,31 @@ private ReorgCPInstruction(Operator op, CPOperand in, CPOperand out, CPOperand c
_col = col;
_desc = desc;
_ixret = ixret;
_shift = new CPOperand();
}

/**
* for opcode roll
*
* @param op
* operator
* @param in
* cp input operand
* @param shift
* ?
* @param out
* cp output operand
* @param opcode
* the opcode
* @param istr
* ?
*/
Baunsgaard marked this conversation as resolved.
Show resolved Hide resolved
private ReorgCPInstruction(Operator op, CPOperand in, CPOperand out, CPOperand shift, String opcode, String istr) {
super(CPType.Reorg, op, in, out, opcode, istr);
_col = new CPOperand();
_desc = new CPOperand();
_ixret = new CPOperand();
_shift = shift;
Baunsgaard marked this conversation as resolved.
Show resolved Hide resolved
}

public static ReorgCPInstruction parseInstruction ( String str ) {
Expand All @@ -103,6 +130,14 @@ else if ( opcode.equalsIgnoreCase("rev") ) {
parseUnaryInstruction(str, in, out); //max 2 operands
return new ReorgCPInstruction(new ReorgOperator(RevIndex.getRevIndexFnObject()), in, out, opcode, str);
}
else if ( opcode.equalsIgnoreCase("roll") ) {
InstructionUtils.checkNumFields(str, 3);
in.split(parts[1]);
out.split(parts[3]);
CPOperand shift = new CPOperand(parts[2]);
return new ReorgCPInstruction(new ReorgOperator(new RollIndex(0)),
in, out, shift, opcode, str);
}
else if ( opcode.equalsIgnoreCase("rdiag") ) {
parseUnaryInstruction(str, in, out); //max 2 operands
return new ReorgCPInstruction(new ReorgOperator(DiagIndex.getDiagIndexFnObject()), in, out, opcode, str);
Expand Down Expand Up @@ -136,7 +171,12 @@ public void processInstruction(ExecutionContext ec) {
boolean ixret = ec.getScalarInput(_ixret).getBooleanValue();
r_op = r_op.setFn(new SortIndex(cols, desc, ixret));
}


if (r_op.fn instanceof RollIndex) {
int shift = (int) ec.getScalarInput(_shift).getLongValue();
r_op = r_op.setFn(new RollIndex(shift));
}

//execute operation
MatrixBlock soresBlock = matBlock.reorgOperations(r_op, new MatrixBlock(), 0, 0, 0);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ public class LineageCacheConfig

// Relatively expensive instructions. Most include shuffles.
private static final String[] PERSIST_OPCODES1 = new String[] {
"cpmm", "rmm", "pmm", "zipmm", "rev", "rshape", "rsort", "-", "*", "+",
"cpmm", "rmm", "pmm", "zipmm", "rev", "roll", "rshape", "rsort", "-", "*", "+",
"/", "%%", "%/%", "1-*", "^", "^2", "*2", "==", "!=", "<", ">",
"<=", ">=", "&&", "||", "xor", "max", "min", "rmempty", "rappend",
"gappend", "galignedappend", "rbind", "cbind", "nmin", "nmax",
Expand Down
Loading
Loading