Skip to content

Commit

Permalink
set default namespace in CatalogParser
Browse files Browse the repository at this point in the history
  • Loading branch information
edgao committed May 17, 2024
1 parent d95a944 commit 96f4b23
Show file tree
Hide file tree
Showing 16 changed files with 123 additions and 185 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
package io.airbyte.cdk.integrations.destination.async

import com.google.common.base.Preconditions
import com.google.common.base.Strings
import io.airbyte.cdk.integrations.base.SerializedAirbyteMessageConsumer
import io.airbyte.cdk.integrations.destination.StreamSyncSummary
import io.airbyte.cdk.integrations.destination.async.buffers.BufferEnqueue
Expand All @@ -28,8 +27,6 @@ import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors
import java.util.concurrent.atomic.AtomicLong
import java.util.function.Consumer
import java.util.stream.Collectors
import kotlin.jvm.optionals.getOrNull
import org.jetbrains.annotations.VisibleForTesting

private val logger = KotlinLogging.logger {}
Expand All @@ -52,7 +49,7 @@ constructor(
onFlush: DestinationFlushFunction,
private val catalog: ConfiguredAirbyteCatalog,
private val bufferManager: BufferManager,
private val defaultNamespace: Optional<String>,
private val defaultNamespace: String,
private val flushFailure: FlushFailure = FlushFailure(),
workerPool: ExecutorService = Executors.newFixedThreadPool(5),
private val airbyteMessageDeserializer: AirbyteMessageDeserializer =
Expand Down Expand Up @@ -80,28 +77,6 @@ constructor(
private var hasClosed = false
private var hasFailed = false

internal constructor(
outputRecordCollector: Consumer<AirbyteMessage>,
onStart: OnStartFunction,
onClose: OnCloseFunction,
flusher: DestinationFlushFunction,
catalog: ConfiguredAirbyteCatalog,
bufferManager: BufferManager,
flushFailure: FlushFailure,
defaultNamespace: Optional<String>,
) : this(
outputRecordCollector,
onStart,
onClose,
flusher,
catalog,
bufferManager,
defaultNamespace,
flushFailure,
Executors.newFixedThreadPool(5),
AirbyteMessageDeserializer(),
)

@Throws(Exception::class)
override fun start() {
Preconditions.checkState(!hasStarted, "Consumer has already been started.")
Expand Down Expand Up @@ -130,9 +105,6 @@ constructor(
message,
)
if (AirbyteMessage.Type.RECORD == partialAirbyteMessage.type) {
if (Strings.isNullOrEmpty(partialAirbyteMessage.record?.namespace)) {
partialAirbyteMessage.record?.namespace = defaultNamespace.getOrNull()
}
validateRecord(partialAirbyteMessage)

partialAirbyteMessage.record?.streamDescriptor?.let {
Expand All @@ -142,7 +114,6 @@ constructor(
bufferEnqueue.addRecord(
partialAirbyteMessage,
sizeInBytes + PARTIAL_DESERIALIZE_REF_BYTES,
defaultNamespace,
)
}

Expand All @@ -160,18 +131,15 @@ constructor(
bufferManager.close()

val streamSyncSummaries =
streamNames
.stream()
.collect(
Collectors.toMap(
{ streamDescriptor: StreamDescriptor -> streamDescriptor },
{ streamDescriptor: StreamDescriptor ->
StreamSyncSummary(
Optional.of(getRecordCounter(streamDescriptor).get()),
)
},
),
)
streamNames.associate { streamDescriptor ->
StreamDescriptorUtils.withDefaultNamespace(
streamDescriptor,
defaultNamespace,
) to
StreamSyncSummary(
Optional.of(getRecordCounter(streamDescriptor).get()),
)
}
onClose.accept(hasFailed, streamSyncSummaries)

// as this throws an exception, we need to be after all other close functions.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,11 @@ object StreamDescriptorUtils {

return pairs
}

fun withDefaultNamespace(sd: StreamDescriptor, defaultNamespace: String) =
if (sd.namespace.isNullOrEmpty()) {
StreamDescriptor().withName(sd.name).withNamespace(defaultNamespace)
} else {
sd
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ package io.airbyte.cdk.integrations.destination.async.buffers
import io.airbyte.cdk.integrations.destination.async.GlobalMemoryManager
import io.airbyte.cdk.integrations.destination.async.model.PartialAirbyteMessage
import io.airbyte.cdk.integrations.destination.async.state.GlobalAsyncStateManager
import io.airbyte.commons.json.Jsons
import io.airbyte.protocol.models.v0.AirbyteMessage
import io.airbyte.protocol.models.v0.StreamDescriptor
import java.util.Optional
import java.util.concurrent.ConcurrentMap

/**
Expand All @@ -20,6 +20,7 @@ class BufferEnqueue(
private val memoryManager: GlobalMemoryManager,
private val buffers: ConcurrentMap<StreamDescriptor, StreamAwareQueue>,
private val stateManager: GlobalAsyncStateManager,
private val defaultNamespace: String,
) {
/**
* Buffer a record. Contains memory management logic to dynamically adjust queue size based via
Expand All @@ -31,12 +32,11 @@ class BufferEnqueue(
fun addRecord(
message: PartialAirbyteMessage,
sizeInBytes: Int,
defaultNamespace: Optional<String>,
) {
if (message.type == AirbyteMessage.Type.RECORD) {
handleRecord(message, sizeInBytes)
} else if (message.type == AirbyteMessage.Type.STATE) {
stateManager.trackState(message, sizeInBytes.toLong(), defaultNamespace.orElse(""))
stateManager.trackState(message, sizeInBytes.toLong())
}
}

Expand All @@ -53,15 +53,28 @@ class BufferEnqueue(
}
val stateId = stateManager.getStateIdAndIncrementCounter(streamDescriptor)

var addedToQueue = queue.offer(message, sizeInBytes.toLong(), stateId)
// We don't set the default namespace until after putting this message into the state
// manager/etc.
// All our internal handling is on the true (null) namespace,
// we just set the default namespace when handing off to destination-specific code.
val mangledMessage =
if (message.record!!.namespace.isNullOrEmpty()) {
val clone = Jsons.clone(message)
clone.record!!.namespace = defaultNamespace
clone
} else {
message
}

var addedToQueue = queue.offer(mangledMessage, sizeInBytes.toLong(), stateId)

var i = 0
while (!addedToQueue) {
val newlyAllocatedMemory = memoryManager.requestMemory()
if (newlyAllocatedMemory > 0) {
queue.addMaxMemory(newlyAllocatedMemory)
}
addedToQueue = queue.offer(message, sizeInBytes.toLong(), stateId)
addedToQueue = queue.offer(mangledMessage, sizeInBytes.toLong(), stateId)
i++
if (i > 5) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ private val logger = KotlinLogging.logger {}
class BufferManager
@JvmOverloads
constructor(
defaultNamespace: String,
maxMemory: Long = (Runtime.getRuntime().maxMemory() * MEMORY_LIMIT_RATIO).toLong(),
) {
@get:VisibleForTesting val buffers: ConcurrentMap<StreamDescriptor, StreamAwareQueue>
Expand All @@ -46,7 +47,7 @@ constructor(
memoryManager = GlobalMemoryManager(maxMemory)
this.stateManager = GlobalAsyncStateManager(memoryManager)
buffers = ConcurrentHashMap()
bufferEnqueue = BufferEnqueue(memoryManager, buffers, stateManager)
bufferEnqueue = BufferEnqueue(memoryManager, buffers, stateManager, defaultNamespace)
bufferDequeue = BufferDequeue(memoryManager, buffers, stateManager)
debugLoop = Executors.newSingleThreadScheduledExecutor()
debugLoop.scheduleAtFixedRate(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
package io.airbyte.cdk.integrations.destination.async.state

import com.google.common.base.Preconditions
import com.google.common.base.Strings
import io.airbyte.cdk.integrations.destination.async.GlobalMemoryManager
import io.airbyte.cdk.integrations.destination.async.model.PartialAirbyteMessage
import io.airbyte.commons.json.Jsons
Expand Down Expand Up @@ -104,7 +103,6 @@ class GlobalAsyncStateManager(private val memoryManager: GlobalMemoryManager) {
fun trackState(
message: PartialAirbyteMessage,
sizeInBytes: Long,
defaultNamespace: String,
) {
if (preState) {
convertToGlobalIfNeeded(message)
Expand All @@ -113,7 +111,7 @@ class GlobalAsyncStateManager(private val memoryManager: GlobalMemoryManager) {
// stateType should not change after a conversion.
Preconditions.checkArgument(stateType == extractStateType(message))

closeState(message, sizeInBytes, defaultNamespace)
closeState(message, sizeInBytes)
}

/**
Expand Down Expand Up @@ -333,10 +331,9 @@ class GlobalAsyncStateManager(private val memoryManager: GlobalMemoryManager) {
private fun closeState(
message: PartialAirbyteMessage,
sizeInBytes: Long,
defaultNamespace: String,
) {
val resolvedDescriptor: StreamDescriptor =
extractStream(message, defaultNamespace)
extractStream(message)
.orElse(
SENTINEL_GLOBAL_DESC,
)
Expand Down Expand Up @@ -434,38 +431,14 @@ class GlobalAsyncStateManager(private val memoryManager: GlobalMemoryManager) {
UUID.randomUUID().toString(),
)

/**
* If the user has selected the Destination Namespace as the Destination default while
* setting up the connector, the platform sets the namespace as null in the StreamDescriptor
* in the AirbyteMessages (both record and state messages). The destination checks that if
* the namespace is empty or null, if yes then re-populates it with the defaultNamespace.
* See [io.airbyte.cdk.integrations.destination.async.AsyncStreamConsumer.accept] But
* destination only does this for the record messages. So when state messages arrive without
* a namespace and since the destination doesn't repopulate it with the default namespace,
* there is a mismatch between the StreamDescriptor from record messages and state messages.
* That breaks the logic of the state management class as [descToStateIdQ] needs to have
* consistent StreamDescriptor. This is why while trying to extract the StreamDescriptor
* from state messages, we check if the namespace is null, if yes then replace it with
* defaultNamespace to keep it consistent with the record messages.
*/
private fun extractStream(
message: PartialAirbyteMessage,
defaultNamespace: String,
): Optional<StreamDescriptor> {
if (
message.state?.type != null &&
message.state?.type == AirbyteStateMessage.AirbyteStateType.STREAM
) {
val streamDescriptor: StreamDescriptor? = message.state?.stream?.streamDescriptor
if (Strings.isNullOrEmpty(streamDescriptor?.namespace)) {
return Optional.of(
StreamDescriptor()
.withName(
streamDescriptor?.name,
)
.withNamespace(defaultNamespace),
)
}
return streamDescriptor?.let { Optional.of(it) } ?: Optional.empty()
}
return Optional.empty()
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import io.airbyte.protocol.models.v0.StreamDescriptor
import java.io.IOException
import java.math.BigDecimal
import java.time.Instant
import java.util.Optional
import java.util.concurrent.Executors
import java.util.concurrent.TimeUnit
import java.util.concurrent.TimeoutException
Expand Down Expand Up @@ -61,7 +60,7 @@ class AsyncStreamConsumerTest {
private val CATALOG: ConfiguredAirbyteCatalog =
ConfiguredAirbyteCatalog()
.withStreams(
java.util.List.of(
listOf(
CatalogHelpers.createConfiguredAirbyteStream(
STREAM_NAME,
SCHEMA_NAME,
Expand Down Expand Up @@ -146,9 +145,9 @@ class AsyncStreamConsumerTest {
onClose = onClose,
onFlush = flushFunction,
catalog = CATALOG,
bufferManager = BufferManager(),
bufferManager = BufferManager("default_ns"),
flushFailure = flushFailure,
defaultNamespace = Optional.of("default_ns"),
defaultNamespace = "default_ns",
airbyteMessageDeserializer = airbyteMessageDeserializer,
workerPool = Executors.newFixedThreadPool(5),
)
Expand Down Expand Up @@ -265,9 +264,9 @@ class AsyncStreamConsumerTest {
Mockito.mock(OnCloseFunction::class.java),
flushFunction,
CATALOG,
BufferManager((1024 * 10).toLong()),
BufferManager("default_ns", (1024 * 10).toLong()),
"default_ns",
flushFailure,
Optional.of("default_ns"),
)
Mockito.`when`(flushFunction.optimalBatchSizeBytes).thenReturn(0L)

Expand Down

0 comments on commit 96f4b23

Please sign in to comment.