diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/StreamingSnapshotTests.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/StreamingSnapshotTests.kt index 499bd12c0b9..01ee263fd2c 100644 --- a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/StreamingSnapshotTests.kt +++ b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/StreamingSnapshotTests.kt @@ -86,6 +86,17 @@ internal class StreamingSnapshotTests { } } + @Test + fun `unknown enum in finish reason`() = + goldenStreamingFile("failure-unknown-finish-enum.txt") { + val responses = model.generateContentStream("prompt") + + withTimeout(testTimeout) { + val exception = shouldThrow { responses.collect() } + exception.response.candidates.first().finishReason shouldBe FinishReason.UNKNOWN + } + } + @Test fun `quotes escaped`() = goldenStreamingFile("success-quotes-escaped.txt") { @@ -184,4 +195,20 @@ internal class StreamingSnapshotTests { withTimeout(testTimeout) { shouldThrow { responses.collect() } } } + + @Test + fun `invalid json`() = + goldenStreamingFile("failure-invalid-json.txt") { + val responses = model.generateContentStream("prompt") + + withTimeout(testTimeout) { shouldThrow { responses.collect() } } + } + + @Test + fun `malformed content`() = + goldenStreamingFile("failure-malformed-content.txt") { + val responses = model.generateContentStream("prompt") + + withTimeout(testTimeout) { shouldThrow { responses.collect() } } + } } diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/UnarySnapshotTests.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/UnarySnapshotTests.kt index f55f71ee34e..41127bd0530 100644 --- a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/UnarySnapshotTests.kt +++ b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/UnarySnapshotTests.kt @@ -38,6 +38,7 @@ import io.kotest.matchers.nulls.shouldNotBeNull import io.kotest.matchers.should import io.kotest.matchers.shouldBe import io.kotest.matchers.shouldNotBe +import io.kotest.matchers.string.shouldContain import io.kotest.matchers.string.shouldNotBeEmpty import io.kotest.matchers.types.shouldBeInstanceOf import io.ktor.http.HttpStatusCode @@ -84,6 +85,56 @@ internal class UnarySnapshotTests { response.candidates.isEmpty() shouldBe false val candidate = response.candidates.first() candidate.safetyRatings.any { it.category == HarmCategory.UNKNOWN } shouldBe true + response.promptFeedback?.safetyRatings?.any { it.category == HarmCategory.UNKNOWN } shouldBe + true + } + } + + @Test + fun `unknown enum in finish reason`() = + goldenUnaryFile("failure-unknown-enum-finish-reason.json") { + withTimeout(testTimeout) { + shouldThrow { model.generateContent("prompt") } should + { + it.response.candidates.first().finishReason shouldBe FinishReason.UNKNOWN + } + } + } + + @Test + fun `unknown enum in block reason`() = + goldenUnaryFile("failure-unknown-enum-prompt-blocked.json") { + withTimeout(testTimeout) { + shouldThrow { model.generateContent("prompt") } should + { + it.response.promptFeedback?.blockReason shouldBe BlockReason.UNKNOWN + } + } + } + + @Test + fun `quotes escaped`() = + goldenUnaryFile("success-quote-reply.json") { + withTimeout(testTimeout) { + val response = model.generateContent("prompt") + + response.candidates.isEmpty() shouldBe false + response.candidates.first().content.parts.isEmpty() shouldBe false + val part = response.candidates.first().content.parts.first() as TextPart + part.text shouldContain "\"" + } + } + + @Test + fun `safetyRatings missing`() = + goldenUnaryFile("success-missing-safety-ratings.json") { + withTimeout(testTimeout) { + val response = model.generateContent("prompt") + + response.candidates.isEmpty() shouldBe false + response.candidates.first().content.parts.isEmpty() shouldBe false + response.candidates.first().safetyRatings.isEmpty() shouldBe true + response.promptFeedback?.safetyRatings?.isEmpty() shouldBe true } } @@ -147,6 +198,15 @@ internal class UnarySnapshotTests { } } + @Test + fun `stopped for safety with no content`() = + goldenUnaryFile("failure-finish-reason-safety-no-content.json") { + withTimeout(testTimeout) { + val exception = shouldThrow { model.generateContent("prompt") } + exception.response.candidates.first().finishReason shouldBe FinishReason.SAFETY + } + } + @Test fun `citation returns correctly`() = goldenUnaryFile("success-citations.json") { @@ -292,4 +352,79 @@ internal class UnarySnapshotTests { callPart.args["current"] shouldBe "true" } } + + @Test + fun `function call contains no arguments`() = + goldenUnaryFile("success-function-call-no-arguments.json") { + withTimeout(testTimeout) { + val response = model.generateContent("prompt") + val callPart = response.functionCalls.shouldNotBeEmpty().first() + + callPart.name shouldBe "current_time" + callPart.args.isEmpty() shouldBe true + } + } + + @Test + fun `function call contains arguments`() = + goldenUnaryFile("success-function-call-with-arguments.json") { + withTimeout(testTimeout) { + val response = model.generateContent("prompt") + val callPart = response.functionCalls.shouldNotBeEmpty().first() + + callPart.name shouldBe "sum" + callPart.args["x"] shouldBe "4" + callPart.args["y"] shouldBe "5" + } + } + + @Test + fun `function call with parallel calls`() = + goldenUnaryFile("success-function-call-parallel-calls.json") { + withTimeout(testTimeout) { + val response = model.generateContent("prompt") + val callList = response.functionCalls + + callList.size shouldBe 3 + callList.forEach { + it.name shouldBe "sum" + it.args.size shouldBe 2 + } + } + } + + @Test + fun `function call with mixed content`() = + goldenUnaryFile("success-function-call-mixed-content.json") { + withTimeout(testTimeout) { + val response = model.generateContent("prompt") + val callList = response.functionCalls + + response.text shouldBe "The sum of [1, 2, 3] is" + callList.size shouldBe 2 + callList.forEach { it.args.size shouldBe 2 } + } + } + + @Test + fun `countTokens succeeds`() = + goldenUnaryFile("success-total-tokens.json") { + withTimeout(testTimeout) { + val response = model.countTokens("prompt") + + response.totalTokens shouldBe 6 + response.totalBillableCharacters shouldBe 16 + } + } + + @Test + fun `countTokens succeeds with no billable characters`() = + goldenUnaryFile("success-no-billable-characters.json") { + withTimeout(testTimeout) { + val response = model.countTokens("prompt") + + response.totalTokens shouldBe 258 + response.totalBillableCharacters shouldBe 0 + } + } }