Skip to content

Commit ee9f0dc

Browse files
committed
STS: When persisting STS keys, use the actual port instead of the one from the policy
'Servers MAY send this key to securely connected clients, but it will be ignored.' -- https://ircv3.net/specs/extensions/sts\#the-port-key
1 parent 74073b2 commit ee9f0dc

File tree

7 files changed

+69
-41
lines changed

7 files changed

+69
-41
lines changed

src/drivers/__init__.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,8 @@ def _getNextServer(self):
9797

9898
def _applyStsPolicy(self, server):
9999
network = ircdb.networks.getNetwork(self.networkName)
100-
policy = network.stsPolicies.get(server.hostname)
100+
(policy_port, policy) = network.stsPolicies.get(
101+
server.hostname, (None, None))
101102
lastDisconnect = network.lastDisconnectTimes.get(server.hostname)
102103

103104
if policy is None or lastDisconnect is None:
@@ -107,22 +108,22 @@ def _applyStsPolicy(self, server):
107108

108109
# The policy was stored, which means it was received on a secure
109110
# connection.
110-
policy = ircutils.parseStsPolicy(log, policy, parseDuration=True)
111+
policy = ircutils.parseStsPolicy(log, policy, secure_connection=True)
111112

112113
if lastDisconnect + policy['duration'] < time.time():
113114
log.info('STS policy expired, removing.')
114115
network.expireStsPolicy(server.hostname)
115116
return server
116117

117-
if server.port == policy['port']:
118+
if server.port == policy_port:
118119
log.info('Using STS policy, port %s', server.port)
119120
else:
120121
log.info('Using STS policy: changing port from %s to %s.',
121-
server.port, policy['port'])
122+
server.port, policy_port)
122123

123124
# Change the port, and force TLS verification, as required by the STS
124125
# specification.
125-
return Server(server.hostname, policy['port'], server.attempt,
126+
return Server(server.hostname, policy_port, server.attempt,
126127
force_tls_verification=True)
127128

128129
def die(self):

src/ircdb.py

+14-7
Original file line numberDiff line numberDiff line change
@@ -509,9 +509,10 @@ def __repr__(self):
509509
(self.__class__.__name__, self.stsPolicies,
510510
self.lastDisconnectTimes)
511511

512-
def addStsPolicy(self, server, stsPolicy):
513-
assert isinstance(stsPolicy, str)
514-
self.stsPolicies[server] = stsPolicy
512+
def addStsPolicy(self, server, port, stsPolicy):
513+
assert isinstance(port, int), repr(port)
514+
assert isinstance(stsPolicy, str), repr(stsPolicy)
515+
self.stsPolicies[server] = (port, stsPolicy)
515516

516517
def expireStsPolicy(self, server):
517518
if server in self.stsPolicies:
@@ -526,8 +527,10 @@ def write(s):
526527
fd.write(s)
527528
fd.write(os.linesep)
528529

529-
for (server, stsPolicy) in sorted(self.stsPolicies.items()):
530-
write('stsPolicy %s %s' % (server, stsPolicy))
530+
for (server, (port, stsPolicy)) in sorted(self.stsPolicies.items()):
531+
assert isinstance(port, int), repr(port)
532+
assert isinstance(stsPolicy, str), repr(stsPolicy)
533+
write('stsPolicy %s %s %s' % (server, port, stsPolicy))
531534

532535
for (server, disconnectTime) in \
533536
sorted(self.lastDisconnectTimes.items()):
@@ -667,8 +670,12 @@ def network(self, rest, lineno):
667670
IrcNetworkCreator.name = rest
668671

669672
def stspolicy(self, rest, lineno):
670-
(server, stsPolicy) = rest.split()
671-
self.net.addStsPolicy(server, stsPolicy)
673+
L = rest.split()
674+
if len(L) == 2:
675+
# Old policy missing a port. Discard it
676+
return
677+
(server, policyPort, stsPolicy) = L
678+
self.net.addStsPolicy(server, int(policyPort), stsPolicy)
672679

673680
def lastdisconnecttime(self, rest, lineno):
674681
(server, when) = rest.split()

src/irclib.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -2050,7 +2050,7 @@ def _onCapSts(self, policy, msg):
20502050
or (self.driver.ssl and self.driver.anyCertValidationEnabled())
20512051

20522052
parsed_policy = ircutils.parseStsPolicy(
2053-
log, policy, parseDuration=secure_connection)
2053+
log, policy, secure_connection=secure_connection)
20542054
if parsed_policy is None:
20552055
# There was an error (and it was logged). Ignore it and proceed
20562056
# with the connection.
@@ -2065,9 +2065,14 @@ def _onCapSts(self, policy, msg):
20652065
# For future-proofing (because we don't want to write an invalid
20662066
# value), we write the raw policy received from the server instead
20672067
# of the parsed one.
2068-
log.debug('Storing STS policy: %s', policy)
2068+
log.debug('Storing STS policy for %s (TLS port %s): %s',
2069+
self.driver.currentServer.hostname,
2070+
self.driver.currentServer.port,
2071+
policy)
20692072
ircdb.networks.getNetwork(self.network).addStsPolicy(
2070-
self.driver.currentServer.hostname, policy)
2073+
self.driver.currentServer.hostname,
2074+
self.driver.currentServer.port,
2075+
policy)
20712076
else:
20722077
hostname = self.driver.currentServer.hostname
20732078
attempt = self.driver.currentServer.attempt

src/ircutils.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -1073,11 +1073,15 @@ def parseCapabilityKeyValue(s):
10731073
return d
10741074

10751075

1076-
def parseStsPolicy(logger, policy, parseDuration):
1076+
def parseStsPolicy(logger, policy, secure_connection):
10771077
parsed_policy = parseCapabilityKeyValue(policy)
10781078

10791079
for key in ('port', 'duration'):
1080-
if key == 'duration' and not parseDuration:
1080+
if key == 'duration' and not secure_connection:
1081+
if key in parsed_policy:
1082+
del parsed_policy[key]
1083+
continue
1084+
elif key == 'port' and secure_connection:
10811085
if key in parsed_policy:
10821086
del parsed_policy[key]
10831087
continue

test/test_drivers.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def tearDown(self):
3939
def testValidStsPolicy(self):
4040
irc = irclib.Irc('test')
4141
net = ircdb.networks.getNetwork('test')
42-
net.addStsPolicy('example.com', 'duration=10,port=6697')
42+
net.addStsPolicy('example.com', 6697, 'duration=10,port=12345')
4343
net.addDisconnection('example.com')
4444

4545
with conf.supybot.networks.test.servers.context(
@@ -64,7 +64,7 @@ def testValidStsPolicy(self):
6464
def testExpiredStsPolicy(self):
6565
irc = irclib.Irc('test')
6666
net = ircdb.networks.getNetwork('test')
67-
net.addStsPolicy('example.com', 'duration=10,port=6697')
67+
net.addStsPolicy('example.com', 6697, 'duration=10')
6868
net.addDisconnection('example.com')
6969

7070
timeFastForward(16)
@@ -81,7 +81,7 @@ def testExpiredStsPolicy(self):
8181
def testRescheduledStsPolicy(self):
8282
irc = irclib.Irc('test')
8383
net = ircdb.networks.getNetwork('test')
84-
net.addStsPolicy('example.com', 'duration=10,port=6697')
84+
net.addStsPolicy('example.com', 6697, 'duration=10')
8585
net.addDisconnection('example.com')
8686

8787
with conf.supybot.networks.test.servers.context(

test/test_ircdb.py

+18-18
Original file line numberDiff line numberDiff line change
@@ -358,11 +358,11 @@ def testDefaults(self):
358358

359359
def testStsPolicy(self):
360360
n = ircdb.IrcNetwork()
361-
n.addStsPolicy('foo', 'bar')
362-
n.addStsPolicy('baz', 'qux')
361+
n.addStsPolicy('foo', 123, 'bar')
362+
n.addStsPolicy('baz', 456, 'qux')
363363
self.assertEqual(n.stsPolicies, {
364-
'foo': 'bar',
365-
'baz': 'qux',
364+
'foo': (123, 'bar'),
365+
'baz': (456, 'qux'),
366366
})
367367

368368
def testAddDisconnection(self):
@@ -374,8 +374,8 @@ def testAddDisconnection(self):
374374

375375
def testPreserve(self):
376376
n = ircdb.IrcNetwork()
377-
n.addStsPolicy('foo', 'sts1')
378-
n.addStsPolicy('bar', 'sts2')
377+
n.addStsPolicy('foo', 123, 'sts1')
378+
n.addStsPolicy('bar', 456,'sts2')
379379
n.addDisconnection('foo')
380380
n.addDisconnection('baz')
381381
disconnect_time_foo = n.lastDisconnectTimes['foo']
@@ -384,8 +384,8 @@ def testPreserve(self):
384384
n.preserve(fd, indent=' ')
385385
fd.seek(0)
386386
self.assertCountEqual(fd.read().split('\n'), [
387-
' stsPolicy foo sts1',
388-
' stsPolicy bar sts2',
387+
' stsPolicy foo 123 sts1',
388+
' stsPolicy bar 456 sts2',
389389
' lastDisconnectTime foo %d' % disconnect_time_foo,
390390
' lastDisconnectTime baz %d' % disconnect_time_baz,
391391
'',
@@ -467,8 +467,8 @@ def testGetSetNetwork(self):
467467

468468
def testPreserveOne(self):
469469
n = ircdb.IrcNetwork()
470-
n.addStsPolicy('foo', 'sts1')
471-
n.addStsPolicy('bar', 'sts2')
470+
n.addStsPolicy('foo', 123, 'sts1')
471+
n.addStsPolicy('bar', 456, 'sts2')
472472
n.addDisconnection('foo')
473473
n.addDisconnection('baz')
474474
disconnect_time_foo = n.lastDisconnectTimes['foo']
@@ -486,8 +486,8 @@ def testPreserveOne(self):
486486
lines = fd.getvalue().split('\n')
487487
self.assertEqual(lines.pop(0), 'network foonet')
488488
self.assertCountEqual(lines, [
489-
' stsPolicy foo sts1',
490-
' stsPolicy bar sts2',
489+
' stsPolicy foo 123 sts1',
490+
' stsPolicy bar 456 sts2',
491491
' lastDisconnectTime foo %d' % disconnect_time_foo,
492492
' lastDisconnectTime baz %d' % disconnect_time_baz,
493493
'',
@@ -496,15 +496,15 @@ def testPreserveOne(self):
496496

497497
def testPreserveThree(self):
498498
n = ircdb.IrcNetwork()
499-
n.addStsPolicy('foo', 'sts1')
499+
n.addStsPolicy('foo', 123, 'sts1')
500500
self.networks.setNetwork('foonet', n)
501501

502502
n = ircdb.IrcNetwork()
503-
n.addStsPolicy('bar', 'sts2')
503+
n.addStsPolicy('bar', 456, 'sts2')
504504
self.networks.setNetwork('barnet', n)
505505

506506
n = ircdb.IrcNetwork()
507-
n.addStsPolicy('baz', 'sts3')
507+
n.addStsPolicy('baz', 789, 'sts3')
508508
self.networks.setNetwork('baznet', n)
509509

510510
fd = io.StringIO()
@@ -518,13 +518,13 @@ def testPreserveThree(self):
518518
fd.seek(0)
519519
self.assertEqual(fd.getvalue(),
520520
'network barnet\n'
521-
' stsPolicy bar sts2\n'
521+
' stsPolicy bar 456 sts2\n'
522522
'\n'
523523
'network baznet\n'
524-
' stsPolicy baz sts3\n'
524+
' stsPolicy baz 789 sts3\n'
525525
'\n'
526526
'network foonet\n'
527-
' stsPolicy foo sts1\n'
527+
' stsPolicy foo 123 sts1\n'
528528
'\n'
529529
)
530530

test/test_irclib.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -759,16 +759,27 @@ def testStsInSecureConnection(self):
759759
self.irc.driver.ssl = True
760760
self.irc.driver.currentServer = drivers.Server('irc.test', 6697, None, False)
761761
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
762-
args=('*', 'LS', 'sts=duration=42,port=6697')))
762+
args=('*', 'LS', 'sts=duration=42,port=12345')))
763763

764764
self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {
765-
'irc.test': 'duration=42,port=6697'})
765+
'irc.test': (6697, 'duration=42,port=12345')})
766+
self.irc.driver.reconnect.assert_not_called()
767+
768+
def testStsInSecureConnectionNoPort(self):
769+
self.irc.driver.anyCertValidationEnabled.return_value = True
770+
self.irc.driver.ssl = True
771+
self.irc.driver.currentServer = drivers.Server('irc.test', 6697, None, False)
772+
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
773+
args=('*', 'LS', 'sts=duration=42')))
774+
775+
self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {
776+
'irc.test': (6697, 'duration=42')})
766777
self.irc.driver.reconnect.assert_not_called()
767778

768779
def testStsInInsecureTlsConnection(self):
769780
self.irc.driver.anyCertValidationEnabled.return_value = False
770781
self.irc.driver.ssl = True
771-
self.irc.driver.currentServer = drivers.Server('irc.test', 6697, None, False)
782+
self.irc.driver.currentServer = drivers.Server('irc.test', 6667, None, False)
772783
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
773784
args=('*', 'LS', 'sts=duration=42,port=6697')))
774785

0 commit comments

Comments
 (0)