Skip to content

Commit

Permalink
Merge pull request #25 from IFRCGo/feature/filters
Browse files Browse the repository at this point in the history
Feature/filters
  • Loading branch information
thenav56 authored Apr 16, 2024
2 parents 2ff1195 + 3c656bf commit 7d7510f
Show file tree
Hide file tree
Showing 19 changed files with 560 additions and 162 deletions.
5 changes: 4 additions & 1 deletion .github/dependabot.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,7 @@ updates:
- package-ecosystem: "pip" # See documentation for possible values
directory: "/" # Location of package manifests
schedule:
interval: "weekly"
interval: "weekly"
ignore:
- dependency-name: "*"
update-types: ["version-update:semver-major"]
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ exclude: |
repos:
- repo: https://github.com/psf/black
rev: 24.2.0
rev: 24.3.0
hooks:
- id: black
# args: ["--check"]
Expand All @@ -31,6 +31,6 @@ repos:
- id: flake8

- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.354
rev: v1.1.355
hooks:
- id: pyright
47 changes: 47 additions & 0 deletions apps/cap_feed/data_injector/feed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import json
import os

from apps.cap_feed.models import Country, Feed, LanguageInfo

module_dir = os.path.dirname(__file__) # get current directory


# inject feed configurations if not already present
def inject_feeds():
file_path = os.path.join(
os.path.dirname(module_dir),
'feeds.json',
)
with open(file_path, encoding='utf-8') as file:
feeds = json.load(file)
print('Injecting feeds...')
unique_countries = set()
feed_counter = 0
for feed_entry in feeds:
try:
feed = Feed()
feed.url = feed_entry['capAlertFeed']
feed.country = Country.objects.get(iso3=feed_entry['iso3'])
feed_counter += 1
unique_countries.add(feed_entry['iso3'])
if Feed.objects.filter(url=feed.url).first():
continue
feed.format = feed_entry['format']
feed.polling_interval = 60
feed.enable_polling = True
feed.enable_rebroadcast = True
feed.official = True
feed.save()

language_info = LanguageInfo()
language_info.feed = feed
language_info.name = feed_entry['name']
language_info.language = feed_entry['language']
language_info.logo = feed_entry['picUrl']
language_info.save()

except Exception as e:
print(feed_entry['name'])
print(f'Error injecting feed: {e}')

print(f'Injected {feed_counter} feeds for {len(unique_countries)} unique countries')
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import requests

from .models import Admin1, Continent, Country, Feed, LanguageInfo, Region
from apps.cap_feed.models import Admin1, Continent, Country, Region

module_dir = os.path.dirname(__file__) # get current directory

Expand Down Expand Up @@ -43,7 +43,7 @@ def process_continents():
continent_data = json.loads(response.content)
process_continents()
else:
file_path = os.path.join(module_dir, 'geographical/continents.json')
file_path = os.path.join(os.path.dirname(module_dir), 'geographical/continents.json')
with open(file_path) as file:
continent_data = json.load(file)
process_continents()
Expand All @@ -68,7 +68,7 @@ def process_regions():
region_data = json.loads(response.content)
process_regions()
else:
file_path = os.path.join(module_dir, 'geographical/ifrc-regions.json')
file_path = os.path.join(os.path.dirname(module_dir), 'geographical/ifrc-regions.json')
with open(file_path) as file:
region_data = json.load(file)
process_regions()
Expand Down Expand Up @@ -125,7 +125,7 @@ def process_countries_opendatasoft():
region_data = json.loads(response.content)
process_regions()
else:
file_path = os.path.join(module_dir, 'geographical/ifrc-regions.json')
file_path = os.path.join(os.path.dirname(module_dir), 'geographical/ifrc-regions.json')
with open(file_path) as file:
region_data = json.load(file)
process_regions()
Expand All @@ -137,7 +137,7 @@ def process_countries_opendatasoft():
country_data = json.loads(response.content)
process_countries_ifrc()
else:
file_path = os.path.join(module_dir, 'geographical/ifrc-countries-and-territories.json')
file_path = os.path.join(os.path.dirname(module_dir), 'geographical/ifrc-countries-and-territories.json')
with open(file_path) as file:
country_data = json.load(file)
process_countries_ifrc()
Expand All @@ -149,7 +149,7 @@ def process_countries_opendatasoft():
country_data = json.loads(response.content)
process_countries_opendatasoft()
else:
file_path = os.path.join(module_dir, 'geographical/opendatasoft-countries-and-territories.geojson')
file_path = os.path.join(os.path.dirname(module_dir), 'geographical/opendatasoft-countries-and-territories.geojson')
with open(file_path) as file:
country_data = json.load(file)
process_countries_opendatasoft()
Expand Down Expand Up @@ -184,45 +184,7 @@ def process_admin1s():
admin1_data = json.loads(response.content)
process_admin1s()
else:
file_path = os.path.join(module_dir, 'geographical/geoBoundariesCGAZ_ADM1.geojson')
file_path = os.path.join(os.path.dirname(module_dir), 'geographical/geoBoundariesCGAZ_ADM1.geojson')
with open(file_path, encoding='utf-8') as f:
admin1_data = json.load(f)
process_admin1s()


# inject feed configurations if not already present
def inject_feeds():
file_path = os.path.join(module_dir, 'feeds.json')
with open(file_path, encoding='utf-8') as file:
feeds = json.load(file)
print('Injecting feeds...')
unique_countries = set()
feed_counter = 0
for feed_entry in feeds:
try:
feed = Feed()
feed.url = feed_entry['capAlertFeed']
feed.country = Country.objects.get(iso3=feed_entry['iso3'])
feed_counter += 1
unique_countries.add(feed_entry['iso3'])
if Feed.objects.filter(url=feed.url).first():
continue
feed.format = feed_entry['format']
feed.polling_interval = 60
feed.enable_polling = True
feed.enable_rebroadcast = True
feed.official = True
feed.save()

language_info = LanguageInfo()
language_info.feed = feed
language_info.name = feed_entry['name']
language_info.language = feed_entry['language']
language_info.logo = feed_entry['picUrl']
language_info.save()

except Exception as e:
print(feed_entry['name'])
print(f'Error injecting feed: {e}')

print(f'Injected {feed_counter} feeds for {len(unique_countries)} unique countries')
64 changes: 64 additions & 0 deletions apps/cap_feed/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from .models import (
Admin1,
Alert,
AlertAdmin1,
AlertInfo,
AlertInfoArea,
Expand Down Expand Up @@ -89,6 +90,22 @@ def load_admin1s_by_country(keys: list[int]) -> list[list['Admin1Type']]:
return [_map[key] for key in keys]


def load_info_by_alert(keys: list[int]) -> list[typing.Union['AlertInfoType', None]]:
qs = (
AlertInfo.objects.filter(alert__in=keys)
# TODO: Is this order good enough?
.order_by('alert_id', 'id')
.distinct('alert_id')
.all()
)

_map: dict[int, 'AlertInfoType'] = { # type: ignore[reportGeneralTypeIssues]
alert_info.alert_id: alert_info for alert_info in qs
}

return [_map.get(key) for key in keys]


def load_infos_by_alert(keys: list[int]) -> list[list['AlertInfoType']]:
qs = AlertInfo.objects.filter(alert__in=keys).all()

Expand Down Expand Up @@ -159,6 +176,41 @@ def load_language_info_by_feed(keys: list[int]) -> list[list['LanguageInfoType']
return [_map[key] for key in keys]


def load_alert_count_by_country(keys: list[int]) -> list[int]:
qs = (
Alert.get_queryset()
.filter(country__in=keys)
.order_by()
.values('country_id')
.annotate(
count=models.Count('id'),
)
.values_list('country_id', 'count')
)

_map = {country_id: count for country_id, count in qs}

return [_map.get(key, 0) for key in keys]


def load_alert_count_by_admin1(keys: list[int]) -> list[int]:
qs = (
Alert.objects
# TODO: Add is_expired=False filter
.filter(admin1s__in=keys)
.order_by()
.values('admin1s')
.annotate(
count=models.Count('id'),
)
.values_list('admin1s', 'count')
)

_map = {admin1_id: count for admin1_id, count in qs}

return [_map.get(key, 0) for key in keys]


class CapFeedDataloader:

@cached_property
Expand All @@ -185,6 +237,10 @@ def load_admin1s_by_alert(self):
def load_admin1s_by_country(self):
return DataLoader(load_fn=sync_to_async(load_admin1s_by_country))

@cached_property
def load_info_by_alert(self):
return DataLoader(load_fn=sync_to_async(load_info_by_alert))

@cached_property
def load_infos_by_alert(self):
return DataLoader(load_fn=sync_to_async(load_infos_by_alert))
Expand Down Expand Up @@ -212,3 +268,11 @@ def load_info_area_geocodes_by_info_area(self):
@cached_property
def load_language_info_by_feed(self):
return DataLoader(load_fn=sync_to_async(load_language_info_by_feed))

@cached_property
def load_alert_count_by_country(self):
return DataLoader(load_fn=sync_to_async(load_alert_count_by_country))

@cached_property
def load_alert_count_by_admin1(self):
return DataLoader(load_fn=sync_to_async(load_alert_count_by_admin1))
88 changes: 85 additions & 3 deletions apps/cap_feed/filters.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,86 @@
import strawberry
import strawberry_django
from django.contrib.postgres.aggregates.general import ArrayAgg
from django.db import models

from .enums import (
AlertInfoCategoryEnum,
AlertInfoCertaintyEnum,
AlertInfoSeverityEnum,
AlertInfoUrgencyEnum,
)
from .models import Admin1, Alert, AlertInfo, Country, Feed, Region


@strawberry_django.filters.filter(Alert, lookups=True)
class AlertFilter:
id: strawberry.auto
url: strawberry.auto
sender: strawberry.auto
admin1s: strawberry.auto
country: strawberry.auto
sent: strawberry.auto

@strawberry_django.filter_field
def region(
self,
queryset: models.QuerySet,
value: strawberry.ID,
prefix: str,
) -> tuple[models.QuerySet, models.Q]:
return queryset, models.Q(**{f"{prefix}country__region": value})

@strawberry_django.filter_field
def admin1(
self,
queryset: models.QuerySet,
value: strawberry.ID,
prefix: str,
) -> tuple[models.QuerySet, models.Q]:
return queryset, models.Q(**{f"{prefix}admin1s": value})

def _info_enum_fields(self, field, queryset, value, prefix) -> tuple[models.QuerySet, models.Q]:
alias_field = f"_infos_{field}_list"
queryset = queryset.alias(
**{
# NOTE: To avoid duplicate alerts when joining infos
alias_field: ArrayAgg(f"{prefix}infos__{field}"),
}
)
return queryset, models.Q(**{f"{prefix}{alias_field}__overlap": value})

@strawberry_django.filter_field
def urgency(
self,
queryset: models.QuerySet,
value: list[AlertInfoUrgencyEnum], # type: ignore[reportInvalidTypeForm]
prefix: str,
) -> tuple[models.QuerySet, models.Q]:
return self._info_enum_fields("urgency", queryset, value, prefix)

@strawberry_django.filter_field
def severity(
self,
queryset: models.QuerySet,
value: list[AlertInfoSeverityEnum], # type: ignore[reportInvalidTypeForm]
prefix: str,
) -> tuple[models.QuerySet, models.Q]:
return self._info_enum_fields("severity", queryset, value, prefix)

@strawberry_django.filter_field
def certainty(
self,
queryset: models.QuerySet,
value: list[AlertInfoCertaintyEnum], # type: ignore[reportInvalidTypeForm]
prefix: str,
) -> tuple[models.QuerySet, models.Q]:
return self._info_enum_fields("certainty", queryset, value, prefix)

@strawberry_django.filter_field
def category(
self,
queryset: models.QuerySet,
value: list[AlertInfoCategoryEnum], # type: ignore[reportInvalidTypeForm]
prefix: str,
) -> tuple[models.QuerySet, models.Q]:
return self._info_enum_fields("category", queryset, value, prefix)


@strawberry_django.filters.filter(AlertInfo, lookups=True)
Expand All @@ -31,6 +102,17 @@ class CountryFilter:
class Admin1Filter:
id: strawberry.auto

@strawberry_django.filter_field
def unknown(
self,
queryset: models.QuerySet,
value: bool,
prefix: str,
) -> tuple[models.QuerySet, models.Q]:
if value:
return queryset, models.Q(**{f"{prefix}id__lt": 0})
return queryset, models.Q(**{f"{prefix}id__gte": 0})


@strawberry_django.filters.filter(Region, lookups=True)
class RegionFilter:
Expand Down
5 changes: 5 additions & 0 deletions apps/cap_feed/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,11 @@ def __init__(self, *args, **kwargs):
def __str__(self):
return self.url

@classmethod
def get_queryset(cls) -> models.QuerySet:
# TODO: Add is_expired=False filter
return cls.objects.all()

def info_has_been_added(self):
self.__all_info_added = True

Expand Down
Loading

0 comments on commit 7d7510f

Please sign in to comment.