main.c (21233B)
1#include <ncurses.h> 2 3#include <sys/random.h> 4#include <math.h> 5#include <err.h> 6#include <signal.h> 7#include <assert.h> 8#include <endian.h> 9#include <stdlib.h> 10#include <stdbool.h> 11#include <string.h> 12#include <stdio.h> 13#include <stdint.h> 14 15#define ARRLEN(x) (sizeof(x)/sizeof((x)[0])) 16 17enum { 18 U8 = 0x08, 19 I8 = 0x09, 20 I16 = 0x0B, 21 I32 = 0x0C, 22 F32 = 0x0D, 23 F64 = 0x0E 24}; 25 26enum { 27 IDENTITY, 28 SIGMOID, 29 SOFTMAX 30}; 31 32struct idx { 33 void *data; 34 uint32_t *dim; 35 uint8_t dims; 36 uint8_t dtype; 37}; 38 39struct layer_spec { 40 int activation; 41 bool has_bias; 42 size_t len; 43}; 44 45struct layer { 46 int activation; 47 bool has_bias; 48 size_t len, nodes; 49 double *activity; 50 double *input; 51 double *derivs; 52}; 53 54struct nn { 55 char *filepath; 56 struct layer *layer; 57 size_t layers; 58 struct layer *input, *output; 59 /* 3d matrix, [layer][source][target] */ 60 double ***weights; 61 double ***deltas; 62}; 63 64static const struct layer_spec layers[] = { 65 { IDENTITY, true, 28 * 28 }, 66 { SIGMOID, true, 20 }, 67 { SOFTMAX, false, 10 }, 68}; 69 70static const uint8_t idx_dtype_size[0x100] = { 71 [U8] = 1, 72 [I8] = 1, 73 [I16] = 2, 74 [I32] = 4, 75 [F32] = 4, 76 [F64] = 8 77}; 78 79static bool quit = false; 80 81void 82train_stop(int sig) 83{ 84 quit = true; 85 printf("QUIT\n"); 86} 87 88double 89dbl_be64toh(double d) 90{ 91 uint64_t tmp; 92 93 tmp = *(uint64_t*)&d; 94 tmp = be64toh(tmp); 95 return *(double*)&tmp; 96} 97 98double 99dbl_htobe64(double d) 100{ 101 uint64_t tmp; 102 103 tmp = *(uint64_t*)&d; 104 tmp = htobe64(tmp); 105 return *(double*)&tmp; 106} 107 108void 109idx_load(struct idx *idx, FILE *file, const char *path) 110{ 111 uint8_t header[2]; 112 uint32_t count; 113 uint32_t *counts; 114 size_t size; 115 int i; 116 117 if (fread(header, 1, 2, file) != 2) 118 errx(1, "Missing idx header (%s)", path); 119 120 if (header[0] || header[1]) 121 errx(1, "Invalid idx header (%s)", path); 122 123 if (fread(&idx->dtype, 1, 1, file) != 1) 124 errx(1, "Missing idx data type (%s)", path); 125 126 if (fread(&idx->dims, 1, 1, file) != 1) 127 errx(1, "Missing idx dims (%s)", path); 128 129 if (!idx_dtype_size[idx->dtype]) 130 errx(1, "Invalid idx data type (%s)", path); 131 132 idx->dim = malloc(idx->dims * sizeof(uint32_t)); 133 if (!idx->dim) err(1, "malloc"); 134 135 size = 1; 136 for (i = 0; i < idx->dims; i++) { 137 if (fread(&count, 4, 1, file) != 1) 138 errx(1, "Missing %i. dimension size (%s)", i + 1, path); 139 idx->dim[i] = be32toh(count); 140 size *= idx->dim[i]; 141 } 142 143 idx->data = malloc(size * idx_dtype_size[idx->dtype]);; 144 if (!idx->dtype) err(1, "malloc"); 145 146 if (fread(idx->data, idx_dtype_size[idx->dtype], size, file) != size) 147 errx(1, "Incomplete data section (%s)", path); 148} 149 150void 151idx_free(struct idx *idx) 152{ 153 free(idx->data); 154 idx->data = NULL; 155 free(idx->dim); 156 idx->dim = NULL; 157} 158 159void 160idx_load_single(struct idx *idx, const char *path) 161{ 162 FILE *file; 163 164 file = fopen(path, "r"); 165 if (!file) err(1, "fopen (%s)", path); 166 idx_load(idx, file, path); 167 fclose(file); 168} 169 170void 171idx_load_images(struct idx *idx, const char *path) 172{ 173 idx_load_single(idx, path); 174 assert(idx->dims == 3); 175 assert(idx->dim[1] == 28 && idx->dim[2] == 28); 176 assert(idx->dtype == U8); 177} 178 179void 180idx_load_labels(struct idx *idx, const char *path) 181{ 182 idx_load_single(idx, path); 183 assert(idx->dims == 1); 184 assert(idx->dtype == U8); 185} 186 187void 188idx_save(struct idx *idx, FILE *file, const char *path) 189{ 190 uint8_t header[2]; 191 uint32_t count; 192 size_t size; 193 int i; 194 195 memset(header, 0, 2); 196 if (fwrite(&header, 1, 2, file) != 2) 197 err(1, "fwrite (%s)", path); 198 199 if (fwrite(&idx->dtype, 1, 1, file) != 1) 200 err(1, "fwrite (%s)", path); 201 202 if (fwrite(&idx->dims, 1, 1, file) != 1) 203 err(1, "fwrite (%s)", path); 204 205 size = 1; 206 for (i = 0; i < idx->dims; i++) { 207 count = htobe32(idx->dim[i]); 208 if (fwrite(&count, 4, 1, file) != 1) 209 err(1, "fwrite (%s)", path); 210 size *= idx->dim[i]; 211 } 212 213 if (fwrite(idx->data, idx_dtype_size[idx->dtype], size, file) != size) 214 err(1, "fwrite (%s)", path); 215} 216 217void 218nn_init(struct nn *nn, const struct layer_spec *spec, size_t layers) 219{ 220 int l, k; 221 222 nn->filepath = NULL; 223 nn->layers = layers; 224 nn->layer = malloc(sizeof(struct layer) * nn->layers); 225 226 for (l = 0; l < nn->layers; l++) { 227 nn->layer[l].len = spec[l].len; 228 nn->layer[l].has_bias = spec[l].has_bias; 229 nn->layer[l].activation = spec[l].activation; 230 231 nn->layer[l].nodes = spec[l].len + spec[l].has_bias; 232 233 nn->layer[l].input = calloc(nn->layer[l].nodes, sizeof(double)); 234 if (!nn->layer[l].input) err(1, "malloc"); 235 236 nn->layer[l].activity = calloc(nn->layer[l].nodes, sizeof(double)); 237 if (!nn->layer[l].activity) err(1, "malloc"); 238 239 nn->layer[l].derivs = calloc(nn->layer[l].nodes, sizeof(double)); 240 if (!nn->layer[l].derivs) err(1, "malloc"); 241 } 242 243 nn->input = &nn->layer[0]; 244 nn->output = &nn->layer[nn->layers - 1]; 245 246 nn->deltas = malloc((nn->layers - 1) * sizeof(double *)); 247 if (!nn->deltas) err(1, "malloc"); 248 249 nn->weights = malloc((nn->layers - 1) * sizeof(double *)); 250 if (!nn->weights) err(1, "malloc"); 251 252 for (l = 0; l < nn->layers - 1; l++) { 253 nn->deltas[l] = malloc(nn->layer[l].nodes * sizeof(double *)); 254 if (!nn->deltas[l]) err(1, "malloc"); 255 for (k = 0; k < nn->layer[l].nodes; k++) { 256 nn->deltas[l][k] = calloc(nn->layer[l+1].len, 257 sizeof(double)); 258 if (!nn->deltas[l][k]) err(1, "malloc"); 259 } 260 261 nn->weights[l] = malloc(nn->layer[l].nodes * sizeof(double *)); 262 if (!nn->weights[l]) err(1, "malloc"); 263 for (k = 0; k < nn->layer[l].nodes; k++) { 264 nn->weights[l][k] = calloc(nn->layer[l+1].len, 265 sizeof(double)); 266 if (!nn->weights[l][k]) err(1, "malloc"); 267 } 268 } 269} 270 271void 272nn_gen(struct nn *nn) 273{ 274 int l, s, t; 275 uint32_t val; 276 277 /* initial weights */ 278 for (l = 0; l < nn->layers - 1; l++) { 279 for (s = 0; s < nn->layer[l].nodes; s++) { 280 for (t = 0; t < nn->layer[l+1].len; t++) { 281 if (getrandom(&val, 4, 0) != 4) 282 err(1, "getrandom"); 283 nn->weights[l][s][t] = 284 ((val / (double) 0xFFFFFFFF) - 0.5) 285 / nn->layer[l].nodes; 286 } 287 } 288 } 289} 290 291void 292nn_load(struct nn *nn, const char *path) 293{ 294 FILE *file; 295 struct idx idx; 296 double weight; 297 int l, s, t; 298 int snodes; 299 300 nn->filepath = strdup(path); 301 if (!nn->filepath) err(1, "strdup"); 302 303 file = fopen(path, "r"); 304 if (!file) err(1, "fopen (%s)", path); 305 306 /* load weights */ 307 for (l = 0; l < nn->layers - 1; l++) { 308 idx_load(&idx, file, path); 309 assert(idx.dtype == F64); 310 assert(idx.dims == 2); 311 assert(idx.dim[0] == nn->layer[l].nodes); 312 assert(idx.dim[1] == nn->layer[l+1].len); 313 snodes = nn->layer[l].nodes; 314 for (s = 0; s < nn->layer[l].nodes; s++) { 315 for (t = 0; t < nn->layer[l+1].len; t++) { 316 weight = ((double*)idx.data)[t * snodes + s]; 317 nn->weights[l][s][t] = dbl_be64toh(weight); 318 } 319 } 320 idx_free(&idx); 321 } 322 323 fclose(file); 324} 325 326void 327nn_save(struct nn *nn, const char *path) 328{ 329 FILE *file; 330 struct idx idx; 331 double weight; 332 int l, s, t; 333 int snodes; 334 335 file = fopen(path, "w+"); 336 if (!file) err(1, "fopen (%s)", path); 337 338 idx.dims = 2; 339 idx.dim = malloc(idx.dims * sizeof(uint32_t)); 340 if (!idx.dim) err(1, "malloc"); 341 idx.dtype = F64; 342 343 /* save weights */ 344 for (l = 0; l < nn->layers - 1; l++) { 345 idx.data = malloc(nn->layer[l].nodes 346 * nn->layer[l+1].len * sizeof(double)); 347 if (!idx.data) err(1, "malloc"); 348 snodes = nn->layer[l].nodes; 349 for (s = 0; s < nn->layer[l].nodes; s++) { 350 for (t = 0; t < nn->layer[l+1].len; t++) { 351 weight = dbl_htobe64(nn->weights[l][s][t]); 352 ((double *)idx.data)[t * snodes + s] = weight; 353 } 354 } 355 idx.dim[0] = nn->layer[l].nodes; 356 idx.dim[1] = nn->layer[l+1].len; 357 idx_save(&idx, file, path); 358 free(idx.data); 359 } 360 361 free(idx.dim); 362 fclose(file); 363} 364 365void 366nn_free(struct nn *nn) 367{ 368 int l, k; 369 370 free(nn->filepath); 371 372 for (l = 0; l < nn->layers; l++) { 373 free(nn->layer[l].derivs); 374 free(nn->layer[l].activity); 375 free(nn->layer[l].input); 376 if (l < nn->layers - 1) { 377 for (k = 0; k < nn->layer[l].nodes; k++) { 378 free(nn->deltas[l][k]); 379 free(nn->weights[l][k]); 380 } 381 free(nn->deltas[l]); 382 free(nn->weights[l]); 383 } 384 } 385 386 free(nn->layer); 387 free(nn->weights); 388 free(nn->deltas); 389} 390 391void 392nn_fwdprop_layer(struct nn *nn, int l) 393{ 394 struct layer *sl, *tl; 395 double expsum, weight, max; 396 int s, t; 397 398 sl = &nn->layer[l]; 399 tl = &nn->layer[l+1]; 400 401 if (tl->has_bias) 402 tl->activity[tl->len] = 1.0; 403 for (t = 0; t < tl->len; t++) { 404 tl->input[t] = 0; 405 for (s = 0; s < sl->nodes; s++) { 406 tl->input[t] += sl->activity[s] 407 * nn->weights[l][s][t]; 408 } 409 } 410 411 switch (tl->activation) { 412 case IDENTITY: 413 for (t = 0; t < tl->len; t++) 414 tl->activity[t] = tl->input[t]; 415 break; 416 case SIGMOID: 417 for (t = 0; t < tl->len; t++) 418 tl->activity[t] = 1 / (1 + exp(-tl->input[t])); 419 break; 420 case SOFTMAX: 421 max = tl->input[0]; 422 for (t = 0; t < tl->len; t++) 423 max = tl->input[t] > max ? tl->input[t] : max; 424 expsum = 0; 425 for (t = 0; t < tl->len; t++) 426 expsum += exp(tl->input[t] - max); 427 for (t = 0; t < tl->len; t++) 428 tl->activity[t] = exp(tl->input[t] - max) / expsum; 429 break; 430 default: 431 errx(1, "Unknown activation function (%i)", tl->activation); 432 }; 433} 434 435 436void 437nn_fwdprop(struct nn *nn, uint8_t *image) 438{ 439 int i, l; 440 441 nn->layer[0].activity[nn->layer[0].len] = 1.0; 442 for (i = 0; i < nn->layer[0].len; i++) 443 nn->layer[0].activity[i] = image[i] ? 1.0 : 0.0; 444 445 for (l = 0; l < nn->layers - 1; l++) 446 nn_fwdprop_layer(nn, l); 447} 448 449void 450nn_backprop_layer(struct nn *nn, int l) 451{ 452 struct layer *sl, *tl; 453 int s, t, i; 454 double sum; 455 456 sl = &nn->layer[l-1]; 457 tl = &nn->layer[l]; 458 459 for (s = 0; s < sl->nodes; s++) 460 sl->derivs[s] = 0; 461 462 switch (nn->layer[l].activation) { 463 case IDENTITY: 464 for (t = 0; t < tl->len; t++) { 465 for (s = 0; s < sl->nodes; s++) { 466 sl->derivs[s] += tl->derivs[t] 467 * nn->weights[l-1][s][t]; 468 } 469 } 470 break; 471 case SIGMOID: 472 for (t = 0; t < tl->len; t++) { 473 /* derivative of activation function */ 474 tl->derivs[t] *= tl->activity[t] * (1 - tl->activity[t]); 475 for (s = 0; s < sl->nodes; s++) { 476 sl->derivs[s] += tl->derivs[t] 477 * nn->weights[l-1][s][t]; 478 } 479 } 480 break; 481 case SOFTMAX: 482 /* derivative of softmax function 483 * (each input i influences activity t) */ 484 for (t = 0; t < tl->nodes; t++) { 485 sum = 0; 486 for (i = 0; i < tl->nodes; i++) { 487 sum += tl->derivs[i] * tl->activity[i] 488 * ((t == i) - tl->activity[t]); 489 } 490 for (s = 0; s < sl->nodes; s++) 491 sl->derivs[s] += sum * nn->weights[l-1][s][t]; 492 } 493 break; 494 } 495} 496 497void 498nn_backprop(struct nn *nn, uint8_t label) 499{ 500 int i, l; 501 502 l = nn->layers - 1; 503 for (i = 0; i < nn->layer[l].len; i++) { 504 /* derivative of error: 1/2 (label - out)^2 */ 505 nn->layer[l].derivs[i] = nn->layer[l].activity[i] 506 - (label == i ? 1.0 : 0.0); 507 } 508 509 /* generate derivs of err / z_i per node */ 510 for (l = nn->layers - 1; l >= 1; l--) 511 nn_backprop_layer(nn, l); 512} 513 514void 515nn_debug(struct nn *nn) 516{ 517 int l, s, t; 518 519 printf("WEIGHTS:\n"); 520 for (l = 0; l < nn->layers - 1; l++) { 521 printf("LAYER %i\n", l); 522 for (s = 0; s < nn->layer[l].nodes; s++) { 523 for (t = 0; t < nn->layer[l+1].len; t++) { 524 printf("%0.3F ", nn->weights[l][s][t]); 525 } 526 printf("\n"); 527 } 528 printf("\n"); 529 } 530 531 printf("DELTAS:\n"); 532 for (l = 0; l < nn->layers - 1; l++) { 533 printf("LAYER %i\n", l); 534 for (s = 0; s < nn->layer[l].nodes; s++) { 535 for (t = 0; t < nn->layer[l+1].len; t++) { 536 printf("%0.3F ", nn->deltas[l][s][t]); 537 } 538 printf("\n"); 539 } 540 printf("\n"); 541 } 542} 543 544void 545nn_debug_prediction(struct nn *nn, uint8_t label) 546{ 547 int k; 548 549 printf("%i : ", label); 550 for (k = 0; k < nn->output->len; k++) 551 printf("%2.0F ", 100 * nn->output->activity[k]); 552 printf("\n"); 553} 554 555int 556weight_color(double weight) 557{ 558 int color; 559 560 if (weight >= 0) { 561 if (weight < 0.01) { 562 color = 22; 563 } else if (weight < 0.1) { 564 color = 28; 565 } else if (weight < 1) { 566 color = 34; 567 } else if (weight < 10) { 568 color = 40; 569 } else { 570 color = 46; 571 } 572 } else { 573 if (weight > -0.01) { 574 color = 52; 575 } else if (weight > -0.1) { 576 color = 88; 577 } else if (weight > -1) { 578 color = 124; 579 } else if (weight > -10) { 580 color = 160; 581 } else { 582 color = 196; 583 } 584 } 585 586 return color; 587} 588 589void 590nn_dump(struct nn *nn) 591{ 592 int l, s, t, x, y; 593 double weight; 594 595 printf("\n"); 596 for (t = 0; t < nn->layer[1].len; t++) { 597 printf("INPUT -> HIDDEN %i\n", t); 598 for (y = 0; y < 28; y++) { 599 for (x = 0; x < 28 + (y == 27); x++) { 600 weight = nn->weights[0][y * 28 + x][t]; 601 printf("\x1b[38:5:%im%s\x1b[0m", 602 weight_color(weight), 603 fabs(weight) >= 0.0001 ? "▮" : " "); 604 } 605 printf("\n"); 606 } 607 printf("\n"); 608 } 609 610 if (nn->layers > 2) { 611 printf("HIDDEN -> OUTPUT\n"); 612 for (t = 0; t < nn->layer[2].len; t++) { 613 for (s = 0; s < nn->layer[1].nodes; s++) { 614 weight = nn->weights[1][s][t]; 615 printf("\x1b[38:5:%im%s\x1b[0m", 616 weight_color(weight), 617 fabs(weight) >= 0.0001 ? "▮" : " "); 618 } 619 printf("\n"); 620 } 621 } 622 623} 624 625void 626nn_reset_deltas(struct nn *nn) 627{ 628 int l, s, t; 629 630 for (l = 0; l < nn->layers - 1; l++) { 631 for (s = 0; s < nn->layer[l].nodes; s++) { 632 for (t = 0; t < nn->layer[l+1].len; t++) 633 nn->deltas[l][s][t] = 0; 634 } 635 } 636} 637 638void 639nn_update_deltas(struct nn *nn, double learning_rate) 640{ 641 int l, s, t; 642 double gradw; 643 644 /* generate deltas for weights from err / z_i */ 645 for (l = nn->layers - 1; l >= 1; l--) { 646 for (t = 0; t < nn->layer[l].len; t++) { 647 for (s = 0; s < nn->layer[l-1].nodes; s++) { 648 gradw = - nn->layer[l].derivs[t] 649 * nn->layer[l-1].activity[s]; 650 nn->deltas[l-1][s][t] += gradw * learning_rate; 651 } 652 } 653 } 654} 655 656void 657nn_apply_deltas(struct nn *nn, size_t size) 658{ 659 int l, s, t; 660 661 for (l = nn->layers - 1; l >= 1; l--) { 662 for (t = 0; t < nn->layer[l].len; t++) { 663 for (s = 0; s < nn->layer[l-1].nodes; s++) { 664 nn->weights[l-1][s][t] += 665 nn->deltas[l-1][s][t] / size; 666 assert(!isnan(nn->weights[l-1][s][t])); 667 } 668 } 669 } 670} 671 672int 673nn_result(struct nn *nn) 674{ 675 double max; 676 int k, maxi; 677 678 maxi = -1; 679 for (k = 0; k < nn->output->len; k++) { 680 if (maxi < 0 || nn->output->activity[k] > max) { 681 max = nn->output->activity[k]; 682 maxi = k; 683 } 684 } 685 686 return maxi; 687} 688 689double 690nn_test(struct nn *nn) 691{ 692 struct idx images; 693 struct idx labels; 694 size_t hits, total; 695 int i, k, res; 696 uint8_t label; 697 double max; 698 699 idx_load_images(&images, "data/test-images.idx"); 700 idx_load_labels(&labels, "data/test-labels.idx"); 701 702 total = hits = 0; 703 for (i = 0; i < images.dim[0]; i++) { 704 nn_fwdprop(nn, images.data + i * nn->input->len); 705 label = *(uint8_t*)(labels.data + i); 706 nn_debug_prediction(nn, label); 707 res = nn_result(nn); 708 if (res == label) hits++; 709 total++; 710 } 711 712 idx_free(&images); 713 idx_free(&labels); 714 715 return 1.F * hits / total; 716} 717 718double 719nn_batch(struct nn *nn, struct idx *images, struct idx *labels, 720 size_t batch_size, double learning_rate) 721{ 722 double lerror, error; 723 uint32_t idx; 724 uint8_t label; 725 size_t i, k; 726 727 nn_reset_deltas(nn); 728 729 error = 0; 730 for (i = 0; i < batch_size; i++) { 731 if (getrandom(&idx, 4, 0) != 4) 732 err(1, "getrandom"); 733 idx = idx % images->dim[0]; 734 nn_fwdprop(nn, images->data + idx * nn->layer[0].len); 735 label = *(uint8_t *)(labels->data + idx); 736 for (k = 0; k < nn->output->nodes; k++) { 737 lerror = nn->output->activity[k] 738 - (k == label ? 1.0 : 0.0); 739 error += 0.5 * lerror * lerror; 740 } 741 //nn_debug_prediction(nn, label); 742 nn_backprop(nn, label); 743 nn_update_deltas(nn, learning_rate); 744 } 745 746 nn_apply_deltas(nn, batch_size); 747 748 return error / batch_size; 749} 750 751void 752nn_train(struct nn *nn, size_t epochs, 753 size_t batch_size, double learning_rate) 754{ 755 struct idx images; 756 struct idx labels; 757 double error; 758 int epoch, i; 759 760 signal(SIGINT, train_stop); 761 762 idx_load_images(&images, "data/train-images.idx"); 763 idx_load_labels(&labels, "data/train-labels.idx"); 764 765 for (epoch = 0; epoch < epochs; epoch++) { 766 /* TODO: ensure using all images in one epoch */ 767 for (i = 0; i < images.dim[0] / batch_size; i++) { 768 error = nn_batch(nn, &images, &labels, 769 batch_size, learning_rate); 770 if (i % 100 == 0) { 771 //nn_debug(nn); 772 //nn_dump(nn); 773 printf("Batch %i / %lu => %2.5F\n", i + 1, 774 images.dim[0] / batch_size, error); 775 } 776 if (quit) { 777 nn_save(nn, nn->filepath); 778 exit(1); 779 } 780 } 781 } 782 783 idx_free(&images); 784 idx_free(&labels); 785} 786 787void 788nn_trainvis(struct nn *nn, size_t batch_size, double learning_rate) 789{ 790 struct idx images; 791 struct idx labels; 792 double error, weight; 793 int epoch, i; 794 int t, x, y; 795 int sx, sy; 796 bool show; 797 798 /* display weights visually after each batch 799 * and adjust batch frequency via UP / DOWN */ 800 signal(SIGINT, train_stop); 801 802 idx_load_images(&images, "data/train-images.idx"); 803 idx_load_labels(&labels, "data/train-labels.idx"); 804 805 printf("\x1b[?25l"); /* hide cursor */ 806 printf("\x1b[2J"); /* clear screen */ 807 808 while (!quit) { 809 error = nn_batch(nn, &images, &labels, 810 batch_size, learning_rate); 811 if (quit) { 812 nn_save(nn, nn->filepath); 813 break; 814 } 815 816 printf("\x1b[%i;%iHTraining error: %F", 2, 0, error); 817 818 assert(nn->layers > 1); 819 for (t = 0; t < nn->layer[1].len; t++) { 820 sy = (t >= nn->layer[1].len / 2) ? 35 : 5; 821 sx = 2 + 30 * (t % (nn->layer[1].len / 2)); 822 for (y = 0; y < 28; y++) { 823 for (x = 0; x < 28 + (y == 27); x++) { 824 weight = nn->weights[0][y * 28 + x][t]; 825 show = fabs(weight) >= 0.0001; 826 printf("\x1b[%i;%iH", sy + y, sx + x); 827 printf("\x1b[38:5:%im%s\x1b[0m", 828 weight_color(weight), 829 show ? "▮" : " "); 830 } 831 } 832 } 833 } 834 835 printf("\x1b[?25h"); /* show cursor */ 836 printf("\x1b[2J"); /* clear screen */ 837 838 idx_free(&images); 839 idx_free(&labels); 840} 841 842void 843nn_predict(struct nn *nn) 844{ 845 uint8_t image[28*28]; 846 WINDOW *win; 847 MEVENT event; 848 int width, height; 849 int startx, starty; 850 int x, y, c, i, label; 851 bool evaluate; 852 853 /* gui interface to draw input and show prediction */ 854 855 /* TODO: 256 color support, is this portable? */ 856 setenv("TERM", "xterm-1002", 1); 857 858 initscr(); 859 keypad(stdscr, true); 860 noecho(); 861 cbreak(); 862 curs_set(0); 863 864 mousemask(ALL_MOUSE_EVENTS | REPORT_MOUSE_POSITION, NULL); 865 866 win = NULL; 867 label = -1; 868 evaluate = true; 869 memset(image, 0, sizeof(image)); 870 while (!quit) { 871 width = getmaxx(stdscr); 872 height = getmaxy(stdscr); 873 assert(width >= 30 && height >= 31); 874 875 startx = (width - 30) / 2; 876 starty = (height - 30) / 2; 877 878 if (evaluate) { 879 nn_fwdprop(nn, image); 880 label = nn_result(nn); 881 evaluate = false; 882 } 883 884 if (!win) { 885 win = newwin(30, 30, starty, startx); 886 if (!win) err(1, "newwin"); 887 } else { 888 mvwin(win, starty, startx); 889 } 890 891 clear(); 892 893 mvprintw(starty - 1, startx - 1, "Predictions: "); 894 for (i = 0; i < nn->output->len; i++) { 895 if (i == label) attron(A_UNDERLINE); 896 if (nn->output->activity[i] >= 0.2) 897 attron(A_BOLD); 898 printw("%i", i); 899 attroff(A_BOLD); 900 attroff(A_UNDERLINE); 901 printw(" "); 902 } 903 refresh(); 904 905 box(win, 0, 0); 906 for (y = 0; y < 28; y++) { 907 for (x = 0; x < 28; x++) { 908 if (image[y * 28 + x]) 909 mvwaddch(win, 1 + y, 1 + x, ACS_BLOCK); 910 else 911 mvwaddch(win, 1 + y, 1 + x, ' '); 912 } 913 } 914 wrefresh(win); 915 916 switch ((c = getch())) { 917 case KEY_MOUSE: 918 if (getmouse(&event) != OK) 919 err(1, "getmouse"); 920 x = event.x - (startx + 1); 921 y = event.y - (starty + 1); 922 if (x < 0 || x >= 28) continue; 923 if (y < 0 || y >= 28) continue; 924 image[y * 28 + x] = 1; 925 if (y > 0) image[(y-1) * 28 + x] = 1; 926 if (y < 27) image[(y+1) * 28 + x] = 1; 927 if (x > 0) image[y * 28 + x - 1] = 1; 928 if (x < 27) image[y * 28 + x + 1] = 1; 929 if (event.bstate & BUTTON1_RELEASED || 930 event.bstate & BUTTON2_RELEASED) 931 evaluate = true; 932 break; 933 case 'c': 934 memset(image, 0, sizeof(image)); 935 break; 936 case 'q': 937 quit = true; 938 break; 939 } 940 } 941 942 delwin(win); 943 endwin(); 944} 945 946void 947dump_sample(const char *set, size_t index) 948{ 949 struct idx images, labels; 950 uint8_t pix, *image; 951 int x, y; 952 953 if (!strcmp(set, "train")) { 954 idx_load_images(&images, "data/train-images.idx"); 955 idx_load_labels(&labels, "data/train-labels.idx"); 956 } else if (!strcmp(set, "test")) { 957 idx_load_images(&images, "data/test-images.idx"); 958 idx_load_labels(&labels, "data/test-labels.idx"); 959 } else { 960 errx(1, "Unknown dataset (%s)", set); 961 } 962 963 image = images.data + 28 * 28 * index; 964 assert(index < images.dim[0]); 965 for (y = 0; y < 28; y++) { 966 for (x = 0; x < 28; x++) { 967 pix = image[y * 28 + x]; 968 printf("%s", pix ? "▮" : " "); 969 } 970 printf("\n"); 971 } 972 973 printf("Label: %i\n", *(uint8_t *)(labels.data + index)); 974 975 idx_free(&images); 976} 977 978int 979main(int argc, const char **argv) 980{ 981 struct nn nn; 982 983 if (argc == 2 && !strcmp(argv[1], "gen")) { 984 nn_init(&nn, layers, ARRLEN(layers)); 985 nn_gen(&nn); 986 nn_save(&nn, ".nn"); 987 nn_free(&nn); 988 } else if (argc == 2 && !strcmp(argv[1], "train")) { 989 nn_init(&nn, layers, ARRLEN(layers)); 990 nn_load(&nn, ".nn"); 991 nn_train(&nn, 10, 5, 0.005); 992 nn_save(&nn, ".nn"); 993 nn_free(&nn); 994 } else if (argc == 2 && !strcmp(argv[1], "trainvis")) { 995 nn_init(&nn, layers, ARRLEN(layers)); 996 nn_load(&nn, ".nn"); 997 nn_trainvis(&nn, 5, 0.005); 998 nn_save(&nn, ".nn"); 999 nn_free(&nn); 1000 } else if (argc == 2 && !strcmp(argv[1], "predict")) { 1001 nn_init(&nn, layers, ARRLEN(layers)); 1002 nn_load(&nn, ".nn"); 1003 nn_predict(&nn); 1004 nn_free(&nn); 1005 } else if (argc == 2 && !strcmp(argv[1], "test")) { 1006 nn_init(&nn, layers, ARRLEN(layers)); 1007 nn_load(&nn, ".nn"); 1008 printf("Accuracy: %F\n", nn_test(&nn)); 1009 nn_free(&nn); 1010 } else if (argc == 2 && !strcmp(argv[1], "dump")) { 1011 nn_init(&nn, layers, ARRLEN(layers)); 1012 nn_load(&nn, ".nn"); 1013 nn_dump(&nn); 1014 nn_free(&nn); 1015 } else if (argc == 4 && !strcmp(argv[1], "sample")) { 1016 dump_sample(argv[2], atoi(argv[3])); 1017 } else { 1018 printf("Commands: gen train trainvis predict test dump sample\n"); 1019 } 1020} 1021