diff --git a/Compiler/library.py b/Compiler/library.py index 2d45d1201..46c4c10a2 100644 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -24,19 +24,59 @@ def get_block(): return get_program().curr_block def vectorize(function): + def mask_output(value, active: regint): + if value is None: + return None + if isinstance(value, tuple): + return tuple(mask_output(x, active) for x in value) + if isinstance(value, list): + return [mask_output(x, active) for x in value] + + try: + size = value.size + except AttributeError: + return value + + if size == 1: + return value + + return value * active + + def get_vector_size(call_args, call_kwargs): + if len(call_args) > 0 and 'size' in dir(call_args[0]): + return call_args[0].size + elif 'size' in call_kwargs: + return call_kwargs['size'] + else: + return None + def vectorized_function(*args, **kwargs): - if len(args) > 0 and 'size' in dir(args[0]): - instructions_base.set_global_vector_size(args[0].size) - res = function(*args, **kwargs) - instructions_base.reset_global_vector_size() - elif 'size' in kwargs: - instructions_base.set_global_vector_size(kwargs['size']) - del kwargs['size'] - res = function(*args, **kwargs) + active_vector_size = regint.conv(kwargs.pop('active_length', None)) + + size = get_vector_size(args, kwargs) + if size is not None: + if 'size' in kwargs: + del kwargs['size'] + instructions_base.set_global_vector_size(size) + + set_active_vector_size = active_vector_size is not None and size is not None and size > 1 + context_saved_arg = None + if set_active_vector_size: + context_saved_arg = get_arg() + starg(-active_vector_size) + + res = function(*args, **kwargs) + + if set_active_vector_size: + starg(context_saved_arg) + active = regint.inc(size) < active_vector_size + res = mask_output(res, active) + + if size is not None: instructions_base.reset_global_vector_size() - else: - res = function(*args, **kwargs) + return res + vectorized_function.__name__ = function.__name__ copy_doc(vectorized_function, function) return vectorized_function diff --git a/GC/ShareThread.hpp b/GC/ShareThread.hpp index 53fd20261..7129b2992 100644 --- a/GC/ShareThread.hpp +++ b/GC/ShareThread.hpp @@ -167,24 +167,50 @@ template void ShareThread::and_(Processor& processor, const vector& args, bool repeat) { + vector active_args; + int active_limit = -1; + long prefix = processor.get_arg().get(); + if (prefix < 0) + active_limit = int(-prefix); + + if (active_limit >= 0) + { + active_args.reserve(args.size()); + for (auto it = args.begin(); it < args.end(); it += 4) + { + int active_bits = min(*it, active_limit); + if (active_bits > 0) + { + active_args.push_back(active_bits); + active_args.push_back(*(it + 1)); + active_args.push_back(*(it + 2)); + active_args.push_back(*(it + 3)); + } + } + } + else + { + active_args = args; + } + auto& protocol = this->protocol; auto& S = processor.S; - processor.check_args(args, 4); + processor.check_args(active_args, 4); protocol->init_mul(); T x_ext, y_ext; size_t total_bits = 0; - for (auto it = args.begin(); it < args.end(); it += 4) + for (auto it = active_args.begin(); it < active_args.end(); it += 4) total_bits += *it; // accept 10 % waste - bool fast_mode = 0.1 * total_bits > args.size() / 4 * T::default_length; + bool fast_mode = 0.1 * total_bits > active_args.size() / 4 * T::default_length; if (fast_mode) { protocol->set_fast_mode(true); } - ArgList> infos(args); + ArgList> infos(active_args); if (repeat) for (auto info : infos) @@ -268,11 +294,43 @@ void ShareThread::andrsvec(Processor& processor, const vector& args) int N_BITS = T::default_length; auto& protocol = this->protocol; assert(protocol); + + vector active_args; + int active_limit = -1; + long prefix = processor.get_arg().get(); + if (prefix < 0) + active_limit = int(-prefix); + + if (active_limit >= 0) + { + active_args.reserve(args.size()); + auto it = args.begin(); + while (it < args.end()) + { + int n_args = (*it - 3) / 2; + int size = *(it + 1); + int active_size = min(size, active_limit); + if (active_size > 0) + { + int group_size = 2 * n_args + 3; + active_args.push_back(*it); + active_args.push_back(active_size); + for (int i = 0; i < group_size - 2; i++) + active_args.push_back(*(it + 2 + i)); + } + it += 2 * n_args + 3; + } + } + else + { + active_args = args; + } + protocol->init_mul(); - auto it = args.begin(); + auto it = active_args.begin(); T x_ext, y_ext; int total_bits = 0; - while (it < args.end()) + while (it < active_args.end()) { int n_args = (*it++ - 3) / 2; int size = *it++; @@ -297,8 +355,8 @@ void ShareThread::andrsvec(Processor& processor, const vector& args) protocol->exchange(); - it = args.begin(); - while (it < args.end()) + it = active_args.begin(); + while (it < active_args.end()) { int n_args = (*it++ - 3) / 2; int size = *it++; diff --git a/Processor/Instruction.cpp b/Processor/Instruction.cpp index 3bb39e2fe..064a2a25d 100644 --- a/Processor/Instruction.cpp +++ b/Processor/Instruction.cpp @@ -21,10 +21,11 @@ void Instruction::execute_clear_gf2n(StackedVector& registers, { auto& C2 = registers; auto& M2C = memory; + int active_size = get_effective_vector_size(Proc, size); switch (opcode) { #define X(NAME, PRE, CODE) \ - case NAME: { PRE; for (int i = 0; i < size; i++) { CODE; } } break; + case NAME: { PRE; for (int i = 0; i < active_size; i++) { CODE; } } break; CLEAR_GF2N_INSTRUCTIONS #undef X } @@ -62,10 +63,11 @@ void Instruction::execute_regint(ArithmeticProcessor& Proc, MemoryPart& { (void) Mi; auto& Ci = Proc.get_Ci(); + int active_size = get_effective_vector_size(Proc, size); switch (opcode) { #define X(NAME, PRE, CODE) \ - case NAME: { PRE; for (int i = 0; i < size; i++) { CODE; } } break; + case NAME: { PRE; for (int i = 0; i < active_size; i++) { CODE; } } break; REGINT_INSTRUCTIONS #undef X } @@ -73,18 +75,20 @@ void Instruction::execute_regint(ArithmeticProcessor& Proc, MemoryPart& void Instruction::shuffle(ArithmeticProcessor& Proc) const { - for (int i = 0; i < size; i++) + int active_size = get_effective_vector_size(Proc, size); + for (int i = 0; i < active_size; i++) Proc.write_Ci(r[0] + i, Proc.read_Ci(r[1] + i)); - for (int i = 0; i < size; i++) + for (int i = 0; i < active_size; i++) { - int j = Proc.shared_prng.get_uint(size - i); + int j = Proc.shared_prng.get_uint(active_size - i); swap(Proc.get_Ci_ref(r[0] + i), Proc.get_Ci_ref(r[0] + i + j)); } } void Instruction::bitdecint(ArithmeticProcessor& Proc) const { - for (int j = 0; j < size; j++) + int active_size = get_effective_vector_size(Proc, size); + for (int j = 0; j < active_size; j++) { long a = Proc.read_Ci(r[0] + j); for (unsigned int i = 0; i < start.size(); i++) diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index c772239a1..aeb6aaa69 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -46,6 +46,14 @@ void BaseInstruction::parse(istream& s, int inst_pos) parse_operands(s, inst_pos, pos); } +inline int get_effective_vector_size(const ProcessorBase& Proc, int size) +{ + long prefix = Proc.get_arg().get(); + if (prefix < 0) + return min(size, int(-prefix)); + return size; +} + inline void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) { @@ -947,14 +955,17 @@ inline void Instruction::execute(Processor& Proc) const { auto& Procp = Proc.Procp; auto& Proc2 = Proc.Proc2; + int active_size = get_effective_vector_size(Proc, size); + auto active_inst = *this; + active_inst.size = active_size; // optimize some instructions switch (opcode) { case CONVMODP: vector values; - values.reserve(size); - for (int i = 0; i < size; i++) + values.reserve(active_size); + for (int i = 0; i < active_size; i++) { auto source = Proc.read_Cp(r[1] + i); Integer tmp; @@ -969,14 +980,14 @@ inline void Instruction::execute(Processor& Proc) const } if (r[2]) Procp.protocol.sync(values, Proc.P); - for (int i = 0; i < size; i++) + for (int i = 0; i < active_size; i++) Proc.write_Ci(r[0] + i, values[i].get()); return; } int r[3] = {this->r[0], this->r[1], this->r[2]}; int64_t n = this->n; - for (int i = 0; i < size; i++) + for (int i = 0; i < active_size; i++) { switch (opcode) { case LDMC: @@ -1143,16 +1154,16 @@ inline void Instruction::execute(Processor& Proc) const Proc.write_Cp(r[0],Proc.temp.ansp); break; case SHRSI: - sint::shrsi(Procp, *this); + sint::shrsi(Procp, active_inst); return; case GSHRSI: - sgf2n::shrsi(Proc2, *this); + sgf2n::shrsi(Proc2, active_inst); return; case OPEN: - Proc.Procp.POpen(*this); + Proc.Procp.POpen(active_inst); return; case GOPEN: - Proc.Proc2.POpen(*this); + Proc.Proc2.POpen(active_inst); return; case MULS: Proc.Procp.muls(start); @@ -1167,48 +1178,48 @@ inline void Instruction::execute(Processor& Proc) const Proc.Proc2.protocol.mulrs(start, Proc.Proc2); return; case DOTPRODS: - Proc.Procp.dotprods(start, size); + Proc.Procp.dotprods(start, active_size); return; case GDOTPRODS: - Proc.Proc2.dotprods(start, size); + Proc.Proc2.dotprods(start, active_size); return; case MATMULS: - Proc.Procp.matmuls(Proc.Procp.get_S(), *this); + Proc.Procp.matmuls(Proc.Procp.get_S(), active_inst); return; case GMATMULS: - Proc.Proc2.matmuls(Proc.Proc2.get_S(), *this); + Proc.Proc2.matmuls(Proc.Proc2.get_S(), active_inst); return; case MATMULSM: - Proc.Procp.protocol.matmulsm(Proc.Procp, Proc.machine.Mp.MS, *this); + Proc.Procp.protocol.matmulsm(Proc.Procp, Proc.machine.Mp.MS, active_inst); return; case GMATMULSM: - Proc.Proc2.protocol.matmulsm(Proc.Proc2, Proc.machine.M2.MS, *this); + Proc.Proc2.protocol.matmulsm(Proc.Proc2, Proc.machine.M2.MS, active_inst); return; case CONV2DS: - Proc.Procp.protocol.conv2ds(Proc.Procp, *this); + Proc.Procp.protocol.conv2ds(Proc.Procp, active_inst); return; case TRUNC_PR: - Proc.Procp.protocol.trunc_pr(start, size, Proc.Procp, + Proc.Procp.protocol.trunc_pr(start, active_size, Proc.Procp, sint::clear::characteristic_two); return; case SECSHUFFLE: - Proc.Procp.secure_shuffle(*this); + Proc.Procp.secure_shuffle(active_inst); return; case GSECSHUFFLE: - Proc.Proc2.secure_shuffle(*this); + Proc.Proc2.secure_shuffle(active_inst); return; case GENSECSHUFFLE: - Proc.write_Ci(r[0], Proc.Procp.generate_secure_shuffle(*this, + Proc.write_Ci(r[0], Proc.Procp.generate_secure_shuffle(active_inst, Proc.machine.shuffle_store)); return; case APPLYSHUFFLE: - Proc.Procp.apply_shuffle(*this, Proc.machine.shuffle_store); + Proc.Procp.apply_shuffle(active_inst, Proc.machine.shuffle_store); return; case DELSHUFFLE: Proc.machine.shuffle_store.del(Proc.read_Ci(r[0])); return; case INVPERM: - Proc.Procp.inverse_permutation(*this); + Proc.Procp.inverse_permutation(active_inst); return; case CHECK: { @@ -1298,7 +1309,13 @@ inline void Instruction::execute(Processor& Proc) const Proc.machine.join_tape(r[0]); break; case CALL_TAPE: - Proc.call_tape(r[0], Proc.read_Ci(r[1]), start); + { + int runtime_arg = Proc.read_Ci(r[1]); + int caller_arg = Proc.get_arg().get(); + if (caller_arg < 0) + runtime_arg = caller_arg; + Proc.call_tape(r[0], runtime_arg, start); + } break; case CRASH: if (Proc.read_Ci(r[0])) @@ -1390,20 +1407,20 @@ inline void Instruction::execute(Processor& Proc) const break; case WRITEFILESHARE: // Write shares to file system - Procp.write_shares_to_file(Proc.read_Ci(r[0]), start, size); + Procp.write_shares_to_file(Proc.read_Ci(r[0]), start, active_size); return; case READFILESHARE: // Read shares from file system - Procp.read_shares_from_file(Proc.read_Ci(r[0]), r[1], start, size, + Procp.read_shares_from_file(Proc.read_Ci(r[0]), r[1], start, active_size, Proc); return; case GWRITEFILESHARE: // Write shares to file system - Proc2.write_shares_to_file(Proc.read_Ci(r[0]), start, size); + Proc2.write_shares_to_file(Proc.read_Ci(r[0]), start, active_size); return; case GREADFILESHARE: // Read shares from file system - Proc2.read_shares_from_file(Proc.read_Ci(r[0]), r[1], start, size, + Proc2.read_shares_from_file(Proc.read_Ci(r[0]), r[1], start, active_size, Proc); return; case PUBINPUT: @@ -1429,16 +1446,16 @@ inline void Instruction::execute(Processor& Proc) const } break; case FIXINPUT: - Proc.fixinput(*this); + Proc.fixinput(active_inst); return; case PREP: - Procp.DataF.get(Proc.Procp.get_S(), r, start, size); + Procp.DataF.get(Proc.Procp.get_S(), r, start, active_size); return; case GPREP: - Proc2.DataF.get(Proc.Proc2.get_S(), r, start, size); + Proc2.DataF.get(Proc.Proc2.get_S(), r, start, active_size); return; case CISC: - Procp.protocol.cisc(Procp, *this); + Procp.protocol.cisc(Procp, active_inst); return; default: throw invalid_opcode(opcode); @@ -1460,7 +1477,7 @@ inline void Instruction::execute(Processor& Proc) const #undef X throw runtime_error("wrong case statement"); return; } - if (size > 1) + if (active_size > 1) { r[0]++; r[1]++; r[2]++; } @@ -1504,11 +1521,12 @@ void Program::execute_with_errors(Processor& Proc) const while (Proc.PC::reset(const Program& program,int arg) Ci.resize(program.num_reg(INT)); this->arg = arg; - Procb.reset(program); + Procb.reset(program, arg); } template @@ -472,6 +472,8 @@ void SubProcessor::POpen(const Instruction& inst) check(); auto& reg = inst.get_start(); int size = inst.get_size(); + if (size == 0) + return; assert(reg.size() % 2 == 0); int sz=reg.size() / 2; MC.init_open(P, sz * size); @@ -553,6 +555,11 @@ void SubProcessor::mulrs(const vector& reg) template void SubProcessor::dotprods(const vector& reg, int size) { + if (size == 0) + { + maybe_check(); + return; + } protocol.init_dotprod(); for (int i = 0; i < size; i++) {