nimxla/train

Source   Edit  

The train module has common functions for training neural networks. It depends on the nn and data modules in addition to the core nimxla modules.

Types

Trainer = object
  client*: Client
  model*: Module
  optim*: Optimizer
  sched*: Scheduler
  trainer*: Executable
  tester*: Executable
  trainAcc*: Executable
  testAcc*: Executable
  stats*: Table[string, seq[float]]
  predict*: Tensor[float32]
  heatmap*: Tensor[int32]
  epoch*: int
Trainer object holds the state for a training Source   Edit  

Procs

proc accuracyFunc(c: Client; batch, nout: int; outType = F32; labelType = I32): Executable {.
    ...raises: [Exception, ValueError, BuilderError, XLAError], tags: [RootEffect],
    forbids: [].}
Helper function to calculate the accuracy from a set of predictions. Callable has two input parameters
  • model output array of shape <outType>[batch, nout]
  • target labels of shape <labelType>[batch]

And tuple with two outputs

  • labels array of predicted class for each sample, and
  • accuracy F32 scalar in range 0-1 from comparison with labels
Source   Edit  
proc getAccuracy[T: ElemType](t: var Trainer; loader: var DataLoader): (float,
    bool)
Calculate the accuracy from the test data set. T should be the type of data returned from the loader. Source   Edit  
proc getLayout(epochs: int): JsonNode {....raises: [], tags: [], forbids: [].}
Source   Edit  
proc loadCheckpoint(t: var Trainer; filename: string) {....raises: [OSError,
    ValueError, IOError, JsonParsingError, Exception, KeyError, JsonKindError,
    XLAError], tags: [ReadDirEffect, WriteDirEffect, ReadIOEffect,
                      WriteIOEffect, RootEffect], forbids: [].}
Read back checkpoint from zip file. Trainer should have already been initialised. Source   Edit  
proc readCheckpointFile(archiveFile, name: string): Stream {.
    ...raises: [OSError, ValueError, Exception, IOError], tags: [ReadDirEffect,
    WriteDirEffect, RootEffect, ReadIOEffect, WriteIOEffect], forbids: [].}
Read named file from the checkpoint archive Source   Edit  
proc saveCheckpoint(t: Trainer; basename: string) {.
    ...raises: [ValueError, OSError, IOError, XLAError, Exception],
    tags: [ReadDirEffect, WriteDirEffect, WriteIOEffect, RootEffect],
    forbids: [].}
Save checkpoint with model weights and optimizer state to file Source   Edit  
proc statsPlots(t: Trainer; classes: seq[string]): JsonNode {.
    ...raises: [KeyError, ValueError], tags: [], forbids: [].}
Convert stats to format used by plots package Source   Edit  
proc testFunc(c: Client; model: Module; dtype: DataType; shape: seq[int]): Executable {.
    ...raises: [Exception, ValueError, BuilderError, XLAError], tags: [RootEffect],
    forbids: [].}
Compile test function with given input data shape. Will apply softmax function to the output. Source   Edit  
proc trainEpoch[T: ElemType](t: var Trainer; loader: var DataLoader): (float,
    float, bool)
Train on one epoch of batches of data from the training set, returns average loss and accuracy on training dara T should be the type of data returned from the loader. Output loss should be a float32 scalar. Source   Edit  
proc trainFunc(c: Client; model: Module; dtype: DataType; shape: seq[int];
               lossFn: proc (yp, y: Node): Node): Executable {.
    ...raises: [Exception, ValueError, BuilderError, KeyError, XLAError],
    tags: [RootEffect], forbids: [].}
Compile training function with given input data shape and loss function which is applied to the output. Source   Edit  
proc trainNetwork[T: ElemType](t: var Trainer; train, test: var DataLoader;
                               epochs: int; plot = false; checkpoint = "";
                               saveEvery = 20)
Training run for given number of epochs. If transform is set it will be applied to each batch of training data. If checkpoint is set then a checkpoint file is written using this prefix Source   Edit