Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Use local previously downloaded files when available. #92

Merged
merged 5 commits into from
Nov 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 63 additions & 41 deletions actual/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,15 @@
import warnings
import zipfile
from os import PathLike
from typing import IO, List, Union
from typing import IO, List, Optional, Union

from sqlmodel import MetaData, Session, create_engine, select
from sqlmodel import MetaData, Session, create_engine

from actual.api import ActualServer
from actual.api.models import BankSyncErrorDTO, RemoteFileListDTO
from actual.crypto import create_key_buffer, decrypt_from_meta, encrypt, make_salt
from actual.database import (
Accounts,
MessagesClock,
Transactions,
apply_change,
get_attribute_from_reflected_table_name,
Expand All @@ -30,6 +29,7 @@
)
from actual.exceptions import (
ActualBankSyncError,
ActualDecryptionError,
ActualError,
InvalidZipFile,
UnknownFileId,
Expand All @@ -39,6 +39,7 @@
from actual.queries import (
get_account,
get_accounts,
get_or_create_clock,
get_ruleset,
get_transactions,
reconcile_transaction,
Expand Down Expand Up @@ -72,7 +73,8 @@ def __init__(
:param file: the name or id of the file to be set
:param encryption_password: password used to configure encryption, if existing
:param data_dir: where to store the downloaded files from the server. If not specified, a temporary folder will
be created instead.
be created instead. If database files are already present on the path, the library will try to
reuse them by re-computing the sync request.
:param cert: if a custom certificate should be used (i.e. self-signed certificate), it's path can be provided
as a string. Set to `False` for no certificate check.
:param bootstrap: if the server is not bootstrapped, bootstrap it with the password.
Expand Down Expand Up @@ -194,8 +196,11 @@ def create_budget(self, budget_name: str):
self.run_migrations(migration_files[1:])
if self._in_context:
self._session = strong_reference_session(Session(self.engine, **self._sa_kwargs))
# create a clock
self.load_clock()
# create a clock. Since the clock entry is not tracked, we use a separate session
with Session(self.engine) as session:
self._client = HULC_Client()
get_or_create_clock(session, self._client)
session.commit()

def rename_budget(self, budget_name: str):
"""Renames the budget with the given name."""
Expand Down Expand Up @@ -254,7 +259,7 @@ def encrypt(self, encryption_password: str):

def upload_budget(self):
"""Uploads the current file to the Actual server. If attempting to upload your first budget, make sure you use
[actual.Actual.create_budget][] first.
[Actual.create_budget][actual.Actual.create_budget] first.
"""
if not self._data_dir:
raise UnknownFileId("No current file loaded.")
Expand All @@ -276,10 +281,11 @@ def upload_budget(self):
self.encrypt(self._encryption_password)

def reupload_budget(self):
"""Similar to the reset sync option from the frontend, resets the user file on the backend and reuploads the
"""Similar to the reset sync option from the frontend, resets the user file on the backend and re-uploads the
current copy instead. **This operation can be destructive**, so make sure you generate a copy before
attempting to reupload your budget."""
attempting to re-upload your budget."""
self.reset_user_file(self._file.file_id)
self.update_metadata({"groupId": None}) # since we don't know what the new group id will be
self.upload_budget()

def apply_changes(self, messages: List[Message]):
Expand Down Expand Up @@ -341,19 +347,34 @@ def download_budget(self, encryption_password: str = None):

If the budget is password protected, the password needs to be present to download the budget, otherwise it will
fail.

When a `data_dir` was provided, the method will try to use the local downloaded copy by first checking if the
sync id (named group id) remains the same. If it does, then the sync is executed using the stored files.
Otherwise, the file is re-downloaded.
"""
file_bytes = self.download_user_file(self._file.file_id)
# check if file has an encryption key and retrieve it
encryption_password = encryption_password or self._encryption_password

if self._file.encrypt_key_id and encryption_password is None:
raise ActualError("File is encrypted but no encryption password provided.")
if encryption_password is not None and self._file.encrypt_key_id:
file_info = self.get_user_file_info(self._file.file_id)
key_info = self.user_get_key(self._file.file_id)
self._master_key = create_key_buffer(encryption_password, key_info.data.salt)
# decrypt file bytes
file_bytes = decrypt_from_meta(self._master_key, file_bytes, file_info.data.encrypt_meta)
self.import_zip(io.BytesIO(file_bytes))
self.download_master_encryption_key(encryption_password)
# then download user file if the data_dir is set and both files are present
if self._data_dir and all((self._data_dir / path).is_file() for path in ["db.sqlite", "metadata.json"]):
group_id = self.get_metadata().get("groupId")
# handle the case where a new group id exists and the file needs to be re-downloaded
if self._file.group_id != group_id:
warnings.warn("Sync id has been reset on remote database, re-downloading the budget.")
(self._data_dir / "db.sqlite").unlink()
(self._data_dir / "metadata.json").unlink()
return self.download_budget(encryption_password)
# resume budget
self.create_engine()
else:
file_bytes = self.download_user_file(self._file.file_id)
if encryption_password is not None and self._file.encrypt_key_id:
file_info = self.get_user_file_info(self._file.file_id)
# decrypt file bytes
file_bytes = decrypt_from_meta(self._master_key, file_bytes, file_info.data.encrypt_meta)
self.import_zip(io.BytesIO(file_bytes))
# sometimes downloaded budgets will not have the groupId
self.update_metadata({"groupId": self._file.group_id})
# actual js always calls validation
self.validate()
# run migrations if needed
Expand All @@ -364,6 +385,17 @@ def download_budget(self, encryption_password: str = None):
if self._in_context and not self._session:
self._session = strong_reference_session(Session(self.engine, **self._sa_kwargs))

def download_master_encryption_key(self, encryption_password: str) -> Optional[bytes]:
"""Downloads and assembles the key for decrypting the budget based on the provided encryption password.
If the user file is not encryption, no key will be returned. If the file was encrypted, the key is assembled
using the key salt and the password with the PBKDF2HMAC algorithm."""
if self._file.encrypt_key_id and encryption_password is None:
raise ActualDecryptionError("File is encrypted but no encryption password was provided.")
if encryption_password is not None and self._file.encrypt_key_id:
key_info = self.user_get_key(self._file.file_id)
self._master_key = create_key_buffer(encryption_password, key_info.data.salt)
return self._master_key

def import_zip(self, file_bytes: str | PathLike[str] | IO[bytes]):
"""Imports a zip file as the current database, as well as generating the local reflected session. Enables you
to inspect backups by loading them directly, instead of unzipping the contents."""
Expand All @@ -375,10 +407,15 @@ def import_zip(self, file_bytes: str | PathLike[str] | IO[bytes]):
self._data_dir = pathlib.Path(tempfile.mkdtemp())
# this should extract 'db.sqlite' and 'metadata.json' to the folder
zip_file.extractall(self._data_dir)
self.create_engine()

def create_engine(self):
self.engine = create_engine(f"sqlite:///{self._data_dir}/db.sqlite")
self._meta = reflect_model(self.engine)
# load the client id
self.load_clock()
with Session(self.engine) as session:
clock = get_or_create_clock(session)
self._client = clock.get_timestamp()

def sync(self):
"""Does a sync request and applies all changes that are stored on the server on the local copy of the database.
Expand All @@ -393,30 +430,15 @@ def sync(self):
"keyId": self._file.encrypt_key_id,
}
)
request.set_null_timestamp(client_id=self._client.client_id) # using 0 timestamp to retrieve all changes
request.set_timestamp(client_id=self._client.client_id, now=self._client.ts)
changes = self.sync_sync(request)
self.apply_changes(changes.get_messages(self._master_key))
# after receiving changes, update the client clock with the latest value
if changes.messages:
self._client = HULC_Client.from_timestamp(changes.messages[-1].timestamp)

def load_clock(self) -> MessagesClock:
"""Loads the HULC Clock from the database. This clock tells the server from when the messages should be
retrieved. See the [original implementation.](
https://github.com/actualbudget/actual/blob/5bcfc71be67c6e7b7c8b444e4c4f60da9ea9fdaa/packages/loot-core/src/server/db/index.ts#L81-L98)
"""
with Session(self.engine) as session:
clock = session.exec(select(MessagesClock)).one_or_none()
if not clock:
clock_message = {
"timestamp": HULC_Client().timestamp(now=datetime.datetime(1970, 1, 1, 0, 0, 0, 0)),
"merkle": {},
}
clock = MessagesClock(id=1, clock=json.dumps(clock_message, separators=(",", ":")))
session.add(clock)
session.commit()
# add clock id to client id
self._client = HULC_Client.from_timestamp(json.loads(clock.clock)["timestamp"])
return clock
# store timestamp also inside database. Session might not be available here, so we create one
with Session(self.engine) as session:
get_or_create_clock(session, self._client)

def commit(self):
"""Adds all pending entries to the local database, and sends a sync request to the remote server to synchronize
Expand Down
2 changes: 2 additions & 0 deletions actual/cli/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def save(self):
"""Saves the current configuration to a file."""
config_path = default_config_path()
os.makedirs(config_path.parent, exist_ok=True)
os.makedirs(config_path.parent / "cache", exist_ok=True)
with open(config_path, "w") as file:
yaml.dump(self.model_dump(by_alias=True), file)

Expand Down Expand Up @@ -72,4 +73,5 @@ def actual(self) -> Actual:
password=budget_config.password,
file=budget_config.file_id,
encryption_password=budget_config.encryption_password,
data_dir=default_config_path().parent / "cache" / context,
)
2 changes: 1 addition & 1 deletion actual/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def transactions():
transactions_data.append(
{
"date": transaction.get_date().isoformat(),
"payee": transaction.payee.name,
"payee": transaction.payee.name if transaction.payee else None,
"notes": transaction.notes or "",
"category": (transaction.category.name if transaction.category else None),
"amount": round(float(transaction.get_amount()), 2),
Expand Down
29 changes: 25 additions & 4 deletions actual/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
sqlacodegen --generator sqlmodels sqlite:///db.sqlite
```

and patch the necessary models by merging the results. The [actual.database.BaseModel][] defines all models that can
be updated from the user, and must contain a unique `id`. Those models can then be converted automatically into a
protobuf change message using [actual.database.BaseModel.convert][].
and patch the necessary models by merging the results. The [BaseModel][actual.database.BaseModel] defines all models
that can be updated from the user, and must contain a unique `id`. Those models can then be converted automatically
into a protobuf change message using [BaseModel.convert][actual.database.BaseModel.convert].
"""

import datetime
import decimal
import json
from typing import Dict, List, Optional, Tuple, Union

from sqlalchemy import MetaData, Table, engine, event, inspect
Expand All @@ -37,7 +38,7 @@
)

from actual.exceptions import ActualInvalidOperationError
from actual.protobuf_models import Message
from actual.protobuf_models import HULC_Client, Message

"""
This variable contains the internal model mappings for all databases. It solves a couple of issues, namely having the
Expand Down Expand Up @@ -420,6 +421,26 @@ class MessagesClock(SQLModel, table=True):
id: Optional[int] = Field(default=None, sa_column=Column("id", Integer, primary_key=True))
clock: Optional[str] = Field(default=None, sa_column=Column("clock", Text))

def get_clock(self) -> dict:
"""Gets the clock from JSON text to a dictionary with fields `timestamp` and `merkle`."""
return json.loads(self.clock)

def set_clock(self, value: dict):
"""Sets the clock from a dictionary and stores it in the correct format."""
self.clock = json.dumps(value, separators=(",", ":"))

def get_timestamp(self) -> HULC_Client:
"""Gets the timestamp from the clock value directly as a [HULC_Client][actual.protobuf_models.HULC_Client]."""
clock = self.get_clock()
return HULC_Client.from_timestamp(clock["timestamp"])

def set_timestamp(self, client: HULC_Client) -> None:
"""Sets the timestamp on the clock value based on the [HULC_Client][actual.protobuf_models.HULC_Client]
provided."""
clock_message = self.get_clock()
clock_message["timestamp"] = str(client)
self.set_clock(clock_message)


class MessagesCrdt(SQLModel, table=True):
__tablename__ = "messages_crdt"
Expand Down
22 changes: 14 additions & 8 deletions actual/protobuf_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,21 @@


class HULC_Client:
def __init__(self, client_id: str = None, initial_count: int = 0):
self.client_id = client_id or self.get_client_id()
def __init__(self, client_id: str = None, initial_count: int = 0, ts: datetime.datetime = None):
self.client_id = client_id or self.random_client_id()
self.initial_count = initial_count
self.ts = ts or datetime.datetime(1970, 1, 1, 0, 0, 0)

@classmethod
def from_timestamp(cls, ts: str) -> HULC_Client:
segments = ts.split("-")
return cls(segments[-1], int(segments[-2], 16))
ts_string, _, rest = ts.partition("Z")
segments = rest.split("-")
ts = datetime.datetime.fromisoformat(ts_string)
return cls(segments[-1], int(segments[-2], 16), ts)

def __str__(self):
count = f"{self.initial_count:0>4X}"
return f"{self.ts.isoformat(timespec='milliseconds')}Z-{count}-{self.client_id}"

def timestamp(self, now: datetime.datetime = None) -> str:
"""Actual uses Hybrid Unique Logical Clock (HULC) timestamp generator.
Expand All @@ -47,13 +54,12 @@ def timestamp(self, now: datetime.datetime = None) -> str:
self.initial_count += 1
return f"{now.isoformat(timespec='milliseconds')}Z-{count}-{self.client_id}"

def get_client_id(self):
@staticmethod
def random_client_id():
"""Creates a client id for the HULC request. Implementation copied [from the source code](
https://github.com/actualbudget/actual/blob/a9362cc6f9b974140a760ad05816cac51c849769/packages/crdt/src/crdt/timestamp.ts#L80)
"""
return (
self.client_id if getattr(self, "client_id", None) is not None else str(uuid.uuid4()).replace("-", "")[-16:]
)
return str(uuid.uuid4()).replace("-", "")[-16:]


class EncryptedData(proto.Message):
Expand Down
29 changes: 29 additions & 0 deletions actual/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Categories,
CategoryGroups,
CategoryMapping,
MessagesClock,
PayeeMapping,
Payees,
Rules,
Expand All @@ -27,6 +28,7 @@
ZeroBudgets,
)
from actual.exceptions import ActualError
from actual.protobuf_models import HULC_Client
from actual.rules import Action, Condition, Rule, RuleSet
from actual.utils.title import title

Expand Down Expand Up @@ -738,3 +740,30 @@ def get_schedules(s: Session, name: str = None, include_deleted: bool = False) -
"""
query = base_query(Schedules, name, include_deleted)
return s.exec(query).all()


def get_or_create_clock(s: Session, client: HULC_Client = None) -> MessagesClock:
"""Loads the HULC Clock from the database. This clock tells the server from when the messages should be
retrieved. See the [original implementation.](
https://github.com/actualbudget/actual/blob/5bcfc71be67c6e7b7c8b444e4c4f60da9ea9fdaa/packages/loot-core/src/server/db/index.ts#L81-L98)

If the clock is not existing, it will be created based on the passed client. If the client is missing, an empty
client is created. If the clock was already existing, the timestamp will only be overwritten if a client is
provided, otherwise the original value will be returned.

:param s: session from Actual local database.
:param client: HULC Client object.
:return: The message clock object.
"""
clock = s.exec(select(MessagesClock)).one_or_none()
if not clock:
clock = MessagesClock(id=1)
if not client:
client = HULC_Client() # create a default client
clock.set_clock({"timestamp": str(client), "merkle": {}})
s.add(clock)
else:
# update the clock only if the client was provided
if client:
clock.set_timestamp(client)
return clock
8 changes: 8 additions & 0 deletions tests/test_database.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
import decimal
import json
from datetime import date, timedelta
Expand All @@ -16,6 +17,7 @@
get_accounts,
get_budgets,
get_or_create_category,
get_or_create_clock,
get_or_create_payee,
get_ruleset,
get_transactions,
Expand Down Expand Up @@ -292,3 +294,9 @@ def test_apply_changes(session, mocker):
assert transactions[0].notes == transaction.notes
assert transactions[0].get_date() == transaction.get_date()
assert transactions[0].get_amount() == transaction.get_amount()


def test_get_or_create_clock(session):
clock = get_or_create_clock(session)
assert clock.get_timestamp().ts == datetime.datetime(1970, 1, 1, 0, 0, 0)
assert clock.get_timestamp().initial_count == 0
Loading
Loading