Skip to content

Add more VertexAI unit tests #6104

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,17 @@ internal class StreamingSnapshotTests {
}
}

@Test
fun `unknown enum in finish reason`() =
goldenStreamingFile("failure-unknown-finish-enum.txt") {
val responses = model.generateContentStream("prompt")

withTimeout(testTimeout) {
val exception = shouldThrow<ResponseStoppedException> { responses.collect() }
exception.response.candidates.first().finishReason shouldBe FinishReason.UNKNOWN
}
}

@Test
fun `quotes escaped`() =
goldenStreamingFile("success-quotes-escaped.txt") {
Expand Down Expand Up @@ -184,4 +195,20 @@ internal class StreamingSnapshotTests {

withTimeout(testTimeout) { shouldThrow<InvalidAPIKeyException> { responses.collect() } }
}

@Test
fun `invalid json`() =
goldenStreamingFile("failure-invalid-json.txt") {
val responses = model.generateContentStream("prompt")

withTimeout(testTimeout) { shouldThrow<SerializationException> { responses.collect() } }
}

@Test
fun `malformed content`() =
goldenStreamingFile("failure-malformed-content.txt") {
val responses = model.generateContentStream("prompt")

withTimeout(testTimeout) { shouldThrow<SerializationException> { responses.collect() } }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import io.kotest.matchers.nulls.shouldNotBeNull
import io.kotest.matchers.should
import io.kotest.matchers.shouldBe
import io.kotest.matchers.shouldNotBe
import io.kotest.matchers.string.shouldContain
import io.kotest.matchers.string.shouldNotBeEmpty
import io.kotest.matchers.types.shouldBeInstanceOf
import io.ktor.http.HttpStatusCode
Expand Down Expand Up @@ -84,6 +85,56 @@ internal class UnarySnapshotTests {
response.candidates.isEmpty() shouldBe false
val candidate = response.candidates.first()
candidate.safetyRatings.any { it.category == HarmCategory.UNKNOWN } shouldBe true
response.promptFeedback?.safetyRatings?.any { it.category == HarmCategory.UNKNOWN } shouldBe
true
}
}

@Test
fun `unknown enum in finish reason`() =
goldenUnaryFile("failure-unknown-enum-finish-reason.json") {
withTimeout(testTimeout) {
shouldThrow<ResponseStoppedException> { model.generateContent("prompt") } should
{
it.response.candidates.first().finishReason shouldBe FinishReason.UNKNOWN
}
}
}

@Test
fun `unknown enum in block reason`() =
goldenUnaryFile("failure-unknown-enum-prompt-blocked.json") {
withTimeout(testTimeout) {
shouldThrow<PromptBlockedException> { model.generateContent("prompt") } should
{
it.response.promptFeedback?.blockReason shouldBe BlockReason.UNKNOWN
}
}
}

@Test
fun `quotes escaped`() =
goldenUnaryFile("success-quote-reply.json") {
withTimeout(testTimeout) {
val response = model.generateContent("prompt")

response.candidates.isEmpty() shouldBe false
response.candidates.first().content.parts.isEmpty() shouldBe false
val part = response.candidates.first().content.parts.first() as TextPart
part.text shouldContain "\""
}
}

@Test
fun `safetyRatings missing`() =
goldenUnaryFile("success-missing-safety-ratings.json") {
withTimeout(testTimeout) {
val response = model.generateContent("prompt")

response.candidates.isEmpty() shouldBe false
response.candidates.first().content.parts.isEmpty() shouldBe false
response.candidates.first().safetyRatings.isEmpty() shouldBe true
response.promptFeedback?.safetyRatings?.isEmpty() shouldBe true
}
}

Expand Down Expand Up @@ -147,6 +198,15 @@ internal class UnarySnapshotTests {
}
}

@Test
fun `stopped for safety with no content`() =
goldenUnaryFile("failure-finish-reason-safety-no-content.json") {
withTimeout(testTimeout) {
val exception = shouldThrow<ResponseStoppedException> { model.generateContent("prompt") }
exception.response.candidates.first().finishReason shouldBe FinishReason.SAFETY
}
}

@Test
fun `citation returns correctly`() =
goldenUnaryFile("success-citations.json") {
Expand Down Expand Up @@ -292,4 +352,79 @@ internal class UnarySnapshotTests {
callPart.args["current"] shouldBe "true"
}
}

@Test
fun `function call contains no arguments`() =
goldenUnaryFile("success-function-call-no-arguments.json") {
withTimeout(testTimeout) {
val response = model.generateContent("prompt")
val callPart = response.functionCalls.shouldNotBeEmpty().first()

callPart.name shouldBe "current_time"
callPart.args.isEmpty() shouldBe true
}
}

@Test
fun `function call contains arguments`() =
goldenUnaryFile("success-function-call-with-arguments.json") {
withTimeout(testTimeout) {
val response = model.generateContent("prompt")
val callPart = response.functionCalls.shouldNotBeEmpty().first()

callPart.name shouldBe "sum"
callPart.args["x"] shouldBe "4"
callPart.args["y"] shouldBe "5"
}
}

@Test
fun `function call with parallel calls`() =
goldenUnaryFile("success-function-call-parallel-calls.json") {
withTimeout(testTimeout) {
val response = model.generateContent("prompt")
val callList = response.functionCalls

callList.size shouldBe 3
callList.forEach {
it.name shouldBe "sum"
it.args.size shouldBe 2
}
}
}

@Test
fun `function call with mixed content`() =
goldenUnaryFile("success-function-call-mixed-content.json") {
withTimeout(testTimeout) {
val response = model.generateContent("prompt")
val callList = response.functionCalls

response.text shouldBe "The sum of [1, 2, 3] is"
callList.size shouldBe 2
callList.forEach { it.args.size shouldBe 2 }
}
}

@Test
fun `countTokens succeeds`() =
goldenUnaryFile("success-total-tokens.json") {
withTimeout(testTimeout) {
val response = model.countTokens("prompt")

response.totalTokens shouldBe 6
response.totalBillableCharacters shouldBe 16
}
}

@Test
fun `countTokens succeeds with no billable characters`() =
goldenUnaryFile("success-no-billable-characters.json") {
withTimeout(testTimeout) {
val response = model.countTokens("prompt")

response.totalTokens shouldBe 258
response.totalBillableCharacters shouldBe 0
}
}
}
Loading