// Malloc verification

if objtype != "mips" then
	error("malloc checker for MIPS\n");

defn go()
{
	new();			// Make a new process

	xmalloc = malloc + 4;	// Delay slot filler may pull in first
	xfree = free + 4;	// instruction

	bpset(xmalloc);
	bpset(xfree);
	bpset(_exits+4);

	blocklist = {};		// List of allocated blocks
	framelist = {};		// Frames they were allocated from
	stacks = {};		// The stacks
	nstack = 0;

	while 1 do {
		cont();
		pc = *PC;
		if pc == xmalloc then
			gotmalloc();
		else
		if pc == xfree then
			gotfree();
		else {
			pstop(pid);
			return {};
		}
	}
}

defn stkcmp(s1, s2)		// return true if the stack traces are the same
{
	local f1, f2;

	while s1 do {
		f1 = head s1;
		f2 = head s2;

		if f2 == {} || f1[0] != f2[0] || f1[1] != f2[1] then
			return 0;

		s1 = tail s1;
		s2 = tail s2;
	}
	if s2 != {} then
		return 0;

	return 1;
}

defn stkindex(s)
{
	local slist, n;

	slist = stacks;
	n = 0;
	while slist do {
		if stkcmp(head slist, s) then
			return n;

		slist = tail slist;
		n = n+1;
	}

	stacks = append stacks, s;
	n = nstack;
	nstack = nstack+1;
	return n;
}

defn gotmalloc()
{
	local raddr, stk;

	raddr = *R31;
	bpset(raddr);
	stk = strace(*PC, *SP, *R31);	// Save a copy of the stack up to the malloc
	cont();				// Continue until malloc returns
	bpdel(raddr);

	blocklist = append blocklist, *R1;
	framelist = append framelist, stkindex(stk);
}

defn pentry(s)
{
	local sno, blst, gotone, size;

	gotone = 0;
	blst = blocklist;
	flst = framelist;

	while blst do {
		sno = head flst;
		if sno == s then {
			addr = fmt(head blst, 'X');
			size = 1<<*(addr-12);
//			print(addr, " of ", fmt(size, 'D'), "bytes\n");
			gotone = gotone + size;
		}
		blst = tail blst;
		flst = tail flst;
	}
	return gotone;
}

defn leak()
{
	local n;
	local sz;

	n = 0;
	while n < nstack do {
		sz = pentry(n);
		if sz then {
			print("Lost a total of ", fmt(sz, 'D'), " bytes from:\n");
			mstack(stacks[n]);
		}
		n = n+1;
	}
}

defn scanmem(start, end)
{
	local n;

	start = fmt(start, 'X');
	while start < end do {
		n = match(*start, blocklist);
		if n >= 0 then {
			blocklist = delete blocklist, n;
			framelist = delete framelist, n;
		}
		start++;
	}
}

defn refs()			// Delete blocks with references
{
	print("data...");
	scanmem(bdata, edata);

	print("bss...");
	scanmem(edata, *bloc);

	print("stack...");
	scanmem(*SP, 0x80000000);

	print("\n");
	leak();	
}

defn mstack(stk)
{
	while stk do {
		frame = head stk;
		print("\t", fmt(frame[0], 'a'), "() ");
		print(pcfile(frame[0]), ":", pcline(frame[0]));
		print("called from ", fmt(frame[1], 'a'), " ");
		pfl(frame[1]);
		stk = tail stk;
	}
}

defn gotfree()
{
	local n, addr;

	addr = *R1;			// We know the first parameter is register R1
	if addr == 0 then
		return {};

	n = match(addr, blocklist);
	if n < 0 then {
		print("free(", addr, ") called with bad pointer:\n");
		stk();
		return {};
	}
	blocklist = delete blocklist, n;
	framelist = delete framelist, n;
}

defn stopped(pid)
{
	return {};
}

print("/sys/lib/acid/malloc");
