Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 09c71bf

Browse files
authored
Add Sparse NDArray support for Scala (#15378)
* add Sparse Support * add imperative invoke sparse support * add retain method and comments * add getData method * add Sparse NDIter test * remove debug line
1 parent 1ae73de commit 09c71bf

File tree

10 files changed

+543
-28
lines changed

10 files changed

+543
-28
lines changed

scala-package/core/src/main/scala/org/apache/mxnet/DType.scala

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,26 +24,17 @@ object DType extends Enumeration {
2424
val Float16 = Value(2, "float16")
2525
val UInt8 = Value(3, "uint8")
2626
val Int32 = Value(4, "int32")
27+
val Int8 = Value(5, "int8")
28+
val Int64 = Value(6, "int64")
2729
val Unknown = Value(-1, "unknown")
2830
private[mxnet] def numOfBytes(dtype: DType): Int = {
2931
dtype match {
30-
case DType.UInt8 => 1
32+
case DType.UInt8 | DType.Int8 => 1
3133
case DType.Int32 => 4
3234
case DType.Float16 => 2
3335
case DType.Float32 => 4
34-
case DType.Float64 => 8
36+
case DType.Float64 | DType.Int64 => 8
3537
case DType.Unknown => 0
3638
}
3739
}
38-
private[mxnet] def getType(dtypeStr: String): DType = {
39-
dtypeStr match {
40-
case "UInt8" => DType.UInt8
41-
case "Int32" => DType.Int32
42-
case "Float16" => DType.Float16
43-
case "Float32" => DType.Float32
44-
case "Float64" => DType.Float64
45-
case _ => throw new IllegalArgumentException(
46-
s"DType: $dtypeStr not found! please set it in DType.scala")
47-
}
48-
}
4940
}

scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,14 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
159159
private def getOutputs: Array[NDArray] = {
160160
val ndHandles = ArrayBuffer[NDArrayHandle]()
161161
checkCall(_LIB.mxExecutorOutputs(handle, ndHandles))
162-
ndHandles.toArray.map(new NDArray(_, addToCollector = false))
162+
ndHandles.toArray.map(ele => {
163+
val nd = new NDArray(ele, addToCollector = false)
164+
if (nd.isSparse) {
165+
nd.asInstanceOf[SparseNDArray]
166+
}
167+
nd
168+
}
169+
)
163170
}
164171

165172
/**

scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,14 @@ private[mxnet] class LibInfo {
3131
@native def mxListAllOpNames(names: ListBuffer[String]): Int
3232
@native def nnGetOpHandle(opName: String, opHandle: RefLong): Int
3333
// NDArray
34-
@native def mxImperativeInvoke(creator: FunctionHandle,
34+
@native def mxImperativeInvokeEx(creator: FunctionHandle,
3535
inputs: Array[NDArrayHandle],
3636
outputsGiven: Array[NDArrayHandle],
3737
outputs: ArrayBuffer[NDArrayHandle],
3838
numParams: Int,
3939
paramKeys: Array[String],
40-
paramVals: Array[String]): Int
40+
paramVals: Array[String],
41+
outStype: ArrayBuffer[Int]): Int
4142
@native def mxNDArrayFree(handle: NDArrayHandle): Int
4243
@native def mxNDArrayCreateNone(out: NDArrayHandleRef): Int
4344
@native def mxNDArrayCreateEx(shape: Array[Int],
@@ -47,6 +48,20 @@ private[mxnet] class LibInfo {
4748
delayAlloc: Int,
4849
dtype: Int,
4950
out: NDArrayHandleRef): Int
51+
// scalastyle:off parameterNum
52+
@native def mxNDArrayCreateSparseEx(storageType: Int,
53+
shape: Array[Int],
54+
ndim: Int,
55+
devType: Int,
56+
devId: Int,
57+
delayAlloc: Int,
58+
dtype: Int,
59+
numAux: Int,
60+
auxTypes: Array[Int],
61+
auxNdims: Array[Int],
62+
auxShapes: Array[Int],
63+
out: NDArrayHandleRef): Int
64+
// scalastyle:on parameterNum
5065
@native def mxNDArrayWaitAll(): Int
5166
@native def mxNDArrayWaitToRead(handle: NDArrayHandle): Int
5267
@native def mxListFunctions(functions: ListBuffer[FunctionHandle]): Int
@@ -76,6 +91,9 @@ private[mxnet] class LibInfo {
7691
@native def mxNDArrayGetShape(handle: NDArrayHandle,
7792
ndim: MXUintRef,
7893
data: ArrayBuffer[Int]): Int
94+
@native def mxNDArraySyncCopyFromNDArray(handleDst: NDArrayHandle,
95+
handleSrc: NDArrayHandle,
96+
locator: Int): Int
7997
@native def mxNDArraySyncCopyToCPU(handle: NDArrayHandle,
8098
data: Array[Byte],
8199
size: Int): Int
@@ -105,10 +123,15 @@ private[mxnet] class LibInfo {
105123
@native def mxNDArraySave(fname: String,
106124
handles: Array[NDArrayHandle],
107125
keys: Array[String]): Int
126+
@native def mxNDArrayGetDataNDArray(handle: NDArrayHandle, out: NDArrayHandleRef): Int
127+
@native def mxNDArrayGetAuxNDArray(handle: NDArrayHandle,
128+
location: Int,
129+
out: NDArrayHandleRef): Int
108130
@native def mxNDArrayGetContext(handle: NDArrayHandle, devTypeId: RefInt, devId: RefInt): Int
109131
@native def mxNDArraySaveRawBytes(handle: NDArrayHandle, buf: ArrayBuffer[Byte]): Int
110132
@native def mxNDArrayLoadFromRawBytes(bytes: Array[Byte], handle: NDArrayHandleRef): Int
111133
@native def mxNDArrayGetDType(handle: NDArrayHandle, dtype: RefInt): Int
134+
@native def mxNDArrayGetStorageType(handle: NDArrayHandle, stype: RefInt): Int
112135

113136
// KVStore Server
114137
@native def mxInitPSEnv(keys: Array[String], values: Array[String]): Int

scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ import java.nio.{ByteBuffer, ByteOrder}
2121

2222
import org.apache.mxnet.Base._
2323
import org.apache.mxnet.DType.DType
24-
import org.apache.mxnet.MX_PRIMITIVES.{MX_PRIMITIVE_TYPE}
24+
import org.apache.mxnet.MX_PRIMITIVES.MX_PRIMITIVE_TYPE
25+
import org.apache.mxnet.SparseFormat.SparseFormat
2526
import org.slf4j.LoggerFactory
2627

2728
import scala.collection.mutable
@@ -113,10 +114,22 @@ object NDArray extends NDArrayBase {
113114
}
114115

115116
val outputs = ArrayBuffer.empty[NDArrayHandle]
116-
checkCall(_LIB.mxImperativeInvoke(function.handle, ndArgs.map(_.handle).toArray, outputVars,
117-
outputs, updatedKwargs.size, updatedKwargs.keys.toArray, updatedKwargs.values.toArray))
117+
val outStypes = ArrayBuffer.empty[Int]
118+
checkCall(_LIB.mxImperativeInvokeEx(function.handle,
119+
ndArgs.map(_.handle).toArray,
120+
outputVars,
121+
outputs,
122+
updatedKwargs.size,
123+
updatedKwargs.keys.toArray,
124+
updatedKwargs.values.toArray,
125+
outStypes))
118126
new NDArrayFuncReturn(Option(oriOutputs).getOrElse {
119-
val outputArrs = outputs.map(new NDArray(_)).toArray
127+
val outputArrs = (outputs zip outStypes).map(
128+
ele => ele._2 match {
129+
case 0 => new NDArray(ele._1)
130+
case _ => new SparseNDArray(ele._1)
131+
}
132+
).toArray
120133
addDependency(ndArgs.toArray, outputArrs)
121134
outputArrs
122135
})
@@ -943,6 +956,12 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
943956
DType(mxDtype.value)
944957
}
945958

959+
val sparseFormat: SparseFormat = {
960+
val mxSF = new RefInt
961+
checkCall(_LIB.mxNDArrayGetStorageType(handle, mxSF))
962+
SparseFormat(mxSF.value)
963+
}
964+
946965
/**
947966
* Return a copied numpy array of current array with specified type.
948967
* @param dtype Desired type of result array.
@@ -1309,6 +1328,30 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
13091328
if (this.context == context) this else this.copyTo(context)
13101329
}
13111330

1331+
/**
1332+
* check if NDArray is SparseNDArray
1333+
* @return Boolean
1334+
*/
1335+
def isSparse: Boolean = {
1336+
this.sparseFormat.id != 0
1337+
}
1338+
1339+
/**
1340+
* Convert a NDArray to SparseNDArray
1341+
*
1342+
* @param sfOption the target sparse type
1343+
* @return SparseNDArray
1344+
*/
1345+
def toSparse(sfOption : Option[SparseFormat] = None): SparseNDArray = {
1346+
val sf = sfOption.getOrElse(SparseFormat.ROW_SPARSE)
1347+
if (sf.id == 0) throw new IllegalArgumentException("Require Sparse")
1348+
if (isSparse && sfOption.isEmpty) {
1349+
this.asInstanceOf[SparseNDArray]
1350+
} else {
1351+
NDArray.api.cast_storage(this, sf.toString).head.asInstanceOf[SparseNDArray]
1352+
}
1353+
}
1354+
13121355
override def equals(o: Any): Boolean = o match {
13131356
case that: NDArray =>
13141357
that != null && that.shape == this.shape && that.toArray.sameElements(this.toArray)
@@ -1479,6 +1522,7 @@ private[mxnet] class NDArrayInternal (private val internal: Array[Byte], private
14791522
case DType.Float32 => units.map(wrapBytes(_).getFloat.toDouble)
14801523
case DType.Float64 => units.map(wrapBytes(_).getDouble)
14811524
case DType.Int32 => units.map(wrapBytes(_).getInt.toDouble)
1525+
case DType.Int64 => units.map(wrapBytes(_).getLong.toDouble)
14821526
case DType.UInt8 => internal.map(_.toDouble)
14831527
}
14841528
}
@@ -1488,6 +1532,7 @@ private[mxnet] class NDArrayInternal (private val internal: Array[Byte], private
14881532
case DType.Float32 => units.map(wrapBytes(_).getFloat)
14891533
case DType.Float64 => units.map(wrapBytes(_).getDouble.toFloat)
14901534
case DType.Int32 => units.map(wrapBytes(_).getInt.toFloat)
1535+
case DType.Int64 => units.map(wrapBytes(_).getLong.toFloat)
14911536
case DType.UInt8 => internal.map(_.toFloat)
14921537
}
14931538
}
@@ -1497,15 +1542,27 @@ private[mxnet] class NDArrayInternal (private val internal: Array[Byte], private
14971542
case DType.Float32 => units.map(wrapBytes(_).getFloat.toInt)
14981543
case DType.Float64 => units.map(wrapBytes(_).getDouble.toInt)
14991544
case DType.Int32 => units.map(wrapBytes(_).getInt)
1545+
case DType.Int64 => units.map(wrapBytes(_).getLong.toInt)
15001546
case DType.UInt8 => internal.map(_.toInt)
15011547
}
15021548
}
1549+
def toLongArray: Array[Long] = {
1550+
require(dtype != DType.Float16, "Currently cannot convert float16 to native numerical types")
1551+
dtype match {
1552+
case DType.Float32 => units.map(wrapBytes(_).getFloat.toLong)
1553+
case DType.Float64 => units.map(wrapBytes(_).getDouble.toLong)
1554+
case DType.Int32 => units.map(wrapBytes(_).getInt.toLong)
1555+
case DType.Int64 => units.map(wrapBytes(_).getLong)
1556+
case DType.UInt8 => internal.map(_.toLong)
1557+
}
1558+
}
15031559
def toByteArray: Array[Byte] = {
15041560
require(dtype != DType.Float16, "Currently cannot convert float16 to native numerical types")
15051561
dtype match {
15061562
case DType.Float16 | DType.Float32 => units.map(wrapBytes(_).getFloat.toByte)
15071563
case DType.Float64 => units.map(wrapBytes(_).getDouble.toByte)
15081564
case DType.Int32 => units.map(wrapBytes(_).getInt.toByte)
1565+
case DType.Int64 => units.map(wrapBytes(_).getLong.toByte)
15091566
case DType.UInt8 => internal.clone()
15101567
}
15111568
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.mxnet
19+
20+
object SparseFormat extends Enumeration {
21+
type SparseFormat = Value
22+
val DEFAULT = Value(0, "default")
23+
val ROW_SPARSE = Value(1, "row_sparse")
24+
val CSR = Value(2, "csr")
25+
}

0 commit comments

Comments
 (0)