sync controller database to disk and load it on start
This commit is contained in:
parent
1e23b1205a
commit
e1b1ec195f
2
.gitignore
vendored
2
.gitignore
vendored
@ -5,3 +5,5 @@ __pycache__
|
||||
venv
|
||||
capport.yaml
|
||||
custom
|
||||
capport.state
|
||||
capport.state.new-*
|
||||
|
@ -28,7 +28,7 @@ async def _run_hub(*, task_status=trio.TASK_STATUS_IGNORED) -> None:
|
||||
app.my_nc = mync
|
||||
_logger.info("Running hub for API")
|
||||
myapp = ApiHubApp()
|
||||
myhub = capport.comm.hub.Hub(config=app.my_config, app=myapp)
|
||||
myhub = capport.comm.hub.Hub(config=app.my_config, app=myapp, is_controller=False)
|
||||
app.my_hub = myhub
|
||||
await myhub.run(task_status=task_status)
|
||||
finally:
|
||||
|
@ -56,28 +56,28 @@ async def get_client_mac(address: typing.Optional[cptypes.IPAddress]=None) -> cp
|
||||
|
||||
async def user_login(address: cptypes.IPAddress, mac: cptypes.MacAddress) -> None:
|
||||
assert app.my_hub # for mypy
|
||||
pu = capport.database.PendingUpdates()
|
||||
try:
|
||||
app.my_hub.database.login(mac, app.my_config.session_timeout, pending_updates=pu)
|
||||
except capport.database.NotReadyYet as e:
|
||||
quart.abort(500, str(e))
|
||||
async with app.my_hub.database.make_changes() as pu:
|
||||
try:
|
||||
pu.login(mac, app.my_config.session_timeout)
|
||||
except capport.database.NotReadyYet as e:
|
||||
quart.abort(500, str(e))
|
||||
|
||||
if pu.macs:
|
||||
if pu:
|
||||
_logger.debug(f'User {mac} (with IP {address}) logged in')
|
||||
for msg in pu.serialize():
|
||||
for msg in pu.serialized:
|
||||
await app.my_hub.broadcast(msg)
|
||||
|
||||
|
||||
async def user_logout(mac: cptypes.MacAddress) -> None:
|
||||
assert app.my_hub # for mypy
|
||||
pu = capport.database.PendingUpdates()
|
||||
try:
|
||||
app.my_hub.database.logout(mac, pending_updates=pu)
|
||||
except capport.database.NotReadyYet as e:
|
||||
quart.abort(500, str(e))
|
||||
if pu.macs:
|
||||
async with app.my_hub.database.make_changes() as pu:
|
||||
try:
|
||||
pu.logout(mac)
|
||||
except capport.database.NotReadyYet as e:
|
||||
quart.abort(500, str(e))
|
||||
if pu:
|
||||
_logger.debug(f'User {mac} logged out')
|
||||
for msg in pu.serialize():
|
||||
for msg in pu.serialized:
|
||||
await app.my_hub.broadcast(msg)
|
||||
|
||||
|
||||
|
@ -265,9 +265,6 @@ class ControllerConn:
|
||||
|
||||
|
||||
class HubApplication:
|
||||
def is_controller(self) -> bool:
|
||||
return False
|
||||
|
||||
async def new_peer(self, *, peer_id: uuid.UUID) -> None:
|
||||
_logger.info(f"New peer {peer_id}")
|
||||
|
||||
@ -287,13 +284,18 @@ class HubApplication:
|
||||
|
||||
|
||||
class Hub:
|
||||
def __init__(self, config: Config, app: HubApplication) -> None:
|
||||
def __init__(self, config: Config, app: HubApplication, *, is_controller: bool) -> None:
|
||||
self._config = config
|
||||
self._instance_id = uuid.uuid4()
|
||||
self._hostname = socket.getfqdn()
|
||||
self.database = capport.database.Database()
|
||||
self._app = app
|
||||
self._is_controller = bool(app.is_controller())
|
||||
self._is_controller = is_controller
|
||||
state_filename: typing.Optional[str]
|
||||
if is_controller:
|
||||
state_filename = 'capport.state'
|
||||
else:
|
||||
state_filename = None
|
||||
self.database = capport.database.Database(state_filename=state_filename)
|
||||
self._anon_context = ssl.SSLContext()
|
||||
# python ssl doesn't support setting tls1.3 ciphers yet, so make sure we stay on 1.2 for now to enable anon
|
||||
self._anon_context.minimum_version = ssl.TLSVersion.TLSv1_2
|
||||
@ -324,6 +326,7 @@ class Hub:
|
||||
|
||||
async def run(self, *, task_status=trio.TASK_STATUS_IGNORED):
|
||||
async with trio.open_nursery() as nursery:
|
||||
await nursery.start(self.database.run)
|
||||
if self._is_controller:
|
||||
await nursery.start(self._listen)
|
||||
|
||||
@ -412,12 +415,12 @@ class Hub:
|
||||
pass
|
||||
elif isinstance(variant, capport.comm.message.MacStates):
|
||||
await self._app.received_mac_state(from_peer_id=peer_id, states=variant)
|
||||
pu = capport.database.PendingUpdates()
|
||||
for state in variant.states:
|
||||
self.database.received_mac_state(state, pending_updates=pu)
|
||||
if pu.macs:
|
||||
async with self.database.make_changes() as pu:
|
||||
for state in variant.states:
|
||||
pu.received_mac_state(state)
|
||||
if pu:
|
||||
# re-broadcast all received updates to all peers
|
||||
await self.broadcast(*pu.serialize(), exclude=peer_id)
|
||||
await self.broadcast(*pu.serialized, exclude=peer_id)
|
||||
await self._app.mac_states_changed(from_peer_id=peer_id, pending_updates=pu)
|
||||
else:
|
||||
await self._app.received_unknown_message(from_peer_id=peer_id, msg=msg)
|
||||
|
@ -1,12 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import typing
|
||||
import uuid
|
||||
|
||||
import capport.database
|
||||
import capport.comm.hub
|
||||
import capport.comm.message
|
||||
import capport.config
|
||||
import capport.database
|
||||
import capport.utils.cli
|
||||
import capport.utils.nft_set
|
||||
import trio
|
||||
@ -22,15 +23,15 @@ class ControlApp(capport.comm.hub.HubApplication):
|
||||
super().__init__()
|
||||
self.nft_set = capport.utils.nft_set.NftSet()
|
||||
|
||||
def is_controller(self) -> bool:
|
||||
return True
|
||||
|
||||
async def mac_states_changed(self, *, from_peer_id: uuid.UUID, pending_updates: capport.database.PendingUpdates) -> None:
|
||||
self.apply_db_entries(pending_updates.changes())
|
||||
|
||||
def apply_db_entries(self, entries: typing.Iterable[typing.Tuple[cptypes.MacAddress, capport.database.MacEntry]]) -> None:
|
||||
# deploy changes to netfilter set
|
||||
inserts = []
|
||||
removals = []
|
||||
now = cptypes.Timestamp.now()
|
||||
for mac, state in pending_updates.macs.items():
|
||||
for mac, state in entries:
|
||||
rem = state.allowed_remaining(now)
|
||||
if rem > 0:
|
||||
inserts.append((mac, rem))
|
||||
@ -42,9 +43,12 @@ class ControlApp(capport.comm.hub.HubApplication):
|
||||
|
||||
async def amain(config: capport.config.Config) -> None:
|
||||
app = ControlApp()
|
||||
hub = capport.comm.hub.Hub(config=config, app=app)
|
||||
hub = capport.comm.hub.Hub(config=config, app=app, is_controller=True)
|
||||
app.hub = hub
|
||||
await hub.run()
|
||||
async with trio.open_nursery() as nursery:
|
||||
# hub.run loads the statefile from disk before signalling it was "started"
|
||||
await nursery.start(hub.run)
|
||||
app.apply_db_entries(hub.database.entries())
|
||||
|
||||
|
||||
def main() -> None:
|
||||
|
@ -1,11 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
import struct
|
||||
import typing
|
||||
|
||||
import google.protobuf.message
|
||||
import trio
|
||||
|
||||
import capport.comm.message
|
||||
from capport import cptypes
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MacEntry:
|
||||
@ -113,30 +122,82 @@ def _serialize_mac_states_as_messages(macs: typing.Dict[cptypes.MacAddress, MacE
|
||||
return [s.to_message() for s in _serialize_mac_states(macs)]
|
||||
|
||||
|
||||
def _states_to_chunk(states: capport.comm.message.MacStates) -> bytes:
|
||||
chunk = states.SerializeToString(deterministic=True)
|
||||
chunk_size = len(chunk)
|
||||
len_bytes = struct.pack('!I', chunk_size)
|
||||
return len_bytes + chunk
|
||||
|
||||
|
||||
class NotReadyYet(Exception):
|
||||
def __init__(self, msg: str, wait: int):
|
||||
self.wait = wait # seconds to wait
|
||||
super().__init__(msg)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Database:
|
||||
_macs: typing.Dict[cptypes.MacAddress, MacEntry] = dataclasses.field(default_factory=dict)
|
||||
def __init__(self, state_filename: typing.Optional[str]=None):
|
||||
self._macs: typing.Dict[cptypes.MacAddress, MacEntry] = {}
|
||||
self._state_filename = state_filename
|
||||
self._changed_since_last_cleanup = False
|
||||
self._send_changes: typing.Optional[trio.MemorySendChannel[typing.Union[
|
||||
capport.comm.message.MacStates,
|
||||
typing.List[capport.comm.message.MacStates],
|
||||
]]] = None
|
||||
|
||||
def received_mac_state(self, state: capport.comm.message.MacState, *, pending_updates: PendingUpdates):
|
||||
(addr, new_entry) = MacEntry.parse_state(state)
|
||||
old_entry = self._macs.get(addr)
|
||||
if not old_entry:
|
||||
# only redistribute if not outdated
|
||||
if not new_entry.outdated():
|
||||
self._macs[addr] = new_entry
|
||||
pending_updates.macs[addr] = new_entry
|
||||
elif old_entry.merge(new_entry):
|
||||
if old_entry.outdated():
|
||||
# remove local entry, but still redistribute
|
||||
self._macs.pop(addr)
|
||||
pending_updates.macs[addr] = old_entry
|
||||
@contextlib.asynccontextmanager
|
||||
async def make_changes(self) -> typing.AsyncGenerator[PendingUpdates, None]:
|
||||
pu = PendingUpdates(self)
|
||||
pu._closed = False
|
||||
yield pu
|
||||
pu._finish()
|
||||
if pu:
|
||||
self._changed_since_last_cleanup = True
|
||||
if self._send_changes:
|
||||
for state in pu.serialized_states:
|
||||
await self._send_changes.send(state)
|
||||
|
||||
def _drop_outdated(self) -> None:
|
||||
done = False
|
||||
while not done:
|
||||
depr: typing.Set[cptypes.MacAddress] = set()
|
||||
now = cptypes.Timestamp.now()
|
||||
done = True
|
||||
for mac, entry in self._macs.items():
|
||||
if entry.outdated(now):
|
||||
depr.add(mac)
|
||||
if len(depr) >= 1024:
|
||||
# clear entries found so far, then try again
|
||||
done = False
|
||||
break
|
||||
if depr:
|
||||
self._changed_since_last_cleanup = True
|
||||
for mac in depr:
|
||||
del self._macs[mac]
|
||||
|
||||
async def run(self, task_status=trio.TASK_STATUS_IGNORED):
|
||||
if self._state_filename:
|
||||
await self._load_statefile()
|
||||
task_status.started()
|
||||
async with trio.open_nursery() as nursery:
|
||||
if self._state_filename:
|
||||
nursery.start_soon(self._run_statefile)
|
||||
while True:
|
||||
await trio.sleep(10) # sleep 15 minutes
|
||||
_logger.debug("Running database cleanup")
|
||||
self._drop_outdated()
|
||||
if self._changed_since_last_cleanup:
|
||||
self._changed_since_last_cleanup = False
|
||||
if self._send_changes:
|
||||
states = _serialize_mac_states(self._macs)
|
||||
# trigger a resync
|
||||
await self._send_changes.send(states)
|
||||
|
||||
# for initial handling of all data
|
||||
def entries(self) -> typing.List[typing.Tuple[cptypes.MacAddress, MacEntry]]:
|
||||
return list(self._macs.items())
|
||||
|
||||
# for initial sync with new peer
|
||||
def serialize(self) -> typing.List[capport.comm.message.Message]:
|
||||
return _serialize_mac_states_as_messages(self._macs)
|
||||
|
||||
@ -146,6 +207,87 @@ class Database:
|
||||
for addr, entry in self._macs.items()
|
||||
}
|
||||
|
||||
async def _run_statefile(self) -> None:
|
||||
rx: trio.MemoryReceiveChannel[typing.Union[
|
||||
capport.comm.message.MacStates,
|
||||
typing.List[capport.comm.message.MacStates],
|
||||
]]
|
||||
tx: trio.MemorySendChannel[typing.Union[
|
||||
capport.comm.message.MacStates,
|
||||
typing.List[capport.comm.message.MacStates],
|
||||
]]
|
||||
tx, rx = trio.open_memory_channel(64)
|
||||
self._send_changes = tx
|
||||
|
||||
assert self._state_filename
|
||||
filename: str = self._state_filename
|
||||
tmp_filename = f'{filename}.new-{os.getpid()}'
|
||||
|
||||
async def resync(all_states: typing.List[capport.comm.message.MacStates]):
|
||||
try:
|
||||
async with await trio.open_file(tmp_filename, 'xb') as tf:
|
||||
for states in all_states:
|
||||
await tf.write(_states_to_chunk(states))
|
||||
os.rename(tmp_filename, filename)
|
||||
finally:
|
||||
if os.path.exists(tmp_filename):
|
||||
_logger.warning(f'Removing (failed) state export file {tmp_filename}')
|
||||
os.unlink(tmp_filename)
|
||||
|
||||
try:
|
||||
while True:
|
||||
async with await trio.open_file(filename, 'ab') as sf:
|
||||
while True:
|
||||
update = await rx.receive()
|
||||
if isinstance(update, list):
|
||||
break
|
||||
await sf.write(_states_to_chunk(update))
|
||||
# got a "list" update - i.e. a resync
|
||||
with trio.CancelScope(shield=True):
|
||||
await resync(update)
|
||||
# now reopen normal statefile and continue appending updates
|
||||
except trio.Cancelled:
|
||||
_logger.info('Final sync to disk')
|
||||
with trio.CancelScope(shield=True):
|
||||
await resync(_serialize_mac_states(self._macs))
|
||||
_logger.info('Final sync to disk done')
|
||||
|
||||
async def _load_statefile(self):
|
||||
if not os.path.exists(self._state_filename):
|
||||
return
|
||||
_logger.info("Loading statefile")
|
||||
# we're going to ignore changes from loading the file
|
||||
pu = PendingUpdates(self)
|
||||
pu._closed = False
|
||||
async with await trio.open_file(self._state_filename, 'rb') as sf:
|
||||
while True:
|
||||
try:
|
||||
len_bytes = await sf.read(4)
|
||||
if not len_bytes:
|
||||
return
|
||||
if len(len_bytes) < 4:
|
||||
_logger.error("Failed to read next chunk from statefile (unexpected EOF)")
|
||||
return
|
||||
chunk_size, = struct.unpack('!I', len_bytes)
|
||||
chunk = await sf.read(chunk_size)
|
||||
except IOError as e:
|
||||
_logger.error(f"Failed to read next chunk from statefile: {e}")
|
||||
return
|
||||
try:
|
||||
states = capport.comm.message.MacStates()
|
||||
states.ParseFromString(chunk)
|
||||
except google.protobuf.message.DecodeError as e:
|
||||
_logger.error(f"Failed to decode chunk from statefile, trying next one: {e}")
|
||||
continue
|
||||
for state in states:
|
||||
errors = 0
|
||||
try:
|
||||
pu.received_mac_state(state)
|
||||
except Exception as e:
|
||||
errors += 1
|
||||
if errors < 5:
|
||||
_logger.error(f'Failed to handle state: {e}')
|
||||
|
||||
def lookup(self, address: cptypes.IPAddress, mac: cptypes.MacAddress) -> cptypes.MacPublicState:
|
||||
entry = self._macs.get(mac)
|
||||
if entry:
|
||||
@ -158,43 +300,87 @@ class Database:
|
||||
allowed_remaining=allowed_remaining,
|
||||
)
|
||||
|
||||
def login(self, mac: cptypes.MacAddress, session_timeout: int, *, pending_updates: PendingUpdates, renew_maximum: float=0.8):
|
||||
|
||||
class PendingUpdates:
|
||||
def __init__(self, database: Database):
|
||||
self._changes: typing.Dict[cptypes.MacAddress, MacEntry] = {}
|
||||
self._database = database
|
||||
self._closed = True
|
||||
self._serialized_states: typing.List[capport.comm.message.MacStates] = []
|
||||
self._serialized: typing.List[capport.comm.message.Message] = []
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self._changes)
|
||||
|
||||
def changes(self) -> typing.Iterable[typing.Tuple[cptypes.MacAddress, MacEntry]]:
|
||||
return self._changes.items()
|
||||
|
||||
@property
|
||||
def serialized_states(self) -> typing.List[capport.comm.message.MacStates]:
|
||||
assert self._closed
|
||||
return self._serialized_states
|
||||
|
||||
@property
|
||||
def serialized(self) -> typing.List[capport.comm.message.Message]:
|
||||
assert self._closed
|
||||
return self._serialized
|
||||
|
||||
def _finish(self):
|
||||
if self._closed:
|
||||
raise Exception("Can't change closed PendingUpdates")
|
||||
self._closed = True
|
||||
self._serialized_states = _serialize_mac_states(self._changes)
|
||||
self._serialized = [s.to_message() for s in self._serialized_states]
|
||||
|
||||
def received_mac_state(self, state: capport.comm.message.MacState):
|
||||
if self._closed:
|
||||
raise Exception("Can't change closed PendingUpdates")
|
||||
(addr, new_entry) = MacEntry.parse_state(state)
|
||||
old_entry = self._database._macs.get(addr)
|
||||
if not old_entry:
|
||||
# only redistribute if not outdated
|
||||
if not new_entry.outdated():
|
||||
self._database._macs[addr] = new_entry
|
||||
self._changes[addr] = new_entry
|
||||
elif old_entry.merge(new_entry):
|
||||
if old_entry.outdated():
|
||||
# remove local entry, but still redistribute
|
||||
self._database._macs.pop(addr)
|
||||
self._changes[addr] = old_entry
|
||||
|
||||
def login(self, mac: cptypes.MacAddress, session_timeout: int, *, renew_maximum: float=0.8):
|
||||
if self._closed:
|
||||
raise Exception("Can't change closed PendingUpdates")
|
||||
now = cptypes.Timestamp.now()
|
||||
allow_until = cptypes.Timestamp(epoch=now.epoch + session_timeout)
|
||||
new_entry = MacEntry(last_change=now, allow_until=allow_until, allowed=True)
|
||||
|
||||
entry = self._macs.get(mac)
|
||||
entry = self._database._macs.get(mac)
|
||||
if not entry:
|
||||
self._macs[mac] = new_entry
|
||||
pending_updates.macs[mac] = new_entry
|
||||
self._database._macs[mac] = new_entry
|
||||
self._changes[mac] = new_entry
|
||||
elif entry.allowed_remaining(now) > renew_maximum * session_timeout:
|
||||
# too much time left on clock, not renewing session
|
||||
return
|
||||
elif entry.merge(new_entry):
|
||||
pending_updates.macs[mac] = entry
|
||||
self._changes[mac] = entry
|
||||
elif not entry.allowed_remaining() > 0:
|
||||
# entry should have been updated - can only fail due to `now < entry.last_change`
|
||||
# i.e. out of sync clocks
|
||||
wait = entry.last_change.epoch - now.epoch
|
||||
raise NotReadyYet(f"can't login yet, try again in {wait} seconds", wait)
|
||||
|
||||
def logout(self, mac: cptypes.MacAddress, *, pending_updates: PendingUpdates):
|
||||
def logout(self, mac: cptypes.MacAddress):
|
||||
if self._closed:
|
||||
raise Exception("Can't change closed PendingUpdates")
|
||||
now = cptypes.Timestamp.now()
|
||||
new_entry = MacEntry(last_change=now, allow_until=None, allowed=False)
|
||||
entry = self._macs.get(mac)
|
||||
entry = self._database._macs.get(mac)
|
||||
if entry:
|
||||
if entry.merge(new_entry):
|
||||
pending_updates.macs[mac] = entry
|
||||
self._changes[mac] = entry
|
||||
elif entry.allowed_remaining() > 0:
|
||||
# still logged in. can only happen with `now <= entry.last_change`
|
||||
# clocks not necessarily out of sync, but you can't logout in the same second you logged in
|
||||
wait = entry.last_change.epoch - now.epoch + 1
|
||||
raise NotReadyYet(f"can't logout yet, try again in {wait} seconds", wait)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PendingUpdates:
|
||||
macs: typing.Dict[cptypes.MacAddress, MacEntry] = dataclasses.field(default_factory=dict)
|
||||
|
||||
def serialize(self) -> typing.List[capport.comm.message.Message]:
|
||||
return _serialize_mac_states_as_messages(self.macs)
|
||||
|
Loading…
Reference in New Issue
Block a user