Skip to content

Commit

Permalink
Merge pull request #103 from singnet/improvement/issue-102-update-get…
Browse files Browse the repository at this point in the history
…-matched-methods

[#102] Update get_matched_links(), get_matched_type_template() e get_matched_type()
  • Loading branch information
marcocapozzoli authored Feb 26, 2024
2 parents 16ba831 + 6063122 commit 6989383
Show file tree
Hide file tree
Showing 5 changed files with 208 additions and 67 deletions.
1 change: 1 addition & 0 deletions CHANGELOG
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
[#96] Adjust retrieve methods for sets in Redis - incoming_set, templates, patterns
[#98] Refactor get_incoming_links() in all AtomDB adapters
[#100] Add support to MeTTa loader nodes/links mapping
[#102] Update get_matched_links(), get_matched_type_template() e get_matched_type()
41 changes: 15 additions & 26 deletions hyperon_das_atomdb/adapters/ram_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,12 @@ def _delete_link_and_update_index(self, link_handle: str) -> None:

def _filter_non_toplevel(self, matches: list) -> list:
matches_toplevel_only = []
for match in matches:
link_handle = match[0]
links = self.db.link.get_table(len(match[-1]))
if links[link_handle]['is_toplevel']:
matches_toplevel_only.append(match)
if len(matches) > 0:
for match in matches:
link_handle = match[0]
links = self.db.link.get_table(len(match[-1]))
if links[link_handle]['is_toplevel']:
matches_toplevel_only.append(match)
return matches_toplevel_only

def _build_targets_list(self, link: Dict[str, Any]):
Expand Down Expand Up @@ -341,12 +342,7 @@ def is_ordered(self, link_handle: str) -> bool:
details=f'link_handle: {link_handle}',
)

def get_matched_links(
self,
link_type: str,
target_handles: List[str],
**kwargs: Optional[Dict[str, Any]],
) -> list:
def get_matched_links(self, link_type: str, target_handles: List[str], **kwargs) -> list:
if link_type != WILDCARD and WILDCARD not in target_handles:
link_handle = self.get_link_handle(link_type, target_handles)
return [link_handle]
Expand All @@ -370,9 +366,8 @@ def get_matched_links(

patterns_matched = self.db.patterns.get(pattern_hash, [])

if len(patterns_matched) > 0:
if kwargs.get('toplevel_only'):
return self._filter_non_toplevel(patterns_matched)
if kwargs.get('toplevel_only'):
return self._filter_non_toplevel(patterns_matched)

return patterns_matched

Expand All @@ -383,25 +378,19 @@ def get_incoming_links(self, atom_handle: str, **kwargs) -> List[IncomingLinksT]
else:
return [self.get_atom(handle, **kwargs) for handle in links]

def get_matched_type_template(
self,
template: List[Any],
**kwargs: Optional[Dict[str, Any]],
) -> List[str]:
def get_matched_type_template(self, template: List[Any], **kwargs) -> list:
template = self._build_named_type_hash_template(template)
template_hash = ExpressionHasher.composite_hash(template)
templates_matched = self.db.templates.get(template_hash, [])
if len(templates_matched) > 0:
if kwargs.get('toplevel_only'):
return self._filter_non_toplevel(templates_matched)
if kwargs.get('toplevel_only'):
return self._filter_non_toplevel(templates_matched)
return templates_matched

def get_matched_type(self, link_type: str, **kwargs: Optional[Dict[str, Any]]) -> List[str]:
def get_matched_type(self, link_type: str, **kwargs) -> list:
link_type_hash = ExpressionHasher.named_type_hash(link_type)
templates_matched = self.db.templates.get(link_type_hash, [])
if len(templates_matched) > 0:
if kwargs.get('toplevel_only'):
return self._filter_non_toplevel(templates_matched)
if kwargs.get('toplevel_only'):
return self._filter_non_toplevel(templates_matched)
return templates_matched

def get_atom(
Expand Down
81 changes: 43 additions & 38 deletions hyperon_das_atomdb/adapters/redis_mongo_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,14 +335,15 @@ def _get_mongo_document_keys(self, document: Dict) -> List[str]:
index += 1

def _filter_non_toplevel(self, matches: list) -> list:
if isinstance(matches[0], list):
matches = matches[0]
matches_toplevel_only = []
for match in matches:
link_handle = match[0]
link = self._retrieve_mongo_document(link_handle, len(match[-1]))
if link['is_toplevel']:
matches_toplevel_only.append(match)
if len(matches) > 0:
if isinstance(matches[0], list):
matches = matches[0]
for match in matches:
link_handle = match[0]
link = self._retrieve_mongo_document(link_handle, len(match[-1]))
if link['is_toplevel']:
matches_toplevel_only.append(match)
return matches_toplevel_only

def get_node_handle(self, node_type: str, node_name: str) -> str:
Expand Down Expand Up @@ -440,17 +441,17 @@ def is_ordered(self, link_handle: str) -> bool:
return True

def get_matched_links(
self,
link_type: str,
target_handles: List[str],
**kwargs: Optional[Dict[str, Any]],
):
self, link_type: str, target_handles: List[str], **kwargs
) -> Union[tuple, list]:
if link_type != WILDCARD and WILDCARD not in target_handles:
try:
link_handle = self.get_link_handle(link_type, target_handles)
document = self._retrieve_mongo_document(link_handle, len(target_handles))
return [link_handle] if document else []
except ValueError:
if kwargs.get('cursor') is not None:
return None, [link_handle]
return [link_handle]
except LinkDoesNotExist:
if kwargs.get('cursor') is not None:
return None, []
return []

if link_type == WILDCARD:
Expand All @@ -459,18 +460,17 @@ def get_matched_links(
link_type_hash = self._get_atom_type_hash(link_type)

if link_type_hash is None:
if kwargs.get('cursor') is not None:
return None, []
return []

if link_type in UNORDERED_LINK_TYPES:
target_handles = sorted(target_handles)

pattern_hash = ExpressionHasher.composite_hash([link_type_hash, *target_handles])
_, patterns_matched = self._retrieve_pattern(pattern_hash, **kwargs)
if len(patterns_matched) > 0:
if kwargs.get("toplevel_only"):
return self._filter_non_toplevel(patterns_matched)

return patterns_matched
cursor, patterns_matched = self._retrieve_pattern(pattern_hash, **kwargs)
toplevel_only = kwargs.get('toplevel_only', False)
return self._process_matched_results(patterns_matched, cursor, toplevel_only)

def get_incoming_links(
self, atom_handle: str, **kwargs
Expand All @@ -488,30 +488,22 @@ def get_incoming_links(
else:
return [self.get_atom(handle, **kwargs) for handle in links]

def get_matched_type_template(
self,
template: List[Any],
**kwargs: Optional[Dict[str, Any]],
) -> List[str]:
def get_matched_type_template(self, template: List[Any], **kwargs) -> Union[tuple, list]:
try:
template = self._build_named_type_hash_template(template)
template_hash = ExpressionHasher.composite_hash(template)
_, templates_matched = self._retrieve_template(template_hash, **kwargs)
if len(templates_matched) > 0:
if kwargs.get("toplevel_only"):
return self._filter_non_toplevel(templates_matched)
return templates_matched
cursor, templates_matched = self._retrieve_template(template_hash, **kwargs)
toplevel_only = kwargs.get('toplevel_only', False)
return self._process_matched_results(templates_matched, cursor, toplevel_only)
except Exception as exception:
logger().error(f'Failed to get matched type template - Details: {str(exception)}')
raise ValueError(str(exception))

def get_matched_type(self, link_type: str, **kwargs: Optional[Dict[str, Any]]) -> List[str]:
def get_matched_type(self, link_type: str, **kwargs) -> Union[tuple, list]:
named_type_hash = self._get_atom_type_hash(link_type)
_, templates_matched = self._retrieve_template(named_type_hash, **kwargs)
if len(templates_matched) > 0:
if kwargs.get("toplevel_only"):
return self._filter_non_toplevel(templates_matched)
return templates_matched
cursor, templates_matched = self._retrieve_template(named_type_hash, **kwargs)
toplevel_only = kwargs.get('toplevel_only', False)
return self._process_matched_results(templates_matched, cursor, toplevel_only)

def get_link_type(self, link_handle: str) -> str:
document = self.get_atom(link_handle)
Expand Down Expand Up @@ -698,7 +690,7 @@ def _get_redis_members(self, key, **kwargs) -> Tuple[int, list]:
chunk_size = kwargs.get('chunk_size', 1000)
cursor, members = self.redis.sscan(name=key, cursor=cursor, count=chunk_size)
else:
cursor = 0
cursor = None
members = self.redis.smembers(key)

return cursor, members
Expand Down Expand Up @@ -771,6 +763,19 @@ def _update_link_index(self, documents: Iterable[Dict[str, any]], **kwargs) -> N
key = _build_redis_key(KeyPrefix.INCOMING_SET, handle)
self.redis.sadd(key, *incoming_buffer[handle])

def _process_matched_results(
self, matched: list, cursor: int = None, toplevel_only: bool = False
) -> Union[tuple, list]:
if toplevel_only:
answer = self._filter_non_toplevel(matched)
else:
answer = matched

if cursor is not None:
return cursor, answer
else:
return answer

def reindex(self, pattern_index_templates: Optional[Dict[str, Dict[str, Any]]] = None):
if pattern_index_templates is not None:
self.pattern_index_templates = deepcopy(pattern_index_templates)
Expand Down
8 changes: 5 additions & 3 deletions hyperon_das_atomdb/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,9 @@ def is_ordered(self, link_handle: str) -> bool:
... # pragma no cover

@abstractmethod
def get_matched_links(self, link_type: str, target_handles: List[str]):
def get_matched_links(
self, link_type: str, target_handles: List[str], **kwargs
) -> Union[tuple, list]:
"""
Get links that match the specified type and targets.
Expand All @@ -339,7 +341,7 @@ def get_incoming_links(self, atom_handle: str, **kwargs) -> List[Any]:
... # pragma no cover

@abstractmethod
def get_matched_type_template(self, template: List[Any]) -> List[str]:
def get_matched_type_template(self, template: List[Any], **kwargs) -> Union[tuple, list]:
"""
Get nodes that match a specified template.
Expand All @@ -352,7 +354,7 @@ def get_matched_type_template(self, template: List[Any]) -> List[str]:
... # pragma no cover

@abstractmethod
def get_matched_type(self, link_type: str):
def get_matched_type(self, link_type: str, **kwargs) -> Union[tuple, list]:
"""
Get links that match a specified link type.
Expand Down
144 changes: 144 additions & 0 deletions tests/integration/adapters/test_redis_mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,3 +799,147 @@ def _asserts_patterns():
_asserts_patterns()

_db_down()

def test_get_matched_with_pagination(self):
_db_up()
db = self._connect_db()
self._add_atoms(db)
db.commit()

response = db.get_matched_links('Similarity', [human, monkey], cursor=0)
assert response == (None, [AtomDB.link_handle('Similarity', [human, monkey])])

response = db.get_matched_links('Fake', [human, monkey], cursor=0)
assert response == (None, [])

response = db.get_matched_links('Similarity', [human, '*'], cursor=0)
assert (response[0], sorted(response[1])) == (
0,
[
(
'16f7e407087bfa0b35b13d13a1aadcae',
('af12f10f9ae2002a1607ba0b47ba8407', '4e8e26e3276af8a5c2ac2cc2dc95c6d2'),
),
(
'b5459e299a5c5e8662c427f7e01b3bf1',
('af12f10f9ae2002a1607ba0b47ba8407', '5b34c54bee150c04f9fa584b899dc030'),
),
(
'bad7472f41a0e7d601ca294eb4607c3a',
('af12f10f9ae2002a1607ba0b47ba8407', '1cdffc6b0b89ff41d68bec237481d1e1'),
),
],
)

template = ['Inheritance', 'Concept', 'Concept']

response = db.get_matched_type_template(template, cursor=0)
assert (response[0], sorted(response[1])) == (
0,
[
(
'116df61c01859c710d178ba14a483509',
('c1db9b517073e51eb7ef6fed608ec204', 'b99ae727c787f1b13b452fd4c9ce1b9a'),
),
(
'1c3bf151ea200b2d9e088a1178d060cb',
('bdfe4e7a431f73386f37c6448afe5840', '0a32b476852eeb954979b87f5f6cb7af'),
),
(
'4120e428ab0fa162a04328e5217912ff',
('bb34ce95f161a6b37ff54b3d4c817857', '0a32b476852eeb954979b87f5f6cb7af'),
),
(
'75756335011dcedb71a0d9a7bd2da9e8',
('5b34c54bee150c04f9fa584b899dc030', 'bdfe4e7a431f73386f37c6448afe5840'),
),
(
'906fa505ae3bc6336d80a5f9aaa47b3b',
('d03e59654221c1e8fcda404fd5c8d6cb', '08126b066d32ee37743e255a2558cccd'),
),
(
'959924e3aab197af80a84c1ab261fd65',
('08126b066d32ee37743e255a2558cccd', 'b99ae727c787f1b13b452fd4c9ce1b9a'),
),
(
'b0f428929706d1d991e4d712ad08f9ab',
('b99ae727c787f1b13b452fd4c9ce1b9a', '0a32b476852eeb954979b87f5f6cb7af'),
),
(
'c93e1e758c53912638438e2a7d7f7b7f',
('af12f10f9ae2002a1607ba0b47ba8407', 'bdfe4e7a431f73386f37c6448afe5840'),
),
(
'e4685d56969398253b6f77efd21dc347',
('b94941d8cd1c0ee4ad3dd3dcab52b964', '80aff30094874e75028033a38ce677bb'),
),
(
'ee1c03e6d1f104ccd811cfbba018451a',
('4e8e26e3276af8a5c2ac2cc2dc95c6d2', '80aff30094874e75028033a38ce677bb'),
),
(
'f31dfe97db782e8cec26de18dddf8965',
('1cdffc6b0b89ff41d68bec237481d1e1', 'bdfe4e7a431f73386f37c6448afe5840'),
),
(
'fbf03d17d6a40feff828a3f2c6e86f05',
('99d18c702e813b07260baf577c60c455', 'bdfe4e7a431f73386f37c6448afe5840'),
),
],
)

response = db.get_matched_type('Inheritance', cursor=0)
assert (response[0], sorted(response[1])) == (
0,
[
(
'116df61c01859c710d178ba14a483509',
('c1db9b517073e51eb7ef6fed608ec204', 'b99ae727c787f1b13b452fd4c9ce1b9a'),
),
(
'1c3bf151ea200b2d9e088a1178d060cb',
('bdfe4e7a431f73386f37c6448afe5840', '0a32b476852eeb954979b87f5f6cb7af'),
),
(
'4120e428ab0fa162a04328e5217912ff',
('bb34ce95f161a6b37ff54b3d4c817857', '0a32b476852eeb954979b87f5f6cb7af'),
),
(
'75756335011dcedb71a0d9a7bd2da9e8',
('5b34c54bee150c04f9fa584b899dc030', 'bdfe4e7a431f73386f37c6448afe5840'),
),
(
'906fa505ae3bc6336d80a5f9aaa47b3b',
('d03e59654221c1e8fcda404fd5c8d6cb', '08126b066d32ee37743e255a2558cccd'),
),
(
'959924e3aab197af80a84c1ab261fd65',
('08126b066d32ee37743e255a2558cccd', 'b99ae727c787f1b13b452fd4c9ce1b9a'),
),
(
'b0f428929706d1d991e4d712ad08f9ab',
('b99ae727c787f1b13b452fd4c9ce1b9a', '0a32b476852eeb954979b87f5f6cb7af'),
),
(
'c93e1e758c53912638438e2a7d7f7b7f',
('af12f10f9ae2002a1607ba0b47ba8407', 'bdfe4e7a431f73386f37c6448afe5840'),
),
(
'e4685d56969398253b6f77efd21dc347',
('b94941d8cd1c0ee4ad3dd3dcab52b964', '80aff30094874e75028033a38ce677bb'),
),
(
'ee1c03e6d1f104ccd811cfbba018451a',
('4e8e26e3276af8a5c2ac2cc2dc95c6d2', '80aff30094874e75028033a38ce677bb'),
),
(
'f31dfe97db782e8cec26de18dddf8965',
('1cdffc6b0b89ff41d68bec237481d1e1', 'bdfe4e7a431f73386f37c6448afe5840'),
),
(
'fbf03d17d6a40feff828a3f2c6e86f05',
('99d18c702e813b07260baf577c60c455', 'bdfe4e7a431f73386f37c6448afe5840'),
),
],
)
_db_down()

0 comments on commit 6989383

Please sign in to comment.