Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] loadPyTorchDump does not load in models with sequential layers correctly #10

Open
ShreyasKhandekar opened this issue Oct 1, 2024 · 2 comments

Comments

@ShreyasKhandekar
Copy link
Contributor

The model.loadPyTorchDump() function does not handle nested module within Sequential layers, and is unable to track the layer names correctly to import the files.

Ex: If we have a model from pytorch:

class Dummy(nn.Module):
    def __init__(self):
        super(Dummy, self).__init__()

        self.model = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
        )

and we dump it using chai_dump

dummy = Dummy()
dummy.chai_dump('models/dummy', 'dummy')

The files we get are:

models/dummy
| - model.0.weight.chdata
| - model.0.weight.json
| - specification.json

And the loadPyTorchDump function does not correctly handle these file names since it does not keep track of the "nested" nature of the Conv2d layer since it's inside of the Sequential layer.

A more complicated example that expands on the above:

class DummyTwo(nn.Module):
    def __init__(self, input_model):
        super(DummyTwo, self).__init__()
        self.model = nn.Sequential(
            input_model,
        )

dummy_two = DummyTwo(dummy)
dummy_two.chai_dump('models/dummy_two', 'dummy_two')

Gives us the files:

models/dummy_two
| - model.0.model.0.weight.chdata
| - model.0.model.0.weight.json
| - specification.json

Again, the loadPyTorchDump cannot handle this nested nature of models.

@Iainmon
Copy link
Owner

Iainmon commented Oct 3, 2024

I think the loading in feature should be rebuild, along with a discussion about how to engineer the module system to be more like PyTorch. I didn't have enough time to fix this during the internship.

@ShreyasKhandekar
Copy link
Contributor Author

I have a fix for this that does not re-engineer to module system, but just works with it to support loading nested modules.

Iainmon pushed a commit that referenced this issue Jan 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants