Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Errors in trainer_sgdmf.py and movielens.py #779

Open
czzhangheng opened this issue Aug 2, 2024 · 1 comment
Open

Fix Errors in trainer_sgdmf.py and movielens.py #779

czzhangheng opened this issue Aug 2, 2024 · 1 comment

Comments

@czzhangheng
Copy link

I tried to run this example following the docs at https://federatedscope.io/docs/recommendation/ with the command:
python federatedscope/main.py --cfg federatedscope/mf/baseline/hfl-sgdmf_fedavg_standalone_on_movielens1m.yaml
However, it did not run and reported some errors. The error occurred in the file ./federatedscope/mf/trainer/trainer_sgdmf.py. It might be caused by a torch type "Embedding". Specifically, ctx.model.embed_user.grad is incorrect, while ctx.model.embed_user.weight.grad is correct. Additionally, there are some other errors, such as "add(sparse, dense)".

I tried using ChatGPT to fix the code, and now the example can run. I checked my Git history. Here are my fixed records:

In federatedscope/mf/dataset/movielens.py, line 160-161

row = [mapping_user[mid] for _, mid in data["userId"].items()]
col = [mapping_item[mid] for _, mid in data["movieId"].items()]

In federatedscope/mf/trainer/trainer_sgdmf.py line 70, replace all funciton def hook_on_batch_backward(ctx):

def hook_on_batch_backward(ctx):
    """Private local updates in SGDMF

    """
    ctx.optimizer.zero_grad()
    ctx.loss_task.backward()

    if ctx.model.embed_user.weight.grad.is_sparse:
        dense_user_grad = ctx.model.embed_user.weight.grad.to_dense()
    else:
        dense_user_grad = ctx.model.embed_user.weight.grad

    if ctx.model.embed_item.weight.grad.is_sparse:
        dense_item_grad = ctx.model.embed_item.weight.grad.to_dense()
    else:
        dense_item_grad = ctx.model.embed_item.weight.grad

    # Inject noise
    dense_user_grad.data += get_random(
        "Normal",
        sample_shape=ctx.model.embed_user.weight.shape,
        params={
            "loc": 0,
            "scale": ctx.scale
        },
        device=ctx.model.embed_user.weight.device)
    dense_item_grad.data += get_random(
        "Normal",
        sample_shape=ctx.model.embed_item.weight.shape,
        params={
            "loc": 0,
            "scale": ctx.scale
        },
        device=ctx.model.embed_item.weight.device)

    ctx.model.embed_user.weight.grad = dense_user_grad.to_sparse()
    ctx.model.embed_item.weight.grad = dense_item_grad.to_sparse()
    ctx.optimizer.step()

    # Embedding clipping
    with torch.no_grad():
        embedding_clip(ctx.model.embed_user.weight, ctx.sgdmf_R)
        embedding_clip(ctx.model.embed_item.weight, ctx.sgdmf_R)

The code can now run, but I’m not sure if there are any other issues.
I rarely use GitHub. I might need to learn how to pull a request later.

Env.:
python 3.9
torch 1.10.1
cuda 11.3

Thank your work. Have a good day. :)

@czzhangheng
Copy link
Author

Meanwhile, I cannot scan the QR code of Ding-group on the official website. It appears to be out-of-date.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant