Skip to content

Commit

Permalink
fix: use correct primary field name in Hits (#2561)
Browse files Browse the repository at this point in the history
issue: #2558

Signed-off-by: zhenshan.cao <[email protected]>
  • Loading branch information
czs007 authored Jan 13, 2025
1 parent 1cc5ee8 commit 3a2abe0
Show file tree
Hide file tree
Showing 8 changed files with 373 additions and 294 deletions.
62 changes: 62 additions & 0 deletions examples/customize_schema_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import time
import numpy as np
from pymilvus import (
MilvusClient,
DataType
)

fmt = "\n=== {:30} ===\n"
dim = 8
collection_name = "hello_milvus"
milvus_client = MilvusClient("http://localhost:19530")

has_collection = milvus_client.has_collection(collection_name, timeout=5)
if has_collection:
milvus_client.drop_collection(collection_name)

schema = milvus_client.create_schema(enable_dynamic_field=True)
schema.add_field("uid", DataType.INT64, is_primary=True)
schema.add_field("embeddings", DataType.FLOAT_VECTOR, dim=dim)
schema.add_field("title", DataType.VARCHAR, max_length=64)
schema.add_field("id", DataType.VARCHAR, max_length=64)


index_params = milvus_client.prepare_index_params()
index_params.add_index(field_name = "embeddings", metric_type="L2")
milvus_client.create_collection(collection_name, schema=schema, index_params=index_params, consistency_level="Strong")

print(fmt.format(" all collections "))
print(milvus_client.list_collections())

print(fmt.format(f"schema of collection {collection_name}"))
print(milvus_client.describe_collection(collection_name))

rng = np.random.default_rng(seed=19530)
rows = [
{"uid": 1, "embeddings": rng.random((1, dim))[0], "a": 100, "title": "t1", "id":"u1"},
{"uid": 2, "embeddings": rng.random((1, dim))[0], "b": 200, "title": "t2", "id":"u2"},
{"uid": 3, "embeddings": rng.random((1, dim))[0], "c": 300, "title": "t3", "id":"u3"},
{"uid": 4, "embeddings": rng.random((1, dim))[0], "d": 400, "title": "t4", "id":"u4"},
{"uid": 5, "embeddings": rng.random((1, dim))[0], "e": 500, "title": "t5", "id":"u5"},
{"uid": 6, "embeddings": rng.random((1, dim))[0], "f": 600, "title": "t6", "id":"u6"},
]

print(fmt.format("Start inserting entities"))
insert_result = milvus_client.insert(collection_name, rows)
print(fmt.format("Inserting entities done"))
print(insert_result)


print(fmt.format("Start load collection "))
milvus_client.load_collection(collection_name)

rng = np.random.default_rng(seed=19530)
vectors_to_search = rng.random((1, dim))

print(fmt.format(f"Start search with retrieve serveral fields."))
result = milvus_client.search(collection_name, vectors_to_search, limit=3, output_fields=["id"])
for hits in result:
for hit in hits:
print(f"hit: {hit}")

milvus_client.drop_collection(collection_name)
21 changes: 16 additions & 5 deletions pymilvus/client/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ def __init__(
self._nq = res.num_queries
all_topks = res.topks
self.recalls = res.recalls
self._pk_name = res.primary_field_name or "id"

self.cost = int(status.extra_info["report_value"] if status and status.extra_info else "0")

Expand All @@ -434,7 +435,14 @@ def __init__(
start, end = nq_thres, nq_thres + topk
nq_th_fields = self.get_fields_by_range(start, end, fields_data)
data.append(
Hits(topk, all_pks[start:end], all_scores[start:end], nq_th_fields, output_fields)
Hits(
topk,
all_pks[start:end],
all_scores[start:end],
nq_th_fields,
output_fields,
self._pk_name,
)
)
nq_thres += topk

Expand Down Expand Up @@ -565,6 +573,7 @@ def __init__(
distances: List[float],
fields: Dict[str, Tuple[List[Any], schema_pb2.FieldData]],
output_fields: List[str],
pk_name: str,
):
"""
Args:
Expand All @@ -573,6 +582,7 @@ def __init__(
"""
self.ids = pks
self.distances = distances
self._pk_name = pk_name

all_fields = list(fields.keys())
dynamic_fields = list(set(output_fields) - set(all_fields))
Expand Down Expand Up @@ -611,7 +621,7 @@ def __init__(
# sparse float vector and other fields
curr_field[fname] = data[i]

hits.append(Hit(pks[i], distances[i], curr_field))
hits.append(Hit(pks[i], distances[i], curr_field, self._pk_name))

super().__init__(hits)

Expand All @@ -631,10 +641,11 @@ class Hit:
distance: float
fields: Dict[str, Any]

def __init__(self, pk: Union[int, str], distance: float, fields: Dict[str, Any]):
def __init__(self, pk: Union[int, str], distance: float, fields: Dict[str, Any], pk_name: str):
self.id = pk
self.distance = distance
self.fields = fields
self._pk_name = pk_name

def __getattr__(self, item: str):
if item not in self.fields:
Expand All @@ -657,13 +668,13 @@ def get(self, field_name: str) -> Any:
return self.fields.get(field_name)

def __str__(self) -> str:
return f"id: {self.id}, distance: {self.distance}, entity: {self.fields}"
return f"{self._pk_name}: {self.id}, distance: {self.distance}, entity: {self.fields}"

__repr__ = __str__

def to_dict(self):
return {
"id": self.id,
self._pk_name: self.id,
"distance": self.distance,
"entity": self.fields,
}
Expand Down
2 changes: 1 addition & 1 deletion pymilvus/grpc_gen/milvus-proto
524 changes: 262 additions & 262 deletions pymilvus/grpc_gen/milvus_pb2.py

Large diffs are not rendered by default.

8 changes: 6 additions & 2 deletions pymilvus/grpc_gen/milvus_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -990,16 +990,20 @@ class QueryRequest(_message.Message):
def __init__(self, base: _Optional[_Union[_common_pb2.MsgBase, _Mapping]] = ..., db_name: _Optional[str] = ..., collection_name: _Optional[str] = ..., expr: _Optional[str] = ..., output_fields: _Optional[_Iterable[str]] = ..., partition_names: _Optional[_Iterable[str]] = ..., travel_timestamp: _Optional[int] = ..., guarantee_timestamp: _Optional[int] = ..., query_params: _Optional[_Iterable[_Union[_common_pb2.KeyValuePair, _Mapping]]] = ..., not_return_all_meta: bool = ..., consistency_level: _Optional[_Union[_common_pb2.ConsistencyLevel, str]] = ..., use_default_consistency: bool = ..., expr_template_values: _Optional[_Mapping[str, _schema_pb2.TemplateValue]] = ...) -> None: ...

class QueryResults(_message.Message):
__slots__ = ("status", "fields_data", "collection_name", "output_fields")
__slots__ = ("status", "fields_data", "collection_name", "output_fields", "session_ts", "primary_field_name")
STATUS_FIELD_NUMBER: _ClassVar[int]
FIELDS_DATA_FIELD_NUMBER: _ClassVar[int]
COLLECTION_NAME_FIELD_NUMBER: _ClassVar[int]
OUTPUT_FIELDS_FIELD_NUMBER: _ClassVar[int]
SESSION_TS_FIELD_NUMBER: _ClassVar[int]
PRIMARY_FIELD_NAME_FIELD_NUMBER: _ClassVar[int]
status: _common_pb2.Status
fields_data: _containers.RepeatedCompositeFieldContainer[_schema_pb2.FieldData]
collection_name: str
output_fields: _containers.RepeatedScalarFieldContainer[str]
def __init__(self, status: _Optional[_Union[_common_pb2.Status, _Mapping]] = ..., fields_data: _Optional[_Iterable[_Union[_schema_pb2.FieldData, _Mapping]]] = ..., collection_name: _Optional[str] = ..., output_fields: _Optional[_Iterable[str]] = ...) -> None: ...
session_ts: int
primary_field_name: str
def __init__(self, status: _Optional[_Union[_common_pb2.Status, _Mapping]] = ..., fields_data: _Optional[_Iterable[_Union[_schema_pb2.FieldData, _Mapping]]] = ..., collection_name: _Optional[str] = ..., output_fields: _Optional[_Iterable[str]] = ..., session_ts: _Optional[int] = ..., primary_field_name: _Optional[str] = ...) -> None: ...

class VectorIDs(_message.Message):
__slots__ = ("collection_name", "field_name", "id_array", "partition_names")
Expand Down
Loading

0 comments on commit 3a2abe0

Please sign in to comment.