-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.cpp
42 lines (34 loc) · 1.53 KB
/
main.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
#include "Projects/CIFAR/CPP/Dataset/CIFARTorchDataset.h"
#include "Projects/CIFAR/CPP/Trainer/Trainer.hpp"
#include "TorchScriptUtilities/LoadJIT/LoadJIT.h"
int main() {
std::string trainDatasetPath = "/Users/kartikrajeshwaran/CodeSupport/CPP/Datasets/CIFAR-10-images/train";
std::string evalDatasetPath = "/Users/kartikrajeshwaran/CodeSupport/CPP/Datasets/CIFAR-10-images/test";
std::string modelPath = "/Users/kartikrajeshwaran/CodeSupport/CPP/Models/LibtorchPlayground/BlockConvNet/BlockConvNetJIT.pt";
auto *trainDataset = new CIFARTorchDataset(trainDatasetPath);
auto *evalDataset = new CIFARTorchDataset(evalDatasetPath);
auto loader = new LoadJIT(modelPath);
auto model = loader->get_model_ptr();
std::vector<torch::Tensor> params;
for(auto i = model->parameters().begin(); i != model->parameters().end(); i++){
params.push_back(*i);
}
torch::optim::AdamWOptions adamOptions(5e-3);
torch::optim::AdamW adamOptimizer(params, adamOptions);
torch::nn::CrossEntropyLoss *loss = new torch::nn::CrossEntropyLoss();
auto *trainer = new Trainer<torch::jit::Module *,
CIFARTorchDataset,
torch::data::samplers::DistributedRandomSampler,
torch::optim::AdamW,
torch::nn::CrossEntropyLoss,
std::vector<torch::jit::IValue>>(
model,
*trainDataset,
*evalDataset,
32,
adamOptimizer,
*loss,
4);
trainer->fit_parallel(16, 0);
return 0;
}