Skip to content

Commit

Permalink
Optimize prepareStatementExecution request freq (#828)
Browse files Browse the repository at this point in the history
* Optimize prepareStatementExecution request freq

Signed-off-by: Louis Chu <[email protected]>

* Add UT

Signed-off-by: Louis Chu <[email protected]>

---------

Signed-off-by: Louis Chu <[email protected]>
  • Loading branch information
noCharger authored Oct 30, 2024
1 parent d2213c5 commit a2a9838
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,13 @@ object FlintREPL extends Logging with FlintJobExecutor {
val threadPool = threadPoolFactory.newDaemonThreadPoolScheduledExecutor("flint-repl-query", 1)
implicit val executionContext = ExecutionContext.fromExecutor(threadPool)
val queryResultWriter = instantiateQueryResultWriter(spark, commandContext)
var futurePrepareQueryExecution: Future[Either[String, Unit]] = null

val statementsExecutionManager =
instantiateStatementExecutionManager(commandContext)

var futurePrepareQueryExecution: Future[Either[String, Unit]] = Future {
statementsExecutionManager.prepareStatementExecution()
}
try {
logInfo(s"""Executing session with sessionId: ${sessionId}""")

Expand All @@ -324,12 +330,6 @@ object FlintREPL extends Logging with FlintJobExecutor {
var lastCanPickCheckTime = 0L
while (currentTimeProvider
.currentEpochMillis() - lastActivityTime <= commandContext.inactivityLimitMillis && canPickUpNextStatement) {
val statementsExecutionManager =
instantiateStatementExecutionManager(commandContext)

futurePrepareQueryExecution = Future {
statementsExecutionManager.prepareStatementExecution()
}

try {
val commandState = CommandState(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
package org.apache.spark.sql

import org.opensearch.flint.common.model.FlintStatement
import org.opensearch.flint.core.storage.OpenSearchUpdater
import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater}
import org.opensearch.search.sort.SortOrder

import org.apache.spark.internal.Logging
Expand All @@ -29,8 +29,8 @@ class StatementExecutionManagerImpl(commandContext: CommandContext)
context("flintSessionIndexUpdater").asInstanceOf[OpenSearchUpdater]

// Using one reader client within same session will cause concurrency issue.
// To resolve this move the reader creation and getNextStatement method to mirco-batch level
private val flintReader = createOpenSearchQueryReader()
// To resolve this move the reader creation to getNextStatement method at mirco-batch level
private var currentReader: Option[FlintReader] = None

override def prepareStatementExecution(): Either[String, Unit] = {
checkAndCreateIndex(osClient, resultIndex)
Expand All @@ -39,12 +39,17 @@ class StatementExecutionManagerImpl(commandContext: CommandContext)
flintSessionIndexUpdater.update(statement.statementId, FlintStatement.serialize(statement))
}
override def terminateStatementExecution(): Unit = {
flintReader.close()
currentReader.foreach(_.close())
currentReader = None
}

override def getNextStatement(): Option[FlintStatement] = {
if (flintReader.hasNext) {
val rawStatement = flintReader.next()
if (currentReader.isEmpty) {
currentReader = Some(createOpenSearchQueryReader())
}

if (currentReader.get.hasNext) {
val rawStatement = currentReader.get.next()
val flintStatement = FlintStatement.deserialize(rawStatement)
logInfo(s"Next statement to execute: $flintStatement")
Some(flintStatement)
Expand Down Expand Up @@ -100,7 +105,6 @@ class StatementExecutionManagerImpl(commandContext: CommandContext)
| ]
| }
|}""".stripMargin
val flintReader = osClient.createQueryReader(sessionIndex, dsl, "submitTime", SortOrder.ASC)
flintReader
osClient.createQueryReader(sessionIndex, dsl, "submitTime", SortOrder.ASC)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1387,7 +1387,8 @@ class FlintREPLTest

val expectedCalls =
Math.ceil(inactivityLimit.toDouble / DEFAULT_QUERY_LOOP_EXECUTION_FREQUENCY).toInt
verify(mockOSClient, Mockito.atMost(expectedCalls)).getIndexMetadata(*)
verify(mockOSClient, times(1)).getIndexMetadata(*)
verify(mockOSClient, Mockito.atMost(expectedCalls)).createQueryReader(*, *, *, *)
}

val testCases = Table(
Expand Down

0 comments on commit a2a9838

Please sign in to comment.