// KMB 2006 Dec 11,12
// ACM Trans Math Soft 13, 58 (1987)
// pick n out of N objects uniformly at random
// gcc try_vitter_fast_sampling.c -lm && ./a.out
// gcc -Wall -O3 try_vitter_fast_sampling.c -lm && ./a.out | histogram | bars | p
// mcopy -o -v try_vitter_fast_sampling.c a:

#include <math.h>
#include <stdlib.h>
#include <stdio.h>
#define NDEBUG
#include <assert.h>

#define UNI01 (rand()/((double)RAND_MAX)) /* uniform(0,1) */

void Method_A_long(unsigned long n, unsigned long N, unsigned long offset) {
  // Method A page 65.  NB: method D is faster.
  double v,quot,top=N-n,Nr=N;
  unsigned long s,k=-1;
  fprintf(stderr,"Method_A: n=%lu N=%lu offset=%lu\n",n,N,offset);
  while (n>1) {
    v=UNI01;
    s=0;
    quot=top/Nr;
    while (quot>v) {
      s++;
      top--;
      Nr--;
      quot*=top/Nr;
    }
    k+=1+s; // skip
    printf("%lu\n",k+offset);
    Nr--;
    n--;
  }
  k+=1+(unsigned long)floor(Nr*UNI01);
  printf("%lu\n",k)+offset;
}

void Method_D_long(unsigned long n, unsigned long N) {
  // Method D page 66.
  double nr=n,Nr=N,ninv=1.0/nr,nmin1inv,u,x,vprime,y1,y2,top,bot,negsr,qu1r=Nr-nr+1.0;
  unsigned long k=0;
  long s,t,limit,negalphainv=-13.0,threshold=-negalphainv*n,qu1=N-n+1;
  vprime=pow(UNI01,ninv);
  while (n>1 && threshold<N) {
    nmin1inv=1.0/(nr-1.0);
    while (1) {
      while (1) {
        x=Nr*(1.0-vprime);
        s=(long)floor(x);
        //fprintf(stderr,"loop: Nr=%g vprime=%g x=%g\n",Nr,vprime,x);
        assert(s>0);
        if (s<qu1) break;
        vprime=pow(UNI01,ninv); 
      }
      u=UNI01; 
      negsr=-s;
      y1=pow(u*Nr/qu1r,nmin1inv);
      vprime=y1*(1.0-x/Nr)*(qu1r/(negsr+qu1r));
      if (vprime<=1.0) break;
      y2=1.0;
      top=Nr-1.0;
      if (n-1>s) {
        bot=Nr-nr;
        limit=N-s;
      } else {
        bot=Nr+negsr-1.0;
        limit=qu1;
      }
      for (t=N-1; t>=limit; t--) {
        y2*=top/bot;
        top--;
        bot--;
      }
      if (Nr/(Nr-x)>=y1*pow(y2,nmin1inv)) {
        vprime=pow(UNI01,nmin1inv); 
        break;
      }
      vprime=pow(UNI01,ninv); 
    }
    k+=1+s; // skip
    printf("%lu\n",k);
    N-=s+1;
    Nr+=negsr-1.0;
    n--;
    nr--;
    ninv=nmin1inv;
    qu1-=s;
    qu1r+=negsr;
    threshold+=negalphainv;
  }
  if (n>1) {
    //fprintf(stderr,"A: N=%lu k=%lu n=%lu offset=%lu\n",N,k,n,k);
    //fprintf(stderr,"A: n=%lu\n",n);
    if (k==-1) Method_A_long(nr,Nr,0); // outside threshold
    else       Method_A_long(nr,Nr-k,k);
  } else {
    s=(long)floor(N*vprime);
    assert(s>0);
    k+=1+s; // skip
    printf("%lu\n",k);
  }
}

void f(double x) { // callback
  printf("%.0f\n",x);
}

void Method_A_double(double nr, double Nr, double offset, void f(double)) {
  // Method A page 65.  For n/N large.
  // offset is added to each output
  // do everything in double to get 2^53 range
  double v,quot,top=Nr-nr,s,k=-1.0;
  while (nr>1) {
    v=UNI01;
    s=0.0;
    quot=top/Nr;
    while (quot>v) {
      s++;
      top--;
      Nr--;
      quot*=top/Nr;
    }
    k+=1+s; // skip
    f(k+offset);
    Nr--;
    nr--;
  }
  k+=floor(Nr*UNI01);
  f(k+offset);
}

void Method_D_double(double nr, double Nr, void f(double)) {
  // Method D page 66.  For n/N small.
  // do everything in double to get 2^53 range
  double ninv=1.0/nr,nmin1inv,u,x,vprime,y1,y2,top,bot,negsr,qu1r=Nr-nr+1.0,k=0.0,s,t,limit,negalphainv=-13.0,threshold=-negalphainv*nr;
  vprime=pow(UNI01,ninv);
  while (nr>1 && threshold<Nr) {
    nmin1inv=1.0/(nr-1.0);
    while (1) {
      while (1) {
        x=Nr*(1.0-vprime);
        s=floor(x);
        if (s<qu1r) break;
        vprime=pow(UNI01,ninv); 
      }
      u=UNI01; 
      negsr=-s;
      y1=pow(u*Nr/qu1r,nmin1inv);
      vprime=y1*(1.0-x/Nr)*(qu1r/(negsr+qu1r));
      if (vprime<=1.0) break;
      y2=1.0;
      top=Nr-1.0;
      if (nr-1>s) {
        bot=Nr-nr;
        limit=Nr-s;
      } else {
        bot=Nr+negsr-1.0;
        limit=qu1r;
      }
      for (t=Nr-1; t>=limit; t--) {
        y2*=top/bot;
        top--;
        bot--;
      }
      if (Nr/(Nr-x)>=y1*pow(y2,nmin1inv)) {
        vprime=pow(UNI01,nmin1inv); 
        break;
      }
      vprime=pow(UNI01,ninv); 
    }
    k+=1+s; // skip
    f(k);
    Nr+=negsr-1.0;
    nr--;
    ninv=nmin1inv;
    qu1r+=negsr;
    threshold+=negalphainv;
  }
  if (nr>1) {
    if (k==-1) Method_A_double(nr,Nr,0,f); // outside threshold
    else       Method_A_double(nr,Nr-k,k,f);
  } else {
    s=floor(Nr*vprime);
    k+=1+s; // skip
    f(k);
  }
}

int main() {
  Method_D_double(100000,(double)1000000000000UL,f);
  return 0;
}
