Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unit tests for operator overloading; fix scalar bug #225

Merged
merged 19 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 10 additions & 24 deletions src/ctorch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,24 +277,6 @@ torch_tensor_t torch_tensor_multiply(const torch_tensor_t tensor1,
return output;
}

torch_tensor_t torch_tensor_premultiply(const torch_data_t scalar,
const torch_tensor_t tensor) {
auto t = reinterpret_cast<torch::Tensor *const>(tensor);
torch::Tensor *output = nullptr;
output = new torch::Tensor;
*output = scalar * *t;
return output;
}

torch_tensor_t torch_tensor_postmultiply(const torch_tensor_t tensor,
const torch_data_t scalar) {
auto t = reinterpret_cast<torch::Tensor *const>(tensor);
torch::Tensor *output = nullptr;
output = new torch::Tensor;
*output = *t * scalar;
return output;
}

torch_tensor_t torch_tensor_divide(const torch_tensor_t tensor1,
const torch_tensor_t tensor2) {
auto t1 = reinterpret_cast<torch::Tensor *const>(tensor1);
Expand All @@ -305,21 +287,25 @@ torch_tensor_t torch_tensor_divide(const torch_tensor_t tensor1,
return output;
}

torch_tensor_t torch_tensor_postdivide(const torch_tensor_t tensor,
const torch_data_t scalar) {
torch_tensor_t torch_tensor_power_int(const torch_tensor_t tensor,
const torch_int_t exponent) {
auto t = reinterpret_cast<torch::Tensor *const>(tensor);
// NOTE: The following cast will only work for integer exponents
auto exp = reinterpret_cast<int *const>(exponent);
torch::Tensor *output = nullptr;
output = new torch::Tensor;
*output = *t / scalar;
*output = pow(*t, *exp);
return output;
}

torch_tensor_t torch_tensor_power(const torch_tensor_t tensor,
const torch_data_t exponent) {
torch_tensor_t torch_tensor_power_float(const torch_tensor_t tensor,
const torch_float_t exponent) {
auto t = reinterpret_cast<torch::Tensor *const>(tensor);
// NOTE: The following cast will only work for floating point exponents
auto exp = reinterpret_cast<float *const>(exponent);
torch::Tensor *output = nullptr;
output = new torch::Tensor;
*output = pow(*t, exponent);
*output = pow(*t, *exp);
return output;
}

Expand Down
47 changes: 18 additions & 29 deletions src/ctorch.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ typedef void *torch_jit_script_module_t;
// Opaque pointer type alias for at::Tensor
typedef void *torch_tensor_t;

// Opaque pointer type alias for integer scalars
typedef void *torch_int_t;

// Opaque pointer type alias for float scalars
typedef void *torch_float_t;

// Data types
typedef enum {
torch_kUInt8,
Expand Down Expand Up @@ -171,25 +177,7 @@ EXPORT_C torch_tensor_t torch_tensor_multiply(const torch_tensor_t tensor1,
const torch_tensor_t tensor2);

/**
* Overloads the premultiplication operator for a scalar and a Torch Tensor
* @param scalar to multiply by
* @param Tensor to be multiplied
* @return product of the scalar and Tensor
*/
EXPORT_C torch_tensor_t torch_tensor_premultiply(const torch_data_t scalar,
const torch_tensor_t tensor);

/**
* Overloads the postmultiplication operator for a Torch Tensor and a scalar
* @param Tensor to be multiplied
* @param scalar to multiply by
* @return product of the Tensor and scalar
*/
EXPORT_C torch_tensor_t torch_tensor_postmultiply(const torch_tensor_t tensor,
const torch_data_t scalar);

/**
* Overloads the division operator for two Torch Tensors
* Overloads the division operator for two Torch Tensors.
* @param first Tensor to be divided
* @param second Tensor to be divided
* @return quotient of the Tensors
Expand All @@ -198,22 +186,23 @@ EXPORT_C torch_tensor_t torch_tensor_divide(const torch_tensor_t tensor1,
const torch_tensor_t tensor2);

/**
* Overloads the post-division operator for a Torch Tensor and a scalar
* @param Tensor to be divided
* @param scalar to divide by
* @return quotient of the Tensor and scalar
* Overloads the exponentiation operator for a Torch Tensor and an integer exponent
* @param Tensor to take the power of
* @param integer exponent
* @return power of the Tensor
*/
EXPORT_C torch_tensor_t torch_tensor_postdivide(const torch_tensor_t tensor,
const torch_data_t scalar);
EXPORT_C torch_tensor_t torch_tensor_power_int(const torch_tensor_t tensor,
const torch_int_t exponent);

/**
* Overloads the exponentiation operator for two Torch Tensors
* Overloads the exponentiation operator for a Torch Tensor and a floating point
* exponent
* @param Tensor to take the power of
* @param scalar exponent
* @param floating point exponent
* @return power of the Tensor
*/
EXPORT_C torch_tensor_t torch_tensor_power(const torch_tensor_t tensor,
const torch_data_t exponent);
EXPORT_C torch_tensor_t torch_tensor_power_float(const torch_tensor_t tensor,
const torch_float_t exponent);

// =====================================================================================
// Module API
Expand Down
Loading
Loading