统一的深度学习框架 ivy

whiteIn1年前 ⋅ 2449 阅读

一个支持 JAX、TensorFlow、PyTorch 和 Numpy 的统一 ML 框架。你可以快速用 Ivy 来训练一个神经网络,就像这样,用上你喜欢的框架:

https://github.com/unifyai/ivy

import ivy

class MyModel(ivy.Module):
    def __init__(self):
        self.linear0 = ivy.Linear(3, 64)
        self.linear1 = ivy.Linear(64, 1)
        ivy.Module.__init__(self)

    def _forward(self, x):
        x = ivy.relu(self.linear0(x))
        return ivy.sigmoid(self.linear1(x))

ivy.set_backend('torch')  # change to any backend!
model = MyModel()
optimizer = ivy.Adam(1e-4)
x_in = ivy.array([1., 2., 3.])
target = ivy.array([0.])

def loss_fn(v):
    out = model(x_in, v=v)
    return ivy.mean((out - target)**2)

for step in range(100):
    loss, grads = ivy.execute_with_gradients(loss_fn, model.v)
    model.v = optimizer.step(model.v, grads)
    print('step {} loss {}'.format(step, ivy.to_numpy(loss).item()))

print('Finished training!')


全部评论: 0

    相关推荐