package com.fairsearch.fair;

import com.fairsearch.fair.lib.FairTopK;
import com.fairsearch.fair.lib.MTableGenerator;
import com.fairsearch.fair.lib.RecursiveNumericFailProbabilityCalculator;
import com.fairsearch.fair.utils.FairScoreDoc;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.logging.Logger;
import jdk.nashorn.internal.runtime.regexp.joni.exception.ValueException;
import org.apache.lucene.search.TopDocs;

/* loaded from: input_file:com/fairsearch/fair/Fair.class */
public class Fair {
    private static final Logger LOGGER = Logger.getLogger(Fair.class.getName());
    private int k;
    private double p;
    private double alpha;
    private FairTopK fairTopK;

    public Fair(int i, double d, double d2) {
        validateBasicParameters(i, d, d2);
        this.k = i;
        this.p = d;
        this.alpha = d2;
        this.fairTopK = new FairTopK();
    }

    public int[] createUnadjustedMTable() {
        return createMTable(this.alpha, false);
    }

    public int[] createAdjustedMTable() {
        return createMTable(this.alpha, true);
    }

    private int[] createMTable(double d, boolean z) {
        validateAlpha(d);
        MTableGenerator mTableGenerator = new MTableGenerator(this.k, this.p, d, z);
        return Arrays.copyOfRange(mTableGenerator.getMTable(), 1, mTableGenerator.getMTable().length);
    }

    public double adjustAlpha() {
        return new RecursiveNumericFailProbabilityCalculator(this.k, this.p, this.alpha).adjustAlpha().getAlpha();
    }

    public double computeFailureProbability(int[] iArr) {
        if (iArr.length != this.k) {
            throw new ValueException("Number of elements k and (int[]) mtable length must be equal!");
        }
        RecursiveNumericFailProbabilityCalculator recursiveNumericFailProbabilityCalculator = new RecursiveNumericFailProbabilityCalculator(this.k, this.p, this.alpha);
        int[] iArr2 = new int[iArr.length + 1];
        System.arraycopy(iArr, 0, iArr2, 1, iArr.length);
        return recursiveNumericFailProbabilityCalculator.calculateFailProbability(iArr2);
    }

    public static boolean checkRankingMTable(TopDocs topDocs, int[] iArr) {
        int i = 0;
        if (topDocs.scoreDocs.length != iArr.length) {
            throw new ValueException("Number of documents in (TopDocs) docs and (int[]) mtable length are not the same!");
        }
        for (int i2 = 0; i2 < topDocs.scoreDocs.length; i2++) {
            i += ((FairScoreDoc) topDocs.scoreDocs[i2]).isProtected ? 1 : 0;
            if (i < iArr[i2]) {
                return false;
            }
        }
        return true;
    }

    public boolean isFair(TopDocs topDocs) {
        return checkRankingMTable(topDocs, createAdjustedMTable());
    }

    public TopDocs reRank(TopDocs topDocs) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i = 0; i < topDocs.scoreDocs.length; i++) {
            FairScoreDoc fairScoreDoc = (FairScoreDoc) topDocs.scoreDocs[i];
            if (fairScoreDoc.isProtected) {
                arrayList.add(fairScoreDoc);
            } else {
                arrayList2.add(fairScoreDoc);
            }
        }
        return this.fairTopK.fairTopK(arrayList2, arrayList, this.k, this.p, this.alpha);
    }

    private static void validateBasicParameters(int i, double d, double d2) {
        if (i < 10 || i > 400) {
            if (i < 2) {
                throw new ValueException("Total number of elements `k` should be between 10 and 400");
            }
            LOGGER.warning("Library has not been tested with values outside this range");
        }
        if (d < 0.02d || d > 0.98d) {
            if (d < 0.0d || d > 1.0d) {
                throw new ValueException("The proportion of protected candidates `p` in the top-k ranking should be between 0.02 and 0.98");
            }
            LOGGER.warning("Library has not been tested with values outside this range");
        }
        validateAlpha(d2);
    }

    private static void validateAlpha(double d) {
        if (d < 0.01d || d > 0.15d) {
            if (d < 0.001d || d > 0.5d) {
                throw new ValueException("The significance level `alpha` must be between 0.01 and 0.15");
            }
            LOGGER.warning("Library has not been tested with values outside this range");
        }
    }
}
