Skip to content

Commit

Permalink
Add base for sgd optimizer (#3496)
Browse files Browse the repository at this point in the history
Summary:

This adds the sgd_optimizer header to executorch. would appreciate some thoughts on where to place this file.

Differential Revision: D56888378
  • Loading branch information
David Lin authored and facebook-github-bot committed May 2, 2024
1 parent 74538f8 commit f619b7d
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 0 deletions.
8 changes: 8 additions & 0 deletions extension/training/optimizers/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Any targets that should be shared between fbcode and xplat must be defined in
# targets.bzl. This file can contain fbcode-only targets.

load(":targets.bzl", "define_common_targets")

oncall("executorch")

define_common_targets()
46 changes: 46 additions & 0 deletions extension/training/optimizers/sgd.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

/**
* SGD (Stochastic gradient descent) optimizer to perform on-device training.
* This uses the gradients calculated in the backwards pass of the loss function
* and updates the parameters such that it minimizes the loss
*/
#pragma once

namespace torch {
namespace executor {
namespace optim {

/**
* SGD optimizer state. This keeps track of the state of a given parameter to
* be used in later epochs.
*/
class SGDParamState {};

/**
* SGD optimizer options. This contains options for performing training on a
* param group, such as the learning rate.
*/
class SGDOptions {};

/**
* SGD optimizer param group. This contains the parameters and
* the OptimizerOptions associated to it.
*/
class SGDParamGroup {};

/**
* SGD optimizer class. This is responsible for performing the optimization
* step.
*/
class SGD {};

} // namespace optim
} // namespace executor
} // namespace torch
19 changes: 19 additions & 0 deletions extension/training/optimizers/targets.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

def define_common_targets():
"""Defines targets that should be shared between fbcode and xplat.
The directory containing this targets.bzl file should also contain both
TARGETS and BUCK files that call this function.
"""

runtime.cxx_library(
name = "optimizers",
exported_headers = [
"sgd.h",
],
exported_deps = [
],
visibility = [
],
)

0 comments on commit f619b7d

Please sign in to comment.