Welcome to chebai’s documentation!

How to define a model

from chebai.models.base import JCIBaseNet
import torch

class MyModel(JCIBaseNet):
    def __init__(self, dims):
        super().__init__()
        self.lin = torch.nn.Linear(dims, dims+1)

    def forward(self, x):
        return self.lin(x)

model = MyModel(5)
inp = torch.rand((1,5))
result = model(inp)
print(result.shape)
torch.Size([1, 6])

How to run experiments

Indices and tables