|
58 | 58 | ] |
59 | 59 | SSLType = typing.Union[_ParsedSSLType, SSLStringValues, bool] |
60 | 60 | HostType = typing.Union[typing.List[str], str] |
61 | | -PortType = typing.Union[typing.List[int], int] |
| 61 | +PortListType = typing.Union[ |
| 62 | + typing.List[typing.Union[int, str]], |
| 63 | + typing.List[int], |
| 64 | + typing.List[str], |
| 65 | +] |
| 66 | +PortType = typing.Union[ |
| 67 | + PortListType, |
| 68 | + int, |
| 69 | + str |
| 70 | +] |
62 | 71 |
|
63 | 72 |
|
64 | 73 | class SSLMode(enum.IntEnum): |
@@ -192,26 +201,42 @@ def _read_password_from_pgpass( |
192 | 201 | return None |
193 | 202 |
|
194 | 203 |
|
195 | | -def _validate_port_spec(hosts: typing.List[str], |
196 | | - port: PortType) \ |
197 | | - -> typing.List[int]: |
| 204 | +@typing.overload |
| 205 | +def _validate_port_spec( |
| 206 | + hosts: typing.List[str], |
| 207 | + port: PortListType |
| 208 | +) -> typing.List[int]: |
| 209 | + ... |
| 210 | + |
| 211 | + |
| 212 | +@typing.overload |
| 213 | +def _validate_port_spec( |
| 214 | + hosts: typing.List[str], |
| 215 | + port: typing.Union[int, str] |
| 216 | +) -> typing.List[int]: |
| 217 | + ... |
| 218 | + |
| 219 | + |
| 220 | +def _validate_port_spec( |
| 221 | + hosts: typing.List[str], |
| 222 | + port: PortType |
| 223 | +) -> typing.List[int]: |
198 | 224 | if isinstance(port, list): |
199 | 225 | # If there is a list of ports, its length must |
200 | 226 | # match that of the host list. |
201 | 227 | if len(port) != len(hosts): |
202 | 228 | raise exceptions.InterfaceError( |
203 | 229 | 'could not match {} port numbers to {} hosts'.format( |
204 | 230 | len(port), len(hosts))) |
| 231 | + return [int(p) for p in port] |
205 | 232 | else: |
206 | | - port = [port for _ in range(len(hosts))] |
207 | | - |
208 | | - return port |
| 233 | + return [int(port) for _ in range(len(hosts))] |
209 | 234 |
|
210 | 235 |
|
211 | 236 | def _parse_hostlist(hostlist: str, |
212 | 237 | port: typing.Optional[PortType], |
213 | 238 | *, unquote: bool = False) \ |
214 | | - -> typing.Tuple[typing.List[str], typing.List[int]]: |
| 239 | + -> typing.Tuple[typing.List[str], PortListType]: |
215 | 240 | if ',' in hostlist: |
216 | 241 | # A comma-separated list of host addresses. |
217 | 242 | hostspecs = hostlist.split(',') |
@@ -242,7 +267,7 @@ def _parse_hostlist(hostlist: str, |
242 | 267 | if hostspec[0] == '/': |
243 | 268 | # Unix socket |
244 | 269 | addr = hostspec |
245 | | - hostspec_port = '' |
| 270 | + hostspec_port: str = '' |
246 | 271 | elif hostspec[0] == '[': |
247 | 272 | # IPv6 address |
248 | 273 | m = re.match(r'(?:\[([^\]]+)\])(?::([0-9]+))?', hostspec) |
@@ -470,13 +495,10 @@ def _parse_connect_dsn_and_args(*, dsn: typing.Optional[str], |
470 | 495 | else: |
471 | 496 | port = 5432 |
472 | 497 |
|
473 | | - elif isinstance(port, (list, tuple)): |
474 | | - port = [int(p) for p in port] |
475 | | - |
476 | | - else: |
| 498 | + elif not isinstance(port, (list, tuple)): |
477 | 499 | port = int(port) |
478 | 500 |
|
479 | | - port = _validate_port_spec(host, port) |
| 501 | + validated_ports = _validate_port_spec(host, port) |
480 | 502 |
|
481 | 503 | if user is None: |
482 | 504 | user = os.getenv('PGUSER') |
@@ -517,13 +539,13 @@ def _parse_connect_dsn_and_args(*, dsn: typing.Optional[str], |
517 | 539 |
|
518 | 540 | if passfile_path is not None: |
519 | 541 | password = _read_password_from_pgpass( |
520 | | - hosts=auth_hosts, ports=port, |
| 542 | + hosts=auth_hosts, ports=validated_ports, |
521 | 543 | database=database, user=user, |
522 | 544 | passfile=passfile_path) |
523 | 545 |
|
524 | 546 | addrs: typing.List[AddrType] = [] |
525 | 547 | have_tcp_addrs = False |
526 | | - for h, p in zip(host, port): |
| 548 | + for h, p in zip(host, validated_ports): |
527 | 549 | if h.startswith('/'): |
528 | 550 | # UNIX socket name |
529 | 551 | if '.s.PGSQL.' not in h: |
|
0 commit comments