@@ -21,7 +21,8 @@ import java.nio.{ByteBuffer, ByteOrder}
21
21
22
22
import org .apache .mxnet .Base ._
23
23
import org .apache .mxnet .DType .DType
24
- import org .apache .mxnet .MX_PRIMITIVES .{MX_PRIMITIVE_TYPE }
24
+ import org .apache .mxnet .MX_PRIMITIVES .MX_PRIMITIVE_TYPE
25
+ import org .apache .mxnet .SparseFormat .SparseFormat
25
26
import org .slf4j .LoggerFactory
26
27
27
28
import scala .collection .mutable
@@ -113,10 +114,22 @@ object NDArray extends NDArrayBase {
113
114
}
114
115
115
116
val outputs = ArrayBuffer .empty[NDArrayHandle ]
116
- checkCall(_LIB.mxImperativeInvoke(function.handle, ndArgs.map(_.handle).toArray, outputVars,
117
- outputs, updatedKwargs.size, updatedKwargs.keys.toArray, updatedKwargs.values.toArray))
117
+ val outStypes = ArrayBuffer .empty[Int ]
118
+ checkCall(_LIB.mxImperativeInvokeEx(function.handle,
119
+ ndArgs.map(_.handle).toArray,
120
+ outputVars,
121
+ outputs,
122
+ updatedKwargs.size,
123
+ updatedKwargs.keys.toArray,
124
+ updatedKwargs.values.toArray,
125
+ outStypes))
118
126
new NDArrayFuncReturn (Option (oriOutputs).getOrElse {
119
- val outputArrs = outputs.map(new NDArray (_)).toArray
127
+ val outputArrs = (outputs zip outStypes).map(
128
+ ele => ele._2 match {
129
+ case 0 => new NDArray (ele._1)
130
+ case _ => new SparseNDArray (ele._1)
131
+ }
132
+ ).toArray
120
133
addDependency(ndArgs.toArray, outputArrs)
121
134
outputArrs
122
135
})
@@ -943,6 +956,12 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
943
956
DType (mxDtype.value)
944
957
}
945
958
959
+ val sparseFormat : SparseFormat = {
960
+ val mxSF = new RefInt
961
+ checkCall(_LIB.mxNDArrayGetStorageType(handle, mxSF))
962
+ SparseFormat (mxSF.value)
963
+ }
964
+
946
965
/**
947
966
* Return a copied numpy array of current array with specified type.
948
967
* @param dtype Desired type of result array.
@@ -1309,6 +1328,30 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
1309
1328
if (this .context == context) this else this .copyTo(context)
1310
1329
}
1311
1330
1331
+ /**
1332
+ * check if NDArray is SparseNDArray
1333
+ * @return Boolean
1334
+ */
1335
+ def isSparse : Boolean = {
1336
+ this .sparseFormat.id != 0
1337
+ }
1338
+
1339
+ /**
1340
+ * Convert a NDArray to SparseNDArray
1341
+ *
1342
+ * @param sfOption the target sparse type
1343
+ * @return SparseNDArray
1344
+ */
1345
+ def toSparse (sfOption : Option [SparseFormat ] = None ): SparseNDArray = {
1346
+ val sf = sfOption.getOrElse(SparseFormat .ROW_SPARSE )
1347
+ if (sf.id == 0 ) throw new IllegalArgumentException (" Require Sparse" )
1348
+ if (isSparse && sfOption.isEmpty) {
1349
+ this .asInstanceOf [SparseNDArray ]
1350
+ } else {
1351
+ NDArray .api.cast_storage(this , sf.toString).head.asInstanceOf [SparseNDArray ]
1352
+ }
1353
+ }
1354
+
1312
1355
override def equals (o : Any ): Boolean = o match {
1313
1356
case that : NDArray =>
1314
1357
that != null && that.shape == this .shape && that.toArray.sameElements(this .toArray)
@@ -1479,6 +1522,7 @@ private[mxnet] class NDArrayInternal (private val internal: Array[Byte], private
1479
1522
case DType .Float32 => units.map(wrapBytes(_).getFloat.toDouble)
1480
1523
case DType .Float64 => units.map(wrapBytes(_).getDouble)
1481
1524
case DType .Int32 => units.map(wrapBytes(_).getInt.toDouble)
1525
+ case DType .Int64 => units.map(wrapBytes(_).getLong.toDouble)
1482
1526
case DType .UInt8 => internal.map(_.toDouble)
1483
1527
}
1484
1528
}
@@ -1488,6 +1532,7 @@ private[mxnet] class NDArrayInternal (private val internal: Array[Byte], private
1488
1532
case DType .Float32 => units.map(wrapBytes(_).getFloat)
1489
1533
case DType .Float64 => units.map(wrapBytes(_).getDouble.toFloat)
1490
1534
case DType .Int32 => units.map(wrapBytes(_).getInt.toFloat)
1535
+ case DType .Int64 => units.map(wrapBytes(_).getLong.toFloat)
1491
1536
case DType .UInt8 => internal.map(_.toFloat)
1492
1537
}
1493
1538
}
@@ -1497,15 +1542,27 @@ private[mxnet] class NDArrayInternal (private val internal: Array[Byte], private
1497
1542
case DType .Float32 => units.map(wrapBytes(_).getFloat.toInt)
1498
1543
case DType .Float64 => units.map(wrapBytes(_).getDouble.toInt)
1499
1544
case DType .Int32 => units.map(wrapBytes(_).getInt)
1545
+ case DType .Int64 => units.map(wrapBytes(_).getLong.toInt)
1500
1546
case DType .UInt8 => internal.map(_.toInt)
1501
1547
}
1502
1548
}
1549
+ def toLongArray : Array [Long ] = {
1550
+ require(dtype != DType .Float16 , " Currently cannot convert float16 to native numerical types" )
1551
+ dtype match {
1552
+ case DType .Float32 => units.map(wrapBytes(_).getFloat.toLong)
1553
+ case DType .Float64 => units.map(wrapBytes(_).getDouble.toLong)
1554
+ case DType .Int32 => units.map(wrapBytes(_).getInt.toLong)
1555
+ case DType .Int64 => units.map(wrapBytes(_).getLong)
1556
+ case DType .UInt8 => internal.map(_.toLong)
1557
+ }
1558
+ }
1503
1559
def toByteArray : Array [Byte ] = {
1504
1560
require(dtype != DType .Float16 , " Currently cannot convert float16 to native numerical types" )
1505
1561
dtype match {
1506
1562
case DType .Float16 | DType .Float32 => units.map(wrapBytes(_).getFloat.toByte)
1507
1563
case DType .Float64 => units.map(wrapBytes(_).getDouble.toByte)
1508
1564
case DType .Int32 => units.map(wrapBytes(_).getInt.toByte)
1565
+ case DType .Int64 => units.map(wrapBytes(_).getLong.toByte)
1509
1566
case DType .UInt8 => internal.clone()
1510
1567
}
1511
1568
}
0 commit comments