diff --git a/tests/test_config_api.py b/tests/test_config_api.py index 27036ee..a448d07 100644 --- a/tests/test_config_api.py +++ b/tests/test_config_api.py @@ -371,6 +371,8 @@ async def test_invalid_option(): PortOption('a', '', allow_zero=True, allow_wellknown=False, allow_registred=True, allow_private=True) with pytest.raises(ValueError): PortOption('a', '', allow_zero=False, allow_wellknown=False, allow_registred=False, allow_private=False) + with pytest.raises(ValueError): + PortOption('a', '', 'tcp:80') NetworkOption('a', '') with pytest.raises(ValueError): NetworkOption('a', '', 'string') diff --git a/tests/test_config_ip.py b/tests/test_config_ip.py index 947b5dc..e1e541a 100644 --- a/tests/test_config_ip.py +++ b/tests/test_config_ip.py @@ -295,6 +295,16 @@ async def test_port(config_type): assert not await list_sessions() +@pytest.mark.asyncio +async def test_port_protocol(config_type): + a = PortOption('a', '', allow_protocol=True) + od = OptionDescription('od', '', [a]) + async with await Config(od) as cfg: + await cfg.option('a').value.set('80') + await cfg.option('a').value.set('tcp:80') + assert not await list_sessions() + + @pytest.mark.asyncio async def test_port_range(config_type): a = PortOption('a', '', allow_range=True) diff --git a/tiramisu/option/portoption.py b/tiramisu/option/portoption.py index 59db844..c277344 100644 --- a/tiramisu/option/portoption.py +++ b/tiramisu/option/portoption.py @@ -48,12 +48,15 @@ class PortOption(StrOption): allow_zero: bool=False, allow_wellknown: bool=True, allow_registred: bool=True, + allow_protocol: bool=False, allow_private: bool=False, **kwargs) -> None: extra = {'_allow_range': allow_range, + '_allow_protocol': allow_protocol, '_min_value': None, - '_max_value': None} + '_max_value': None, + } ports_min = [0, 1, 1024, 49152] ports_max = [0, 1023, 49151, 65535] is_finally = False @@ -81,7 +84,9 @@ class PortOption(StrOption): def validate(self, value: str) -> None: super().validate(value) - if self.impl_get_extra('_allow_range') and ":" in str(value): + if self.impl_get_extra('_allow_protocol') and (value.startswith('tcp:') or value.startswith('udp:')): + value = [value[4:]] + elif self.impl_get_extra('_allow_range') and ":" in str(value): value = value.split(':') if len(value) != 2: raise ValueError(_('range must have two values only')) @@ -98,7 +103,13 @@ class PortOption(StrOption): def second_level_validation(self, value: str, warnings_only: bool) -> None: - for val in value.split(':'): + if self.impl_get_extra('_allow_protocol') and (value.startswith('tcp:') or value.startswith('udp:')): + value = [value[4:]] + elif ':' in value: + value = value.split(':') + else: + value = [value] + for val in value: val = int(val) if not self.impl_get_extra('_min_value') <= val <= self.impl_get_extra('_max_value'): if warnings_only: