1

So I have a method with simple lambda functions I use to update my weights and I want to try different functions but I also want to have default parameter for the decay:

void ema_update(int i, const NetImpl& mdl, void (&updwp)(torch::Tensor& w, const torch::Tensor& w1, double decay = 0.999)  = [](torch::Tensor& w, const torch::Tensor& w1, double decay) {
    w.set_data(w.data().detach() * decay + w1.detach() * (1. - decay));
    }, void (&updw)(torch::Tensor& w, const torch::Tensor& w1, double decay = 0.999) = [](torch::Tensor& w, const torch::Tensor& w1, double decay) {
        w = w.detach() * decay;
        w += w1.detach() * (1. - decay);
        }) {
    updw(layers[i].cnvtr1->weight, mdl.layers[i].cnvtr1->weight);
    updw(layers[i].cnvtr2->weight, mdl.layers[i].cnvtr2->weight);
    updw(layers[i].cnvtr3->weight, mdl.layers[i].cnvtr3->weight);
    updw(layers[i].lin1->weight, mdl.layers[i].lin1->weight);
    updw(layers[i].lin2->weight, mdl.layers[i].lin2->weight);
    updw(layers[i].lin3->weight, mdl.layers[i].lin3->weight);

    updw(layers[i].cnv1->weight, mdl.layers[i].cnv1->weight);
    updw(layers[i].cnv2->weight, mdl.layers[i].cnv2->weight);
    updw(layers[i].cnv3->weight, mdl.layers[i].cnv3->weight);
    updw(layers[i].cnv4->weight, mdl.layers[i].cnv4->weight);
    updw(layers[i].rnnresh, mdl.layers[i].rnnresh);
    if (layers[i].mha->in_proj_weight.defined())
        updw(layers[i].mha->in_proj_weight, mdl.layers[i].mha->in_proj_weight);
    if (layers[i].mha->k_proj_weight.defined())
        updw(layers[i].mha->k_proj_weight, mdl.layers[i].mha->k_proj_weight);
    if (layers[i].mha->q_proj_weight.defined())
        updw(layers[i].mha->q_proj_weight, mdl.layers[i].mha->q_proj_weight);
    if (layers[i].mha->v_proj_weight.defined())
        updw(layers[i].mha->v_proj_weight, mdl.layers[i].mha->v_proj_weight);
    for (size_t pi = 0; pi < layers[i].trans->decoder.ptr()->parameters().size(); ++pi)
        updwp(layers[i].trans->decoder.ptr()->parameters()[pi], mdl.layers[i].trans->decoder.ptr()->parameters()[pi].data());//torch::nn::init::xavier_uniform_(p).cuda();
    for (size_t pi = 0; pi < layers[i].trans->encoder.ptr()->parameters().size(); ++pi)
        updwp(layers[i].trans->encoder.ptr()->parameters()[pi], mdl.layers[i].trans->encoder.ptr()->parameters()[pi].data());
    for (size_t pi = 0; pi < layers[i].rnn1->all_weights().size(); ++pi)
        updwp(layers[i].rnn1->all_weights()[pi], mdl.layers[i].rnn1->all_weights()[pi].data());
}

Here I specify all the layers I need to update and the small lambdas are default parameters - however I can't set default parameters on the function pointer prototypes for the decay.

MSVC says:

error C2383: 'updwp': default-arguments are not allowed on this symbol

It doesn't matter if I use reference or pointer to the function.

I'm open to alternative suggestions to make both lambdas and decay default parameters.

Also it was originally:

void ema_update(int i, const NetImpl& mdl, double decay = 0.999, void (&updwp)(torch::Tensor& w, const torch::Tensor& w1)  = [](torch::Tensor& w, const torch::Tensor& w1) {
    w.set_data(w.data().detach() * decay + w1.detach() * (1. - decay));
    }, void (&updw)(torch::Tensor& w, const torch::Tensor& w1) = [](torch::Tensor& w, const torch::Tensor& w1) {
        w = w.detach() * decay;
        w += w1.detach() * (1. - decay);
        }) 

But that doesn't work either.

Self contained example:

void f(void (*pf)(int &a, double b=0.77) = [] (int &a, double b){
     a *= b;
}) {
    int a = 9;
    pf(a);
}

void f1(double b=0.77, void (*pf)(int &a) = [] (int &a){
    a *= b;
}) {
    int a = 9;
    pf(a);
}

https://wandbox.org/permlink/XmZtzsxmcwgUIbJy

I also tried std functions:

#include <functional>



void f3(double b=0.77, std::function<void (double b, int &a)> fa = [] (double b, int &a){
     a *= b;
}) {
    int a = 9;
    std::function f = std::bind(fa, b);
    f(a);
}

https://wandbox.org/permlink/Jvv3Dw2gSemkyaxt

6
  • 2
    You can in C++20, but please don't, it's ridiculously unreadable. Use a free function if you must. Commented Mar 8 at 11:38
  • 1
    Do you ever plan on calling the lambda with and without default parameters? Anyway better to make a struct with all the parameters in it and pass an instance of that. Would make the code a lot more readable too Commented Mar 8 at 11:40
  • 1
    Wait, ignore my previous comment. You mean you want the function pointer parameter to have default arguments? That's never intended, but you can wrap it like auto wrapped = [pf, b](int& a) { return pf(a, b); }; Commented Mar 8 at 11:44
  • @PasserBy Yes. Yeah that's the solution I've found - I was overcomplicating it :). Commented Mar 8 at 11:46
  • @PepijnKramer No - I only call with the default parameter passed/or omitted by the caller. Commented Mar 8 at 11:53

1 Answer 1

0

I've found a solution - which is the simplest (maybe not the most elegant):

I need to pass by reference because in the original code a is not simple integer but an object with getter and setter I use inside the lambda.

#include <functional>

void f4(double b=0.77, std::function<void (double b, int &a)> fa= [] (double b, int &a){
     a *= b;
}) {
    int a = 9;
    auto f = [fa, b] (int &a) {fa(b, a);};
    f(a);
}

https://wandbox.org/permlink/uKwqe1reWwwEJAGC

Sign up to request clarification or add additional context in comments.

3 Comments

Two remarks, why do you give your lambda an "in/out parameter" (its IMO not very idiomatic, that's what you have retrun values for). And instead of using a function pointer, use a std::function<void(double,int&> that way you don't have to explain to your callers they can only pass in lambdas without captures.
I need in/out because I operate with 2 objects that I need to modify (in the original code). They have get/set methods. But otherwise thanks I agree std::function is better. (Here is the simplified version)
@AnArrayOfFunctions • Avoid in/out parameters for value objects. Use move-in or copy-by-value parameters, and return values for value objects. (Reserve in/out parameters for entity objects.) Strive make calculating functions pure, not ones that mutate their caller's arguments in situ.

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.