Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
D
DeepInverse
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Package registry
Container registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Hailu, Dawit
DeepInverse
Commits
75b954bc
Commit
75b954bc
authored
4 years ago
by
Kartheek Akella
Browse files
Options
Downloads
Patches
Plain Diff
Add data module for handling dataset specific tasks.
parent
6c2f52b5
Branches
Branches containing commit
No related tags found
No related merge requests found
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
src/research_mnist/mnist.py
+23
-22
23 additions, 22 deletions
src/research_mnist/mnist.py
src/research_mnist/mnist_data_module.py
+92
-0
92 additions, 0 deletions
src/research_mnist/mnist_data_module.py
src/research_mnist/mnist_trainer.py
+6
-2
6 additions, 2 deletions
src/research_mnist/mnist_trainer.py
with
121 additions
and
24 deletions
src/research_mnist/mnist.py
+
23
−
22
View file @
75b954bc
...
...
@@ -29,52 +29,54 @@ class CoolSystem(pl.LightningModule):
y_hat
=
self
.
forward
(
x
)
loss
=
F
.
cross_entropy
(
y_hat
,
y
)
tensorboard_logs
=
{
'
train_loss
'
:
loss
}
result
=
pl
.
TrainResult
(
minimize
=
loss
)
result
.
log
(
'
train_loss
'
,
loss
,
prog_bar
=
True
)
return
{
'
loss
'
:
loss
,
'
log
'
:
tensorboard_logs
}
return
result
def
validation_step
(
self
,
batch
,
batch_idx
):
# OPTIONAL
x
,
y
=
batch
y_hat
=
self
.
forward
(
x
)
return
{
'
val_loss
'
:
F
.
cross_entropy
(
y_hat
,
y
)}
loss
=
F
.
cross_entropy
(
y_hat
,
y
)
result
=
pl
.
EvalResult
()
result
.
valid_batch_loss
=
loss
result
.
log
(
'
valid_loss
'
,
loss
,
on_epoch
=
True
,
prog_bar
=
True
)
return
result
def
validation_epoch_end
(
self
,
outputs
):
# OPTIONAL
avg_loss
=
torch
.
stack
([
x
[
'
val_loss
'
]
for
x
in
outputs
]).
mean
()
avg_loss
=
outputs
.
valid_batch_loss
.
mean
()
result
=
pl
.
EvalResult
(
checkpoint_on
=
avg_loss
)
result
.
log
(
'
valid_loss
'
,
avg_loss
,
on_epoch
=
True
,
prog_bar
=
True
)
tensorboard_logs
=
{
'
avg_val_loss
'
:
avg_loss
}
return
{
'
val_loss
'
:
avg_loss
,
'
log
'
:
tensorboard_logs
}
return
result
def
test_step
(
self
,
batch
,
batch_idx
):
# OPTIONAL
x
,
y
=
batch
y_hat
=
self
.
forward
(
x
)
return
{
'
test_loss
'
:
F
.
cross_entropy
(
y_hat
,
y
)}
loss
=
F
.
cross_entropy
(
y_hat
,
y
)
result
=
pl
.
EvalResult
()
result
.
test_batch_loss
=
loss
result
.
log
(
'
test_loss
'
,
loss
,
on_epoch
=
True
)
return
result
def
test_epoch_end
(
self
,
outputs
):
# OPTIONAL
avg_loss
=
torch
.
stack
([
x
[
'
test_loss
'
]
for
x
in
outputs
])
.
mean
()
avg_loss
=
outputs
.
test_batch_loss
.
mean
()
tensorboard_logs
=
{
'
test_val_loss
'
:
avg_loss
}
return
{
'
test_loss
'
:
avg_loss
,
'
log
'
:
tensorboard_logs
}
result
=
pl
.
EvalResult
()
result
.
log
(
'
test_loss
'
,
avg_loss
,
on_epoch
=
True
)
return
result
def
configure_optimizers
(
self
):
# REQUIRED
# can return multiple optimizers and learning_rate schedulers
return
torch
.
optim
.
Adam
(
self
.
parameters
(),
lr
=
self
.
hparams
.
learning_rate
)
def
train_dataloader
(
self
):
# REQUIRED
return
DataLoader
(
MNIST
(
os
.
getcwd
(),
train
=
True
,
download
=
True
,
transform
=
transforms
.
ToTensor
()),
batch_size
=
self
.
hparams
.
batch_size
)
def
val_dataloader
(
self
):
# OPTIONAL
return
DataLoader
(
MNIST
(
os
.
getcwd
(),
train
=
True
,
download
=
True
,
transform
=
transforms
.
ToTensor
()),
batch_size
=
self
.
hparams
.
batch_size
)
def
test_dataloader
(
self
):
# OPTIONAL
return
DataLoader
(
MNIST
(
os
.
getcwd
(),
train
=
True
,
download
=
True
,
transform
=
transforms
.
ToTensor
()),
batch_size
=
self
.
hparams
.
batch_size
)
@staticmethod
def
add_model_specific_args
(
parent_parser
):
...
...
@@ -84,7 +86,6 @@ class CoolSystem(pl.LightningModule):
# MODEL specific
parser
=
ArgumentParser
(
parents
=
[
parent_parser
],
add_help
=
False
)
parser
.
add_argument
(
'
--learning_rate
'
,
default
=
0.02
,
type
=
float
)
parser
.
add_argument
(
'
--batch_size
'
,
default
=
32
,
type
=
int
)
# training specific (for this model)
parser
.
add_argument
(
'
--max_nb_epochs
'
,
default
=
2
,
type
=
int
)
...
...
This diff is collapsed.
Click to expand it.
src/research_mnist/mnist_data_module.py
0 → 100644
+
92
−
0
View file @
75b954bc
from
argparse
import
ArgumentParser
import
pytorch_lightning
as
pl
from
pytorch_lightning.metrics.functional
import
accuracy
import
torch
from
torch
import
nn
import
torch.nn.functional
as
F
from
torch.utils.data
import
random_split
,
DataLoader
# Note - you must have torchvision installed for this example
from
torchvision.datasets
import
MNIST
,
CIFAR10
from
torchvision
import
transforms
class
MNISTDataModule
(
pl
.
LightningDataModule
):
def
__init__
(
self
,
hparams
):
super
().
__init__
()
self
.
hparams
=
hparams
self
.
data_dir
=
self
.
hparams
.
data_dir
self
.
batch_size
=
self
.
hparams
.
batch_size
# We hardcode dataset specific stuff here.
self
.
num_classes
=
10
self
.
dims
=
(
1
,
28
,
28
)
self
.
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),])
# Basic test that parameters passed are sensible.
assert
(
self
.
hparams
.
train_size
+
self
.
hparams
.
valid_size
==
60_000
),
"
Invalid Train and Valid Split, make sure they add up to 60,000
"
def
prepare_data
(
self
):
# download the dataset
MNIST
(
self
.
data_dir
,
train
=
True
,
download
=
True
)
MNIST
(
self
.
data_dir
,
train
=
False
,
download
=
True
)
def
setup
(
self
,
stage
=
None
):
# Assign train/val datasets for use in dataloaders
if
stage
==
"
fit
"
or
stage
is
None
:
mnist_full
=
MNIST
(
self
.
data_dir
,
train
=
True
,
transform
=
self
.
transform
)
self
.
mnist_train
,
self
.
mnist_val
=
random_split
(
mnist_full
,
[
self
.
hparams
.
train_size
,
self
.
hparams
.
valid_size
]
)
# Assign test dataset for use in dataloader(s)
if
stage
==
"
test
"
or
stage
is
None
:
self
.
mnist_test
=
MNIST
(
self
.
data_dir
,
train
=
False
,
transform
=
self
.
transform
)
def
train_dataloader
(
self
):
# REQUIRED
return
DataLoader
(
self
.
mnist_train
,
batch_size
=
self
.
batch_size
,
num_workers
=
self
.
hparams
.
workers
,
)
def
val_dataloader
(
self
):
# OPTIONAL
return
DataLoader
(
self
.
mnist_val
,
batch_size
=
self
.
batch_size
,
num_workers
=
self
.
hparams
.
workers
)
def
test_dataloader
(
self
):
# OPTIONAL
return
DataLoader
(
self
.
mnist_test
,
batch_size
=
self
.
batch_size
,
num_workers
=
self
.
hparams
.
workers
,
)
@staticmethod
def
add_data_specific_args
(
parent_parser
):
"""
Specify the hyperparams for this LightningModule
"""
# Dataset specific
parser
=
ArgumentParser
(
parents
=
[
parent_parser
],
add_help
=
False
)
parser
.
add_argument
(
"
--batch_size
"
,
default
=
32
,
type
=
int
)
parser
.
add_argument
(
"
--data_dir
"
,
default
=
"
./
"
,
type
=
str
)
# training specific
parser
.
add_argument
(
"
--train_size
"
,
default
=
55_000
,
type
=
int
)
parser
.
add_argument
(
"
--valid_size
"
,
default
=
5_000
,
type
=
int
)
parser
.
add_argument
(
"
--workers
"
,
default
=
8
,
type
=
int
)
return
parser
This diff is collapsed.
Click to expand it.
src/research_mnist/mnist_trainer.py
+
6
−
2
View file @
75b954bc
...
...
@@ -4,18 +4,20 @@ This file runs the main training/val loop, etc... using Lightning Trainer
from
pytorch_lightning
import
Trainer
,
seed_everything
from
argparse
import
ArgumentParser
from
src.research_mnist.mnist
import
CoolSystem
from
src.research_mnist.mnist_data_module
import
MNISTDataModule
# sets seeds for numpy, torch, etc...
# must do for DDP to work well
seed_everything
(
123
)
def
main
(
args
):
# init module
# init modules
dm
=
MNISTDataModule
(
hparams
=
args
)
model
=
CoolSystem
(
hparams
=
args
)
# most basic trainer, uses good defaults
trainer
=
Trainer
.
from_argparse_args
(
args
)
trainer
.
fit
(
model
)
trainer
.
fit
(
model
,
dm
)
trainer
.
test
()
...
...
@@ -29,6 +31,8 @@ if __name__ == '__main__':
# give the module a chance to add own params
# good practice to define LightningModule speficic params in the module
parser
=
CoolSystem
.
add_model_specific_args
(
parser
)
# same goes for data modules
parser
=
MNISTDataModule
.
add_data_specific_args
(
parser
)
# parse params
args
=
parser
.
parse_args
()
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment