热线电话:13121318867

登录
2018-12-19 阅读量: 791
如何用Logistic回归识别手写数字?(2)

现在,我们将定义我们的超参数

# Hyper Parameters

input_size = 784

num_classes = 10

num_epochs = 5

batch_size = 100

learning_rate = 0.001

在我们的数据集中,图像大小为28 * 28。因此,我们的输入大小是784.此外,这里有10位数字,因此,我们可以有10个不同的输出。因此,我们将num_classes设置为10.此外,我们将在整个数据集上训练五次。最后,我们将分别训练小批量的100张图像,以防止因内存溢出而导致程序崩溃。

在此之后,我们将定义我们的模型如下。在这里,我们将我们的模型初始化为torch.nn.Module的子类,然后定义前向传递。在我们编写的代码中,softmax在每次正向传递期间内部计算,因此我们不需要在forward()函数内指定它。

class LogisticRegression(nn.Module):

def __init__(self, input_size, num_classes):

super(LogisticRegression, self).__init__()

self.linear = nn.Linear(input_size, num_classes)

def forward(self, x):

out = self.linear(x)

return out

定义了我们的类之后,现在我们实例化了一个对象

model = LogisticRegression(input_size, num_classes)

接下来,我们设置损失函数和优化器。在这里,我们将使用交叉熵损失,对于优化器,我们将使用随机梯度下降算法,其学习率为0.001,如上面的超参数中所定义。

criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)

0.0000
1
关注作者
收藏
评论(0)

发表评论

暂无数据
推荐帖子