From 8acd15ed7f685a4412a7a5bd8da3037fbfd562b7 Mon Sep 17 00:00:00 2001 From: Harald Mack <39521902+MaHaWo@users.noreply.github.com> Date: Mon, 13 Jan 2025 15:10:51 +0100 Subject: [PATCH] Add `included_in_statistics` flag to answersession (#215) * add flags to answer * add default args and work on tests * include datetime fix for tests * fix error in statistics database query * adjust existing tests * add test for setting the included_in_milestonestatistics flag * add test for usage of used answers in group statistics * update openapi.json & openapi-ts client * change point where answers are updated * incorporate suggestions, fix test * update openapi.json & openapi-ts client --------- Co-authored-by: github-actions[bot] --- .../src/mondey_backend/models/milestones.py | 2 + .../src/mondey_backend/routers/statistics.py | 37 +++++---- mondey_backend/tests/conftest.py | 19 ++++- mondey_backend/tests/routers/test_users.py | 49 ++++++++++-- mondey_backend/tests/utils/test_statistics.py | 77 +++++++++++++++++++ 5 files changed, 157 insertions(+), 27 deletions(-) diff --git a/mondey_backend/src/mondey_backend/models/milestones.py b/mondey_backend/src/mondey_backend/models/milestones.py index c72e6377..a08c72b1 100644 --- a/mondey_backend/src/mondey_backend/models/milestones.py +++ b/mondey_backend/src/mondey_backend/models/milestones.py @@ -158,6 +158,8 @@ class MilestoneAnswer(SQLModel, table=True): ) milestone_group_id: int = Field(default=None, foreign_key="milestonegroup.id") answer: int + included_in_milestone_statistics: bool = False + included_in_milestonegroup_statistics: bool = False class MilestoneAnswerSession(SQLModel, table=True): diff --git a/mondey_backend/src/mondey_backend/routers/statistics.py b/mondey_backend/src/mondey_backend/routers/statistics.py index b5a5501d..1be97ffa 100644 --- a/mondey_backend/src/mondey_backend/routers/statistics.py +++ b/mondey_backend/src/mondey_backend/routers/statistics.py @@ -4,7 +4,6 @@ from collections.abc import Sequence import numpy as np -from sqlalchemy import and_ from sqlmodel import col from sqlmodel import select @@ -195,8 +194,8 @@ def calculate_milestone_statistics_by_age( MilestoneAgeScoreCollection object which contains a list of MilestoneAgeScore objects, one for each month, or None if there are no answers for the milestoneg and no previous statistics. """ - # TODO: when the answersession eventually has an expired flag, this can go again. session_expired_days: int = 7 + # TODO: when the answersession eventually has an expired flag, this can go again. # get the newest statistics for the milestone last_statistics = session.get(MilestoneAgeScoreCollection, milestone_id) @@ -221,6 +220,7 @@ def calculate_milestone_statistics_by_age( col(MilestoneAnswer.answer_session_id) == MilestoneAnswerSession.id, ) .where(MilestoneAnswer.milestone_id == milestone_id) + .where(~col(MilestoneAnswer.included_in_milestone_statistics)) .where(MilestoneAnswerSession.created_at < expiration_date) ) else: @@ -239,12 +239,8 @@ def calculate_milestone_statistics_by_age( col(MilestoneAnswer.answer_session_id) == MilestoneAnswerSession.id, ) .where(MilestoneAnswer.milestone_id == milestone_id) - .where( - and_( - col(MilestoneAnswerSession.created_at) > last_statistics.created_at, - col(MilestoneAnswerSession.created_at) <= expiration_date, - ) # expired session only which are not in the last statistics - ) + .where(~col(MilestoneAnswer.included_in_milestone_statistics)) + .where(col(MilestoneAnswerSession.created_at) <= expiration_date) ) answers = session.exec(answers_query).all() @@ -259,6 +255,11 @@ def calculate_milestone_statistics_by_age( expected_age = _get_expected_age_from_scores(avg_scores) + for answer in answers: + answer.included_in_milestone_statistics = True + session.merge(answer) + session.commit() + # overwrite last_statistics with updated stuff --> set primary keys explicitly return MilestoneAgeScoreCollection( milestone_id=milestone_id, @@ -302,7 +303,6 @@ def calculate_milestonegroup_statistics_by_age( one for each month, or None if there are no answers for the milestonegroup and no previous statistics. """ - # TODO: when the answersession eventually has an 'expired' flag, this can go again. session_expired_days: int = 7 # get the newest statistics for the milestonegroup @@ -326,9 +326,10 @@ def calculate_milestonegroup_statistics_by_age( col(MilestoneAnswer.answer_session_id) == MilestoneAnswerSession.id, ) .where(MilestoneAnswer.milestone_group_id == milestonegroup_id) + .where(~col(MilestoneAnswer.included_in_milestonegroup_statistics)) .where( MilestoneAnswerSession.created_at - < expiration_date # expired session only + <= expiration_date # expired session only ) ) else: @@ -349,15 +350,11 @@ def calculate_milestonegroup_statistics_by_age( select(MilestoneAnswer) .join( MilestoneAnswerSession, - MilestoneAnswer.answer_session_id == MilestoneAnswerSession.id, # type: ignore + col(MilestoneAnswer.answer_session_id) == MilestoneAnswerSession.id, ) .where(MilestoneAnswer.milestone_group_id == milestonegroup_id) - .where( - and_( - MilestoneAnswerSession.created_at > last_statistics.created_at, # type: ignore - MilestoneAnswerSession.created_at <= expiration_date, # type: ignore - ) - ) # expired session only which are not in the last statistics + .where(~col(MilestoneAnswer.included_in_milestonegroup_statistics)) + .where(MilestoneAnswerSession.created_at <= expiration_date) ) answers = session.exec(answer_query).all() @@ -371,6 +368,12 @@ def calculate_milestonegroup_statistics_by_age( answers, child_ages, count=count, avg=avg_scores, stddev=stddev_scores ) + # update answer.included_in_milestonegroup_statistics to True + for answer in answers: + answer.included_in_milestonegroup_statistics = True + session.merge(answer) + session.commit() + return MilestoneGroupAgeScoreCollection( milestone_group_id=milestonegroup_id, scores=[ diff --git a/mondey_backend/tests/conftest.py b/mondey_backend/tests/conftest.py index 36101b1b..f9cc3a83 100644 --- a/mondey_backend/tests/conftest.py +++ b/mondey_backend/tests/conftest.py @@ -249,12 +249,22 @@ def session(children: list[dict], monkeypatch: pytest.MonkeyPatch): ) session.add( MilestoneAnswer( - answer_session_id=1, milestone_id=1, milestone_group_id=1, answer=1 + answer_session_id=1, + milestone_id=1, + milestone_group_id=1, + answer=1, + included_in_milestone_statistics=True, + included_in_milestonegroup_statistics=True, ) ) session.add( MilestoneAnswer( - answer_session_id=1, milestone_id=2, milestone_group_id=1, answer=0 + answer_session_id=1, + milestone_id=2, + milestone_group_id=1, + answer=0, + included_in_milestone_statistics=True, + included_in_milestonegroup_statistics=True, ) ) # add another (current) milestone answer session for child 1 / user (id 3) with 2 answers to the same questions @@ -280,7 +290,10 @@ def session(children: list[dict], monkeypatch: pytest.MonkeyPatch): ) session.add( MilestoneAnswer( - answer_session_id=3, milestone_id=7, milestone_group_id=2, answer=2 + answer_session_id=3, + milestone_id=7, + milestone_group_id=2, + answer=2, ) ) # add a research group (that user with id 3 is part of, and researcher with id 2 has access to) diff --git a/mondey_backend/tests/routers/test_users.py b/mondey_backend/tests/routers/test_users.py index 7f9827fe..8011e88c 100644 --- a/mondey_backend/tests/routers/test_users.py +++ b/mondey_backend/tests/routers/test_users.py @@ -2,6 +2,9 @@ import pathlib from fastapi.testclient import TestClient +from sqlmodel import select + +from mondey_backend.models.milestones import MilestoneAnswer def _is_approx_now(iso_date_string: str, delta=datetime.timedelta(hours=1)) -> bool: @@ -181,8 +184,14 @@ def test_get_milestone_answers_child1_current_answer_session(user_client: TestCl assert response.json()["id"] == 2 assert response.json()["child_id"] == 1 assert response.json()["answers"] == { - "1": {"milestone_id": 1, "answer": 3}, - "2": {"milestone_id": 2, "answer": 2}, + "1": { + "milestone_id": 1, + "answer": 3, + }, + "2": { + "milestone_id": 2, + "answer": 2, + }, } assert _is_approx_now(response.json()["created_at"]) @@ -195,7 +204,10 @@ def test_update_milestone_answer_no_current_answer_session( # child 2 is 20 months old, so milestones 4 assert current_answer_session["answers"]["4"]["answer"] == -1 - new_answer = {"milestone_id": 4, "answer": 2} + new_answer = { + "milestone_id": 4, + "answer": 2, + } response = user_client.put( f"/users/milestone-answers/{current_answer_session['id']}", json=new_answer ) @@ -207,8 +219,14 @@ def test_update_milestone_answer_no_current_answer_session( def test_update_milestone_answer_update_existing_answer(user_client: TestClient): current_answer_session = user_client.get("/users/milestone-answers/1").json() - assert current_answer_session["answers"]["1"] == {"milestone_id": 1, "answer": 3} - new_answer = {"milestone_id": 1, "answer": 2} + assert current_answer_session["answers"]["1"] == { + "milestone_id": 1, + "answer": 3, + } + new_answer = { + "milestone_id": 1, + "answer": 2, + } response = user_client.put( f"/users/milestone-answers/{current_answer_session['id']}", json=new_answer ) @@ -356,7 +374,16 @@ def test_update_current_child_answers_no_prexisting( assert response.status_code == 404 -def test_get_summary_feedback_for_session(user_client: TestClient): +def test_get_summary_feedback_for_session(user_client: TestClient, session): + answers = session.exec( + select(MilestoneAnswer).where(MilestoneAnswer.answer_session_id == 1) + ).all() + for answer in answers: + answer.included_in_milestone_statistics = False + answer.included_in_milestonegroup_statistics = False + session.merge(answer) + session.commit() + response = user_client.get("/users/feedback/answersession=1/summary") assert response.status_code == 200 assert response.json() == {"1": 1} @@ -367,7 +394,15 @@ def test_get_summary_feedback_for_session_invalid(user_client: TestClient): assert response.status_code == 404 -def test_get_detailed_feedback_for_session(user_client: TestClient): +def test_get_detailed_feedback_for_session(user_client: TestClient, session): + answers = session.exec( + select(MilestoneAnswer).where(MilestoneAnswer.answer_session_id == 1) + ).all() + for answer in answers: + answer.included_in_milestone_statistics = False + answer.included_in_milestonegroup_statistics = False + session.merge(answer) + session.commit() response = user_client.get("/users/feedback/answersession=1/detailed") assert response.status_code == 200 assert response.json() == {"1": {"1": 1, "2": 1}} diff --git a/mondey_backend/tests/utils/test_statistics.py b/mondey_backend/tests/utils/test_statistics.py index a0ca079e..890dafdf 100644 --- a/mondey_backend/tests/utils/test_statistics.py +++ b/mondey_backend/tests/utils/test_statistics.py @@ -2,10 +2,14 @@ import numpy as np import pytest +from sqlmodel import col from sqlmodel import select +from mondey_backend.models.milestones import MilestoneAgeScoreCollection from mondey_backend.models.milestones import MilestoneAnswer +from mondey_backend.models.milestones import MilestoneAnswerSession from mondey_backend.models.milestones import MilestoneGroup +from mondey_backend.models.milestones import MilestoneGroupAgeScoreCollection from mondey_backend.routers.statistics import _add_sample from mondey_backend.routers.statistics import _finalize_statistics from mondey_backend.routers.statistics import _get_statistics_by_age @@ -200,6 +204,23 @@ def test_get_score_statistics_by_age_no_data(statistics_session): def test_calculate_milestone_statistics_by_age(statistics_session): + expiration_date = datetime.datetime.now() - datetime.timedelta(days=7) + answers_query = ( + select(MilestoneAnswer) + .join( + MilestoneAnswerSession, + col(MilestoneAnswer.answer_session_id) == MilestoneAnswerSession.id, + ) + .where(MilestoneAnswer.milestone_id == 1) + .where(~col(MilestoneAnswer.included_in_milestone_statistics)) + .where(col(MilestoneAnswerSession.created_at) <= expiration_date) + ) + + # originally, the relevant answers have not been integrated into the statistics yet + all_answers = statistics_session.exec(answers_query).all() + for answer in all_answers: + assert answer.included_in_milestone_statistics is False + # calculate_milestone_statistics_by_age mscore = calculate_milestone_statistics_by_age(statistics_session, 1) @@ -224,8 +245,46 @@ def test_calculate_milestone_statistics_by_age(statistics_session): else: assert mscore.scores[age].expected_score == 4 + # all answers for milestone 1 are now included into the answersesssion + # if they come from expired milestonesessions + + all_answers = statistics_session.exec(answers_query).all() + for answer in all_answers: + assert answer.included_in_milestone_statistics is True + + # the new result is not written into the database, so in order to check + # that data is not taken into account twice, we need to check against the + # old result, not the new one. + old = statistics_session.get(MilestoneAgeScoreCollection, 1) + + mscore2 = calculate_milestone_statistics_by_age(statistics_session, 1) + for s1, s2 in zip(mscore2.scores, old.scores, strict=True): + assert s1.age == s2.age + assert s1.count == s2.count + assert np.isclose(s1.avg_score, s2.avg_score) + assert np.isclose(s1.stddev_score, s2.stddev_score) + assert np.isclose(s1.expected_score, s2.expected_score) + def test_calculate_milestonegroup_statistics(statistics_session): + expiration_date = datetime.datetime.now() - datetime.timedelta(days=7) + + answer_query = ( + select(MilestoneAnswer) + .join( + MilestoneAnswerSession, + col(MilestoneAnswer.answer_session_id) == MilestoneAnswerSession.id, + ) + .where(MilestoneAnswer.milestone_group_id == 1) + .where(~col(MilestoneAnswer.included_in_milestonegroup_statistics)) + .where(MilestoneAnswerSession.created_at <= expiration_date) + ) + + all_answers = statistics_session.exec(answer_query).all() + for answer in all_answers: + print(answer) + assert answer.included_in_milestonegroup_statistics is False + milestone_group = statistics_session.exec( select(MilestoneGroup).where(MilestoneGroup.id == 1) ).first() @@ -273,3 +332,21 @@ def test_calculate_milestonegroup_statistics(statistics_session): assert score.scores[age].count == 0 if age > 12: assert np.isclose(score.scores[age].avg_score, 3.0) + + # check that calling the statistics anew with already integrated answers doesn´t change anything. + # we need to check against the old result, not the new one because this is not written into the database + all_answers = statistics_session.exec(answer_query).all() + for answer in all_answers: + assert answer.included_in_milestonegroup_statistics is True + + old_stats = statistics_session.get(MilestoneGroupAgeScoreCollection, 1) + new_stats = calculate_milestonegroup_statistics_by_age( + statistics_session, + milestone_group.id, + ) + for new_score, old_score in zip(new_stats.scores, old_stats.scores, strict=True): + assert new_score.age == old_score.age + assert new_score.count == old_score.count + assert np.isclose(new_score.avg_score, old_score.avg_score) + assert np.isclose(new_score.stddev_score, old_score.stddev_score) + assert new_score.milestone_group_id == old_score.milestone_group_id