#include "u.h" #include "lib.h" #include "mem.h" #include "dat.h" #include "fns.h" #include "io.h" #include "ip.h" extern int debugload; uchar broadcast[Eaddrlen] = { 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, }; static ushort tftpport = 5000; static int Id = 1; static Netaddr myaddr; static Netaddr server; typedef struct { uchar header[4]; uchar data[Segsize]; } Tftp; static Tftp tftpb; static void hnputs(uchar *ptr, ushort val) { ptr[0] = val>>8; ptr[1] = val; } static void hnputl(uchar *ptr, ulong val) { ptr[0] = val>>24; ptr[1] = val>>16; ptr[2] = val>>8; ptr[3] = val; } static ulong nhgetl(uchar *ptr) { return ((ptr[0]<<24) | (ptr[1]<<16) | (ptr[2]<<8) | ptr[3]); } static ushort nhgets(uchar *ptr) { return ((ptr[0]<<8) | ptr[1]); } static short endian = 1; static char* aendian = (char*)&endian; #define LITTLE *aendian static ushort ptcl_csum(void *a, int len) { uchar *addr; ulong t1, t2; ulong losum, hisum, mdsum, x; addr = a; losum = 0; hisum = 0; mdsum = 0; x = 0; if((ulong)addr & 1) { if(len) { hisum += addr[0]; len--; addr++; } x = 1; } while(len >= 16) { t1 = *(ushort*)(addr+0); t2 = *(ushort*)(addr+2); mdsum += t1; t1 = *(ushort*)(addr+4); mdsum += t2; t2 = *(ushort*)(addr+6); mdsum += t1; t1 = *(ushort*)(addr+8); mdsum += t2; t2 = *(ushort*)(addr+10); mdsum += t1; t1 = *(ushort*)(addr+12); mdsum += t2; t2 = *(ushort*)(addr+14); mdsum += t1; mdsum += t2; len -= 16; addr += 16; } while(len >= 2) { mdsum += *(ushort*)addr; len -= 2; addr += 2; } if(x) { if(len) losum += addr[0]; if(LITTLE) losum += mdsum; else hisum += mdsum; } else { if(len) hisum += addr[0]; if(LITTLE) hisum += mdsum; else losum += mdsum; } losum += hisum >> 8; losum += (hisum & 0xff) << 8; while(hisum = losum>>16) losum = hisum + (losum & 0xffff); return ~losum; } static ushort ip_csum(uchar *addr) { int len; ulong sum = 0; len = (addr[0]&0xf)<<2; while(len > 0) { sum += addr[0]<<8 | addr[1] ; len -= 2; addr += 2; } sum = (sum & 0xffff) + (sum >> 16); sum = (sum & 0xffff) + (sum >> 16); return (sum^0xffff); } enum { /* this is only true of IPv4, but we're not doing v6 yet */ Min_udp_payload = ETHERMINTU - ETHERHDRSIZE - UDP_HDRSIZE, }; static void udpsend(int ctlrno, Netaddr *a, void *data, int dlen) { char payload[ETHERMAXTU]; Udphdr *uh; Etherhdr *ip; Etherpkt pkt; int len, ptcllen; /* * if packet is too short, make it longer rather than relying * on ethernet interface or lower layers to pad it. */ if (dlen < Min_udp_payload) { memmove(payload, data, dlen); data = payload; dlen = Min_udp_payload; } uh = (Udphdr*)&pkt; memset(uh, 0, sizeof(Etherpkt)); memmove(uh->udpcksum+sizeof(uh->udpcksum), data, dlen); /* * UDP portion */ ptcllen = dlen + (UDP_HDRSIZE-UDP_PHDRSIZE); uh->ttl = 0; uh->udpproto = IP_UDPPROTO; uh->frag[0] = 0; uh->frag[1] = 0; hnputs(uh->udpplen, ptcllen); hnputl(uh->udpsrc, myaddr.ip); hnputs(uh->udpsport, myaddr.port); hnputl(uh->udpdst, a->ip); hnputs(uh->udpdport, a->port); hnputs(uh->udplen, ptcllen); uh->udpcksum[0] = 0; uh->udpcksum[1] = 0; dlen = (dlen+1)&~1; hnputs(uh->udpcksum, ptcl_csum(&uh->ttl, dlen+UDP_HDRSIZE)); /* * IP portion */ ip = (Etherhdr*)&pkt; len = UDP_EHSIZE+UDP_HDRSIZE+dlen; /* non-descriptive names */ ip->vihl = IP_VER|IP_HLEN; ip->tos = 0; ip->ttl = 255; hnputs(ip->length, len-ETHER_HDR); hnputs(ip->id, Id++); ip->frag[0] = 0; ip->frag[1] = 0; ip->cksum[0] = 0; ip->cksum[1] = 0; hnputs(ip->cksum, ip_csum(&ip->vihl)); /* * Ethernet MAC portion */ hnputs(ip->type, ET_IP); memmove(ip->d, a->ea, sizeof(ip->d)); if(debug) { print("udpsend "); } ethertxpkt(ctlrno, &pkt, len, Timeout); } static void nak(int ctlrno, Netaddr *a, int code, char *msg, int report) { int n; char buf[128]; buf[0] = 0; buf[1] = Tftp_ERROR; buf[2] = 0; buf[3] = code; strcpy(buf+4, msg); n = strlen(msg) + 4 + 1; udpsend(ctlrno, a, buf, n); if(report) print("\ntftp: error(%d): %s\n", code, msg); } static int udprecv(int ctlrno, Netaddr *a, void *data, int dlen) { int n, len; ushort csm; Udphdr *h; ulong addr, timo; Etherpkt pkt; static int rxactive; if(rxactive == 0) timo = 1000; else timo = Timeout; timo += TK2MS(m->ticks); while(timo > TK2MS(m->ticks)){ n = etherrxpkt(ctlrno, &pkt, timo-TK2MS(m->ticks)); if(n <= 0) continue; h = (Udphdr*)&pkt; if(debug) print("udprecv %E to %E...\n", h->s, h->d); if(nhgets(h->type) != ET_IP) { if(debug) print("not ip..."); continue; } if(ip_csum(&h->vihl)) { print("ip chksum error\n"); continue; } if(h->vihl != (IP_VER|IP_HLEN)) { print("ip bad vers/hlen\n"); continue; } if(h->udpproto != IP_UDPPROTO) { if(debug) print("not udp (%d)...", h->udpproto); continue; } if(debug) print("okay udp..."); h->ttl = 0; len = nhgets(h->udplen); hnputs(h->udpplen, len); if(nhgets(h->udpcksum)) { csm = ptcl_csum(&h->ttl, len+UDP_PHDRSIZE); if(csm != 0) { print("udp chksum error csum #%4ux len %d\n", csm, n); break; } } if(a->port != 0 && nhgets(h->udpsport) != a->port) { if(debug) print("udpport %ux not %ux\n", nhgets(h->udpsport), a->port); continue; } addr = nhgetl(h->udpsrc); if(a->ip != Bcastip && a->ip != addr) { if(debug) print("bad ip %lux not %lux\n", addr, a->ip); continue; } len -= UDP_HDRSIZE-UDP_PHDRSIZE; if(len > dlen) { print("udp: packet too big: %d > %d; from addr %E\n", len, dlen, h->udpsrc); continue; } memmove(data, h->udpcksum+sizeof(h->udpcksum), len); a->ip = addr; a->port = nhgets(h->udpsport); memmove(a->ea, pkt.s, sizeof(a->ea)); rxactive = 1; return len; } return 0; } static int tftpblockno; /* * format of a request packet, from the RFC: * 2 bytes string 1 byte string 1 byte ------------------------------------------------ | Opcode | Filename | 0 | Mode | 0 | ------------------------------------------------ */ static int tftpopen(int ctlrno, Netaddr *a, char *name, Tftp *tftp) { int i, len, rlen, oport; char buf[Segsize+2]; buf[0] = 0; buf[1] = Tftp_READ; len = 2 + sprint(buf+2, "%s", name) + 1; len += sprint(buf+len, "octet") + 1; oport = a->port; for(i = 0; i < 5; i++){ a->port = oport; udpsend(ctlrno, a, buf, len); a->port = 0; if((rlen = udprecv(ctlrno, a, tftp, sizeof(Tftp))) < sizeof(tftp->header)) continue; switch((tftp->header[0]<<8)|tftp->header[1]){ case Tftp_ERROR: print("tftpopen: error (%d): %s\n", (tftp->header[2]<<8)|tftp->header[3], (char*)tftp->data); return -1; case Tftp_DATA: tftpblockno = 1; len = (tftp->header[2]<<8)|tftp->header[3]; if(len != tftpblockno){ print("tftpopen: block error: %d\n", len); nak(ctlrno, a, 1, "block error", 0); return -1; } rlen -= sizeof(tftp->header); if(rlen < Segsize){ /* ACK now, in case we don't later */ buf[0] = 0; buf[1] = Tftp_ACK; buf[2] = tftpblockno>>8; buf[3] = tftpblockno; udpsend(ctlrno, a, buf, sizeof(tftp->header)); } return rlen; } } print("tftpopen: failed to connect to server\n"); return -1; } static int tftpread(int ctlrno, Netaddr *a, Tftp *tftp, int dlen) { uchar buf[4]; int try, blockno, len; dlen += sizeof(tftp->header); for(try = 0; try < 10; try++) { buf[0] = 0; buf[1] = Tftp_ACK; buf[2] = tftpblockno>>8; buf[3] = tftpblockno; udpsend(ctlrno, a, buf, sizeof(buf)); len = udprecv(ctlrno, a, tftp, dlen); if(len <= sizeof(tftp->header)){ if(debug) print("tftpread: too short %d <= %d\n", len, sizeof(tftp->header)); continue; } blockno = (tftp->header[2]<<8)|tftp->header[3]; if(blockno <= tftpblockno){ if(debug) print("tftpread: blkno %d <= %d\n", blockno, tftpblockno); continue; } if(blockno == tftpblockno+1) { tftpblockno++; if(len < dlen) { /* last packet; send final ack */ tftpblockno++; buf[0] = 0; buf[1] = Tftp_ACK; buf[2] = tftpblockno>>8; buf[3] = tftpblockno; udpsend(ctlrno, a, buf, sizeof(buf)); } return len-sizeof(tftp->header); } print("tftpread: block error: %d, expected %d\n", blockno, tftpblockno+1); } return -1; } static int bootpopen(int ctlrno, char *file, Bootp *rep, int dotftpopen) { Bootp req; int i, n; uchar *ea; char name[128], *filename, *sysname; if (debugload) print("bootpopen: ether%d!%s...", ctlrno, file); if((ea = etheraddr(ctlrno)) == 0){ print("invalid ctlrno %d\n", ctlrno); return -1; } filename = 0; sysname = 0; if(file && *file){ strcpy(name, file); if(filename = strchr(name, '!')){ sysname = name; *filename++ = 0; } else filename = name; } memset(&req, 0, sizeof(req)); req.op = Bootrequest; req.htype = 1; /* ethernet */ req.hlen = Eaddrlen; /* ethernet */ memmove(req.chaddr, ea, Eaddrlen); if(filename != nil) strncpy(req.file, filename, sizeof(req.file)); if(sysname != nil) strncpy(req.sname, sysname, sizeof(req.sname)); myaddr.ip = 0; myaddr.port = BPportsrc; memmove(myaddr.ea, ea, Eaddrlen); etherrxflush(ctlrno); for(i = 0; i < 10; i++) { server.ip = Bcastip; server.port = BPportdst; memmove(server.ea, broadcast, sizeof(server.ea)); udpsend(ctlrno, &server, &req, sizeof(req)); if(udprecv(ctlrno, &server, rep, sizeof(*rep)) <= 0) continue; if(memcmp(req.chaddr, rep->chaddr, Eaddrlen)) continue; if(rep->htype != 1 || rep->hlen != Eaddrlen) continue; if(sysname == 0 || strcmp(sysname, rep->sname) == 0) break; } if(i >= 10) { print("bootp timed out\n"); return -1; } if(!dotftpopen) return 0; if(filename == 0 || *filename == 0){ if(strcmp(rep->file, "/386/9pxeload") == 0) return -1; filename = rep->file; } if(rep->sname[0] != '\0') print("%s ", rep->sname); print("(%d.%d.%d.%d!%d): %s\n", rep->siaddr[0], rep->siaddr[1], rep->siaddr[2], rep->siaddr[3], server.port, filename); myaddr.ip = nhgetl(rep->yiaddr); myaddr.port = tftpport++; server.ip = nhgetl(rep->siaddr); server.port = TFTPport; if((n = tftpopen(ctlrno, &server, filename, &tftpb)) < 0) return -1; return n; } int bootpboot(int ctlrno, char *file, Boot *b) { int n; Bootp rep; if((n = bootpopen(ctlrno, file, &rep, 1)) < 0) return -1; while(bootpass(b, tftpb.data, n) == MORE){ n = tftpread(ctlrno, &server, &tftpb, sizeof(tftpb.data)); if(n < sizeof(tftpb.data)) break; } if(0 < n && n < sizeof(tftpb.data)) /* got to end of file */ bootpass(b, tftpb.data, n); else nak(ctlrno, &server, 3, "ok", 0); /* tftpclose to abort transfer */ bootpass(b, nil, 0); /* boot if possible */ return -1; } #include "fs.h" #define INIPATHLEN 64 static struct { Fs fs; char ini[INIPATHLEN]; } pxether[MaxEther]; static vlong pxediskseek(Fs*, vlong) { return -1LL; } static long pxediskread(Fs*, void*, long) { return -1; } static long pxeread(File* f, void* va, long len) { int n; Bootp rep; char *p, *v; if((n = bootpopen(f->fs->dev, pxether[f->fs->dev].ini, &rep, 1)) < 0) return -1; p = v = va; while(n > 0) { if((p-v)+n > len) n = len - (p-v); memmove(p, tftpb.data, n); p += n; if(n != Segsize) break; if((n = tftpread(f->fs->dev, &server, &tftpb, sizeof(tftpb.data))) < 0) return -1; } return p-v; } static int pxewalk(File* f, char* name) { Bootp rep; char *ini; switch(f->walked){ default: return -1; case 0: if(strcmp(name, "cfg") == 0){ f->walked = 1; return 1; } break; case 1: if(strcmp(name, "pxe") == 0){ f->walked = 2; return 1; } break; case 2: if(strcmp(name, "%E") != 0) break; f->walked = 3; if(bootpopen(f->fs->dev, nil, &rep, 0) < 0) return 0; ini = pxether[f->fs->dev].ini; /* use our mac address instead of relying on a bootp answer */ snprint(ini, INIPATHLEN, "/cfg/pxe/%E", (uchar *)myaddr.ea); f->path = ini; return 1; } return 0; } void* pxegetfspart(int ctlrno, char* part, int) { if(!pxe) return nil; if(strcmp(part, "*") != 0) return nil; if(ctlrno >= MaxEther) return nil; if(iniread && getconf("*pxeini") != nil) return nil; pxether[ctlrno].fs.dev = ctlrno; pxether[ctlrno].fs.diskread = pxediskread; pxether[ctlrno].fs.diskseek = pxediskseek; pxether[ctlrno].fs.read = pxeread; pxether[ctlrno].fs.walk = pxewalk; pxether[ctlrno].fs.root.fs = &pxether[ctlrno].fs; pxether[ctlrno].fs.root.walked = 0; return &pxether[ctlrno].fs; }