Skip to content

Commit

Permalink
Add base for sgd optimizer (#3496)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3496

This adds the sgd_optimizer header to executorch. would appreciate some thoughts on where to place this file.

Reviewed By: JacobSzwejbka

Differential Revision: D56888378

fbshipit-source-id: 17d6bb3975ae2d58aee911ee91a3ff07acbc6850
  • Loading branch information
David Lin authored and facebook-github-bot committed May 13, 2024
1 parent ebe701e commit c853b3c
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 0 deletions.
8 changes: 8 additions & 0 deletions extension/training/optimizer/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Any targets that should be shared between fbcode and xplat must be defined in
# targets.bzl. This file can contain fbcode-only targets.

load(":targets.bzl", "define_common_targets")

oncall("executorch")

define_common_targets()
49 changes: 49 additions & 0 deletions extension/training/optimizer/sgd.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

/**
* SGD (stochastic gradient descent) optimizer to perform on-device training.
* This uses the gradients calculated in the backwards pass of the loss function
* and updates the parameters such that it minimizes the loss.
*
* This is similar to the Lite Interpreter implementation of the SGD optimizer,
* but without the dependency on ATen Tensors and autograd.
*/
#pragma once

namespace torch {
namespace executor {
namespace optimizer {

/**
* SGD optimizer state. This keeps track of the state of a given parameter to
* be used in later epochs.
*/
class SGDParamState {};

/**
* SGD optimizer options. This contains options for performing training on a
* param group, such as the learning rate.
*/
class SGDOptions {};

/**
* SGD optimizer param group. This contains the parameters and
* the OptimizerOptions associated to it.
*/
class SGDParamGroup {};

/**
* SGD optimizer class. This is responsible for performing the optimization
* step.
*/
class SGD {};

} // namespace optimizer
} // namespace executor
} // namespace torch
20 changes: 20 additions & 0 deletions extension/training/optimizer/targets.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

def define_common_targets():
"""Defines targets that should be shared between fbcode and xplat.
The directory containing this targets.bzl file should also contain both
TARGETS and BUCK files that call this function.
"""

runtime.cxx_library(
name = "optimizer",
exported_headers = [
"sgd.h",
],
exported_deps = [
],
visibility = [
"@EXECUTORCH_CLIENTS",
],
)
8 changes: 8 additions & 0 deletions extension/training/optimizer/test/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Any targets that should be shared between fbcode and xplat must be defined in
# targets.bzl. This file can contain fbcode-only targets.

load(":targets.bzl", "define_common_targets")

oncall("executorch")

define_common_targets()
28 changes: 28 additions & 0 deletions extension/training/optimizer/test/sgd_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/extension/training/optimizer/sgd.h>

#include <gtest/gtest.h>

using namespace ::testing;
using namespace torch::executor::optimizer;

class SGDOptimizerTest : public ::testing::Test {};

TEST_F(SGDOptimizerTest, InstantiateTypes) {
SGDParamState state;
SGDOptions options;
SGDParamGroup param_group;
SGD sgd;

EXPECT_TRUE(dynamic_cast<SGDParamState*>(&state) != nullptr);
EXPECT_TRUE(dynamic_cast<SGDOptions*>(&options) != nullptr);
EXPECT_TRUE(dynamic_cast<SGDParamGroup*>(&param_group) != nullptr);
EXPECT_TRUE(dynamic_cast<SGD*>(&sgd) != nullptr);
}
18 changes: 18 additions & 0 deletions extension/training/optimizer/test/targets.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

def define_common_targets():
"""Defines targets that should be shared between fbcode and xplat.
The directory containing this targets.bzl file should also contain both
TARGETS and BUCK files that call this function.
"""

runtime.cxx_test(
name = "sgd_test",
srcs = [
"sgd_test.cpp",
],
deps = [
"//executorch/extension/training/optimizer:optimizer",
],
)

0 comments on commit c853b3c

Please sign in to comment.