package com.ibm.gpu;

import com.ibm.cuda.CudaDevice;
import java.io.IOException;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.nio.charset.StandardCharsets;

/* loaded from: input_file:com/ibm/gpu/PtxKernelGenerator.class */
final class PtxKernelGenerator {
    private final String maxValue;
    private int stepCount;
    private final boolean typeIsScalar;
    private final String typeName;
    private final int typeSize;
    private final OutputStreamWriter writer;
    static final /* synthetic */ boolean $assertionsDisabled;

    public static void writeTo(int i, char c, OutputStream outputStream) throws IOException {
        new PtxKernelGenerator(outputStream, c).generate(i);
    }

    private PtxKernelGenerator(OutputStream outputStream, char c) {
        switch (c) {
            case 'D':
                this.typeIsScalar = false;
                this.typeName = ".f64";
                this.typeSize = 8;
                this.maxValue = "0dFFF8000000000000";
                break;
            case 'E':
            case 'G':
            case CudaDevice.ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_PITCH /* 72 */:
            default:
                throw new IllegalArgumentException(String.valueOf(c));
            case 'F':
                this.typeIsScalar = false;
                this.typeName = ".f32";
                this.typeSize = 4;
                this.maxValue = "0f7FFFFFFF";
                break;
            case CudaDevice.ATTRIBUTE_MAXIMUM_TEXTURE2D_MIPMAPPED_WIDTH /* 73 */:
                this.typeIsScalar = true;
                this.typeName = ".s32";
                this.typeSize = 4;
                this.maxValue = "0x" + Integer.toHexString(Integer.MAX_VALUE);
                break;
            case CudaDevice.ATTRIBUTE_MAXIMUM_TEXTURE2D_MIPMAPPED_HEIGHT /* 74 */:
                this.typeIsScalar = true;
                this.typeName = ".s64";
                this.typeSize = 8;
                this.maxValue = "0x" + Long.toHexString(Long.MAX_VALUE);
                break;
        }
        this.stepCount = 0;
        this.writer = new OutputStreamWriter(outputStream, StandardCharsets.US_ASCII);
    }

    private void append(String str) throws IOException {
        this.writer.append((CharSequence) str).append('\n');
    }

    private void compare(int i, int i2) throws IOException {
        if (this.typeIsScalar) {
            format("setp.gt%s p0,vl%d,vl%d;", this.typeName, Integer.valueOf(i), Integer.valueOf(i2));
            return;
        }
        format("testp.number%s p1,vl%d;", this.typeName, Integer.valueOf(i2));
        format("setp.gtu.and%s p0,vl%d,vl%d,p1;", this.typeName, Integer.valueOf(i), Integer.valueOf(i2));
        format("@!p0 mov.b%d vs0,vl%d;", Integer.valueOf(this.typeSize * 8), Integer.valueOf(i));
        format("@!p0 mov.b%d vs1,vl%d;", Integer.valueOf(this.typeSize * 8), Integer.valueOf(i2));
        format("@!p0 setp.eq%s p1,vl%d,vl%d;", this.typeName, Integer.valueOf(i), Integer.valueOf(i2));
        format("@!p0 setp.gt.and.s%d p0,vs0,vs1,p1;", Integer.valueOf(this.typeSize * 8));
    }

    private void compareAndSwap(int i, int i2) throws IOException {
        compare(i, i2);
        format("@p0 mov%s tmp,vl%d;", this.typeName, Integer.valueOf(i));
        format("@p0 mov%s vl%d,vl%d;", this.typeName, Integer.valueOf(i), Integer.valueOf(i2));
        format("@p0 mov%s vl%d,tmp;", this.typeName, Integer.valueOf(i2));
        format("@p0 or.b32 moved,moved,%s;", constant((1 << i) | (1 << i2)));
    }

    private void computeIndices(boolean z) throws IOException {
        if (!$assertionsDisabled && (1 > this.stepCount || this.stepCount > 5)) {
            throw new AssertionError();
        }
        int i = 1 << this.stepCount;
        Object obj = "stride";
        if (this.stepCount != 1) {
            obj = "step";
            format("shr.u32 step,stride,%d;", Integer.valueOf(this.stepCount - 1));
        }
        format("sub.s32 mask,%s,1;", obj);
        append("mov.u32 rt0,%nctaid.x;");
        append("mov.u32 rt1,%ctaid.y;");
        append("mov.u32 rt2,%ctaid.x;");
        append("mad.lo.u32 threadId,rt0,rt1,rt2;");
        append("mov.u32 rt0,%ntid.x;");
        append("mov.u32 rt1,%tid.x;");
        append("mad.lo.u32 threadId,threadId,rt0,rt1;");
        append("not.b32 rt0,mask;");
        append("and.b32 rt0,rt0,threadId;");
        append("and.b32 rt1,threadId,mask;");
        format("mad.lo.u32 ix0,rt0,%d,rt1;", Integer.valueOf(i));
        if (z) {
            int i2 = i >> 1;
            int i3 = 0;
            while (true) {
                i3++;
                if (i3 >= i2) {
                    break;
                } else {
                    format("add.u32 ix%d,ix%d,%s;", Integer.valueOf(i3), Integer.valueOf(i3 - 1), obj);
                }
            }
            append("mad.lo.u32 rt0,stride,2,-1;");
            format("xor.b32 ix%d,ix%d,rt0;", Integer.valueOf(i2), Integer.valueOf(i2 - 1));
            int i4 = i2;
            while (true) {
                i4++;
                if (i4 >= i) {
                    return;
                } else {
                    format("add.u32 ix%d,ix%d,%s;", Integer.valueOf(i4), Integer.valueOf(i4 - 1), obj);
                }
            }
        } else {
            int i5 = 0;
            while (true) {
                i5++;
                if (i5 >= i) {
                    return;
                } else {
                    format("add.u32 ix%d,ix%d,%s;", Integer.valueOf(i5), Integer.valueOf(i5 - 1), obj);
                }
            }
        }
    }

    private String constant(int i) {
        String str;
        switch (this.stepCount) {
            case 1:
                str = "%d";
                break;
            case 2:
                str = "0x%x";
                break;
            case 3:
                str = "0x%02x";
                break;
            default:
                str = "0x%04x";
                break;
        }
        return String.format(str, Integer.valueOf(i));
    }

    private void declareLocals() throws IOException {
        int i = 1 << this.stepCount;
        append(".reg .u64 data;");
        append(".reg .u32 length;");
        append(".reg .u32 stride;");
        append(".reg .u32 threadId;");
        append(".reg .u32 mask;");
        if (this.stepCount != 1) {
            append(".reg .u32 step;");
        }
        format(".reg %s tmp;", this.typeName);
        append(".reg .b32 moved;");
        append(".reg .b32 bit;");
        append(".reg .u32 rt<3>;");
        format(".reg .u32 ix<%d>;", Integer.valueOf(i));
        format(".reg %s vl<%d>;", this.typeName, Integer.valueOf(i));
        append(".reg .pred p<2>;");
        format(".reg .s%d vs<2>;", Integer.valueOf(this.typeSize * 8));
        append(".reg .u64 ptr;");
    }

    private void emitFirstPhases() throws IOException {
        append(".visible .entry");
        format("phase%d(.param .u64 _data,.param .u32 _length)", 9);
        append(".maxntid 256,1,1");
        append("{");
        format(".shared .align %d %s _sharedData[%d];", Integer.valueOf(this.typeSize), this.typeName, 512);
        append(".reg .u64 data;");
        append(".reg .u32 length;");
        append(".reg .u64 sharedData;");
        append(".reg .u64 dataPtr;");
        append(".reg .u64 sharedPtr<2>;");
        append(".reg .u32 baseIndex;");
        append(".reg .u32 blockDimX;");
        append(".reg .u32 globalIndex;");
        append(".reg .u32 workId;");
        append(".reg .pred p<2>;");
        format(".reg .s%d vs<2>;", Integer.valueOf(this.typeSize * 8));
        append(".reg .u32 ix<2>;");
        append(".reg .u32 rt<3>;");
        format(".reg %s vl<2>;", this.typeName);
        append("ld.param.u64 data,[_data];");
        append("cvta.to.global.u64 data,data;");
        append("ld.param.u32 length,[_length];");
        append("mov.u64 sharedData,_sharedData;");
        append("mov.u32 blockDimX,%ntid.x;");
        append("mov.u32 rt0,%nctaid.x;");
        append("mov.u32 rt1,%ctaid.y;");
        append("mov.u32 rt2,%ctaid.x;");
        append("mad.lo.u32 baseIndex,rt0,rt1,rt2;");
        format("shl.b32 baseIndex,baseIndex,%d;", 9);
        append("mov.u32 workId,%tid.x;");
        append("bra loadTest;");
        append("loadLoop:");
        append("add.u32 globalIndex,baseIndex,workId;");
        format("mov%s vl0,%s;", this.typeName, this.maxValue);
        append("setp.lt.u32 p0,globalIndex,length;");
        format("@p0 mad.wide.u32 dataPtr,globalIndex,%d,data;", Integer.valueOf(this.typeSize));
        format("@p0 ld.global%s vl0,[dataPtr];", this.typeName);
        format("mad.wide.u32 sharedPtr0,workId,%d,sharedData;", Integer.valueOf(this.typeSize));
        format("st.shared%s [sharedPtr0],vl0;", this.typeName);
        append("add.u32 workId,workId,blockDimX;");
        append("loadTest:");
        format("setp.lt.u32 p0,workId,%d;", 512);
        append("@p0 bra loadLoop;");
        for (int i = 0; i < 9; i++) {
            for (int i2 = 0; i2 <= i; i2++) {
                append("bar.sync 0;");
                String format = String.format("workLoop_%d_%d", Integer.valueOf(i + 1), Integer.valueOf(i2 + 1));
                String format2 = String.format("workTest_%d_%d", Integer.valueOf(i + 1), Integer.valueOf(i2 + 1));
                append("mov.u32 workId,%tid.x;");
                format("bra %s;", format2);
                format("%s:", format);
                if (i2 == i) {
                    append("shl.b32 ix0,workId,1;");
                } else {
                    append("shl.b32 ix0,workId,1;");
                    format("and.b32 rt0,workId,%s;", constant((1 << (i - i2)) - 1));
                    append("sub.u32 ix0,ix0,rt0;");
                }
                if (i2 != 0 || i2 == i) {
                    format("add.u32 ix1,ix0,%s;", constant(1 << (i - i2)));
                } else {
                    format("xor.b32 ix1,ix0,%s;", constant((2 << i) - 1));
                }
                format("mad.wide.u32 sharedPtr0,ix0,%d,sharedData;", Integer.valueOf(this.typeSize));
                format("ld.shared%s vl0,[sharedPtr0];", this.typeName);
                format("mad.wide.u32 sharedPtr1,ix1,%d,sharedData;", Integer.valueOf(this.typeSize));
                format("ld.shared%s vl1,[sharedPtr1];", this.typeName);
                compare(0, 1);
                format("@p0 st.shared%s [sharedPtr0],vl1;", this.typeName);
                format("@p0 st.shared%s [sharedPtr1],vl0;", this.typeName);
                append("add.u32 workId,workId,blockDimX;");
                format("%s:", format2);
                format("setp.lt.u32 p0,workId,%d;", 256);
                format("@p0 bra %s;", format);
            }
        }
        append("bar.sync 0;");
        append("mov.u32 workId,%tid.x;");
        append("bra storeTest;");
        append("storeLoop:");
        append("{");
        append("add.u32 globalIndex,baseIndex,workId;");
        append("setp.lt.u32 p0,globalIndex,length;");
        format("@p0 mad.wide.u32 sharedPtr0,workId,%d,sharedData;", Integer.valueOf(this.typeSize));
        format("@p0 ld.shared%s vl0,[sharedPtr0];", this.typeName);
        format("@p0 mad.wide.u32 dataPtr,globalIndex,%d,data;", Integer.valueOf(this.typeSize));
        format("@p0 st.global%s [dataPtr],vl0;", this.typeName);
        append("add.u32 workId,workId,blockDimX;");
        append("}");
        append("storeTest:");
        format("setp.lt.u32 p0,workId,%d;", 512);
        append("@p0 bra storeLoop;");
        append("}");
    }

    private void emitKernel(boolean z) throws IOException {
        append(".visible .entry");
        Object[] objArr = new Object[2];
        objArr[0] = z ? "first" : "other";
        objArr[1] = Integer.valueOf(this.stepCount);
        format("%s%d(.param .u64 _data,.param .u32 _length,.param .u32 _stride)", objArr);
        append(".maxntid 256,1,1");
        append("{");
        declareLocals();
        append("ld.param.u64 data,[_data];");
        append("cvta.to.global.u64 data,data;");
        append("ld.param.u32 length,[_length];");
        append("ld.param.u32 stride,[_stride];");
        computeIndices(z);
        gatherData();
        sortLocally(z);
        scatterData();
        append("}");
    }

    private void emitPreamble(int i) throws IOException {
        append(".version 3.2");
        Object[] objArr = new Object[1];
        objArr[0] = Integer.valueOf(i < 3 ? 20 : 30);
        format(".target sm_%d", objArr);
        append(".address_size 64");
    }

    private void format(String str, Object... objArr) throws IOException {
        append(String.format(str, objArr));
    }

    private void gatherData() throws IOException {
        int i = 1 << this.stepCount;
        for (int i2 = 0; i2 < i; i2++) {
            format("mov%s vl%d,%s;", this.typeName, Integer.valueOf(i2), this.maxValue);
            format("setp.lt.u32 p0,ix%d,length;", Integer.valueOf(i2));
            format("@p0 mad.wide.u32 ptr,ix%d,%d,data;", Integer.valueOf(i2), Integer.valueOf(this.typeSize));
            format("@p0 ld.global%s vl%d,[ptr];", this.typeName, Integer.valueOf(i2));
        }
    }

    private void generate(int i) throws IOException {
        emitPreamble(i);
        this.stepCount = 1;
        while (true) {
            emitKernel(false);
            if (this.stepCount == 4) {
                emitKernel(true);
                emitFirstPhases();
                this.writer.flush();
                return;
            }
            this.stepCount++;
        }
    }

    private void scatterData() throws IOException {
        int i = 1 << this.stepCount;
        for (int i2 = 0; i2 < i; i2++) {
            format("and.b32 bit,moved,%s;", constant(1 << i2));
            append("setp.ne.b32 p0,bit,0;");
            format("@p0 mad.wide.u32 ptr,ix%d,%d,data;", Integer.valueOf(i2), Integer.valueOf(this.typeSize));
            format("@p0 st.global%s [ptr],vl%d;", this.typeName, Integer.valueOf(i2));
        }
    }

    private void sortLocally(boolean z) throws IOException {
        int i = 1 << this.stepCount;
        int i2 = i >> 1;
        int i3 = 0;
        append("mov.b32 moved,0;");
        if (z) {
            int i4 = i - 1;
            for (int i5 = 0; i5 < i2; i5++) {
                compareAndSwap(i5, i5 ^ i4);
            }
            i3 = 0 + 1;
        }
        while (i3 < this.stepCount) {
            int i6 = i >> (i3 + 1);
            int i7 = i6 << 1;
            int i8 = 0;
            while (true) {
                int i9 = i8;
                if (i9 < i) {
                    for (int i10 = 0; i10 < i6; i10++) {
                        int i11 = i9 + (i10 & (-i6)) + i10;
                        compareAndSwap(i11, i11 + i6);
                    }
                    i8 = i9 + i7;
                }
            }
            i3++;
        }
    }

    static {
        $assertionsDisabled = !PtxKernelGenerator.class.desiredAssertionStatus();
    }
}
