// 
// mp3scalpel 0.2a by Vassil Roussev
// DFRWS'07 Challenge Review Version (not ready for public release yet)
// 
#include <errno.h>
#include <fcntl.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/mman.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <unistd.h>

#define TRUE  1
#define FALSE 0

// Types
typedef unsigned char BOOL;
typedef unsigned char UCHAR;
typedef unsigned int  UINT;

typedef struct {
  UINT  length;
  UCHAR version;
  UCHAR layer;
  int   crc16;
  UINT  bit_rate;
  UINT  sample_rate;
  UCHAR pad;
  UCHAR eof;
} mp3header;

typedef struct {
  UINT id;
  UINT bit_rate;
  UINT start_block;
  UINT head;
  UINT end_block;
  UINT tail;
  UINT frame_rem;
  UINT frame_len;
  BOOL start_flag;
  BOOL end_flag;
  BOOL tag_flag;
} mp3chunk;


// Constants
#define MAX_CHUNKS 16384
const UINT BIT_MASKS[] = { 
  0x01, 0x03, 0x07, 0x0F, 0x1F, 0x3F, 0x7F, 0xFF, 
  0x01FF, 0x03FF, 0x07FF, 0x0FFF, 0x1FFF, 0x3FFF, 0x7FFF, 0xFFFF,
  0x01FFFF, 0x03FFFF, 0x07FFFF, 0x0FFFFF, 0x1FFFFF, 0x3FFFFF, 0x7FFFFF, 0xFFFFFF, 
  0x01FFFFFF, 0x03FFFFFF, 0x07FFFFFF, 0x0FFFFFFF, 0x1FFFFFFF, 0x3FFFFFFF, 0x7FFFFFFF, 0xFFFFFFFF
};
const UCHAR BITS[] = { 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80};
// -------

// Function prototypes
int findID3v2( UINT curr_bl_pos, UINT file_size);
int parseID3v2( UINT curr_pos);
int parse_mp3frame( UINT curr_pos, mp3header *mp3h);
mp3chunk *get_mp3chunk( UINT curr_pos, UINT file_size, UINT min_size, 
			UCHAR version, UCHAR layer, int b_rate, int s_rate);
void make_all_files( UINT chunk_cnt, UINT max_use_cnt);
void make_file( UINT start, UINT chunk_cnt, UINT max_use_cnt, UINT *trace, UINT level);
void carve_trace( char *fname, UINT *trace, UINT trace_len);
void carve_chunk( char *fname, mp3chunk *chunk);
void glue_chunks( char *prefix, UINT *trace, UINT trace_len);
// -------

// Globals
UCHAR *buffer;
UINT bl_size;
mp3chunk *chunks[MAX_CHUNKS];
UINT chunk_stat[MAX_CHUNKS];
// -------

int main( int argc, char **argv) {
  if( argc < 8) {
    printf( "Use: %s <block_size> <min_frames> <ver> <layer> <BR> <SR> <image_file>\n", argv[0]);
    exit(-1);
  }

  FILE *in;
  struct stat file_stat;
  UINT id3_cnt = 0;
  bl_size = atoi( argv[1]);
  UINT min_frames = atoi( argv[2]) % bl_size;
  UINT version = atoi( argv[3]);
  UINT layer = atoi( argv[4]);
  UINT bit_rate = atoi( argv[5]);
  UINT sample_rate = atoi( argv[6]);
  char *fname = argv[7];
  UINT  file_size;
  int i;

  memset( chunks, 0, sizeof( chunks));
  memset( chunk_stat, 0, sizeof( chunk_stat));

  UINT curr_pos, curr_bl_pos, piece_pos;
  
  // --- open image file ---
  if( !( in = fopen( fname,"r"))) {
    fprintf( stderr, "WARNING: Could not open file %s--skipping\n", fname);
    return -1;
  }
  if( fstat( fileno( in), &file_stat)) {
    fprintf( stderr, "WARNING: Could not stat file %s--skipping.\n", fname);
    return -1;
  }
  file_size = file_stat.st_size;
  buffer = mmap( 0, file_size, PROT_READ, MAP_FILE | MAP_PRIVATE, fileno( in), 0);
  if( buffer == MAP_FAILED) {
    fprintf( stderr, "PANIC: mmap() failed %d--bailing out.\n", errno);
    return -1;
  }
  // ------
  int match = 0, chunk_size = 0, chunk_cnt = 0;
  curr_pos = 0;
  mp3header mp3h;

  int id3pos = findID3v2( 0, file_size);

  mp3chunk *chunk = get_mp3chunk( 0, file_size, min_frames, version, layer, bit_rate, sample_rate);
  while( chunk && chunk_cnt < MAX_CHUNKS) {
    chunk->id = chunk_cnt+1;
    // Prepend id3v2 header (if available)
    UINT id3size = 0;
    curr_pos = chunk->start_block*bl_size + chunk->head;
    while( id3pos < curr_pos && id3pos > -1) {
      UINT id3size = getID3size( id3pos);
      if( id3pos+id3size == curr_pos) {
	chunk->start_block = id3pos/bl_size;
	chunk->head = 0;
	chunk->start_flag = TRUE;
	id3pos = findID3v2( chunk->end_block*bl_size, file_size);
      } else {
	id3size = 0;
	id3pos = findID3v2( id3pos + (id3size/bl_size+1)*bl_size, file_size); 
      }
    }
    // Secondary SOF rule
    if( chunk->head == 0)
      chunk->start_flag = TRUE;
    /*
    printf( "dd if=%s bs=%-4d skip=%-8d count=%-6d > BR_%03d_%03d-A_%08X_%08X-H_%03x-R_%03x-st_%d-end_%d.mp3\n",
	    fname, bl_size, chunk->start_block, chunk->end_block - chunk->start_block, bit_rate/1000, chunk_cnt+1,
	    chunk->start_block*bl_size, chunk->end_block*bl_size, chunk->head, chunk->frame_rem,
	    chunk->start_flag, chunk->end_flag);
    */  
    chunks[chunk_cnt] = chunk;
    chunk_cnt++;
    chunk = get_mp3chunk( chunk->end_block*bl_size, file_size, min_frames, version, layer, bit_rate, sample_rate);
  }
  // --- close & cleanup ---
  munmap( buffer, file_size);
  fclose( in);
  // -------
  for( i=0; i<chunk_cnt; i++)
    carve_chunk( fname, chunks[i]);
  make_all_files( chunk_cnt, 4);
} 

void make_all_files( UINT chunk_cnt, UINT max_use_cnt) {
  int i;
  UINT *trace = calloc( chunk_cnt, sizeof( UINT));
  memset( chunk_stat, 0, sizeof( chunk_stat));

  for( i=0; i<chunk_cnt; i++) {
    if( chunks[i]->start_flag) {
      make_file( i, chunk_cnt, max_use_cnt, trace, 0);
    }
  }
  for( i=0; i<chunk_cnt; i++)
    if( !chunks[i]->start_flag && chunk_stat[i] < max_use_cnt) {
      make_file( i, chunk_cnt, max_use_cnt, trace, 0);
    }
  free( trace);
}
void make_file( UINT start, UINT chunk_cnt, UINT max_use_cnt, UINT *trace, UINT level) {
  int i;

  if( chunk_stat[start] >= max_use_cnt)
    return;
  chunk_stat[start]++;
  trace[level] = start;
  if( chunks[start]->end_flag) {
    char prefix[64];
    if( level >= 1) {
      sprintf( prefix, "chunk-%03d_%03d", chunks[trace[0]]->bit_rate/1000, chunks[trace[0]]->id);
      glue_chunks( prefix, trace, level+1);
    }
    return;
  }
  for( i=0; i<chunk_cnt; i++)
    if( !chunks[i]->start_flag && chunk_stat[i] < max_use_cnt && i != start) {
      if( chunks[start]->frame_rem == chunks[i]->head) {
	make_file( i, chunk_cnt, max_use_cnt, trace, level+1);
      }
    }
  char prefix[64];
  sprintf( prefix, "chunk-%03d_%03d", chunks[trace[0]]->bit_rate/1000, chunks[trace[0]]->id);
  glue_chunks( prefix, trace, level+1);
}

void carve_chunk( char *fname, mp3chunk *chunk) {
  if( chunk->tag_flag) {
    printf( "dd if=%s bs=%-4d skip=%-8d count=%-6d > chunk-%03d_%03d\n",
	    fname, bl_size, chunk->start_block, chunk->end_block-chunk->start_block-1, chunk->bit_rate/1000, chunk->id);
    printf( "dd if=%s bs=1 skip=%-8d count=%-6d > rem-%03d-%03d\n",
	    fname, (chunk->end_block-1)*bl_size, chunk->tail, chunk->bit_rate/1000, chunk->id);
    printf( "cat rem-%03d-%03d >> chunk-%03d_%03d\n", chunk->bit_rate/1000, chunk->id, chunk->bit_rate/1000, chunk->id);
  } else
    printf( "dd if=%s bs=%-4d skip=%-8d count=%-6d > chunk-%03d_%03d\n",
	    fname, bl_size, chunk->start_block, chunk->end_block - chunk->start_block, chunk->bit_rate/1000, chunk->id);

  fprintf( stderr, "chunk-%03d_%03d: %d-%d (%08X-%08X)\n", chunk->bit_rate/1000, chunk->id, 
	   chunk->start_block, chunk->end_block-1, chunk->start_block, chunk->end_block-1);

  if( chunk->start_flag && chunk->end_flag)
    printf( "cp chunk-%03d_%03d file-%03d_%03d.mp3\n", chunk->bit_rate/1000, chunk->id, chunk->bit_rate/1000, chunk->id);
  else if( chunk->start_flag)
    printf( "cp chunk-%03d_%03d head-%03d_%03d.mp3\n", chunk->bit_rate/1000, chunk->id, chunk->bit_rate/1000, chunk->id);
  else if( chunk->end_flag)
    printf( "cp chunk-%03d_%03d tail-%03d_%03d.mp3\n", chunk->bit_rate/1000, chunk->id, chunk->bit_rate/1000, chunk->id);
  else
    printf( "cp chunk-%03d_%03d chunk-%03d_%03d.mp3\n", chunk->bit_rate/1000, chunk->id, chunk->bit_rate/1000, chunk->id);
    
  if( !chunk->start_flag && chunk->start_block > 0)
    printf( "dd if=%s bs=%-4d skip=%-8d count=1 > block-%08X\n",
	    fname, bl_size, chunk->start_block-1, chunk->start_block-1);

  // will blow up at the end of the file
  if( !chunk->end_flag)
    printf( "dd if=%s bs=%-4d skip=%-8d count=1 > block-%08X\n",
	    fname, bl_size, chunk->end_block, chunk->end_block);
      
  /*
  printf( "dd if=%s bs=%-4d skip=%-8d count=%-6d > chunk_%03d_%03d-A_%08X_%08X-H_%03x-R_%03x-st_%d-end_%d.mp3\n",
	  fname, bl_size, chunk->start_block, chunk->end_block - chunk->start_block, bit_rate/1000, chunk_cnt+1,
	  chunk->start_block*bl_size, chunk->end_block*bl_size, chunk->head, chunk->frame_rem,
	  chunk->start_flag, chunk->end_flag);
  */
}
void glue_chunks( char *prefix, UINT *trace, UINT trace_len) {
  char cmd[1024], name0[1024], name1[1024], name2[1024];

  if( trace_len < 2)
    return;
  int i=0;
  /*
  fprintf( stderr, "glue: %s ", prefix);
  for( i=0; i<trace_len; i++)
    fprintf( stderr, "%d ", chunks[trace[i]]->id);
  */
  int rem_blocks = (chunks[trace[0]]->frame_len - (bl_size - chunks[trace[0]]->tail));
  rem_blocks = (rem_blocks < 0) ? 0 : rem_blocks/bl_size;
  //fprintf( stderr, "  frame: %d tail: %d rem: %d rem_bl: %d\n", 
  //	   chunks[trace[0]]->frame_len, chunks[trace[0]]->tail, chunks[trace[0]]->frame_rem, rem_blocks);
  if( rem_blocks == 0) {
    sprintf( name0, "%s-%03d_%03d", prefix, chunks[trace[1]]->bit_rate/1000, chunks[trace[1]]->id);
    if( trace_len > 2) {
      printf( "cat %s chunk-%03d_%03d > %s\n", prefix, chunks[trace[1]]->bit_rate/1000, chunks[trace[1]]->id, name0);
      glue_chunks( name0, trace+1, trace_len-1);
    } else {
      printf( "cat %s chunk-%03d_%03d > %s.mp3\n", prefix, chunks[trace[1]]->bit_rate/1000, chunks[trace[1]]->id, name0);
    }
  } else if( rem_blocks == 1) {
    sprintf( name0, "%s--10--%03d-%03d", prefix, chunks[trace[1]]->bit_rate/1000, chunks[trace[1]]->id);
    sprintf( name1, "%s--01--%03d-%03d", prefix, chunks[trace[1]]->bit_rate/1000, chunks[trace[1]]->id);
    printf( "cp %s %s\n", prefix, name0);
    printf( "cp %s %s\n", prefix, name1);
    if( trace_len > 2) {
      sprintf( cmd, "cat %s block-%08X chunk-%03d_%03d > %s\n", 
	       prefix, chunks[trace[0]]->end_block, chunks[trace[1]]->bit_rate/1000, chunks[trace[1]]->id, name0);
      printf( cmd);
      sprintf( cmd, "cat %s block-%08X chunk-%03d_%03d > %s\n", 
	       prefix, chunks[trace[1]]->start_block-1, chunks[trace[1]]->bit_rate/1000, chunks[trace[1]]->id, name1);
      printf( cmd);
      glue_chunks( name0, trace+1, trace_len-1);
      glue_chunks( name0, trace+1, trace_len-1);
    } else {
      sprintf( cmd, "cat %s block-%08X chunk-%03d_%03d > %s.mp3\n", 
	       prefix, chunks[trace[0]]->end_block, chunks[trace[1]]->bit_rate/1000, chunks[trace[1]]->id, name0);
      printf( cmd);
      sprintf( cmd, "cat %s block-%08X chunk-%03d_%03d > %s.mp3\n", 
	       prefix, chunks[trace[1]]->start_block-1, chunks[trace[1]]->bit_rate/1000, chunks[trace[1]]->id, name1);
      printf( cmd);
    }
  } else
    fprintf( stderr, "rem_blocks: %d\n", rem_blocks);
  
}

// Returns the next mp3 chunk description
mp3chunk *get_mp3chunk( UINT pos, UINT file_size, UINT min_frames, 
			UCHAR version, UCHAR layer, int b_rate, int s_rate) {
  int sync = -1, chunk_size = -1;
  UCHAR match = 0;
  UINT piece_pos, br;
  
  UINT frame_cnt, frame_len;
  mp3chunk *chunk = NULL;
  while( 1) {
    //    printf( "seek: %08X %5.2f sync: %d match: %d\n", pos, pos/(1024.0*1024), sync, match);
    mp3header mp3h;
    while( !match && pos < file_size) {
      sync = parse_mp3frame( pos, &mp3h);
      while( sync < 0 && pos < file_size) {
	pos++;
	sync = parse_mp3frame( pos, &mp3h);
      }
      if( pos >= file_size)
	break;
      match = (mp3h.bit_rate == b_rate) && (mp3h.sample_rate == s_rate) && (mp3h.version == version) && (mp3h.layer == layer);
      pos = (match) ? pos : pos+1;
    }   
    if( pos >= file_size)
      return NULL;

    piece_pos = pos;
    frame_cnt = -1;
    frame_len = mp3h.length;
    br = mp3h.bit_rate;
    while( sync > 0 && match && (pos < file_size-10)) {
      frame_cnt++;
      frame_len = mp3h.length;
      pos += mp3h.length;
      sync = parse_mp3frame( pos, &mp3h);
      match = (mp3h.bit_rate == b_rate) && (mp3h.sample_rate == s_rate) && (mp3h.version == version) && (mp3h.layer == layer);
    }
    // Is chunk long enough?
    if( frame_cnt >= min_frames)
      break;
    pos = piece_pos+1;
    match = 0;
  } // while(1)

  chunk = (mp3chunk *)calloc( 1, sizeof( mp3chunk));
  // Append id3v1 header (if available)
  if( buffer[pos] == 'T' && buffer[pos+1] == 'A' && buffer[pos+2] == 'G') {
    chunk->tag_flag = TRUE;
    chunk->end_flag = TRUE;
    pos += 128;
  } 
  if( pos > file_size) 
    pos = file_size;
  chunk->bit_rate = br;
  chunk->start_block = piece_pos/bl_size;
  chunk->head = piece_pos%bl_size;
  chunk->end_block = pos/bl_size;
  chunk->tail = pos % bl_size;
  chunk->frame_len = frame_len;
  if( !chunk->end_flag) {
    if( pos/bl_size > (pos-frame_len)/bl_size) {
      chunk->frame_rem = chunk->tail;
      chunk->end_block -= (pos/bl_size - (pos-frame_len)/bl_size)-1;
    } else {
      chunk->frame_rem = frame_len-(bl_size-chunk->tail);
      chunk->end_block++;
      //      printf( "end++ >> pos: %08x bl: %08X   prev: %08X bl: %08X\n", 
      //	      pos, (pos/bl_size)*bl_size, pos-frame_len, ((pos-frame_len)/bl_size)*bl_size);
    }
  } else
    chunk->end_block++;
  return chunk;
} 

int parse_mp3frame( UINT curr_pos, mp3header *mp3h) {
  UCHAR LAYERS[] = { 0xFF, 3, 2, 1};
  UINT  FREQ[] = { 44100, 48000, 32000, 0xFFFF};
  UINT  V1L1[] = { 0xFFFF, 32, 64, 96, 128, 160, 192, 224, 256, 288, 320, 352, 384, 416, 448, 0xFFFF};
  UINT  V1L2[] = { 0xFFFF, 32, 48, 56, 64, 80, 96, 112, 128, 160, 192, 224, 256, 320, 384, 0xFFFF};
  UINT  V1L3[] = { 0xFFFF, 32, 40, 48, 56, 64, 80, 96, 112, 128, 160, 192, 224, 256, 320, 0xFFFF};
  UINT  V2L1[] = { 0xFFFF, 32, 48, 56, 64, 80, 96, 112, 118, 144, 160, 176, 192, 224, 256, 0xFFFF};
  UINT  V2L23[] = { 0xFFFF, 8, 16, 24, 32, 40, 48, 56, 64, 80, 96, 112, 128, 144, 160, 0xFFFF};

  if( buffer[curr_pos] != 0xFF || (buffer[curr_pos+1] & 0xF0) != 0xF0)
    return -1;
  UINT field;

  memset( mp3h, 0, sizeof( mp3header));
  //printf( "mp3frame: %08X %x %x\n", curr_pos, buffer[curr_pos], buffer[curr_pos+1] );
  // mpeg version (ignore 2.5)
  field = (buffer[curr_pos+1] & 0x08) >> 3;
  mp3h->version = (field) ? 1 : 2;


  // layer
  field = (buffer[curr_pos+1] & 0x06) >> 1;
  mp3h->layer = LAYERS[field];
  if( mp3h->layer > 3)
    return -1; 

  // CRC16
  mp3h->crc16 = (buffer[curr_pos+1] & 0x01) ? -1 : ((buffer[curr_pos+4] << 8) | buffer[curr_pos+5]);

  // bit rate
  field = (buffer[curr_pos+2] & 0xF0) >> 4;
  if( field == 0 || field == 0xF)
    return -1;
  switch( mp3h->version) {
  case 1:
    switch( mp3h->layer) {
    case 1:
      mp3h->bit_rate = V1L1[field]*1000;
      break;
    case 2:
      mp3h->bit_rate = V1L2[field]*1000;
      break;
    case 3:
      mp3h->bit_rate = V1L3[field]*1000;
      break;
    default:
      fprintf( stderr, "Internal error: Unknown layer: %d\n", mp3h->layer);
      exit(-1);
    }
    break;
  case 2:
    switch( mp3h->layer) {
    case 1:
      mp3h->bit_rate = V2L1[field]*1000;
      break;
    case 2:
    case 3:
      mp3h->bit_rate = V2L23[field]*1000;
      break;
    default:
      fprintf( stderr, "Internal error: Unknown layer: %d\n", mp3h->layer);
      exit(-1);
    }
    break;
  default:
    fprintf( stderr, "Internal error: Unknown mpeg version: %d\n", mp3h->version);
    exit(-1);
  }

  // sampling rate
  field = (buffer[curr_pos+2] & 0xA) >> 2;
  if( mp3h->version == 1)
    mp3h->sample_rate = FREQ[field];
  else
    mp3h->sample_rate = FREQ[field]/2;

  // pad
  mp3h->pad = (buffer[curr_pos+2] & 0x2) ? 1 : 0;

  // frame length
  mp3h->length = (UINT)((144.0*mp3h->bit_rate)/mp3h->sample_rate + mp3h->pad);
  return 1;
}

// Parse a valid ID3v2 tag
int getID3size( UINT id3_pos) {
  UINT size = buffer[id3_pos+9] | buffer[id3_pos+8] << 7 | buffer[id3_pos+7] << 14 | buffer[id3_pos+6] << 21;
  return size+10;
}

int findID3v2( UINT curr_bl_pos, UINT file_size) {
  while( (buffer[curr_bl_pos+0] != 0x49) ||
	 (buffer[curr_bl_pos+1] != 0x44) ||
	 (buffer[curr_bl_pos+2] != 0x33) ||
	 (buffer[curr_bl_pos+3] == 0xFF) ||
	 (buffer[curr_bl_pos+4] == 0xFF) ||
	 (buffer[curr_bl_pos+5] &  0xE0) ||
	 (buffer[curr_bl_pos+6] >= 0x80) ||
	 (buffer[curr_bl_pos+7] >= 0x80) ||
	 (buffer[curr_bl_pos+8] >= 0x80) ||
	 (buffer[curr_bl_pos+9] >= 0x80)) {
    curr_bl_pos += bl_size;
    if( curr_bl_pos >= file_size)
      return -1;
  }
  if( curr_bl_pos < file_size)
    return curr_bl_pos;
  return -1;
}

