"""Test suite for startup strategies."""
from unittest.mock import Mock, patch, MagicMock
from django.test import TestCase
from management.models import KeyStorageConfig
from management.util.startup_strategies import (
DekCacheState,
StartupContext,
WizardState,
DatabaseNotInitializedStrategy,
DatabaseInitializedNoVersionStrategy,
VersionMatchStrategy,
VersionUpgradeStrategy,
RestoreSoftwareWizardCompletedStrategy,
RestoreSoftwareWizardIncompleteStrategy,
StartupStrategySelector,
RestoreSoftHsmWizardCompletedDekCachedStrategy,
RestoreSoftHsmNewKekWizardCompletedStrategy,
)
from packaging.version import Version
from setup_wizard import SetupWizardState
[docs]
class WizardStateTest(TestCase):
"""Test suite for WizardState enum."""
[docs]
def test_completed_value(self):
"""Test COMPLETED enum value."""
self.assertEqual(WizardState.COMPLETED.value, 'COMPLETED')
[docs]
def test_incomplete_value(self):
"""Test INCOMPLETE enum value."""
self.assertEqual(WizardState.INCOMPLETE.value, 'INCOMPLETE')
[docs]
class DekCacheStateTest(TestCase):
"""Test suite for DekCacheState enum."""
[docs]
def test_cached_value(self):
"""Test CACHED enum value."""
self.assertEqual(DekCacheState.CACHED.value, 'CACHED')
[docs]
def test_not_cached_value(self):
"""Test NOT_CACHED enum value."""
self.assertEqual(DekCacheState.NOT_CACHED.value, 'NOT_CACHED')
[docs]
class StartupContextTest(TestCase):
"""Test suite for StartupContext dataclass."""
[docs]
def setUp(self):
"""Set up test fixtures."""
self.mock_output = Mock()
self.current_version = Version('1.0.0')
self.db_version = Version('1.0.0')
[docs]
def test_initialization(self):
"""Test StartupContext initialization."""
context = StartupContext(
db_initialized=True,
db_version=self.db_version,
current_version=self.current_version,
wizard_state_enum=WizardState.COMPLETED,
wizard_state_raw=SetupWizardState.WIZARD_COMPLETED,
storage_type=KeyStorageConfig.StorageType.SOFTWARE,
dek_cache_state=None,
output=self.mock_output,
)
self.assertTrue(context.db_initialized)
self.assertEqual(context.db_version, self.db_version)
self.assertEqual(context.current_version, self.current_version)
self.assertEqual(context.wizard_state_enum, WizardState.COMPLETED)
self.assertFalse(context.has_kek)
self.assertFalse(context.has_backup_encrypted_dek)
[docs]
def test_is_wizard_completed_true(self):
"""Test is_wizard_completed property when completed."""
context = StartupContext(
db_initialized=True,
db_version=self.db_version,
current_version=self.current_version,
wizard_state_enum=WizardState.COMPLETED,
wizard_state_raw=None,
storage_type=None,
dek_cache_state=None,
output=self.mock_output,
)
self.assertTrue(context.is_wizard_completed)
[docs]
def test_is_wizard_completed_false(self):
"""Test is_wizard_completed property when incomplete."""
context = StartupContext(
db_initialized=True,
db_version=self.db_version,
current_version=self.current_version,
wizard_state_enum=WizardState.INCOMPLETE,
wizard_state_raw=None,
storage_type=None,
dek_cache_state=None,
output=self.mock_output,
)
self.assertFalse(context.is_wizard_completed)
[docs]
def test_is_software_storage(self):
"""Test is_software_storage property."""
context = StartupContext(
db_initialized=True,
db_version=self.db_version,
current_version=self.current_version,
wizard_state_enum=WizardState.COMPLETED,
wizard_state_raw=None,
storage_type=KeyStorageConfig.StorageType.SOFTWARE,
dek_cache_state=None,
output=self.mock_output,
)
self.assertTrue(context.is_software_storage)
self.assertFalse(context.is_softhsm_storage)
self.assertFalse(context.is_physical_hsm_storage)
self.assertFalse(context.is_hsm_storage)
[docs]
def test_is_softhsm_storage(self):
"""Test is_softhsm_storage property."""
context = StartupContext(
db_initialized=True,
db_version=self.db_version,
current_version=self.current_version,
wizard_state_enum=WizardState.COMPLETED,
wizard_state_raw=None,
storage_type=KeyStorageConfig.StorageType.SOFTHSM,
dek_cache_state=DekCacheState.CACHED,
output=self.mock_output,
)
self.assertFalse(context.is_software_storage)
self.assertTrue(context.is_softhsm_storage)
self.assertFalse(context.is_physical_hsm_storage)
self.assertTrue(context.is_hsm_storage)
[docs]
def test_is_physical_hsm_storage(self):
"""Test is_physical_hsm_storage property."""
context = StartupContext(
db_initialized=True,
db_version=self.db_version,
current_version=self.current_version,
wizard_state_enum=WizardState.COMPLETED,
wizard_state_raw=None,
storage_type=KeyStorageConfig.StorageType.PHYSICAL_HSM,
dek_cache_state=DekCacheState.CACHED,
output=self.mock_output,
)
self.assertFalse(context.is_software_storage)
self.assertFalse(context.is_softhsm_storage)
self.assertTrue(context.is_physical_hsm_storage)
self.assertTrue(context.is_hsm_storage)
[docs]
def test_is_dek_cached_true(self):
"""Test is_dek_cached property when cached."""
context = StartupContext(
db_initialized=True,
db_version=self.db_version,
current_version=self.current_version,
wizard_state_enum=WizardState.COMPLETED,
wizard_state_raw=None,
storage_type=KeyStorageConfig.StorageType.SOFTHSM,
dek_cache_state=DekCacheState.CACHED,
output=self.mock_output,
)
self.assertTrue(context.is_dek_cached)
[docs]
def test_is_dek_cached_false(self):
"""Test is_dek_cached property when not cached."""
context = StartupContext(
db_initialized=True,
db_version=self.db_version,
current_version=self.current_version,
wizard_state_enum=WizardState.COMPLETED,
wizard_state_raw=None,
storage_type=KeyStorageConfig.StorageType.SOFTHSM,
dek_cache_state=DekCacheState.NOT_CACHED,
output=self.mock_output,
)
self.assertFalse(context.is_dek_cached)
[docs]
def test_is_dek_cached_raises_for_software_storage(self):
"""Test is_dek_cached raises ValueError for software storage."""
context = StartupContext(
db_initialized=True,
db_version=self.db_version,
current_version=self.current_version,
wizard_state_enum=WizardState.COMPLETED,
wizard_state_raw=None,
storage_type=KeyStorageConfig.StorageType.SOFTWARE,
dek_cache_state=None,
output=self.mock_output,
)
with self.assertRaises(ValueError) as cm:
_ = context.is_dek_cached
self.assertIn('only applicable for HSM storage', str(cm.exception))
[docs]
def test_is_new_kek_scenario_true(self):
"""Test is_new_kek_scenario when conditions are met."""
context = StartupContext(
db_initialized=True,
db_version=self.db_version,
current_version=self.current_version,
wizard_state_enum=WizardState.COMPLETED,
wizard_state_raw=None,
storage_type=KeyStorageConfig.StorageType.SOFTHSM,
dek_cache_state=DekCacheState.NOT_CACHED,
output=self.mock_output,
has_kek=False,
has_backup_encrypted_dek=True,
)
self.assertTrue(context.is_new_kek_scenario)
[docs]
def test_is_new_kek_scenario_false_software_storage(self):
"""Test is_new_kek_scenario is False for software storage."""
context = StartupContext(
db_initialized=True,
db_version=self.db_version,
current_version=self.current_version,
wizard_state_enum=WizardState.COMPLETED,
wizard_state_raw=None,
storage_type=KeyStorageConfig.StorageType.SOFTWARE,
dek_cache_state=None,
output=self.mock_output,
has_kek=False,
has_backup_encrypted_dek=True,
)
self.assertFalse(context.is_new_kek_scenario)
[docs]
def test_is_new_kek_scenario_false_dek_cached(self):
"""Test is_new_kek_scenario is False when DEK is cached."""
context = StartupContext(
db_initialized=True,
db_version=self.db_version,
current_version=self.current_version,
wizard_state_enum=WizardState.COMPLETED,
wizard_state_raw=None,
storage_type=KeyStorageConfig.StorageType.SOFTHSM,
dek_cache_state=DekCacheState.CACHED,
output=self.mock_output,
has_kek=False,
has_backup_encrypted_dek=True,
)
self.assertFalse(context.is_new_kek_scenario)
[docs]
def test_is_new_kek_scenario_false_has_kek(self):
"""Test is_new_kek_scenario is False when KEK exists."""
context = StartupContext(
db_initialized=True,
db_version=self.db_version,
current_version=self.current_version,
wizard_state_enum=WizardState.COMPLETED,
wizard_state_raw=None,
storage_type=KeyStorageConfig.StorageType.SOFTHSM,
dek_cache_state=DekCacheState.NOT_CACHED,
output=self.mock_output,
has_kek=True,
has_backup_encrypted_dek=True,
)
self.assertFalse(context.is_new_kek_scenario)
[docs]
def test_is_new_kek_scenario_false_no_backup(self):
"""Test is_new_kek_scenario is False without backup encrypted DEK."""
context = StartupContext(
db_initialized=True,
db_version=self.db_version,
current_version=self.current_version,
wizard_state_enum=WizardState.COMPLETED,
wizard_state_raw=None,
storage_type=KeyStorageConfig.StorageType.SOFTHSM,
dek_cache_state=DekCacheState.NOT_CACHED,
output=self.mock_output,
has_kek=False,
has_backup_encrypted_dek=False,
)
self.assertFalse(context.is_new_kek_scenario)
[docs]
class DatabaseNotInitializedStrategyTest(TestCase):
"""Test suite for DatabaseNotInitializedStrategy."""
[docs]
def setUp(self):
"""Set up test fixtures."""
self.mock_output = Mock()
self.mock_init_strategy = Mock()
self.strategy = DatabaseNotInitializedStrategy(init_strategy=self.mock_init_strategy)
self.context = StartupContext(
db_initialized=False,
db_version=None,
current_version=Version('1.0.0'),
wizard_state_enum=WizardState.INCOMPLETE,
wizard_state_raw=None,
storage_type=None,
dek_cache_state=None,
output=self.mock_output,
)
[docs]
def test_get_description(self):
"""Test get_description returns correct string."""
description = self.strategy.get_description()
self.assertIn('not initialized', description.lower())
[docs]
def test_execute_calls_init_strategy(self):
"""Test execute calls initialization strategy with TLS."""
self.strategy.execute(self.context)
self.mock_init_strategy.initialize.assert_called_once_with(self.context, with_tls=True)
self.mock_output.write.assert_called()
[docs]
class DatabaseInitializedNoVersionStrategyTest(TestCase):
"""Test suite for DatabaseInitializedNoVersionStrategy."""
[docs]
def setUp(self):
"""Set up test fixtures."""
self.mock_output = Mock()
self.mock_init_strategy = Mock()
self.strategy = DatabaseInitializedNoVersionStrategy(init_strategy=self.mock_init_strategy)
self.context = StartupContext(
db_initialized=True,
db_version=None,
current_version=Version('1.0.0'),
wizard_state_enum=WizardState.INCOMPLETE,
wizard_state_raw=None,
storage_type=None,
dek_cache_state=None,
output=self.mock_output,
)
[docs]
def test_get_description(self):
"""Test get_description returns correct string."""
description = self.strategy.get_description()
self.assertIn('no version', description.lower())
[docs]
def test_execute_calls_init_strategy(self):
"""Test execute calls initialization strategy with TLS."""
self.strategy.execute(self.context)
self.mock_init_strategy.initialize.assert_called_once_with(self.context, with_tls=True)
self.mock_output.write.assert_called()
[docs]
class VersionMatchStrategyTest(TestCase):
"""Test suite for VersionMatchStrategy."""
[docs]
def setUp(self):
"""Set up test fixtures."""
self.mock_output = Mock()
self.mock_restore_strategy = Mock()
self.mock_init_strategy = Mock()
self.strategy = VersionMatchStrategy(
restore_strategy=self.mock_restore_strategy,
init_strategy=self.mock_init_strategy
)
self.context = StartupContext(
db_initialized=True,
db_version=Version('1.0.0'),
current_version=Version('1.0.0'),
wizard_state_enum=WizardState.COMPLETED,
wizard_state_raw=SetupWizardState.WIZARD_COMPLETED,
storage_type=KeyStorageConfig.StorageType.SOFTWARE,
dek_cache_state=None,
output=self.mock_output,
)
[docs]
def test_get_description(self):
"""Test get_description returns correct string."""
description = self.strategy.get_description()
self.assertIn('match', description.lower())
[docs]
def test_execute_initializes_and_restores(self):
"""Test execute calls init and restore strategies."""
self.strategy.execute(self.context)
self.mock_init_strategy.initialize.assert_called_once_with(self.context, with_tls=False)
self.mock_restore_strategy.execute.assert_called_once_with(self.context)
self.mock_output.write.assert_called()
[docs]
class VersionUpgradeStrategyTest(TestCase):
"""Test suite for VersionUpgradeStrategy."""
[docs]
def setUp(self):
"""Set up test fixtures."""
self.mock_output = Mock()
self.mock_restore_strategy = Mock()
self.mock_init_strategy = Mock()
self.mock_app_version = Mock()
self.mock_app_version.version = '0.9.0'
self.strategy = VersionUpgradeStrategy(
restore_strategy=self.mock_restore_strategy,
app_version=self.mock_app_version,
init_strategy=self.mock_init_strategy
)
self.context = StartupContext(
db_initialized=True,
db_version=Version('0.9.0'),
current_version=Version('1.0.0'),
wizard_state_enum=WizardState.COMPLETED,
wizard_state_raw=SetupWizardState.WIZARD_COMPLETED,
storage_type=KeyStorageConfig.StorageType.SOFTWARE,
dek_cache_state=None,
output=self.mock_output,
)
[docs]
def test_get_description(self):
"""Test get_description returns correct string."""
description = self.strategy.get_description()
self.assertIn('upgrade', description.lower())
[docs]
def test_execute_upgrades_version(self):
"""Test execute performs upgrade and updates version."""
self.strategy.execute(self.context)
self.mock_init_strategy.initialize.assert_called_once_with(self.context, with_tls=False)
self.mock_restore_strategy.execute.assert_called_once_with(self.context)
self.assertEqual(self.mock_app_version.version, '1.0.0')
self.mock_app_version.save.assert_called_once()
self.mock_output.write.assert_called()
[docs]
class StartupStrategySelectorTest(TestCase):
"""Test suite for StartupStrategySelector."""
[docs]
def setUp(self):
"""Set up test fixtures."""
self.mock_output = Mock()
[docs]
def test_select_startup_strategy_db_not_initialized(self):
"""Test select_startup_strategy when DB not initialized."""
strategy = StartupStrategySelector.select_startup_strategy(
db_initialized=False,
has_version=False,
)
self.assertIsInstance(strategy, DatabaseNotInitializedStrategy)
[docs]
def test_select_startup_strategy_db_initialized_no_version(self):
"""Test select_startup_strategy when DB has no version."""
strategy = StartupStrategySelector.select_startup_strategy(
db_initialized=True,
has_version=False,
)
self.assertIsInstance(strategy, DatabaseInitializedNoVersionStrategy)
[docs]
def test_select_startup_strategy_requires_context_and_version(self):
"""Test select_startup_strategy raises error without context/version."""
with self.assertRaises(ValueError) as cm:
StartupStrategySelector.select_startup_strategy(
db_initialized=True,
has_version=True,
)
self.assertIn('required', str(cm.exception).lower())
[docs]
def test_select_restore_strategy_software_wizard_completed(self):
"""Test select_restore_strategy for software storage with completed wizard."""
context = StartupContext(
db_initialized=True,
db_version=Version('1.0.0'),
current_version=Version('1.0.0'),
wizard_state_enum=WizardState.COMPLETED,
wizard_state_raw=None,
storage_type=KeyStorageConfig.StorageType.SOFTWARE,
dek_cache_state=None,
output=self.mock_output,
)
strategy = StartupStrategySelector.select_restore_strategy(context)
self.assertIsInstance(strategy, RestoreSoftwareWizardCompletedStrategy)
[docs]
def test_select_restore_strategy_software_wizard_incomplete(self):
"""Test select_restore_strategy for software storage with incomplete wizard."""
context = StartupContext(
db_initialized=True,
db_version=Version('1.0.0'),
current_version=Version('1.0.0'),
wizard_state_enum=WizardState.INCOMPLETE,
wizard_state_raw=None,
storage_type=KeyStorageConfig.StorageType.SOFTWARE,
dek_cache_state=None,
output=self.mock_output,
)
strategy = StartupStrategySelector.select_restore_strategy(context)
self.assertIsInstance(strategy, RestoreSoftwareWizardIncompleteStrategy)
[docs]
def test_select_restore_strategy_softhsm_wizard_completed_dek_cached(self):
"""Test select_restore_strategy for SoftHSM with completed wizard and cached DEK."""
context = StartupContext(
db_initialized=True,
db_version=Version('1.0.0'),
current_version=Version('1.0.0'),
wizard_state_enum=WizardState.COMPLETED,
wizard_state_raw=None,
storage_type=KeyStorageConfig.StorageType.SOFTHSM,
dek_cache_state=DekCacheState.CACHED,
output=self.mock_output,
)
strategy = StartupStrategySelector.select_restore_strategy(context)
self.assertIsInstance(strategy, RestoreSoftHsmWizardCompletedDekCachedStrategy)
[docs]
def test_select_restore_strategy_softhsm_new_kek(self):
"""Test select_restore_strategy for SoftHSM with new KEK (old KEK lost)."""
context = StartupContext(
db_initialized=True,
db_version=Version('1.0.0'),
current_version=Version('1.0.0'),
wizard_state_enum=WizardState.COMPLETED,
wizard_state_raw=None,
storage_type=KeyStorageConfig.StorageType.SOFTHSM,
dek_cache_state=DekCacheState.NOT_CACHED,
output=self.mock_output,
has_kek=False,
has_backup_encrypted_dek=True,
)
strategy = StartupStrategySelector.select_restore_strategy(context)
self.assertIsInstance(strategy, RestoreSoftHsmNewKekWizardCompletedStrategy)
[docs]
def test_select_restore_strategy_unsupported_storage_raises_error(self):
"""Test select_restore_strategy raises error for unsupported storage."""
context = StartupContext(
db_initialized=True,
db_version=Version('1.0.0'),
current_version=Version('1.0.0'),
wizard_state_enum=WizardState.COMPLETED,
wizard_state_raw=None,
storage_type=None,
dek_cache_state=None,
output=self.mock_output,
)
with self.assertRaises(ValueError) as cm:
StartupStrategySelector.select_restore_strategy(context)
self.assertIn('unexpected', str(cm.exception).lower())
@patch('management.util.startup_strategies.AppVersion')
[docs]
def test_select_version_strategy_version_match(self, mock_app_version):
"""Test select_version_strategy when versions match."""
context = StartupContext(
db_initialized=True,
db_version=Version('1.0.0'),
current_version=Version('1.0.0'),
wizard_state_enum=WizardState.COMPLETED,
wizard_state_raw=None,
storage_type=KeyStorageConfig.StorageType.SOFTWARE,
dek_cache_state=None,
output=self.mock_output,
)
strategy = StartupStrategySelector.select_version_strategy(context, mock_app_version)
self.assertIsInstance(strategy, VersionMatchStrategy)
@patch('management.util.startup_strategies.AppVersion')
[docs]
def test_select_version_strategy_version_upgrade(self, mock_app_version):
"""Test select_version_strategy when upgrade needed."""
context = StartupContext(
db_initialized=True,
db_version=Version('0.9.0'),
current_version=Version('1.0.0'),
wizard_state_enum=WizardState.COMPLETED,
wizard_state_raw=None,
storage_type=KeyStorageConfig.StorageType.SOFTWARE,
dek_cache_state=None,
output=self.mock_output,
)
strategy = StartupStrategySelector.select_version_strategy(context, mock_app_version)
self.assertIsInstance(strategy, VersionUpgradeStrategy)
@patch('management.util.startup_strategies.AppVersion')
[docs]
def test_select_version_strategy_version_downgrade_raises_error(self, mock_app_version):
"""Test select_version_strategy raises error for downgrade."""
context = StartupContext(
db_initialized=True,
db_version=Version('2.0.0'),
current_version=Version('1.0.0'),
wizard_state_enum=WizardState.COMPLETED,
wizard_state_raw=None,
storage_type=KeyStorageConfig.StorageType.SOFTWARE,
dek_cache_state=None,
output=self.mock_output,
)
with self.assertRaises(RuntimeError) as cm:
StartupStrategySelector.select_version_strategy(context, mock_app_version)
self.assertIn('not supported', str(cm.exception).lower())