The train module has common functions for training neural networks. It depends on the nn and data modules in addition to the core nimxla modules.
Types
Trainer = object client*: Client model*: Module optim*: Optimizer sched*: Scheduler trainer*: Executable tester*: Executable trainAcc*: Executable testAcc*: Executable stats*: Table[string, seq[float]] predict*: Tensor[float32] heatmap*: Tensor[int32] epoch*: int
- Trainer object holds the state for a training Source Edit
Procs
proc accuracyFunc(c: Client; batch, nout: int; outType = F32; labelType = I32): Executable {. ...raises: [Exception, ValueError, BuilderError, XLAError], tags: [RootEffect], forbids: [].}
-
Helper function to calculate the accuracy from a set of predictions. Callable has two input parameters
- model output array of shape <outType>[batch, nout]
- target labels of shape <labelType>[batch]
And tuple with two outputs
- labels array of predicted class for each sample, and
- accuracy F32 scalar in range 0-1 from comparison with labels
proc getAccuracy[T: ElemType](t: var Trainer; loader: var DataLoader): (float, bool)
- Calculate the accuracy from the test data set. T should be the type of data returned from the loader. Source Edit
proc loadCheckpoint(t: var Trainer; filename: string) {....raises: [OSError, ValueError, IOError, JsonParsingError, Exception, KeyError, JsonKindError, XLAError], tags: [ReadDirEffect, WriteDirEffect, ReadIOEffect, WriteIOEffect, RootEffect], forbids: [].}
- Read back checkpoint from zip file. Trainer should have already been initialised. Source Edit
proc readCheckpointFile(archiveFile, name: string): Stream {. ...raises: [OSError, ValueError, Exception, IOError], tags: [ReadDirEffect, WriteDirEffect, RootEffect, ReadIOEffect, WriteIOEffect], forbids: [].}
- Read named file from the checkpoint archive Source Edit
proc saveCheckpoint(t: Trainer; basename: string) {. ...raises: [ValueError, OSError, IOError, XLAError, Exception], tags: [ReadDirEffect, WriteDirEffect, WriteIOEffect, RootEffect], forbids: [].}
- Save checkpoint with model weights and optimizer state to file Source Edit
proc statsPlots(t: Trainer; classes: seq[string]): JsonNode {. ...raises: [KeyError, ValueError], tags: [], forbids: [].}
- Convert stats to format used by plots package Source Edit
proc trainEpoch[T: ElemType](t: var Trainer; loader: var DataLoader): (float, float, bool)
- Train on one epoch of batches of data from the training set, returns average loss and accuracy on training dara T should be the type of data returned from the loader. Output loss should be a float32 scalar. Source Edit
proc trainFunc(c: Client; model: Module; dtype: DataType; shape: seq[int]; lossFn: proc (yp, y: Node): Node): Executable {. ...raises: [Exception, ValueError, BuilderError, KeyError, XLAError], tags: [RootEffect], forbids: [].}
- Compile training function with given input data shape and loss function which is applied to the output. Source Edit
proc trainNetwork[T: ElemType](t: var Trainer; train, test: var DataLoader; epochs: int; plot = false; checkpoint = ""; saveEvery = 20)
- Training run for given number of epochs. If transform is set it will be applied to each batch of training data. If checkpoint is set then a checkpoint file is written using this prefix Source Edit