/* Did someone say *Minimalism* David Fouhey Thanks to Even Shelhamer for the name and Richard Zhang for the encouragement */ .intel_syntax noprefix .equ SYS_READ, 0 .equ SYS_WRITE, 1 .equ SYS_OPEN, 2 .equ SYS_CLOSE, 3 .equ SYS_LSEEK, 8 .equ SYS_EXIT, 60 .equ SYS_BRK, 12 .data errstr: .asciz " miniml R N F C X.txt y.txt w.txt\n\tor\n miniml E N F X.txt w.txt p.txt\n" .equ errstr_len, . - errstr v0: .float 0 v0p1: .float 0.1 v0p5: .float 0.5 v1: .float 1.0 v3: .float 3.0 v10: .float 10.0 vn1: .float -1.0 vlr: .float 1E-5 chrtype: .byte 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 .section .text .globl _start _start: /* argc */ mov rdi, [rsp] /* argv */ lea rsi, [rsp+8] call main /* if we return signal no error */ mov eax, SYS_EXIT xor edi, edi syscall badargs: mov rax, SYS_WRITE mov edi, 2 lea rsi, [errstr] mov edx, errstr_len syscall mov eax, SYS_EXIT mov edi, 1 syscall main: /* essentially switch to either train or test mode or yell at the user */ cmp rdi, 7 jl badargs mov rax, [rsi+8] mov al, byte ptr [rax] cmp al, 0x45 je main_te cmp al, 0x52 jne badargs cmp rdi, 8 je main_tr jmp badargs main_te: /* lrasm E N F C X.txt w.txt o.txt r12 = char **argv r13 = N r14 = F r15 = heap pointer rbx = NxF heap layout (total: NxF+F+N) 0 = X NxF = w (N+1)xF = pred y Note: we can clobber everything, it's just SYS_EXIT */ /* parse args */ mov r12, rsi mov rdi, [r12+16] call sint mov r13, rax mov rdi, [r12+24] call sint mov r14, rax imul rax, r13 mov rbx, rax /* allocate NxF+F+N floats */ mov rdi, rbx add rdi, r14 add rdi, r13 shl rdi, 2 call alloc mov r15, rax /* read in files */ /* X */ mov rdi, [r12+32] mov rsi, r15 mov rdx, rbx call readmat /* w */ mov rdi, [r12+40] mov rsi, r15 /* move forward sizeof(float)*N*F bytes on heap */ lea rsi, [rsi+4*rbx] mov rdx, r14 push rsi call readmat pop r8 /* rdi = N, rsi = F, rdx = float *X, rcx = float *y, r8 = float *w */ mov rdi, r13 mov rsi, r14 mov rdx, r15 /* rcx = r8 + F floats */ lea rcx, [r8+4*r14] /* rcx has the target */ push rcx call computeLR pop rsi /* write it out */ mov rdi, [r12+48] /* rsi set before */ mov rdx, r13 call writevec ret main_tr: /* lrasm R N F C X.txt y.txt w.txt r12 = char **argv r13 = N r14 = F r15 = heap pointer rbx = NxF heap layout: 0 = X NxF = y Nx(F+1) = w Nx(F+1)+F = scratch */ /* parse our args */ mov r12, rsi mov rdi, [r12+16] call sint mov r13, rax mov rdi, [r12+24] call sint mov r14, rax imul rax, r13 mov rbx, rax /* allocate: X = NxF , y = N , w = F , scratch = F tot = (F+1)*N + 2F */ /* 2F -> rdi */ mov rdi, r14 shl rdi, 1 /* (F+1)*N += -> rdi */ mov rax, rbx add rax, r13 add rdi, rax shl rdi, 2 call alloc mov r15, rax /* read in training files */ /* trainX */ mov rdi, [r12+40] mov rsi, r15 mov rdx, rbx call readmat /* trainY */ mov rdi, [r12+48] mov rsi, r15 /* move forward sizeof(float)*N*F bytes on heap */ lea rsi, [rsi+4*rbx] mov rdx, r13 call readmat /* parse C now so we don't have to store it, but before we setup fitLR */ mov rdi, [r12+32] call sfloat /* call the fitting procedure. PITA call */ /* N, F, float *X */ mov rdi, r13 mov rsi, r14 mov rdx, r15 /* float *y is at float *X + (NxFx4) */ lea rcx, [rdx+4*rbx] /* float *w @ float *y + Nx4 */ lea r8, [rcx+4*r13] /* float *buffer @ float *w + F */ lea r9, [r8+4*r14] /* xmm0 is already loaded */ /* phew call it */ call fitLR /* calculate offset to w into r8 */ mov rax, rbx add rax, r13 lea r8, [r15+4*rax] /* write it out */ mov rdi, [r12+56] mov rsi, r8 mov rdx, r14 call writevec ret /* rdi: how much to alloc -> r12 heap start -> r13 */ alloc: push r12 push r13 mov r12, rdi /* get */ mov eax, SYS_BRK xor edi, edi syscall mov r13, rax /* increment and set */ add rax, r12 mov rdi, rax mov eax, SYS_BRK syscall /* return */ mov rax, r13 pop r13 pop r12 ret printchar: /* print char in low byte of rsi to file id'd by rdi */ push rbp mov rbp, rsp sub rsp, 8 mov eax, SYS_WRITE /* rdi is already set */ mov byte ptr [rbp-8], sil lea rsi, [rbp-8] mov edx, 1 syscall add rsp, 8 pop rbp ret printnl: /* print a newline to the file in rdi */ mov sil, 0xa jmp printchar printFloat: /* super dirty print a float from xmm0 to the file in rdi we'll also use knowledge of what printchar and syscalls clobber */ push rbp push rbx push r12 mov rbp, rsp sub rsp, 16 /* stash mxcsr and a copy that we'll muck with */ stmxcsr [rbp-8] stmxcsr [rbp-16] /* set round down mode by setting bits 14,13 to 0,1*/ mov eax, [rbp-16] and ax, 0x9fff or ax, 0x2000 mov [rbp-16], eax ldmxcsr [rbp-16] /* check to see if we print a sign */ comiss xmm0, [v0] ja printFloat_pos mulss xmm0, [vn1] mov sil, 0x2d call printchar printFloat_pos: /* we are now printing a positive float */ movss xmm1, xmm0 movss xmm2, [v0p1] /* use rbx so we don't have to preserve */ xor ebx, ebx countdigits: comiss xmm0, [v1] jna printFloat_donecounting inc rbx divss xmm0, [v10] mulss xmm2, [v10] jmp countdigits printFloat_donecounting: /* special case of only fractional, print 0. */ cmp rbx, 0 jne printFloat_noleadzero mov sil, 0x30 call printchar mov sil, 0x2e call printchar printFloat_noleadzero: /* rbx has number of digits to the left of the decimal and xmm2 = 10^rbx 6 sig figs because */ mov r12, rbx add rbx, 6 movss xmm0, xmm1 printFloat_loop: movss xmm3, xmm0 divss xmm3, xmm2 cvtss2si rax, xmm3 cvtsi2ss xmm3, rax mov sil, al add sil, 0x30 call printchar mulss xmm3, xmm2 subss xmm0, xmm3 divss xmm2, [v10] /* how many until decimal */ sub r12, 1 jnz printFloat_reloop mov sil, 0x2e call printchar printFloat_reloop: sub rbx, 1 jnz printFloat_loop /* restore state */ ldmxcsr [rbp-8] add rsp, 16 pop r12 pop rbx pop rbp ret writevec: /* rdi -> char *s -> r12 rsi -> float *M -> r13 rdx -> N, number of floats to write -> r14 fid -> r15 */ push r12 push r13 push r14 push r15 /* store args for safekeeping */ mov r12, rdi mov r13, rsi mov r14, rdx /* open */ mov eax, SYS_OPEN mov rdi, r12 mov esi, 0x0041 mov edx, 0777 syscall mov r15, rax writevec_loop: mov rdi, r15 movss xmm0, [r13] call printFloat call printnl add r13, 4 sub r14, 1 jnz writevec_loop /* close */ mov eax, SYS_CLOSE mov rdi, r15 syscall pop r15 pop r14 pop r13 pop r12 ret readmat: /* rdi -> char *s -> r12 rsi -> float *M -> r13 rdx -> N, number of floats to read -> r14 fid -> r15 buffer ptr -> rbx */ push rbp push r12 push r13 push r14 push r15 push rbx mov r12, rdi mov r13, rsi mov r14, rdx mov rbp, rsp sub rsp, 264 /* move the buffer start to rbx and zero-end it */ mov rbx, rsp /* open the file */ mov eax, SYS_OPEN mov rdi, r12 xor esi, esi xor edx, edx syscall mov r15, rax readmat_rloop: /* loop invariants: file pointer is before the next value. may have to chew whitespace r13 points to next float to write to rdx lists how float to write to r14 lists how many floats to read */ /* load it into a buffer */ mov eax, SYS_READ mov rdi, r15 mov rsi, rbx mov edx, 255 syscall mov r8, rax /* null terminate the end; we can then rely on the fact that it's null-terminated */ mov byte ptr [rbx+r8+1], 0 /* rewind where we are in the file, we'll ffwd as needed once we know how many to ffwd */ mov eax, SYS_LSEEK mov rdi, r15 mov rsi, r8 neg rsi mov edx, 1 syscall /* now we have up to 255 bytes at rbx, we need to find the end of the next float, scan using rdi */ lea r9, [chrtype] mov rdi, rbx readmat_wsloop: movzxb rax, byte ptr [rdi] add rax, r9 mov al, [rax] cmp al, 1 je readmat_scaf add rdi, 1 jmp readmat_wsloop readmat_scaf: /* ok now we're at the start of the float with rdi, scan to find the end of the float */ mov rsi, rdi readmap_scaf_loop: movzxb rax, byte ptr [rdi] add rax, r9 mov al, [rax] cmp al, 1 jne readmap_scaf_done add rdi, 1 jmp readmap_scaf_loop readmap_scaf_done: /* ok, so rsi is the start of the float in the buffer, rdi is the first non float char We need to: (1) seek forward rdi-rbx in the file (2) scanf starting at rsi (so set *rdi to 0) */ mov byte ptr [rdi], 0 /* calculate rdi-rbx, and save it across calls, popping it back into r8 */ sub rdi, rbx push rdi /* call sfloat on the char * at rsi and store to the target*/ mov rdi, rsi call sfloat movss [r13], xmm0 add r13, 4 pop r8 /* seek to next float in file */ mov eax, SYS_LSEEK mov rdi, r15 mov rsi, r8 mov edx, 1 syscall sub r14, 1 jnz readmat_rloop add rsp, 264 pop rbx pop r15 pop r14 pop r13 pop r12 pop rbp ret sint: /* rdi -> char *s */ xor eax, eax xor r8d, r8d /* check if neg */ cmp byte ptr [rdi], 0x2d jne sint_rloop add rdi, 1 add r8b, 1 sint_rloop: movzxb rcx, byte ptr [rdi] cmp cl, 0 je sint_fin /* rax * 10 */ imul rax, 10 sub rcx, 0x30 add rax, rcx add rdi, 1 jmp sint_rloop sint_fin: cmp r8, 0 je sint_ret neg rax sint_ret: ret sfloat: /* rdi -> char *s */ /* xmm0 = int part, xmm1 = frac part, xmm2 = 0.1, xmm3 = 10, r8 = sign flip */ pxor xmm0, xmm0 pxor xmm1, xmm1 movss xmm2, v0p1 movss xmm3, v10 xor r8d, r8d /* is it negative */ cmp byte ptr [rdi], 0x2d jne sfloat_irloop add r8b, 1 add rdi, 1 sfloat_irloop: movzxb rcx, byte ptr [rdi] cmp cl, 0 je sfloat_fin cmp cl, 0x2e je sfloat_ifin mulss xmm0, xmm3 sub rcx, 0x30 cvtsi2ss xmm4, ecx addss xmm0, xmm4 add rdi, 1 jmp sfloat_irloop sfloat_ifin: /* done reading integer part, and we only fall through here for a dot*/ add rdi, 1 /* now read fractional part */ sfloat_frloop: movzxb rcx, byte ptr [rdi] cmp cl, 0 je sfloat_fin sub rcx, 0x30 cvtsi2ss xmm4, ecx mulss xmm4, xmm2 addss xmm1, xmm4 divss xmm2, xmm3 add rdi, 1 jmp sfloat_frloop sfloat_fin: /* return integer + fractional parts */ addss xmm0, xmm1 cmp r8, 0 je sfloat_ret mulss xmm0, [vn1] sfloat_ret: ret computeLR: /* rdi = N -> r11 rsi = F -> r12 rdx = float *X -> r13 rcx = float *y -> r14 r8 = float *w -> r15 */ push r12 push r13 push r14 push r15 mov r11, rdi mov r12, rsi mov r13, rdx mov r14, rcx mov r15, r8 /* loop variables r8 -> rows (i = 0, N-1) rcx -> columns (j=0, F-1) invariants for outer loop: rdi points to next to write rsi points to next to start of next row */ mov rdi, r14 mov rsi, r13 mov r8, r11 compute_lr_yloop: /* accumulate in xmm0, initialize rax to *w */ mov rax, r15 mov rcx, r12 pxor xmm0, xmm0 compute_lr_dploop: movss xmm1, [rax] mulss xmm1, [rsi] addss xmm0, xmm1 add rax, 4 add rsi, 4 dec rcx jnz compute_lr_dploop compute_lr_done_loop: mulss xmm0, [vn1] /* this clobbers rax, rcx */ call paexp addss xmm0, [v1] rcpss xmm0, xmm0 movss [rdi], xmm0 add rdi, 4 dec r8 jnz compute_lr_yloop pop r15 pop r14 pop r13 pop r12 ret fitLR: /* Args and where we put them rdi -> N (r11) rsi -> F (r12) rdx -> float *X (r13) rcx -> float *y (r14) r8 -> float *w (r15) r9 -> float *b scratch (rbx) xmm0 -> C (xmm10) We're taking advantage of the fact that we only call our exp function to not save xmm{8,10,11} and r11 even though they're caller save... xmm10 = C xmm11 = lr */ push r12 push r13 push r14 push r15 push rbx /* stash everything in the right registers */ mov r11, rdi mov r12, rsi mov r13, rdx mov r14, rcx mov r15, r8 mov rbx, r9 movss xmm10, xmm0 movss xmm11, [vlr] /* Loop variables: r8 epoch; r9 i/N; rcx j/F */ xor r8d, r8d fitLR_epoch: /* reset scratch space */ mov rcx, r12 xor eax, eax mov rdi, rbx rep stosd xor r9d, r9d fitLR_instance: /* *w -> rdi */ mov rdi, r15 /* X+i*F (note float) -> rsi*/ mov rsi, r13 mov rax, r9 imul rax, r12 shl rax, 2 add rsi, rax /* compute X[i,:]*w into xmm0 */ pxor xmm0, xmm0 xor eax, eax mov rcx, r12 fitLR_dploop: movss xmm1, [rdi+rax] mulss xmm1, [rsi+rax] addss xmm0, xmm1 add rax, 4 sub rcx, 1 jnz fitLR_dploop /* compute 1/(1+exp(-X[i,:]*w)) into xmm0 */ mulss xmm0, [vn1] call paexp addss xmm0, [v1] rcpss xmm0, xmm0 /* compute y - exp(-w^T x) */ mov rax, r14 lea rax, [rax+4*r9] movss xmm1, [rax] subss xmm1, xmm0 /* accumulate gradient for loss in [rdi] [rdi] = [rdi] + [rsi] * xmm1 */ mov rcx, r12 /* rsi has X+(i*F) */ mov rdi, rbx fitLR_accumG: movss xmm2, [rdi] movss xmm0, [rsi] mulss xmm0, xmm1 addss xmm0, xmm2 movss [rdi], xmm0 add rsi, 4 add rdi, 4 sub rcx, 1 jnz fitLR_accumG /* loop over the instances */ add r9, 1 cmp r9, r11 jl fitLR_instance /* ok now rbx/aka float *b has the gradient; calculate [rdi] = [rdi] + lr*([rsi]*C - [rdi]) */ mov rsi, rbx mov rdi, r15 mov rcx, r12 fitLR_updateW: movss xmm0, [rsi] movss xmm2, [rdi] mulss xmm0, xmm10 subss xmm0, xmm2 mulss xmm0, xmm11 addss xmm0, xmm2 movss [rdi], xmm0 add rdi, 4 add rsi, 4 sub rcx, 1 jnz fitLR_updateW add r8, 1 cmp r8, 10000 jl fitLR_epoch pop rbx pop r15 pop r14 pop r13 pop r12 ret /* exp because watevs courtesy of https://math.stackexchange.com/questions/55830/how-to-calculate-ex-with-a-standard-calculator Basically, set al if it's negative Divide by two until we get to in [0,0.1], keep count in rcx Approximate by ((x+3)^2 + 3) / ((x-3)^2 + 3) Square rcx Reciprocate if al set Clobbers only rax, rcx */ paexp: xor ecx, ecx xor eax, eax comiss xmm0, [v0] ja paexp_pos /* flip sign, set al to 1 so we rcp at end */ mulss xmm0, [vn1] mov al, 1 paexp_pos: /* if it's already < 0.1, just approx */ comiss xmm0, [v0p1] jle paexp_pave /* else, divide by two */ movss xmm1, [v0p1] movss xmm2, [v0p5] paexp_sqIn: inc rcx mulss xmm0, xmm2 comiss xmm0, xmm1 ja paexp_sqIn paexp_pave: movss xmm2, [v3] /* xmm0 \in [0,0.1], al is sign, and rcx is # of powers we knocked off */ /* ((x+3)*(x+3) + 3) / ((x-3)*(x-3) + 3); numerator: xmm0; denom: xmm1 */ movss xmm1, xmm0 addss xmm0, xmm2 mulss xmm0, xmm0 addss xmm0, xmm2 subss xmm1, xmm2 mulss xmm1, xmm1 addss xmm1, xmm2 divss xmm0, xmm1 /* do we have to square? */ cmp rcx, 0 je paexp_sqOutDone paexp_sqOut: mulss xmm0, xmm0 dec rcx jnz paexp_sqOut paexp_sqOutDone: /* do we have to reciprocate since the arg was negative? */ cmp al, 0 je paexp_fin rcpss xmm0, xmm0 paexp_fin: ret