summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLouis Burda <quent.burda@gmail.com>2022-08-09 02:00:04 +0200
committerLouis Burda <quent.burda@gmail.com>2022-08-09 15:53:23 +0200
commit2e126f4c2a8ac88bb6ab7e4d7a6baf7d571773bd (patch)
treebb9f48702eaeb03f69a51998c942aed40324cc4c
downloadmnist-c-2e126f4c2a8ac88bb6ab7e4d7a6baf7d571773bd.tar.gz
mnist-c-2e126f4c2a8ac88bb6ab7e4d7a6baf7d571773bd.zip
Basic setup, working inference
-rw-r--r--.gitignore3
-rw-r--r--Makefile11
-rw-r--r--main.c879
3 files changed, 893 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..99abf72
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,3 @@
+main
+.gdb_history
+.nn*
diff --git a/Makefile b/Makefile
new file mode 100644
index 0000000..c325fd1
--- /dev/null
+++ b/Makefile
@@ -0,0 +1,11 @@
+CFLAGS = -g
+LDLIBS = -lm -lncurses
+
+.PHONY: all clean
+
+all: main
+
+clean:
+ rm -f main
+
+main: main.c
diff --git a/main.c b/main.c
new file mode 100644
index 0000000..60c911b
--- /dev/null
+++ b/main.c
@@ -0,0 +1,879 @@
+#include <ncurses.h>
+
+#include <sys/random.h>
+#include <math.h>
+#include <err.h>
+#include <signal.h>
+#include <assert.h>
+#include <endian.h>
+#include <stdlib.h>
+#include <stdbool.h>
+#include <string.h>
+#include <stdio.h>
+#include <stdint.h>
+
+#define ARRLEN(x) (sizeof(x)/sizeof((x)[0]))
+
+enum {
+ U8 = 0x08,
+ I8 = 0x09,
+ I16 = 0x0B,
+ I32 = 0x0C,
+ F32 = 0x0D,
+ F64 = 0x0E
+};
+
+enum {
+ IDENTITY,
+ SIGMOID,
+ SOFTMAX
+};
+
+struct idx {
+ void *data;
+ uint32_t *dim;
+ uint8_t dims;
+ uint8_t dtype;
+};
+
+struct layer_spec {
+ int activation;
+ bool has_bias;
+ size_t len;
+};
+
+struct layer {
+ int activation;
+ bool has_bias;
+ size_t len, nodes;
+ double *activity;
+ double *input;
+ double *derivs;
+};
+
+struct nn {
+ char *filepath;
+ struct layer *layer;
+ size_t layers;
+ struct layer *input, *output;
+ /* 3d matrix, [layer][source][target] */
+ double ***weights;
+ double ***deltas;
+};
+
+static const struct layer_spec layers[] = {
+ { IDENTITY, true, 28 * 28 },
+ { SIGMOID, false, 10 },
+ //{ IDENTITY, false, 10 },
+};
+
+static const uint8_t idx_dtype_size[0x100] = {
+ [U8] = 1,
+ [I8] = 1,
+ [I16] = 2,
+ [I32] = 4,
+ [F32] = 4,
+ [F64] = 8
+};
+
+static bool quit = false;
+
+void
+sigint(int sig)
+{
+ quit = true;
+ printf("QUIT\n");
+}
+
+double
+dbl_be64toh(double d)
+{
+ uint64_t tmp;
+
+ tmp = *(uint64_t*)&d;
+ tmp = be64toh(tmp);
+ return *(double*)&tmp;
+}
+
+double
+dbl_htobe64(double d)
+{
+ uint64_t tmp;
+
+ tmp = *(uint64_t*)&d;
+ tmp = htobe64(tmp);
+ return *(double*)&tmp;
+}
+
+void
+idx_load(struct idx *idx, FILE *file, const char *path)
+{
+ uint8_t header[2];
+ uint32_t count;
+ uint32_t *counts;
+ size_t size;
+ int i;
+
+ if (fread(header, 1, 2, file) != 2)
+ errx(1, "Missing idx header (%s)", path);
+
+ if (header[0] || header[1])
+ errx(1, "Invalid idx header (%s)", path);
+
+ if (fread(&idx->dtype, 1, 1, file) != 1)
+ errx(1, "Missing idx data type (%s)", path);
+
+ if (fread(&idx->dims, 1, 1, file) != 1)
+ errx(1, "Missing idx dims (%s)", path);
+
+ if (!idx_dtype_size[idx->dtype])
+ errx(1, "Invalid idx data type (%s)", path);
+
+ idx->dim = malloc(idx->dims * sizeof(uint32_t));
+ if (!idx->dim) err(1, "malloc");
+
+ size = 1;
+ for (i = 0; i < idx->dims; i++) {
+ if (fread(&count, 4, 1, file) != 1)
+ errx(1, "Missing %i. dimension size (%s)", i + 1, path);
+ idx->dim[i] = be32toh(count);
+ size *= idx->dim[i];
+ }
+
+ idx->data = malloc(size * idx_dtype_size[idx->dtype]);;
+ if (!idx->dtype) err(1, "malloc");
+
+ if (fread(idx->data, idx_dtype_size[idx->dtype], size, file) != size)
+ errx(1, "Incomplete data section (%s)", path);
+}
+
+void
+idx_free(struct idx *idx)
+{
+ free(idx->data);
+ idx->data = NULL;
+ free(idx->dim);
+ idx->dim = NULL;
+}
+
+void
+idx_load_single(struct idx *idx, const char *path)
+{
+ FILE *file;
+
+ file = fopen(path, "r");
+ if (!file) err(1, "fopen (%s)", path);
+ idx_load(idx, file, path);
+ fclose(file);
+}
+
+void
+idx_load_images(struct idx *idx, const char *path)
+{
+ idx_load_single(idx, path);
+ assert(idx->dims == 3);
+ assert(idx->dim[1] == 28 && idx->dim[2] == 28);
+ assert(idx->dtype == U8);
+}
+
+void
+idx_load_labels(struct idx *idx, const char *path)
+{
+ idx_load_single(idx, path);
+ assert(idx->dims == 1);
+ assert(idx->dtype == U8);
+}
+
+void
+idx_save(struct idx *idx, FILE *file, const char *path)
+{
+ uint8_t header[2];
+ uint32_t count;
+ size_t size;
+ int i;
+
+ memset(header, 0, 2);
+ if (fwrite(&header, 1, 2, file) != 2)
+ err(1, "fwrite (%s)", path);
+
+ if (fwrite(&idx->dtype, 1, 1, file) != 1)
+ err(1, "fwrite (%s)", path);
+
+ if (fwrite(&idx->dims, 1, 1, file) != 1)
+ err(1, "fwrite (%s)", path);
+
+ size = 1;
+ for (i = 0; i < idx->dims; i++) {
+ count = htobe32(idx->dim[i]);
+ if (fwrite(&count, 4, 1, file) != 1)
+ err(1, "fwrite (%s)", path);
+ size *= idx->dim[i];
+ }
+
+ if (fwrite(idx->data, idx_dtype_size[idx->dtype], size, file) != size)
+ err(1, "fwrite (%s)", path);
+}
+
+void
+nn_init(struct nn *nn, const struct layer_spec *spec, size_t layers)
+{
+ int l, k;
+
+ nn->filepath = NULL;
+ nn->layers = layers;
+ nn->layer = malloc(sizeof(struct layer) * nn->layers);
+
+ for (l = 0; l < nn->layers; l++) {
+ nn->layer[l].len = spec[l].len;
+ nn->layer[l].has_bias = spec[l].has_bias;
+ nn->layer[l].activation = spec[l].activation;
+
+ nn->layer[l].nodes = spec[l].len + spec[l].has_bias;
+
+ nn->layer[l].input = calloc(nn->layer[l].nodes, sizeof(double));
+ if (!nn->layer[l].input) err(1, "malloc");
+
+ nn->layer[l].activity = calloc(nn->layer[l].nodes, sizeof(double));
+ if (!nn->layer[l].activity) err(1, "malloc");
+
+ nn->layer[l].derivs = calloc(nn->layer[l].nodes, sizeof(double));
+ if (!nn->layer[l].derivs) err(1, "malloc");
+ }
+
+ nn->input = &nn->layer[0];
+ nn->output = &nn->layer[nn->layers - 1];
+
+ nn->deltas = malloc((nn->layers - 1) * sizeof(double *));
+ if (!nn->deltas) err(1, "malloc");
+
+ nn->weights = malloc((nn->layers - 1) * sizeof(double *));
+ if (!nn->weights) err(1, "malloc");
+
+ for (l = 0; l < nn->layers - 1; l++) {
+ nn->deltas[l] = malloc(nn->layer[l].nodes * sizeof(double *));
+ if (!nn->deltas[l]) err(1, "malloc");
+ for (k = 0; k < nn->layer[l].nodes; k++) {
+ nn->deltas[l][k] = calloc(nn->layer[l+1].len,
+ sizeof(double));
+ if (!nn->deltas[l][k]) err(1, "malloc");
+ }
+
+ nn->weights[l] = malloc(nn->layer[l].nodes * sizeof(double *));
+ if (!nn->weights[l]) err(1, "malloc");
+ for (k = 0; k < nn->layer[l].nodes; k++) {
+ nn->weights[l][k] = calloc(nn->layer[l+1].len,
+ sizeof(double));
+ if (!nn->weights[l][k]) err(1, "malloc");
+ }
+ }
+}
+
+void
+nn_gen(struct nn *nn)
+{
+ int l, s, t;
+ uint32_t val;
+
+ /* initial weights */
+ for (l = 0; l < nn->layers - 1; l++) {
+ for (s = 0; s < nn->layer[l].nodes; s++) {
+ for (t = 0; t < nn->layer[l+1].nodes; t++) {
+ if (getrandom(&val, 4, 0) != 4)
+ err(1, "getrandom");
+ nn->weights[l][s][t] =
+ ((val / (double) 0xFFFFFFFF) - 0.5)
+ / nn->layer[l].nodes;
+ }
+ }
+ }
+}
+
+void
+nn_load(struct nn *nn, const char *path)
+{
+ FILE *file;
+ struct idx idx;
+ double weight;
+ int l, s, t;
+ int snodes;
+
+ nn->filepath = strdup(path);
+ if (!nn->filepath) err(1, "strdup");
+
+ file = fopen(path, "r");
+ if (!file) err(1, "fopen (%s)", path);
+
+ /* load weights */
+ for (l = 0; l < nn->layers - 1; l++) {
+ idx_load(&idx, file, path);
+ assert(idx.dtype == F64);
+ assert(idx.dims == 2);
+ assert(idx.dim[0] == nn->layer[l].nodes);
+ assert(idx.dim[1] == nn->layer[l+1].nodes);
+ snodes = nn->layer[l].nodes;
+ for (s = 0; s < nn->layer[l].nodes; s++) {
+ for (t = 0; t < nn->layer[l+1].nodes; t++) {
+ weight = ((double*)idx.data)[t * snodes + s];
+ nn->weights[l][s][t] = dbl_be64toh(weight);
+ }
+ }
+ idx_free(&idx);
+ }
+
+ fclose(file);
+}
+
+void
+nn_save(struct nn *nn, const char *path)
+{
+ FILE *file;
+ struct idx idx;
+ double weight;
+ int l, s, t;
+ int snodes;
+
+ file = fopen(path, "w+");
+ if (!file) err(1, "fopen (%s)", path);
+
+ idx.dims = 2;
+ idx.dim = malloc(idx.dims * sizeof(uint32_t));
+ if (!idx.dim) err(1, "malloc");
+ idx.dtype = F64;
+
+ /* save weights */
+ for (l = 0; l < nn->layers - 1; l++) {
+ idx.data = malloc(nn->layer[l].nodes
+ * nn->layer[l+1].nodes * sizeof(double));
+ if (!idx.data) err(1, "malloc");
+ snodes = nn->layer[l].nodes;
+ for (s = 0; s < nn->layer[l].nodes; s++) {
+ for (t = 0; t < nn->layer[l+1].nodes; t++) {
+ weight = dbl_htobe64(nn->weights[l][s][t]);
+ ((double *)idx.data)[t * snodes + s] = weight;
+ }
+ }
+ idx.dim[0] = nn->layer[l].nodes;
+ idx.dim[1] = nn->layer[l+1].nodes;
+ idx_save(&idx, file, path);
+ free(idx.data);
+ }
+
+ free(idx.dim);
+ fclose(file);
+}
+
+void
+nn_free(struct nn *nn)
+{
+ int l, k;
+
+ free(nn->filepath);
+
+ for (l = 0; l < nn->layers; l++) {
+ free(nn->layer[l].derivs);
+ free(nn->layer[l].activity);
+ free(nn->layer[l].input);
+ if (l < nn->layers - 1) {
+ for (k = 0; k < nn->layer[l].nodes; k++) {
+ free(nn->deltas[l][k]);
+ free(nn->weights[l][k]);
+ }
+ free(nn->deltas[l]);
+ free(nn->weights[l]);
+ }
+ }
+
+ free(nn->weights);
+ free(nn->deltas);
+}
+
+void
+nn_fwdprop_layer(struct nn *nn, int l)
+{
+ struct layer *sl, *tl;
+ double expsum, weight, max;
+ int s, t;
+
+ sl = &nn->layer[l];
+ tl = &nn->layer[l+1];
+
+ if (tl->has_bias)
+ tl->activity[tl->len] = 1.0;
+ for (t = 0; t < tl->len; t++) {
+ tl->input[t] = 0;
+ for (s = 0; s < sl->nodes; s++) {
+ tl->input[t] += sl->activity[s]
+ * nn->weights[l][s][t];
+ }
+ }
+
+ switch (tl->activation) {
+ case IDENTITY:
+ for (t = 0; t < tl->len; t++)
+ tl->activity[t] = tl->input[t];
+ break;
+ case SIGMOID:
+ for (t = 0; t < tl->len; t++)
+ tl->activity[t] = 1 / (1 + exp(-tl->input[t]));
+ break;
+ case SOFTMAX:
+ max = tl->input[0];
+ for (t = 0; t < tl->len; t++)
+ max = tl->input[t] > max ? tl->input[t] : max;
+ expsum = 0;
+ for (t = 0; t < tl->len; t++)
+ expsum += exp(tl->input[t] - max);
+ for (t = 0; t < tl->len; t++)
+ tl->activity[t] = exp(tl->input[t] - max) / expsum;
+ break;
+ default:
+ errx(1, "Unknown activation function (%i)", tl->activation);
+ };
+}
+
+
+void
+nn_fwdprop(struct nn *nn, uint8_t *image)
+{
+ int i, l;
+
+ nn->layer[0].activity[nn->layer[0].len] = 1.0;
+ for (i = 0; i < nn->layer[0].len; i++)
+ nn->layer[0].activity[i] = image[i] ? 1.0 : 0.0;
+
+ for (l = 0; l < nn->layers - 1; l++)
+ nn_fwdprop_layer(nn, l);
+}
+
+void
+nn_backprop_layer(struct nn *nn, int l)
+{
+ struct layer *sl, *tl;
+ int s, t, i;
+ double sum;
+
+ sl = &nn->layer[l-1];
+ tl = &nn->layer[l];
+
+ for (s = 0; s < sl->nodes; s++)
+ sl->derivs[s] = 0;
+
+ switch (nn->layer[l].activation) {
+ case IDENTITY:
+ for (t = 0; t < tl->len; t++) {
+ for (s = 0; s < sl->nodes; s++) {
+ sl->derivs[s] += tl->derivs[t]
+ * nn->weights[l-1][s][t];
+ }
+ }
+ break;
+ case SIGMOID:
+ for (t = 0; t < tl->len; t++) {
+ /* derivative of activation function */
+ tl->derivs[t] *= tl->activity[t] * (1 - tl->activity[t]);
+ for (s = 0; s < sl->nodes; s++) {
+ sl->derivs[s] += tl->derivs[t]
+ * nn->weights[l-1][s][t];
+ }
+ }
+ break;
+ case SOFTMAX:
+ /* derivative of softmax function
+ * (each input i influences activity t) */
+ for (t = 0; t < tl->nodes; t++) {
+ sum = 0;
+ for (i = 0; i < tl->nodes; i++) {
+ sum += tl->derivs[i] * tl->activity[i]
+ * ((t == i) - tl->activity[t]);
+ }
+ for (s = 0; s < sl->nodes; s++)
+ sl->derivs[s] += sum * nn->weights[l-1][s][t];
+ }
+ break;
+ }
+}
+
+void
+nn_backprop(struct nn *nn, uint8_t label)
+{
+ int i, l;
+
+ l = nn->layers - 1;
+ for (i = 0; i < nn->layer[l].len; i++) {
+ /* derivative of error: 1/2 (label - out)^2 */
+ nn->layer[l].derivs[i] = nn->layer[l].activity[i]
+ - (label == i ? 1.0 : 0.0);
+ }
+
+ /* generate derivs of err / z_i per node */
+ for (l = nn->layers - 1; l >= 1; l--)
+ nn_backprop_layer(nn, l);
+}
+
+void
+nn_debug(struct nn *nn)
+{
+ int l, s, t;
+
+ printf("WEIGHTS:\n");
+ for (l = 0; l < nn->layers - 1; l++) {
+ printf("LAYER %i\n", l);
+ for (s = 0; s < nn->layer[l].nodes; s++) {
+ for (t = 0; t < nn->layer[l+1].len; t++) {
+ printf("%0.3F ", nn->weights[l][s][t]);
+ }
+ printf("\n");
+ }
+ printf("\n");
+ }
+
+ printf("DELTAS:\n");
+ for (l = 0; l < nn->layers - 1; l++) {
+ printf("LAYER %i\n", l);
+ for (s = 0; s < nn->layer[l].nodes; s++) {
+ for (t = 0; t < nn->layer[l+1].len; t++) {
+ printf("%0.3F ", nn->deltas[l][s][t]);
+ }
+ printf("\n");
+ }
+ printf("\n");
+ }
+}
+
+void
+nn_debug_prediction(struct nn *nn, uint8_t label)
+{
+ int k;
+
+ printf("%i : ", label);
+ for (k = 0; k < nn->output->len; k++)
+ printf("%2.0F ", 100 * nn->output->activity[k]);
+ printf("\n");
+}
+
+void
+print_weight_pix(double weight)
+{
+ int color;
+
+ if (weight >= 0) {
+ if (weight < 0.01) {
+ color = 22;
+ } else if (weight < 0.1) {
+ color = 28;
+ } else if (weight < 1) {
+ color = 34;
+ } else if (weight < 10) {
+ color = 40;
+ } else {
+ color = 46;
+ }
+ } else {
+ if (weight > -0.01) {
+ color = 52;
+ } else if (weight > -0.1) {
+ color = 88;
+ } else if (weight > -1) {
+ color = 124;
+ } else if (weight > -10) {
+ color = 160;
+ } else {
+ color = 196;
+ }
+ }
+ printf("\x1b[38:5:%im", color);
+ printf("%s", fabs(weight) >= 0.0001 ? "▮" : " ");
+ printf("\x1b[0m");
+}
+
+void
+nn_dump(struct nn *nn)
+{
+ int l, s, t, x, y;
+ double weight;
+
+ printf("\n");
+ for (t = 0; t < nn->layer[1].len; t++) {
+ printf("INPUT -> HIDDEN %i\n", t);
+ for (y = 0; y < 28; y++) {
+ for (x = 0; x < 28 + (y == 27); x++) {
+ weight = nn->weights[0][y * 28 + x][t];
+ print_weight_pix(weight);
+ }
+ printf("\n");
+ }
+ printf("\n");
+ }
+
+ //printf("HIDDEN -> OUTPUT\n");
+ //for (t = 0; t < nn->layer[2].len; t++) {
+ // for (s = 0; s < nn->layer[1].nodes; s++) {
+ // weight = nn->weights[1][s][t];
+ // print_weight_pix(weight);
+ // }
+ // printf("\n");
+ //}
+
+}
+
+void
+nn_check_error(struct nn *nn, uint8_t *image, uint8_t *label)
+{
+ int i;
+
+ printf("ERROR:\n");
+ nn_fwdprop(nn, image);
+ for (i = 0; i < nn->output->len; i++) {
+ printf("OUT %i: %F %F\n", i,
+ nn->output->activity[i],
+ fabs(nn->output->activity[i]
+ - (*label == i ? 1.0 : 0.0)));
+ }
+}
+
+void
+nn_reset_deltas(struct nn *nn)
+{
+ int l, s, t;
+
+ for (l = 0; l < nn->layers - 1; l++) {
+ for (s = 0; s < nn->layer[l].nodes; s++) {
+ for (t = 0; t < nn->layer[l+1].len; t++)
+ nn->deltas[l][s][t] = 0;
+ }
+ }
+}
+
+void
+nn_update_deltas(struct nn *nn, double learning_rate)
+{
+ int l, s, t;
+ double gradw;
+
+ /* generate deltas for weights from err / z_i */
+ for (l = nn->layers - 1; l >= 1; l--) {
+ for (t = 0; t < nn->layer[l].len; t++) {
+ for (s = 0; s < nn->layer[l-1].nodes; s++) {
+ gradw = - nn->layer[l].derivs[t]
+ * nn->layer[l-1].activity[s];
+ nn->deltas[l-1][s][t] += gradw * learning_rate;
+ }
+ }
+ }
+}
+
+void
+nn_apply_deltas(struct nn *nn, size_t size)
+{
+ int l, s, t;
+
+ for (l = nn->layers - 1; l >= 1; l--) {
+ for (t = 0; t < nn->layer[l].len; t++) {
+ for (s = 0; s < nn->layer[l-1].nodes; s++) {
+ nn->weights[l-1][s][t] +=
+ nn->deltas[l-1][s][t] / size;
+ assert(!isnan(nn->weights[l-1][s][t]));
+ }
+ }
+ }
+}
+
+double
+nn_test(struct nn *nn)
+{
+ struct idx images;
+ struct idx labels;
+ size_t hits, total;
+ int i, k, maxi;
+ double max;
+
+ idx_load_images(&images, "data/test-images.idx");
+ idx_load_labels(&labels, "data/test-labels.idx");
+
+ total = hits = 0;
+ for (i = 0; i < images.dim[0]; i++) {
+ nn_fwdprop(nn, images.data + i * nn->input->len);
+ maxi = -1;
+ for (k = 0; k < nn->output->len; k++) {
+ if (maxi < 0 || nn->output->activity[k] > max) {
+ max = nn->output->activity[k];
+ maxi = k;
+ }
+ }
+ if (maxi == *(uint8_t*)(labels.data + i))
+ hits++;
+ total++;
+ }
+
+ idx_free(&images);
+ idx_free(&labels);
+
+ return 1.F * hits / total;
+}
+
+double
+nn_batch(struct nn *nn, struct idx *images, struct idx *labels,
+ size_t batch_size, double learning_rate)
+{
+ double lerror, error;
+ uint32_t idx;
+ uint8_t label;
+ size_t i, k;
+
+ nn_reset_deltas(nn);
+
+ error = 0;
+ for (i = 0; i < batch_size; i++) {
+ if (getrandom(&idx, 4, 0) != 4)
+ err(1, "getrandom");
+ idx = idx % images->dim[0];
+ nn_fwdprop(nn, images->data + idx * nn->layer[0].len);
+ label = *(uint8_t *)(labels->data + idx);
+ for (k = 0; k < nn->output->nodes; k++) {
+ lerror = nn->output->activity[k]
+ - (k == label ? 1.0 : 0.0);
+ error += 0.5 * lerror * lerror;
+ }
+ nn_debug_prediction(nn, label);
+ nn_backprop(nn, label);
+ nn_update_deltas(nn, learning_rate);
+ }
+
+ nn_apply_deltas(nn, batch_size);
+
+ return error / batch_size;
+}
+
+void
+nn_train(struct nn *nn, size_t epochs,
+ size_t batch_size, double learning_rate)
+{
+ struct idx images;
+ struct idx labels;
+ double error;
+ int epoch, i;
+
+ idx_load_images(&images, "data/train-images.idx");
+ idx_load_labels(&labels, "data/train-labels.idx");
+
+ for (epoch = 0; epoch < epochs; epoch++) {
+ /* TODO: ensure using all images in one epoch */
+ for (i = 0; i < images.dim[0] / batch_size; i++) {
+ error = nn_batch(nn, &images, &labels,
+ batch_size, learning_rate);
+ if (i % 1 == 0) {
+ nn_debug(nn);
+ // nn_check_error(nn, images.data, labels.data);
+ //nn_dump(nn);
+ printf("Batch %i / %lu => %2.5F\n", i + 1,
+ images.dim[0] / batch_size, error);
+ }
+ nn_save(nn, nn->filepath);
+ if (quit) exit(1);
+ }
+ }
+
+ idx_free(&images);
+ idx_free(&labels);
+}
+
+void
+nn_trainvis(struct nn *nn, size_t epochs,
+ size_t batch_size, double learning_rate)
+{
+ /* display weights visually after each batch
+ * and adjust batch frequency via UP / DOWN */
+
+}
+
+void
+nn_predict(struct nn *nn)
+{
+ struct idx images, labels;
+
+ /* gui interface to draw input and show prediction */
+
+
+
+
+}
+
+void
+dump_sample(const char *set, size_t index)
+{
+ struct idx images, labels;
+ uint8_t pix, *image;
+ int x, y;
+
+ if (!strcmp(set, "train")) {
+ idx_load_images(&images, "data/train-images.idx");
+ idx_load_labels(&labels, "data/train-labels.idx");
+ } else if (!strcmp(set, "test")) {
+ idx_load_images(&images, "data/test-images.idx");
+ idx_load_labels(&labels, "data/test-labels.idx");
+ } else {
+ errx(1, "Unknown dataset (%s)", set);
+ }
+
+ image = images.data + 28 * 28 * index;
+ assert(index < images.dim[0]);
+ for (y = 0; y < 28; y++) {
+ for (x = 0; x < 28; x++) {
+ pix = image[y * 28 + x];
+ printf("%s", pix ? "▮" : " ");
+ }
+ printf("\n");
+ }
+
+ printf("Label: %i\n", *(uint8_t *)(labels.data + index));
+
+ idx_free(&images);
+}
+
+int
+main(int argc, const char **argv)
+{
+ struct nn nn;
+
+ signal(SIGINT, sigint);
+
+ if (argc == 2 && !strcmp(argv[1], "gen")) {
+ nn_init(&nn, layers, ARRLEN(layers));
+ nn_gen(&nn);
+ nn_save(&nn, ".nn");
+ nn_free(&nn);
+ } else if (argc == 2 && !strcmp(argv[1], "train")) {
+ nn_init(&nn, layers, ARRLEN(layers));
+ nn_load(&nn, ".nn");
+ nn_train(&nn, 1, 10, 0.01);
+ nn_save(&nn, ".nn");
+ nn_free(&nn);
+ } else if (argc == 2 && !strcmp(argv[1], "trainvis")) {
+ nn_init(&nn, layers, ARRLEN(layers));
+ nn_load(&nn, ".nn");
+ nn_trainvis(&nn, 1, 10, 0.02);
+ nn_save(&nn, ".nn");
+ nn_free(&nn);
+ } else if (argc == 2 && !strcmp(argv[1], "predict")) {
+ nn_init(&nn, layers, ARRLEN(layers));
+ nn_load(&nn, ".nn");
+ nn_predict(&nn);
+ nn_save(&nn, ".nn");
+ nn_free(&nn);
+ } else if (argc == 2 && !strcmp(argv[1], "test")) {
+ nn_init(&nn, layers, ARRLEN(layers));
+ nn_load(&nn, ".nn");
+ printf("Accuracy: %F\n", nn_test(&nn));
+ nn_free(&nn);
+ } else if (argc == 2 && !strcmp(argv[1], "dump")) {
+ nn_init(&nn, layers, ARRLEN(layers));
+ nn_load(&nn, ".nn");
+ nn_dump(&nn);
+ nn_free(&nn);
+ } else if (argc == 4 && !strcmp(argv[1], "sample")) {
+ dump_sample(argv[2], atoi(argv[3]));
+ } else {
+ printf("USAGE: main (gen|train|test|sample) [ARGS..]\n");
+ }
+}
+