mnist-c

MNIST digit recognition neural network in C
git clone https://git.sinitax.com/sinitax/mnist-c
Log | Files | Refs | sfeed.txt

main.c (21233B)


      1#include <ncurses.h>
      2
      3#include <sys/random.h>
      4#include <math.h>
      5#include <err.h>
      6#include <signal.h>
      7#include <assert.h>
      8#include <endian.h>
      9#include <stdlib.h>
     10#include <stdbool.h>
     11#include <string.h>
     12#include <stdio.h>
     13#include <stdint.h>
     14
     15#define ARRLEN(x) (sizeof(x)/sizeof((x)[0]))
     16
     17enum {
     18	U8 = 0x08,
     19	I8 = 0x09,
     20	I16 = 0x0B,
     21	I32 = 0x0C,
     22	F32 = 0x0D,
     23	F64 = 0x0E
     24};
     25
     26enum {
     27	IDENTITY,
     28	SIGMOID,
     29	SOFTMAX
     30};
     31
     32struct idx {
     33	void *data;
     34	uint32_t *dim;
     35	uint8_t dims;
     36	uint8_t dtype;
     37};
     38
     39struct layer_spec {
     40	int activation;
     41	bool has_bias;
     42	size_t len;
     43};
     44
     45struct layer {
     46	int activation;
     47	bool has_bias;
     48	size_t len, nodes;
     49	double *activity;
     50	double *input;
     51	double *derivs;
     52};
     53
     54struct nn {
     55	char *filepath;
     56	struct layer *layer;
     57	size_t layers;
     58	struct layer *input, *output;
     59	/* 3d matrix, [layer][source][target] */
     60	double ***weights;
     61	double ***deltas;
     62};
     63
     64static const struct layer_spec layers[] = {
     65	{ IDENTITY, true, 28 * 28 },
     66	{ SIGMOID, true, 20 },
     67	{ SOFTMAX, false, 10 },
     68};
     69
     70static const uint8_t idx_dtype_size[0x100] = {
     71	[U8] = 1,
     72	[I8] = 1,
     73	[I16] = 2,
     74	[I32] = 4,
     75	[F32] = 4,
     76	[F64] = 8
     77};
     78
     79static bool quit = false;
     80
     81void
     82train_stop(int sig)
     83{
     84	quit = true;
     85	printf("QUIT\n");
     86}
     87
     88double
     89dbl_be64toh(double d)
     90{
     91	uint64_t tmp;
     92
     93	tmp = *(uint64_t*)&d;
     94	tmp = be64toh(tmp);
     95	return *(double*)&tmp;
     96}
     97
     98double
     99dbl_htobe64(double d)
    100{
    101	uint64_t tmp;
    102
    103	tmp = *(uint64_t*)&d;
    104	tmp = htobe64(tmp);
    105	return *(double*)&tmp;
    106}
    107
    108void
    109idx_load(struct idx *idx, FILE *file, const char *path)
    110{
    111	uint8_t header[2];
    112	uint32_t count;
    113	uint32_t *counts;
    114	size_t size;
    115	int i;
    116
    117	if (fread(header, 1, 2, file) != 2)
    118		errx(1, "Missing idx header (%s)", path);
    119
    120	if (header[0] || header[1])
    121		errx(1, "Invalid idx header (%s)", path);
    122
    123	if (fread(&idx->dtype, 1, 1, file) != 1)
    124		errx(1, "Missing idx data type (%s)", path);
    125
    126	if (fread(&idx->dims, 1, 1, file) != 1)
    127		errx(1, "Missing idx dims (%s)", path);
    128
    129	if (!idx_dtype_size[idx->dtype])
    130		errx(1, "Invalid idx data type (%s)", path);
    131
    132	idx->dim = malloc(idx->dims * sizeof(uint32_t));
    133	if (!idx->dim) err(1, "malloc");
    134
    135	size = 1;
    136	for (i = 0; i < idx->dims; i++) {
    137		if (fread(&count, 4, 1, file) != 1)
    138			errx(1, "Missing %i. dimension size (%s)", i + 1, path);
    139		idx->dim[i] = be32toh(count);
    140		size *= idx->dim[i];
    141	}
    142
    143	idx->data = malloc(size * idx_dtype_size[idx->dtype]);;
    144	if (!idx->dtype) err(1, "malloc");
    145
    146	if (fread(idx->data, idx_dtype_size[idx->dtype], size, file) != size)
    147		errx(1, "Incomplete data section (%s)", path);
    148}
    149
    150void
    151idx_free(struct idx *idx)
    152{
    153	free(idx->data);
    154	idx->data = NULL;
    155	free(idx->dim);
    156	idx->dim = NULL;
    157}
    158
    159void
    160idx_load_single(struct idx *idx, const char *path)
    161{
    162	FILE *file;
    163
    164	file = fopen(path, "r");
    165	if (!file) err(1, "fopen (%s)", path);
    166	idx_load(idx, file, path);
    167	fclose(file);
    168}
    169
    170void
    171idx_load_images(struct idx *idx, const char *path)
    172{
    173	idx_load_single(idx, path);
    174	assert(idx->dims == 3);
    175	assert(idx->dim[1] == 28 && idx->dim[2] == 28);
    176	assert(idx->dtype == U8);
    177}
    178
    179void
    180idx_load_labels(struct idx *idx, const char *path)
    181{
    182	idx_load_single(idx, path);
    183	assert(idx->dims == 1);
    184	assert(idx->dtype == U8);
    185}
    186
    187void
    188idx_save(struct idx *idx, FILE *file, const char *path)
    189{
    190	uint8_t header[2];
    191	uint32_t count;
    192	size_t size;
    193	int i;
    194
    195	memset(header, 0, 2);
    196	if (fwrite(&header, 1, 2, file) != 2)
    197		err(1, "fwrite (%s)", path);
    198
    199	if (fwrite(&idx->dtype, 1, 1, file) != 1)
    200		err(1, "fwrite (%s)", path);
    201
    202	if (fwrite(&idx->dims, 1, 1, file) != 1)
    203		err(1, "fwrite (%s)", path);
    204
    205	size = 1;
    206	for (i = 0; i < idx->dims; i++) {
    207		count = htobe32(idx->dim[i]);
    208		if (fwrite(&count, 4, 1, file) != 1)
    209			err(1, "fwrite (%s)", path);
    210		size *= idx->dim[i];
    211	}
    212
    213	if (fwrite(idx->data, idx_dtype_size[idx->dtype], size, file) != size)
    214		err(1, "fwrite (%s)", path);
    215}
    216
    217void
    218nn_init(struct nn *nn, const struct layer_spec *spec, size_t layers)
    219{
    220	int l, k;
    221
    222	nn->filepath = NULL;
    223	nn->layers = layers;
    224	nn->layer = malloc(sizeof(struct layer) * nn->layers);
    225
    226	for (l = 0; l < nn->layers; l++) {
    227		nn->layer[l].len = spec[l].len;
    228		nn->layer[l].has_bias = spec[l].has_bias;
    229		nn->layer[l].activation = spec[l].activation;
    230
    231		nn->layer[l].nodes = spec[l].len + spec[l].has_bias;
    232
    233		nn->layer[l].input = calloc(nn->layer[l].nodes, sizeof(double));
    234		if (!nn->layer[l].input) err(1, "malloc");
    235
    236		nn->layer[l].activity = calloc(nn->layer[l].nodes, sizeof(double));
    237		if (!nn->layer[l].activity) err(1, "malloc");
    238
    239		nn->layer[l].derivs = calloc(nn->layer[l].nodes, sizeof(double));
    240		if (!nn->layer[l].derivs) err(1, "malloc");
    241	}
    242
    243	nn->input = &nn->layer[0];
    244	nn->output = &nn->layer[nn->layers - 1];
    245
    246	nn->deltas = malloc((nn->layers - 1) * sizeof(double *));
    247	if (!nn->deltas) err(1, "malloc");
    248
    249	nn->weights = malloc((nn->layers - 1) * sizeof(double *));
    250	if (!nn->weights) err(1, "malloc");
    251
    252	for (l = 0; l < nn->layers - 1; l++) {
    253		nn->deltas[l] = malloc(nn->layer[l].nodes * sizeof(double *));
    254		if (!nn->deltas[l]) err(1, "malloc");
    255		for (k = 0; k < nn->layer[l].nodes; k++) {
    256			nn->deltas[l][k] = calloc(nn->layer[l+1].len,
    257				sizeof(double));
    258			if (!nn->deltas[l][k]) err(1, "malloc");
    259		}
    260
    261		nn->weights[l] = malloc(nn->layer[l].nodes * sizeof(double *));
    262		if (!nn->weights[l]) err(1, "malloc");
    263		for (k = 0; k < nn->layer[l].nodes; k++) {
    264			nn->weights[l][k] = calloc(nn->layer[l+1].len,
    265				sizeof(double));
    266			if (!nn->weights[l][k]) err(1, "malloc");
    267		}
    268	}
    269}
    270
    271void
    272nn_gen(struct nn *nn)
    273{
    274	int l, s, t;
    275	uint32_t val;
    276
    277	/* initial weights */
    278	for (l = 0; l < nn->layers - 1; l++) {
    279		for (s = 0; s < nn->layer[l].nodes; s++) {
    280			for (t = 0; t < nn->layer[l+1].len; t++) {
    281				if (getrandom(&val, 4, 0) != 4)
    282					err(1, "getrandom");
    283				nn->weights[l][s][t] =
    284					((val / (double) 0xFFFFFFFF) - 0.5)
    285					/ nn->layer[l].nodes;
    286			}
    287		}
    288	}
    289}
    290
    291void
    292nn_load(struct nn *nn, const char *path)
    293{
    294	FILE *file;
    295	struct idx idx;
    296	double weight;
    297	int l, s, t;
    298	int snodes;
    299
    300	nn->filepath = strdup(path);
    301	if (!nn->filepath) err(1, "strdup");
    302
    303	file = fopen(path, "r");
    304	if (!file) err(1, "fopen (%s)", path);
    305
    306	/* load weights */
    307	for (l = 0; l < nn->layers - 1; l++) {
    308		idx_load(&idx, file, path);
    309		assert(idx.dtype == F64);
    310		assert(idx.dims == 2);
    311		assert(idx.dim[0] == nn->layer[l].nodes);
    312		assert(idx.dim[1] == nn->layer[l+1].len);
    313		snodes = nn->layer[l].nodes;
    314		for (s = 0; s < nn->layer[l].nodes; s++) {
    315			for (t = 0; t < nn->layer[l+1].len; t++) {
    316				weight = ((double*)idx.data)[t * snodes + s];
    317				nn->weights[l][s][t] = dbl_be64toh(weight);
    318			}
    319		}
    320		idx_free(&idx);
    321	}
    322
    323	fclose(file);
    324}
    325
    326void
    327nn_save(struct nn *nn, const char *path)
    328{
    329	FILE *file;
    330	struct idx idx;
    331	double weight;
    332	int l, s, t;
    333	int snodes;
    334
    335	file = fopen(path, "w+");
    336	if (!file) err(1, "fopen (%s)", path);
    337
    338	idx.dims = 2;
    339	idx.dim = malloc(idx.dims * sizeof(uint32_t));
    340	if (!idx.dim) err(1, "malloc");
    341	idx.dtype = F64;
    342
    343	/* save weights */
    344	for (l = 0; l < nn->layers - 1; l++) {
    345		idx.data = malloc(nn->layer[l].nodes
    346			* nn->layer[l+1].len * sizeof(double));
    347		if (!idx.data) err(1, "malloc");
    348		snodes = nn->layer[l].nodes;
    349		for (s = 0; s < nn->layer[l].nodes; s++) {
    350			for (t = 0; t < nn->layer[l+1].len; t++) {
    351				weight = dbl_htobe64(nn->weights[l][s][t]);
    352				((double *)idx.data)[t * snodes + s] = weight;
    353			}
    354		}
    355		idx.dim[0] = nn->layer[l].nodes;
    356		idx.dim[1] = nn->layer[l+1].len;
    357		idx_save(&idx, file, path);
    358		free(idx.data);
    359	}
    360
    361	free(idx.dim);
    362	fclose(file);
    363}
    364
    365void
    366nn_free(struct nn *nn)
    367{
    368	int l, k;
    369
    370	free(nn->filepath);
    371
    372	for (l = 0; l < nn->layers; l++) {
    373		free(nn->layer[l].derivs);
    374		free(nn->layer[l].activity);
    375		free(nn->layer[l].input);
    376		if (l < nn->layers - 1) {
    377			for (k = 0; k < nn->layer[l].nodes; k++) {
    378				free(nn->deltas[l][k]);
    379				free(nn->weights[l][k]);
    380			}
    381			free(nn->deltas[l]);
    382			free(nn->weights[l]);
    383		}
    384	}
    385
    386	free(nn->layer);
    387	free(nn->weights);
    388	free(nn->deltas);
    389}
    390
    391void
    392nn_fwdprop_layer(struct nn *nn, int l)
    393{
    394	struct layer *sl, *tl;
    395	double expsum, weight, max;
    396	int s, t;
    397
    398	sl = &nn->layer[l];
    399	tl = &nn->layer[l+1];
    400
    401	if (tl->has_bias)
    402		tl->activity[tl->len] = 1.0;
    403	for (t = 0; t < tl->len; t++) {
    404		tl->input[t] = 0;
    405		for (s = 0; s < sl->nodes; s++) {
    406			tl->input[t] += sl->activity[s]
    407				* nn->weights[l][s][t];
    408		}
    409	}
    410
    411	switch (tl->activation) {
    412	case IDENTITY:
    413		for (t = 0; t < tl->len; t++)
    414			tl->activity[t] = tl->input[t];
    415		break;
    416	case SIGMOID:
    417		for (t = 0; t < tl->len; t++)
    418			tl->activity[t] = 1 / (1 + exp(-tl->input[t]));
    419		break;
    420	case SOFTMAX:
    421		max = tl->input[0];
    422		for (t = 0; t < tl->len; t++)
    423			max = tl->input[t] > max ? tl->input[t] : max;
    424		expsum = 0;
    425		for (t = 0; t < tl->len; t++)
    426			expsum += exp(tl->input[t] - max);
    427		for (t = 0; t < tl->len; t++)
    428			tl->activity[t] = exp(tl->input[t] - max) / expsum;
    429		break;
    430	default:
    431		errx(1, "Unknown activation function (%i)", tl->activation);
    432	};
    433}
    434
    435
    436void
    437nn_fwdprop(struct nn *nn, uint8_t *image)
    438{
    439	int i, l;
    440
    441	nn->layer[0].activity[nn->layer[0].len] = 1.0;
    442	for (i = 0; i < nn->layer[0].len; i++)
    443		nn->layer[0].activity[i] = image[i] ? 1.0 : 0.0;
    444
    445	for (l = 0; l < nn->layers - 1; l++)
    446		nn_fwdprop_layer(nn, l);
    447}
    448
    449void
    450nn_backprop_layer(struct nn *nn, int l)
    451{
    452	struct layer *sl, *tl;
    453	int s, t, i;
    454	double sum;
    455
    456	sl = &nn->layer[l-1];
    457	tl = &nn->layer[l];
    458
    459	for (s = 0; s < sl->nodes; s++)
    460		sl->derivs[s] = 0;
    461
    462	switch (nn->layer[l].activation) {
    463	case IDENTITY:
    464		for (t = 0; t < tl->len; t++) {
    465			for (s = 0; s < sl->nodes; s++) {
    466				sl->derivs[s] += tl->derivs[t]
    467					* nn->weights[l-1][s][t];
    468			}
    469		}
    470		break;
    471	case SIGMOID:
    472		for (t = 0; t < tl->len; t++) {
    473			/* derivative of activation function */
    474			tl->derivs[t] *= tl->activity[t] * (1 - tl->activity[t]);
    475			for (s = 0; s < sl->nodes; s++) {
    476				sl->derivs[s] += tl->derivs[t]
    477					* nn->weights[l-1][s][t];
    478			}
    479		}
    480		break;
    481	case SOFTMAX:
    482		/* derivative of softmax function
    483		 * (each input i influences activity t) */
    484		for (t = 0; t < tl->nodes; t++) {
    485			sum = 0;
    486			for (i = 0; i < tl->nodes; i++) {
    487				sum += tl->derivs[i] * tl->activity[i]
    488					* ((t == i) - tl->activity[t]);
    489			}
    490			for (s = 0; s < sl->nodes; s++)
    491				sl->derivs[s] += sum * nn->weights[l-1][s][t];
    492		}
    493		break;
    494	}
    495}
    496
    497void
    498nn_backprop(struct nn *nn, uint8_t label)
    499{
    500	int i, l;
    501
    502	l = nn->layers - 1;
    503	for (i = 0; i < nn->layer[l].len; i++) {
    504		/* derivative of error: 1/2 (label - out)^2 */
    505		nn->layer[l].derivs[i] = nn->layer[l].activity[i]
    506			- (label == i ? 1.0 : 0.0);
    507	}
    508
    509	/* generate derivs of err / z_i per node */
    510	for (l = nn->layers - 1; l >= 1; l--)
    511		nn_backprop_layer(nn, l);
    512}
    513
    514void
    515nn_debug(struct nn *nn)
    516{
    517	int l, s, t;
    518
    519	printf("WEIGHTS:\n");
    520	for (l = 0; l < nn->layers - 1; l++) {
    521		printf("LAYER %i\n", l);
    522		for (s = 0; s < nn->layer[l].nodes; s++) {
    523			for (t = 0; t < nn->layer[l+1].len; t++) {
    524				printf("%0.3F ", nn->weights[l][s][t]);
    525			}
    526			printf("\n");
    527		}
    528		printf("\n");
    529	}
    530
    531	printf("DELTAS:\n");
    532	for (l = 0; l < nn->layers - 1; l++) {
    533		printf("LAYER %i\n", l);
    534		for (s = 0; s < nn->layer[l].nodes; s++) {
    535			for (t = 0; t < nn->layer[l+1].len; t++) {
    536				printf("%0.3F ", nn->deltas[l][s][t]);
    537			}
    538			printf("\n");
    539		}
    540		printf("\n");
    541	}
    542}
    543
    544void
    545nn_debug_prediction(struct nn *nn, uint8_t label)
    546{
    547	int k;
    548
    549	printf("%i : ", label);
    550	for (k = 0; k < nn->output->len; k++)
    551		printf("%2.0F ", 100 * nn->output->activity[k]);
    552	printf("\n");
    553}
    554
    555int
    556weight_color(double weight)
    557{
    558	int color;
    559
    560	if (weight >= 0) {
    561		if (weight < 0.01) {
    562			color = 22;
    563		} else if (weight < 0.1) {
    564			color = 28;
    565		} else if (weight < 1) {
    566			color = 34;
    567		} else if (weight < 10) {
    568			color = 40;
    569		} else {
    570			color = 46;
    571		}
    572	} else {
    573		if (weight > -0.01) {
    574			color = 52;
    575		} else if (weight > -0.1) {
    576			color = 88;
    577		} else if (weight > -1) {
    578			color = 124;
    579		} else if (weight > -10) {
    580			color = 160;
    581		} else {
    582			color = 196;
    583		}
    584	}
    585
    586	return color;
    587}
    588
    589void
    590nn_dump(struct nn *nn)
    591{
    592	int l, s, t, x, y;
    593	double weight;
    594
    595	printf("\n");
    596	for (t = 0; t < nn->layer[1].len; t++) {
    597		printf("INPUT -> HIDDEN %i\n", t);
    598		for (y = 0; y < 28; y++) {
    599			for (x = 0; x < 28 + (y == 27); x++) {
    600				weight = nn->weights[0][y * 28 + x][t];
    601				printf("\x1b[38:5:%im%s\x1b[0m",
    602					weight_color(weight),
    603					fabs(weight) >= 0.0001 ? "▮" : " ");
    604			}
    605			printf("\n");
    606		}
    607		printf("\n");
    608	}
    609
    610	if (nn->layers > 2) {
    611		printf("HIDDEN -> OUTPUT\n");
    612		for (t = 0; t < nn->layer[2].len; t++) {
    613			for (s = 0; s < nn->layer[1].nodes; s++) {
    614				weight = nn->weights[1][s][t];
    615				printf("\x1b[38:5:%im%s\x1b[0m",
    616					weight_color(weight),
    617					fabs(weight) >= 0.0001 ? "▮" : " ");
    618			}
    619			printf("\n");
    620		}
    621	}
    622
    623}
    624
    625void
    626nn_reset_deltas(struct nn *nn)
    627{
    628	int l, s, t;
    629
    630	for (l = 0; l < nn->layers - 1; l++) {
    631		for (s = 0; s < nn->layer[l].nodes; s++) {
    632			for (t = 0; t < nn->layer[l+1].len; t++)
    633				nn->deltas[l][s][t] = 0;
    634		}
    635	}
    636}
    637
    638void
    639nn_update_deltas(struct nn *nn, double learning_rate)
    640{
    641	int l, s, t;
    642	double gradw;
    643
    644	/* generate deltas for weights from err / z_i */
    645	for (l = nn->layers - 1; l >= 1; l--) {
    646		for (t = 0; t < nn->layer[l].len; t++) {
    647			for (s = 0; s < nn->layer[l-1].nodes; s++) {
    648				gradw = - nn->layer[l].derivs[t]
    649					* nn->layer[l-1].activity[s];
    650				nn->deltas[l-1][s][t] += gradw * learning_rate;
    651			}
    652		}
    653	}
    654}
    655
    656void
    657nn_apply_deltas(struct nn *nn, size_t size)
    658{
    659	int l, s, t;
    660
    661	for (l = nn->layers - 1; l >= 1; l--) {
    662		for (t = 0; t < nn->layer[l].len; t++) {
    663			for (s = 0; s < nn->layer[l-1].nodes; s++) {
    664				nn->weights[l-1][s][t] +=
    665					nn->deltas[l-1][s][t] / size;
    666				assert(!isnan(nn->weights[l-1][s][t]));
    667			}
    668		}
    669	}
    670}
    671
    672int
    673nn_result(struct nn *nn)
    674{
    675	double max;
    676	int k, maxi;
    677
    678	maxi = -1;
    679	for (k = 0; k < nn->output->len; k++) {
    680		if (maxi < 0 || nn->output->activity[k] > max) {
    681			max = nn->output->activity[k];
    682			maxi = k;
    683		}
    684	}
    685
    686	return maxi;
    687}
    688
    689double
    690nn_test(struct nn *nn)
    691{
    692	struct idx images;
    693	struct idx labels;
    694	size_t hits, total;
    695	int i, k, res;
    696	uint8_t label;
    697	double max;
    698
    699	idx_load_images(&images, "data/test-images.idx");
    700	idx_load_labels(&labels, "data/test-labels.idx");
    701
    702	total = hits = 0;
    703	for (i = 0; i < images.dim[0]; i++) {
    704		nn_fwdprop(nn, images.data + i * nn->input->len);
    705		label = *(uint8_t*)(labels.data + i);
    706		nn_debug_prediction(nn, label);
    707		res = nn_result(nn);
    708		if (res == label) hits++;
    709		total++;
    710	}
    711
    712	idx_free(&images);
    713	idx_free(&labels);
    714
    715	return 1.F * hits / total;
    716}
    717
    718double
    719nn_batch(struct nn *nn, struct idx *images, struct idx *labels,
    720	size_t batch_size, double learning_rate)
    721{
    722	double lerror, error;
    723	uint32_t idx;
    724	uint8_t label;
    725	size_t i, k;
    726
    727	nn_reset_deltas(nn);
    728
    729	error = 0;
    730	for (i = 0; i < batch_size; i++) {
    731		if (getrandom(&idx, 4, 0) != 4)
    732			err(1, "getrandom");
    733		idx = idx % images->dim[0];
    734		nn_fwdprop(nn, images->data + idx * nn->layer[0].len);
    735		label = *(uint8_t *)(labels->data + idx);
    736		for (k = 0; k < nn->output->nodes; k++) {
    737			lerror = nn->output->activity[k]
    738				- (k == label ? 1.0 : 0.0);
    739			error += 0.5 * lerror * lerror;
    740		}
    741		//nn_debug_prediction(nn, label);
    742		nn_backprop(nn, label);
    743		nn_update_deltas(nn, learning_rate);
    744	}
    745
    746	nn_apply_deltas(nn, batch_size);
    747
    748	return error / batch_size;
    749}
    750
    751void
    752nn_train(struct nn *nn, size_t epochs,
    753	size_t batch_size, double learning_rate)
    754{
    755	struct idx images;
    756	struct idx labels;
    757	double error;
    758	int epoch, i;
    759
    760	signal(SIGINT, train_stop);
    761
    762	idx_load_images(&images, "data/train-images.idx");
    763	idx_load_labels(&labels, "data/train-labels.idx");
    764
    765	for (epoch = 0; epoch < epochs; epoch++) {
    766		/* TODO: ensure using all images in one epoch */
    767		for (i = 0; i < images.dim[0] / batch_size; i++) {
    768			error = nn_batch(nn, &images, &labels,
    769				batch_size, learning_rate);
    770			if (i % 100 == 0) {
    771				//nn_debug(nn);
    772				//nn_dump(nn);
    773				printf("Batch %i / %lu => %2.5F\n", i + 1,
    774					images.dim[0] / batch_size, error);
    775			}
    776			if (quit) {
    777				nn_save(nn, nn->filepath);
    778				exit(1);
    779			}
    780		}
    781	}
    782
    783	idx_free(&images);
    784	idx_free(&labels);
    785}
    786
    787void
    788nn_trainvis(struct nn *nn, size_t batch_size, double learning_rate)
    789{
    790	struct idx images;
    791	struct idx labels;
    792	double error, weight;
    793	int epoch, i;
    794	int t, x, y;
    795	int sx, sy;
    796	bool show;
    797
    798	/* display weights visually after each batch
    799	 * and adjust batch frequency via UP / DOWN */
    800	signal(SIGINT, train_stop);
    801
    802	idx_load_images(&images, "data/train-images.idx");
    803	idx_load_labels(&labels, "data/train-labels.idx");
    804
    805	printf("\x1b[?25l"); /* hide cursor */
    806	printf("\x1b[2J"); /* clear screen */
    807
    808	while (!quit) {
    809		error = nn_batch(nn, &images, &labels,
    810			batch_size, learning_rate);
    811		if (quit) {
    812			nn_save(nn, nn->filepath);
    813			break;
    814		}
    815
    816		printf("\x1b[%i;%iHTraining error: %F", 2, 0, error);
    817
    818		assert(nn->layers > 1);
    819		for (t = 0; t < nn->layer[1].len; t++) {
    820			sy = (t >= nn->layer[1].len / 2) ? 35 : 5;
    821			sx = 2 + 30 * (t % (nn->layer[1].len / 2));
    822			for (y = 0; y < 28; y++) {
    823				for (x = 0; x < 28 + (y == 27); x++) {
    824					weight = nn->weights[0][y * 28 + x][t];
    825					show = fabs(weight) >= 0.0001;
    826					printf("\x1b[%i;%iH", sy + y, sx + x);
    827					printf("\x1b[38:5:%im%s\x1b[0m",
    828						weight_color(weight),
    829						show ? "▮" : " ");
    830				}
    831			}
    832		}
    833	}
    834
    835	printf("\x1b[?25h"); /* show cursor */
    836	printf("\x1b[2J"); /* clear screen */ 
    837
    838	idx_free(&images);
    839	idx_free(&labels);
    840}
    841
    842void
    843nn_predict(struct nn *nn)
    844{
    845	uint8_t image[28*28];
    846	WINDOW *win;
    847	MEVENT event;
    848	int width, height;
    849	int startx, starty;
    850	int x, y, c, i, label;
    851	bool evaluate;
    852
    853	/* gui interface to draw input and show prediction */
    854
    855	/* TODO: 256 color support, is this portable? */
    856	setenv("TERM", "xterm-1002", 1);
    857
    858	initscr();
    859	keypad(stdscr, true);
    860	noecho();
    861	cbreak();
    862	curs_set(0);
    863
    864	mousemask(ALL_MOUSE_EVENTS | REPORT_MOUSE_POSITION, NULL);
    865
    866	win = NULL;
    867	label = -1;
    868	evaluate = true;
    869	memset(image, 0, sizeof(image));
    870	while (!quit) {
    871		width = getmaxx(stdscr);
    872		height = getmaxy(stdscr);
    873		assert(width >= 30 && height >= 31);
    874
    875		startx = (width - 30) / 2;
    876		starty = (height - 30) / 2;
    877
    878		if (evaluate) {
    879			nn_fwdprop(nn, image);
    880			label = nn_result(nn);
    881			evaluate = false;
    882		}
    883
    884		if (!win) {
    885			win = newwin(30, 30, starty, startx);
    886			if (!win) err(1, "newwin");
    887		} else {
    888			mvwin(win, starty, startx);
    889		}
    890
    891		clear();
    892
    893		mvprintw(starty - 1, startx - 1, "Predictions: ");
    894		for (i = 0; i < nn->output->len; i++) {
    895			if (i == label) attron(A_UNDERLINE);
    896			if (nn->output->activity[i] >= 0.2)
    897				attron(A_BOLD);
    898			printw("%i", i);
    899			attroff(A_BOLD);
    900			attroff(A_UNDERLINE);
    901			printw(" ");
    902		}
    903		refresh();
    904
    905		box(win, 0, 0);
    906		for (y = 0; y < 28; y++) {
    907			for (x = 0; x < 28; x++) {
    908				if (image[y * 28 + x])
    909					mvwaddch(win, 1 + y, 1 + x, ACS_BLOCK);
    910				else
    911					mvwaddch(win, 1 + y, 1 + x, ' ');
    912			}
    913		}
    914		wrefresh(win);
    915
    916		switch ((c = getch())) {
    917		case KEY_MOUSE:
    918			if (getmouse(&event) != OK)
    919				err(1, "getmouse");
    920			x = event.x - (startx + 1);
    921			y = event.y - (starty + 1);
    922			if (x < 0 || x >= 28) continue;
    923			if (y < 0 || y >= 28) continue;
    924			image[y * 28 + x] = 1;
    925			if (y > 0) image[(y-1) * 28 + x] = 1;
    926			if (y < 27) image[(y+1) * 28 + x] = 1;
    927			if (x > 0) image[y * 28 + x - 1] = 1;
    928			if (x < 27) image[y * 28 + x + 1] = 1;
    929			if (event.bstate & BUTTON1_RELEASED ||
    930					event.bstate & BUTTON2_RELEASED)
    931				evaluate = true;
    932			break;
    933		case 'c':
    934			memset(image, 0, sizeof(image));
    935			break;
    936		case 'q':
    937			quit = true;
    938			break;
    939		}
    940	}
    941
    942	delwin(win);
    943	endwin();
    944}
    945
    946void
    947dump_sample(const char *set, size_t index)
    948{
    949	struct idx images, labels;
    950	uint8_t pix, *image;
    951	int x, y;
    952
    953	if (!strcmp(set, "train")) {
    954		idx_load_images(&images, "data/train-images.idx");
    955		idx_load_labels(&labels, "data/train-labels.idx");
    956	} else if (!strcmp(set, "test")) {
    957		idx_load_images(&images, "data/test-images.idx");
    958		idx_load_labels(&labels, "data/test-labels.idx");
    959	} else {
    960		errx(1, "Unknown dataset (%s)", set);
    961	}
    962
    963	image = images.data + 28 * 28 * index;
    964	assert(index < images.dim[0]);
    965	for (y = 0; y < 28; y++) {
    966		for (x = 0; x < 28; x++) {
    967			pix = image[y * 28 + x];
    968			printf("%s", pix ? "▮" : " ");
    969		}
    970		printf("\n");
    971	}
    972
    973	printf("Label: %i\n", *(uint8_t *)(labels.data + index));
    974
    975	idx_free(&images);
    976}
    977
    978int
    979main(int argc, const char **argv)
    980{
    981	struct nn nn;
    982
    983	if (argc == 2 && !strcmp(argv[1], "gen")) {
    984		nn_init(&nn, layers, ARRLEN(layers));
    985		nn_gen(&nn);
    986		nn_save(&nn, ".nn");
    987		nn_free(&nn);
    988	} else if (argc == 2 && !strcmp(argv[1], "train")) {
    989		nn_init(&nn, layers, ARRLEN(layers));
    990		nn_load(&nn, ".nn");
    991		nn_train(&nn, 10, 5, 0.005);
    992		nn_save(&nn, ".nn");
    993		nn_free(&nn);
    994	} else if (argc == 2 && !strcmp(argv[1], "trainvis")) {
    995		nn_init(&nn, layers, ARRLEN(layers));
    996		nn_load(&nn, ".nn");
    997		nn_trainvis(&nn, 5, 0.005);
    998		nn_save(&nn, ".nn");
    999		nn_free(&nn);
   1000	} else if (argc == 2 && !strcmp(argv[1], "predict")) {
   1001		nn_init(&nn, layers, ARRLEN(layers));
   1002		nn_load(&nn, ".nn");
   1003		nn_predict(&nn);
   1004		nn_free(&nn);
   1005	} else if (argc == 2 && !strcmp(argv[1], "test")) {
   1006		nn_init(&nn, layers, ARRLEN(layers));
   1007		nn_load(&nn, ".nn");
   1008		printf("Accuracy: %F\n", nn_test(&nn));
   1009		nn_free(&nn);
   1010	} else if (argc == 2 && !strcmp(argv[1], "dump")) {
   1011		nn_init(&nn, layers, ARRLEN(layers));
   1012		nn_load(&nn, ".nn");
   1013		nn_dump(&nn);
   1014		nn_free(&nn);
   1015	} else if (argc == 4 && !strcmp(argv[1], "sample")) {
   1016		dump_sample(argv[2], atoi(argv[3]));
   1017	} else {
   1018		printf("Commands: gen train trainvis predict test dump sample\n");
   1019	}
   1020}
   1021