Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Algo to give k ranked miners for store #45

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion db/migrations/20241212075345_validator_db.sql
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ CREATE TABLE piece (
chunk_idx INTEGER, -- Index of the chunk in the file
piece_idx INTEGER, -- Index of the piece in the chunk
piece_type INTEGER CHECK (piece_type IN (0, 1)), -- Type of the piece (0: data, 1: parity)
tag TEXT, -- APDP Tag of the piece
signature TEXT -- Signature of the DHT entry by the miner storing the piece
);

Expand Down
1 change: 1 addition & 0 deletions settings.toml.example
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ store_dir = "object_store"

[validator]
synthetic = false
top_miner_ratio = 0.7

[validator.neuron]
num_concurrent_forwards = 1
Expand Down
7 changes: 7 additions & 0 deletions storb/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,10 @@ def add_validator_args(self):
help="Query timeout",
default=self.settings.validator.query.timeout,
)

self._parser.add_argument(
"--top_miner_ratio",
type=float,
help="Top miner ratio",
default=self.settings.validator.top_miner_ratio,
)
152 changes: 152 additions & 0 deletions storb/util/uids.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import random

import numpy as np
from fiber.chain.metagraph import Metagraph

from storb.neuron import Neuron
from storb.util.logging import get_logger

logger = get_logger(__name__)


def check_hotkey_availability(metagraph: Metagraph, hotkey: str) -> bool:
Expand Down Expand Up @@ -80,3 +84,151 @@ def get_random_hotkeys(self: Neuron, k: int, exclude: list[int] = None) -> list[
)
hotkeys = random.sample(available_hotkeys, k)
return hotkeys


def get_ranked_hotkeys(
self: Neuron,
k: int,
exclude: list[int] = None,
top_fraction: float = 0.7,
low_fraction: float = 0.3,
) -> list[str]:
"""Returns a mixed list of k miner hotkeys from the metagraph, combining
high-ranked, low-ranked, and randomly selected miners based on
EMA scores in `self.scores`.

Parameters
----------
k : int
Total number of hotkeys to return.
exclude : list[int], optional
List of hotkeys to exclude from the selection.
top_fraction : float, optional
Fraction of k to fill with top-ranked miners. Must be between 0 and 1.
Default is 0.7.
low_fraction : float, optional
Fraction of k to fill with low-ranked miners. Must be between 0 and 1.
Default is 0.3.

Returns
-------
mixed_hotkeys : list[str]
A list of hotkeys, ordered as [top_hotkeys + low_hotkeys + random_hotkeys].
The total length will be <= k (it may be shorter if there aren't enough
available hotkeys).

Notes
-----
- `self.scores` is a numpy array of shape [num_nodes].
- The node's ID (index in `self.scores`) is `self.metagraph.nodes[hotkey].node_id`.
"""

# Validate inputs
if not 0 <= top_fraction <= 1:
raise ValueError("`top_fraction` must be between 0 and 1.")
if not 0 <= low_fraction <= 1:
raise ValueError("`low_fraction` must be between 0 and 1.")
if top_fraction + low_fraction > 1:
raise ValueError("top_fraction + low_fraction must not exceed 1.")
if k <= 0:
raise ValueError("`k` must be a positive integer.")

exclude_set = set(exclude) if exclude else set()

# Gather available hotkeys
candidate_hotkeys = []
avail_hotkeys = []
for hotkey in self.metagraph.nodes:
hotkey_is_available = check_hotkey_availability(self.metagraph, hotkey)
hotkey_is_not_excluded = hotkey not in exclude_set

if hotkey_is_available:
avail_hotkeys.append(hotkey)
if hotkey_is_not_excluded:
candidate_hotkeys.append(hotkey)

# Adjust k if needed
k = min(k, len(avail_hotkeys))
if k == 0:
return []

# Sort candidate hotkeys based on EMA scores in descending order
# (higher score = higher rank)
sorted_hotkeys_desc = sorted(
candidate_hotkeys,
key=lambda hk: self.scores[self.metagraph.nodes[hk].node_id],
reverse=True,
)

# Calculate how many top and low we need
num_top = int(k * top_fraction)
num_low = int(k * low_fraction)

# Select top-ranked hotkeys
top_hotkeys = sorted_hotkeys_desc[:num_top]

low_candidates = sorted_hotkeys_desc[-num_low:]
low_hotkeys = [hk for hk in low_candidates if hk not in top_hotkeys]
# Recompute `num_low` in case of overlap
num_low = len(low_hotkeys)

# Calculate how many hotkeys remain to be filled randomly
remaining = k - (num_top + num_low)
random_selected = []

if remaining > 0:
# The middle set excludes both top and low
middle_hotkeys = set(candidate_hotkeys) - set(top_hotkeys) - set(low_hotkeys)
middle_hotkeys = list(middle_hotkeys)

if middle_hotkeys:
# Assign weights proportional to each node's score among the middle
scores_middle = [
self.scores[self.metagraph.nodes[hk].node_id] for hk in middle_hotkeys
]
max_score = max(scores_middle) or 1.0

# Normalize
weights = [s / max_score for s in scores_middle]
total_weight = sum(weights)

if total_weight > 0:
normalized_weights = [w / total_weight for w in weights]
else:
# If all are zero
normalized_weights = [1 / len(middle_hotkeys)] * len(middle_hotkeys)

# Draw without replacement using those weights
actual_num_random = min(remaining, len(middle_hotkeys))
try:
random_selected = list(
np.random.choice(
middle_hotkeys,
size=actual_num_random,
replace=False,
p=normalized_weights,
)
)
except ValueError as e:
logger.error(f"Random choice failed: {e}")
random_selected = []

# Combine all selected
mixed_hotkeys = top_hotkeys + low_hotkeys + random_selected
current_total = len(mixed_hotkeys)

# If we still don't have k, fill the remainder from any available hotkeys
if current_total < k:
needed = k - current_total
# Exclude any we've already chosen
not_selected = set(avail_hotkeys) - set(mixed_hotkeys)
additional_hotkeys = list(not_selected)
if additional_hotkeys:
additional_sample = random.sample(
additional_hotkeys, min(needed, len(additional_hotkeys))
)
mixed_hotkeys.extend(additional_sample)

# Final trim in case we overshoot (unlikely, but just to be safe)
mixed_hotkeys = mixed_hotkeys[:k]
return mixed_hotkeys
82 changes: 75 additions & 7 deletions storb/validator/piece_processing.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import asyncio
import json
import math
from datetime import datetime

import numpy as np

from storb import db, protocol
from storb.challenge import APDPTag
from storb.constants import QUERY_TIMEOUT
from storb.constants import NUM_UIDS_QUERY, QUERY_TIMEOUT
from storb.dht.piece_dht import PieceDHTValue
from storb.protocol import PieceChallenge
from storb.util.logging import get_logger
from storb.util.message_signing import PieceMessage, sign_message
from storb.util.piece import piece_hash
from storb.util.piece import Piece, PieceType, piece_hash
from storb.util.query import Payload
from storb.validator.types import PieceTask, ProcessedPieceResponse

Expand Down Expand Up @@ -89,7 +90,7 @@ def consume_piece_queue(self):

logger.info("Exiting consume_piece_queue thread.")

async def process_pieces(self, pieces, hotkeys):
async def process_pieces(self, pieces: list[Piece], hotkeys: list[str]):
"""
Process each piece: compute piece hash, store them on the miners,
update DB stats, update latencies, etc.
Expand All @@ -100,7 +101,6 @@ async def process_pieces(self, pieces, hotkeys):
piece_hashes = []
processed_pieces = []

# Basic example of how you'd do batch queries:
to_query = []
curr_batch_size = 0
uids = []
Expand All @@ -113,9 +113,10 @@ async def process_pieces(self, pieces, hotkeys):
async def handle_batch_requests():
nonlocal to_query, latencies
batch_responses = await asyncio.gather(*(t for _, t in to_query))

logger.info(f"Batch responses: {batch_responses}")
for i, batch_result in enumerate(batch_responses):
real_idx = to_query[i][0]
successful_miners = []
# batch_result is the list of (uid, payload_response)
for uid, payload_resp in batch_result:
miner_stats[uid]["store_attempts"] += 1
Expand All @@ -127,7 +128,13 @@ async def handle_batch_requests():
)
miner_stats[uid]["store_successes"] += 1
miner_stats[uid]["total_successes"] += 1
successful_miners.append(uid)

if len(successful_miners) == 0:
logger.error(f"No successful miners for piece {real_idx}")
return

logger.info(f"Successful miners: {successful_miners}")
# Put piece info in the queue for DHT storing
self.piece_miners.setdefault(piece_hashes[real_idx], []).extend(
[uid for (uid, p) in batch_result if p and p.data]
Expand Down Expand Up @@ -169,8 +176,69 @@ async def handle_batch_requests():
logger.info(f"Curr piece queue size: {self.piece_queue.qsize()}")
to_query = []

# Separate pieces into PARITY and DATA
data_pieces = [p for p in pieces if p.piece_type == PieceType.Data]
parity_pieces = [p for p in pieces if p.piece_type == PieceType.Parity]

total_pieces = len(pieces)
top_miner_pieces_count = math.floor(
total_pieces * self.settings.validator.top_miner_ratio
)

top_miner_pieces = parity_pieces[:]
Shr1ftyy marked this conversation as resolved.
Show resolved Hide resolved

if len(top_miner_pieces) < top_miner_pieces_count:
top_miner_pieces.extend(
data_pieces[: top_miner_pieces_count - len(top_miner_pieces)]
)

low_miner_pieces = data_pieces[top_miner_pieces_count:]

top_miner_count = int(NUM_UIDS_QUERY * self.settings.validator.top_miner_ratio)
top_miners = hotkeys[:top_miner_count]
low_miners = hotkeys[top_miner_count:]

logger.debug(f"Top miners: {top_miners}, Low miners: {low_miners}")

# Loop pieces
for idx, piece_info in enumerate(pieces):
for idx, piece_info in enumerate(top_miner_pieces):
p_hash = piece_hash(piece_info.data)
piece_hashes.append(p_hash)
processed_pieces.append(
protocol.ProcessedPieceInfo(
chunk_idx=piece_info.chunk_idx,
piece_type=piece_info.piece_type,
piece_idx=piece_info.piece_idx,
data=piece_info.data,
piece_id=p_hash,
)
)
# Create a store request
payload = Payload(
data=protocol.Store(
chunk_idx=piece_info.chunk_idx,
piece_type=piece_info.piece_type,
piece_idx=piece_info.piece_idx,
),
file=piece_info.data,
)
task = asyncio.create_task(
self.query_multiple_miners(
miner_hotkeys=top_miners,
endpoint="/store",
payload=payload,
method="POST",
)
)
to_query.append((idx, task))
curr_batch_size += 1

# If batch is full, send
if curr_batch_size >= self.settings.validator.query.batch_size:
await handle_batch_requests()
curr_batch_size = 0

for idx, piece_info in enumerate(low_miner_pieces):
p_hash = piece_hash(piece_info.data)
piece_hashes.append(p_hash)
processed_pieces.append(
Expand All @@ -193,7 +261,7 @@ async def handle_batch_requests():
)
task = asyncio.create_task(
self.query_multiple_miners(
miner_hotkeys=hotkeys,
miner_hotkeys=low_miners,
endpoint="/store",
payload=payload,
method="POST",
Expand Down
10 changes: 7 additions & 3 deletions storb/validator/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
reconstruct_data_stream,
)
from storb.util.query import Payload
from storb.util.uids import get_random_hotkeys
from storb.util.uids import get_ranked_hotkeys

logger = get_logger(__name__)

Expand Down Expand Up @@ -120,8 +120,12 @@ async def upload_file(self, file: UploadFile = File(...)) -> protocol.StoreRespo

timestamp = str(datetime.now(UTC).timestamp())

# TODO: Consider miner scores for selection, and not just their availability
hotkeys = get_random_hotkeys(self, NUM_UIDS_QUERY)
hotkeys = get_ranked_hotkeys(
self,
k=NUM_UIDS_QUERY,
top_fraction=self.settings.validator.top_miner_ratio,
low_fraction=(1 - self.settings.validator.top_miner_ratio),
)

chunk_hashes = []
piece_hashes = set()
Expand Down
Loading