Skip to content

Commit

Permalink
allow passing index to create_variant
Browse files Browse the repository at this point in the history
  • Loading branch information
jthurner committed Apr 10, 2024
1 parent 3ba43db commit cfd1453
Showing 1 changed file with 20 additions and 22 deletions.
42 changes: 20 additions & 22 deletions pandahub/lib/PandaHub.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,10 +390,10 @@ def _get_project_document(self, filter_dict: dict) -> Optional[dict]:
else:
return project_doc

def _get_project_database(self) -> pymongo.mongo_client:
def _get_project_database(self) -> MongoClient:
return self.mongo_client[str(self.active_project["_id"])]

def _get_global_database(self):
def _get_global_database(self) -> MongoClient:
if (
self.mongo_client_global_db is None
and SETTINGS.MONGODB_GLOBAL_DATABASE_URL is not None
Expand Down Expand Up @@ -1652,23 +1652,23 @@ def _create_mongodb_indexes(
# Variants
# -------------------------

def create_variant(self, data):
def create_variant(self, data, index: Optional[int] = None):
db = self._get_project_database()
net_id = int(data["net_id"])
max_index = list(
db["variant"]
.find({"net_id": net_id}, projection={"_id": 0, "index": 1})
.sort("index", -1)
.limit(1)
)
if not max_index:
index = 1
for coll in self._get_net_collections(db):
update = {"$set": {"var_type": "base", "not_in_var": []}}
db[coll].update_many({}, update)

else:
index = int(max_index[0]["index"]) + 1
if index is None:
max_index = list(
db["variant"]
.find({"net_id": net_id}, projection={"_id": 0, "index": 1})
.sort("index", -1)
.limit(1)
)
if not max_index:
index = 1
for coll in self._get_net_collections(db):
update = {"$set": {"var_type": "base", "not_in_var": []}}
db[coll].update_many({}, update)
else:
index = int(max_index[0]["index"]) + 1

data["index"] = index

Expand Down Expand Up @@ -1704,7 +1704,7 @@ def update_variant(self, net_id, index, data):
db = self._get_project_database()
db["variant"].update_one({"net_id": net_id, "index": index}, {"$set": data})

def get_variant_filter(self, variants):
def get_variant_filter(self, variants: Optional[Union[list[int], int]]) -> dict:
"""
Creates a mongodb query filter to retrieve pandapower elements for the given variant(s).
Expand All @@ -1718,11 +1718,9 @@ def get_variant_filter(self, variants):
dict
mongodb query filter for the given variant(s)
"""
if type(variants) is list and variants:
if isinstance(variants, list):
if len(variants) > 1:
variants = [
int(var) for var in variants
] # make sure variants are of type int
variants = [int(var) for var in variants] # make sure variants are of type int
return {
"$or": [
{"var_type": "base", "not_in_var": {"$nin": variants}},
Expand Down

0 comments on commit cfd1453

Please sign in to comment.