#!/usr/bin/env python3
# 04 Fabian Ying 2016-11-13 Added discrete example
# 03 Keith Briggs 2016-10-25
# 02: Added ccdf option to plot function
#     Switched order for spans, the variable for CI (now: [Lower bound, Best estimate, Upper bound])
#     Added a new plot function with shading
# 01: Added best estimate point for each quantile (3rd column of spans). Added plot cdf tool
# Keith Briggs 2016-08-08 was quantile_estimation_02.py
# see slide 25 in ~/KTP_TVWS_Project/tvws_ktn_london_slides-0.2/Spectrum_sharing_deliverable_2013-10-11.pdf

import numpy as np
import matplotlib.pyplot as plt
from sys import exit,stderr,stdout
from math import sqrt
from random import random
from scipy.stats import binom

class Quantile_estimator:

  def __init__(self,f,**f_params):
    ''' 
    Instantiate a Quantile_estimator: f_params are extra keywords arguments
    passed to the sampling function f, which must also take a size
    keyword argument
    '''
    self.f=f
    self.f_params=f_params

  def _predict_indices_for_confidence_interval(self,sample_size,quantile,conf=0.95):
    # internal use only
    # cf. ~/KTP_TVWS_software/quantsim/src/binom.cc
    rv=binom(sample_size,quantile)
    r,s=rv.interval(conf)
    return max(0,int(r)),min(sample_size-1,int(s))

  def estimate(self,quantiles=(0.5,),conf=0.95,eps_abs=1e-3,eps_rel=1e-2,batch_size=1000,max_n_batches=1000):
    ''' 
     Take samples from callable object f until all quantiles are 
     sufficiently accurately determined, and return the confidence intervals 
     as rows of an np.array, and also the number of samples needed.
     The default quantiles=(0.5,) finds the median.
     '''
    if type(quantiles) is float:
      quantiles=[quantiles,]
    for quantile in quantiles:
      if not 0.0<quantile<1.0:
        print('Please specify quantiles between 0 and 1')
        return None
    sample_size=batch_size
    nq=len(quantiles)
    x=self.f(size=sample_size,**self.f_params)
    for i in range(max_n_batches):
      # FIXME only need to check CIs which have not converged yet...
      x.sort()
      span=x[-1]-x[0]
      ok=True
      spans=np.empty(shape=(nq,3))
      for j,qj in enumerate(quantiles):
        r,s=self._predict_indices_for_confidence_interval(sample_size,qj,conf)
        ok=ok and x[s]-x[r]<eps_abs+eps_rel*span
        if not ok: break
        spans[j]=(x[r],x[int(qj*sample_size)],x[s]) # [LB, best_estimate, UB]
      if ok:
        return spans,sample_size
      sample_size+=batch_size
      x=np.append(x,self.f(size=batch_size,**self.f_params))
    print('estimate failed to converge, quitting!',file=stderr)
    print('estimate failed to converge, quitting!')
    exit(1)

  def plot_cdf(self,spans,quantiles,log_yscale=False,xlabel='',ylabel='',title='',img_fn_base='',ccdf=False,fontsize=20):
    ''' Plots the cdf from quantiles estimation, with errorbars from spans '''
    if not isinstance(quantiles,np.ndarray):
      quantiles=np.array(quantiles)
    x=spans[:,1]
    x_err1=x-spans[:,0]
    x_err2=spans[:,2]-x
    if ccdf: y=1.0-quantiles
    else:    y=quantiles
    # plotting
    fig,ax=plt.subplots(figsize=(8,6))
    ax.grid()
    if xlabel: ax.set_xlabel(xlabel)
    if ylabel: ax.set_ylabel(ylabel)
    if title:  ax.set_title(title)
    ax.errorbar(x,y,xerr=(x_err1,x_err2),fmt='o:',ms=5,elinewidth=2,capthick=2)
    if log_yscale:
      #FIXME: Log scale not visualized correctly
      ax.set_yscale('log')  # pass (for debug)
      ax.set_ylim(top=1)
    else:
      ax.set_ylim([0,1])
    for item in ([ax.title,ax.xaxis.label,ax.yaxis.label]+ax.get_xticklabels()+ax.get_yticklabels()):
      item.set_fontsize(fontsize)
    if img_fn_base:
      fig.savefig(img_fn_base+'.pdf')
      fig.savefig(img_fn_base+'.png')
    else: plt.show()
    return fig,ax

  def plot_cdf_alt(self,spans,quantiles,exponential_quantiles=False,savefig=False,img_filename=None,ccdf=False):
    '''
    Plots the cdf from quantiles estimation (with shaded regions instead of errorbars)
    ccdf: True if plotting the ccdf
    '''
    if type(quantiles) is not np.ndarray:
      quantiles=np.array(quantiles)
    x=spans[:, 1]
    lower_bound=spans[:, 0]
    upper_bound=spans[:, 2]
    if ccdf:
      y=1.0 - quantiles
    else:
      y=quantiles
    if exponential_quantiles:
      #FIXME: Add exponential_quantiles capabilities
      print('No exponential quantiles implemented yet!!!')
    # Plotting
    fig, ax=plt.subplots(figsize=(12, 8))
    ax.plot(x,y,  'b')
    ax.fill_betweenx(y, lower_bound, upper_bound, facecolor='blue', alpha=0.2, linewidth=0.0)
    # ax.errorbar(x,y,xerr=[x_err1,x_err2],fmt='o:',ms=5,elinewidth=2,capthick=2)
    ax.set_ylim([0, 1])
    if savefig: fig.savefig(img_filename)
    else: plt.show()
    return fig, ax

if __name__=='__main__':

  def test_02(save_path='quantile_plots/'):
    np.set_printoptions(precision=4)
    conf=0.95
    scale=2.0
    a=10; p=0.3; lam1=2; lam2=10;
    eps_rel=1e-2
    ccdf=True
    continuous_tests=(
      ('Gaussian',                  np.random.normal),
      ('Exp_scale=%.0f_'%scale,     np.random.exponential),
      ('Logistic_scale=%.0f_'%scale,np.random.logistic),
      ('Rayleigh_scale=%.0f_'%scale,np.random.rayleigh),
      ('Gumbel_scale=%.0f_'%scale,  np.random.gumbel),
    )

    fmt='Estimating %.1f%% confidence intervals for quantiles %s of %s:'
    # uniformly spaced quantiles...
    quantiles=np.arange(0.05,0.99,0.1)
    name='Uniform(0,1)'
    print(fmt%(100.0*conf,str(quantiles),name,))
    qe=Quantile_estimator(np.random.random)
    spans,sample_size=qe.estimate(quantiles,conf)
    img_fn_base='%s/qe_uniform_%s'%(save_path,('cdf','ccdf')[ccdf],)
    qe.plot_cdf(spans,quantiles,False,'$x$','Prob[$X>x$]',img_fn_base,ccdf)
    print(spans,'\nsample size needed was %d \n'%sample_size)

    # Continuous distributions
    for name,fun in continuous_tests:
      print(fmt%(100.0*conf,str(quantiles),name,))
      qe=Quantile_estimator(fun,scale=scale)
      spans,sample_size=qe.estimate(quantiles,conf,eps_rel=eps_rel)
      print(spans,'\nsample size needed was %d \n'%sample_size)
      img_fn_base='%s/qe_%s_%s'%(save_path,name,('cdf','ccdf')[ccdf],)
      qe.plot_cdf(spans,quantiles,False,'$x$','Prob[$X>x$]',img_fn_base,ccdf)

    # Discrete distributions
    name='Uniform (discrete)'
    print(fmt%(100.0*conf,str(quantiles),name,))
    qe=Quantile_estimator(np.random.choice, a=a)
    spans,sample_size=qe.estimate(quantiles,conf,eps_rel=eps_rel)
    print(spans,'\nsample size needed was %d \n'%sample_size)
    img_fn_base='%s/qe_uniform_discrete_%s'%(save_path,('cdf','ccdf')[ccdf],)
    qe.plot_cdf(spans,quantiles,False,'$x$','Prob[$X>x$]',img_fn_base,ccdf)

    name = 'Geometric'
    print(fmt % (100.0*conf, str(quantiles), name,))
    qe = Quantile_estimator(np.random.geometric, p=p)
    spans, sample_size = qe.estimate(quantiles, conf, eps_rel=eps_rel)
    print(spans, '\nsample size needed was %d \n' % sample_size)
    img_fn_base = '%s/qe_geometric_%s_%s' % (save_path, 'p=%d' % p, ('cdf', 'ccdf')[ccdf],)
    qe.plot_cdf(spans, quantiles, False, '$x$', 'Prob[$X>x$]', img_fn_base, ccdf)

    name = 'Poisson'
    for lam in [lam1, lam2]:
      print(fmt % (100.0*conf, str(quantiles), name,))
      qe = Quantile_estimator(np.random.poisson, lam=lam)
      # spans,sample_size=qe.estimate(quantiles,conf,eps_abs=2.0) # absolute tolerance
      spans, sample_size = qe.estimate(quantiles, conf, eps_rel=eps_rel)
      print(spans, '\nsample size needed was %d \n' % sample_size)
      img_fn_base = '%s/qe_poisson_%s_%s' % (save_path, 'lam=%d' % lam,
                                             ('cdf', 'ccdf')[ccdf],)
      qe.plot_cdf(spans, quantiles, False, '$x$', 'Prob[$X>x$]', img_fn_base, ccdf)
      #qe.plot_cdf_alt(spans,quantiles,exponential_quantiles=exponential_quantiles,savefig=True,img_filename=img_filename_alt,ccdf=ccdf)
    # non-uniformly-spaces quantiles...
    # quantiles_exponential=1.0-np.power(10.0,np.arange(-30.0,-5.0,5.0)/10.0)
    #quantiles_exponential=np.power(10.0,np.arange(-30.0,-5.0,2.0)/10.0)
  test_02()
