Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit ed00cc1

Browse files
committed
add getData method
1 parent 40cf36e commit ed00cc1

File tree

5 files changed

+37
-8
lines changed

5 files changed

+37
-8
lines changed

scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ private[mxnet] class LibInfo {
123123
@native def mxNDArraySave(fname: String,
124124
handles: Array[NDArrayHandle],
125125
keys: Array[String]): Int
126+
@native def mxNDArrayGetDataNDArray(handle: NDArrayHandle, out: NDArrayHandleRef): Int
126127
@native def mxNDArrayGetAuxNDArray(handle: NDArrayHandle,
127128
location: Int,
128129
out: NDArrayHandleRef): Int

scala-package/core/src/main/scala/org/apache/mxnet/SparseNDArray.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,19 @@ class SparseNDArray private[mxnet] (override private[mxnet] val handle: NDArrayH
151151
dense.at(idx)
152152
}
153153

154+
/**
155+
* Get the Data portion from the row Sparse NDArray
156+
* @return NDArray
157+
*/
158+
def getData: NDArray = {
159+
val handle = new NDArrayHandleRef
160+
if (this.sparseFormat == SparseFormat.CSR) {
161+
throw new UnsupportedOperationException("Not Supported for CSR")
162+
}
163+
_LIB.mxNDArrayGetDataNDArray(this.handle, handle)
164+
new NDArray(handle.value, false)
165+
}
166+
154167
/**
155168
* Get the indptr Array
156169
* @return NDArray

scala-package/core/src/test/scala/org/apache/mxnet/SparseNDArraySuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class SparseNDArraySuite extends FunSuite {
6565
printf(rspIn.toString)
6666
val toRetain = Array(0f, 3f)
6767
val rspOut = SparseNDArray.retain(rspIn, toRetain)
68-
assert(rspOut.at(0).toArray sameElements Array(1f, 2f))
68+
assert(rspOut.getData.toArray sameElements Array(1f, 2f, 5f, 6f))
6969
assert(rspOut.getIndices.toArray sameElements Array(0f, 3f))
7070
}
7171

scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,15 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxFloat64NDArraySyncCopyFro
474474
return ret;
475475
}
476476

477+
JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetDataNDArray
478+
(JNIEnv *env, jobject obj, jlong arrayPtr, jobject ndArrayHandle) {
479+
NDArrayHandle out;
480+
int ret = MXNDArrayGetDataNDArray(reinterpret_cast<NDArrayHandle>(arrayPtr),
481+
&out);
482+
SetLongField(env, ndArrayHandle, reinterpret_cast<jlong>(out));
483+
return ret;
484+
}
485+
477486
JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetAuxNDArray
478487
(JNIEnv *env, jobject obj, jlong arrayPtr, jint location, jobject ndArrayHandle) {
479488
NDArrayHandle out;

scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h

Lines changed: 13 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)