diff --git a/Makefile b/Makefile index 64572ad..8283fbe 100644 --- a/Makefile +++ b/Makefile @@ -44,6 +44,10 @@ test-lmdb: @echo "⚡ Running LMDB tests..." $(PYTHON) pytest tests/ -m "lmdb" -v --log-cli-level=ERROR +test-kafka: + @echo "Running Kafka tests..." + $(PYTHON) pytest tests/ -m "kafka" -v --log-cli-level=ERROR + # Parallel streaming integration tests test-parallel-streaming: @echo "⚡ Running parallel streaming integration tests..." @@ -132,6 +136,7 @@ help: @echo " make test-postgresql - Run PostgreSQL tests" @echo " make test-redis - Run Redis tests" @echo " make test-snowflake - Run Snowflake tests" + @echo " make test-kafka - Run Kafka tests" @echo " make test-performance - Run performance tests" @echo " make lint - Lint code with ruff" @echo " make format - Format code with ruff" diff --git a/apps/test_kafka_query.py b/apps/test_kafka_query.py index 8bc7422..d05e9f7 100644 --- a/apps/test_kafka_query.py +++ b/apps/test_kafka_query.py @@ -13,7 +13,7 @@ from amp.loaders.types import LabelJoinConfig # Connect to Amp server -server_url = os.getenv('AMP_SERVER_URL', 'grpc://34.27.238.174:80') +server_url = os.getenv('AMP_SERVER_URL', 'grpc://127.0.0.1:1602') print(f'Connecting to {server_url}...') client = Client(server_url) print('✅ Connected!') diff --git a/src/amp/loaders/base.py b/src/amp/loaders/base.py index a2775eb..82eb75c 100644 --- a/src/amp/loaders/base.py +++ b/src/amp/loaders/base.py @@ -840,8 +840,21 @@ def _rewind_to_watermark(self, table_name: str, connection_name: Optional[str] = for range_obj in resume_pos.ranges: from_block = range_obj.end + 1 + # Check if there are actually uncommitted batches beyond the watermark + uncommitted = self.state_store.invalidate_from_block( + connection_name, table_name, range_obj.network, from_block + ) + + if not uncommitted: + self.logger.debug( + f'No uncommitted batches for {range_obj.network} beyond block {from_block}, ' + f'skipping crash recovery cleanup' + ) + continue + self.logger.info( - f'Crash recovery: Cleaning up {table_name} data for {range_obj.network} from block {from_block} onwards' + f'Crash recovery: Cleaning up {len(uncommitted)} uncommitted batches ' + f'for {range_obj.network} from block {from_block} onwards in {table_name}' ) invalidation_ranges = [ @@ -859,20 +872,13 @@ def _rewind_to_watermark(self, table_name: str, connection_name: Optional[str] = self.logger.info(f'Crash recovery completed for {range_obj.network} in {table_name}') except NotImplementedError: - invalidated = self.state_store.invalidate_from_block( - connection_name, table_name, range_obj.network, from_block + self.logger.warning( + f'Crash recovery: Cleared {len(uncommitted)} batches from state ' + f'for {range_obj.network} but cannot delete data from {table_name}. ' + f'{self.__class__.__name__} does not support data deletion. ' + f'Duplicates may occur on resume.' ) - if invalidated: - self.logger.warning( - f'Crash recovery: Cleared {len(invalidated)} batches from state ' - f'for {range_obj.network} but cannot delete data from {table_name}. ' - f'{self.__class__.__name__} does not support data deletion. ' - f'Duplicates may occur on resume.' - ) - else: - self.logger.debug(f'No uncommitted batches found for {range_obj.network}') - def _add_metadata_columns(self, data: pa.RecordBatch, block_ranges: List[BlockRange]) -> pa.RecordBatch: """ Add metadata columns for streaming data with compact batch identification. diff --git a/src/amp/loaders/implementations/kafka_loader.py b/src/amp/loaders/implementations/kafka_loader.py index b86f346..76c0ae0 100644 --- a/src/amp/loaders/implementations/kafka_loader.py +++ b/src/amp/loaders/implementations/kafka_loader.py @@ -25,7 +25,7 @@ class KafkaConfig: class KafkaLoader(DataLoader[KafkaConfig]): SUPPORTED_MODES = {LoadMode.APPEND} REQUIRES_SCHEMA_MATCH = False - SUPPORTS_TRANSACTIONS = True + SUPPORTS_TRANSACTIONS = False def __init__(self, config: Dict[str, Any], label_manager=None) -> None: self._extra_producer_config = { @@ -34,6 +34,14 @@ def __init__(self, config: Dict[str, Any], label_manager=None) -> None: super().__init__(config, label_manager) self._producer = None + # Replace in-memory state store with LMDB if configured (before connect, consistent with other loaders) + if self.state_enabled and self.state_storage == 'lmdb': + self.state_store = LMDBStreamStateStore( + connection_name=self.config.client_id, + data_dir=self.state_data_dir, + ) + self.logger.info(f'Initialized LMDB state store at {self.state_store.data_dir}') + def _get_required_config_fields(self) -> list[str]: return ['bootstrap_servers'] @@ -57,13 +65,6 @@ def connect(self) -> None: self.logger.info(f'Connected to Kafka at {self.config.bootstrap_servers}') self.logger.info(f'Client ID: {self.config.client_id}') - if self.state_enabled and self.state_storage == 'lmdb': - self.state_store = LMDBStreamStateStore( - connection_name=self.config.client_id, - data_dir=self.state_data_dir, - ) - self.logger.info(f'Initialized LMDB state store at {self.state_store.data_dir}') - self._is_connected = True except Exception as e: @@ -84,9 +85,27 @@ def disconnect(self) -> None: self._is_connected = False self.logger.info('Disconnected from Kafka') + def health_check(self) -> Dict[str, Any]: + """Check Kafka broker connectivity.""" + base = { + 'healthy': False, + 'loader_type': 'kafka', + 'bootstrap_servers': self.config.bootstrap_servers, + 'client_id': self.config.client_id, + } + if not self._is_connected or not self._producer: + base['error'] = 'Not connected' + return base + try: + healthy = hasattr(self._producer, '_sender') and self._producer._sender.is_alive() + base['healthy'] = healthy + return base + except Exception as e: + base['error'] = str(e) + return base + def _create_table_from_schema(self, schema: pa.Schema, table_name: str) -> None: self.logger.info(f'Kafka topic {table_name} will be auto-created on first message send') - pass def _load_batch_impl(self, batch: pa.RecordBatch, table_name: str, **kwargs) -> int: if not self._producer: @@ -108,6 +127,7 @@ def _load_batch_impl(self, batch: pa.RecordBatch, table_name: str, **kwargs) -> self._producer.send(topic=table_name, key=key, value=row) + self._producer.flush() self._producer.commit_transaction() self.logger.debug(f'Committed transaction with {num_rows} messages to topic {table_name}') @@ -169,6 +189,7 @@ def _handle_reorg(self, invalidation_ranges: List[BlockRange], table_name: str, f'{invalidation_range.network} blocks {invalidation_range.start}-{invalidation_range.end}' ) + self._producer.flush() self._producer.commit_transaction() self.logger.info(f'Committed {len(invalidation_ranges)} reorg events to {reorg_topic}') diff --git a/tests/conftest.py b/tests/conftest.py index a984cb4..c76b7f4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -228,7 +228,9 @@ def kafka_container(): container.with_env('KAFKA_OFFSETS_TOPIC_REPLICATION_FACTOR', '1') container.start() - time.sleep(10) + # KafkaContainer.start() already waits for "[KafkaServer id=N] started" log. + # Brief additional wait for transaction coordinator to be fully ready. + time.sleep(5) yield container diff --git a/tests/integration/test_kafka_loader.py b/tests/integration/loaders/backends/test_kafka.py similarity index 96% rename from tests/integration/test_kafka_loader.py rename to tests/integration/loaders/backends/test_kafka.py index 092d13c..52420db 100644 --- a/tests/integration/test_kafka_loader.py +++ b/tests/integration/loaders/backends/test_kafka.py @@ -27,6 +27,18 @@ def test_loader_connection(self, kafka_test_config): assert loader._is_connected == False assert loader._producer is None + def test_health_check(self, kafka_test_config): + loader = KafkaLoader(kafka_test_config) + + health = loader.health_check() + assert health['healthy'] == False + assert 'error' in health + + with loader: + health = loader.health_check() + assert health['healthy'] == True + assert health['bootstrap_servers'] == kafka_test_config['bootstrap_servers'] + def test_context_manager(self, kafka_test_config): loader = KafkaLoader(kafka_test_config) diff --git a/tests/unit/test_crash_recovery.py b/tests/unit/test_crash_recovery.py index 76a4ae9..b84ded8 100644 --- a/tests/unit/test_crash_recovery.py +++ b/tests/unit/test_crash_recovery.py @@ -48,6 +48,7 @@ def test_rewind_calls_handle_reorg(self, mock_loader): """Should call _handle_reorg with correct invalidation ranges""" watermark = ResumeWatermark(ranges=[BlockRange(network='ethereum', start=1000, end=1010, hash='0xabc')]) mock_loader.state_store.get_resume_position = Mock(return_value=watermark) + mock_loader.state_store.invalidate_from_block = Mock(return_value=['batch1']) mock_loader._handle_reorg = Mock() mock_loader._rewind_to_watermark('test_table', 'test_conn') @@ -61,12 +62,23 @@ def test_rewind_calls_handle_reorg(self, mock_loader): assert call_args[0][1] == 'test_table' assert call_args[0][2] == 'test_conn' + def test_rewind_skips_reorg_when_no_uncommitted_batches(self, mock_loader): + """Should skip _handle_reorg when there are no uncommitted batches (e.g., clean restart)""" + watermark = ResumeWatermark(ranges=[BlockRange(network='ethereum', start=1000, end=1010, hash='0xabc')]) + mock_loader.state_store.get_resume_position = Mock(return_value=watermark) + mock_loader.state_store.invalidate_from_block = Mock(return_value=[]) + mock_loader._handle_reorg = Mock() + + mock_loader._rewind_to_watermark('test_table', 'test_conn') + + mock_loader._handle_reorg.assert_not_called() + def test_rewind_handles_not_implemented(self, mock_loader): """Should gracefully handle loaders without _handle_reorg""" watermark = ResumeWatermark(ranges=[BlockRange(network='ethereum', start=1000, end=1010, hash='0xabc')]) mock_loader.state_store.get_resume_position = Mock(return_value=watermark) mock_loader._handle_reorg = Mock(side_effect=NotImplementedError()) - mock_loader.state_store.invalidate_from_block = Mock(return_value=[]) + mock_loader.state_store.invalidate_from_block = Mock(return_value=['batch1']) mock_loader._rewind_to_watermark('test_table', 'test_conn') @@ -83,6 +95,7 @@ def test_rewind_with_multiple_networks(self, mock_loader): ] ) mock_loader.state_store.get_resume_position = Mock(return_value=watermark) + mock_loader.state_store.invalidate_from_block = Mock(return_value=['batch1']) mock_loader._handle_reorg = Mock() mock_loader._rewind_to_watermark('test_table', 'test_conn') @@ -101,6 +114,7 @@ def test_rewind_uses_default_connection_name(self, mock_loader): """Should use default connection name from loader class""" watermark = ResumeWatermark(ranges=[BlockRange(network='ethereum', start=1000, end=1010, hash='0xabc')]) mock_loader.state_store.get_resume_position = Mock(return_value=watermark) + mock_loader.state_store.invalidate_from_block = Mock(return_value=['batch1']) mock_loader._handle_reorg = Mock() mock_loader._rewind_to_watermark('test_table', connection_name=None)